summary refs log tree commit diff stats
path: root/compiler/patterns.nim
diff options
context:
space:
mode:
Diffstat (limited to 'compiler/patterns.nim')
-rw-r--r--compiler/patterns.nim200
1 files changed, 156 insertions, 44 deletions
diff --git a/compiler/patterns.nim b/compiler/patterns.nim
index 7109d9975..ceadfe350 100644
--- a/compiler/patterns.nim
+++ b/compiler/patterns.nim
@@ -10,13 +10,16 @@
 ## This module implements the pattern matching features for term rewriting
 ## macro support.
 
-import ast, astalgo, types, semdata, sigmatch, msgs, idents
+import
+  ast, astalgo, types, semdata, sigmatch, msgs, idents, aliases, parampatterns,
+  trees
 
 type
   TPatternContext = object
     owner: PSym
     mapping: TIdNodeTable  # maps formal parameters to nodes
     c: PContext
+    subMatch: bool         # subnode matches are special
   PPatternContext = var TPatternContext
 
 proc matches(c: PPatternContext, p, n: PNode): bool
@@ -53,38 +56,88 @@ proc sameTrees(a, b: PNode): bool =
         result = true
 
 proc inSymChoice(sc, x: PNode): bool =
-  if sc.kind in {nkOpenSymChoice, nkClosedSymChoice}:
+  if sc.kind == nkClosedSymChoice:
     for i in 0.. <sc.len:
       if sc.sons[i].sym == x.sym: return true
+  elif sc.kind == nkOpenSymChoice:
+    # same name suffices for open sym choices!
+    result = sc.sons[0].sym.name.id == x.sym.name.id
 
 proc checkTypes(c: PPatternContext, p: PSym, n: PNode): bool =
-  # XXX tyVarargs is special here; lots of other special cases
+  # check param constraints first here as this quite optimized:
+  if p.typ.constraint != nil:
+    result = matchNodeKinds(p.typ.constraint, n)
+    if not result: return
   if isNil(n.typ):
-    result = p.typ.kind == tyStmt
+    result = p.typ.kind in {tyEmpty, tyStmt}
   else:
     result = sigmatch.argtypeMatches(c.c, p.typ, n.typ)
 
+proc isPatternParam(c: PPatternContext, p: PNode): bool {.inline.} =
+  result = p.kind == nkSym and p.sym.kind == skParam and p.sym.owner == c.owner
+
+proc matchChoice(c: PPatternContext, p, n: PNode): bool =
+  for i in 1 .. <p.len:
+    if matches(c, p.sons[i], n): return true
+
+proc bindOrCheck(c: PPatternContext, param: PSym, n: PNode): bool =
+  var pp = IdNodeTableGetLazy(c.mapping, param)
+  if pp != nil:
+    # check if we got the same pattern (already unified):
+    result = sameTrees(pp, n) #matches(c, pp, n)
+  elif checkTypes(c, param, n) and 
+      (param.ast == nil or checkConstraints(c, param.ast, n)):
+    IdNodeTablePutLazy(c.mapping, param, n)
+    result = true
+
+proc matchStar(c: PPatternContext, p, n: PNode): bool =
+  # match ``op*param``
+  # this is quite hard: 
+  # match against: f(a, ..., f(b, c, f(...)))
+  # we have different semantics if there is a choice as left operand:
+
+  proc matchStarAux(c: PPatternContext, op, n, arglist: PNode) =
+    if n.kind in nkCallKinds and matches(c, op, n.sons[0]):
+      for i in 1..sonsLen(n)-1: matchStarAux(c, op, n.sons[i], arglist)
+    else:
+      add(arglist, n)
+
+  if n.kind notin nkCallKinds: return false
+  if p.sons[0].kind != nkPattern:
+    if matches(c, p.sons[0], n.sons[0]):
+      var arglist = newNodeI(nkArgList, n.info)
+      arglist.typ = p.sons[1].sym.typ
+      matchStarAux(c, p.sons[0], n, arglist)
+      result = bindOrCheck(c, p.sons[1].sym, arglist)
+  else:
+    # well it matches somehow ...
+    if matches(c, p.sons[0], n.sons[0]):
+      result = bindOrCheck(c, p.sons[1].sym, n)
+
 proc matches(c: PPatternContext, p, n: PNode): bool =
