summary refs log tree commit diff stats
diff options
context:
space:
mode:
-rw-r--r--compiler/semfold.nim75
-rw-r--r--tests/arithm/tcast.nim43
2 files changed, 86 insertions, 32 deletions
diff --git a/compiler/semfold.nim b/compiler/semfold.nim
index 27a6af1f4..0018f0755 100644
--- a/compiler/semfold.nim
+++ b/compiler/semfold.nim
@@ -214,7 +214,24 @@ proc evalIs(n: PNode, lhs: PSym, g: ModuleGraph): PNode =
   result = newIntNode(nkIntLit, ord(res))
   result.typ = n.typ
 
+proc fitLiteral(c: ConfigRef, n: PNode): PNode =
+  # Trim the literal value in order to make it fit in the destination type
+  if n == nil:
+    # `n` may be nil if the overflow check kicks in
+    return
+
+  doAssert n.kind in {nkIntLit, nkCharLit}
+
+  result = n
+
+  let typ = n.typ.skipTypes(abstractRange)
+  if typ.kind in tyUInt..tyUint32:
+    result.intVal = result.intVal and lastOrd(c, typ, fixedUnsigned=true)
+
 proc evalOp(m: TMagic, n, a, b, c: PNode; g: ModuleGraph): PNode =
+  template doAndFit(op: untyped): untyped =
+    # Implements wrap-around behaviour for unsigned types
+    fitLiteral(g.config, op)
   # b and c may be nil
   result = nil
   case m
@@ -224,12 +241,7 @@ proc evalOp(m: TMagic, n, a, b, c: PNode; g: ModuleGraph): PNode =
   of mUnaryMinusF64: result = newFloatNodeT(- getFloat(a), n, g)
   of mNot: result = newIntNodeT(1 - getInt(a), n, g)
   of mCard: result = newIntNodeT(nimsets.cardSet(g.config, a), n, g)
-  of mBitnotI:
-    case skipTypes(n.typ, abstractRange).kind
-    of tyUInt..tyUInt64:
-      result = newIntNodeT((not getInt(a)) and lastOrd(g.config, a.typ, fixedUnsigned=true), n, g)
-    else:
-      result = newIntNodeT(not getInt(a), n, g)
+  of mBitnotI: result = doAndFit(newIntNodeT(not getInt(a), n, g))
   of mLengthArray: result = newIntNodeT(lengthOrd(g.config, a.typ), n, g)
   of mLengthSeq, mLengthOpenArray, mXLenSeq, mLengthStr, mXLenStr:
     if a.kind == nkNilLit:
