summary refs log tree commit diff stats
path: root/compiler
diff options
context:
space:
mode:
Diffstat (limited to 'compiler')
-rw-r--r--compiler/ast.nim8
-rw-r--r--compiler/astalgo.nim42
-rw-r--r--compiler/lowerings.nim12
-rw-r--r--compiler/sem.nim18
-rw-r--r--compiler/semcall.nim4
-rw-r--r--compiler/seminst.nim26
-rw-r--r--compiler/semtypes.nim2
-rw-r--r--compiler/sigmatch.nim3
-rw-r--r--compiler/transf.nim46
9 files changed, 130 insertions, 31 deletions
diff --git a/compiler/ast.nim b/compiler/ast.nim
index 085a243b3..6302c21b9 100644
--- a/compiler/ast.nim
+++ b/compiler/ast.nim
@@ -293,6 +293,10 @@ const
     # the compiler will avoid printing such names
     # in user messages.
 
+  sfHoisted* = sfForward
+    # an expression was hoised to an anonymous variable.
+    # the flag is applied to the var/let symbol
+
   sfNoForward* = sfRegister
     # forward declarations are not required (per module)
   sfReorder* = sfForward
@@ -455,6 +459,8 @@ type
     nfBlockArg  # this a stmtlist appearing in a call (e.g. a do block)
     nfFromTemplate # a top-level node returned from a template
     nfDefaultParam # an automatically inserter default parameter
+    nfDefaultRefsParam # a default param value references another parameter
+                       # the flag is applied to proc default values and to calls
 
   TNodeFlags* = set[TNodeFlag]
   TTypeFlag* = enum   # keep below 32 for efficiency reasons (now: beyond that)
@@ -972,7 +978,7 @@ const
   PersistentNodeFlags*: TNodeFlags = {nfBase2, nfBase8, nfBase16,
                                       nfDotSetter, nfDotField,
                                       nfIsRef, nfPreventCg, nfLL,
-                                      nfFromTemplate}
+                                      nfFromTemplate, nfDefaultRefsParam}
   namePos* = 0
   patternPos* = 1    # empty except for term rewriting macros
   genericParamsPos* = 2
diff --git a/compiler/astalgo.nim b/compiler/astalgo.nim
index a4a14405e..290ac05ee 100644
--- a/compiler/astalgo.nim
+++ b/compiler/astalgo.nim
@@ -34,32 +34,36 @@ when declared(echo):
   template debug*(x: PSym|PType|PNode) {.deprecated.} =
     when compiles(c.config):
       debug(c.config, x)
+    elif compiles(c.graph.config):
+      debug(c.graph.config, x)
     else:
       error()
 
   template debug*(x: auto) {.deprecated.} =
     echo x
 
-template mdbg*: bool {.dirty.} =
-  when compiles(c.module):
-    c.module.fileIdx == c.config.projectMainIdx
-  elif compiles(c.c.module):
-    c.c.module.fileIdx == c.c.config.projectMainIdx
-  elif compiles(m.c.module):
-    m.c.module.fileIdx == m.c.config.projectMainIdx
-  elif compiles(cl.c.module):
-    cl.c.module.fileIdx == cl.c.config.projectMainIdx
-  elif compiles(p):
-    when compiles(p.lex):
-      p.lex.fileIdx == p.lex.config.projectMainIdx
+  template mdbg*: bool {.deprecated.} =
+    when compiles(c.graph):
+      c.module.fileIdx == c.graph.config.projectMainIdx
+    elif compiles(c.module):
+      c.module.fileIdx == c.config.projectMainIdx
+    elif compiles(c.c.module):
+      c.c.module.fileIdx == c.c.config.projectMainIdx
+    elif compiles(m.c.module):
+      m.c.module.fileIdx == m.c.config.projectMainIdx
+    elif compiles(cl.c.module):
+      cl.c.module.fileIdx == cl.c.config.projectMainIdx
+    elif compiles(p):
+      when compiles(p.lex):
+        p.lex.fileIdx == p.lex.config.projectMainIdx
+      else:
+        p.module.module.fileIdx == p.config.projectMainIdx
+    elif compiles(m.module.fileIdx):
+      m.module.fileIdx == m.config.projectMainIdx
+    elif compiles(L.fileIdx):
+      L.fileIdx == L.config.projectMainIdx
     else:
-      p.module.module.fileIdx == p.config.projectMainIdx
-  elif compiles(m.module.fileIdx):
-    m.module.fileIdx == m.config.projectMainIdx
-  elif compiles(L.fileIdx):
-    L.fileIdx == L.config.projectMainIdx
-  else:
-    error()
+      error()
 
 # --------------------------- ident tables ----------------------------------
 proc idTableGet*(t: TIdTable, key: PIdObj): RootRef
