summary refs log tree commit diff stats
path: root/tests/ast_pattern_matching.nim
diff options
context:
space:
mode:
authorArne Döring <arne.doering@gmx.net>2019-03-19 11:45:29 +0100
committerAndreas Rumpf <rumpf_a@web.de>2019-03-19 11:45:29 +0100
commit389b140029577845b9a0e40b6fecc8ba78af679f (patch)
tree3d3dd6de889cb3434209bca793d024aa9e806143 /tests/ast_pattern_matching.nim
parent5c1c5902e20cad9d1f28923e7abba3282ad4f8a1 (diff)
downloadNim-389b140029577845b9a0e40b6fecc8ba78af679f.tar.gz
add tastspec (and ast_pattern_matching) (#10863)
Diffstat (limited to 'tests/ast_pattern_matching.nim')
-rw-r--r--tests/ast_pattern_matching.nim584
1 files changed, 584 insertions, 0 deletions
diff --git a/tests/ast_pattern_matching.nim b/tests/ast_pattern_matching.nim
new file mode 100644
index 000000000..c08234b9e
--- /dev/null
+++ b/tests/ast_pattern_matching.nim
@@ -0,0 +1,584 @@
+# this is a copy paste implementation of github.com/krux02/ast_pattern_matching
+# Please provide bugfixes upstream first before adding them here.
+
+import macros, strutils, tables
+
+export macros
+
+when isMainModule:
+  template debug(args: varargs[untyped]): untyped =
+    echo args
+else:
+  template debug(args: varargs[untyped]): untyped =
+    discard
+
+const
+  nnkIntLiterals*   = nnkCharLit..nnkUInt64Lit
+  nnkStringLiterals* = nnkStrLit..nnkTripleStrLit
+  nnkFloatLiterals* = nnkFloatLit..nnkFloat64Lit
+
+proc newLit[T: enum](arg: T): NimNode =
+  newIdentNode($arg)
+
+proc newLit[T](arg: set[T]): NimNode =
+  ## does not work for the empty sets
+  result = nnkCurly.newTree
+  for x in arg:
+    result.add newLit(x)
+
+type SomeFloat = float | float32 | float64
+
+proc len[T](arg: set[T]): int = card(arg)
+
+type
+  MatchingErrorKind* = enum
+    NoError
+    WrongKindLength
+    WrongKindValue
+    WrongIdent
+    WrongCustomCondition
+
+  MatchingError = object
+    node*: NimNode
+    expectedKind*: set[NimNodeKind]
+    case kind*: MatchingErrorKind
+    of NoError:
+      discard
+    of WrongKindLength:
+      expectedLength*: int
+    of WrongKindValue:
+      expectedValue*: NimNode
+    of WrongIdent, WrongCustomCondition:
+      strVal*: string
+
+proc `$`*(arg: MatchingError): string =
+  let n = arg.node
+  case arg.kind
+  of NoError:
+    "no error"
+  of WrongKindLength:
+    let k = arg.expectedKind
+    let l = arg.expectedLength
+    var msg = "expected "
+    if k.len == 0:
+      msg.add "any node"
+    elif k.len == 1:
+      for el in k:  # only one element but there is no index op for sets
+        msg.add $el
+    else:
+      msg.add "a node in" & $k
+
+    if l >= 0:
+      msg.add " with " & $l & " child(ren)"
+    msg.add ", but got " & $n.kind
+    if l >= 0:
+      msg.add " with " & $n.len & " child(ren)"
+    msg
+  of WrongKindValue:
+    let k = $arg.expectedKind
+    let v = arg.expectedValue.repr
+    var msg = "expected " & k & " with value " & v & " but got " & n.lispRepr
+    if n.kind in {nnkOpenSymChoice, nnkClosedSymChoice}:
+      msg = msg & " (a sym-choice does not have a strVal member, maybe you should match with `ident`)"
+    msg
+  of WrongIdent:
+    let prefix = "expected ident `" & arg.strVal & "` but got "
+    if n.kind in {nnkIdent, nnkSym, nnkOpenSymChoice, nnkClosedSymChoice}:
+      prefix & "`" & n.strVal & "`"
+    else:
+      prefix & $n.kind & " with " & $n.len & " child(ren)"
+  of WrongCustomCondition:
+    "custom condition check failed: " & arg.strVal
+
+
+proc failWithMatchingError*(arg: MatchingError): void {.compileTime, noReturn.} =
+  error($arg, arg.node)
+
+proc expectValue(arg: NimNode; value: SomeInteger): void {.compileTime.} =
+  arg.expectKind nnkLiterals
+  if arg.intVal != int(value):
+    error("expected value " & $value & " but got " & arg.repr, arg)
+
+proc expectValue(arg: NimNode; value: SomeFloat): void {.compileTime.} =
+  arg.expectKind nnkLiterals
+  if arg.floatVal != float(value):
+    error("expected value " & $value & " but got " & arg.repr, arg)
+
+proc expectValue(arg: NimNode; value: string): void {.compileTime.} =
+  arg.expectKind nnkLiterals
+  if arg.strVal != value:
+    error("expected value " & value & " but got " & arg.repr, arg)
+
+proc expectValue[T](arg: NimNode; value: pointer): void {.compileTime.} =
+  arg.expectKind nnkLiterals
+  if value != nil:
+    error("Expect Value for pointers works only on `nil` when the argument is a pointer.")
+  arg.expectKind nnkNilLit
+
+proc expectIdent(arg: NimNode; strVal: string): void {.compileTime.} =
+  if not arg.eqIdent(strVal):
+    error("Expect ident `" & strVal & "` but got " & arg.repr)
+
+proc matchLengthKind*(arg: NimNode; kind: set[NimNodeKind]; length: int): MatchingError {.compileTime.} =
+  let kindFail   = not(kind.card == 0 or arg.kind in kind)
+  let lengthFail = not(length < 0 or length == arg.len)
+  if kindFail or lengthFail:
+    result.node = arg
+    result.kind = WrongKindLength
+    result.expectedLength = length
+    result.expectedKind   = kind
+
+
+proc matchLengthKind*(arg: NimNode; kind: NimNodeKind; length: int): MatchingError {.compileTime.} =
+  matchLengthKind(arg, {kind}, length)
+
+proc matchValue(arg: NimNode; kind: set[NimNodeKind]; value: SomeInteger): MatchingError {.compileTime.} =
+  let kindFail   = not(kind.card == 0 or arg.kind in kind)
+  let valueFail  = arg.intVal != int(value)
+  if kindFail or valueFail:
+    result.node = arg
+    result.kind = WrongKindValue
+    result.expectedKind  = kind
+    result.expectedValue = newLit(value)
+
+proc matchValue(arg: NimNode; kind: NimNodeKind; value: SomeInteger): MatchingError {.compileTime.} =
+  matchValue(arg, {kind}, value)
+
+proc matchValue(arg: NimNode; kind: set[NimNodeKind]; value: SomeFloat): MatchingError {.compileTime.} =
+  let kindFail   = not(kind.card == 0 or arg.kind in kind)
+  let valueFail  = arg.floatVal != float(value)
+  if kindFail or valueFail:
+    result.node = arg
+    result.kind = WrongKindValue
+    result.expectedKind  = kind
+    result.expectedValue = newLit(value)
+
+proc matchValue(arg: NimNode; kind: NimNodeKind; value: SomeFloat): MatchingError {.compileTime.} =
+  matchValue(arg, {kind}, value)
+
+const nnkStrValKinds = {nnkStrLit, nnkRStrLit, nnkTripleStrLit, nnkIdent, nnkSym}
+
+proc matchValue(arg: NimNode; kind: set[NimNodeKind]; value: string): MatchingError {.compileTime.} =
+  # if kind * nnkStringLiterals TODO do something that ensures that here is only checked for string literals
+  let kindFail   = not(kind.card == 0 or arg.kind in kind)
+  let valueFail  =
+    if kind.card == 0:
+      false
+    else:
+      arg.kind notin (kind * nnkStrValKinds) or arg.strVal != value
+  if kindFail or valueFail:
+    result.node = arg
+    result.kind = WrongKindValue
+    result.expectedKind  = kind
+    result.expectedValue = newLit(value)
+
+proc matchValue(arg: NimNode; kind: NimNodeKind; value: string): MatchingError {.compileTime.} =
+  matchValue(arg, {kind}, value)
+
+proc matchValue[T](arg: NimNode; value: pointer): MatchingError {.compileTime.} =
+  if value != nil:
+    error("Expect Value for pointers works only on `nil` when the argument is a pointer.")
+  arg.matchLengthKind(nnkNilLit, -1)
+
+proc matchIdent*(arg:NimNode; value: string): MatchingError =
+  if not arg.eqIdent(value):
+    result.node = arg
+    result.kind = Wrongident
+    result.strVal = value
+
+proc checkCustomExpr*(arg: NimNode; cond: bool, exprstr: string): MatchingError =
+  if not cond:
+    result.node = arg
+    result.kind = WrongCustomCondition
+    result.strVal = exprstr
+
+static:
+  var literals: array[19, NimNode]
+  var i = 0
+  for litKind in nnkLiterals:
+    literals[i] = ident($litKind)
+    i += 1
+
+  var nameToKind = newTable[string, NimNodeKind]()
+  for kind in NimNodeKind:
+    nameToKind[ ($kind)[3..^1] ] = kind
+
+  let identifierKinds = newLit({nnkSym, nnkIdent, nnkOpenSymChoice, nnkClosedSymChoice})
+
+proc generateMatchingCode(astSym: NimNode, pattern: NimNode, depth: int, blockLabel, errorSym, localsArraySym: NimNode; dest: NimNode): int =
+  ## return the number of indices used in the array for local variables.
+
+  var currentLocalIndex = 0
+
+  proc nodeVisiting(astSym: NimNode, pattern: NimNode, depth: int): void =
+    let ind = "  ".repeat(depth) # indentation
+
+    proc genMatchLogic(matchProc, argSym1, argSym2: NimNode): void =
+      dest.add quote do:
+        `errorSym` = `astSym`.`matchProc`(`argSym1`, `argSym2`)
+        if `errorSym`.kind != NoError:
+          break `blockLabel`
+
+    proc genIdentMatchLogic(identValueLit: NimNode): void =
+      dest.add quote do:
+        `errorSym` = `astSym`.matchIdent(`identValueLit`)
+        if `errorSym`.kind != NoError:
+          break `blockLabel`
+
+    proc genCustomMatchLogic(conditionExpr: NimNode): void =
+      let exprStr = newLit(conditionExpr.repr)
+      dest.add quote do:
+        `errorSym` = `astSym`.checkCustomExpr(`conditionExpr`, `exprStr`)
+        if `errorSym`.kind != NoError:
+          break `blockLabel`
+
+    # proc handleKindMatching(kindExpr: NimNode): void =
+    #   if kindExpr.eqIdent("_"):
+    #     # this is the wildcand that matches any kind
+    #     return
+    #   else:
+    #     genMatchLogic(bindSym"matchKind", kindExpr)
+
+    # generate recursively a matching expression
+    if pattern.kind == nnkCall:
+      pattern.expectMinLen(1)
+
+      debug ind, pattern[0].repr, "("
+
+      let kindSet = if pattern[0].eqIdent("_"): nnkCurly.newTree else: pattern[0]
+      # handleKindMatching(pattern[0])
+
+      if pattern.len == 2 and pattern[1].kind == nnkExprEqExpr:
+        if pattern[1][1].kind in nnkStringLiterals:
+          pattern[1][0].expectIdent("strVal")
+        elif pattern[1][1].kind in nnkIntLiterals:
+          pattern[1][0].expectIdent("intVal")
+        elif pattern[1][1].kind in nnkFloatLiterals:
+          pattern[1][0].expectIdent("floatVal")
+
+        genMatchLogic(bindSym"matchValue", kindSet, pattern[1][1])
+
+      else:
+        let lengthLit = newLit(pattern.len - 1)
+        genMatchLogic(bindSym"matchLengthKind", kindSet, lengthLit)
+
+        for i in 1 ..< pattern.len:
+          let childSym = nnkBracketExpr.newTree(localsArraySym, newLit(currentLocalIndex))
+          currentLocalIndex += 1
+          let indexLit = newLit(i - 1)
+          dest.add quote do:
+            `childSym` = `astSym`[`indexLit`]
+          nodeVisiting(childSym, pattern[i], depth + 1)
+      debug ind, ")"
+    elif pattern.kind == nnkCallStrLit and pattern[0].eqIdent("ident"):
+      genIdentMatchLogic(pattern[1])
+
+    elif pattern.kind == nnkPar and pattern.len == 1:
+      nodeVisiting(astSym, pattern[0], depth)
+    elif pattern.kind == nnkPrefix:
+      error("prefix patterns not implemented", pattern)
+    elif pattern.kind == nnkAccQuoted:
+      debug ind, pattern.repr
+      let matchedExpr = pattern[0]
+      matchedExpr.expectKind nnkIdent
+      dest.add quote do:
+        let `matchedExpr` = `astSym`
+
+    elif pattern.kind == nnkInfix and pattern[0].eqIdent("@"):
+      pattern[1].expectKind nnkAccQuoted
+
+      let matchedExpr = pattern[1][0]
+      matchedExpr.expectKind nnkIdent
+      dest.add quote do:
+        let `matchedExpr` = `astSym`
+
+      debug ind, pattern[1].repr, " = "
+      nodeVisiting(matchedExpr, pattern[2], depth + 1)
+
+    elif pattern.kind == nnkInfix and pattern[0].eqIdent("|="):
+      nodeVisiting(astSym, pattern[1], depth + 1)
+      genCustomMatchLogic(pattern[2])
+
+    elif pattern.kind in nnkCallKinds:
+      error("only boring call syntax allowed, this is " & $pattern.kind & ".", pattern)
+    elif pattern.kind in nnkLiterals:
+      genMatchLogic(bindSym"matchValue", nnkCurly.newTree, pattern)
+    elif not pattern.eqIdent("_"):
+      # When it is not one of the other branches, it is simply treated
+      # as an expression for the node kind, without checking child
+      # nodes.
+      debug ind, pattern.repr
+      genMatchLogic(bindSym"matchLengthKind", pattern, newLit(-1))
+
+  nodeVisiting(astSym, pattern, depth)
+
+  return currentLocalIndex
+
+macro matchAst*(astExpr: NimNode; args: varargs[untyped]): untyped =
+  let astSym = genSym(nskLet, "ast")
+  let beginBranches = if args[0].kind == nnkIdent: 1 else: 0
+  let endBranches   = if args[^1].kind == nnkElse: args.len - 1 else: args.len
+  for i in beginBranches ..< endBranches:
+    args[i].expectKind nnkOfBranch
+
+  let outerErrorSym: NimNode =
+    if beginBranches == 1:
+      args[0].expectKind nnkIdent
+      args[0]
+    else:
+      nil
+
+  let elseBranch: NimNode =
+    if endBranches == args.len - 1:
+      args[^1].expectKind(nnkElse)
+      args[^1][0]
+    else:
+      nil
+
+  let outerBlockLabel = genSym(nskLabel, "matchingSection")
+  let outerStmtList = newStmtList()
+  let errorSymbols = nnkBracket.newTree
+
+  ## the vm only allows 255 local variables. This sucks a lot and I
+  ## have to work around it.  So instead of creating a lot of local
+  ## variables, I just create one array of local variables. This is
+  ## just annoying.
+  let localsArraySym = genSym(nskVar, "locals")
+  var localsArrayLen: int = 0
+
+  for i in beginBranches ..< endBranches:
+    let ofBranch = args[i]
+
+    ofBranch.expectKind(nnkOfBranch)
+    ofBranch.expectLen(2)
+    let pattern = ofBranch[0]
+    let code = ofBranch[1]
+    code.expectKind nnkStmtList
+    let stmtList = newStmtList()
+    let blockLabel = genSym(nskLabel, "matchingBranch")
+    let errorSym = genSym(nskVar, "branchError")
+
+    errorSymbols.add errorSym
+    let numLocalsUsed = generateMatchingCode(astSym, pattern, 0, blockLabel, errorSym, localsArraySym, stmtList)
+    localsArrayLen = max(localsArrayLen, numLocalsUsed)
+    stmtList.add code
+    # maybe there is a better mechanism disable errors for statement after return
+    if code[^1].kind != nnkReturnStmt:
+      stmtList.add nnkBreakStmt.newTree(outerBlockLabel)
+
+    outerStmtList.add quote do:
+      var `errorSym`: MatchingError
+      block `blockLabel`:
+        `stmtList`
+
+  if elseBranch != nil:
+    if outerErrorSym != nil:
+      outerStmtList.add quote do:
+        let `outerErrorSym` = @`errorSymbols`
+        `elseBranch`
+    else:
+      outerStmtList.add elseBranch
+
+  else:
+    if errorSymbols.len == 1:
+      # there is only one of branch and no else branch
+      # the error message can be very precise here.
+      let errorSym = errorSymbols[0]
+      outerStmtList.add quote do:
+        failWithMatchingError(`errorSym`)
+    else:
+
+      var patterns: string = ""
+      for i in beginBranches ..< endBranches:
+        let ofBranch = args[i]
+        let pattern = ofBranch[0]
+        patterns.add pattern.repr
+        patterns.add "\n"
+
+      let patternsLit = newLit(patterns)
+      outerStmtList.add quote do:
+        error("Ast pattern mismatch: got " & `astSym`.lispRepr & "\nbut expected one of:\n" & `patternsLit`, `astSym`)
+
+  let lengthLit = newLit(localsArrayLen)
+  result = quote do:
+    block `outerBlockLabel`:
+      let `astSym` = `astExpr`
+      var `localsArraySym`: array[`lengthLit`, NimNode]
+      `outerStmtList`
+
+  debug result.repr
+
+proc recursiveNodeVisiting*(arg: NimNode, callback: proc(arg: NimNode): bool) =
+  ## if `callback` returns true, visitor continues to visit the
+  ## children of `arg` otherwise it stops.
+  if callback(arg):
+    for child in arg:
+      recursiveNodeVisiting(child, callback)
+
+macro matchAstRecursive*(ast: NimNode; args: varargs[untyped]): untyped =
+  # Does not recurse further on matched nodes.
+  if args[^1].kind == nnkElse:
+    error("Recursive matching with an else branch is pointless.", args[^1])
+
+  let visitor = genSym(nskProc, "visitor")
+  let visitorArg = genSym(nskParam, "arg")
+
+  let visitorStmtList = newStmtList()
+
+  let matchingSection = genSym(nskLabel, "matchingSection")
+
+  let localsArraySym = genSym(nskVar, "locals")
+  let branchError = genSym(nskVar, "branchError")
+  var localsArrayLen = 0
+
+  for ofBranch in args:
+    ofBranch.expectKind(nnkOfBranch)
+    ofBranch.expectLen(2)
+    let pattern = ofBranch[0]
+    let code = ofBranch[1]
+    code.expectkind(nnkStmtList)
+
+    let stmtList = newStmtList()
+    let matchingBranch = genSym(nskLabel, "matchingBranch")
+
+    let numLocalsUsed = generateMatchingCode(visitorArg, pattern, 0, matchingBranch, branchError, localsArraySym, stmtList)
+    localsArrayLen = max(localsArrayLen, numLocalsUsed)
+
+    stmtList.add code
+    stmtList.add nnkBreakStmt.newTree(matchingSection)
+
+
+    visitorStmtList.add quote do:
+      `branchError`.kind = NoError
+      block `matchingBranch`:
+        `stmtList`
+
+  let resultIdent = ident"result"
+
+  let visitingProc = bindSym"recursiveNodeVisiting"
+  let lengthLit = newLit(localsArrayLen)
+
+  result = quote do:
+    proc `visitor`(`visitorArg`: NimNode): bool =
+      block `matchingSection`:
+        var `localsArraySym`: array[`lengthLit`, NimNode]
+        var `branchError`: MatchingError
+        `visitorStmtList`
+        `resultIdent` = true
+
+    `visitingProc`(`ast`, `visitor`)
+
+  debug result.repr
+
+################################################################################
+################################# Example Code #################################
+################################################################################
+
+when isMainModule:
+  static:
+    let mykinds = {nnkIdent, nnkCall}
+
+  macro foo(arg: untyped): untyped =
+    matchAst(arg, matchError):
+    of nnkStmtList(nnkIdent, nnkIdent, nnkIdent):
+      echo(88*88+33*33)
+    of nnkStmtList(
+      _(
+        nnkIdentDefs(
+          ident"a",
+          nnkEmpty, nnkIntLit(intVal = 123)
+        )
+      ),
+      _,
+      nnkForStmt(
+        nnkIdent(strVal = "i"),
+        nnkInfix,
+        `mysym` @ nnkStmtList
+      )
+    ):
+      echo "The AST did match!!!"
+      echo "The matched sub tree is the following:"
+      echo mysym.lispRepr
+    #else:
+    #  echo "sadly the AST did not match :("
+    #  echo arg.treeRepr
+    #  failWithMatchingError(matchError[1])
+
+  foo:
+    let a = 123
+    let b = 342
+    for i in a ..< b:
+      echo "Hallo", i
+
+  static:
+
+    var ast = quote do:
+      type
+        A[T: static[int]] = object
+
+    ast = ast[0]
+    ast.matchAst(err):  # this is a sub ast for this a findAst or something like that is useful
+    of nnkTypeDef(_, nnkGenericParams( nnkIdentDefs( nnkIdent(strVal = "T"), `staticTy`, nnkEmpty )), _):
+      echo "`", staticTy.repr, "` used to be of nnkStaticTy, now it is ", staticTy.kind, " with ", staticTy[0].repr
+    ast = quote do:
+      if cond1: expr1 elif cond2: expr2 else: expr3
+
+    ast.matchAst:
+    of {nnkIfExpr, nnkIfStmt}(
+      {nnkElifExpr, nnkElifBranch}(`cond1`, `expr1`),
+      {nnkElifExpr, nnkElifBranch}(`cond2`, `expr2`),
+      {nnkElseExpr, nnkElse}(`expr3`)
+    ):
+      echo "ok"
+
+    let ast2 = nnkStmtList.newTree( newLit(1) )
+
+    ast2.matchAst:
+    of nnkIntLit( 1 ):
+      echo "fail"
+    of nnkStmtList( 1 ):
+      echo "ok"
+
+    ast = bindSym"[]"
+    ast.matchAst(errors):
+    of nnkClosedSymChoice(strVal = "[]"):
+      echo "fail, this is the wrong syntax, a sym choice does not have a `strVal` member."
+    of ident"[]":
+      echo "ok"
+
+    const myConst = 123
+    ast = newLit(123)
+
+    ast.matchAst:
+    of _(intVal = myConst):
+      echo "ok"
+
+    macro testRecCase(ast: untyped): untyped =
+      ast.matchAstRecursive:
+      of nnkIdentDefs(`a`,`b`,`c`):
+        echo "got ident defs a: ", a.repr, " b: ", b.repr, " c: ", c.repr
+      of ident"m":
+        echo "got the ident m"
+
+    testRecCase:
+      type Obj[T] = object {.inheritable.}
+        name: string
+        case isFat: bool
+        of true:
+          m: array[100_000, T]
+        of false:
+          m: array[10, T]
+
+
+    macro testIfCondition(ast: untyped): untyped =
+      let literals = nnkBracket.newTree
+      ast.matchAstRecursive:
+      of `intLit` @ nnkIntLit |= intLit.intVal > 5:
+        literals.add intLit
+
+      let literals2 = quote do:
+        [6,7,8,9]
+
+      doAssert literals2 == literals
+
+    testIfCondition([1,6,2,7,3,8,4,9,5,0,"123"])