@@ -251,9 +263,9 @@ proc evalOp(m: TMagic, n, a, b, c: PNode; g: ModuleGraph): PNode =
   of mToU8: result = newIntNodeT(getInt(a) and 0x000000FF, n, g)
   of mToU16: result = newIntNodeT(getInt(a) and 0x0000FFFF, n, g)
   of mToU32: result = newIntNodeT(getInt(a) and 0x00000000FFFFFFFF'i64, n, g)
-  of mUnaryLt: result = foldSub(getOrdValue(a), 1, n, g)
-  of mSucc: result = foldAdd(getOrdValue(a), getInt(b), n, g)
-  of mPred: result = foldSub(getOrdValue(a), getInt(b), n, g)
+  of mUnaryLt: result = doAndFit(foldSub(getOrdValue(a), 1, n, g))
+  of mSucc: result = doAndFit(foldAdd(getOrdValue(a), getInt(b), n, g))
+  of mPred: result = doAndFit(foldSub(getOrdValue(a), getInt(b), n, g))
   of mAddI: result = foldAdd(getInt(a), getInt(b), n, g)
   of mSubI: result = foldSub(getInt(a), getInt(b), n, g)
   of mMulI: result = foldMul(getInt(a), getInt(b), n, g)
@@ -271,7 +283,7 @@ proc evalOp(m: TMagic, n, a, b, c: PNode; g: ModuleGraph): PNode =
     of tyInt64, tyInt:
       result = newIntNodeT(`shl`(getInt(a), getInt(b)), n, g)
     of tyUInt..tyUInt64:
-      result = newIntNodeT(`shl`(getInt(a), getInt(b)) and lastOrd(g.config, a.typ, fixedUnsigned=true), n, g)
+      result = doAndFit(newIntNodeT(`shl`(getInt(a), getInt(b)), n, g))
     else: internalError(g.config, n.info, "constant folding for shl")
   of mShrI:
     case skipTypes(n.typ, abstractRange).kind
@@ -324,14 +336,14 @@ proc evalOp(m: TMagic, n, a, b, c: PNode; g: ModuleGraph): PNode =
     result = newIntNodeT(ord(`<%`(getOrdValue(a), getOrdValue(b))), n, g)
   of mLeU, mLeU64:
     result = newIntNodeT(ord(`<=%`(getOrdValue(a), getOrdValue(b))), n, g)
-  of mBitandI, mAnd: result = newIntNodeT(a.getInt and b.getInt, n, g)
-  of mBitorI, mOr: result = newIntNodeT(getInt(a) or getInt(b), n, g)
-  of mBitxorI, mXor: result = newIntNodeT(a.getInt xor b.getInt, n, g)
-  of mAddU: result = newIntNodeT(`+%`(getInt(a), getInt(b)), n, g)
-  of mSubU: result = newIntNodeT(`-%`(getInt(a), getInt(b)), n, g)
-  of mMulU: result = newIntNodeT(`*%`(getInt(a), getInt(b)), n, g)
-  of mModU: result = foldModU(getInt(a), getInt(b), n, g)
-  of mDivU: result = foldDivU(getInt(a), getInt(b), n, g)
+  of mBitandI, mAnd: result = doAndFit(newIntNodeT(a.getInt and b.getInt, n, g))
+  of mBitorI, mOr: result = doAndFit(newIntNodeT(getInt(a) or getInt(b), n, g))
+  of mBitxorI, mXor: result = doAndFit(newIntNodeT(a.getInt xor b.getInt, n, g))
+  of mAddU: result = doAndFit(newIntNodeT(`+%`(getInt(a), getInt(b)), n, g))
+  of mSubU: result = doAndFit(newIntNodeT(`-%`(getInt(a), getInt(b)), n, g))
+  of mMulU: result = doAndFit(newIntNodeT(`*%`(getInt(a), getInt(b)), n, g))
+  of mModU: result = doAndFit(foldModU(getInt(a), getInt(b), n, g))
+  of mDivU: result = doAndFit(foldDivU(getInt(a), getInt(b), n, g))
   of mLeSet: result = newIntNodeT(ord(containsSets(g.config, a, b)), n, g)
   of mEqSet: result = newIntNodeT(ord(equalSets(g.config, a, b)), n, g)
   of mLtSet:
@@ -462,17 +474,24 @@ proc foldConv(n, a: PNode; g: ModuleGraph; check = false): PNode =
       result = newIntNodeT(int(getFloat(a)), n, g)
     of tyChar:
       result = newIntNodeT(getOrdValue(a), n, g)
-    of tyUInt8..tyUInt32, tyInt8..tyInt32:
-      let fromSigned = srcTyp.kind in tyInt..tyInt64
+    of tyUInt..tyUInt64, tyInt..tyInt64:
       let toSigned = dstTyp.kind in tyInt..tyInt64
-
-      let mask = lastOrd(g.config, dstTyp, fixedUnsigned=true)
-
-      var val =
-        if toSigned:
-          a.getOrdValue mod mask
-        else:
-          a.getOrdValue and mask
+      var val = a.getOrdValue
+
+      if dstTyp.kind in {tyInt, tyInt64, tyUint, tyUInt64}:
+        # No narrowing needed
+        discard
+      elif dstTyp.kind in {tyInt..tyInt64}:
+        # Signed type: Overflow check (if requested) and conversion
+        if check: rangeCheck(n, val, g)
+        let mask = (`shl`(1, getSize(g.config, dstTyp) * 8) - 1)
+        let valSign = val < 0
+        val = abs(val) and mask
+        if valSign: val = -val
+      else:
+        # Unsigned type: Conversion
+        let mask = (`shl`(1, getSize(g.config, dstTyp) * 8) - 1)
+        val = val and mask
 
       result = newIntNodeT(val, n, g)
     else:
diff --git a/tests/arithm/tcast.nim b/tests/arithm/tcast.nim
index 954e2e677..4017ed1c5 100644
--- a/tests/arithm/tcast.nim
+++ b/tests/arithm/tcast.nim
@@ -4,6 +4,9 @@ B0
 B1
 B2
 B3
+B4
+B5
+B6
 '''
 """
 
@@ -14,6 +17,14 @@ template crossCheck(ty: untyped, exp: untyped) =
     echo "Got ", ct
     echo "Expected ", rt
 
+template add1(x: uint8): untyped = x + 1
+template add1(x: uint16): untyped = x + 1
+template add1(x: uint32): untyped = x + 1
+
+template sub1(x: uint8): untyped = x - 1
+template sub1(x: uint16): untyped = x - 1
+template sub1(x: uint32): untyped = x - 1
+
 block:
   when true:
     echo "B0"
@@ -34,10 +45,34 @@ block:
     crossCheck(uint64, (-1).uint64 + 5'u64)
 
     echo "B3"
-    crossCheck(int8, 0'u8 - 5'u8)
-    crossCheck(int16, 0'u16 - 5'u16)
-    crossCheck(int32, 0'u32 - 5'u32)
-    crossCheck(int64, 0'u64 - 5'u64)
+    doAssert $sub1(0'u8) == "255"
+    doAssert $sub1(0'u16) == "65535"
+    doAssert $sub1(0'u32) == "4294967295"
+
+    echo "B4"
+    doAssert $add1(255'u8) == "0"
+    doAssert $add1(65535'u16) == "0"
+    doAssert $add1(4294967295'u32) == "0"
+
+    echo "B5"
+    crossCheck(int32, high(int32))
+    crossCheck(int32, high(int32).int32)
+    crossCheck(int32, low(int32))
+    crossCheck(int32, low(int32).int32)
+    crossCheck(int64, high(int8).int16.int32.int64)
+    crossCheck(int64, low(int8).int16.int32.int64)
+
+    echo "B6"
+    crossCheck(int64, 0xFFFFFFFFFFFFFFFF'u64)
+    crossCheck(int32, 0xFFFFFFFFFFFFFFFF'u64)
+    crossCheck(int16, 0xFFFFFFFFFFFFFFFF'u64)
+    crossCheck(int8 , 0xFFFFFFFFFFFFFFFF'u64)
+
+    # Out of range conversion, caught for `let`s only
+    # crossCheck(int8, 0'u8 - 5'u8)
+    # crossCheck(int16, 0'u16 - 5'u16)
+    # crossCheck(int32, 0'u32 - 5'u32)
+    # crossCheck(int64, 0'u64 - 5'u64)
 
   # crossCheck(int8, 0'u16 - 129'u16)
   # crossCheck(uint8, 0'i16 + 257'i16)