summary refs log tree commit diff stats
path: root/compiler/semparallel.nim
blob: 5571d1bedb5356381b81ab00627abe9160594ca6 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
#
#
#           The Nimrod Compiler
#        (c) Copyright 2014 Andreas Rumpf
#
#    See the file "copying.txt", included in this
#    distribution, for details about the copyright.
#

## Semantic checking for 'parallel'.

# - codegen needs to support mSlice (+)
# - lowerings must not perform unnecessary copies (+)
# - slices should become "nocopy" to openArray (+)
#   - need to perform bound checks (+)
#
# - parallel needs to insert a barrier (+)
# - passed arguments need to be ensured to be "const"
#   - what about 'f(a)'? --> f shouldn't have side effects anyway
# - passed arrays need to be ensured not to alias
# - passed slices need to be ensured to be disjoint (+)
# - output slices need special logic (+)

import
  ast, astalgo, idents, lowerings, magicsys, guards, sempass2, msgs,
  renderer, types
from trees import getMagic
from strutils import `%`

discard """

one major problem:
  spawn f(a[i])
  inc i
  spawn f(a[i])
is valid, but
  spawn f(a[i])
  spawn f(a[i])
  inc i
is not! However, 
  spawn f(a[i])
  if guard: inc i
  spawn f(a[i])
is not valid either! --> We need a flow dependent analysis here.

However:
  while foo:
    spawn f(a[i])
    inc i
    spawn f(a[i])

Is not valid either! --> We should really restrict 'inc' to loop endings?

The heuristic that we implement here (that has no false positives) is: Usage
of 'i' in a slice *after* we determined the stride is invalid!
"""

type
  TDirection = enum
    ascending, descending
  MonotonicVar = object
    v, alias: PSym        # to support the ordinary 'countup' iterator
                          # we need to detect aliases
    lower, upper, stride: PNode
    dir: TDirection
    blacklisted: bool     # blacklisted variables that are not monotonic
  AnalysisCtx = object
    locals: seq[MonotonicVar]
    slices: seq[tuple[x,a,b: PNode, spawnId: int, inLoop: bool]]
    guards: TModel      # nested guards
    args: seq[PSym]     # args must be deeply immutable
    spawns: int         # we can check that at last 1 spawn is used in
                        # the 'parallel' section
    currentSpawnId: int
    inLoop: int

let opSlice = createMagic("slice", mSlice)

proc initAnalysisCtx(): AnalysisCtx =
  result.locals = @[]
  result.slices = @[]
  result.args = @[]
  result.guards = @[]

proc lookupSlot(c: AnalysisCtx; s: PSym): int =
  for i in 0.. <c.locals.len:
    if c.locals[i].v == s or c.locals[i].alias == s: return i
  return -1

proc getSlot(c: var AnalysisCtx; v: PSym): ptr MonotonicVar =
  let s = lookupSlot(c, v)
  if s >= 0: return addr(c.locals[s])
  let L = c.locals.len
  c.locals.setLen(L+1)
  c.locals[L].v = v
  return addr(c.locals[L])

proc gatherArgs(c: var AnalysisCtx; n: PNode) =
  for i in 0.. <n.safeLen:
    let root = getRoot n[i]
    if root != nil:
      block addRoot:
        for r in items(c.args):
          if r == root: break addRoot
        c.args.add root
    gatherArgs(c, n[i])

proc isSingleAssignable(n: PNode): bool =
  n.kind == nkSym and (let s = n.sym;
    s.kind in {skTemp, skForVar, skLet} and
          {sfAddrTaken, sfGlobal} * s.flags == {})

proc isLocal(n: PNode): bool =
  n.kind == nkSym and (let s = n.sym;
    s.kind in {skResult, skTemp, skForVar, skVar, skLet} and
          {sfAddrTaken, sfGlobal} * s.flags == {})

proc checkLocal(c: AnalysisCtx; n: PNode) =
  if isLocal(n):
    let s = c.lookupSlot(n.sym)
    if s >= 0 and c.locals[s].stride != nil:
      localError(n.info, "invalid usage of counter after increment")
  else:
    for i in 0 .. <n.safeLen: checkLocal(c, n.sons[i])

template `?`(x): expr = x.renderTree

proc checkLe(c: AnalysisCtx; a, b: PNode) =
  case proveLe(c.guards, a, b)
  of impUnknown:
    localError(a.info, "cannot prove: " & ?a & " <= " & ?b)
  of impYes: discard
  of impNo:
    localError(a.info, "can prove: " & ?a & " > " & ?b)

proc checkBounds(c: AnalysisCtx; arr, idx: PNode) =
  checkLe(c, arr.lowBound, idx)
  checkLe(c, idx, arr.highBound)

