summary refs log tree commit diff stats
diff options
context:
space:
mode:
-rw-r--r--lib/pure/math.nim18
-rw-r--r--lib/pure/random.nim41
-rw-r--r--tests/stdlib/tmath.nim57
3 files changed, 86 insertions, 30 deletions
diff --git a/lib/pure/math.nim b/lib/pure/math.nim
index ee32772b1..72abae265 100644
--- a/lib/pure/math.nim
+++ b/lib/pure/math.nim
@@ -163,6 +163,24 @@ proc prod*[T](x: openArray[T]): T {.noSideEffect.} =
   result = 1.T
   for i in items(x): result = result * i
 
+proc cumsummed*[T](x: openArray[T]): seq[T] =
+  ## Return cumulative aka prefix summation of ``x``.
+  ##
+  ## .. code-block:: nim
+  ##   var x = [1, 2, 3, 4]
+  ##   echo x.cumsummed    # [1, 3, 6, 10]
+  result.setLen(x.len)
+  result[0] = x[0]
+  for i in 1 ..< x.len: result[i] = result[i-1] + x[i]
+
+proc cumsum*[T](x: var openArray[T]) =
+  ## Transforms ``x`` in-place into its cumulative aka prefix summation.
+  ##
+  ## .. code-block:: nim
+  ##   var x = [1, 2, 3, 4]
+  ##   x.cumsum;  echo x    # [1, 3, 6, 10]
+  for i in 1 ..< x.len: x[i] = x[i-1] + x[i]
+
 {.push noSideEffect.}
 when not defined(JS): # C
   proc sqrt*(x: float32): float32 {.importc: "sqrtf", header: "<math.h>".}
diff --git a/lib/pure/random.nim b/lib/pure/random.nim
index 26e6740ea..ee728ad4a 100644
--- a/lib/pure/random.nim
+++ b/lib/pure/random.nim
@@ -175,27 +175,26 @@ proc sample*[T](a: openArray[T]): T =
   ## returns a random element from openArray ``a`` using non-thread-safe state.
   result = a[rand(a.low..a.high)]
 
-proc sample*[T, U](r: var Rand; a: openArray[T], w: openArray[U], n=1): seq[T] =
-  ## Return a sample (with replacement) of size ``n`` from elements of ``a``
-  ## according to convertible-to-``float``, not necessarily normalized, and
-  ## non-negative weights ``w``.  Uses state in ``r``.  Must have sum ``w > 0.0``.
-  assert(w.len == a.len)
-  var cdf = newSeq[float](a.len)   # The *unnormalized* CDF
-  var tot = 0.0                    # Unnormalized is fine if we sample up to tot
-  for i, w in w:
-    assert(w >= 0)
-    tot += float(w)
-    cdf[i] = tot
-  assert(tot > 0.0)                # Need at least one non-zero weight
-  for i in 0 ..< n:
-    result.add(a[cdf.upperBound(r.rand(tot))])
-
-proc sample*[T, U](a: openArray[T], w: openArray[U], n=1): seq[T] =
-  ## Return a sample (with replacement) of size ``n`` from elements of ``a``
-  ## according to convertible-to-``float``, not necessarily normalized, and
-  ## non-negative weights ``w``.  Uses default non-thread-safe state.
-  state.sample(a, w, n)
-
+proc sample*[T, U](r: var Rand; a: openArray[T], cdf: openArray[U]): T =
+  ## Sample one element from openArray ``a`` when it has cumulative distribution
+  ## function (CDF) ``cdf`` (not necessarily normalized, any type of elements
+  ## convertible to ``float``). Uses state in ``r``. E.g.:
+  ##
+  ## .. code-block:: nim
+  ##   let val = [ "a", "b", "c", "d" ]  # some values
+  ##   var cnt = [1, 2, 3, 4]            # histogram of counts
+  ##   echo r.sample(val, cnt.cumsummed) # echo a sample
+  assert(cdf.len == a.len)              # Two basic sanity checks.
+  assert(float(cdf[^1]) > 0.0)
+  #While we could check cdf[i-1] <= cdf[i] for i in 1..cdf.len, that could get
+  #awfully expensive even in debugging modes.
+  let u = r.rand(float(cdf[^1]))
+  a[cdf.upperBound(U(u))]
+
+proc sample*[T, U](a: openArray[T], cdf: openArray[U]): T =
+  ## Like ``sample(var Rand; openArray[T], openArray[U])``, but uses default
+  ## non-thread-safe state.
+  state.sample(a, cdf)
 
 proc initRand*(seed: int64): Rand =
   ## Creates a new ``Rand`` state from ``seed``.
