summary refs log tree commit diff stats
path: root/compiler/patterns.nim
diff options
context:
space:
mode:
Diffstat (limited to 'compiler/patterns.nim')
-rw-r--r--compiler/patterns.nim330
1 files changed, 330 insertions, 0 deletions
diff --git a/compiler/patterns.nim b/compiler/patterns.nim
new file mode 100644
index 000000000..32ec7fb53
--- /dev/null
+++ b/compiler/patterns.nim
@@ -0,0 +1,330 @@
+#
+#
+#           The Nim Compiler
+#        (c) Copyright 2012 Andreas Rumpf
+#
+#    See the file "copying.txt", included in this
+#    distribution, for details about the copyright.
+#
+
+## This module implements the pattern matching features for term rewriting
+## macro support.
+
+import
+  ast, types, semdata, sigmatch, idents, aliases, parampatterns, trees
+
+when defined(nimPreviewSlimSystem):
+  import std/assertions
+
+type
+  TPatternContext = object
+    owner: PSym
+    mapping: seq[PNode]  # maps formal parameters to nodes
+    formals: int
+    c: PContext
+    subMatch: bool       # subnode matches are special
+    mappingIsFull: bool
+  PPatternContext = var TPatternContext
+
+proc getLazy(c: PPatternContext, sym: PSym): PNode =
+  if c.mappingIsFull:
+    result = c.mapping[sym.position]
+  else:
+    result = nil
+
+proc putLazy(c: PPatternContext, sym: PSym, n: PNode) =
+  if not c.mappingIsFull:
+    newSeq(c.mapping, c.formals)
+    c.mappingIsFull = true
+  c.mapping[sym.position] = n
+
+proc matches(c: PPatternContext, p, n: PNode): bool
+
+proc canonKind(n: PNode): TNodeKind =
+  ## nodekind canonicalization for pattern matching
+  result = n.kind
+  case result
+  of nkCallKinds: result = nkCall
+  of nkStrLit..nkTripleStrLit: result = nkStrLit
+  of nkFastAsgn, nkSinkAsgn: result = nkAsgn
+  else: discard
+
+proc sameKinds(a, b: PNode): bool {.inline.} =
+  result = a.kind == b.kind or a.canonKind == b.canonKind
+
+proc sameTrees*(a, b: PNode): bool =
+  if sameKinds(a, b):
+    case a.kind
+    of nkSym: result = a.sym == b.sym
+    of nkIdent: result = a.ident.id == b.ident.id
+    of nkCharLit..nkInt64Lit: result = a.intVal == b.intVal
+    of nkFloatLit..nkFloat64Lit: result = a.floatVal == b.floatVal
+    of nkStrLit..nkTripleStrLit: result = a.strVal == b.strVal
+    of nkEmpty, nkNilLit: result = true
+    of nkType: result = sameTypeOrNil(a.typ, b.typ)
+    else:
+      if a.len == b.len:
+        for i in 0..<a.len:
+          if not sameTrees(a[i], b[i]): return
+        result = true
+      else:
+        result = false
+  else:
+    result = false
+
+proc inSymChoice(sc, x: PNode): bool =
+  if sc.kind == nkClosedSymChoice:
+    result = false
+    for i in 0..<sc.len:
+      if sc[i].sym == x.sym: return true
+  elif sc.kind == nkOpenSymChoice:
+    # same name suffices for open sym choices!
+    result = sc[0].sym.name.id == x.sym.name.id
+  else:
+    result = false
+
+proc checkTypes(c: PPatternContext, p: PSym, n: PNode): bool =
+  # check param constraints first here as this is quite optimized:
+  if p.constraint != nil:
+    result = matchNodeKinds(p.constraint, n)
+    if not result: return
+  if isNil(n.typ):
+    result = p.typ.kind in {tyVoid, tyTyped}
+  else:
+    result = sigmatch.argtypeMatches(c.c, p.typ, n.typ, fromHlo = true)
+
+proc isPatternParam(c: PPatternContext, p: PNode): bool {.inline.} =
+  result = p.kind == nkSym and p.sym.kind == skParam and p.sym.owner == c.owner
+
+proc matchChoice(c: PPatternContext, p, n: PNode): bool =
+  result = false
+  for i in 1..<p.len:
+    if matches(c, p[i], n): return true
+
+proc bindOrCheck(c: PPatternContext, param: PSym, n: PNode): bool =
+  var pp = getLazy(c, param)
+  if pp != nil:
+    # check if we got the same pattern (already unified):
+    result = sameTrees(pp, n) #matches(c, pp, n)
+  elif n.kind == nkArgList or checkTypes(c, param, n):
+    putLazy(c, param, n)
+    result = true
+  else:
+    result = false
+
+proc gather(c: PPatternContext, param: PSym, n: PNode) =
+  var pp = getLazy(c, param)
+  if pp != nil and pp.kind == nkArgList:
+    pp.add(n)
+  else:
+    pp = newNodeI(nkArgList, n.info, 1)
+    pp[0] = n
+    putLazy(c, param, pp)
+
+proc matchNested(c: PPatternContext, p, n: PNode, rpn: bool): bool =
+  # match ``op * param`` or ``op *| param``
+  proc matchStarAux(c: PPatternContext, op, n, arglist: PNode,
+                    rpn: bool): bool =
+    result = true
+    if n.kind in nkCallKinds and matches(c, op[1], n[0]):
+      for i in 1..<n.len:
+        if not matchStarAux(c, op, n[i], arglist, rpn): return false
+      if rpn: arglist.add(n[0])
+    elif n.kind == nkHiddenStdConv and n[1].kind == nkBracket:
+      let n = n[1]
+      for i in 0..<n.len:
+        if not matchStarAux(c, op, n[i], arglist, rpn): return false
+    elif checkTypes(c, p[2].sym, n):
+      arglist.add(n)
+    else:
+      result = false
+
+  if n.kind notin nkCallKinds: return false
+  if matches(c, p[1], n[0]):
+    var arglist = newNodeI(nkArgList, n.info)
+    if matchStarAux(c, p, n, arglist, rpn):
+      result = bindOrCheck(c, p[2].sym, arglist)
+    else:
+      result = false
+  else:
+    result = false
+
+proc matches(c: PPatternContext, p, n: PNode): bool =
+  let n = skipHidden(n)
+  if nfNoRewrite in n.flags:
+    result = false
+  elif isPatternParam(c, p):
+    result = bindOrCheck(c, p.sym, n)
+  elif n.kind == nkSym and p.kind == nkIdent:
+    result = p.ident.id == n.sym.name.id
+  elif n.kind == nkSym and inSymChoice(p, n):
+    result = true
+  elif n.kind == nkSym and n.sym.kind == skConst:
+    # try both:
+    if p.kind == nkSym: result = p.sym == n.sym
+    elif matches(c, p, n.sym.astdef): result = true
+    else: result = false
+  elif p.kind == nkPattern:
+    # pattern operators: | *
+    let opr = p[0].ident.s
+    case opr
+    of "|": result = matchChoice(c, p, n)
+    of "*": result = matchNested(c, p, n, rpn=false)
+    of "**": result = matchNested(c, p, n, rpn=true)
+    of "~": result = not matches(c, p[1], n)
+    else:
+      result = false
+      doAssert(false, "invalid pattern")
+    # template {add(a, `&` * b)}(a: string{noalias}, b: varargs[string]) =
+    #   a.add(b)
+  elif p.kind == nkCurlyExpr:
+    if p[1].kind == nkPrefix:
+      if matches(c, p[0], n):
+        gather(c, p[1][1].sym, n)
+        result = true
+      else:
+        result = false
+    else:
+      assert isPatternParam(c, p[1])
+      if matches(c, p[0], n):
+        result = bindOrCheck(c, p[1].sym, n)
+      else:
+        result = false
+  elif sameKinds(p, n):
+    case p.kind
+    of nkSym: result = p.sym == n.sym
+    of nkIdent: result = p.ident.id == n.ident.id
+    of nkCharLit..nkInt64Lit: result = p.intVal == n.intVal
+    of nkFloatLit..nkFloat64Lit: result = p.floatVal == n.floatVal
+    of nkStrLit..nkTripleStrLit: result = p.strVal == n.strVal
+    of nkEmpty, nkNilLit, nkType:
+      result = true
+    else:
+      # special rule for p(X) ~ f(...); this also works for stuff like
+      # partial case statements, etc! - Not really ... :-/
+      result = false
+      let v = lastSon(p)
+      if isPatternParam(c, v) and v.sym.typ.kind == tyVarargs:
+        var arglist: PNode
+        if p.len <= n.len:
+          for i in 0..<p.len - 1:
+            if not matches(c, p[i], n[i]): return
+          if p.len == n.len and lastSon(n).kind == nkHiddenStdConv and
+              lastSon(n)[1].kind == nkBracket:
+            # unpack varargs:
+            let n = lastSon(n)[1]
+            arglist = newNodeI(nkArgList, n.info, n.len)
+            for i in 0..<n.len: arglist[i] = n[i]
+          else:
+            arglist = newNodeI(nkArgList, n.info, n.len - p.len + 1)
+            # f(1, 2, 3)
+            # p(X)
+            for i in 0..n.len - p.len:
+              arglist[i] = n[i + p.len - 1]
+          return bindOrCheck(c, v.sym, arglist)
+        elif p.len-1 == n.len:
+          for i in 0..<p.len - 1:
+            if not matches(c, p[i], n[i]): return
+          arglist = newNodeI(nkArgList, n.info)
+          return bindOrCheck(c, v.sym, arglist)
+      if p.len == n.len:
+        for i in 0..<p.len:
+          if not matches(c, p[i], n[i]): return
+        result = true
+  else:
+    result = false
+
+proc matchStmtList(c: PPatternContext, p, n: PNode): PNode =
+  proc matchRange(c: PPatternContext, p, n: PNode, i: int): bool =
+    for j in 0..<p.len:
+      if not matches(c, p[j], n[i+j]):
+        # we need to undo any bindings:
+        c.mapping = @[]
+        c.mappingIsFull = false
+        return false
+    result = true
+
+  if p.kind == nkStmtList and n.kind == p.kind and p.len < n.len:
+    result = nil
+    let n = flattenStmts(n)
+    # no need to flatten 'p' here as that has already been done
+    for i in 0..n.len - p.len:
+      if matchRange(c, p, n, i):
+        c.subMatch = true
+        result = newNodeI(nkStmtList, n.info, 3)
+        result[0] = extractRange(nkStmtList, n, 0, i-1)
+        result[1] = extractRange(nkStmtList, n, i, i+p.len-1)
+        result[2] = extractRange(nkStmtList, n, i+p.len, n.len-1)
+        break
+  elif matches(c, p, n):
+    result = n
+  else:
+    result = nil
+
+proc aliasAnalysisRequested(params: PNode): bool =
+  result = false
+  if params.len >= 2:
+    for i in 1..<params.len:
+      let param = params[i].sym
+      if whichAlias(param) != aqNone: return true
+
+proc addToArgList(result, n: PNode) =
+  if n.typ != nil and n.typ.kind != tyTyped:
+    if n.kind != nkArgList: result.add(n)
+    else:
+      for i in 0..<n.len: result.add(n[i])
+
+proc applyRule*(c: PContext, s: PSym, n: PNode): PNode =
+  ## returns a tree to semcheck if the rule triggered; nil otherwise
+  var ctx = TPatternContext(owner: s, c: c, formals: s.typ.paramsLen)
+  var m = matchStmtList(ctx, s.ast[patternPos], n)
+  if isNil(m): return nil
+  # each parameter should have been bound; we simply setup a call and
+  # let semantic checking deal with the rest :-)
+  result = newNodeI(nkCall, n.info)
+  result.add(newSymNode(s, n.info))
+  let params = s.typ.n
+  let requiresAA = aliasAnalysisRequested(params)
+  var args: PNode =
+    if requiresAA:
+      newNodeI(nkArgList, n.info)
+    else:
+      nil
+  for i in 1..<params.len:
+    let param = params[i].sym
+    let x = getLazy(ctx, param)
+    # couldn't bind parameter:
+    if isNil(x): return nil
+    result.add(x)
+    if requiresAA: addToArgList(args, x)
+  # perform alias analysis here:
+  if requiresAA:
+    for i in 1..<params.len:
+      var rs = result[i]
+      let param = params[i].sym
+      case whichAlias(param)
+      of aqNone: discard
+      of aqShouldAlias:
+        # it suffices that it aliases for sure with *some* other param:
+        var ok = false
+        for arg in items(args):
+          if arg != rs and aliases.isPartOf(rs, arg) == arYes:
+            ok = true
+            break
+        # constraint not fulfilled:
+        if not ok: return nil
+      of aqNoAlias:
+        # it MUST not alias with any other param:
+        var ok = true
+        for arg in items(args):
+          if arg != rs and aliases.isPartOf(rs, arg) != arNo:
+            ok = false
+            break
+        # constraint not fulfilled:
+        if not ok: return nil
+
+  markUsed(c, n.info, s)
+  if ctx.subMatch:
+    assert m.len == 3
+    m[1] = result
+    result = m