summary refs log tree commit diff stats
diff options
context:
space:
mode:
-rw-r--r--lib/pure/random.nim37
-rw-r--r--tests/stdlib/tmath.nim28
2 files changed, 62 insertions, 3 deletions
diff --git a/lib/pure/random.nim b/lib/pure/random.nim
index c458d51eb..d6501c87e 100644
--- a/lib/pure/random.nim
+++ b/lib/pure/random.nim
@@ -14,6 +14,8 @@
 ##
 ## **Do not use this module for cryptographic purposes!**
 
+import algorithm                    #For upperBound
+
 include "system/inclrtl"
 {.push debugger:off.}
 
@@ -155,14 +157,45 @@ proc rand*[T](x: HSlice[T, T]): T =
   ## For a slice `a .. b` returns a value in the range `a .. b`.
   result = rand(state, x)
 
-proc rand*[T](r: var Rand; a: openArray[T]): T =
+proc rand*[T](r: var Rand; a: openArray[T]): T {.deprecated.} =
   ## returns a random element from the openarray `a`.
+  ## **Deprecated since v0.20.0:** use ``sample`` instead.
   result = a[rand(r, a.low..a.high)]
 
-proc rand*[T](a: openArray[T]): T =
+proc rand*[T](a: openArray[T]): T {.deprecated.} =
   ## returns a random element from the openarray `a`.
+  ## **Deprecated since v0.20.0:** use ``sample`` instead.
   result = a[rand(a.low..a.high)]
 
+proc sample*[T](r: var Rand; a: openArray[T]): T =
+  ## returns a random element from openArray ``a`` using state in ``r``.
+  result = a[r.rand(a.low..a.high)]
+
+proc sample*[T](a: openArray[T]): T =
+  ## returns a random element from openArray ``a`` using non-thread-safe state.
+  result = a[rand(a.low..a.high)]
+
+proc sample*[T, U](r: var Rand; a: openArray[T], w: openArray[U], n=1): seq[T] =
+  ## Return a sample (with replacement) of size ``n`` from elements of ``a``
+  ## according to convertible-to-``float``, not necessarily normalized, and
+  ## non-negative weights ``w``.  Uses state in ``r``.  Must have sum ``w > 0.0``.
+  assert(w.len == a.len)
+  var cdf = newSeq[float](a.len)   # The *unnormalized* CDF
+  var tot = 0.0                    # Unnormalized is fine if we sample up to tot
+  for i, w in w:
+    assert(w >= 0)
+    tot += float(w)
+    cdf[i] = tot
+  assert(tot > 0.0)                # Need at least one non-zero weight
+  for i in 0 ..< n:
+    result.add(a[cdf.upperBound(r.rand(tot))])
+
+proc sample*[T, U](a: openArray[T], w: openArray[U], n=1): seq[T] =
+  ## Return a sample (with replacement) of size ``n`` from elements of ``a``
+  ## according to convertible-to-``float``, not necessarily normalized, and
+  ## non-negative weights ``w``.  Uses default non-thread-safe state.
+  state.sample(a, w, n)
+
 
 proc initRand*(seed: int64): Rand =
   ## Creates a new ``Rand`` state from ``seed``.
diff --git a/tests/stdlib/tmath.nim b/tests/stdlib/tmath.nim
index 581308a7e..7c1851e7a 100644
--- a/tests/stdlib/tmath.nim
+++ b/tests/stdlib/tmath.nim
@@ -4,6 +4,8 @@ discard """
 
 [Suite] random float
 
+[Suite] random sample
+
 [Suite] ^
 
 '''
@@ -11,7 +13,7 @@ discard """
 
 import math, random, os
 import unittest
-import sets
+import sets, tables
 
 suite "random int":
   test "there might be some randomness":
@@ -72,6 +74,30 @@ suite "random float":
     var rand2:float = random(1000000.0)
     check rand1 != rand2
 
+suite "random sample":
+  test "non-uniform array sample":
+    let values = [ 10, 20, 30, 40, 50 ] # values
+    let weight = [ 4, 3, 2, 1, 0 ]      # weights aka unnormalized probabilities
+    let weightSum = 10.0                # sum of weights
+    var histo = initCountTable[int]()
+    for v in sample(values, weight, 5000):
+      histo.inc(v)
+    check histo.len == 4                # number of non-zero in `weight`
+    # Any one bin is a binomial random var for n samples, each with prob p of
+    # adding a count to k; E[k]=p*n, Var k=p*(1-p)*n, approximately Normal for
+    # big n.  So, P(abs(k - p*n)/sqrt(p*(1-p)*n))>3.0) =~ 0.0027, while
+    # P(wholeTestFails) =~ 1 - P(binPasses)^4 =~ 1 - (1-0.0027)^4 =~ 0.01.
+    for i, w in weight:
+      if w == 0:
+        check values[i] notin histo
+        continue
+      let p = float(w) / float(weightSum)
+      let n = 5000.0
+      let expected = p * n
+      let stdDev = sqrt(n * p * (1.0 - p))
+      check abs(float(histo[values[i]]) - expected) <= 3.0 * stdDev
+
+
 suite "^":
   test "compiles for valid types":
     check: compiles(5 ^ 2)