diff options
author | c-blake <c-blake@users.noreply.github.com> | 2018-12-23 07:23:20 -0500 |
---|---|---|
committer | Dominik Picheta <dominikpicheta@googlemail.com> | 2018-12-23 12:23:20 +0000 |
commit | e1d5356ae9fe0e289cc6fe21fec486c54f89400a (patch) | |
tree | 827a126eedd14656cfc2645ada998d4851be9d3d /lib/pure | |
parent | d407af565f164a2182e9d72fa69d1263cd7b3f9f (diff) | |
download | Nim-e1d5356ae9fe0e289cc6fe21fec486c54f89400a.tar.gz |
Add ability to sample elements from openArray according to a weight array (#10072)
* Add the ability to sample elements from an openArray according to a parallel array of weights/unnormalized probabilities (any sort of histogram, basically). Also add a non-thread safe version for convenience. * Address Araq comments on https://github.com/nim-lang/Nim/pull/10072 * import at top of file and space after '#'. * Put in a check for non-zero total weight. * Clarify constraint on `w`. * Rename `rand(openArray[T])` to `sample(openArray[T])` to `sample`, deprecating old name and name new (openArray[T], openArray[U]) variants `sample`. * Rename caller-provided state version of rand(openArray[T]) and also clean up doc comments. * Add test for new non-uniform array sampler. 3 sd bound makes it 99% likely that it will still pass in the future if the random number generator changes. We cannot both have a tight bound to check distribution *and* loose check to ensure resilience to RNG changes. (We cannot *guarantee* resilience, anyway. There's always a small chance any test hits a legitimate random fluctuation.)
Diffstat (limited to 'lib/pure')
-rw-r--r-- | lib/pure/random.nim | 37 |
1 files changed, 35 insertions, 2 deletions
diff --git a/lib/pure/random.nim b/lib/pure/random.nim index c458d51eb..d6501c87e 100644 --- a/lib/pure/random.nim +++ b/lib/pure/random.nim @@ -14,6 +14,8 @@ ## ## **Do not use this module for cryptographic purposes!** +import algorithm #For upperBound + include "system/inclrtl" {.push debugger:off.} @@ -155,14 +157,45 @@ proc rand*[T](x: HSlice[T, T]): T = ## For a slice `a .. b` returns a value in the range `a .. b`. result = rand(state, x) -proc rand*[T](r: var Rand; a: openArray[T]): T = +proc rand*[T](r: var Rand; a: openArray[T]): T {.deprecated.} = ## returns a random element from the openarray `a`. + ## **Deprecated since v0.20.0:** use ``sample`` instead. result = a[rand(r, a.low..a.high)] -proc rand*[T](a: openArray[T]): T = +proc rand*[T](a: openArray[T]): T {.deprecated.} = ## returns a random element from the openarray `a`. + ## **Deprecated since v0.20.0:** use ``sample`` instead. result = a[rand(a.low..a.high)] +proc sample*[T](r: var Rand; a: openArray[T]): T = + ## returns a random element from openArray ``a`` using state in ``r``. + result = a[r.rand(a.low..a.high)] + +proc sample*[T](a: openArray[T]): T = + ## returns a random element from openArray ``a`` using non-thread-safe state. + result = a[rand(a.low..a.high)] + +proc sample*[T, U](r: var Rand; a: openArray[T], w: openArray[U], n=1): seq[T] = + ## Return a sample (with replacement) of size ``n`` from elements of ``a`` + ## according to convertible-to-``float``, not necessarily normalized, and + ## non-negative weights ``w``. Uses state in ``r``. Must have sum ``w > 0.0``. + assert(w.len == a.len) + var cdf = newSeq[float](a.len) # The *unnormalized* CDF + var tot = 0.0 # Unnormalized is fine if we sample up to tot + for i, w in w: + assert(w >= 0) + tot += float(w) + cdf[i] = tot + assert(tot > 0.0) # Need at least one non-zero weight + for i in 0 ..< n: + result.add(a[cdf.upperBound(r.rand(tot))]) + +proc sample*[T, U](a: openArray[T], w: openArray[U], n=1): seq[T] = + ## Return a sample (with replacement) of size ``n`` from elements of ``a`` + ## according to convertible-to-``float``, not necessarily normalized, and + ## non-negative weights ``w``. Uses default non-thread-safe state. + state.sample(a, w, n) + proc initRand*(seed: int64): Rand = ## Creates a new ``Rand`` state from ``seed``. |