diff options
-rw-r--r-- | compiler/transf.nim | 55 | ||||
-rw-r--r-- | tests/stdlib/tyield.nim | 256 |
2 files changed, 300 insertions, 11 deletions
diff --git a/compiler/transf.nim b/compiler/transf.nim index fb59887b0..80794b581 100644 --- a/compiler/transf.nim +++ b/compiler/transf.nim @@ -364,7 +364,7 @@ proc transformYield(c: PTransf, n: PNode): PNode = if e.typ.isNil: return result # can happen in nimsuggest for unknown reasons if c.transCon.forStmt.len != 3: e = skipConv(e) - if e.kind in {nkPar, nkTupleConstr}: + if e.kind == nkTupleConstr: for i in 0..<e.len: var v = e[i] if v.kind == nkExprColonExpr: v = v[1] @@ -377,19 +377,51 @@ proc transformYield(c: PTransf, n: PNode): PNode = let lhs = c.transCon.forStmt[i] let rhs = transform(c, v) result.add(asgnTo(lhs, rhs)) + elif e.kind notin {nkAddr, nkHiddenAddr}: # no need to generate temp for address operation + # TODO do not use temp for nodes which cannot have side-effects + var tmp = newTemp(c, e.typ, e.info) + let v = newNodeI(nkVarSection, e.info) + v.addVar(tmp, e) + + result.add transform(c, v) + + for i in 0..<c.transCon.forStmt.len - 2: + let lhs = c.transCon.forStmt[i] + let rhs = transform(c, newTupleAccess(c.graph, tmp, i)) + result.add(asgnTo(lhs, rhs)) else: - # Unpack the tuple into the loop variables - # XXX: BUG: what if `n` is an expression with side-effects? for i in 0..<c.transCon.forStmt.len - 2: let lhs = c.transCon.forStmt[i] let rhs = transform(c, newTupleAccess(c.graph, e, i)) result.add(asgnTo(lhs, rhs)) else: if c.transCon.forStmt[0].kind == nkVarTuple: - for i in 0..<c.transCon.forStmt[0].len-1: - let lhs = c.transCon.forStmt[0][i] - let rhs = transform(c, newTupleAccess(c.graph, e, i)) - result.add(asgnTo(lhs, rhs)) + var notLiteralTuple = false # we don't generate temp for tuples with const value: (1, 2, 3) + let ev = e.skipConv + if ev.kind == nkTupleConstr: + for i in ev: + if not isConstExpr(i): + notLiteralTuple = true + break + else: + notLiteralTuple = true + + if e.kind notin {nkAddr, nkHiddenAddr} and notLiteralTuple: + # TODO do not use temp for nodes which cannot have side-effects + var tmp = newTemp(c, e.typ, e.info) + let v = newNodeI(nkVarSection, e.info) + v.addVar(tmp, e) + + result.add transform(c, v) + for i in 0..<c.transCon.forStmt[0].len-1: + let lhs = c.transCon.forStmt[0][i] + let rhs = transform(c, newTupleAccess(c.graph, tmp, i)) + result.add(asgnTo(lhs, rhs)) + else: + for i in 0..<c.transCon.forStmt[0].len-1: + let lhs = c.transCon.forStmt[0][i] + let rhs = transform(c, newTupleAccess(c.graph, e, i)) + result.add(asgnTo(lhs, rhs)) else: let lhs = c.transCon.forStmt[0] let rhs = transform(c, e) @@ -403,10 +435,11 @@ proc transformYield(c: PTransf, n: PNode): PNode = # we need to introduce new local variables: result.add(introduceNewLocalVars(c, c.transCon.forLoopBody)) if result.len > 0: - var changeNode = result[0] - changeNode.info = c.transCon.forStmt.info - for i, child in changeNode: - child.info = changeNode.info + for idx in 0 ..< result.len: + var changeNode = result[idx] + changeNode.info = c.transCon.forStmt.info + for i, child in changeNode: + child.info = changeNode.info proc transformAddrDeref(c: PTransf, n: PNode, a, b: TNodeKind): PNode = result = transformSons(c, n) diff --git a/tests/stdlib/tyield.nim b/tests/stdlib/tyield.nim new file mode 100644 index 000000000..85be97365 --- /dev/null +++ b/tests/stdlib/tyield.nim @@ -0,0 +1,256 @@ +discard """ + targets: "c cpp js" +""" + +import std/[sugar, algorithm] + +block: + var x = @[(6.0, 6, '6'), + (5.0, 5, '5'), + (4.0, 4, '4'), + (3.0, 3, '3'), + (2.0, 2, '2'), + (1.0, 1, '1')] + + let y = x.reversed + + block: + let res = collect: + for (f, i, c) in x: + (f, i, c) + + doAssert res == x + + iterator popAscending[T](q: var seq[T]): T = + while q.len > 0: yield q.pop + + block: + var res = collect: + for f, i, c in popAscending(x): + (f, i, c) + + doAssert res == y + + let z = reversed(res) + let res2 = collect: + for (f, i, c) in popAscending(res): + (f, i, c) + + doAssert res2 == z + + +block: + var visits = 0 + block: + proc bar(): (int, int) = + inc visits + (visits, visits) + + iterator foo(): (int, int) = + yield bar() + + for a, b in foo(): + doAssert a == b + + doAssert visits == 1 + + block: + proc iterAux(a: seq[int], i: var int): (int, string) = + result = (a[i], $a[i]) + inc i + + iterator pairs(a: seq[int]): (int, string) = + var i = 0 + while i < a.len: + yield iterAux(a, i) + + var x = newSeq[int](10) + for i in 0 ..< x.len: + x[i] = i + + let res = collect: + for k, v in x: + (k, v) + + let expected = collect: + for i in 0 ..< x.len: + (i, $i) + + doAssert res == expected + + block: + proc bar(): (int, int, int) = + inc visits + (visits, visits, visits) + + iterator foo(): (int, int, int) = + yield bar() + + for a, b, c in foo(): + doAssert a == b + + doAssert visits == 2 + + + block: + + proc bar(): int = + inc visits + visits + + proc car(): int = + inc visits + visits + + iterator foo(): (int, int) = + yield (bar(), car()) + yield (bar(), car()) + + for a, b in foo(): + doAssert b == a + 1 + + doAssert visits == 6 + + + block: + proc bar(): (int, int) = + inc visits + (visits, visits) + + proc t2(): int = 99 + + iterator foo(): (int, int) = + yield (12, t2()) + yield bar() + + let res = collect: + for (a, b) in foo(): + (a, b) + + doAssert res == @[(12, 99), (7, 7)] + doAssert visits == 7 + + block: + proc bar(): (int, int) = + inc visits + (visits, visits) + + proc t2(): int = 99 + + iterator foo(): (int, int) = + yield ((12, t2())) + yield (bar()) + + let res = collect: + for (a, b) in foo(): + (a, b) + + doAssert res == @[(12, 99), (8, 8)] + doAssert visits == 8 + + block: + proc bar(): (int, int) = + inc visits + (visits, visits) + + proc t1(): int = 99 + proc t2(): int = 99 + + iterator foo(): (int, int) = + yield (t1(), t2()) + yield bar() + + let res = collect: + for a, b in foo(): + (a, b) + + doAssert res == @[(99, 99), (9, 9)] + doAssert visits == 9 + + + block: + proc bar(): ((int, int), string) = + inc visits + ((visits, visits), $visits) + + proc t2(): int = 99 + + iterator foo(): ((int, int), string) = + yield ((1, 2), $t2()) + yield bar() + + let res = collect: + for a, b in foo(): + (a, b) + + doAssert res == @[((1, 2), "99"), ((10, 10), "10")] + doAssert visits == 10 + + + block: + proc bar(): (int, int) = + inc visits + (visits, visits) + + iterator foo(): (int, int) = + yield (for i in 0 ..< 10: discard bar(); bar()) + yield (bar()) + + let res = collect: + for (a, b) in foo(): + (a, b) + + doAssert res == @[(21, 21), (22, 22)] + + block: + proc bar(): (int, int) = + inc visits + (visits, visits) + + proc t2(): int = 99 + + iterator foo(): (int, int) = + yield if true: bar() else: (t2(), t2()) + yield (bar()) + + let res = collect: + for a, b in foo(): + (a, b) + + doAssert res == @[(23, 23), (24, 24)] + + +block: + iterator foo(): (int, int, int) = + var time = 777 + yield (1, time, 3) + + let res = collect: + for a, b, c in foo(): + (a, b, c) + + doAssert res == @[(1, 777, 3)] + +block: + iterator foo(): (int, int, int) = + var time = 777 + yield (1, time, 3) + + let res = collect: + for t in foo(): + (t[0], t[1], t[2]) + + doAssert res == @[(1, 777, 3)] + + +block: + proc bar(): (int, int, int) = + (1, 2, 3) + iterator foo(): (int, int, int) = + yield bar() + + let res = collect: + for a, b, c in foo(): + (a, b, c) + + doAssert res == @[(1, 2, 3)] |