summary refs log tree commit diff stats
diff options
context:
space:
mode:
-rw-r--r--compiler/vmhooks.nim3
-rw-r--r--compiler/vmops.nim9
-rw-r--r--tests/stdlib/tmath.nim58
3 files changed, 41 insertions, 29 deletions
diff --git a/compiler/vmhooks.nim b/compiler/vmhooks.nim
index d211d8343..9f68eb434 100644
--- a/compiler/vmhooks.nim
+++ b/compiler/vmhooks.nim
@@ -41,6 +41,9 @@ template getX(k, field) {.dirty.} =
   doAssert a.slots[i+a.rb+1].kind == k
   result = a.slots[i+a.rb+1].field
 
+proc numArgs*(a: VmArgs): int =
+  result = a.rc-1
+
 proc getInt*(a: VmArgs; i: Natural): BiggestInt = getX(rkInt, intVal)
 proc getBool*(a: VmArgs; i: Natural): bool = getInt(a, i) != 0
 proc getFloat*(a: VmArgs; i: Natural): BiggestFloat = getX(rkFloat, floatVal)
diff --git a/compiler/vmops.nim b/compiler/vmops.nim
index 504a352b5..ec44a7c08 100644
--- a/compiler/vmops.nim
+++ b/compiler/vmops.nim
@@ -52,6 +52,7 @@ template md5op(op) {.dirty.} =
 
 template wrap1f_math(op) {.dirty.} =
   proc `op Wrapper`(a: VmArgs) {.nimcall.} =
+    doAssert a.numArgs == 1
     setResult(a, op(getFloat(a, 0)))
   mathop op
 
@@ -157,7 +158,6 @@ proc registerAdditionalOps*(c: PCtx) =
   wrap1f_math(log10)
   wrap1f_math(log2)
   wrap1f_math(exp)
-  wrap1f_math(round)
   wrap1f_math(arccos)
   wrap1f_math(arcsin)
   wrap1f_math(arctan)
@@ -180,6 +180,13 @@ proc registerAdditionalOps*(c: PCtx) =
   when declared(signbit):
     wrap1f_math(signbit)
 
+  registerCallback c, "stdlib.math.round", proc (a: VmArgs) {.nimcall.} =
+    let n = a.numArgs
+    case n
+    of 1: setResult(a, round(getFloat(a, 0)))
+    of 2: setResult(a, round(getFloat(a, 0), getInt(a, 1).int))
+    else: doAssert false, $n
+
   wrap1s(getMD5, md5op)
 
   proc `mod Wrapper`(a: VmArgs) {.nimcall.} =
diff --git a/tests/stdlib/tmath.nim b/tests/stdlib/tmath.nim
index edab62a66..7af9eb73e 100644
--- a/tests/stdlib/tmath.nim
+++ b/tests/stdlib/tmath.nim
@@ -168,14 +168,6 @@ block:
       let x: seq[float] = @[]
       doAssert prod(x) == 1.0
 
-    block: # round() tests
-      # Round to 0 decimal places
-      doAssert round(54.652) == 55.0
-      doAssert round(54.352) == 54.0
-      doAssert round(-54.652) == -55.0
-      doAssert round(-54.352) == -54.0
-      doAssert round(0.0) == 0.0
-
     block: # splitDecimal() tests
       doAssert splitDecimal(54.674).intpart == 54.0
       doAssert splitDecimal(54.674).floatpart ==~ 0.674
@@ -372,26 +364,36 @@ template main =
     doAssert copySign(-NaN, -0.0).isNaN
 
     block: # round() tests
-      # Round to 0 decimal places
-      doAssert round(54.652) == 55.0
-      doAssert round(54.352) == 54.0
-      doAssert round(-54.652) == -55.0
-      doAssert round(-54.352) == -54.0
-      doAssert round(0.0) == 0.0
-      doAssert 1 / round(0.0) == Inf
-      doAssert 1 / round(-0.0) == -Inf
-      doAssert round(Inf) == Inf
-      doAssert round(-Inf) == -Inf
-      doAssert round(NaN).isNaN
-      doAssert round(-NaN).isNaN
-      doAssert round(-0.5) == -1.0
-      doAssert round(0.5) == 1.0
-      doAssert round(-1.5) == -2.0
-      doAssert round(1.5) == 2.0
-      doAssert round(-2.5) == -3.0
-      doAssert round(2.5) == 3.0
-      doAssert round(2.5'f32) == 3.0'f32
-      doAssert round(2.5'f64) == 3.0'f64
+      block: # Round to 0 decimal places
+        doAssert round(54.652) == 55.0
+        doAssert round(54.352) == 54.0
+        doAssert round(-54.652) == -55.0
+        doAssert round(-54.352) == -54.0
+        doAssert round(0.0) == 0.0
+        doAssert 1 / round(0.0) == Inf
+        doAssert 1 / round(-0.0) == -Inf
+        doAssert round(Inf) == Inf
+        doAssert round(-Inf) == -Inf
+        doAssert round(NaN).isNaN
+        doAssert round(-NaN).isNaN
+        doAssert round(-0.5) == -1.0
+        doAssert round(0.5) == 1.0
+        doAssert round(-1.5) == -2.0
+        doAssert round(1.5) == 2.0
+        doAssert round(-2.5) == -3.0
+        doAssert round(2.5) == 3.0
+        doAssert round(2.5'f32) == 3.0'f32
+        doAssert round(2.5'f64) == 3.0'f64
+      block: # func round*[T: float32|float64](x: T, places: int): T
+        doAssert round(54.345, 0) == 54.0
+        template fn(x) =
+          doAssert round(x, 2).almostEqual 54.35
+          doAssert round(x, 2).almostEqual 54.35
+          doAssert round(x, -1).almostEqual 50.0
+          doAssert round(x, -2).almostEqual 100.0
+          doAssert round(x, -3).almostEqual 0.0
+        fn(54.346)
+        fn(54.346'f32)
 
     when nimvm:
       discard