-  # XXX special treatment: statement list,
-  # ignore comments, nkPar, hidden conversions
-  # f(..X) ~> how can 'X' stand for all remaining parameters? -> introduce
-  # a new local node kind (alias of nkReturnToken or something)
-  if p.kind == nkSym and p.sym.kind == skParam and p.sym.owner == c.owner:
-    var pp = IdNodeTableGetLazy(c.mapping, p.sym)
-    if pp != nil:
-      # check if we got the same pattern (already unified):
-      result = matches(c, pp, n)
-    elif checkTypes(c, p.sym, n) and 
-        (p.sym.ast == nil or checkConstraints(c, p.sym.ast, n)):
-      IdNodeTablePutLazy(c.mapping, p.sym, n)
-      result = true
+  # hidden conversions (?)
+  if isPatternParam(c, p):
+    result = bindOrCheck(c, p.sym, n)
   elif n.kind == nkSym and inSymChoice(p, n):
     result = true
   elif n.kind == nkSym and n.sym.kind == skConst:
     # try both:
-    if sameTrees(p, n): result = true
-    elif matches(c, p, n.sym.ast):
-      result = true
+    if p.kind == nkSym: result = p.sym == n.sym
+    elif matches(c, p, n.sym.ast): result = true
+  elif p.kind == nkPattern:
+    # pattern operators: | *
+    let opr = p.sons[0].ident.s
+    case opr
+    of "|": result = matchChoice(c, p, n)
+    of "*": result = matchStar(c, p, n)
+    of "~": result = not matches(c, p.sons[1], n)
+    else: InternalError(p.info, "invalid pattern")
+    # template {add(a, `&` * b)}(a: string{noalias}, b: varargs[string]) = 
+    #   add(a, b)
+  elif p.kind == nkCurlyExpr:
+    assert isPatternParam(c, p.sons[1])
+    if matches(c, p.sons[0], n):
+      result = bindOrCheck(c, p.sons[1].sym, n)
   elif sameKinds(p, n):
     case p.kind
     of nkSym: result = p.sym == n.sym
@@ -92,21 +145,52 @@ proc matches(c: PPatternContext, p, n: PNode): bool =
     of nkCharLit..nkInt64Lit: result = p.intVal == n.intVal
     of nkFloatLit..nkFloat64Lit: result = p.floatVal == n.floatVal
     of nkStrLit..nkTripleStrLit: result = p.strVal == n.strVal
-    of nkEmpty, nkNilLit, nkType: 
+    of nkEmpty, nkNilLit, nkType:
       result = true
-      # of nkStmtList:
-      # both are statement lists; we need to ignore comment statements and
-      # 'nil' statements and check whether p <: n which is however trivially
-      # checked as 'applyRule' is checked after every created statement
-      # already; We need to ensure that the matching span is passed to the
-      # macro and NOT simply 'n'!
-      # XXX
     else:
-      if sonsLen(p) == sonsLen(n):
+      var plen = sonsLen(p)
+      # special rule for p(X) ~ f(...); this also works for stuff like
+      # partial case statements, etc! - Not really ... :-/
+      if plen <= sonsLen(n):
+        let v = lastSon(p)
+        if isPatternParam(c, v) and v.sym.typ.kind == tyVarargs:
+          for i in countup(0, plen - 2):
+            if not matches(c, p.sons[i], n.sons[i]): return
+          var arglist = newNodeI(nkArgList, n.info, sonsLen(n) - plen + 1)
+          # f(1, 2, 3)
+          # p(X)
+          for i in countup(0, sonsLen(n) - plen):
+            arglist.sons[i] = n.sons[i + plen - 1]
+          # check or bind 'X':
+          return bindOrCheck(c, v.sym, arglist)
+      if plen == sonsLen(n):
         for i in countup(0, sonsLen(p) - 1):
           if not matches(c, p.sons[i], n.sons[i]): return
         result = true
 