proc addLowerBoundAsFacts(c: var AnalysisCtx) =
  for v in c.locals:
    if not v.blacklisted:
      c.guards.addFactLe(v.lower, newSymNode(v.v))

proc addSlice(c: var AnalysisCtx; n: PNode; x, le, ri: PNode) =
  checkLocal(c, n)
  let le = le.canon
  let ri = ri.canon
  # perform static bounds checking here; and not later!
  let oldState = c.guards.len
  addLowerBoundAsFacts(c)
  c.checkBounds(x, le)
  c.checkBounds(x, ri)
  c.guards.setLen(oldState)
  c.slices.add((x, le, ri, c.currentSpawnId, c.inLoop > 0))

proc overlap(m: TModel; x,y,c,d: PNode) =
  #  X..Y and C..D overlap iff (X <= D and C <= Y)
  case proveLe(m, x, d)
  of impUnknown:
    localError(x.info,
      "cannot prove: $# > $#; required for ($#)..($#) disjoint from ($#)..($#)" %
        [?x, ?d, ?x, ?y, ?c, ?d])
  of impYes:
    case proveLe(m, c, y)
    of impUnknown:
      localError(x.info,
        "cannot prove: $# > $#; required for ($#)..($#) disjoint from ($#)..($#)" %
          [?c, ?y, ?x, ?y, ?c, ?d])
    of impYes:
      localError(x.info, "($#)..($#) not disjoint from ($#)..($#)" % [?x, ?y, ?c, ?d])
    of impNo: discard
  of impNo: discard

proc stride(c: AnalysisCtx; n: PNode): BiggestInt =
  if isLocal(n):
    let s = c.lookupSlot(n.sym)
    if s >= 0 and c.locals[s].stride != nil:
      result = c.locals[s].stride.intVal
  else:
    for i in 0 .. <n.safeLen: result += stride(c, n.sons[i])

proc subStride(c: AnalysisCtx; n: PNode): PNode =
  # substitute with stride:
  if isLocal(n):
    let s = c.lookupSlot(n.sym)
    if s >= 0 and c.locals[s].stride != nil:
      result = n +@ c.locals[s].stride.intVal
    else:
      result = n
  elif n.safeLen > 0:
    result = shallowCopy(n)
    for i in 0 .. <n.len: result.sons[i] = subStride(c, n.sons[i])
  else:
    result = n

proc checkSlicesAreDisjoint(c: var AnalysisCtx) =
  # this is the only thing that we need to perform after we have traversed
  # the whole tree so that the strides are available.
  # First we need to add all the computed lower bounds:
  addLowerBoundAsFacts(c)
  # Every slice used in a loop needs to be disjoint with itself:
  for x,a,b,id,inLoop in items(c.slices):
    if inLoop: overlap(c.guards, a,b, c.subStride(a), c.subStride(b))
  # Another tricky example is:
  #   while true:
  #     spawn f(a[i])
  #     spawn f(a[i+1])
  #     inc i  # inc i, 2  would be correct here
  #
  # Or even worse:
  #   while true:
  #     spawn f(a[i+1 .. i+3])
  #     spawn f(a[i+4 .. i+5])
  #     inc i, 4
  # Prove that i*k*stride + 3 != i*k'*stride + 5
  # For the correct example this amounts to
  #   i*k*2 != i*k'*2 + 1
  # which is true.
  # For now, we don't try to prove things like that at all, even though it'd
  # be feasible for many useful examples. Instead we attach the slice to
  # a spawn and if the attached spawns differ, we bail out:
  for i in 0 .. high(c.slices):
    for j in i+1 .. high(c.slices):
      let x = c.slices[i]
      let y = c.slices[j]
      if x.spawnId != y.spawnId and guards.sameTree(x.x, y.x):
        if not x.inLoop or not y.inLoop:
          # XXX strictly speaking, 'or' is not correct here and it needs to
          # be 'and'. However this prevents too many obviously correct programs
          # like f(a[0..x]); for i in x+1 .. a.high: f(a[i])
          overlap(c.guards, x.a, x.b, y.a, y.b)
        elif (let k = simpleSlice(x.a, x.b); let m = simpleSlice(y.a, y.b);
              k >= 0 and m >= 0):
          # ah I cannot resist the temptation and add another sweet heuristic:
          # if both slices have the form (i+k)..(i+k)  and (i+m)..(i+m) we
          # check they are disjoint and k < stride and m < stride:
          overlap(c.guards, x.a, x.b, y.a, y.b)
          let stride = min(c.stride(x.a), c.stride(y.a))
          if k < stride and m < stride:
            discard
          else:
            localError(x.x.info, "cannot prove ($#)..($#) disjoint from ($#)..($#)" %
              [?x.a, ?x.b, ?y.a, ?y.b])
        else:
          localError(x.x.info, "cannot prove ($#)..($#) disjoint from ($#)..($#)" %
            [?x.a, ?x.b, ?y.a, ?y.b])

