summary refs log tree commit diff stats
diff options
context:
space:
mode:
-rw-r--r--lib/pure/math.nim93
1 files changed, 86 insertions, 7 deletions
diff --git a/lib/pure/math.nim b/lib/pure/math.nim
index 58d9879b2..4ef169b4f 100644
--- a/lib/pure/math.nim
+++ b/lib/pure/math.nim
@@ -199,13 +199,6 @@ when not defined(JS):
   proc tgamma*(x: float64): float64 {.importc: "tgamma", header: "<math.h>".}
     ## The gamma function
 
-  proc trunc*(x: float32): float32 {.importc: "truncf", header: "<math.h>".}
-  proc trunc*(x: float64): float64 {.importc: "trunc", header: "<math.h>".}
-    ## Truncates `x` to the decimal point
-    ##
-    ## .. code-block:: nim
-    ##  echo trunc(PI) # 3.0
-
   proc floor*(x: float32): float32 {.importc: "floorf", header: "<math.h>".}
   proc floor*(x: float64): float64 {.importc: "floor", header: "<math.h>".}
     ## Computes the floor function (i.e., the largest integer not greater than `x`)
@@ -221,6 +214,56 @@ when not defined(JS):
     ##  echo ceil(-2.1) ## -2.0
 
   when defined(windows) and defined(vcc):
+    # MSVC 2010 don't have trunc/truncf
+    # this implementation was inspired by Go-lang Math.Trunc
+    proc truncImpl(f: float64): float64 =
+      const
+        mask : uint64 = 0x7FF
+        shift: uint64 = 64 - 12
+        bias : uint64 = 0x3FF
+
+      if f < 1:
+        if f < 0: return -truncImpl(-f)
+        elif f == 0: return f # Return -0 when f == -0
+        else: return 0
+
+      var x = cast[uint64](f)
+      let e = (x shr shift) and mask - bias
+
+      # Keep the top 12+e bits, the integer part; clear the rest.
+      if e < 64-12:
+        x = x and (not (1'u64 shl (64'u64-12'u64-e) - 1'u64))
+
+      result = cast[float64](x)
+    
+    proc truncImpl(f: float32): float32 =
+      const
+        mask : uint32 = 0xFF
+        shift: uint32 = 32 - 9
+        bias : uint32 = 0x7F
+
+      if f < 1:
+        if f < 0: return -truncImpl(-f)
+        elif f == 0: return f # Return -0 when f == -0
+        else: return 0
+
+      var x = cast[uint32](f)
+      let e = (x shr shift) and mask - bias
+
+      # Keep the top 9+e bits, the integer part; clear the rest.
+      if e < 32-9:
+        x = x and (not (1'u32 shl (32'u32-9'u32-e) - 1'u32))
+
+      result = cast[float32](x)
+      
+    proc trunc*(x: float64): float64 =
+      if classify(x) in {fcZero, fcNegZero, fcNan, fcInf, fcNegInf}: return x
+      result = truncImpl(x)
+
+    proc trunc*(x: float32): float32 =
+      if classify(x) in {fcZero, fcNegZero, fcNan, fcInf, fcNegInf}: return x
+      result = truncImpl(x)
+
     proc round0[T: float32|float64](x: T): T =
       ## Windows compilers prior to MSVC 2012 do not implement 'round',
       ## 'roundl' or 'roundf'.
@@ -231,6 +274,13 @@ when not defined(JS):
       ## Rounds a float to zero decimal places.  Used internally by the round
       ## function when the specified number of places is 0.
 
+    proc trunc*(x: float32): float32 {.importc: "truncf", header: "<math.h>".}
+    proc trunc*(x: float64): float64 {.importc: "trunc", header: "<math.h>".}
+      ## Truncates `x` to the decimal point
+      ##
+      ## .. code-block:: nim
+      ##  echo trunc(PI) # 3.0
+
   proc fmod*(x, y: float32): float32 {.importc: "fmodf", header: "<math.h>".}
   proc fmod*(x, y: float64): float64 {.importc: "fmod", header: "<math.h>".}
     ## Computes the remainder of `x` divided by `y`
@@ -430,3 +480,32 @@ when isMainModule:
     doAssert splitDecimal(-693.4356).floatpart ==~ -0.4356
     doAssert splitDecimal(0.0).intpart ==~ 0.0
     doAssert splitDecimal(0.0).floatpart ==~ 0.0
+
+  block: # trunc tests for vcc
+    doAssert(trunc(-1.1) == -1)
+    doAssert(trunc(1.1) == 1)
+    doAssert(trunc(-0.1) == -0)
+    doAssert(trunc(0.1) == 0)
+
+    #special case
+    doAssert(classify(trunc(1e1000000)) == fcInf)
+    doAssert(classify(trunc(-1e1000000)) == fcNegInf)
+    doAssert(classify(trunc(0.0/0.0)) == fcNan)
+    doAssert(classify(trunc(0.0)) == fcZero)
+
+    #trick the compiler to produce signed zero
+    let
+      f_neg_one = -1.0
+      f_zero = 0.0
+      f_nan = f_zero / f_zero
+
+    doAssert(classify(trunc(f_neg_one*f_zero)) == fcNegZero)
+
+    doAssert(trunc(-1.1'f32) == -1)
+    doAssert(trunc(1.1'f32) == 1)
+    doAssert(trunc(-0.1'f32) == -0)
+    doAssert(trunc(0.1'f32) == 0)
+    doAssert(classify(trunc(1e1000000'f32)) == fcInf)
+    doAssert(classify(trunc(-1e1000000'f32)) == fcNegInf)
+    doAssert(classify(trunc(f_nan.float32)) == fcNan)
+    doAssert(classify(trunc(0.0'f32)) == fcZero)