summary refs log tree commit diff stats
path: root/compiler/guards.nim
diff options
context:
space:
mode:
Diffstat (limited to 'compiler/guards.nim')
-rw-r--r--compiler/guards.nim136
1 files changed, 128 insertions, 8 deletions
diff --git a/compiler/guards.nim b/compiler/guards.nim
index bc802ae33..5ad932e48 100644
--- a/compiler/guards.nim
+++ b/compiler/guards.nim
@@ -37,6 +37,7 @@ const
   someMod = {mModI}
   someMax = {mMaxI, mMaxF64}
   someMin = {mMinI, mMinF64}
+  someBinaryOp = someAdd+someSub+someMul+someMax+someMin
 
 proc isValue(n: PNode): bool = n.kind in {nkCharLit..nkNilLit}
 proc isLocation(n: PNode): bool = not n.isValue
@@ -165,11 +166,21 @@ proc `|+|`(a, b: PNode): PNode =
   if a.kind in {nkCharLit..nkUInt64Lit}: result.intVal = a.intVal |+| b.intVal
   else: result.floatVal = a.floatVal + b.floatVal
 
+proc `|-|`(a, b: PNode): PNode =
+  result = copyNode(a)
+  if a.kind in {nkCharLit..nkUInt64Lit}: result.intVal = a.intVal |-| b.intVal
+  else: result.floatVal = a.floatVal - b.floatVal
+
 proc `|*|`(a, b: PNode): PNode =
   result = copyNode(a)
   if a.kind in {nkCharLit..nkUInt64Lit}: result.intVal = a.intVal |*| b.intVal
   else: result.floatVal = a.floatVal * b.floatVal
 
+proc `|div|`(a, b: PNode): PNode =
+  result = copyNode(a)
+  if a.kind in {nkCharLit..nkUInt64Lit}: result.intVal = a.intVal div b.intVal
+  else: result.floatVal = a.floatVal / b.floatVal
+
 proc negate(a, b, res: PNode): PNode =
   if b.kind in {nkCharLit..nkUInt64Lit} and b.intVal != low(BiggestInt):
     var b = copyNode(b)
@@ -213,10 +224,16 @@ proc reassociation(n: PNode): PNode =
     if result[2].isValue and
         result[1].getMagic in someAdd and result[1][2].isValue:
       result = opAdd.buildCall(result[1][1], result[1][2] |+| result[2])
+      if result[2].intVal == 0:
+        result = result[1]
   of someMul:
     if result[2].isValue and
         result[1].getMagic in someMul and result[1][2].isValue:
-      result = opAdd.buildCall(result[1][1], result[1][2] |*| result[2])
+      result = opMul.buildCall(result[1][1], result[1][2] |*| result[2])
+      if result[2].intVal == 1:
+        result = result[1]
+      elif result[2].intVal == 0:
+        result = zero()
   else: discard
 
 proc pred(n: PNode): PNode =
@@ -234,7 +251,7 @@ proc canon*(n: PNode): PNode =
       result.sons[i] = canon(n.sons[i])
   elif n.kind == nkSym and n.sym.kind == skLet and
       n.sym.ast.getMagic in (someEq + someAdd + someMul + someMin +
-      someMax + someHigh + {mUnaryLt} + someSub + someLen):
+      someMax + someHigh + {mUnaryLt} + someSub + someLen + someDiv):
     result = n.sym.ast.copyTree
   else:
     result = n
@@ -248,7 +265,7 @@ proc canon*(n: PNode): PNode =
     # high == len+(-1)
     result = opAdd.buildCall(opLen.buildCall(result[1]), minusOne())
   of mUnaryLt:
-    result = buildCall(opAdd, result[1], newIntNode(nkIntLit, -1))
+    result = buildCall(opAdd, result[1], minusOne())
   of someSub:
     # x - 4  -->  x + (-4)
     result = negate(result[1], result[2], result)