diff --git a/tests/stdlib/tmath.nim b/tests/stdlib/tmath.nim
index 7c1851e7a..544fa8dc0 100644
--- a/tests/stdlib/tmath.nim
+++ b/tests/stdlib/tmath.nim
@@ -4,6 +4,8 @@ discard """
 
 [Suite] random float
 
+[Suite] cumsum
+
 [Suite] random sample
 
 [Suite] ^
@@ -74,29 +76,66 @@ suite "random float":
     var rand2:float = random(1000000.0)
     check rand1 != rand2
 
+suite "cumsum":
+  test "cumsum int seq return":
+    let counts = [ 1, 2, 3, 4 ]
+    check counts.cumsummed == [ 1, 3, 6, 10 ]
+
+  test "cumsum float seq return":
+    let counts = [ 1.0, 2.0, 3.0, 4.0 ]
+    check counts.cumsummed == [ 1.0, 3.0, 6.0, 10.0 ]
+
+  test "cumsum int in-place":
+    var counts = [ 1, 2, 3, 4 ]
+    counts.cumsum
+    check counts == [ 1, 3, 6, 10 ]
+
+  test "cumsum float in-place":
+    var counts = [ 1.0, 2.0, 3.0, 4.0 ]
+    counts.cumsum
+    check counts == [ 1.0, 3.0, 6.0, 10.0 ]
+
 suite "random sample":
-  test "non-uniform array sample":
+  test "non-uniform array sample unnormalized int CDF":
     let values = [ 10, 20, 30, 40, 50 ] # values
-    let weight = [ 4, 3, 2, 1, 0 ]      # weights aka unnormalized probabilities
-    let weightSum = 10.0                # sum of weights
+    let counts = [ 4, 3, 2, 1, 0 ]      # weights aka unnormalized probabilities
     var histo = initCountTable[int]()
-    for v in sample(values, weight, 5000):
-      histo.inc(v)
-    check histo.len == 4                # number of non-zero in `weight`
+    let cdf = counts.cumsummed          # unnormalized CDF
+    for i in 0 ..< 5000:
+      histo.inc(sample(values, cdf))
+    check histo.len == 4                # number of non-zero in `counts`
     # Any one bin is a binomial random var for n samples, each with prob p of
     # adding a count to k; E[k]=p*n, Var k=p*(1-p)*n, approximately Normal for
     # big n.  So, P(abs(k - p*n)/sqrt(p*(1-p)*n))>3.0) =~ 0.0027, while
     # P(wholeTestFails) =~ 1 - P(binPasses)^4 =~ 1 - (1-0.0027)^4 =~ 0.01.
-    for i, w in weight:
-      if w == 0:
+    for i, c in counts:
+      if c == 0:
         check values[i] notin histo
         continue
-      let p = float(w) / float(weightSum)
+      let p = float(c) / float(cdf[^1])
       let n = 5000.0
       let expected = p * n
       let stdDev = sqrt(n * p * (1.0 - p))
       check abs(float(histo[values[i]]) - expected) <= 3.0 * stdDev
 
+  test "non-uniform array sample normalized float CDF":
+    let values = [ 10, 20, 30, 40, 50 ]     # values
+    let counts = [ 0.4, 0.3, 0.2, 0.1, 0 ]  # probabilities
+    var histo = initCountTable[int]()
+    let cdf = counts.cumsummed              # normalized CDF
+    for i in 0 ..< 5000:
+      histo.inc(sample(values, cdf))
+    check histo.len == 4                    # number of non-zero in ``counts``
+    for i, c in counts:
+      if c == 0:
+        check values[i] notin histo
+        continue
+      let p = float(c) / float(cdf[^1])
+      let n = 5000.0
+      let expected = p * n
+      let stdDev = sqrt(n * p * (1.0 - p))
+      # NOTE: like unnormalized int CDF test, P(wholeTestFails) =~ 0.01.
+      check abs(float(histo[values[i]]) - expected) <= 3.0 * stdDev
 
 suite "^":
   test "compiles for valid types":