summary refs log tree commit diff stats
path: root/tests/concepts/trandomvars.nim
blob: db41aa9019904044de743989cede75942802288c (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
discard """
output: '''
true
true
true
3
18.0
324.0
'''
"""

type RNG = object

proc random(rng: var RNG): float = 1.0

type
  RandomVar[A] = concept x
    var rng: RNG
    rng.sample(x) is A

  Constant[A] = object
    value: A

  Uniform = object
    a, b: float

  ClosureVar[A] = proc(rng: var RNG): A

proc sample[A](rng: var RNG, c: Constant[A]): A = c.value

proc sample(rng: var RNG, u: Uniform): float = u.a + (u.b - u.a) * rng.random()

proc sample[A](rng: var RNG, c: ClosureVar[A]): A = c(rng)

proc constant[A](a: A): Constant[A] = Constant[A](value: a)

proc uniform(a, b: float): Uniform = Uniform(a: a, b: b)

proc lift1[A, B](f: proc(a: A): B, r: RandomVar[A]): ClosureVar[B] =
  proc inner(rng: var RNG): B = f(rng.sample(r))

  return inner

when isMainModule:
  proc sq(x: float): float = x * x

  let
    c = constant(3)
    u = uniform(2, 18)
    t = lift1(sq, u)

  var rng: RNG

  echo(c is RandomVar[int])
  echo(u is RandomVar[float])
  echo(t is RandomVar[float])

  echo rng.sample(c)
  echo rng.sample(u)
  echo rng.sample(t)