+proc matchStmtList(c: PPatternContext, p, n: PNode): PNode =
+  proc matchRange(c: PPatternContext, p, n: PNode, i: int): bool =
+    for j in 0 .. <p.len:
+      if not matches(c, p.sons[j], n.sons[i+j]):
+        # we need to undo any bindings:
+        if not isNil(c.mapping.data): reset(c.mapping)
+        return false
+    result = true
+  
+  if p.kind == nkStmtList and n.kind == p.kind and p.len < n.len:
+    let n = flattenStmts(n)
+    # no need to flatten 'p' here as that has already been done
+    for i in 0 .. n.len - p.len:
+      if matchRange(c, p, n, i):
+        c.subMatch = true
+        result = newNodeI(nkStmtList, n.info, 3)
+        result.sons[0] = extractRange(nkStmtList, n, 0, i-1)
+        result.sons[1] = extractRange(nkStmtList, n, i, i+p.len-1)
+        result.sons[2] = extractRange(nkStmtList, n, i+p.len, n.len-1)
+        break
+  elif matches(c, p, n):
+    result = n
+
 # writeln(X, a); writeln(X, b); --> writeln(X, a, b)
 
 proc applyRule*(c: PContext, s: PSym, n: PNode): PNode =
@@ -115,20 +199,48 @@ proc applyRule*(c: PContext, s: PSym, n: PNode): PNode =
   ctx.owner = s
   ctx.c = c
   # we perform 'initIdNodeTable' lazily for performance
-  if matches(ctx, s.ast.sons[patternPos], n):
-    # each parameter should have been bound; we simply setup a call and
-    # let semantic checking deal with the rest :-)
-    # this also saves type checking if we allow for type checking errors
-    # as in 'system.compiles' and simply discard the results. But an error
-    # may have been desired in the first place! Meh, it's good enough for
-    # a first implementation:
-    result = newNodeI(nkCall, n.info)
-    result.add(newSymNode(s, n.info))
-    let params = s.typ.n
+  var m = matchStmtList(ctx, s.ast.sons[patternPos], n)
+  if isNil(m): return nil
+  # each parameter should have been bound; we simply setup a call and
+  # let semantic checking deal with the rest :-)
+  result = newNodeI(nkCall, n.info)
+  result.add(newSymNode(s, n.info))
+  let params = s.typ.n
+  for i in 1 .. < params.len:
+    let param = params.sons[i].sym
+    let x = IdNodeTableGetLazy(ctx.mapping, param)
+    # couldn't bind parameter:
+    if isNil(x): return nil
+    result.add(x)
+  # perform alias analysis here:
+  if params.len >= 2:
     for i in 1 .. < params.len:
       let param = params.sons[i].sym
-      let x = IdNodeTableGetLazy(ctx.mapping, param)
-      # couldn't bind parameter:
-      if isNil(x): return nil
-      result.add(x)
-    markUsed(n, s)
+      case whichAlias(param)
+      of aqNone: nil
+      of aqShouldAlias:
+        # it suffices that it aliases for sure with *some* other param:
+        var ok = false
+        for j in 1 .. < result.len:
+          if j != i and result.sons[j].typ != nil:
+            if aliases.isPartOf(result[i], result[j]) == arYes:
+              ok = true
+              break
+        # constraint not fullfilled:
+        if not ok: return nil
+      of aqNoAlias:
+        # it MUST not alias with any other param:
+        var ok = true
+        for j in 1 .. < result.len:
+          if j != i and result.sons[j].typ != nil:
+            if aliases.isPartOf(result[i], result[j]) != arNo:
+              ok = false
+              break
+        # constraint not fullfilled:
+        if not ok: return nil
+
+  markUsed(n, s)
+  if ctx.subMatch:
+    assert m.len == 3
+    m.sons[1] = result
+    result = m