@@ -294,6 +311,16 @@ proc canon*(n: PNode): PNode =
         if plus != nil and not isLetLocation(x, true):
           result = buildCall(result[0].sym, plus, y[1])
       else: discard
+    elif x.isValue and y.getMagic in someAdd and y[2].isValue:
+      # 0 <= a.len + 3
+      # -3 <= a.len
+      result.sons[1] = x |-| y[2]
+      result.sons[2] = y[1]
+    elif x.isValue and y.getMagic in someSub and y[2].isValue:
+      # 0 <= a.len - 3
+      # 3 <= a.len
+      result.sons[1] = x |+| y[2]
+      result.sons[2] = y[1]
   else: discard
 
 proc `+@`*(a: PNode; b: BiggestInt): PNode =
@@ -313,6 +340,9 @@ proc usefulFact(n: PNode): PNode =
     if isLetLocation(n.sons[1], true) or isLetLocation(n.sons[2], true):
       # XXX algebraic simplifications!  'i-1 < a.len' --> 'i < a.len+1'
       result = n
+    elif n[1].getMagic in someLen or n[2].getMagic in someLen:
+      # XXX Rethink this whole idea of 'usefulFact' for semparallel
+      result = n
   of mIsNil:
     if isLetLocation(n.sons[1], false) or isVar(n.sons[1]):
       result = n
@@ -366,8 +396,8 @@ proc usefulFact(n: PNode): PNode =
 type
   TModel* = seq[PNode] # the "knowledge base"
 
-proc addFact*(m: var TModel, n: PNode) =
-  let n = usefulFact(n)
+proc addFact*(m: var TModel, nn: PNode) =
+  let n = usefulFact(nn)
   if n != nil: m.add n
 
 proc addFactNeg*(m: var TModel, n: PNode) =
@@ -697,10 +727,57 @@ proc simpleSlice*(a, b: PNode): BiggestInt =
   else:
     result = -1
 
+
+template isMul(x): expr = x.getMagic in someMul
+template isDiv(x): expr = x.getMagic in someDiv
+template isAdd(x): expr = x.getMagic in someAdd
+template isSub(x): expr = x.getMagic in someSub
+template isVal(x): expr = x.kind in {nkCharLit..nkUInt64Lit}
+template isIntVal(x, y): expr = x.intVal == y
+
+import macros
+
+macro `=~`(x: PNode, pat: untyped): bool =
+  proc m(x, pat, conds: NimNode) =
+    case pat.kind
+    of nnkInfix:
+      case $pat[0]
+      of "*": conds.add getAst(isMul(x))
+      of "/": conds.add getAst(isDiv(x))
+      of "+": conds.add getAst(isAdd(x))
+      of "-": conds.add getAst(isSub(x))
+      else:
+        error("invalid pattern")
+      m(newTree(nnkBracketExpr, x, newLit(1)), pat[1], conds)
+      m(newTree(nnkBracketExpr, x, newLit(2)), pat[2], conds)
+    of nnkPar:
+      if pat.len == 1:
+        m(x, pat[0], conds)
+      else:
+        error("invalid pattern")
+    of nnkIdent:
+      let c = newTree(nnkStmtListExpr, newLetStmt(pat, x))
+      conds.add c
+      if ($pat)[^1] == 'c': c.add(getAst(isVal(pat)))
+      else: c.add bindSym"true"
+    of nnkIntLit:
+      conds.add(getAst(isIntVal(pat.intVal)))
+    else:
+      error("invalid pattern")
+
+  var conds = newTree(nnkBracket)
+  m(x, pat, conds)
+  result = nestList(!"and", conds)
+
+
+proc isMinusOne(n: PNode): bool =
+  n.kind in {nkCharLit..nkUInt64Lit} and n.intVal == -1
+
 proc pleViaModel(model: TModel; aa, bb: PNode): TImplication
 
 proc ple(m: TModel; a, b: PNode): TImplication =
   template `<=?`(a,b): expr = ple(m,a,b) == impYes
+  template `>=?`(a,b): expr = ple(m, nkIntLit.newIntNode(b), a) == impYes
 
   #   0 <= 3
   if a.isValue and b.isValue:
@@ -721,6 +798,7 @@ proc ple(m: TModel; a, b: PNode): TImplication =
     if a.intVal <= 0: return impYes
 
   #   x <= y+c  if 0 <= c and x <= y
+  #   x <= y+(-c)  if c <= 0  and y >= x
   if b.getMagic in someAdd and zero() <=? b[2] and a <=? b[1]: return impYes
 
   #   x+c <= y  if c <= 0 and x <= y
@@ -730,10 +808,44 @@ proc ple(m: TModel; a, b: PNode): TImplication =
   if b.getMagic in someMul:
     if a <=? b[1] and one() <=? b[2] and zero() <=? b[1]: return impYes
 
+
+  if a.getMagic in someMul and a[2].isValue and a[1].getMagic in someDiv and
+      a[1][2].isValue:
+    # simplify   (x div 4) * 2 <= y   to  x div (c div d)  <= y
+    if ple(m, buildCall(opDiv, a[1][1], `|div|`(a[1][2], a[2])), b) == impYes:
+      return impYes
+
+  # x*3 + x == x*4. It follows that:
+  # x*3 + y <= x*4  if  y <= x  and 3 <= 4
+  if a =~ x*dc + y and b =~ x2*ec:
+    if sameTree(x, x2):
+      let ec1 = opAdd.buildCall(ec, minusOne())
+      if x >=? 1 and ec >=? 1 and dc >=? 1 and dc <=? ec1 and y <=? x:
+        return impYes
+  elif a =~ x*dc and b =~ x2*ec + y:
+    #echo "BUG cam ehrer e ", a, " <=? ", b
+    if sameTree(x, x2):
+      let ec1 = opAdd.buildCall(ec, minusOne())
+      if x >=? 1 and ec >=? 1 and dc >=? 1 and dc <=? ec1 and y <=? zero():
+        return impYes
+
+  #  x+c <= x+d  if c <= d. Same for *, - etc.
+  if a.getMagic in someBinaryOp and a.getMagic == b.getMagic:
+    if sameTree(a[1], b[1]) and a[2] <=? b[2]: return impYes
+    elif sameTree(a[2], b[2]) and a[1] <=? b[1]: return impYes
+
   #   x div c <= y   if   1 <= c  and  0 <= y  and x <= y:
   if a.getMagic in someDiv:
     if one() <=? a[2] and zero() <=? b and a[1] <=? b: return impYes
 
+    #  x div c <= x div d  if d <= c
+    if b.getMagic in someDiv:
+      if sameTree(a[1], b[1]) and b[2] <=? a[2]: return impYes
+
+    # x div z <= x - 1   if  z <= x
+    if a[2].isValue and b.getMagic in someAdd and b[2].isMinusOne:
+      if a[2] <=? a[1] and sameTree(a[1], b[1]): return impYes
+
   # slightly subtle:
   # x <= max(y, z)  iff x <= y or x <= z
   # note that 'x <= max(x, z)' is a special case of the above rule
@@ -769,11 +881,19 @@ proc pleViaModelRec(m: var TModel; a, b: PNode): TImplication =
   for i in 0..m.high:
     let fact = m[i]
     if fact != nil and fact.getMagic in someLe:
-      # x <= y implies a <= b  if  a <= x and y <= b
-      let x = fact[1]
-      let y = fact[2]
       # mark as used:
       m[i] = nil
+      # i <= len-100
+      # i <=? len-1
+      # --> true  if  (len-100) <= (len-1)
+      let x = fact[1]
+      let y = fact[2]
+      if sameTree(x, a) and y.getMagic in someAdd and b.getMagic in someAdd and
+         sameTree(y[1], b[1]):
+        if ple(m, b[2], y[2]) == impYes:
+          return impYes
+
+      # x <= y implies a <= b  if  a <= x and y <= b
       if ple(m, a, x) == impYes:
         if ple(m, y, b) == impYes:
           return impYes