diff options
author | Arne Döring <arne.doering@gmx.net> | 2019-03-19 11:45:29 +0100 |
---|---|---|
committer | Andreas Rumpf <rumpf_a@web.de> | 2019-03-19 11:45:29 +0100 |
commit | 389b140029577845b9a0e40b6fecc8ba78af679f (patch) | |
tree | 3d3dd6de889cb3434209bca793d024aa9e806143 /tests/ast_pattern_matching.nim | |
parent | 5c1c5902e20cad9d1f28923e7abba3282ad4f8a1 (diff) | |
download | Nim-389b140029577845b9a0e40b6fecc8ba78af679f.tar.gz |
add tastspec (and ast_pattern_matching) (#10863)
Diffstat (limited to 'tests/ast_pattern_matching.nim')
-rw-r--r-- | tests/ast_pattern_matching.nim | 584 |
1 files changed, 584 insertions, 0 deletions
diff --git a/tests/ast_pattern_matching.nim b/tests/ast_pattern_matching.nim new file mode 100644 index 000000000..c08234b9e --- /dev/null +++ b/tests/ast_pattern_matching.nim @@ -0,0 +1,584 @@ +# this is a copy paste implementation of github.com/krux02/ast_pattern_matching +# Please provide bugfixes upstream first before adding them here. + +import macros, strutils, tables + +export macros + +when isMainModule: + template debug(args: varargs[untyped]): untyped = + echo args +else: + template debug(args: varargs[untyped]): untyped = + discard + +const + nnkIntLiterals* = nnkCharLit..nnkUInt64Lit + nnkStringLiterals* = nnkStrLit..nnkTripleStrLit + nnkFloatLiterals* = nnkFloatLit..nnkFloat64Lit + +proc newLit[T: enum](arg: T): NimNode = + newIdentNode($arg) + +proc newLit[T](arg: set[T]): NimNode = + ## does not work for the empty sets + result = nnkCurly.newTree + for x in arg: + result.add newLit(x) + +type SomeFloat = float | float32 | float64 + +proc len[T](arg: set[T]): int = card(arg) + +type + MatchingErrorKind* = enum + NoError + WrongKindLength + WrongKindValue + WrongIdent + WrongCustomCondition + + MatchingError = object + node*: NimNode + expectedKind*: set[NimNodeKind] + case kind*: MatchingErrorKind + of NoError: + discard + of WrongKindLength: + expectedLength*: int + of WrongKindValue: + expectedValue*: NimNode + of WrongIdent, WrongCustomCondition: + strVal*: string + +proc `$`*(arg: MatchingError): string = + let n = arg.node + case arg.kind + of NoError: + "no error" + of WrongKindLength: + let k = arg.expectedKind + let l = arg.expectedLength + var msg = "expected " + if k.len == 0: + msg.add "any node" + elif k.len == 1: + for el in k: # only one element but there is no index op for sets + msg.add $el + else: + msg.add "a node in" & $k + + if l >= 0: + msg.add " with " & $l & " child(ren)" + msg.add ", but got " & $n.kind + if l >= 0: + msg.add " with " & $n.len & " child(ren)" + msg + of WrongKindValue: + let k = $arg.expectedKind + let v = arg.expectedValue.repr + var msg = "expected " & k & " with value " & v & " but got " & n.lispRepr + if n.kind in {nnkOpenSymChoice, nnkClosedSymChoice}: + msg = msg & " (a sym-choice does not have a strVal member, maybe you should match with `ident`)" + msg + of WrongIdent: + let prefix = "expected ident `" & arg.strVal & "` but got " + if n.kind in {nnkIdent, nnkSym, nnkOpenSymChoice, nnkClosedSymChoice}: + prefix & "`" & n.strVal & "`" + else: + prefix & $n.kind & " with " & $n.len & " child(ren)" + of WrongCustomCondition: + "custom condition check failed: " & arg.strVal + + +proc failWithMatchingError*(arg: MatchingError): void {.compileTime, noReturn.} = + error($arg, arg.node) + +proc expectValue(arg: NimNode; value: SomeInteger): void {.compileTime.} = + arg.expectKind nnkLiterals + if arg.intVal != int(value): + error("expected value " & $value & " but got " & arg.repr, arg) + +proc expectValue(arg: NimNode; value: SomeFloat): void {.compileTime.} = + arg.expectKind nnkLiterals + if arg.floatVal != float(value): + error("expected value " & $value & " but got " & arg.repr, arg) + +proc expectValue(arg: NimNode; value: string): void {.compileTime.} = + arg.expectKind nnkLiterals + if arg.strVal != value: + error("expected value " & value & " but got " & arg.repr, arg) + +proc expectValue[T](arg: NimNode; value: pointer): void {.compileTime.} = + arg.expectKind nnkLiterals + if value != nil: + error("Expect Value for pointers works only on `nil` when the argument is a pointer.") + arg.expectKind nnkNilLit + +proc expectIdent(arg: NimNode; strVal: string): void {.compileTime.} = + if not arg.eqIdent(strVal): + error("Expect ident `" & strVal & "` but got " & arg.repr) + +proc matchLengthKind*(arg: NimNode; kind: set[NimNodeKind]; length: int): MatchingError {.compileTime.} = + let kindFail = not(kind.card == 0 or arg.kind in kind) + let lengthFail = not(length < 0 or length == arg.len) + if kindFail or lengthFail: + result.node = arg + result.kind = WrongKindLength + result.expectedLength = length + result.expectedKind = kind + + +proc matchLengthKind*(arg: NimNode; kind: NimNodeKind; length: int): MatchingError {.compileTime.} = + matchLengthKind(arg, {kind}, length) + +proc matchValue(arg: NimNode; kind: set[NimNodeKind]; value: SomeInteger): MatchingError {.compileTime.} = + let kindFail = not(kind.card == 0 or arg.kind in kind) + let valueFail = arg.intVal != int(value) + if kindFail or valueFail: + result.node = arg + result.kind = WrongKindValue + result.expectedKind = kind + result.expectedValue = newLit(value) + +proc matchValue(arg: NimNode; kind: NimNodeKind; value: SomeInteger): MatchingError {.compileTime.} = + matchValue(arg, {kind}, value) + +proc matchValue(arg: NimNode; kind: set[NimNodeKind]; value: SomeFloat): MatchingError {.compileTime.} = + let kindFail = not(kind.card == 0 or arg.kind in kind) + let valueFail = arg.floatVal != float(value) + if kindFail or valueFail: + result.node = arg + result.kind = WrongKindValue + result.expectedKind = kind + result.expectedValue = newLit(value) + +proc matchValue(arg: NimNode; kind: NimNodeKind; value: SomeFloat): MatchingError {.compileTime.} = + matchValue(arg, {kind}, value) + +const nnkStrValKinds = {nnkStrLit, nnkRStrLit, nnkTripleStrLit, nnkIdent, nnkSym} + +proc matchValue(arg: NimNode; kind: set[NimNodeKind]; value: string): MatchingError {.compileTime.} = + # if kind * nnkStringLiterals TODO do something that ensures that here is only checked for string literals + let kindFail = not(kind.card == 0 or arg.kind in kind) + let valueFail = + if kind.card == 0: + false + else: + arg.kind notin (kind * nnkStrValKinds) or arg.strVal != value + if kindFail or valueFail: + result.node = arg + result.kind = WrongKindValue + result.expectedKind = kind + result.expectedValue = newLit(value) + +proc matchValue(arg: NimNode; kind: NimNodeKind; value: string): MatchingError {.compileTime.} = + matchValue(arg, {kind}, value) + +proc matchValue[T](arg: NimNode; value: pointer): MatchingError {.compileTime.} = + if value != nil: + error("Expect Value for pointers works only on `nil` when the argument is a pointer.") + arg.matchLengthKind(nnkNilLit, -1) + +proc matchIdent*(arg:NimNode; value: string): MatchingError = + if not arg.eqIdent(value): + result.node = arg + result.kind = Wrongident + result.strVal = value + +proc checkCustomExpr*(arg: NimNode; cond: bool, exprstr: string): MatchingError = + if not cond: + result.node = arg + result.kind = WrongCustomCondition + result.strVal = exprstr + +static: + var literals: array[19, NimNode] + var i = 0 + for litKind in nnkLiterals: + literals[i] = ident($litKind) + i += 1 + + var nameToKind = newTable[string, NimNodeKind]() + for kind in NimNodeKind: + nameToKind[ ($kind)[3..^1] ] = kind + + let identifierKinds = newLit({nnkSym, nnkIdent, nnkOpenSymChoice, nnkClosedSymChoice}) + +proc generateMatchingCode(astSym: NimNode, pattern: NimNode, depth: int, blockLabel, errorSym, localsArraySym: NimNode; dest: NimNode): int = + ## return the number of indices used in the array for local variables. + + var currentLocalIndex = 0 + + proc nodeVisiting(astSym: NimNode, pattern: NimNode, depth: int): void = + let ind = " ".repeat(depth) # indentation + + proc genMatchLogic(matchProc, argSym1, argSym2: NimNode): void = + dest.add quote do: + `errorSym` = `astSym`.`matchProc`(`argSym1`, `argSym2`) + if `errorSym`.kind != NoError: + break `blockLabel` + + proc genIdentMatchLogic(identValueLit: NimNode): void = + dest.add quote do: + `errorSym` = `astSym`.matchIdent(`identValueLit`) + if `errorSym`.kind != NoError: + break `blockLabel` + + proc genCustomMatchLogic(conditionExpr: NimNode): void = + let exprStr = newLit(conditionExpr.repr) + dest.add quote do: + `errorSym` = `astSym`.checkCustomExpr(`conditionExpr`, `exprStr`) + if `errorSym`.kind != NoError: + break `blockLabel` + + # proc handleKindMatching(kindExpr: NimNode): void = + # if kindExpr.eqIdent("_"): + # # this is the wildcand that matches any kind + # return + # else: + # genMatchLogic(bindSym"matchKind", kindExpr) + + # generate recursively a matching expression + if pattern.kind == nnkCall: + pattern.expectMinLen(1) + + debug ind, pattern[0].repr, "(" + + let kindSet = if pattern[0].eqIdent("_"): nnkCurly.newTree else: pattern[0] + # handleKindMatching(pattern[0]) + + if pattern.len == 2 and pattern[1].kind == nnkExprEqExpr: + if pattern[1][1].kind in nnkStringLiterals: + pattern[1][0].expectIdent("strVal") + elif pattern[1][1].kind in nnkIntLiterals: + pattern[1][0].expectIdent("intVal") + elif pattern[1][1].kind in nnkFloatLiterals: + pattern[1][0].expectIdent("floatVal") + + genMatchLogic(bindSym"matchValue", kindSet, pattern[1][1]) + + else: + let lengthLit = newLit(pattern.len - 1) + genMatchLogic(bindSym"matchLengthKind", kindSet, lengthLit) + + for i in 1 ..< pattern.len: + let childSym = nnkBracketExpr.newTree(localsArraySym, newLit(currentLocalIndex)) + currentLocalIndex += 1 + let indexLit = newLit(i - 1) + dest.add quote do: + `childSym` = `astSym`[`indexLit`] + nodeVisiting(childSym, pattern[i], depth + 1) + debug ind, ")" + elif pattern.kind == nnkCallStrLit and pattern[0].eqIdent("ident"): + genIdentMatchLogic(pattern[1]) + + elif pattern.kind == nnkPar and pattern.len == 1: + nodeVisiting(astSym, pattern[0], depth) + elif pattern.kind == nnkPrefix: + error("prefix patterns not implemented", pattern) + elif pattern.kind == nnkAccQuoted: + debug ind, pattern.repr + let matchedExpr = pattern[0] + matchedExpr.expectKind nnkIdent + dest.add quote do: + let `matchedExpr` = `astSym` + + elif pattern.kind == nnkInfix and pattern[0].eqIdent("@"): + pattern[1].expectKind nnkAccQuoted + + let matchedExpr = pattern[1][0] + matchedExpr.expectKind nnkIdent + dest.add quote do: + let `matchedExpr` = `astSym` + + debug ind, pattern[1].repr, " = " + nodeVisiting(matchedExpr, pattern[2], depth + 1) + + elif pattern.kind == nnkInfix and pattern[0].eqIdent("|="): + nodeVisiting(astSym, pattern[1], depth + 1) + genCustomMatchLogic(pattern[2]) + + elif pattern.kind in nnkCallKinds: + error("only boring call syntax allowed, this is " & $pattern.kind & ".", pattern) + elif pattern.kind in nnkLiterals: + genMatchLogic(bindSym"matchValue", nnkCurly.newTree, pattern) + elif not pattern.eqIdent("_"): + # When it is not one of the other branches, it is simply treated + # as an expression for the node kind, without checking child + # nodes. + debug ind, pattern.repr + genMatchLogic(bindSym"matchLengthKind", pattern, newLit(-1)) + + nodeVisiting(astSym, pattern, depth) + + return currentLocalIndex + +macro matchAst*(astExpr: NimNode; args: varargs[untyped]): untyped = + let astSym = genSym(nskLet, "ast") + let beginBranches = if args[0].kind == nnkIdent: 1 else: 0 + let endBranches = if args[^1].kind == nnkElse: args.len - 1 else: args.len + for i in beginBranches ..< endBranches: + args[i].expectKind nnkOfBranch + + let outerErrorSym: NimNode = + if beginBranches == 1: + args[0].expectKind nnkIdent + args[0] + else: + nil + + let elseBranch: NimNode = + if endBranches == args.len - 1: + args[^1].expectKind(nnkElse) + args[^1][0] + else: + nil + + let outerBlockLabel = genSym(nskLabel, "matchingSection") + let outerStmtList = newStmtList() + let errorSymbols = nnkBracket.newTree + + ## the vm only allows 255 local variables. This sucks a lot and I + ## have to work around it. So instead of creating a lot of local + ## variables, I just create one array of local variables. This is + ## just annoying. + let localsArraySym = genSym(nskVar, "locals") + var localsArrayLen: int = 0 + + for i in beginBranches ..< endBranches: + let ofBranch = args[i] + + ofBranch.expectKind(nnkOfBranch) + ofBranch.expectLen(2) + let pattern = ofBranch[0] + let code = ofBranch[1] + code.expectKind nnkStmtList + let stmtList = newStmtList() + let blockLabel = genSym(nskLabel, "matchingBranch") + let errorSym = genSym(nskVar, "branchError") + + errorSymbols.add errorSym + let numLocalsUsed = generateMatchingCode(astSym, pattern, 0, blockLabel, errorSym, localsArraySym, stmtList) + localsArrayLen = max(localsArrayLen, numLocalsUsed) + stmtList.add code + # maybe there is a better mechanism disable errors for statement after return + if code[^1].kind != nnkReturnStmt: + stmtList.add nnkBreakStmt.newTree(outerBlockLabel) + + outerStmtList.add quote do: + var `errorSym`: MatchingError + block `blockLabel`: + `stmtList` + + if elseBranch != nil: + if outerErrorSym != nil: + outerStmtList.add quote do: + let `outerErrorSym` = @`errorSymbols` + `elseBranch` + else: + outerStmtList.add elseBranch + + else: + if errorSymbols.len == 1: + # there is only one of branch and no else branch + # the error message can be very precise here. + let errorSym = errorSymbols[0] + outerStmtList.add quote do: + failWithMatchingError(`errorSym`) + else: + + var patterns: string = "" + for i in beginBranches ..< endBranches: + let ofBranch = args[i] + let pattern = ofBranch[0] + patterns.add pattern.repr + patterns.add "\n" + + let patternsLit = newLit(patterns) + outerStmtList.add quote do: + error("Ast pattern mismatch: got " & `astSym`.lispRepr & "\nbut expected one of:\n" & `patternsLit`, `astSym`) + + let lengthLit = newLit(localsArrayLen) + result = quote do: + block `outerBlockLabel`: + let `astSym` = `astExpr` + var `localsArraySym`: array[`lengthLit`, NimNode] + `outerStmtList` + + debug result.repr + +proc recursiveNodeVisiting*(arg: NimNode, callback: proc(arg: NimNode): bool) = + ## if `callback` returns true, visitor continues to visit the + ## children of `arg` otherwise it stops. + if callback(arg): + for child in arg: + recursiveNodeVisiting(child, callback) + +macro matchAstRecursive*(ast: NimNode; args: varargs[untyped]): untyped = + # Does not recurse further on matched nodes. + if args[^1].kind == nnkElse: + error("Recursive matching with an else branch is pointless.", args[^1]) + + let visitor = genSym(nskProc, "visitor") + let visitorArg = genSym(nskParam, "arg") + + let visitorStmtList = newStmtList() + + let matchingSection = genSym(nskLabel, "matchingSection") + + let localsArraySym = genSym(nskVar, "locals") + let branchError = genSym(nskVar, "branchError") + var localsArrayLen = 0 + + for ofBranch in args: + ofBranch.expectKind(nnkOfBranch) + ofBranch.expectLen(2) + let pattern = ofBranch[0] + let code = ofBranch[1] + code.expectkind(nnkStmtList) + + let stmtList = newStmtList() + let matchingBranch = genSym(nskLabel, "matchingBranch") + + let numLocalsUsed = generateMatchingCode(visitorArg, pattern, 0, matchingBranch, branchError, localsArraySym, stmtList) + localsArrayLen = max(localsArrayLen, numLocalsUsed) + + stmtList.add code + stmtList.add nnkBreakStmt.newTree(matchingSection) + + + visitorStmtList.add quote do: + `branchError`.kind = NoError + block `matchingBranch`: + `stmtList` + + let resultIdent = ident"result" + + let visitingProc = bindSym"recursiveNodeVisiting" + let lengthLit = newLit(localsArrayLen) + + result = quote do: + proc `visitor`(`visitorArg`: NimNode): bool = + block `matchingSection`: + var `localsArraySym`: array[`lengthLit`, NimNode] + var `branchError`: MatchingError + `visitorStmtList` + `resultIdent` = true + + `visitingProc`(`ast`, `visitor`) + + debug result.repr + +################################################################################ +################################# Example Code ################################# +################################################################################ + +when isMainModule: + static: + let mykinds = {nnkIdent, nnkCall} + + macro foo(arg: untyped): untyped = + matchAst(arg, matchError): + of nnkStmtList(nnkIdent, nnkIdent, nnkIdent): + echo(88*88+33*33) + of nnkStmtList( + _( + nnkIdentDefs( + ident"a", + nnkEmpty, nnkIntLit(intVal = 123) + ) + ), + _, + nnkForStmt( + nnkIdent(strVal = "i"), + nnkInfix, + `mysym` @ nnkStmtList + ) + ): + echo "The AST did match!!!" + echo "The matched sub tree is the following:" + echo mysym.lispRepr + #else: + # echo "sadly the AST did not match :(" + # echo arg.treeRepr + # failWithMatchingError(matchError[1]) + + foo: + let a = 123 + let b = 342 + for i in a ..< b: + echo "Hallo", i + + static: + + var ast = quote do: + type + A[T: static[int]] = object + + ast = ast[0] + ast.matchAst(err): # this is a sub ast for this a findAst or something like that is useful + of nnkTypeDef(_, nnkGenericParams( nnkIdentDefs( nnkIdent(strVal = "T"), `staticTy`, nnkEmpty )), _): + echo "`", staticTy.repr, "` used to be of nnkStaticTy, now it is ", staticTy.kind, " with ", staticTy[0].repr + ast = quote do: + if cond1: expr1 elif cond2: expr2 else: expr3 + + ast.matchAst: + of {nnkIfExpr, nnkIfStmt}( + {nnkElifExpr, nnkElifBranch}(`cond1`, `expr1`), + {nnkElifExpr, nnkElifBranch}(`cond2`, `expr2`), + {nnkElseExpr, nnkElse}(`expr3`) + ): + echo "ok" + + let ast2 = nnkStmtList.newTree( newLit(1) ) + + ast2.matchAst: + of nnkIntLit( 1 ): + echo "fail" + of nnkStmtList( 1 ): + echo "ok" + + ast = bindSym"[]" + ast.matchAst(errors): + of nnkClosedSymChoice(strVal = "[]"): + echo "fail, this is the wrong syntax, a sym choice does not have a `strVal` member." + of ident"[]": + echo "ok" + + const myConst = 123 + ast = newLit(123) + + ast.matchAst: + of _(intVal = myConst): + echo "ok" + + macro testRecCase(ast: untyped): untyped = + ast.matchAstRecursive: + of nnkIdentDefs(`a`,`b`,`c`): + echo "got ident defs a: ", a.repr, " b: ", b.repr, " c: ", c.repr + of ident"m": + echo "got the ident m" + + testRecCase: + type Obj[T] = object {.inheritable.} + name: string + case isFat: bool + of true: + m: array[100_000, T] + of false: + m: array[10, T] + + + macro testIfCondition(ast: untyped): untyped = + let literals = nnkBracket.newTree + ast.matchAstRecursive: + of `intLit` @ nnkIntLit |= intLit.intVal > 5: + literals.add intLit + + let literals2 = quote do: + [6,7,8,9] + + doAssert literals2 == literals + + testIfCondition([1,6,2,7,3,8,4,9,5,0,"123"]) |