proc analyse(c: var AnalysisCtx; n: PNode)

proc analyseSons(c: var AnalysisCtx; n: PNode) =
  for i in 0 .. <safeLen(n): analyse(c, n[i])

proc min(a, b: PNode): PNode =
  if a.isNil: result = b
  elif a.intVal < b.intVal: result = a
  else: result = b

proc fromSystem(op: PSym): bool = sfSystemModule in getModule(op).flags

proc analyseCall(c: var AnalysisCtx; n: PNode; op: PSym) =
  if op.magic == mSpawn:
    inc c.spawns
    let oldSpawnId = c.currentSpawnId
    c.currentSpawnId = c.spawns
    gatherArgs(c, n[1])
    analyseSons(c, n)
    c.currentSpawnId = oldSpawnId
  elif op.magic == mInc or (op.name.s == "+=" and op.fromSystem):
    if n[1].isLocal:
      let incr = n[2].skipConv
      if incr.kind in {nkCharLit..nkUInt32Lit} and incr.intVal > 0:
        let slot = c.getSlot(n[1].sym)
        slot.stride = min(slot.stride, incr)
    analyseSons(c, n)
  elif op.name.s == "[]" and op.fromSystem:
    c.addSlice(n, n[1], n[2][1], n[2][2])
    analyseSons(c, n)
  elif op.name.s == "[]=" and op.fromSystem:
    c.addSlice(n, n[1], n[2][1], n[2][2])
    analyseSons(c, n)
  else:
    analyseSons(c, n)

proc analyseCase(c: var AnalysisCtx; n: PNode) =
  analyse(c, n.sons[0])
  let oldFacts = c.guards.len
  for i in 1.. <n.len:
    let branch = n.sons[i]
    setLen(c.guards, oldFacts)
    addCaseBranchFacts(c.guards, n, i)
    for i in 0 .. <branch.len:
      analyse(c, branch.sons[i])
  setLen(c.guards, oldFacts)

proc analyseIf(c: var AnalysisCtx; n: PNode) =
  analyse(c, n.sons[0].sons[0])
  let oldFacts = c.guards.len
  addFact(c.guards, canon(n.sons[0].sons[0]))

  analyse(c, n.sons[0].sons[1])
  for i in 1.. <n.len:
    let branch = n.sons[i]
    setLen(c.guards, oldFacts)
    for j in 0..i-1:
      addFactNeg(c.guards, canon(n.sons[j].sons[0]))
    if branch.len > 1:
      addFact(c.guards, canon(branch.sons[0]))
    for i in 0 .. <branch.len:
      analyse(c, branch.sons[i])
  setLen(c.guards, oldFacts)

proc analyse(c: var AnalysisCtx; n: PNode) =
  case n.kind
  of nkAsgn, nkFastAsgn:
    if n[0].isSingleAssignable and n[1].isLocal:
      let slot = c.getSlot(n[1].sym)
      slot.alias = n[0].sym
    elif n[0].isLocal:
      # since we already ensure sfAddrTaken is not in s.flags, we only need to
      # prevent direct assignments to the monotonic variable:
      let slot = c.getSlot(n[0].sym)
      slot.blacklisted = true
    invalidateFacts(c.guards, n[0])
    analyseSons(c, n)
    addAsgnFact(c.guards, n[0], n[1])
  of nkCallKinds:
    # direct call:
    if n[0].kind == nkSym: analyseCall(c, n, n[0].sym)
    else: analyseSons(c, n)
  of nkBracketExpr:
    c.addSlice(n, n[0], n[1], n[1])
    analyseSons(c, n)
  of nkReturnStmt, nkRaiseStmt, nkTryStmt:
    localError(n.info, "invalid control flow for 'parallel'")
    # 'break' that leaves the 'parallel' section is not valid either
    # or maybe we should generate a 'try' XXX
  of nkVarSection:
    for it in n:
      let value = it.lastSon
      if value.kind != nkEmpty:
        for j in 0 .. it.len-3:
          if it[j].isLocal:
            let slot = c.getSlot(it[j].sym)
            if slot.lower.isNil: slot.lower = value
            else: internalError(it.info, "slot already has a lower bound")
        analyse(c, value)
  of nkCaseStmt: analyseCase(c, n)
  of nkIfStmt, nkIfExpr: analyseIf(c, n)
  of nkWhileStmt:
    analyse(c, n.sons[0])
    # 'while true' loop?
    inc c.inLoop
    if isTrue(n.sons[0]):
      analyseSons(c, n.sons[1])
    else:
      # loop may never execute:
      let oldState = c.locals.len
      let oldFacts = c.guards.len
      addFact(c.guards, canon(n.sons[0]))
      analyse(c, n.sons[1])
      setLen(c.locals, oldState)
      setLen(c.guards, oldFacts)
      # we know after the loop the negation holds:
      if not hasSubnodeWith(n.sons[1], nkBreakStmt):
        addFactNeg(c.guards, canon(n.sons[0]))
    dec c.inLoop
  of nkTypeSection, nkProcDef, nkConverterDef, nkMethodDef, nkIteratorDef,
      nkMacroDef, nkTemplateDef, nkConstSection, nkPragma:
    discard
  else:
    analyseSons(c, n)

