summary refs log tree commit diff stats
diff options
context:
space:
mode:
authorAndreas Rumpf <rumpf_a@web.de>2019-09-02 15:47:56 +0200
committerGitHub <noreply@github.com>2019-09-02 15:47:56 +0200
commit7ef85db9a9e3df4d6630673ceac33e9fb986e2ed (patch)
tree529940ba28d24b3e396742352570cc742d92ab5b
parentfc7fe636e22f1a9a502e805256c01c70e8695f88 (diff)
downloadNim-7ef85db9a9e3df4d6630673ceac33e9fb986e2ed.tar.gz
fixes #12020 (#12106)
-rw-r--r--compiler/astalgo.nim42
-rw-r--r--compiler/sigmatch.nim2
-rw-r--r--tests/template/tparams_gensymed.nim19
3 files changed, 60 insertions, 3 deletions
diff --git a/compiler/astalgo.nim b/compiler/astalgo.nim
index 4100c3629..11144ebf4 100644
--- a/compiler/astalgo.nim
+++ b/compiler/astalgo.nim
@@ -77,7 +77,6 @@ proc idNodeTablePut*(t: var TIdNodeTable, key: PIdObj, val: PNode)
 
 # ---------------------------------------------------------------------------
 
-proc getSymFromList*(list: PNode, ident: PIdent, start: int = 0): PSym
 proc lookupInRecord*(n: PNode, field: PIdent): PSym
 proc mustRehash*(length, counter: int): bool
 proc nextTry*(h, maxHash: Hash): Hash {.inline.}
@@ -174,7 +173,7 @@ proc getModule*(s: PSym): PSym =
   assert((result.kind == skModule) or (result.owner != result))
   while result != nil and result.kind != skModule: result = result.owner
 
-proc getSymFromList(list: PNode, ident: PIdent, start: int = 0): PSym =
+proc getSymFromList*(list: PNode, ident: PIdent, start: int = 0): PSym =
   for i in start ..< sonsLen(list):
     if list.sons[i].kind == nkSym:
       result = list.sons[i].sym
@@ -182,6 +181,45 @@ proc getSymFromList(list: PNode, ident: PIdent, start: int = 0): PSym =
     else: return nil
   result = nil
 
+proc sameIgnoreBacktickGensymInfo(a, b: string): bool =
+  if a[0] != b[0]: return false
+  var last = a.len - 1
+  while last > 0 and a[last] != '`': dec(last)
+
+  var i = 1
+  var j = 1
+  while true:
+    while i < last and a[i] == '_': inc i
+    while j < b.len and b[j] == '_': inc j
+    var aa = if i < last: toLowerAscii(a[i]) else: '\0'
+    var bb = if j < b.len: toLowerAscii(b[j]) else: '\0'
+    if aa != bb: return false
+
+    # the characters are identical:
+    if i >= last:
+      # both cursors at the end:
+      if j >= b.len: return true
+      # not yet at the end of 'b':
+      return false
+    elif j >= b.len:
+      return false
+    inc i
+    inc j
+
+proc getNamedParamFromList*(list: PNode, ident: PIdent): PSym =
+  ## Named parameters are special because a named parameter can be
+  ## gensym'ed and then they have '`<number>' suffix that we need to
+  ## ignore, see compiler / evaltempl.nim, snippet:
+  ##
+  ## .. code-block:: nim
+  ##
+  ##    result.add newIdentNode(getIdent(c.ic, x.name.s & "`gensym" & $x.id),
+  ##            if c.instLines: actual.info else: templ.info)
+  for i in 1 ..< len(list):
+    let it = list[i].sym
+    if it.name.id == ident.id or
+        sameIgnoreBacktickGensymInfo(it.name.s, ident.s): return it
+
 proc hashNode(p: RootRef): Hash =
   result = hash(cast[pointer](p))
 
diff --git a/compiler/sigmatch.nim b/compiler/sigmatch.nim
index ea1945a64..2c22620da 100644
--- a/compiler/sigmatch.nim
+++ b/compiler/sigmatch.nim
@@ -2363,7 +2363,7 @@ proc matchesAux(c: PContext, n, nOrig: PNode,
         localError(c.config, n.sons[a].info, "named parameter has to be an identifier")
         noMatch()
         return
-      formal = getSymFromList(m.callee.n, n.sons[a].sons[0].ident, 1)
+      formal = getNamedParamFromList(m.callee.n, n.sons[a].sons[0].ident)
       if formal == nil:
         # no error message!
         noMatch()
diff --git a/tests/template/tparams_gensymed.nim b/tests/template/tparams_gensymed.nim
index f7a02efa0..fe5608add 100644
--- a/tests/template/tparams_gensymed.nim
+++ b/tests/template/tparams_gensymed.nim
@@ -9,6 +9,11 @@ output: '''
 2
 3
 wth
+3
+2
+1
+0
+(total: 6)
 '''
 """
 # bug #1915
@@ -145,3 +150,17 @@ macro m(): untyped =
 
 let meh = m()
 meh("wth")
+
+
+macro foo(body: untyped): untyped =
+  result = body
+
+template baz(): untyped =
+  foo:
+    proc bar2(b: int): int =
+      echo b
+      if b > 0: b + bar2(b = b - 1)
+      else: 0
+  echo (total: bar2(3))
+
+baz()