summary refs log tree commit diff stats
diff options
context:
space:
mode:
authorc-blake <c-blake@users.noreply.github.com>2018-12-23 07:23:20 -0500
committerDominik Picheta <dominikpicheta@googlemail.com>2018-12-23 12:23:20 +0000
commite1d5356ae9fe0e289cc6fe21fec486c54f89400a (patch)
tree827a126eedd14656cfc2645ada998d4851be9d3d
parentd407af565f164a2182e9d72fa69d1263cd7b3f9f (diff)
downloadNim-e1d5356ae9fe0e289cc6fe21fec486c54f89400a.tar.gz
Add ability to sample elements from openArray according to a weight array (#10072)
* Add the ability to sample elements from an openArray according to a parallel
array of weights/unnormalized probabilities (any sort of histogram, basically).
Also add a non-thread safe version for convenience.

* Address Araq comments on https://github.com/nim-lang/Nim/pull/10072

* import at top of file and space after '#'.

* Put in a check for non-zero total weight.

* Clarify constraint on `w`.

* Rename `rand(openArray[T])` to `sample(openArray[T])` to `sample`, deprecating
old name and name new (openArray[T], openArray[U]) variants `sample`.

* Rename caller-provided state version of rand(openArray[T]) and also clean
up doc comments.

* Add test for new non-uniform array sampler.  3 sd bound makes it 99% likely
that it will still pass in the future if the random number generator changes.
We cannot both have a tight bound to check distribution *and* loose check to
ensure resilience to RNG changes.  (We cannot *guarantee* resilience, anyway.
There's always a small chance any test hits a legitimate random fluctuation.)
-rw-r--r--lib/pure/random.nim37
-rw-r--r--tests/stdlib/tmath.nim28
2 files changed, 62 insertions, 3 deletions
diff --git a/lib/pure/random.nim b/lib/pure/random.nim
index c458d51eb..d6501c87e 100644
--- a/lib/pure/random.nim
+++ b/lib/pure/random.nim
@@ -14,6 +14,8 @@
 ##
 ## **Do not use this module for cryptographic purposes!**
 
+import algorithm                    #For upperBound
+
 include "system/inclrtl"
 {.push debugger:off.}
 
@@ -155,14 +157,45 @@ proc rand*[T](x: HSlice[T, T]): T =
   ## For a slice `a .. b` returns a value in the range `a .. b`.
   result = rand(state, x)
 
-proc rand*[T](r: var Rand; a: openArray[T]): T =
+proc rand*[T](r: var Rand; a: openArray[T]): T {.deprecated.} =
   ## returns a random element from the openarray `a`.
+  ## **Deprecated since v0.20.0:** use ``sample`` instead.
   result = a[rand(r, a.low..a.high)]
 
-proc rand*[T](a: openArray[T]): T =
+proc rand*[T](a: openArray[T]): T {.deprecated.} =
   ## returns a random element from the openarray `a`.
+  ## **Deprecated since v0.20.0:** use ``sample`` instead.
   result = a[rand(a.low..a.high)]
 
+proc sample*[T](r: var Rand; a: openArray[T]): T =
+  ## returns a random element from openArray ``a`` using state in ``r``.
+  result = a[r.rand(a.low..a.high)]
+
+proc sample*[T](a: openArray[T]): T =
+  ## returns a random element from openArray ``a`` using non-thread-safe state.
+  result = a[rand(a.low..a.high)]
+
+proc sample*[T, U](r: var Rand; a: openArray[T], w: openArray[U], n=1): seq[T] =
+  ## Return a sample (with replacement) of size ``n`` from elements of ``a``
+  ## according to convertible-to-``float``, not necessarily normalized, and
+  ## non-negative weights ``w``.  Uses state in ``r``.  Must have sum ``w > 0.0``.
+  assert(w.len == a.len)
+  var cdf = newSeq[float](a.len)   # The *unnormalized* CDF
+  var tot = 0.0                    # Unnormalized is fine if we sample up to tot
+  for i, w in w:
+    assert(w >= 0)
+    tot += float(w)
+    cdf[i] = tot
+  assert(tot > 0.0)                # Need at least one non-zero weight
+  for i in 0 ..< n:
+    result.add(a[cdf.upperBound(r.rand(tot))])
+
+proc sample*[T, U](a: openArray[T], w: openArray[U], n=1): seq[T] =
+  ## Return a sample (with replacement) of size ``n`` from elements of ``a``
+  ## according to convertible-to-``float``, not necessarily normalized, and
+  ## non-negative weights ``w``.  Uses default non-thread-safe state.
+  state.sample(a, w, n)
+
 
 proc initRand*(seed: int64): Rand =
   ## Creates a new ``Rand`` state from ``seed``.
diff --git a/tests/stdlib/tmath.nim b/tests/stdlib/tmath.nim
index 581308a7e..7c1851e7a 100644
--- a/tests/stdlib/tmath.nim
+++ b/tests/stdlib/tmath.nim
@@ -4,6 +4,8 @@ discard """
 
 [Suite] random float
 
+[Suite] random sample
+
 [Suite] ^
 
 '''
@@ -11,7 +13,7 @@ discard """
 
 import math, random, os
 import unittest
-import sets
+import sets, tables
 
 suite "random int":
   test "there might be some randomness":
@@ -72,6 +74,30 @@ suite "random float":
     var rand2:float = random(1000000.0)
     check rand1 != rand2
 
+suite "random sample":
+  test "non-uniform array sample":
+    let values = [ 10, 20, 30, 40, 50 ] # values
+    let weight = [ 4, 3, 2, 1, 0 ]      # weights aka unnormalized probabilities
+    let weightSum = 10.0                # sum of weights
+    var histo = initCountTable[int]()
+    for v in sample(values, weight, 5000):
+      histo.inc(v)
+    check histo.len == 4                # number of non-zero in `weight`
+    # Any one bin is a binomial random var for n samples, each with prob p of
+    # adding a count to k; E[k]=p*n, Var k=p*(1-p)*n, approximately Normal for
+    # big n.  So, P(abs(k - p*n)/sqrt(p*(1-p)*n))>3.0) =~ 0.0027, while
+    # P(wholeTestFails) =~ 1 - P(binPasses)^4 =~ 1 - (1-0.0027)^4 =~ 0.01.
+    for i, w in weight:
+      if w == 0:
+        check values[i] notin histo
+        continue
+      let p = float(w) / float(weightSum)
+      let n = 5000.0
+      let expected = p * n
+      let stdDev = sqrt(n * p * (1.0 - p))
+      check abs(float(histo[values[i]]) - expected) <= 3.0 * stdDev
+
+
 suite "^":
   test "compiles for valid types":
     check: compiles(5 ^ 2)
='n360' href='#n360'>360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413