summary refs log tree commit diff stats
diff options
context:
space:
mode:
-rw-r--r--compiler/guards.nim112
-rw-r--r--compiler/semparallel.nim41
-rw-r--r--tests/parallel/tdisjoint_slice1.nim21
-rw-r--r--tests/parallel/tdisjoint_slice2.nim21
-rw-r--r--tests/parallel/tinvalid_array_bounds.nim25
-rw-r--r--tests/parallel/tinvalid_counter_usage.nim26
-rw-r--r--tests/parallel/tnon_disjoint_slice1.nim25
7 files changed, 221 insertions, 50 deletions
diff --git a/compiler/guards.nim b/compiler/guards.nim
index 551a11256..de0ce1dcc 100644
--- a/compiler/guards.nim
+++ b/compiler/guards.nim
@@ -1,7 +1,7 @@
 #
 #
 #           The Nimrod Compiler
-#        (c) Copyright 2013 Andreas Rumpf
+#        (c) Copyright 2014 Andreas Rumpf
 #
 #    See the file "copying.txt", included in this
 #    distribution, for details about the copyright.
@@ -165,9 +165,6 @@ proc buildCall(op: PSym; a, b: PNode): PNode =
   result.sons[1] = a
   result.sons[2] = b
 
-proc `+@`*(a: PNode; b: BiggestInt): PNode =
-  (if b != 0: opAdd.buildCall(a, nkIntLit.newIntNode(b)) else: a)
-
 proc `|+|`(a, b: PNode): PNode =
   result = copyNode(a)
   if a.kind in {nkCharLit..nkUInt64Lit}: result.intVal = a.intVal |+| b.intVal
@@ -178,22 +175,56 @@ proc `|*|`(a, b: PNode): PNode =
   if a.kind in {nkCharLit..nkUInt64Lit}: result.intVal = a.intVal |*| b.intVal
   else: result.floatVal = a.floatVal * b.floatVal
 
+proc negate(a, b, res: PNode): PNode =
+  if b.kind in {nkCharLit..nkUInt64Lit} and b.intVal != low(BiggestInt):
+    var b = copyNode(b)
+    b.intVal = -b.intVal
+    if a.kind in {nkCharLit..nkUInt64Lit}:
+      b.intVal = b.intVal |+| a.intVal
+      result = b
+    else:
+      result = buildCall(opAdd, a, b)
+  elif b.kind in {nkFloatLit..nkFloat64Lit}:
+    var b = copyNode(b)
+    b.floatVal = -b.floatVal
+    result = buildCall(opAdd, a, b)
+  else:
+    result = res
+
 proc zero(): PNode = nkIntLit.newIntNode(0)
 proc one(): PNode = nkIntLit.newIntNode(1)
 proc minusOne(): PNode = nkIntLit.newIntNode(-1)
 
-proc lowBound*(x: PNode): PNode = nkIntLit.newIntNode(firstOrd(x.typ))
+proc lowBound*(x: PNode): PNode = 
+  result = nkIntLit.newIntNode(firstOrd(x.typ))
+  result.info = x.info
+
 proc highBound*(x: PNode): PNode =
-  if x.typ.skipTypes(abstractInst).kind == tyArray:
-    nkIntLit.newIntNode(lastOrd(x.typ))
-  else:
-    opAdd.buildCall(opLen.buildCall(x), minusOne())
+  result = if x.typ.skipTypes(abstractInst).kind == tyArray:
+             nkIntLit.newIntNode(lastOrd(x.typ))
+           else:
+             opAdd.buildCall(opLen.buildCall(x), minusOne())
+  result.info = x.info
+
+proc reassociation(n: PNode): PNode =
+  result = n
+  # (foo+5)+5 --> foo+10;  same for '*'
+  case result.getMagic
+  of someAdd:
+    if result[2].isValue and 
+        result[1].getMagic in someAdd and result[1][2].isValue:
+      result = opAdd.buildCall(result[1][1], result[1][2] |+| result[2])
+  of someMul:
+    if result[2].isValue and 
+        result[1].getMagic in someMul and result[1][2].isValue:
+      result = opAdd.buildCall(result[1][1], result[1][2] |*| result[2])
+  else: discard
 
 proc canon*(n: PNode): PNode =
   # XXX for now only the new code in 'semparallel' uses this
   if n.safeLen >= 1:
-    result = newNodeI(n.kind, n.info, n.len)
-    for i in 0 .. < n.safeLen:
+    result = shallowCopy(n)
+    for i in 0 .. < n.len:
       result.sons[i] = canon(n.sons[i])
   else:
     result = n