diff --git a/compiler/lowerings.nim b/compiler/lowerings.nim
index 24a4f186e..1b17f620c 100644
--- a/compiler/lowerings.nim
+++ b/compiler/lowerings.nim
@@ -336,6 +336,18 @@ proc typeNeedsNoDeepCopy(t: PType): bool =
   if t.kind in {tyVar, tyLent, tySequence}: t = t.lastSon
   result = not containsGarbageCollectedRef(t)
 
+proc hoistExpr*(varSection, expr: PNode, name: PIdent, owner: PSym): PSym =
+  result = newSym(skLet, name, owner, varSection.info, owner.options)
+  result.flags.incl sfHoisted
+  result.typ = expr.typ
+
+  var varDef = newNodeI(nkIdentDefs, varSection.info, 3)
+  varDef.sons[0] = newSymNode(result)
+  varDef.sons[1] = newNodeI(nkEmpty, varSection.info)
+  varDef.sons[2] = expr
+
+  varSection.add varDef
+
 proc addLocalVar(g: ModuleGraph; varSection, varInit: PNode; owner: PSym; typ: PType;
                  v: PNode; useShallowCopy=false): PSym =
   result = newSym(skTemp, getIdent(g.cache, genPrefix), owner, varSection.info,
diff --git a/compiler/sem.nim b/compiler/sem.nim
index afc794a37..299286545 100644
--- a/compiler/sem.nim
+++ b/compiler/sem.nim
@@ -73,6 +73,16 @@ template semIdeForTemplateOrGeneric(c: PContext; n: PNode;
       #  echo "passing to safeSemExpr: ", renderTree(n)
       discard safeSemExpr(c, n)
 
+proc fitNodePostMatch(c: PContext, formal: PType, arg: PNode): PNode =
+  result = arg
+  let x = result.skipConv
+  if x.kind in {nkPar, nkTupleConstr} and formal.kind != tyExpr:
+    changeType(c, x, formal, check=true)
+  else:
+    result = skipHiddenSubConv(result)
+    #result.typ = takeType(formal, arg.typ)
+    #echo arg.info, " picked ", result.typ.typeToString
+
 proc fitNode(c: PContext, formal: PType, arg: PNode; info: TLineInfo): PNode =
   if arg.typ.isNil:
     localError(c.config, arg.info, "expression has no type: " &
@@ -88,13 +98,7 @@ proc fitNode(c: PContext, formal: PType, arg: PNode; info: TLineInfo): PNode =
       result = copyTree(arg)
       result.typ = formal
     else:
-      let x = result.skipConv
-      if x.kind in {nkPar, nkTupleConstr} and formal.kind != tyExpr:
-        changeType(c, x, formal, check=true)
-      else:
-        result = skipHiddenSubConv(result)
-        #result.typ = takeType(formal, arg.typ)
-        #echo arg.info, " picked ", result.typ.typeToString
+      result = fitNodePostMatch(c, formal, result)
 
 proc inferWithMetatype(c: PContext, formal: PType,
                        arg: PNode, coerceDistincts = false): PNode
diff --git a/compiler/semcall.nim b/compiler/semcall.nim
index 0de22cfb3..67fe99232 100644
--- a/compiler/semcall.nim
+++ b/compiler/semcall.nim
@@ -402,7 +402,9 @@ proc updateDefaultParams(call: PNode) =
   for i in countdown(call.len - 1, 1):
     if nfDefaultParam notin call[i].flags:
       return
-    call[i] = calleeParams[i].sym.ast
+    let def = calleeParams[i].sym.ast
+    if nfDefaultRefsParam in def.flags: call.flags.incl nfDefaultRefsParam
+    call[i] = def
 
 proc semResolvedCall(c: PContext, x: TCandidate,
                      n: PNode, flags: TExprFlags): PNode =
diff --git a/compiler/seminst.nim b/compiler/seminst.nim
index 0ad1fb872..fac04e3a0 100644
--- a/compiler/seminst.nim
+++ b/compiler/seminst.nim
@@ -220,6 +220,14 @@ proc instGenericContainer(c: PContext, info: TLineInfo, header: PType,
   result = replaceTypeVarsT(cl, header)
   closeScope(c)
 
+proc referencesAnotherParam(n: PNode, p: PSym): bool =
+  if n.kind == nkSym:
+    return n.sym.kind == skParam and n.sym.owner == p
+  else:
+    for i in 0..<n.safeLen:
+      if referencesAnotherParam(n[i], p): return true
+    return false
+
 proc instantiateProcType(c: PContext, pt: TIdTable,
                          prc: PSym, info: TLineInfo) =
   # XXX: Instantiates a generic proc signature, while at the same
@@ -276,8 +284,22 @@ proc instantiateProcType(c: PContext, pt: TIdTable,
       if def.kind == nkCall:
         for i in 1 ..< def.len:
           def[i] = replaceTypeVarsN(cl, def[i])
-        def = semExprWithType(c, def)
-      param.ast = fitNode(c, typeToFit, def, def.info)
+
+      def = semExprWithType(c, def)
+      if def.referencesAnotherParam(getCurrOwner(c)):
+        def.flags.incl nfDefaultRefsParam
+
+      var converted = indexTypesMatch(c, typeToFit, def.typ, def)
+      if converted == nil:
+        # The default value doesn't match the final instantiated type.
+        # As an example of this, see:
+        # https://github.com/nim-lang/Nim/issues/1201
+        # We are replacing the default value with an error node in case
+        # the user calls an explicit instantiation of the proc (this is
+        # the only way the default value might be inserted).
+        param.ast = errorNode(c, def)
+      else:
+        param.ast = fitNodePostMatch(c, typeToFit, converted)
       param.typ = result[i]
 
     result.n[i] = newSymNode(param)
diff --git a/compiler/semtypes.nim b/compiler/semtypes.nim
index ff2820ec8..f0f22e87c 100644
--- a/compiler/semtypes.nim
+++ b/compiler/semtypes.nim
@@ -1042,6 +1042,8 @@ proc semProcTypeNode(c: PContext, n, genericParams: PNode,
             break determineType
 
         def = semExprWithType(c, def, {efDetermineType})
+        if def.referencesAnotherParam(getCurrOwner(c)):
+          def.flags.incl nfDefaultRefsParam
 
       if typ == nil:
         typ = def.typ
diff --git a/compiler/sigmatch.nim b/compiler/sigmatch.nim
index 537efd55c..0bb7c4fdd 100644
--- a/compiler/sigmatch.nim
+++ b/compiler/sigmatch.nim
@@ -2366,7 +2366,8 @@ proc matches*(c: PContext, n, nOrig: PNode, m: var TCandidate) =
           m.firstMismatch = f
           break
       else:
-        # use default value:
+        if nfDefaultRefsParam in formal.ast.flags:
+          m.call.flags.incl nfDefaultRefsParam
         var def = copyTree(formal.ast)
         if def.kind == nkNilLit:
           def = implicitConv(nkHiddenStdConv, formal.typ, def, m, c)
diff --git a/compiler/transf.nim b/compiler/transf.nim
index ad7f38b91..abe713eb8 100644
--- a/compiler/transf.nim
+++ b/compiler/transf.nim
@@ -780,6 +780,43 @@ proc commonOptimizations*(g: ModuleGraph; c: PSym, n: PNode): PNode =
     else:
       result = n
 
+proc hoistParamsUsedInDefault(c: PTransf, call, letSection, defExpr: PNode): PNode =
+  # This takes care of complicated signatures such as:
+  # proc foo(a: int, b = a)
+  # proc bar(a: int, b: int, c = a + b)
+  #
+  # The recursion may confuse you. It performs two duties:
+  #
+  # 1) extracting all referenced params from default expressions
+  #    into a let section preceeding the call
+  #
+  # 2) replacing the "references" within the default expression
+  #    with these extracted skLet symbols.
+  #
+  # The first duty is carried out directly in the code here, while the second
+  # duty is activated by returning a non-nil value. The caller is responsible
+  # for replacing the input to the function with the returned non-nil value.
+  # (which is the hoisted symbol)
+  if defExpr.kind == nkSym:
+    if defExpr.sym.kind == skParam and defExpr.sym.owner == call[0].sym:
+      let paramPos = defExpr.sym.position + 1
+
+      if call[paramPos].kind == nkSym and sfHoisted in call[paramPos].sym.flags:
+        # Already hoisted, we still need to return it in order to replace the
+        # placeholder expression in the default value.
+        return call[paramPos]
+
+      let hoistedVarSym = hoistExpr(letSection,
+                                    call[paramPos],
+                                    getIdent(c.graph.cache, genPrefix),
+                                    c.transCon.owner).newSymNode
+      call[paramPos] = hoistedVarSym
+      return hoistedVarSym
+  else:
+    for i in 0..<defExpr.safeLen:
+      let hoisted = hoistParamsUsedInDefault(c, call, letSection, defExpr[i])
+      if hoisted != nil: defExpr[i] = hoisted
+
 proc transform(c: PTransf, n: PNode): PTransNode =
   when false:
     var oldDeferAnchor: PNode
@@ -849,6 +886,15 @@ proc transform(c: PTransf, n: PNode): PTransNode =
   of nkBreakStmt: result = transformBreak(c, n)
   of nkCallKinds:
     result = transformCall(c, n)
+    var call = result.PNode
+    if nfDefaultRefsParam in call.flags:
+      # We've found a default value that references another param.
+      # See the notes in `hoistParamsUsedInDefault` for more details.
+      var hoistedParams = newNodeI(nkLetSection, call.info, 0)
+      for i in 1 ..< call.len:
+        let hoisted = hoistParamsUsedInDefault(c, call, hoistedParams, call[i])
+        if hoisted != nil: call[i] = hoisted
+      result = newTree(nkStmtListExpr, hoistedParams, call).PTransNode
   of nkAddr, nkHiddenAddr:
     result = transformAddrDeref(c, n, nkDerefExpr, nkHiddenDeref)
   of nkDerefExpr, nkHiddenDeref: