summary refs log tree commit diff stats
diff options
context:
space:
mode:
authorTimothee Cour <timothee.cour2@gmail.com>2021-07-23 04:41:16 -0700
committerGitHub <noreply@github.com>2021-07-23 13:41:16 +0200
commit76f74fae88fba72f58e43ad5c8fd20a7c4d8b439 (patch)
treeea38cd8625e1fafa7ea30598f967c74795e0f0cb
parentf62f4159f8c793ba24fa214b26c4dc68a530bc2e (diff)
downloadNim-76f74fae88fba72f58e43ad5c8fd20a7c4d8b439.tar.gz
std/random: fix overflow bugs; fixes #16360; fixes #16296; fixes #17670 (#18456)
-rw-r--r--lib/pure/random.nim32
-rw-r--r--tests/stdlib/trandom.nim63
2 files changed, 85 insertions, 10 deletions
diff --git a/lib/pure/random.nim b/lib/pure/random.nim
index 07db0365c..a292386af 100644
--- a/lib/pure/random.nim
+++ b/lib/pure/random.nim
@@ -204,6 +204,19 @@ proc skipRandomNumbers*(s: var Rand) =
   s.a0 = s0
   s.a1 = s1
 
+proc rand[T: uint | uint64](r: var Rand; max: T): T =
+  # xxx export in future work
+  if max == 0: return
+  else:
+    let max = uint64(max)
+    when T.high.uint64 == uint64.high:
+      if max == uint64.high: return T(next(r))
+    while true:
+      let x = next(r)
+      # avoid `mod` bias
+      if x <= randMax - (randMax mod max):
+        return T(x mod (max + 1))
+
 proc rand*(r: var Rand; max: Natural): int {.benign.} =
   ## Returns a random integer in the range `0..max` using the given state.
   ##
@@ -213,15 +226,13 @@ proc rand*(r: var Rand; max: Natural): int {.benign.} =
   ## * `rand proc<#rand,Rand,HSlice[T: Ordinal or float or float32 or float64,T: Ordinal or float or float32 or float64]>`_
   ##   that accepts a slice
   ## * `rand proc<#rand,typedesc[T]>`_ that accepts an integer or range type
-  runnableExamples("-r:off"):
+  runnableExamples:
     var r = initRand(123)
-    assert r.rand(100) == 96 # implementation defined
-
-  if max == 0: return
-  while true:
-    let x = next(r)
-    if x <= randMax - (randMax mod Ui(max)):
-      return int(x mod (uint64(max) + 1u64))
+    if false:
+      assert r.rand(100) == 96 # implementation defined
+  # bootstrap: can't use `runnableExamples("-r:off")`
+  cast[int](rand(r, uint64(max)))
+    # xxx toUnsigned pending https://github.com/nim-lang/Nim/pull/18445
 
 proc rand*(max: int): int {.benign.} =
   ## Returns a random integer in the range `0..max`.
@@ -306,7 +317,10 @@ proc rand*[T: Ordinal or SomeFloat](r: var Rand; x: HSlice[T, T]): T =
   when T is SomeFloat:
     result = rand(r, x.b - x.a) + x.a
   else: # Integers and Enum types
-    result = T(rand(r, int(x.b) - int(x.a)) + int(x.a))
+    when defined(js):
+      result = cast[T](rand(r, cast[uint](x.b) - cast[uint](x.a)) + cast[uint](x.a))
+    else:
+      result = cast[T](rand(r, cast[uint64](x.b) - cast[uint64](x.a)) + cast[uint64](x.a))
 
 proc rand*[T: Ordinal or SomeFloat](x: HSlice[T, T]): T =
   ## For a slice `a..b`, returns a value in the range `a..b`.
diff --git a/tests/stdlib/trandom.nim b/tests/stdlib/trandom.nim
index e47ddad66..39ccca85b 100644
--- a/tests/stdlib/trandom.nim
+++ b/tests/stdlib/trandom.nim
@@ -3,7 +3,9 @@ discard """
   targets: "c js"
 """
 
-import std/[random, math, os, stats, sets, tables]
+import std/[random, math, stats, sets, tables]
+when not defined(js):
+  import std/os
 
 randomize(233)
 
@@ -187,3 +189,62 @@ block: # bug #17467
     doAssert x > 1e-4, $(x, i)
       # This used to fail for each i in 0..<26844, i.e. the 1st produced value
       # was predictable and < 1e-4, skewing distributions.
+
+const withUint = false # pending exporting `proc rand[T: uint | uint64](r: var Rand; max: T): T =`
+
+block: # bug #16360
+  var r = initRand()
+  template test(a) =
+    let a2 = a
+    block:
+      let a3 = r.rand(a2)
+      doAssert a3 <= a2
+      doAssert a3.type is a2.type
+    block:
+      let a3 = rand(a2)
+      doAssert a3 <= a2
+      doAssert a3.type is a2.type
+  when withUint:
+    test cast[uint](int.high)
+    test cast[uint](int.high) + 1
+    when not defined(js):
+      # pending bug #16411
+      test uint64.high
+      test uint64.high - 1
+    test uint.high - 2
+    test uint.high - 1
+    test uint.high
+  test int.high
+  test int.high - 1
+  test int.high - 2
+  test 0
+  when withUint:
+    test 0'u
+    test 0'u64
+
+block: # bug #16296
+  var r = initRand()
+  template test(x) =
+    let a2 = x
+    let a3 = r.rand(a2)
+    doAssert a3 <= a2.b
+    doAssert a3 >= a2.a
+    doAssert a3.type is a2.a.type
+  test(-2 .. int.high-1)
+  test(int.low .. int.high)
+  test(int.low+1 .. int.high)
+  test(int.low .. int.high-1)
+  test(int.low .. 0)
+  test(int.low .. -1)
+  test(int.low .. 1)
+  test(int64.low .. 1'i64)
+  when not defined(js):
+    # pending bug #16411
+    test(10'u64 .. uint64.high)
+
+block: # bug #17670
+  when not defined(js):
+    # pending bug #16411
+    type UInt48 = range[0'u64..2'u64^48-1]
+    let x = rand(UInt48)
+    doAssert x is UInt48