summary refs log tree commit diff stats
path: root/compiler/semexprs.nim
diff options
context:
space:
mode:
Diffstat (limited to 'compiler/semexprs.nim')
-rw-r--r--compiler/semexprs.nim16
1 files changed, 14 insertions, 2 deletions
diff --git a/compiler/semexprs.nim b/compiler/semexprs.nim
index 203a51816..c16ab9b87 100644
--- a/compiler/semexprs.nim
+++ b/compiler/semexprs.nim
@@ -1119,6 +1119,9 @@ proc asgnToResultVar(c: PContext, n, le, ri: PNode) {.inline.} =
       n.sons[0] = x # 'result[]' --> 'result'
       n.sons[1] = takeImplicitAddr(c, ri)
 
+template resultTypeIsInferrable(typ: PType): expr =
+  typ.isMetaType and typ.kind != tyTypeDesc
+
 proc semAsgn(c: PContext, n: PNode): PNode =
   checkSonsLen(n, 2)
   var a = n.sons[0]
@@ -1170,7 +1173,7 @@ proc semAsgn(c: PContext, n: PNode): PNode =
         if lhsIsResult: {efAllowDestructor} else: {})
     if lhsIsResult:
       n.typ = enforceVoidContext
-      if lhs.sym.typ.isMetaType and lhs.sym.typ.kind != tyTypeDesc:
+      if resultTypeIsInferrable(lhs.sym.typ):
         if cmpTypes(c, lhs.typ, rhs.typ) == isGeneric:
           internalAssert c.p.resultSym != nil
           lhs.typ = rhs.typ
@@ -1259,12 +1262,21 @@ proc semYield(c: PContext, n: PNode): PNode =
     localError(n.info, errYieldNotAllowedInTryStmt)
   elif n.sons[0].kind != nkEmpty:
     n.sons[0] = semExprWithType(c, n.sons[0]) # check for type compatibility:
-    var restype = c.p.owner.typ.sons[0]
+    var iterType = c.p.owner.typ
+    var restype = iterType.sons[0]
     if restype != nil:
       let adjustedRes = if c.p.owner.kind == skIterator: restype.base
                         else: restype
       n.sons[0] = fitNode(c, adjustedRes, n.sons[0])
       if n.sons[0].typ == nil: internalError(n.info, "semYield")
+      
+      if resultTypeIsInferrable(adjustedRes):
+        let inferred = n.sons[0].typ
+        if c.p.owner.kind == skIterator:
+          iterType.sons[0].sons[0] = inferred
+        else:
+          iterType.sons[0] = inferred
+      
       semYieldVarResult(c, n, adjustedRes)
     else:
       localError(n.info, errCannotReturnExpr)