proc transformSlices(n: PNode): PNode =
  if n.kind in nkCallKinds and n[0].kind == nkSym:
    let op = n[0].sym
    if op.name.s == "[]" and op.fromSystem:
      result = copyNode(n)
      result.add opSlice.newSymNode
      result.add n[1]
      result.add n[2][1]
      result.add n[2][2]
      return result
  if n.safeLen > 0:
    result = shallowCopy(n)
    for i in 0 .. < n.len:
      result.sons[i] = transformSlices(n.sons[i])
  else:
    result = n

proc transformSpawn(owner: PSym; n, barrier: PNode): PNode
proc transformSpawnSons(owner: PSym; n, barrier: PNode): PNode =
  result = shallowCopy(n)
  for i in 0 .. < n.len:
    result.sons[i] = transformSpawn(owner, n.sons[i], barrier)

proc transformSpawn(owner: PSym; n, barrier: PNode): PNode =
  case n.kind
  of nkVarSection:
    result = nil
    for it in n:
      let b = it.lastSon
      if getMagic(b) == mSpawn:
        if it.len != 3: localError(it.info, "invalid context for 'spawn'")
        let m = transformSlices(b)
        if result.isNil:
          result = newNodeI(nkStmtList, n.info)
          result.add n
        let t = b[1][0].typ.sons[0]
        if spawnResult(t, true) == srByVar:
          result.add wrapProcForSpawn(owner, m, b.typ, barrier, it[0])
          it.sons[it.len-1] = emptyNode
        else:
          it.sons[it.len-1] = wrapProcForSpawn(owner, m, b.typ, barrier, nil)
    if result.isNil: result = n
  of nkAsgn, nkFastAsgn:
    let b = n[1]
    if getMagic(b) == mSpawn and (let t = b[1][0].typ.sons[0];
        spawnResult(t, true) == srByVar):
      let m = transformSlices(b)
      return wrapProcForSpawn(owner, m, b.typ, barrier, n[0])
    result = transformSpawnSons(owner, n, barrier)
  of nkCallKinds:
    if getMagic(n) == mSpawn:
      result = transformSlices(n)
      return wrapProcForSpawn(owner, result, n.typ, barrier, nil)
    result = transformSpawnSons(owner, n, barrier)
  elif n.safeLen > 0:
    result = transformSpawnSons(owner, n, barrier)
  else:
    result = n

proc checkArgs(a: var AnalysisCtx; n: PNode) =
  discard "too implement"

proc generateAliasChecks(a: AnalysisCtx; result: PNode) =
  discard "too implement"

proc liftParallel*(owner: PSym; n: PNode): PNode =
  # this needs to be called after the 'for' loop elimination

  # first pass:
  # - detect monotonic local integer variables
  # - detect used slices
  # - detect used arguments
  #echo "PAR ", renderTree(n)
  
  var a = initAnalysisCtx()
  let body = n.lastSon
  analyse(a, body)
  if a.spawns == 0:
    localError(n.info, "'parallel' section without 'spawn'")
  checkSlicesAreDisjoint(a)
  checkArgs(a, body)

  var varSection = newNodeI(nkVarSection, n.info)
  var temp = newSym(skTemp, getIdent"barrier", owner, n.info)
  temp.typ = magicsys.getCompilerProc("Barrier").typ
  incl(temp.flags, sfFromGeneric)
  let tempNode = newSymNode(temp)
  varSection.addVar tempNode

  let barrier = genAddrOf(tempNode)
  result = newNodeI(nkStmtList, n.info)
  generateAliasChecks(a, result)
  result.add varSection
  result.add callCodegenProc("openBarrier", barrier)
  result.add transformSpawn(owner, body, barrier)
  result.add callCodegenProc("closeBarrier", barrier)