summary refs log tree commit diff stats
diff options
context:
space:
mode:
-rw-r--r--compiler/semparallel.nim32
-rw-r--r--tests/parallel/treadafterwrite.nim31
-rw-r--r--tests/parallel/tuseafterdef.nim31
3 files changed, 86 insertions, 8 deletions
diff --git a/compiler/semparallel.nim b/compiler/semparallel.nim
index c4546f616..bd3152b54 100644
--- a/compiler/semparallel.nim
+++ b/compiler/semparallel.nim
@@ -258,14 +258,18 @@ proc min(a, b: PNode): PNode =
 
 proc fromSystem(op: PSym): bool = sfSystemModule in getModule(op).flags
 
+template pushSpawnId(c: expr, body: stmt) {.immediate, dirty.} =
+  inc c.spawns
+  let oldSpawnId = c.currentSpawnId
+  c.currentSpawnId = c.spawns
+  body
+  c.currentSpawnId = oldSpawnId
+
 proc analyseCall(c: var AnalysisCtx; n: PNode; op: PSym) =
   if op.magic == mSpawn:
-    inc c.spawns
-    let oldSpawnId = c.currentSpawnId
-    c.currentSpawnId = c.spawns
-    gatherArgs(c, n[1])
-    analyseSons(c, n)
-    c.currentSpawnId = oldSpawnId
+    pushSpawnId(c):
+      gatherArgs(c, n[1])
+      analyseSons(c, n)
   elif op.magic == mInc or (op.name.s == "+=" and op.fromSystem):
     if n[1].isLocal:
       let incr = n[2].skipConv
@@ -322,7 +326,14 @@ proc analyse(c: var AnalysisCtx; n: PNode) =
       let slot = c.getSlot(n[0].sym)
       slot.blacklisted = true
     invalidateFacts(c.guards, n[0])
-    analyseSons(c, n)
+    let value = n[1]
+    if getMagic(value) == mSpawn:
+      pushSpawnId(c):
+        gatherArgs(c, value[1])
+        analyseSons(c, value[1])
+        analyse(c, n[0])
+    else:
+      analyseSons(c, n)
     addAsgnFact(c.guards, n[0], n[1])
   of nkCallKinds:
     # direct call:
@@ -338,13 +349,18 @@ proc analyse(c: var AnalysisCtx; n: PNode) =
   of nkVarSection, nkLetSection:
     for it in n:
       let value = it.lastSon
+      let isSpawned = getMagic(value) == mSpawn
+      if isSpawned:
+        pushSpawnId(c):
+          gatherArgs(c, value[1])
+          analyseSons(c, value[1])
       if value.kind != nkEmpty:
         for j in 0 .. it.len-3:
           if it[j].isLocal:
             let slot = c.getSlot(it[j].sym)
             if slot.lower.isNil: slot.lower = value
             else: internalError(it.info, "slot already has a lower bound")
-        analyse(c, value)
+        if not isSpawned: analyse(c, value)
   of nkCaseStmt: analyseCase(c, n)
   of nkIfStmt, nkIfExpr: analyseIf(c, n)
   of nkWhileStmt:
diff --git a/tests/parallel/treadafterwrite.nim b/tests/parallel/treadafterwrite.nim
new file mode 100644
index 000000000..f59ad5ae0
--- /dev/null
+++ b/tests/parallel/treadafterwrite.nim
@@ -0,0 +1,31 @@
+discard """
+  errormsg: "'foo' not disjoint from 'foo'"
+  line: 23
+  disabled: "true"
+"""
+
+# bug #1597
+
+import strutils, math, threadpool
+
+type
+  BoxedFloat = object
+    value: float
+
+proc term(k: float): ptr BoxedFloat = 
+  var temp = 4 * math.pow(-1, k) / (2*k + 1)
+  result = cast[ptr BoxedFloat](allocShared(sizeof(BoxedFloat)))
+  result.value = temp
+
+proc pi(n: int): float =
+  var ch = newSeq[ptr BoxedFloat](n+1)
+  parallel:
+    for k in 0..ch.high:
+      let foo = (spawn term(float(k)))
+      assert foo != nil
+  for k in 0..ch.high:
+    var temp = ch[k][].value
+    result += temp
+    deallocShared(ch[k])
+
+echo formatFloat(pi(5000))
diff --git a/tests/parallel/tuseafterdef.nim b/tests/parallel/tuseafterdef.nim
new file mode 100644
index 000000000..95123e886
--- /dev/null
+++ b/tests/parallel/tuseafterdef.nim
@@ -0,0 +1,31 @@
+discard """
+  errormsg: "(k)..(k) not disjoint from (k)..(k)"
+  line: 23
+"""
+
+# bug #1597
+
+import strutils, math, threadpool
+
+type 
+  BoxedFloat = object
+    value: float
+
+proc term(k: float): ptr BoxedFloat = 
+  var temp = 4 * math.pow(-1, k) / (2*k + 1)
+  result = cast[ptr BoxedFloat](allocShared(sizeof(BoxedFloat)))
+  result.value = temp
+
+proc pi(n: int): float =
+  var ch = newSeq[ptr BoxedFloat](n+1)
+  parallel:
+    for k in 0..ch.high:
+      ch[k] = (spawn term(float(k)))
+      assert ch[k] != nil
+  for k in 0..ch.high:
+    var temp = ch[k][].value
+    result += temp
+    deallocShared(ch[k])
+
+
+echo formatFloat(pi(5000))