# this is a copy paste implementation of github.com/krux02/ast_pattern_matching # Please provide bugfixes upstream first before adding them here. import macros, strutils, tables export macros when isMainModule: template debug(args: varargs[untyped]): untyped = echo args else: template debug(args: varargs[untyped]): untyped = discard const nnkIntLiterals* = nnkCharLit..nnkUInt64Lit nnkStringLiterals* = nnkStrLit..nnkTripleStrLit nnkFloatLiterals* = nnkFloatLit..nnkFloat64Lit proc newLit[T: enum](arg: T): NimNode = newIdentNode($arg) proc newLit[T](arg: set[T]): NimNode = ## does not work for the empty sets result = nnkCurly.newTree for x in arg: result.add newLit(x) type SomeFloat = float | float32 | float64 proc len[T](arg: set[T]): int = card(arg) type MatchingErrorKind* = enum NoError WrongKindLength WrongKindValue WrongIdent WrongCustomCondition MatchingError = object node*: NimNode expectedKind*: set[NimNodeKind] case kind*: MatchingErrorKind of NoError: discard of WrongKindLength: expectedLength*: int of WrongKindValue: expectedValue*: NimNode of WrongIdent, WrongCustomCondition: strVal*: string proc `$`*(arg: MatchingError): string = let n = arg.node case arg.kind of NoError: "no error" of WrongKindLength: let k = arg.expectedKind let l = arg.expectedLength var msg = "expected " if k.len == 0: msg.add "any node" elif k.len == 1: for el in k: # only one element but there is no index op for sets msg.add $el else: msg.add "a node in" & $k if l >= 0: msg.add " with " & $l & " child(ren)" msg.add ", but got " & $n.kind if l >= 0: msg.add " with " & $n.len & " child(ren)" msg of WrongKindValue: let k = $arg.expectedKind let v = arg.expectedValue.repr var msg = "expected " & k & " with value " & v & " but got " & n.lispRepr if n.kind in {nnkOpenSymChoice, nnkClosedSymChoice}: msg = msg & " (a sym-choice does not have a strVal member, maybe you should match with `ident`)" msg of WrongIdent: let prefix = "expected ident `" & arg.strVal & "` but got " if n.kind in {nnkIdent, nnkSym, nnkOpenSymChoice, nnkClosedSymChoice}: prefix & "`" & n.strVal & "`" else: prefix & $n.kind & " with " & $n.len & " child(ren)" of WrongCustomCondition: "custom condition check failed: " & arg.strVal proc failWithMatchingError*(arg: MatchingError): void {.compileTime, noReturn.} = error($arg, arg.node) proc expectValue(arg: NimNode; value: SomeInteger): void {.compileTime.} = arg.expectKind nnkLiterals if arg.intVal != int(value): error("expected value " & $value & " but got " & arg.repr, arg) proc expectValue(arg: NimNode; value: SomeFloat): void {.compileTime.} = arg.expectKind nnkLiterals if arg.floatVal != float(value): error("expected value " & $value & " but got " & arg.repr, arg) proc expectValue(arg: NimNode; value: string): void {.compileTime.} = arg.expectKind nnkLiterals if arg.strVal != value: error("expected value " & value & " but got " & arg.repr, arg) proc expectValue[T](arg: NimNode; value: pointer): void {.compileTime.} = arg.expectKind nnkLiterals if value != nil: error("Expect Value for pointers works only on `nil` when the argument is a pointer.") arg.expectKind nnkNilLit proc expectIdent(arg: NimNode; strVal: string): void {.compileTime.} = if not arg.eqIdent(strVal): error("Expect ident `" & strVal & "` but got " & arg.repr) proc matchLengthKind*(arg: NimNode; kind: set[NimNodeKind]; length: int): MatchingError {.compileTime.} = let kindFail = not(kind.card == 0 or arg.kind in kind) let lengthFail = not(length < 0 or length == arg.len) if kindFail or lengthFail: result.node = arg result.kind = WrongKindLength result.expectedLength = length result.expectedKind = kind proc matchLengthKind*(arg: NimNode; kind: NimNodeKind; length: int): MatchingError {.compileTime.} = matchLengthKind(arg, {kind}, length) proc matchValue(arg: NimNode; kind: set[NimNodeKind]; value: SomeInteger): MatchingError {.compileTime.} = let kindFail = not(kind.card == 0 or arg.kind in kind) let valueFail = arg.intVal != int(value) if kindFail or valueFail: result.node = arg result.kind = WrongKindValue result.expectedKind = kind result.expectedValue = newLit(value) proc matchValue(arg: NimNode; kind: NimNodeKind; value: SomeInteger): MatchingError {.compileTime.} = matchValue(arg, {kind}, value) proc matchValue(arg: NimNode; kind: set[NimNodeKind]; value: SomeFloat): MatchingError {.compileTime.} = let kindFail = not(kind.card == 0 or arg.kind in kind) let valueFail = arg.floatVal != float(value) if kindFail or valueFail: result.node = arg result.kind = WrongKindValue result.expectedKind = kind result.expectedValue = newLit(value) proc matchValue(arg: NimNode; kind: NimNodeKind; value: SomeFloat): MatchingError {.compileTime.} = matchValue(arg, {kind}, value) const nnkStrValKinds = {nnkStrLit, nnkRStrLit, nnkTripleStrLit, nnkIdent, nnkSym} proc matchValue(arg: NimNode; kind: set[NimNodeKind]; value: string): MatchingError {.compileTime.} = # if kind * nnkStringLiterals TODO do something that ensures that here is only checked for string literals let kindFail = not(kind.card == 0 or arg.kind in kind) let valueFail = if kind.card == 0: false else: arg.kind notin (kind * nnkStrValKinds) or arg.strVal != value if kindFail or valueFail: result.node = arg result.kind = WrongKindValue result.expectedKind = kind result.expectedValue = newLit(value) proc matchValue(arg: NimNode; kind: NimNodeKind; value: string): MatchingError {.compileTime.} = matchValue(arg, {kind}, value) proc matchValue[T](arg: NimNode; value: pointer): MatchingError {.compileTime.} = if value != nil: error("Expect Value for pointers works only on `nil` when the argument is a pointer.") arg.matchLengthKind(nnkNilLit, -1) proc matchIdent*(arg:NimNode; value: string): MatchingError = if not arg.eqIdent(value): result.node = arg result.kind = Wrongident result.strVal = value proc checkCustomExpr*(arg: NimNode; cond: bool, exprstr: string): MatchingError = if not cond: result.node = arg result.kind = WrongCustomCondition result.strVal = exprstr static: var literals: array[19, NimNode] var i = 0 for litKind in nnkLiterals: literals[i] = ident($litKind) i += 1 var nameToKind = newTable[string, NimNodeKind]() for kind in NimNodeKind: nameToKind[ ($kind)[3..^1] ] = kind let identifierKinds = newLit({nnkSym, nnkIdent, nnkOpenSymChoice, nnkClosedSymChoice}) proc generateMatchingCode(astSym: NimNode, pattern: NimNode, depth: int, blockLabel, errorSym, localsArraySym: NimNode; dest: NimNode): int = ## return the number of indices used in the array for local variables. var currentLocalIndex = 0 proc nodeVisiting(astSym: NimNode, pattern: NimNode, depth: int): void = let ind = " ".repeat(depth) # indentation proc genMatchLogic(matchProc, argSym1, argSym2: NimNode): void = dest.add quote do: `errorSym` = `astSym`.`matchProc`(`argSym1`, `argSym2`) if `errorSym`.kind != NoError: break `blockLabel` proc genIdentMatchLogic(identValueLit: NimNode): void = dest.add quote do: `errorSym` = `astSym`.matchIdent(`identValueLit`) if `errorSym`.kind != NoError: break `blockLabel` proc genCustomMatchLogic(conditionExpr: NimNode): void = let exprStr = newLit(conditionExpr.repr) dest.add quote do: `errorSym` = `astSym`.checkCustomExpr(`conditionExpr`, `exprStr`) if `errorSym`.kind != NoError: break `blockLabel` # proc handleKindMatching(kindExpr: NimNode): void = # if kindExpr.eqIdent("_"): # # this is the wildcand that matches any kind # return # else: # genMatchLogic(bindSym"matchKind", kindExpr) # generate recursively a matching expression if pattern.kind == nnkCall: pattern.expectMinLen(1) debug ind, pattern[0].repr, "(" let kindSet = if pattern[0].eqIdent("_"): nnkCurly.newTree else: pattern[0] # handleKindMatching(pattern[0]) if pattern.len == 2 and pattern[1].kind == nnkExprEqExpr: if pattern[1][1].kind in nnkStringLiterals: pattern[1][0].expectIdent("strVal") elif pattern[1][1].kind in nnkIntLiterals: pattern[1][0].expectIdent("intVal") elif pattern[1][1].kind in nnkFloatLiterals: pattern[1][0].expectIdent("floatVal") genMatchLogic(bindSym"matchValue", kindSet, pattern[1][1]) else: let lengthLit = newLit(pattern.len - 1) genMatchLogic(bindSym"matchLengthKind", kindSet, lengthLit) for i in 1 ..< pattern.len: let childSym = nnkBracketExpr.newTree(localsArraySym, newLit(currentLocalIndex)) currentLocalIndex += 1 let indexLit = newLit(i - 1) dest.add quote do: `childSym` = `astSym`[`indexLit`] nodeVisiting(childSym, pattern[i], depth + 1) debug ind, ")" elif pattern.kind == nnkCallStrLit and pattern[0].eqIdent("ident"): genIdentMatchLogic(pattern[1]) elif pattern.kind == nnkPar and pattern.len == 1: nodeVisiting(astSym, pattern[0], depth) elif pattern.kind == nnkPrefix: error("prefix patterns not implemented", pattern) elif pattern.kind == nnkAccQuoted: debug ind, pattern.repr let matchedExpr = pattern[0] matchedExpr.expectKind nnkIdent dest.add quote do: let `matchedExpr` = `astSym` elif pattern.kind == nnkInfix and pattern[0].eqIdent("@"): pattern[1].expectKind nnkAccQuoted let matchedExpr = pattern[1][0] matchedExpr.expectKind nnkIdent dest.add quote do: let `matchedExpr` = `astSym` debug ind, pattern[1].repr, " = " nodeVisiting(matchedExpr, pattern[2], depth + 1) elif pattern.kind == nnkInfix and pattern[0].eqIdent("|="): nodeVisiting(astSym, pattern[1], depth + 1) genCustomMatchLogic(pattern[2]) elif pattern.kind in nnkCallKinds: error("only boring call syntax allowed, this is " & $pattern.kind & ".", pattern) elif pattern.kind in nnkLiterals: genMatchLogic(bindSym"matchValue", nnkCurly.newTree, pattern) elif not pattern.eqIdent("_"): # When it is not one of the other branches, it is simply treated # as an expression for the node kind, without checking child # nodes. debug ind, pattern.repr genMatchLogic(bindSym"matchLengthKind", pattern, newLit(-1)) nodeVisiting(astSym, pattern, depth) return currentLocalIndex macro matchAst*(astExpr: NimNode; args: varargs[untyped]): untyped = let astSym = genSym(nskLet, "ast") let beginBranches = if args[0].kind == nnkIdent: 1 else: 0 let endBranches = if args[^1].kind == nnkElse: args.len - 1 else: args.len for i in beginBranches ..< endBranches: args[i].expectKind nnkOfBranch let outerErrorSym: NimNode = if beginBranches == 1: args[0].expectKind nnkIdent args[0] else: nil let elseBranch: NimNode = if endBranches == args.len - 1: args[^1].expectKind(nnkElse) args[^1][0] else: nil let outerBlockLabel = genSym(nskLabel, "matchingSection") let outerStmtList = newStmtList() let errorSymbols = nnkBracket.newTree ## the vm only allows 255 local variables. This sucks a lot and I ## have to work around it. So instead of creating a lot of local ## variables, I just create one array of local variables. This is ## just annoying. let localsArraySym = genSym(nskVar, "locals") var localsArrayLen: int = 0 for i in beginBranches ..< endBranches: let ofBranch = args[i] ofBranch.expectKind(nnkOfBranch) ofBranch.expectLen(2) let pattern = ofBranch[0] let code = ofBranch[1] code.expectKind nnkStmtList let stmtList = newStmtList() let blockLabel = genSym(nskLabel, "matchingBranch") let errorSym = genSym(nskVar, "branchError") errorSymbols.add errorSym let numLocalsUsed = generateMatchingCode(astSym, pattern, 0, blockLabel, errorSym, localsArraySym, stmtList) localsArrayLen = max(localsArrayLen, numLocalsUsed) stmtList.add code # maybe there is a better mechanism disable errors for statement after return if code[^1].kind != nnkReturnStmt: stmtList.add nnkBreakStmt.newTree(outerBlockLabel) outerStmtList.add quote do: var `errorSym`: MatchingError block `blockLabel`: `stmtList` if elseBranch != nil: if outerErrorSym != nil: outerStmtList.add quote do: let `outerErrorSym` = @`errorSymbols` `elseBranch` else: outerStmtList.add elseBranch else: if errorSymbols.len == 1: # there is only one of branch and no else branch # the error message can be very precise here. let errorSym = errorSymbols[0] outerStmtList.add quote do: failWithMatchingError(`errorSym`) else: var patterns: string = "" for i in beginBranches ..< endBranches: let ofBranch = args[i] let pattern = ofBranch[0] patterns.add pattern.repr patterns.add "\n" let patternsLit = newLit(patterns) outerStmtList.add quote do: error("Ast pattern mismatch: got " & `astSym`.lispRepr & "\nbut expected one of:\n" & `patternsLit`, `astSym`) let lengthLit = newLit(localsArrayLen) result = quote do: block `outerBlockLabel`: let `astSym` = `astExpr` var `localsArraySym`: array[`lengthLit`, NimNode] `outerStmtList` debug result.repr proc recursiveNodeVisiting*(arg: NimNode, callback: proc(arg: NimNode): bool) = ## if `callback` returns true, visitor continues to visit the ## children of `arg` otherwise it stops. if callback(arg): for child in arg: recursiveNodeVisiting(child, callback) macro matchAstRecursive*(ast: NimNode; args: varargs[untyped]): untyped = # Does not recurse further on matched nodes. if args[^1].kind == nnkElse: error("Recursive matching with an else branch is pointless.", args[^1]) let visitor = genSym(nskProc, "visitor") let visitorArg = genSym(nskParam, "arg") let visitorStmtList = newStmtList() let matchingSection = genSym(nskLabel, "matchingSection") let localsArraySym = genSym(nskVar, "locals") let branchError = genSym(nskVar, "branchError") var localsArrayLen = 0 for ofBranch in args: ofBranch.expectKind(nnkOfBranch) ofBranch.expectLen(2) let pattern = ofBranch[0] let code = ofBranch[1] code.expectkind(nnkStmtList) let stmtList = newStmtList() let matchingBranch = genSym(nskLabel, "matchingBranch") let numLocalsUsed = generateMatchingCode(visitorArg, pattern, 0, matchingBranch, branchError, localsArraySym, stmtList) localsArrayLen = max(localsArrayLen, numLocalsUsed) stmtList.add code stmtList.add nnkBreakStmt.newTree(matchingSection) visitorStmtList.add quote do: `branchError`.kind = NoError block `matchingBranch`: `stmtList` let resultIdent = ident"result" let visitingProc = bindSym"recursiveNodeVisiting" let lengthLit = newLit(localsArrayLen) result = quote do: proc `visitor`(`visitorArg`: NimNode): bool = block `matchingSection`: var `localsArraySym`: array[`lengthLit`, NimNode] var `branchError`: MatchingError `visitorStmtList` `resultIdent` = true `visitingProc`(`ast`, `visitor`) debug result.repr ################################################################################ ################################# Example Code ################################# ################################################################################ when isMainModule: static: let mykinds = {nnkIdent, nnkCall} macro foo(arg: untyped): untyped = matchAst(arg, matchError): of nnkStmtList(nnkIdent, nnkIdent, nnkIdent): echo(88*88+33*33) of nnkStmtList( _( nnkIdentDefs( ident"a", nnkEmpty, nnkIntLit(intVal = 123) ) ), _, nnkForStmt( nnkIdent(strVal = "i"), nnkInfix, `mysym` @ nnkStmtList ) ): echo "The AST did match!!!" echo "The matched sub tree is the following:" echo mysym.lispRepr #else: # echo "sadly the AST did not match :(" # echo arg.treeRepr # failWithMatchingError(matchError[1]) foo: let a = 123 let b = 342 for i in a ..< b: echo "Hallo", i static: var ast = quote do: type A[T: static[int]] = object ast = ast[0] ast.matchAst(err): # this is a sub ast for this a findAst or something like that is useful of nnkTypeDef(_, nnkGenericParams( nnkIdentDefs( nnkIdent(strVal = "T"), `staticTy`, nnkEmpty )), _): echo "`", staticTy.repr, "` used to be of nnkStaticTy, now it is ", staticTy.kind, " with ", staticTy[0].repr ast = quote do: if cond1: expr1 elif cond2: expr2 else: expr3 ast.matchAst: of {nnkIfExpr, nnkIfStmt}( {nnkElifExpr, nnkElifBranch}(`cond1`, `expr1`), {nnkElifExpr, nnkElifBranch}(`cond2`, `expr2`), {nnkElseExpr, nnkElse}(`expr3`) ): echo "ok" let ast2 = nnkStmtList.newTree( newLit(1) ) ast2.matchAst: of nnkIntLit( 1 ): echo "fail" of nnkStmtList( 1 ): echo "ok" ast = bindSym"[]" ast.matchAst(errors): of nnkClosedSymChoice(strVal = "[]"): echo "fail, this is the wrong syntax, a sym choice does not have a `strVal` member." of ident"[]": echo "ok" const myConst = 123 ast = newLit(123) ast.matchAst: of _(intVal = myConst): echo "ok" macro testRecCase(ast: untyped): untyped = ast.matchAstRecursive: of nnkIdentDefs(`a`,`b`,`c`): echo "got ident defs a: ", a.repr, " b: ", b.repr, " c: ", c.repr of ident"m": echo "got the ident m" testRecCase: type Obj[T] = object {.inheritable.} name: string case isFat: bool of true: m: array[100_000, T] of false: m: array[10, T] macro testIfCondition(ast: untyped): untyped = let literals = nnkBracket.newTree ast.matchAstRecursive: of `intLit` @ nnkIntLit |= intLit.intVal > 5: literals.add intLit let literals2 = quote do: [6,7,8,9] doAssert literals2 == literals testIfCondition([1,6,2,7,3,8,4,9,5,0,"123"])