summary refs log tree commit diff stats
path: root/lib/pure/unittest.nim
diff options
context:
space:
mode:
Diffstat (limited to 'lib/pure/unittest.nim')
-rw-r--r--lib/pure/unittest.nim83
1 files changed, 42 insertions, 41 deletions
diff --git a/lib/pure/unittest.nim b/lib/pure/unittest.nim
index 28691fcb4..3772a213a 100644
--- a/lib/pure/unittest.nim
+++ b/lib/pure/unittest.nim
@@ -509,10 +509,6 @@ macro check*(conditions: untyped): untyped =
   ##    "AKB48".toLowerAscii() == "akb48"
   ##    'C' in teams
   let checked = callsite()[1]
-  var
-    argsAsgns = newNimNode(nnkStmtList)
-    argsPrintOuts = newNimNode(nnkStmtList)
-    counter = 0
 
   template asgn(a: untyped, value: typed) =
     var a = value # XXX: we need "var: var" here in order to
@@ -522,66 +518,71 @@ macro check*(conditions: untyped): untyped =
     when compiles(string($value)):
       checkpoint(name & " was " & $value)
 
-  proc inspectArgs(exp: NimNode): NimNode =
-    result = copyNimTree(exp)
+  proc inspectArgs(exp: NimNode): tuple[assigns, check, printOuts: NimNode] =
+    result.check = copyNimTree(exp)
+    result.assigns = newNimNode(nnkStmtList)
+    result.printOuts = newNimNode(nnkStmtList)
+
+    var counter = 0
+
     if exp[0].kind == nnkIdent and
-        $exp[0] in ["and", "or", "not", "in", "notin", "==", "<=",
+        $exp[0] in ["not", "in", "notin", "==", "<=",
                     ">=", "<", ">", "!=", "is", "isnot"]:
-      for i in countup(1, exp.len - 1):
+
+      for i in 1 ..< exp.len:
         if exp[i].kind notin nnkLiterals:
           inc counter
-          var arg = newIdentNode(":p" & $counter)
-          var argStr = exp[i].toStrLit
-          var paramAst = exp[i]
+          let argStr = exp[i].toStrLit
+          let paramAst = exp[i]
           if exp[i].kind == nnkIdent:
-            argsPrintOuts.add getAst(print(argStr, paramAst))
-          if exp[i].kind in nnkCallKinds:
-            var callVar = newIdentNode(":c" & $counter)
-            argsAsgns.add getAst(asgn(callVar, paramAst))
-            result[i] = callVar
-            argsPrintOuts.add getAst(print(argStr, callVar))
+            result.printOuts.add getAst(print(argStr, paramAst))
+          if exp[i].kind in nnkCallKinds + { nnkDotExpr, nnkBracketExpr }:
+            let callVar = newIdentNode(":c" & $counter)
+            result.assigns.add getAst(asgn(callVar, paramAst))
+            result.check[i] = callVar
+            result.printOuts.add getAst(print(argStr, callVar))
           if exp[i].kind == nnkExprEqExpr:
             # ExprEqExpr
             #   Ident !"v"
             #   IntLit 2
-            result[i] = exp[i][1]
+            result.check[i] = exp[i][1]
           if exp[i].typekind notin {ntyTypeDesc}:
-            argsAsgns.add getAst(asgn(arg, paramAst))
-            argsPrintOuts.add getAst(print(argStr, arg))
+            let arg = newIdentNode(":p" & $counter)
+            result.assigns.add getAst(asgn(arg, paramAst))
+            result.printOuts.add getAst(print(argStr, arg))
             if exp[i].kind != nnkExprEqExpr:
-              result[i] = arg
+              result.check[i] = arg
             else:
-              result[i][1] = arg
+              result.check[i][1] = arg
 
   case checked.kind
   of nnkCallKinds:
-    template rewrite(call, lineInfoLit, callLit,
-                     argAssgs, argPrintOuts) =
+
+    let (assigns, check, printOuts) = inspectArgs(checked)
+    let lineinfo = newStrLitNode(checked.lineinfo)
+    let callLit = checked.toStrLit
+    result = quote do:
       block:
-        argAssgs #all callables (and assignments) are run here
-        if not call:
-          checkpoint(lineInfoLit & ": Check failed: " & callLit)
-          argPrintOuts
+        `assigns`
+        if not `check`:
+          checkpoint(`lineinfo` & ": Check failed: " & `callLit`)
+          `printOuts`
           fail()
 
-    var checkedStr = checked.toStrLit
-    let parameterizedCheck = inspectArgs(checked)
-    result = getAst(rewrite(parameterizedCheck, checked.lineinfo, checkedStr,
-                            argsAsgns, argsPrintOuts))
-
   of nnkStmtList:
     result = newNimNode(nnkStmtList)
-    for i in countup(0, checked.len - 1):
-      if checked[i].kind != nnkCommentStmt:
-        result.add(newCall(!"check", checked[i]))
+    for node in checked:
+      if node.kind != nnkCommentStmt:
+        result.add(newCall(!"check", node))
 
   else:
-    template rewrite(exp, lineInfoLit, expLit) =
-      if not exp:
-        checkpoint(lineInfoLit & ": Check failed: " & expLit)
-        fail()
+    let lineinfo = newStrLitNode(checked.lineinfo)
+    let callLit = checked.toStrLit
 
-    result = getAst(rewrite(checked, checked.lineinfo, checked.toStrLit))
+    result = quote do:
+      if not `checked`:
+        checkpoint(`lineinfo` & ": Check failed: " & `callLit`)
+        fail()
 
 template require*(conditions: untyped) =
   ## Same as `check` except any failed test causes the program to quit