summary refs log tree commit diff stats
diff options
context:
space:
mode:
-rw-r--r--lib/pure/unittest.nim83
-rw-r--r--tests/stdlib/tunittest.nim27
2 files changed, 63 insertions, 47 deletions
diff --git a/lib/pure/unittest.nim b/lib/pure/unittest.nim
index 28691fcb4..3772a213a 100644
--- a/lib/pure/unittest.nim
+++ b/lib/pure/unittest.nim
@@ -509,10 +509,6 @@ macro check*(conditions: untyped): untyped =
   ##    "AKB48".toLowerAscii() == "akb48"
   ##    'C' in teams
   let checked = callsite()[1]
-  var
-    argsAsgns = newNimNode(nnkStmtList)
-    argsPrintOuts = newNimNode(nnkStmtList)
-    counter = 0
 
   template asgn(a: untyped, value: typed) =
     var a = value # XXX: we need "var: var" here in order to
@@ -522,66 +518,71 @@ macro check*(conditions: untyped): untyped =
     when compiles(string($value)):
       checkpoint(name & " was " & $value)
 
-  proc inspectArgs(exp: NimNode): NimNode =
-    result = copyNimTree(exp)
+  proc inspectArgs(exp: NimNode): tuple[assigns, check, printOuts: NimNode] =
+    result.check = copyNimTree(exp)
+    result.assigns = newNimNode(nnkStmtList)
+    result.printOuts = newNimNode(nnkStmtList)
+
+    var counter = 0
+
     if exp[0].kind == nnkIdent and
-        $exp[0] in ["and", "or", "not", "in", "notin", "==", "<=",
+        $exp[0] in ["not", "in", "notin", "==", "<=",
                     ">=", "<", ">", "!=", "is", "isnot"]:
-      for i in countup(1, exp.len - 1):
+
+      for i in 1 ..< exp.len:
         if exp[i].kind notin nnkLiterals:
           inc counter
-          var arg = newIdentNode(":p" & $counter)
-          var argStr = exp[i].toStrLit
-          var paramAst = exp[i]
+          let argStr = exp[i].toStrLit
+          let paramAst = exp[i]
           if exp[i].kind == nnkIdent:
-            argsPrintOuts.add getAst(print(argStr, paramAst))
-          if exp[i].kind in nnkCallKinds:
-            var callVar = newIdentNode(":c" & $counter)
-            argsAsgns.add getAst(asgn(callVar, paramAst))
-            result[i] = callVar
-            argsPrintOuts.add getAst(print(argStr, callVar))
+            result.printOuts.add getAst(print(argStr, paramAst))
+          if exp[i].kind in nnkCallKinds + { nnkDotExpr, nnkBracketExpr }:
+            let callVar = newIdentNode(":c" & $counter)
+            result.assigns.add getAst(asgn(callVar, paramAst))
+            result.check[i] = callVar
+            result.printOuts.add getAst(print(argStr, callVar))
           if exp[i].kind == nnkExprEqExpr:
             # ExprEqExpr
             #   Ident !"v"
             #   IntLit 2
-            result[i] = exp[i][1]
+            result.check[i] = exp[i][1]
           if exp[i].typekind notin {ntyTypeDesc}:
-            argsAsgns.add getAst(asgn(arg, paramAst))
-            argsPrintOuts.add getAst(print(argStr, arg))
+            let arg = newIdentNode(":p" & $counter)
+            result.assigns.add getAst(asgn(arg, paramAst))
+            result.printOuts.add getAst(print(argStr, arg))
             if exp[i].kind != nnkExprEqExpr:
-              result[i] = arg
+              result.check[i] = arg
             else:
-              result[i][1] = arg
+              result.check[i][1] = arg
 
   case checked.kind
   of nnkCallKinds:
-    template rewrite(call, lineInfoLit, callLit,
-                     argAssgs, argPrintOuts) =
+
+    let (assigns, check, printOuts) = inspectArgs(checked)
+    let lineinfo = newStrLitNode(checked.lineinfo)
+    let callLit = checked.toStrLit
+    result = quote do:
       block:
-        argAssgs #all callables (and assignments) are run here
-        if not call:
-          checkpoint(lineInfoLit & ": Check failed: " & callLit)
-          argPrintOuts
+        `assigns`
+        if not `check`:
+          checkpoint(`lineinfo` & ": Check failed: " & `callLit`)
+          `printOuts`
           fail()
 
-    var checkedStr = checked.toStrLit
-    let parameterizedCheck = inspectArgs(checked)
-    result = getAst(rewrite(parameterizedCheck, checked.lineinfo, checkedStr,
-                            argsAsgns, argsPrintOuts))
-
   of nnkStmtList:
     result = newNimNode(nnkStmtList)
-    for i in countup(0, checked.len - 1):
-      if checked[i].kind != nnkCommentStmt:
-        result.add(newCall(!"check", checked[i]))
+    for node in checked:
+      if node.kind != nnkCommentStmt:
+        result.add(newCall(!"check", node))
 
   else:
-    template rewrite(exp, lineInfoLit, expLit) =
-      if not exp:
-        checkpoint(lineInfoLit & ": Check failed: " & expLit)
-        fail()
+    let lineinfo = newStrLitNode(checked.lineinfo)
+    let callLit = checked.toStrLit
 
-    result = getAst(rewrite(checked, checked.lineinfo, checked.toStrLit))
+    result = quote do:
+      if not `checked`:
+        checkpoint(`lineinfo` & ": Check failed: " & `callLit`)
+        fail()
 
 template require*(conditions: untyped) =
   ## Same as `check` except any failed test causes the program to quit
diff --git a/tests/stdlib/tunittest.nim b/tests/stdlib/tunittest.nim
index 674ce50dd..e4a801871 100644
--- a/tests/stdlib/tunittest.nim
+++ b/tests/stdlib/tunittest.nim
@@ -1,12 +1,23 @@
 discard """
-  nimout: "compile start\ncompile end"
+  output: '''[Suite] suite with only teardown
+
+[Suite] suite with only setup
+
+[Suite] suite with none
+
+[Suite] suite with both
+
+[Suite] bug #4494
+
+[Suite] bug #5571
+
+[Suite] bug #5784
+
+'''
 """
 
 import unittest, sequtils
 
-static:
-  echo "compile start"
-
 proc doThings(spuds: var int): int =
   spuds = 24
   return 99
@@ -103,5 +114,9 @@ suite "bug #5571":
       check: line == "a"
     doTest()
 
-static:
-  echo "compile end"
+suite "bug #5784":
+  test "`or` should short circuit":
+    type Obj = ref object
+      field: int
+    var obj: Obj
+    check obj.isNil or obj.field == 0