summary refs log tree commit diff stats
path: root/tests/macros/tincremental.nim
blob: 401d6f3f8446fe7ba9594cd88bfab972869cd6cb (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
discard """
  output: '''heavy_calc_impl is called
sub_calc1_impl is called
sub_calc2_impl is called
** no changes recompute effectively
** change one input and recompute effectively
heavy_calc_impl is called
sub_calc2_impl is called'''
"""

# sample incremental

import tables
import macros

var inputs = initTable[string, float]() 
var cache = initTable[string, float]()
var dep_tree {.compileTime.} = initTable[string, string]()

macro symHash(s: typed{nkSym}): string = 
  result = newStrLitNode(symBodyHash(s))

#######################################################################################

template graph_node(key: string) {.pragma.}

proc tag(n: NimNode): NimNode = 
  ## returns graph node unique name of a function or nil if it is not a graph node
  expectKind(n, {nnkProcDef, nnkFuncDef})
  for p in n.pragma:
    if p.len > 0 and p[0] == bindSym"graph_node":
      return p[1]
  return nil 

macro graph_node_key(n: typed{nkSym}): untyped =
  result = newStrLitNode(n.symBodyHash)

macro graph_discovery(n: typed{nkSym}): untyped =
  # discovers graph dependency tree and updated dep_tree global var
  let mytag = newStrLitNode(n.symBodyHash)
  var visited: seq[NimNode]
  proc discover(n: NimNode) = 
    case n.kind:
      of nnkNone..pred(nnkSym), succ(nnkSym)..nnkNilLit: discard
      of nnkSym:
        if n.symKind in {nskFunc, nskProc}:
          if n notin visited:
            visited.add n
            let tag = n.getImpl.tag
            if tag != nil:
              dep_tree[tag.strVal] =  mytag.strVal
            else:
              discover(n.getImpl.body)
      else:
        for child in n:
          discover(child)
  discover(n.getImpl.body)
  result = newEmptyNode()

#######################################################################################

macro incremental_input(key: static[string], n: untyped{nkFuncDef}): untyped =
  # mark leaf nodes of the graph
  template getInput(key) {.dirty.} =
    {.noSideEffect.}:
      inputs[key]
  result = n
  result.pragma = nnkPragma.newTree(nnkCall.newTree(bindSym"graph_node", newStrLitNode(key)))
  result.body = getAst(getInput(key))

macro incremental(n: untyped{nkFuncDef}): untyped =
  ## incrementalize side effect free computation
  ## wraps function into caching layer, mark caching function as a graph_node
  ## injects dependency discovery between graph nodes
  template cache_func_body(func_name, func_name_str, func_call) {.dirty.} =
    {.noSideEffect.}: 
      graph_discovery(func_name)
      let key = graph_node_key(func_name)
      if key in cache:
        result = cache[key]
      else:
        echo func_name_str & " is called"
        result = func_call
        cache[key] = result

  let func_name = n.name.strVal & "_impl"
  let func_call = nnkCall.newTree(ident func_name)
  for i in 1..<n.params.len:
    func_call.add n.params[i][0]
  let cache_func = n.copyNimTree
  cache_func.body = getAst(cache_func_body(ident func_name, func_name, func_call))
  cache_func.pragma = nnkPragma.newTree(newCall(bindSym"graph_node", 
    newCall(bindSym"symHash", ident func_name)))
  
  n.name = ident(func_name)
  result = nnkStmtList.newTree(n, cache_func)

###########################################################################
### Example
###########################################################################

func input1(): float {.incremental_input("a1").}

func input2(): float {.incremental_input("a2").}

func sub_calc1(a: float): float  {.incremental.} = 
  a + input1()

func sub_calc2(b: float): float  {.incremental.} = 
  b + input2()

func heavy_calc(a: float, b: float): float {.incremental.} = 
  sub_calc1(a) + sub_calc2(b)

###########################################################################
## graph finalize and inputs
###########################################################################

macro finalize_dep_tree(): untyped = 
  result = nnkTableConstr.newNimNode
  for key, val in dep_tree:
    result.add nnkExprColonExpr.newTree(newStrLitNode key, newStrLitNode val)
  result = nnkCall.newTree(bindSym"toTable", result)

const dep_tree_final = finalize_dep_tree()

proc set_input(key: string, val: float) = 
  ## set input value
  ## all affected nodes of graph are invalidated
  inputs[key] = val
  var k = key
  while k != "":
    k = dep_tree_final.getOrDefault(k , "")
    cache.del(k)

###########################################################################
## demo
###########################################################################

set_input("a1", 5)
set_input("a2", 2)
discard heavy_calc(5.0, 10.0)

echo "** no changes recompute effectively"
discard heavy_calc(5.0, 10.0)

echo "** change one input and recompute effectively"

set_input("a2", 10)
discard heavy_calc(5.0, 10.0)