@@ -210,32 +241,12 @@ proc canon*(n: PNode): PNode =
     result = buildCall(opAdd, result[1], newIntNode(nkIntLit, -1))
   of someSub:
     # x - 4  -->  x + (-4)
-    var b = result[2]
-    if b.kind in {nkCharLit..nkUInt64Lit} and b.intVal != low(BiggestInt):
-      b = copyNode(b)
-      b.intVal = -b.intVal
-      result = buildCall(opAdd, result[1], b)
-    elif b.kind in {nkFloatLit..nkFloat64Lit}:
-      b = copyNode(b)
-      b.floatVal = -b.floatVal
-      result = buildCall(opAdd, result[1], b)    
+    result = negate(result[1], result[2], result)
   of someLen:
     result.sons[0] = opLen.newSymNode
   else: discard
 
-  # re-association:
-  # (foo+5)+5 --> foo+10;  same for '*'
-  case result.getMagic
-  of someAdd:
-    if result[2].isValue and 
-        result[1].getMagic in someAdd and result[1][2].isValue:
-      result = opAdd.buildCall(result[1][1], result[1][2] |+| result[2])
-  of someMul:
-    if result[2].isValue and 
-        result[1].getMagic in someMul and result[1][2].isValue:
-      result = opAdd.buildCall(result[1][1], result[1][2] |*| result[2])
-  else: discard
-
+  result = reassociation(result)
   # most important rule: (x-4) < a.len -->  x < a.len+4
   case result.getMagic
   of someLe, someLt:
@@ -245,21 +256,32 @@ proc canon*(n: PNode): PNode =
         isLetLocation(x[1], true):
       case x.getMagic
       of someSub:
-        result = buildCall(result[0].sym, x[1], opAdd.buildCall(y, x[2]))
+        result = buildCall(result[0].sym, x[1], 
+                           reassociation(opAdd.buildCall(y, x[2])))
       of someAdd:
-        result = buildCall(result[0].sym, x[1], opSub.buildCall(y, x[2]))
+        # Rule A:
+        let plus = negate(y, x[2], nil).reassociation
+        if plus != nil: result = buildCall(result[0].sym, x[1], plus)
       else: discard
     elif y.kind in nkCallKinds and y.len == 3 and y[2].isValue and 
         isLetLocation(y[1], true):
       # a.len < x-3
       case y.getMagic
       of someSub:
-        result = buildCall(result[0].sym, y[1], opAdd.buildCall(x, y[2]))
+        result = buildCall(result[0].sym, y[1],
+                           reassociation(opAdd.buildCall(x, y[2])))
       of someAdd:
-        result = buildCall(result[0].sym, y[1], opSub.buildCall(x, y[2]))
+        let plus = negate(x, y[2], nil).reassociation
+        # ensure that Rule A will not trigger afterwards with the
+        # additional 'not isLetLocation' constraint:
+        if plus != nil and not isLetLocation(x, true):
+          result = buildCall(result[0].sym, plus, y[1])
       else: discard
   else: discard
 
+proc `+@`*(a: PNode; b: BiggestInt): PNode =
+  canon(if b != 0: opAdd.buildCall(a, nkIntLit.newIntNode(b)) else: a)
+
 proc usefulFact(n: PNode): PNode =
   case n.getMagic
   of someEq:
@@ -639,8 +661,20 @@ proc doesImply*(facts: TModel, prop: PNode): TImplication =
 proc impliesNotNil*(facts: TModel, arg: PNode): TImplication =
   result = doesImply(facts, opIsNil.buildCall(arg).neg)
 
+proc simpleSlice*(a, b: PNode): BiggestInt =
+  # returns 'c' if a..b matches (i+c)..(i+c), -1 otherwise. (i)..(i) is matched
+  # as if it is (i+0)..(i+0).
+  if guards.sameTree(a, b):
+    if a.getMagic in someAdd and a[2].kind in {nkCharLit..nkUInt64Lit}:
+      result = a[2].intVal
+    else:
+      result = 0
+  else:
+    result = -1
+
 proc proveLe*(m: TModel; a, b: PNode): TImplication =
   let res = canon(opLe.buildCall(a, b))
+  #echo renderTree(res)
   # we hardcode lots of axioms here:
   let a = res[1]
   let b = res[2]
@@ -662,6 +696,10 @@ proc proveLe*(m: TModel; a, b: PNode): TImplication =
   if b.getMagic in someAdd and sameTree(a, b[1]):
     return proveLe(m, zero(), b[2])
 
+  #   x+c <= x  iff c <= 0
+  if a.getMagic in someAdd and sameTree(b, a[1]):
+    return proveLe(m, a[2], zero())
+
   #   x <= x*c  if  1 <= c and 0 <= x:
   if b.getMagic in someMul and sameTree(a, b[1]):
     if proveLe(m, one(), b[2]) == impYes and proveLe(m, zero(), a) == impYes:
