diff options
Diffstat (limited to 'compiler')
-rw-r--r-- | compiler/ast.nim | 8 | ||||
-rw-r--r-- | compiler/astalgo.nim | 42 | ||||
-rw-r--r-- | compiler/lowerings.nim | 12 | ||||
-rw-r--r-- | compiler/sem.nim | 18 | ||||
-rw-r--r-- | compiler/semcall.nim | 4 | ||||
-rw-r--r-- | compiler/seminst.nim | 26 | ||||
-rw-r--r-- | compiler/semtypes.nim | 2 | ||||
-rw-r--r-- | compiler/sigmatch.nim | 3 | ||||
-rw-r--r-- | compiler/transf.nim | 46 |
9 files changed, 130 insertions, 31 deletions
diff --git a/compiler/ast.nim b/compiler/ast.nim index 085a243b3..6302c21b9 100644 --- a/compiler/ast.nim +++ b/compiler/ast.nim @@ -293,6 +293,10 @@ const # the compiler will avoid printing such names # in user messages. + sfHoisted* = sfForward + # an expression was hoised to an anonymous variable. + # the flag is applied to the var/let symbol + sfNoForward* = sfRegister # forward declarations are not required (per module) sfReorder* = sfForward @@ -455,6 +459,8 @@ type nfBlockArg # this a stmtlist appearing in a call (e.g. a do block) nfFromTemplate # a top-level node returned from a template nfDefaultParam # an automatically inserter default parameter + nfDefaultRefsParam # a default param value references another parameter + # the flag is applied to proc default values and to calls TNodeFlags* = set[TNodeFlag] TTypeFlag* = enum # keep below 32 for efficiency reasons (now: beyond that) @@ -972,7 +978,7 @@ const PersistentNodeFlags*: TNodeFlags = {nfBase2, nfBase8, nfBase16, nfDotSetter, nfDotField, nfIsRef, nfPreventCg, nfLL, - nfFromTemplate} + nfFromTemplate, nfDefaultRefsParam} namePos* = 0 patternPos* = 1 # empty except for term rewriting macros genericParamsPos* = 2 diff --git a/compiler/astalgo.nim b/compiler/astalgo.nim index a4a14405e..290ac05ee 100644 --- a/compiler/astalgo.nim +++ b/compiler/astalgo.nim @@ -34,32 +34,36 @@ when declared(echo): template debug*(x: PSym|PType|PNode) {.deprecated.} = when compiles(c.config): debug(c.config, x) + elif compiles(c.graph.config): + debug(c.graph.config, x) else: error() template debug*(x: auto) {.deprecated.} = echo x -template mdbg*: bool {.dirty.} = - when compiles(c.module): - c.module.fileIdx == c.config.projectMainIdx - elif compiles(c.c.module): - c.c.module.fileIdx == c.c.config.projectMainIdx - elif compiles(m.c.module): - m.c.module.fileIdx == m.c.config.projectMainIdx - elif compiles(cl.c.module): - cl.c.module.fileIdx == cl.c.config.projectMainIdx - elif compiles(p): - when compiles(p.lex): - p.lex.fileIdx == p.lex.config.projectMainIdx + template mdbg*: bool {.deprecated.} = + when compiles(c.graph): + c.module.fileIdx == c.graph.config.projectMainIdx + elif compiles(c.module): + c.module.fileIdx == c.config.projectMainIdx + elif compiles(c.c.module): + c.c.module.fileIdx == c.c.config.projectMainIdx + elif compiles(m.c.module): + m.c.module.fileIdx == m.c.config.projectMainIdx + elif compiles(cl.c.module): + cl.c.module.fileIdx == cl.c.config.projectMainIdx + elif compiles(p): + when compiles(p.lex): + p.lex.fileIdx == p.lex.config.projectMainIdx + else: + p.module.module.fileIdx == p.config.projectMainIdx + elif compiles(m.module.fileIdx): + m.module.fileIdx == m.config.projectMainIdx + elif compiles(L.fileIdx): + L.fileIdx == L.config.projectMainIdx else: - p.module.module.fileIdx == p.config.projectMainIdx - elif compiles(m.module.fileIdx): - m.module.fileIdx == m.config.projectMainIdx - elif compiles(L.fileIdx): - L.fileIdx == L.config.projectMainIdx - else: - error() + error() # --------------------------- ident tables ---------------------------------- proc idTableGet*(t: TIdTable, key: PIdObj): RootRef diff --git a/compiler/lowerings.nim b/compiler/lowerings.nim index 24a4f186e..1b17f620c 100644 --- a/compiler/lowerings.nim +++ b/compiler/lowerings.nim @@ -336,6 +336,18 @@ proc typeNeedsNoDeepCopy(t: PType): bool = if t.kind in {tyVar, tyLent, tySequence}: t = t.lastSon result = not containsGarbageCollectedRef(t) +proc hoistExpr*(varSection, expr: PNode, name: PIdent, owner: PSym): PSym = + result = newSym(skLet, name, owner, varSection.info, owner.options) + result.flags.incl sfHoisted + result.typ = expr.typ + + var varDef = newNodeI(nkIdentDefs, varSection.info, 3) + varDef.sons[0] = newSymNode(result) + varDef.sons[1] = newNodeI(nkEmpty, varSection.info) + varDef.sons[2] = expr + + varSection.add varDef + proc addLocalVar(g: ModuleGraph; varSection, varInit: PNode; owner: PSym; typ: PType; v: PNode; useShallowCopy=false): PSym = result = newSym(skTemp, getIdent(g.cache, genPrefix), owner, varSection.info, diff --git a/compiler/sem.nim b/compiler/sem.nim index afc794a37..299286545 100644 --- a/compiler/sem.nim +++ b/compiler/sem.nim @@ -73,6 +73,16 @@ template semIdeForTemplateOrGeneric(c: PContext; n: PNode; # echo "passing to safeSemExpr: ", renderTree(n) discard safeSemExpr(c, n) +proc fitNodePostMatch(c: PContext, formal: PType, arg: PNode): PNode = + result = arg + let x = result.skipConv + if x.kind in {nkPar, nkTupleConstr} and formal.kind != tyExpr: + changeType(c, x, formal, check=true) + else: + result = skipHiddenSubConv(result) + #result.typ = takeType(formal, arg.typ) + #echo arg.info, " picked ", result.typ.typeToString + proc fitNode(c: PContext, formal: PType, arg: PNode; info: TLineInfo): PNode = if arg.typ.isNil: localError(c.config, arg.info, "expression has no type: " & @@ -88,13 +98,7 @@ proc fitNode(c: PContext, formal: PType, arg: PNode; info: TLineInfo): PNode = result = copyTree(arg) result.typ = formal else: - let x = result.skipConv - if x.kind in {nkPar, nkTupleConstr} and formal.kind != tyExpr: - changeType(c, x, formal, check=true) - else: - result = skipHiddenSubConv(result) - #result.typ = takeType(formal, arg.typ) - #echo arg.info, " picked ", result.typ.typeToString + result = fitNodePostMatch(c, formal, result) proc inferWithMetatype(c: PContext, formal: PType, arg: PNode, coerceDistincts = false): PNode diff --git a/compiler/semcall.nim b/compiler/semcall.nim index 0de22cfb3..67fe99232 100644 --- a/compiler/semcall.nim +++ b/compiler/semcall.nim @@ -402,7 +402,9 @@ proc updateDefaultParams(call: PNode) = for i in countdown(call.len - 1, 1): if nfDefaultParam notin call[i].flags: return - call[i] = calleeParams[i].sym.ast + let def = calleeParams[i].sym.ast + if nfDefaultRefsParam in def.flags: call.flags.incl nfDefaultRefsParam + call[i] = def proc semResolvedCall(c: PContext, x: TCandidate, n: PNode, flags: TExprFlags): PNode = diff --git a/compiler/seminst.nim b/compiler/seminst.nim index 0ad1fb872..fac04e3a0 100644 --- a/compiler/seminst.nim +++ b/compiler/seminst.nim @@ -220,6 +220,14 @@ proc instGenericContainer(c: PContext, info: TLineInfo, header: PType, result = replaceTypeVarsT(cl, header) closeScope(c) +proc referencesAnotherParam(n: PNode, p: PSym): bool = + if n.kind == nkSym: + return n.sym.kind == skParam and n.sym.owner == p + else: + for i in 0..<n.safeLen: + if referencesAnotherParam(n[i], p): return true + return false + proc instantiateProcType(c: PContext, pt: TIdTable, prc: PSym, info: TLineInfo) = # XXX: Instantiates a generic proc signature, while at the same @@ -276,8 +284,22 @@ proc instantiateProcType(c: PContext, pt: TIdTable, if def.kind == nkCall: for i in 1 ..< def.len: def[i] = replaceTypeVarsN(cl, def[i]) - def = semExprWithType(c, def) - param.ast = fitNode(c, typeToFit, def, def.info) + + def = semExprWithType(c, def) + if def.referencesAnotherParam(getCurrOwner(c)): + def.flags.incl nfDefaultRefsParam + + var converted = indexTypesMatch(c, typeToFit, def.typ, def) + if converted == nil: + # The default value doesn't match the final instantiated type. + # As an example of this, see: + # https://github.com/nim-lang/Nim/issues/1201 + # We are replacing the default value with an error node in case + # the user calls an explicit instantiation of the proc (this is + # the only way the default value might be inserted). + param.ast = errorNode(c, def) + else: + param.ast = fitNodePostMatch(c, typeToFit, converted) param.typ = result[i] result.n[i] = newSymNode(param) diff --git a/compiler/semtypes.nim b/compiler/semtypes.nim index ff2820ec8..f0f22e87c 100644 --- a/compiler/semtypes.nim +++ b/compiler/semtypes.nim @@ -1042,6 +1042,8 @@ proc semProcTypeNode(c: PContext, n, genericParams: PNode, break determineType def = semExprWithType(c, def, {efDetermineType}) + if def.referencesAnotherParam(getCurrOwner(c)): + def.flags.incl nfDefaultRefsParam if typ == nil: typ = def.typ diff --git a/compiler/sigmatch.nim b/compiler/sigmatch.nim index 537efd55c..0bb7c4fdd 100644 --- a/compiler/sigmatch.nim +++ b/compiler/sigmatch.nim @@ -2366,7 +2366,8 @@ proc matches*(c: PContext, n, nOrig: PNode, m: var TCandidate) = m.firstMismatch = f break else: - # use default value: + if nfDefaultRefsParam in formal.ast.flags: + m.call.flags.incl nfDefaultRefsParam var def = copyTree(formal.ast) if def.kind == nkNilLit: def = implicitConv(nkHiddenStdConv, formal.typ, def, m, c) diff --git a/compiler/transf.nim b/compiler/transf.nim index ad7f38b91..abe713eb8 100644 --- a/compiler/transf.nim +++ b/compiler/transf.nim @@ -780,6 +780,43 @@ proc commonOptimizations*(g: ModuleGraph; c: PSym, n: PNode): PNode = else: result = n +proc hoistParamsUsedInDefault(c: PTransf, call, letSection, defExpr: PNode): PNode = + # This takes care of complicated signatures such as: + # proc foo(a: int, b = a) + # proc bar(a: int, b: int, c = a + b) + # + # The recursion may confuse you. It performs two duties: + # + # 1) extracting all referenced params from default expressions + # into a let section preceeding the call + # + # 2) replacing the "references" within the default expression + # with these extracted skLet symbols. + # + # The first duty is carried out directly in the code here, while the second + # duty is activated by returning a non-nil value. The caller is responsible + # for replacing the input to the function with the returned non-nil value. + # (which is the hoisted symbol) + if defExpr.kind == nkSym: + if defExpr.sym.kind == skParam and defExpr.sym.owner == call[0].sym: + let paramPos = defExpr.sym.position + 1 + + if call[paramPos].kind == nkSym and sfHoisted in call[paramPos].sym.flags: + # Already hoisted, we still need to return it in order to replace the + # placeholder expression in the default value. + return call[paramPos] + + let hoistedVarSym = hoistExpr(letSection, + call[paramPos], + getIdent(c.graph.cache, genPrefix), + c.transCon.owner).newSymNode + call[paramPos] = hoistedVarSym + return hoistedVarSym + else: + for i in 0..<defExpr.safeLen: + let hoisted = hoistParamsUsedInDefault(c, call, letSection, defExpr[i]) + if hoisted != nil: defExpr[i] = hoisted + proc transform(c: PTransf, n: PNode): PTransNode = when false: var oldDeferAnchor: PNode @@ -849,6 +886,15 @@ proc transform(c: PTransf, n: PNode): PTransNode = of nkBreakStmt: result = transformBreak(c, n) of nkCallKinds: result = transformCall(c, n) + var call = result.PNode + if nfDefaultRefsParam in call.flags: + # We've found a default value that references another param. + # See the notes in `hoistParamsUsedInDefault` for more details. + var hoistedParams = newNodeI(nkLetSection, call.info, 0) + for i in 1 ..< call.len: + let hoisted = hoistParamsUsedInDefault(c, call, hoistedParams, call[i]) + if hoisted != nil: call[i] = hoisted + result = newTree(nkStmtListExpr, hoistedParams, call).PTransNode of nkAddr, nkHiddenAddr: result = transformAddrDeref(c, n, nkDerefExpr, nkHiddenDeref) of nkDerefExpr, nkHiddenDeref: |