summary refs log tree commit diff stats
path: root/lib/pure
diff options
context:
space:
mode:
authorc-blake <c-blake@users.noreply.github.com>2018-12-23 07:23:20 -0500
committerDominik Picheta <dominikpicheta@googlemail.com>2018-12-23 12:23:20 +0000
commite1d5356ae9fe0e289cc6fe21fec486c54f89400a (patch)
tree827a126eedd14656cfc2645ada998d4851be9d3d /lib/pure
parentd407af565f164a2182e9d72fa69d1263cd7b3f9f (diff)
downloadNim-e1d5356ae9fe0e289cc6fe21fec486c54f89400a.tar.gz
Add ability to sample elements from openArray according to a weight array (#10072)
* Add the ability to sample elements from an openArray according to a parallel
array of weights/unnormalized probabilities (any sort of histogram, basically).
Also add a non-thread safe version for convenience.

* Address Araq comments on https://github.com/nim-lang/Nim/pull/10072

* import at top of file and space after '#'.

* Put in a check for non-zero total weight.

* Clarify constraint on `w`.

* Rename `rand(openArray[T])` to `sample(openArray[T])` to `sample`, deprecating
old name and name new (openArray[T], openArray[U]) variants `sample`.

* Rename caller-provided state version of rand(openArray[T]) and also clean
up doc comments.

* Add test for new non-uniform array sampler.  3 sd bound makes it 99% likely
that it will still pass in the future if the random number generator changes.
We cannot both have a tight bound to check distribution *and* loose check to
ensure resilience to RNG changes.  (We cannot *guarantee* resilience, anyway.
There's always a small chance any test hits a legitimate random fluctuation.)
Diffstat (limited to 'lib/pure')
-rw-r--r--lib/pure/random.nim37
1 files changed, 35 insertions, 2 deletions
diff --git a/lib/pure/random.nim b/lib/pure/random.nim
index c458d51eb..d6501c87e 100644
--- a/lib/pure/random.nim
+++ b/lib/pure/random.nim
@@ -14,6 +14,8 @@
 ##
 ## **Do not use this module for cryptographic purposes!**
 
+import algorithm                    #For upperBound
+
 include "system/inclrtl"
 {.push debugger:off.}
 
@@ -155,14 +157,45 @@ proc rand*[T](x: HSlice[T, T]): T =
   ## For a slice `a .. b` returns a value in the range `a .. b`.
   result = rand(state, x)
 
-proc rand*[T](r: var Rand; a: openArray[T]): T =
+proc rand*[T](r: var Rand; a: openArray[T]): T {.deprecated.} =
   ## returns a random element from the openarray `a`.
+  ## **Deprecated since v0.20.0:** use ``sample`` instead.
   result = a[rand(r, a.low..a.high)]
 
-proc rand*[T](a: openArray[T]): T =
+proc rand*[T](a: openArray[T]): T {.deprecated.} =
   ## returns a random element from the openarray `a`.
+  ## **Deprecated since v0.20.0:** use ``sample`` instead.
   result = a[rand(a.low..a.high)]
 
+proc sample*[T](r: var Rand; a: openArray[T]): T =
+  ## returns a random element from openArray ``a`` using state in ``r``.
+  result = a[r.rand(a.low..a.high)]
+
+proc sample*[T](a: openArray[T]): T =
+  ## returns a random element from openArray ``a`` using non-thread-safe state.
+  result = a[rand(a.low..a.high)]
+
+proc sample*[T, U](r: var Rand; a: openArray[T], w: openArray[U], n=1): seq[T] =
+  ## Return a sample (with replacement) of size ``n`` from elements of ``a``
+  ## according to convertible-to-``float``, not necessarily normalized, and
+  ## non-negative weights ``w``.  Uses state in ``r``.  Must have sum ``w > 0.0``.
+  assert(w.len == a.len)
+  var cdf = newSeq[float](a.len)   # The *unnormalized* CDF
+  var tot = 0.0                    # Unnormalized is fine if we sample up to tot
+  for i, w in w:
+    assert(w >= 0)
+    tot += float(w)
+    cdf[i] = tot
+  assert(tot > 0.0)                # Need at least one non-zero weight
+  for i in 0 ..< n:
+    result.add(a[cdf.upperBound(r.rand(tot))])
+
+proc sample*[T, U](a: openArray[T], w: openArray[U], n=1): seq[T] =
+  ## Return a sample (with replacement) of size ``n`` from elements of ``a``
+  ## according to convertible-to-``float``, not necessarily normalized, and
+  ## non-negative weights ``w``.  Uses default non-thread-safe state.
+  state.sample(a, w, n)
+
 
 proc initRand*(seed: int64): Rand =
   ## Creates a new ``Rand`` state from ``seed``.