From 76f74fae88fba72f58e43ad5c8fd20a7c4d8b439 Mon Sep 17 00:00:00 2001 From: Timothee Cour Date: Fri, 23 Jul 2021 04:41:16 -0700 Subject: std/random: fix overflow bugs; fixes #16360; fixes #16296; fixes #17670 (#18456) --- lib/pure/random.nim | 32 +++++++++++++++++------- tests/stdlib/trandom.nim | 63 +++++++++++++++++++++++++++++++++++++++++++++++- 2 files changed, 85 insertions(+), 10 deletions(-) diff --git a/lib/pure/random.nim b/lib/pure/random.nim index 07db0365c..a292386af 100644 --- a/lib/pure/random.nim +++ b/lib/pure/random.nim @@ -204,6 +204,19 @@ proc skipRandomNumbers*(s: var Rand) = s.a0 = s0 s.a1 = s1 +proc rand[T: uint | uint64](r: var Rand; max: T): T = + # xxx export in future work + if max == 0: return + else: + let max = uint64(max) + when T.high.uint64 == uint64.high: + if max == uint64.high: return T(next(r)) + while true: + let x = next(r) + # avoid `mod` bias + if x <= randMax - (randMax mod max): + return T(x mod (max + 1)) + proc rand*(r: var Rand; max: Natural): int {.benign.} = ## Returns a random integer in the range `0..max` using the given state. ## @@ -213,15 +226,13 @@ proc rand*(r: var Rand; max: Natural): int {.benign.} = ## * `rand proc<#rand,Rand,HSlice[T: Ordinal or float or float32 or float64,T: Ordinal or float or float32 or float64]>`_ ## that accepts a slice ## * `rand proc<#rand,typedesc[T]>`_ that accepts an integer or range type - runnableExamples("-r:off"): + runnableExamples: var r = initRand(123) - assert r.rand(100) == 96 # implementation defined - - if max == 0: return - while true: - let x = next(r) - if x <= randMax - (randMax mod Ui(max)): - return int(x mod (uint64(max) + 1u64)) + if false: + assert r.rand(100) == 96 # implementation defined + # bootstrap: can't use `runnableExamples("-r:off")` + cast[int](rand(r, uint64(max))) + # xxx toUnsigned pending https://github.com/nim-lang/Nim/pull/18445 proc rand*(max: int): int {.benign.} = ## Returns a random integer in the range `0..max`. @@ -306,7 +317,10 @@ proc rand*[T: Ordinal or SomeFloat](r: var Rand; x: HSlice[T, T]): T = when T is SomeFloat: result = rand(r, x.b - x.a) + x.a else: # Integers and Enum types - result = T(rand(r, int(x.b) - int(x.a)) + int(x.a)) + when defined(js): + result = cast[T](rand(r, cast[uint](x.b) - cast[uint](x.a)) + cast[uint](x.a)) + else: + result = cast[T](rand(r, cast[uint64](x.b) - cast[uint64](x.a)) + cast[uint64](x.a)) proc rand*[T: Ordinal or SomeFloat](x: HSlice[T, T]): T = ## For a slice `a..b`, returns a value in the range `a..b`. diff --git a/tests/stdlib/trandom.nim b/tests/stdlib/trandom.nim index e47ddad66..39ccca85b 100644 --- a/tests/stdlib/trandom.nim +++ b/tests/stdlib/trandom.nim @@ -3,7 +3,9 @@ discard """ targets: "c js" """ -import std/[random, math, os, stats, sets, tables] +import std/[random, math, stats, sets, tables] +when not defined(js): + import std/os randomize(233) @@ -187,3 +189,62 @@ block: # bug #17467 doAssert x > 1e-4, $(x, i) # This used to fail for each i in 0..<26844, i.e. the 1st produced value # was predictable and < 1e-4, skewing distributions. + +const withUint = false # pending exporting `proc rand[T: uint | uint64](r: var Rand; max: T): T =` + +block: # bug #16360 + var r = initRand() + template test(a) = + let a2 = a + block: + let a3 = r.rand(a2) + doAssert a3 <= a2 + doAssert a3.type is a2.type + block: + let a3 = rand(a2) + doAssert a3 <= a2 + doAssert a3.type is a2.type + when withUint: + test cast[uint](int.high) + test cast[uint](int.high) + 1 + when not defined(js): + # pending bug #16411 + test uint64.high + test uint64.high - 1 + test uint.high - 2 + test uint.high - 1 + test uint.high + test int.high + test int.high - 1 + test int.high - 2 + test 0 + when withUint: + test 0'u + test 0'u64 + +block: # bug #16296 + var r = initRand() + template test(x) = + let a2 = x + let a3 = r.rand(a2) + doAssert a3 <= a2.b + doAssert a3 >= a2.a + doAssert a3.type is a2.a.type + test(-2 .. int.high-1) + test(int.low .. int.high) + test(int.low+1 .. int.high) + test(int.low .. int.high-1) + test(int.low .. 0) + test(int.low .. -1) + test(int.low .. 1) + test(int64.low .. 1'i64) + when not defined(js): + # pending bug #16411 + test(10'u64 .. uint64.high) + +block: # bug #17670 + when not defined(js): + # pending bug #16411 + type UInt48 = range[0'u64..2'u64^48-1] + let x = rand(UInt48) + doAssert x is UInt48 -- cgit 1.4.1-2-gfad0