diff --git a/compiler/semparallel.nim b/compiler/semparallel.nim
index dd1584e7d..7917cab90 100644
--- a/compiler/semparallel.nim
+++ b/compiler/semparallel.nim
@@ -9,6 +9,8 @@
 
 ## Semantic checking for 'parallel'.
 
+# - codegen needs to support mSlice
+# - lowerings must not perform unnecessary copies
 # - slices should become "nocopy" to openArray (+)
 #   - need to perform bound checks (+)
 #
@@ -153,6 +155,8 @@ proc addLowerBoundAsFacts(c: var AnalysisCtx) =
 
 proc addSlice(c: var AnalysisCtx; n: PNode; x, le, ri: PNode) =
   checkLocal(c, n)
+  let le = le.canon
+  let ri = ri.canon
   # perform static bounds checking here; and not later!
   let oldState = c.guards.len
   addLowerBoundAsFacts(c)
@@ -166,16 +170,16 @@ proc overlap(m: TModel; x,y,c,d: PNode) =
   case proveLe(m, x, d)
   of impUnknown:
     localError(x.info,
-      "cannot prove: $# > $#; required for $#..$# disjoint from $#..$#" %
+      "cannot prove: $# > $#; required for ($#)..($#) disjoint from ($#)..($#)" %
         [?x, ?d, ?x, ?y, ?c, ?d])
   of impYes:
     case proveLe(m, c, y)
     of impUnknown:
       localError(x.info,
-        "cannot prove: $# > $#; required for $#..$# disjoint from $#..$#" %
+        "cannot prove: $# > $#; required for ($#)..($#) disjoint from ($#)..($#)" %
           [?y, ?d, ?x, ?y, ?c, ?d])
     of impYes:
-      localError(x.info, "$#..$# not disjoint from $#..$#" % [?x, ?y, ?c, ?d])
+      localError(x.info, "($#)..($#) not disjoint from ($#)..($#)" % [?x, ?y, ?c, ?d])
     of impNo: discard
   of impNo: discard
 
@@ -220,14 +224,25 @@ proc checkSlicesAreDisjoint(c: var AnalysisCtx) =
       let x = c.slices[i]
       let y = c.slices[j]
       if x.spawnId != y.spawnId and guards.sameTree(x.x, y.x):
-        if not x.inLoop and not y.inLoop:
+        if not x.inLoop or not y.inLoop:
+          # XXX strictly speaking, 'or' is not correct here and it needs to
+          # be 'and'. However this prevents too many obviously correct programs
+          # like f(a[0..x]); for i in x+1 .. a.high: f(a[i])
           overlap(c.guards, x.a, x.b, y.a, y.b)
+        elif (let k = simpleSlice(x.a, x.b); let m = simpleSlice(y.a, y.b);
+              k >= 0 and m >= 0):
+          # ah I cannot resist the temptation and add another sweet heuristic:
+          # if both slices have the form (i+k)..(i+k)  and (i+m)..(i+m) we
+          # check they are disjoint and k < stride and m < stride:
+          overlap(c.guards, x.a, x.b, y.a, y.b)
+          let stride = min(c.stride(x.a), c.stride(y.a))
+          if k < stride and m < stride:
+            discard
+          else:
+            localError(x.x.info, "cannot prove ($#)..($#) disjoint from ($#)..($#)" %
+              [?x.a, ?x.b, ?y.a, ?y.b])
         else:
