diff options
-rw-r--r-- | tests/ast_pattern_matching.nim | 584 | ||||
-rw-r--r-- | tests/astspec/tastspec.nim | 1016 |
2 files changed, 1600 insertions, 0 deletions
diff --git a/tests/ast_pattern_matching.nim b/tests/ast_pattern_matching.nim new file mode 100644 index 000000000..c08234b9e --- /dev/null +++ b/tests/ast_pattern_matching.nim @@ -0,0 +1,584 @@ +# this is a copy paste implementation of github.com/krux02/ast_pattern_matching +# Please provide bugfixes upstream first before adding them here. + +import macros, strutils, tables + +export macros + +when isMainModule: + template debug(args: varargs[untyped]): untyped = + echo args +else: + template debug(args: varargs[untyped]): untyped = + discard + +const + nnkIntLiterals* = nnkCharLit..nnkUInt64Lit + nnkStringLiterals* = nnkStrLit..nnkTripleStrLit + nnkFloatLiterals* = nnkFloatLit..nnkFloat64Lit + +proc newLit[T: enum](arg: T): NimNode = + newIdentNode($arg) + +proc newLit[T](arg: set[T]): NimNode = + ## does not work for the empty sets + result = nnkCurly.newTree + for x in arg: + result.add newLit(x) + +type SomeFloat = float | float32 | float64 + +proc len[T](arg: set[T]): int = card(arg) + +type + MatchingErrorKind* = enum + NoError + WrongKindLength + WrongKindValue + WrongIdent + WrongCustomCondition + + MatchingError = object + node*: NimNode + expectedKind*: set[NimNodeKind] + case kind*: MatchingErrorKind + of NoError: + discard + of WrongKindLength: + expectedLength*: int + of WrongKindValue: + expectedValue*: NimNode + of WrongIdent, WrongCustomCondition: + strVal*: string + +proc `$`*(arg: MatchingError): string = + let n = arg.node + case arg.kind + of NoError: + "no error" + of WrongKindLength: + let k = arg.expectedKind + let l = arg.expectedLength + var msg = "expected " + if k.len == 0: + msg.add "any node" + elif k.len == 1: + for el in k: # only one element but there is no index op for sets + msg.add $el + else: + msg.add "a node in" & $k + + if l >= 0: + msg.add " with " & $l & " child(ren)" + msg.add ", but got " & $n.kind + if l >= 0: + msg.add " with " & $n.len & " child(ren)" + msg + of WrongKindValue: + let k = $arg.expectedKind + let v = arg.expectedValue.repr + var msg = "expected " & k & " with value " & v & " but got " & n.lispRepr + if n.kind in {nnkOpenSymChoice, nnkClosedSymChoice}: + msg = msg & " (a sym-choice does not have a strVal member, maybe you should match with `ident`)" + msg + of WrongIdent: + let prefix = "expected ident `" & arg.strVal & "` but got " + if n.kind in {nnkIdent, nnkSym, nnkOpenSymChoice, nnkClosedSymChoice}: + prefix & "`" & n.strVal & "`" + else: + prefix & $n.kind & " with " & $n.len & " child(ren)" + of WrongCustomCondition: + "custom condition check failed: " & arg.strVal + + +proc failWithMatchingError*(arg: MatchingError): void {.compileTime, noReturn.} = + error($arg, arg.node) + +proc expectValue(arg: NimNode; value: SomeInteger): void {.compileTime.} = + arg.expectKind nnkLiterals + if arg.intVal != int(value): + error("expected value " & $value & " but got " & arg.repr, arg) + +proc expectValue(arg: NimNode; value: SomeFloat): void {.compileTime.} = + arg.expectKind nnkLiterals + if arg.floatVal != float(value): + error("expected value " & $value & " but got " & arg.repr, arg) + +proc expectValue(arg: NimNode; value: string): void {.compileTime.} = + arg.expectKind nnkLiterals + if arg.strVal != value: + error("expected value " & value & " but got " & arg.repr, arg) + +proc expectValue[T](arg: NimNode; value: pointer): void {.compileTime.} = + arg.expectKind nnkLiterals + if value != nil: + error("Expect Value for pointers works only on `nil` when the argument is a pointer.") + arg.expectKind nnkNilLit + +proc expectIdent(arg: NimNode; strVal: string): void {.compileTime.} = + if not arg.eqIdent(strVal): + error("Expect ident `" & strVal & "` but got " & arg.repr) + +proc matchLengthKind*(arg: NimNode; kind: set[NimNodeKind]; length: int): MatchingError {.compileTime.} = + let kindFail = not(kind.card == 0 or arg.kind in kind) + let lengthFail = not(length < 0 or length == arg.len) + if kindFail or lengthFail: + result.node = arg + result.kind = WrongKindLength + result.expectedLength = length + result.expectedKind = kind + + +proc matchLengthKind*(arg: NimNode; kind: NimNodeKind; length: int): MatchingError {.compileTime.} = + matchLengthKind(arg, {kind}, length) + +proc matchValue(arg: NimNode; kind: set[NimNodeKind]; value: SomeInteger): MatchingError {.compileTime.} = + let kindFail = not(kind.card == 0 or arg.kind in kind) + let valueFail = arg.intVal != int(value) + if kindFail or valueFail: + result.node = arg + result.kind = WrongKindValue + result.expectedKind = kind + result.expectedValue = newLit(value) + +proc matchValue(arg: NimNode; kind: NimNodeKind; value: SomeInteger): MatchingError {.compileTime.} = + matchValue(arg, {kind}, value) + +proc matchValue(arg: NimNode; kind: set[NimNodeKind]; value: SomeFloat): MatchingError {.compileTime.} = + let kindFail = not(kind.card == 0 or arg.kind in kind) + let valueFail = arg.floatVal != float(value) + if kindFail or valueFail: + result.node = arg + result.kind = WrongKindValue + result.expectedKind = kind + result.expectedValue = newLit(value) + +proc matchValue(arg: NimNode; kind: NimNodeKind; value: SomeFloat): MatchingError {.compileTime.} = + matchValue(arg, {kind}, value) + +const nnkStrValKinds = {nnkStrLit, nnkRStrLit, nnkTripleStrLit, nnkIdent, nnkSym} + +proc matchValue(arg: NimNode; kind: set[NimNodeKind]; value: string): MatchingError {.compileTime.} = + # if kind * nnkStringLiterals TODO do something that ensures that here is only checked for string literals + let kindFail = not(kind.card == 0 or arg.kind in kind) + let valueFail = + if kind.card == 0: + false + else: + arg.kind notin (kind * nnkStrValKinds) or arg.strVal != value + if kindFail or valueFail: + result.node = arg + result.kind = WrongKindValue + result.expectedKind = kind + result.expectedValue = newLit(value) + +proc matchValue(arg: NimNode; kind: NimNodeKind; value: string): MatchingError {.compileTime.} = + matchValue(arg, {kind}, value) + +proc matchValue[T](arg: NimNode; value: pointer): MatchingError {.compileTime.} = + if value != nil: + error("Expect Value for pointers works only on `nil` when the argument is a pointer.") + arg.matchLengthKind(nnkNilLit, -1) + +proc matchIdent*(arg:NimNode; value: string): MatchingError = + if not arg.eqIdent(value): + result.node = arg + result.kind = Wrongident + result.strVal = value + +proc checkCustomExpr*(arg: NimNode; cond: bool, exprstr: string): MatchingError = + if not cond: + result.node = arg + result.kind = WrongCustomCondition + result.strVal = exprstr + +static: + var literals: array[19, NimNode] + var i = 0 + for litKind in nnkLiterals: + literals[i] = ident($litKind) + i += 1 + + var nameToKind = newTable[string, NimNodeKind]() + for kind in NimNodeKind: + nameToKind[ ($kind)[3..^1] ] = kind + + let identifierKinds = newLit({nnkSym, nnkIdent, nnkOpenSymChoice, nnkClosedSymChoice}) + +proc generateMatchingCode(astSym: NimNode, pattern: NimNode, depth: int, blockLabel, errorSym, localsArraySym: NimNode; dest: NimNode): int = + ## return the number of indices used in the array for local variables. + + var currentLocalIndex = 0 + + proc nodeVisiting(astSym: NimNode, pattern: NimNode, depth: int): void = + let ind = " ".repeat(depth) # indentation + + proc genMatchLogic(matchProc, argSym1, argSym2: NimNode): void = + dest.add quote do: + `errorSym` = `astSym`.`matchProc`(`argSym1`, `argSym2`) + if `errorSym`.kind != NoError: + break `blockLabel` + + proc genIdentMatchLogic(identValueLit: NimNode): void = + dest.add quote do: + `errorSym` = `astSym`.matchIdent(`identValueLit`) + if `errorSym`.kind != NoError: + break `blockLabel` + + proc genCustomMatchLogic(conditionExpr: NimNode): void = + let exprStr = newLit(conditionExpr.repr) + dest.add quote do: + `errorSym` = `astSym`.checkCustomExpr(`conditionExpr`, `exprStr`) + if `errorSym`.kind != NoError: + break `blockLabel` + + # proc handleKindMatching(kindExpr: NimNode): void = + # if kindExpr.eqIdent("_"): + # # this is the wildcand that matches any kind + # return + # else: + # genMatchLogic(bindSym"matchKind", kindExpr) + + # generate recursively a matching expression + if pattern.kind == nnkCall: + pattern.expectMinLen(1) + + debug ind, pattern[0].repr, "(" + + let kindSet = if pattern[0].eqIdent("_"): nnkCurly.newTree else: pattern[0] + # handleKindMatching(pattern[0]) + + if pattern.len == 2 and pattern[1].kind == nnkExprEqExpr: + if pattern[1][1].kind in nnkStringLiterals: + pattern[1][0].expectIdent("strVal") + elif pattern[1][1].kind in nnkIntLiterals: + pattern[1][0].expectIdent("intVal") + elif pattern[1][1].kind in nnkFloatLiterals: + pattern[1][0].expectIdent("floatVal") + + genMatchLogic(bindSym"matchValue", kindSet, pattern[1][1]) + + else: + let lengthLit = newLit(pattern.len - 1) + genMatchLogic(bindSym"matchLengthKind", kindSet, lengthLit) + + for i in 1 ..< pattern.len: + let childSym = nnkBracketExpr.newTree(localsArraySym, newLit(currentLocalIndex)) + currentLocalIndex += 1 + let indexLit = newLit(i - 1) + dest.add quote do: + `childSym` = `astSym`[`indexLit`] + nodeVisiting(childSym, pattern[i], depth + 1) + debug ind, ")" + elif pattern.kind == nnkCallStrLit and pattern[0].eqIdent("ident"): + genIdentMatchLogic(pattern[1]) + + elif pattern.kind == nnkPar and pattern.len == 1: + nodeVisiting(astSym, pattern[0], depth) + elif pattern.kind == nnkPrefix: + error("prefix patterns not implemented", pattern) + elif pattern.kind == nnkAccQuoted: + debug ind, pattern.repr + let matchedExpr = pattern[0] + matchedExpr.expectKind nnkIdent + dest.add quote do: + let `matchedExpr` = `astSym` + + elif pattern.kind == nnkInfix and pattern[0].eqIdent("@"): + pattern[1].expectKind nnkAccQuoted + + let matchedExpr = pattern[1][0] + matchedExpr.expectKind nnkIdent + dest.add quote do: + let `matchedExpr` = `astSym` + + debug ind, pattern[1].repr, " = " + nodeVisiting(matchedExpr, pattern[2], depth + 1) + + elif pattern.kind == nnkInfix and pattern[0].eqIdent("|="): + nodeVisiting(astSym, pattern[1], depth + 1) + genCustomMatchLogic(pattern[2]) + + elif pattern.kind in nnkCallKinds: + error("only boring call syntax allowed, this is " & $pattern.kind & ".", pattern) + elif pattern.kind in nnkLiterals: + genMatchLogic(bindSym"matchValue", nnkCurly.newTree, pattern) + elif not pattern.eqIdent("_"): + # When it is not one of the other branches, it is simply treated + # as an expression for the node kind, without checking child + # nodes. + debug ind, pattern.repr + genMatchLogic(bindSym"matchLengthKind", pattern, newLit(-1)) + + nodeVisiting(astSym, pattern, depth) + + return currentLocalIndex + +macro matchAst*(astExpr: NimNode; args: varargs[untyped]): untyped = + let astSym = genSym(nskLet, "ast") + let beginBranches = if args[0].kind == nnkIdent: 1 else: 0 + let endBranches = if args[^1].kind == nnkElse: args.len - 1 else: args.len + for i in beginBranches ..< endBranches: + args[i].expectKind nnkOfBranch + + let outerErrorSym: NimNode = + if beginBranches == 1: + args[0].expectKind nnkIdent + args[0] + else: + nil + + let elseBranch: NimNode = + if endBranches == args.len - 1: + args[^1].expectKind(nnkElse) + args[^1][0] + else: + nil + + let outerBlockLabel = genSym(nskLabel, "matchingSection") + let outerStmtList = newStmtList() + let errorSymbols = nnkBracket.newTree + + ## the vm only allows 255 local variables. This sucks a lot and I + ## have to work around it. So instead of creating a lot of local + ## variables, I just create one array of local variables. This is + ## just annoying. + let localsArraySym = genSym(nskVar, "locals") + var localsArrayLen: int = 0 + + for i in beginBranches ..< endBranches: + let ofBranch = args[i] + + ofBranch.expectKind(nnkOfBranch) + ofBranch.expectLen(2) + let pattern = ofBranch[0] + let code = ofBranch[1] + code.expectKind nnkStmtList + let stmtList = newStmtList() + let blockLabel = genSym(nskLabel, "matchingBranch") + let errorSym = genSym(nskVar, "branchError") + + errorSymbols.add errorSym + let numLocalsUsed = generateMatchingCode(astSym, pattern, 0, blockLabel, errorSym, localsArraySym, stmtList) + localsArrayLen = max(localsArrayLen, numLocalsUsed) + stmtList.add code + # maybe there is a better mechanism disable errors for statement after return + if code[^1].kind != nnkReturnStmt: + stmtList.add nnkBreakStmt.newTree(outerBlockLabel) + + outerStmtList.add quote do: + var `errorSym`: MatchingError + block `blockLabel`: + `stmtList` + + if elseBranch != nil: + if outerErrorSym != nil: + outerStmtList.add quote do: + let `outerErrorSym` = @`errorSymbols` + `elseBranch` + else: + outerStmtList.add elseBranch + + else: + if errorSymbols.len == 1: + # there is only one of branch and no else branch + # the error message can be very precise here. + let errorSym = errorSymbols[0] + outerStmtList.add quote do: + failWithMatchingError(`errorSym`) + else: + + var patterns: string = "" + for i in beginBranches ..< endBranches: + let ofBranch = args[i] + let pattern = ofBranch[0] + patterns.add pattern.repr + patterns.add "\n" + + let patternsLit = newLit(patterns) + outerStmtList.add quote do: + error("Ast pattern mismatch: got " & `astSym`.lispRepr & "\nbut expected one of:\n" & `patternsLit`, `astSym`) + + let lengthLit = newLit(localsArrayLen) + result = quote do: + block `outerBlockLabel`: + let `astSym` = `astExpr` + var `localsArraySym`: array[`lengthLit`, NimNode] + `outerStmtList` + + debug result.repr + +proc recursiveNodeVisiting*(arg: NimNode, callback: proc(arg: NimNode): bool) = + ## if `callback` returns true, visitor continues to visit the + ## children of `arg` otherwise it stops. + if callback(arg): + for child in arg: + recursiveNodeVisiting(child, callback) + +macro matchAstRecursive*(ast: NimNode; args: varargs[untyped]): untyped = + # Does not recurse further on matched nodes. + if args[^1].kind == nnkElse: + error("Recursive matching with an else branch is pointless.", args[^1]) + + let visitor = genSym(nskProc, "visitor") + let visitorArg = genSym(nskParam, "arg") + + let visitorStmtList = newStmtList() + + let matchingSection = genSym(nskLabel, "matchingSection") + + let localsArraySym = genSym(nskVar, "locals") + let branchError = genSym(nskVar, "branchError") + var localsArrayLen = 0 + + for ofBranch in args: + ofBranch.expectKind(nnkOfBranch) + ofBranch.expectLen(2) + let pattern = ofBranch[0] + let code = ofBranch[1] + code.expectkind(nnkStmtList) + + let stmtList = newStmtList() + let matchingBranch = genSym(nskLabel, "matchingBranch") + + let numLocalsUsed = generateMatchingCode(visitorArg, pattern, 0, matchingBranch, branchError, localsArraySym, stmtList) + localsArrayLen = max(localsArrayLen, numLocalsUsed) + + stmtList.add code + stmtList.add nnkBreakStmt.newTree(matchingSection) + + + visitorStmtList.add quote do: + `branchError`.kind = NoError + block `matchingBranch`: + `stmtList` + + let resultIdent = ident"result" + + let visitingProc = bindSym"recursiveNodeVisiting" + let lengthLit = newLit(localsArrayLen) + + result = quote do: + proc `visitor`(`visitorArg`: NimNode): bool = + block `matchingSection`: + var `localsArraySym`: array[`lengthLit`, NimNode] + var `branchError`: MatchingError + `visitorStmtList` + `resultIdent` = true + + `visitingProc`(`ast`, `visitor`) + + debug result.repr + +################################################################################ +################################# Example Code ################################# +################################################################################ + +when isMainModule: + static: + let mykinds = {nnkIdent, nnkCall} + + macro foo(arg: untyped): untyped = + matchAst(arg, matchError): + of nnkStmtList(nnkIdent, nnkIdent, nnkIdent): + echo(88*88+33*33) + of nnkStmtList( + _( + nnkIdentDefs( + ident"a", + nnkEmpty, nnkIntLit(intVal = 123) + ) + ), + _, + nnkForStmt( + nnkIdent(strVal = "i"), + nnkInfix, + `mysym` @ nnkStmtList + ) + ): + echo "The AST did match!!!" + echo "The matched sub tree is the following:" + echo mysym.lispRepr + #else: + # echo "sadly the AST did not match :(" + # echo arg.treeRepr + # failWithMatchingError(matchError[1]) + + foo: + let a = 123 + let b = 342 + for i in a ..< b: + echo "Hallo", i + + static: + + var ast = quote do: + type + A[T: static[int]] = object + + ast = ast[0] + ast.matchAst(err): # this is a sub ast for this a findAst or something like that is useful + of nnkTypeDef(_, nnkGenericParams( nnkIdentDefs( nnkIdent(strVal = "T"), `staticTy`, nnkEmpty )), _): + echo "`", staticTy.repr, "` used to be of nnkStaticTy, now it is ", staticTy.kind, " with ", staticTy[0].repr + ast = quote do: + if cond1: expr1 elif cond2: expr2 else: expr3 + + ast.matchAst: + of {nnkIfExpr, nnkIfStmt}( + {nnkElifExpr, nnkElifBranch}(`cond1`, `expr1`), + {nnkElifExpr, nnkElifBranch}(`cond2`, `expr2`), + {nnkElseExpr, nnkElse}(`expr3`) + ): + echo "ok" + + let ast2 = nnkStmtList.newTree( newLit(1) ) + + ast2.matchAst: + of nnkIntLit( 1 ): + echo "fail" + of nnkStmtList( 1 ): + echo "ok" + + ast = bindSym"[]" + ast.matchAst(errors): + of nnkClosedSymChoice(strVal = "[]"): + echo "fail, this is the wrong syntax, a sym choice does not have a `strVal` member." + of ident"[]": + echo "ok" + + const myConst = 123 + ast = newLit(123) + + ast.matchAst: + of _(intVal = myConst): + echo "ok" + + macro testRecCase(ast: untyped): untyped = + ast.matchAstRecursive: + of nnkIdentDefs(`a`,`b`,`c`): + echo "got ident defs a: ", a.repr, " b: ", b.repr, " c: ", c.repr + of ident"m": + echo "got the ident m" + + testRecCase: + type Obj[T] = object {.inheritable.} + name: string + case isFat: bool + of true: + m: array[100_000, T] + of false: + m: array[10, T] + + + macro testIfCondition(ast: untyped): untyped = + let literals = nnkBracket.newTree + ast.matchAstRecursive: + of `intLit` @ nnkIntLit |= intLit.intVal > 5: + literals.add intLit + + let literals2 = quote do: + [6,7,8,9] + + doAssert literals2 == literals + + testIfCondition([1,6,2,7,3,8,4,9,5,0,"123"]) diff --git a/tests/astspec/tastspec.nim b/tests/astspec/tastspec.nim new file mode 100644 index 000000000..82c32f130 --- /dev/null +++ b/tests/astspec/tastspec.nim @@ -0,0 +1,1016 @@ +discard """ +action: compile +""" + +# this test should ensure that the AST doesn't change slighly without it getting noticed. + +import ../ast_pattern_matching + +macro testAddrAst(arg: typed): bool = + arg.expectKind nnkStmtListExpr + arg[0].expectKind(nnkVarSection) + arg[1].expectKind({nnkAddr, nnkCall}) + result = newLit(arg[1].kind == nnkCall) + +const newAddrAst: bool = testAddrAst((var x: int; addr(x))) + +static: + echo "new addr ast: ", newAddrAst + +# TODO test on matching failures + +proc peelOff*(arg: NimNode, kinds: set[NimNodeKind]): NimNode {.compileTime.} = + ## Peel off nodes of a specific kinds. + if arg.len == 1 and arg.kind in kinds: + arg[0].peelOff(kinds) + else: + arg + +proc peelOff*(arg: NimNode, kind: NimNodeKind): NimNode {.compileTime.} = + ## Peel off nodes of a specific kind. + if arg.len == 1 and arg.kind == kind: + arg[0].peelOff(kind) + else: + arg + +static: + template testPattern(pattern, astArg: untyped): untyped = + let ast = quote do: `astArg` + ast.matchAst: + of `pattern`: + echo "ok" + + template testPatternFail(pattern, astArg: untyped): untyped = + let ast = quote do: `astArg` + ast.matchAst: + of `pattern`: + error("this should not match", ast) + else: + echo "OK" + + + testPattern nnkIntLit(intVal = 42) , 42 + testPattern nnkInt8Lit(intVal = 42) , 42'i8 + testPattern nnkInt16Lit(intVal = 42) , 42'i16 + testPattern nnkInt32Lit(intVal = 42) , 42'i32 + testPattern nnkInt64Lit(intVal = 42) , 42'i64 + testPattern nnkUInt8Lit(intVal = 42) , 42'u8 + testPattern nnkUInt16Lit(intVal = 42) , 42'u16 + testPattern nnkUInt32Lit(intVal = 42) , 42'u32 + testPattern nnkUInt64Lit(intVal = 42) , 42'u64 + #testPattern nnkFloat64Lit(floatVal = 42.0) , 42.0 + testPattern nnkFloat32Lit(floatVal = 42.0) , 42.0'f32 + #testPattern nnkFloat64Lit(floatVal = 42.0) , 42.0'f64 + testPattern nnkStrLit(strVal = "abc") , "abc" + testPattern nnkRStrLit(strVal = "abc") , r"abc" + testPattern nnkTripleStrLit(strVal = "abc") , """abc""" + testPattern nnkCharLit(intVal = 32) , ' ' + testPattern nnkNilLit() , nil + testPattern nnkIdent(strVal = "myIdentifier") , myIdentifier + + testPatternFail nnkInt8Lit(intVal = 42) , 42'i16 + testPatternFail nnkInt16Lit(intVal = 42) , 42'i8 + + +# this should be just `block` but it doesn't work that way anymore because of VM. +macro scope(arg: untyped): untyped = + let procSym = genSym(nskProc) + result = quote do: + proc `procSym`(): void {.compileTime.} = + `arg` + + `procSym`() + +static: + ## Command call + scope: + + let ast = quote do: + echo "abc", "xyz" + + ast.matchAst: + of nnkCommand(ident"echo", "abc", "xyz"): + echo "ok" + + ## Call with ``()`` + + scope: + let ast = quote do: + echo("abc", "xyz") + + ast.matchAst: + of nnkCall(ident"echo", "abc", "xyz"): + echo "ok" + + ## Infix operator call + + macro testInfixOperatorCall(ast: untyped): untyped = + ast.matchAst(errorSym): + of nnkInfix( + ident"&", + nnkStrLit(strVal = "abc"), + nnkStrLit(strVal = "xyz") + ): + echo "ok1" + of nnkInfix( + ident"+", + nnkIntLit(intVal = 5), + nnkInfix( + ident"*", + nnkIntLit(intVal = 3), + nnkIntLit(intVal = 4) + ) + ): + echo "ok2" + of nnkCall( + nnkAccQuoted( + ident"+" + ), + nnkIntLit(intVal = 3), + nnkIntLit(intVal = 4) + ): + echo "ok3" + + testInfixOperatorCall("abc" & "xyz") + testInfixOperatorCall(5 + 3 * 4) + testInfixOperatorCall(`+`(3, 4)) + + + ## Prefix operator call + + scope: + + let ast = quote do: + ? "xyz" + + ast.matchAst(err): + of nnkPrefix( + ident"?", + nnkStrLit(strVal = "xyz") + ): + echo "ok" + + + ## Postfix operator call + + scope: + + let ast = quote do: + proc identifier* + + ast[0].matchAst(err): + of nnkPostfix( + ident"*", + ident"identifier" + ): + echo "ok" + + + ## Call with named arguments + + macro testCallWithNamedArguments(ast: untyped): untyped = + ast.peelOff(nnkStmtList).matchAst: + of nnkCall( + ident"writeLine", + nnkExprEqExpr( + ident"file", + ident"stdout" + ), + nnkStrLit(strVal = "hallo") + ): + echo "ok" + + testCallWithNamedArguments: + writeLine(file=stdout, "hallo") + + ## Call with raw string literal + scope: + let ast = quote do: + echo"abc" + + + ast.matchAst(err): + of nnkCallStrLit( + ident"echo", + nnkRStrLit(strVal = "abc") + ): + echo "ok" + + ## Dereference operator ``[]`` + + scope: + # The dereferece operator exists only on a typed ast. + macro testDereferenceOperator(ast: typed): untyped = + ast.matchAst(err): + of nnkDerefExpr(_): + echo "ok" + + var x: ptr int + testDereferenceOperator(x[]) + + + + ## Addr operator + + scope: + # The addr operator exists only on a typed ast. + macro testAddrOperator(ast: untyped): untyped = + echo ast.treeRepr + ast.matchAst(err): + of nnkAddr(ident"x"): + echo "old nim" + of nnkCall(ident"addr", ident"x"): + echo "ok" + + var x: int + testAddrOperator(addr(x)) + + + ## Cast operator + + scope: + + let ast = quote do: + cast[T](x) + + ast.matchAst: + of nnkCast(ident"T", ident"x"): + echo "ok" + + + ## Object access operator ``.`` + + scope: + + let ast = quote do: + x.y + + ast.matchAst: + of nnkDotExpr(ident"x", ident"y"): + echo "ok" + + ## Array access operator ``[]`` + + macro testArrayAccessOperator(ast: untyped): untyped = + ast.matchAst: + of nnkBracketExpr(ident"x", ident"y"): + echo "ok" + + testArrayAccessOperator(x[y]) + + + + ## Parentheses + + scope: + + let ast = quote do: + (1, 2, (3)) + + ast.matchAst: + of nnkPar(nnkIntLit(intVal = 1), nnkIntLit(intVal = 2), nnkPar(nnkIntLit(intVal = 3))): + echo "ok" + + + ## Curly braces + + scope: + + let ast = quote do: + {1, 2, 3} + + ast.matchAst: + of nnkCurly(nnkIntLit(intVal = 1), nnkIntLit(intVal = 2), nnkIntLit(intVal = 3)): + echo "ok" + + scope: + + let ast = quote do: + {a: 3, b: 5} + + ast.matchAst: + of nnkTableConstr( + nnkExprColonExpr(ident"a", nnkIntLit(intVal = 3)), + nnkExprColonExpr(ident"b", nnkIntLit(intVal = 5)) + ): + echo "ok" + + + ## Brackets + + scope: + + let ast = quote do: + [1, 2, 3] + + ast.matchAst: + of nnkBracket(nnkIntLit(intVal = 1), nnkIntLit(intVal = 2), nnkIntLit(intVal = 3)): + echo "ok" + + + ## Ranges + + scope: + + let ast = quote do: + 1..3 + + ast.matchAst: + of nnkInfix( + ident"..", + nnkIntLit(intVal = 1), + nnkIntLit(intVal = 3) + ): + echo "ok" + + + ## If expression + + scope: + + let ast = quote do: + if cond1: expr1 elif cond2: expr2 else: expr3 + + ast.matchAst: + of {nnkIfExpr, nnkIfStmt}( + {nnkElifExpr, nnkElifBranch}(`cond1`, `expr1`), + {nnkElifExpr, nnkElifBranch}(`cond2`, `expr2`), + {nnkElseExpr, nnkElse}(`expr3`) + ): + echo "ok" + + ## Documentation Comments + + scope: + + let ast = quote do: + ## This is a comment + ## This is part of the first comment + stmt1 + ## Yet another + + ast.matchAst: + of nnkStmtList( + nnkCommentStmt(), + `stmt1`, + nnkCommentStmt() + ): + echo "ok" + else: + echo "NOT OK!!!" + echo ast.treeRepr + echo "TEST causes no fail, because of a regression in Nim." + + scope: + let ast = quote do: + {.emit: "#include <stdio.h>".} + + ast.matchAst: + of nnkPragma( + nnkExprColonExpr( + ident"emit", + nnkStrLit(strVal = "#include <stdio.h>") # the "argument" + ) + ): + echo "ok" + + scope: + let ast = quote do: + {.pragma: cdeclRename, cdecl.} + + ast.matchAst: + of nnkPragma( + nnkExprColonExpr( + ident"pragma", # this is always first when declaring a new pragma + ident"cdeclRename" # the name of the pragma + ), + ident"cdecl" + ): + echo "ok" + + + + scope: + let ast = quote do: + if cond1: + stmt1 + elif cond2: + stmt2 + elif cond3: + stmt3 + else: + stmt4 + + ast.matchAst: + of nnkIfStmt( + nnkElifBranch(`cond1`, `stmt1`), + nnkElifBranch(`cond2`, `stmt2`), + nnkElifBranch(`cond3`, `stmt3`), + nnkElse(`stmt4`) + ): + echo "ok" + + + + scope: + let ast = quote do: + x = 42 + + ast.matchAst: + of nnkAsgn(ident"x", nnkIntLit(intVal = 42)): + echo "ok" + + + + scope: + let ast = quote do: + stmt1 + stmt2 + stmt3 + + ast.matchAst: + of nnkStmtList(`stmt1`, `stmt2`, `stmt3`): + assert stmt1.strVal == "stmt1" + assert stmt2.strVal == "stmt2" + assert stmt3.strVal == "stmt3" + echo "ok" + + ## Case statement + + scope: + + let ast = quote do: + case expr1 + of expr2, expr3..expr4: + stmt1 + of expr5: + stmt2 + elif cond1: + stmt3 + else: + stmt4 + + ast.matchAst: + of nnkCaseStmt( + `expr1`, + nnkOfBranch(`expr2`, {nnkRange, nnkInfix}(_, `expr3`, `expr4`), `stmt1`), + nnkOfBranch(`expr5`, `stmt2`), + nnkElifBranch(`cond1`, `stmt3`), + nnkElse(`stmt4`) + ): + echo "ok" + + ## While statement + + scope: + + let ast = quote do: + while expr1: + stmt1 + + ast.matchAst: + of nnkWhileStmt(`expr1`, `stmt1`): + echo "ok" + + + ## For statement + + scope: + + let ast = quote do: + for ident1, ident2 in expr1: + stmt1 + + ast.matchAst: + of nnkForStmt(`ident1`, `ident2`, `expr1`, `stmt1`): + echo "ok" + + + ## Try statement + + scope: + + let ast = quote do: + try: + stmt1 + except e1, e2: + stmt2 + except e3: + stmt3 + except: + stmt4 + finally: + stmt5 + + ast.matchAst: + of nnkTryStmt( + `stmt1`, + nnkExceptBranch(`e1`, `e2`, `stmt2`), + nnkExceptBranch(`e3`, `stmt3`), + nnkExceptBranch(`stmt4`), + nnkFinally(`stmt5`) + ): + echo "ok" + + + ## Return statement + + scope: + + let ast = quote do: + return expr1 + + ast.matchAst: + of nnkReturnStmt(`expr1`): + echo "ok" + + + ## Continue statement + + scope: + let ast = quote do: + continue + + ast.matchAst: + of nnkContinueStmt: + echo "ok" + + ## Break statement + + scope: + + let ast = quote do: + break otherLocation + + ast.matchAst: + of nnkBreakStmt(ident"otherLocation"): + echo "ok" + + ## Block statement + + scope: + + let ast = quote do: + block name: + discard + + ast.matchAst: + of nnkBlockStmt(ident"name", nnkStmtList): + echo "ok" + + ## Asm statement + + scope: + + let ast = quote do: + asm """some asm""" + + ast.matchAst: + of nnkAsmStmt( + nnkEmpty(), # for pragmas + nnkTripleStrLit(strVal = "some asm"), + ): + echo "ok" + + ## Import section + + scope: + + let ast = quote do: + import math + + ast.matchAst: + of nnkImportStmt(ident"math"): + echo "ok" + + scope: + + let ast = quote do: + import math except pow + + ast.matchAst: + of nnkImportExceptStmt(ident"math",ident"pow"): + echo "ok" + + scope: + + let ast = quote do: + import strutils as su + + ast.matchAst: + of nnkImportStmt( + nnkInfix( + ident"as", + ident"strutils", + ident"su" + ) + ): + echo "ok" + + ## From statement + + scope: + + let ast = quote do: + from math import pow + + ast.matchAst: + of nnkFromStmt(ident"math", ident"pow"): + echo "ok" + + ## Export statement + + scope: + + let ast = quote do: + export unsigned + + ast.matchAst: + of nnkExportStmt(ident"unsigned"): + echo "ok" + + scope: + + let ast = quote do: + export math except pow # we're going to implement our own exponentiation + + ast.matchAst: + of nnkExportExceptStmt(ident"math",ident"pow"): + echo "ok" + + ## Include statement + + scope: + + let ast = quote do: + include blocks + + ast.matchAst: + of nnkIncludeStmt(ident"blocks"): + echo "ok" + + ## Var section + + scope: + + let ast = quote do: + var a = 3 + + ast.matchAst: + of nnkVarSection( + nnkIdentDefs( + ident"a", + nnkEmpty(), # or nnkIdent(...) if the variable declares the type + nnkIntLit(intVal = 3), + ) + ): + echo "ok" + + ## Let section + + scope: + + let ast = quote do: + let a = 3 + + ast.matchAst: + of nnkLetSection( + nnkIdentDefs( + ident"a", + nnkEmpty(), # or nnkIdent(...) for the type + nnkIntLit(intVal = 3), + ) + ): + echo "ok" + + ## Const section + + scope: + + let ast = quote do: + const a = 3 + + ast.matchAst: + of nnkConstSection( + nnkConstDef( # not nnkConstDefs! + ident"a", + nnkEmpty(), # or nnkIdent(...) if the variable declares the type + nnkIntLit(intVal = 3), # required in a const declaration! + ) + ): + echo "ok" + + ## Type section + + scope: + + let ast = quote do: + type A = int + + ast.matchAst: + of nnkTypeSection( + nnkTypeDef( + ident"A", + nnkEmpty(), + ident"int" + ) + ): + echo "ok" + + scope: + + let ast = quote do: + type MyInt = distinct int + + ast.peelOff({nnkTypeSection}).matchAst: + of# ... + nnkTypeDef( + ident"MyInt", + nnkEmpty(), + nnkDistinctTy( + ident"int" + ) + ): + echo "ok" + + scope: + + let ast = quote do: + type A[T] = expr1 + + ast.matchAst: + of nnkTypeSection( + nnkTypeDef( + ident"A", + nnkGenericParams( + nnkIdentDefs( + ident"T", + nnkEmpty(), # if the type is declared with options, like + # ``[T: SomeInteger]``, they are given here + nnkEmpty() + ) + ), + `expr1` + ) + ): + echo "ok" + + scope: + + let ast = quote do: + type IO = object of RootObj + + ast.peelOff(nnkTypeSection).matchAst: + of nnkTypeDef( + ident"IO", + nnkEmpty(), + nnkObjectTy( + nnkEmpty(), # no pragmas here + nnkOfInherit( + ident"RootObj" # inherits from RootObj + ), + nnkEmpty() + ) + ): + echo "ok" + + scope: + macro testRecCase(ast: untyped): untyped = + ast.peelOff({nnkStmtList, nnkTypeSection})[2].matchAst: + of nnkObjectTy( + nnkPragma( + ident"inheritable" + ), + nnkEmpty(), + nnkRecList( # list of object parameters + nnkIdentDefs( + ident"name", + ident"string", + nnkEmpty() + ), + nnkRecCase( # case statement within object (not nnkCaseStmt) + nnkIdentDefs( + ident"isFat", + ident"bool", + nnkEmpty() + ), + nnkOfBranch( + ident"true", + nnkRecList( # again, a list of object parameters + nnkIdentDefs( + ident"m", + nnkBracketExpr( + ident"array", + nnkIntLit(intVal = 100000), + ident"T" + ), + nnkEmpty() + ) + ) + ), + nnkOfBranch( + ident"false", + nnkRecList( + nnkIdentDefs( + ident"m", + nnkBracketExpr( + ident"array", + nnkIntLit(intVal = 10), + ident"T" + ), + nnkEmpty() + ) + ) + ) + ) + ) + ): + echo "ok" + + + + testRecCase: + type Obj[T] = object {.inheritable.} + name: string + case isFat: bool + of true: + m: array[100_000, T] + of false: + m: array[10, T] + + scope: + + let ast = quote do: + type X = enum + First + + ast.peelOff({nnkStmtList, nnkTypeSection})[2].matchAst: + of nnkEnumTy( + nnkEmpty(), + ident"First" # you need at least one nnkIdent or the compiler complains + ): + echo "ok" + + scope: + + let ast = quote do: + type Con = concept x,y,z + (x & y & z) is string + + ast.peelOff({nnkStmtList, nnkTypeSection}).matchAst: + of nnkTypeDef(_, _, nnkTypeClassTy(nnkArgList, _, _, nnkStmtList)): + # note this isn't nnkConceptTy! + echo "ok" + + + scope: + + let astX = quote do: + type + A[T: static[int]] = object + + let ast = astX.peelOff({nnkStmtList, nnkTypeSection}) + + ast.matchAst(err): # this is a sub ast for this a findAst or something like that is useful + of nnkTypeDef(_, nnkGenericParams( nnkIdentDefs( ident"T", nnkCall( ident"[]", ident"static", _ ), _ )), _): + echo "ok" + else: + echo "foobar" + echo ast.treeRepr + + + scope: + let ast = quote do: + type MyProc[T] = proc(x: T) + + ast.peelOff({nnkStmtList, nnkTypeSection}).matchAst(err): + of nnkTypeDef( + ident"MyProc", + nnkGenericParams, # here, not with the proc + nnkProcTy( # behaves like a procedure declaration from here on + nnkFormalParams, _ + ) + ): + echo "ok" + + ## Mixin statement + + macro testMixinStatement(ast: untyped): untyped = + ast.peelOff(nnkStmtList).matchAst: + of nnkMixinStmt(ident"x"): + echo "ok" + + testMixinStatement: + mixin x + + ## Bind statement + + + macro testBindStmt(ast: untyped): untyped = + ast[0].matchAst: + of `node` @ nnkBindStmt(ident"x"): + echo "ok" + + testBindStmt: + bind x + + ## Procedure declaration + + macro testProcedureDeclaration(ast: untyped): untyped = + # NOTE this is wrong in astdef + + ast.peelOff(nnkStmtList).matchAst: + of nnkProcDef( + nnkPostfix(ident"*", ident"hello"), # the exported proc name + nnkEmpty, # patterns for term rewriting in templates and macros (not procs) + nnkGenericParams( # generic type parameters, like with type declaration + nnkIdentDefs( + ident"T", + ident"SomeInteger", _ + ) + ), + nnkFormalParams( + ident"int", # the first FormalParam is the return type. nnkEmpty if there is none + nnkIdentDefs( + ident"x", + ident"int", # type type (required for procs, not for templates) + nnkIntLit(intVal = 3) # a default value + ), + nnkIdentDefs( + ident"y", + ident"float32", + nnkEmpty + ) + ), + nnkPragma(ident"inline"), + nnkEmpty, # reserved slot for future use + `meat` @ nnkStmtList # the meat of the proc + ): + echo "ok got meat: ", meat.lispRepr + + testProcedureDeclaration: + proc hello*[T: SomeInteger](x: int = 3, y: float32): int {.inline.} = discard + + scope: + var ast = quote do: + proc foobar(a, b: int): void + + ast = ast[3] + + ast.matchAst: # sub expression + of nnkFormalParams( + _, # return would be here + nnkIdentDefs( + ident"a", # the first parameter + ident"b", # directly to the second parameter + ident"int", # their shared type identifier + nnkEmpty, # default value would go here + ) + ): + echo "ok" + + scope: + + let ast = quote do: + proc hello(): var int + + ast[3].matchAst: # subAst + of nnkFormalParams( + nnkVarTy( + ident"int" + ) + ): + echo "ok" + + ## Iterator declaration + + scope: + + let ast = quote do: + iterator nonsense[T](x: seq[T]): float {.closure.} = + discard + + ast.matchAst: + of nnkIteratorDef(ident"nonsense", nnkEmpty, _, _, _, _, _): + echo "ok" + + ## Converter declaration + + scope: + + let ast = quote do: + converter toBool(x: float): bool + + ast.matchAst: + of nnkConverterDef(ident"toBool",_,_,_,_,_,_): + echo "ok" + + ## Template declaration + + scope: + let ast = quote do: + template optOpt{expr1}(a: int): int + + ast.matchAst: + of nnkTemplateDef(ident"optOpt", nnkStmtList(`expr1`), _, _, _, _, _): + echo "ok" |