diff options
-rw-r--r-- | compiler/guards.nim | 112 | ||||
-rw-r--r-- | compiler/semparallel.nim | 41 | ||||
-rw-r--r-- | tests/parallel/tdisjoint_slice1.nim | 21 | ||||
-rw-r--r-- | tests/parallel/tdisjoint_slice2.nim | 21 | ||||
-rw-r--r-- | tests/parallel/tinvalid_array_bounds.nim | 25 | ||||
-rw-r--r-- | tests/parallel/tinvalid_counter_usage.nim | 26 | ||||
-rw-r--r-- | tests/parallel/tnon_disjoint_slice1.nim | 25 |
7 files changed, 221 insertions, 50 deletions
diff --git a/compiler/guards.nim b/compiler/guards.nim index 551a11256..de0ce1dcc 100644 --- a/compiler/guards.nim +++ b/compiler/guards.nim @@ -1,7 +1,7 @@ # # # The Nimrod Compiler -# (c) Copyright 2013 Andreas Rumpf +# (c) Copyright 2014 Andreas Rumpf # # See the file "copying.txt", included in this # distribution, for details about the copyright. @@ -165,9 +165,6 @@ proc buildCall(op: PSym; a, b: PNode): PNode = result.sons[1] = a result.sons[2] = b -proc `+@`*(a: PNode; b: BiggestInt): PNode = - (if b != 0: opAdd.buildCall(a, nkIntLit.newIntNode(b)) else: a) - proc `|+|`(a, b: PNode): PNode = result = copyNode(a) if a.kind in {nkCharLit..nkUInt64Lit}: result.intVal = a.intVal |+| b.intVal @@ -178,22 +175,56 @@ proc `|*|`(a, b: PNode): PNode = if a.kind in {nkCharLit..nkUInt64Lit}: result.intVal = a.intVal |*| b.intVal else: result.floatVal = a.floatVal * b.floatVal +proc negate(a, b, res: PNode): PNode = + if b.kind in {nkCharLit..nkUInt64Lit} and b.intVal != low(BiggestInt): + var b = copyNode(b) + b.intVal = -b.intVal + if a.kind in {nkCharLit..nkUInt64Lit}: + b.intVal = b.intVal |+| a.intVal + result = b + else: + result = buildCall(opAdd, a, b) + elif b.kind in {nkFloatLit..nkFloat64Lit}: + var b = copyNode(b) + b.floatVal = -b.floatVal + result = buildCall(opAdd, a, b) + else: + result = res + proc zero(): PNode = nkIntLit.newIntNode(0) proc one(): PNode = nkIntLit.newIntNode(1) proc minusOne(): PNode = nkIntLit.newIntNode(-1) -proc lowBound*(x: PNode): PNode = nkIntLit.newIntNode(firstOrd(x.typ)) +proc lowBound*(x: PNode): PNode = + result = nkIntLit.newIntNode(firstOrd(x.typ)) + result.info = x.info + proc highBound*(x: PNode): PNode = - if x.typ.skipTypes(abstractInst).kind == tyArray: - nkIntLit.newIntNode(lastOrd(x.typ)) - else: - opAdd.buildCall(opLen.buildCall(x), minusOne()) + result = if x.typ.skipTypes(abstractInst).kind == tyArray: + nkIntLit.newIntNode(lastOrd(x.typ)) + else: + opAdd.buildCall(opLen.buildCall(x), minusOne()) + result.info = x.info + +proc reassociation(n: PNode): PNode = + result = n + # (foo+5)+5 --> foo+10; same for '*' + case result.getMagic + of someAdd: + if result[2].isValue and + result[1].getMagic in someAdd and result[1][2].isValue: + result = opAdd.buildCall(result[1][1], result[1][2] |+| result[2]) + of someMul: + if result[2].isValue and + result[1].getMagic in someMul and result[1][2].isValue: + result = opAdd.buildCall(result[1][1], result[1][2] |*| result[2]) + else: discard proc canon*(n: PNode): PNode = # XXX for now only the new code in 'semparallel' uses this if n.safeLen >= 1: - result = newNodeI(n.kind, n.info, n.len) - for i in 0 .. < n.safeLen: + result = shallowCopy(n) + for i in 0 .. < n.len: result.sons[i] = canon(n.sons[i]) else: result = n @@ -210,32 +241,12 @@ proc canon*(n: PNode): PNode = result = buildCall(opAdd, result[1], newIntNode(nkIntLit, -1)) of someSub: # x - 4 --> x + (-4) - var b = result[2] - if b.kind in {nkCharLit..nkUInt64Lit} and b.intVal != low(BiggestInt): - b = copyNode(b) - b.intVal = -b.intVal - result = buildCall(opAdd, result[1], b) - elif b.kind in {nkFloatLit..nkFloat64Lit}: - b = copyNode(b) - b.floatVal = -b.floatVal - result = buildCall(opAdd, result[1], b) + result = negate(result[1], result[2], result) of someLen: result.sons[0] = opLen.newSymNode else: discard - # re-association: - # (foo+5)+5 --> foo+10; same for '*' - case result.getMagic - of someAdd: - if result[2].isValue and - result[1].getMagic in someAdd and result[1][2].isValue: - result = opAdd.buildCall(result[1][1], result[1][2] |+| result[2]) - of someMul: - if result[2].isValue and - result[1].getMagic in someMul and result[1][2].isValue: - result = opAdd.buildCall(result[1][1], result[1][2] |*| result[2]) - else: discard - + result = reassociation(result) # most important rule: (x-4) < a.len --> x < a.len+4 case result.getMagic of someLe, someLt: @@ -245,21 +256,32 @@ proc canon*(n: PNode): PNode = isLetLocation(x[1], true): case x.getMagic of someSub: - result = buildCall(result[0].sym, x[1], opAdd.buildCall(y, x[2])) + result = buildCall(result[0].sym, x[1], + reassociation(opAdd.buildCall(y, x[2]))) of someAdd: - result = buildCall(result[0].sym, x[1], opSub.buildCall(y, x[2])) + # Rule A: + let plus = negate(y, x[2], nil).reassociation + if plus != nil: result = buildCall(result[0].sym, x[1], plus) else: discard elif y.kind in nkCallKinds and y.len == 3 and y[2].isValue and isLetLocation(y[1], true): # a.len < x-3 case y.getMagic of someSub: - result = buildCall(result[0].sym, y[1], opAdd.buildCall(x, y[2])) + result = buildCall(result[0].sym, y[1], + reassociation(opAdd.buildCall(x, y[2]))) of someAdd: - result = buildCall(result[0].sym, y[1], opSub.buildCall(x, y[2])) + let plus = negate(x, y[2], nil).reassociation + # ensure that Rule A will not trigger afterwards with the + # additional 'not isLetLocation' constraint: + if plus != nil and not isLetLocation(x, true): + result = buildCall(result[0].sym, plus, y[1]) else: discard else: discard +proc `+@`*(a: PNode; b: BiggestInt): PNode = + canon(if b != 0: opAdd.buildCall(a, nkIntLit.newIntNode(b)) else: a) + proc usefulFact(n: PNode): PNode = case n.getMagic of someEq: @@ -639,8 +661,20 @@ proc doesImply*(facts: TModel, prop: PNode): TImplication = proc impliesNotNil*(facts: TModel, arg: PNode): TImplication = result = doesImply(facts, opIsNil.buildCall(arg).neg) +proc simpleSlice*(a, b: PNode): BiggestInt = + # returns 'c' if a..b matches (i+c)..(i+c), -1 otherwise. (i)..(i) is matched + # as if it is (i+0)..(i+0). + if guards.sameTree(a, b): + if a.getMagic in someAdd and a[2].kind in {nkCharLit..nkUInt64Lit}: + result = a[2].intVal + else: + result = 0 + else: + result = -1 + proc proveLe*(m: TModel; a, b: PNode): TImplication = let res = canon(opLe.buildCall(a, b)) + #echo renderTree(res) # we hardcode lots of axioms here: let a = res[1] let b = res[2] @@ -662,6 +696,10 @@ proc proveLe*(m: TModel; a, b: PNode): TImplication = if b.getMagic in someAdd and sameTree(a, b[1]): return proveLe(m, zero(), b[2]) + # x+c <= x iff c <= 0 + if a.getMagic in someAdd and sameTree(b, a[1]): + return proveLe(m, a[2], zero()) + # x <= x*c if 1 <= c and 0 <= x: if b.getMagic in someMul and sameTree(a, b[1]): if proveLe(m, one(), b[2]) == impYes and proveLe(m, zero(), a) == impYes: diff --git a/compiler/semparallel.nim b/compiler/semparallel.nim index dd1584e7d..7917cab90 100644 --- a/compiler/semparallel.nim +++ b/compiler/semparallel.nim @@ -9,6 +9,8 @@ ## Semantic checking for 'parallel'. +# - codegen needs to support mSlice +# - lowerings must not perform unnecessary copies # - slices should become "nocopy" to openArray (+) # - need to perform bound checks (+) # @@ -153,6 +155,8 @@ proc addLowerBoundAsFacts(c: var AnalysisCtx) = proc addSlice(c: var AnalysisCtx; n: PNode; x, le, ri: PNode) = checkLocal(c, n) + let le = le.canon + let ri = ri.canon # perform static bounds checking here; and not later! let oldState = c.guards.len addLowerBoundAsFacts(c) @@ -166,16 +170,16 @@ proc overlap(m: TModel; x,y,c,d: PNode) = case proveLe(m, x, d) of impUnknown: localError(x.info, - "cannot prove: $# > $#; required for $#..$# disjoint from $#..$#" % + "cannot prove: $# > $#; required for ($#)..($#) disjoint from ($#)..($#)" % [?x, ?d, ?x, ?y, ?c, ?d]) of impYes: case proveLe(m, c, y) of impUnknown: localError(x.info, - "cannot prove: $# > $#; required for $#..$# disjoint from $#..$#" % + "cannot prove: $# > $#; required for ($#)..($#) disjoint from ($#)..($#)" % [?y, ?d, ?x, ?y, ?c, ?d]) of impYes: - localError(x.info, "$#..$# not disjoint from $#..$#" % [?x, ?y, ?c, ?d]) + localError(x.info, "($#)..($#) not disjoint from ($#)..($#)" % [?x, ?y, ?c, ?d]) of impNo: discard of impNo: discard @@ -220,14 +224,25 @@ proc checkSlicesAreDisjoint(c: var AnalysisCtx) = let x = c.slices[i] let y = c.slices[j] if x.spawnId != y.spawnId and guards.sameTree(x.x, y.x): - if not x.inLoop and not y.inLoop: + if not x.inLoop or not y.inLoop: + # XXX strictly speaking, 'or' is not correct here and it needs to + # be 'and'. However this prevents too many obviously correct programs + # like f(a[0..x]); for i in x+1 .. a.high: f(a[i]) overlap(c.guards, x.a, x.b, y.a, y.b) + elif (let k = simpleSlice(x.a, x.b); let m = simpleSlice(y.a, y.b); + k >= 0 and m >= 0): + # ah I cannot resist the temptation and add another sweet heuristic: + # if both slices have the form (i+k)..(i+k) and (i+m)..(i+m) we + # check they are disjoint and k < stride and m < stride: + overlap(c.guards, x.a, x.b, y.a, y.b) + let stride = min(c.stride(x.a), c.stride(y.a)) + if k < stride and m < stride: + discard + else: + localError(x.x.info, "cannot prove ($#)..($#) disjoint from ($#)..($#)" % + [?x.a, ?x.b, ?y.a, ?y.b]) else: - # ah I cannot resists the temptation and add another sweet heuristic: - # if both slices have the form (i+c)..(i+c) and (i+d)..(i+d) we - # check they are disjoint and c <= stride and d <= stride: - # XXX - localError(x.x.info, "cannot prove $#..$# disjoint from $#..$#" % + localError(x.x.info, "cannot prove ($#)..($#) disjoint from ($#)..($#)" % [?x.a, ?x.b, ?y.a, ?y.b]) proc analyse(c: var AnalysisCtx; n: PNode) @@ -369,9 +384,9 @@ proc transformSlices(n: PNode): PNode = result.add n[2][2] return result if n.safeLen > 0: - result = copyNode(n) + result = shallowCopy(n) for i in 0 .. < n.len: - result.add transformSlices(n.sons[i]) + result.sons[i] = transformSlices(n.sons[i]) else: result = n @@ -383,9 +398,9 @@ proc transformSpawn(owner: PSym; n, barrier: PNode): PNode = result = transformSlices(n) return wrapProcForSpawn(owner, result[1], barrier) elif n.safeLen > 0: - result = copyNode(n) + result = shallowCopy(n) for i in 0 .. < n.len: - result.add transformSpawn(owner, n.sons[i], barrier) + result.sons[i] = transformSpawn(owner, n.sons[i], barrier) else: result = n diff --git a/tests/parallel/tdisjoint_slice1.nim b/tests/parallel/tdisjoint_slice1.nim new file mode 100644 index 000000000..2ca96d6ae --- /dev/null +++ b/tests/parallel/tdisjoint_slice1.nim @@ -0,0 +1,21 @@ + +import threadpool + +proc f(a: openArray[int]) = + for x in a: echo x + +proc f(a: int) = echo a + +proc main() = + var a: array[0..30, int] + parallel: + #spawn f(a[0..15]) + #spawn f(a[16..30]) + var i = 0 + while i <= 29: + spawn f(a[i]) + spawn f(a[i+1]) + inc i, 2 + # is correct here + +main() diff --git a/tests/parallel/tdisjoint_slice2.nim b/tests/parallel/tdisjoint_slice2.nim new file mode 100644 index 000000000..b26559fc2 --- /dev/null +++ b/tests/parallel/tdisjoint_slice2.nim @@ -0,0 +1,21 @@ + +import threadpool + +proc f(a: openArray[int]) = + for x in a: echo x + +proc f(a: int) = echo a + +proc main() = + var a: array[0..30, int] + parallel: + spawn f(a[0..15]) + #spawn f(a[16..30]) + var i = 16 + while i <= 29: + spawn f(a[i]) + spawn f(a[i+1]) + inc i, 2 + # is correct here + +main() diff --git a/tests/parallel/tinvalid_array_bounds.nim b/tests/parallel/tinvalid_array_bounds.nim new file mode 100644 index 000000000..337fae729 --- /dev/null +++ b/tests/parallel/tinvalid_array_bounds.nim @@ -0,0 +1,25 @@ +discard """ + errormsg: "cannot prove: i + 1 <= 30" + line: 21 +""" + +import threadpool + +proc f(a: openArray[int]) = + for x in a: echo x + +proc f(a: int) = echo a + +proc main() = + var a: array[0..30, int] + parallel: + spawn f(a[0..15]) + spawn f(a[16..30]) + var i = 0 + while i <= 30: + spawn f(a[i]) + spawn f(a[i+1]) + inc i + #inc i # inc i, 2 would be correct here + +main() diff --git a/tests/parallel/tinvalid_counter_usage.nim b/tests/parallel/tinvalid_counter_usage.nim new file mode 100644 index 000000000..c6303c651 --- /dev/null +++ b/tests/parallel/tinvalid_counter_usage.nim @@ -0,0 +1,26 @@ +discard """ + errormsg: "invalid usage of counter after increment" + line: 21 +""" + +import threadpool + +proc f(a: openArray[int]) = + for x in a: echo x + +proc f(a: int) = echo a + +proc main() = + var a: array[0..30, int] + parallel: + spawn f(a[0..15]) + spawn f(a[16..30]) + var i = 0 + while i <= 30: + inc i + spawn f(a[i]) + inc i + #spawn f(a[i+1]) + #inc i # inc i, 2 would be correct here + +main() diff --git a/tests/parallel/tnon_disjoint_slice1.nim b/tests/parallel/tnon_disjoint_slice1.nim new file mode 100644 index 000000000..72d008bbd --- /dev/null +++ b/tests/parallel/tnon_disjoint_slice1.nim @@ -0,0 +1,25 @@ +discard """ + errormsg: "cannot prove (i)..(i) disjoint from (i + 1)..(i + 1)" + line: 20 +""" + +import threadpool + +proc f(a: openArray[int]) = + for x in a: echo x + +proc f(a: int) = echo a + +proc main() = + var a: array[0..30, int] + parallel: + #spawn f(a[0..15]) + #spawn f(a[16..30]) + var i = 0 + while i <= 29: + spawn f(a[i]) + spawn f(a[i+1]) + inc i + #inc i # inc i, 2 would be correct here + +main() |