-          # ah I cannot resists the temptation and add another sweet heuristic:
-          # if both slices have the form (i+c)..(i+c)  and (i+d)..(i+d) we
-          # check they are disjoint and c <= stride and d <= stride:
-          # XXX
-          localError(x.x.info, "cannot prove $#..$# disjoint from $#..$#" %
+          localError(x.x.info, "cannot prove ($#)..($#) disjoint from ($#)..($#)" %
             [?x.a, ?x.b, ?y.a, ?y.b])
 
 proc analyse(c: var AnalysisCtx; n: PNode)
@@ -369,9 +384,9 @@ proc transformSlices(n: PNode): PNode =
       result.add n[2][2]
       return result
   if n.safeLen > 0:
-    result = copyNode(n)
+    result = shallowCopy(n)
     for i in 0 .. < n.len:
-      result.add transformSlices(n.sons[i])
+      result.sons[i] = transformSlices(n.sons[i])
   else:
     result = n
 
@@ -383,9 +398,9 @@ proc transformSpawn(owner: PSym; n, barrier: PNode): PNode =
         result = transformSlices(n)
         return wrapProcForSpawn(owner, result[1], barrier)
   elif n.safeLen > 0:
-    result = copyNode(n)
+    result = shallowCopy(n)
     for i in 0 .. < n.len:
-      result.add transformSpawn(owner, n.sons[i], barrier)
+      result.sons[i] = transformSpawn(owner, n.sons[i], barrier)
   else:
     result = n
 
diff --git a/tests/parallel/tdisjoint_slice1.nim b/tests/parallel/tdisjoint_slice1.nim
new file mode 100644
index 000000000..2ca96d6ae
--- /dev/null
+++ b/tests/parallel/tdisjoint_slice1.nim
@@ -0,0 +1,21 @@
+
+import threadpool
+
+proc f(a: openArray[int]) =
+  for x in a: echo x
+
+proc f(a: int) = echo a
+
+proc main() =
+  var a: array[0..30, int]
+  parallel:
+    #spawn f(a[0..15])
+    #spawn f(a[16..30])
+    var i = 0
+    while i <= 29:
+      spawn f(a[i])
+      spawn f(a[i+1])
+      inc i, 2
+      # is correct here
+
+main()
diff --git a/tests/parallel/tdisjoint_slice2.nim b/tests/parallel/tdisjoint_slice2.nim
new file mode 100644
index 000000000..b26559fc2
--- /dev/null
+++ b/tests/parallel/tdisjoint_slice2.nim
@@ -0,0 +1,21 @@
+
+import threadpool
+
+proc f(a: openArray[int]) =
+  for x in a: echo x
+
+proc f(a: int) = echo a
+
+proc main() =
+  var a: array[0..30, int]
+  parallel:
+    spawn f(a[0..15])
+    #spawn f(a[16..30])
+    var i = 16
+    while i <= 29:
+      spawn f(a[i])
+      spawn f(a[i+1])
+      inc i, 2
+      # is correct here
+
+main()
diff --git a/tests/parallel/tinvalid_array_bounds.nim b/tests/parallel/tinvalid_array_bounds.nim
new file mode 100644
index 000000000..337fae729
--- /dev/null
+++ b/tests/parallel/tinvalid_array_bounds.nim
@@ -0,0 +1,25 @@
+discard """
+  errormsg: "cannot prove: i + 1 <= 30"
+  line: 21
+"""
+
+import threadpool
+
+proc f(a: openArray[int]) =
+  for x in a: echo x
+
+proc f(a: int) = echo a
+
+proc main() =
+  var a: array[0..30, int]
+  parallel:
+    spawn f(a[0..15])
+    spawn f(a[16..30])
+    var i = 0
+    while i <= 30:
+      spawn f(a[i])
+      spawn f(a[i+1])
+      inc i
+      #inc i  # inc i, 2  would be correct here
+
+main()
diff --git a/tests/parallel/tinvalid_counter_usage.nim b/tests/parallel/tinvalid_counter_usage.nim
new file mode 100644
index 000000000..c6303c651
--- /dev/null
+++ b/tests/parallel/tinvalid_counter_usage.nim
@@ -0,0 +1,26 @@
+discard """
+  errormsg: "invalid usage of counter after increment"
+  line: 21
+"""
+
+import threadpool
+
+proc f(a: openArray[int]) =
+  for x in a: echo x
+
+proc f(a: int) = echo a
+
+proc main() =
+  var a: array[0..30, int]
+  parallel:
+    spawn f(a[0..15])
+    spawn f(a[16..30])
+    var i = 0
+    while i <= 30:
+      inc i
+      spawn f(a[i])
+      inc i
+      #spawn f(a[i+1])
+      #inc i  # inc i, 2  would be correct here
+
+main()
diff --git a/tests/parallel/tnon_disjoint_slice1.nim b/tests/parallel/tnon_disjoint_slice1.nim
new file mode 100644
index 000000000..72d008bbd
--- /dev/null
+++ b/tests/parallel/tnon_disjoint_slice1.nim
@@ -0,0 +1,25 @@
+discard """
+  errormsg: "cannot prove (i)..(i) disjoint from (i + 1)..(i + 1)"
+  line: 20
+"""
+
+import threadpool
+
+proc f(a: openArray[int]) =
+  for x in a: echo x
+
+proc f(a: int) = echo a
+
+proc main() =
+  var a: array[0..30, int]
+  parallel:
+    #spawn f(a[0..15])
+    #spawn f(a[16..30])
+    var i = 0
+    while i <= 29:
+      spawn f(a[i])
+      spawn f(a[i+1])
+      inc i
+      #inc i  # inc i, 2  would be correct here
+
+main()