summary refs log tree commit diff stats
path: root/compiler/guards.nim
diff options
context:
space:
mode:
Diffstat (limited to 'compiler/guards.nim')
-rw-r--r--compiler/guards.nim316
1 files changed, 297 insertions, 19 deletions
diff --git a/compiler/guards.nim b/compiler/guards.nim
index f475f5068..4cf06fe02 100644
--- a/compiler/guards.nim
+++ b/compiler/guards.nim
@@ -1,7 +1,7 @@
 #
 #
 #           The Nimrod Compiler
-#        (c) Copyright 2013 Andreas Rumpf
+#        (c) Copyright 2014 Andreas Rumpf
 #
 #    See the file "copying.txt", included in this
 #    distribution, for details about the copyright.
@@ -9,7 +9,8 @@
 
 ## This module implements the 'implies' relation for guards.
 
-import ast, astalgo, msgs, magicsys, nimsets, trees, types, renderer, idents
+import ast, astalgo, msgs, magicsys, nimsets, trees, types, renderer, idents,
+  saturate
 
 const
   someEq = {mEqI, mEqI64, mEqF64, mEqEnum, mEqCh, mEqB, mEqRef, mEqProc,
@@ -25,6 +26,17 @@ const
 
   someIn = {mInRange, mInSet}
 
+  someHigh = {mHigh}
+  # we don't list unsigned here because wrap around semantics suck for
+  # proving anything:
+  someAdd = {mAddI, mAddI64, mAddF64, mSucc}
+  someSub = {mSubI, mSubI64, mSubF64, mPred}
+  someMul = {mMulI, mMulI64, mMulF64}
+  someDiv = {mDivI, mDivI64, mDivF64}
+  someMod = {mModI, mModI64}
+  someMax = {mMaxI, mMaxI64, mMaxF64}
+  someMin = {mMinI, mMinI64, mMinF64}
+
 proc isValue(n: PNode): bool = n.kind in {nkCharLit..nkNilLit}
 proc isLocation(n: PNode): bool = not n.isValue
 
@@ -69,19 +81,24 @@ proc isLetLocation(m: PNode, isApprox: bool): bool =
 
 proc interestingCaseExpr*(m: PNode): bool = isLetLocation(m, true)
 
-proc getMagicOp(name: string, m: TMagic): PSym =
+proc createMagic*(name: string, m: TMagic): PSym =
   result = newSym(skProc, getIdent(name), nil, unknownLineInfo())
   result.magic = m
 
 let
-  opLe = getMagicOp("<=", mLeI)
-  opLt = getMagicOp("<", mLtI)
-  opAnd = getMagicOp("and", mAnd)
-  opOr = getMagicOp("or", mOr)
-  opNot = getMagicOp("not", mNot)
-  opIsNil = getMagicOp("isnil", mIsNil)
-  opContains = getMagicOp("contains", mInSet)
-  opEq = getMagicOp("==", mEqI)
+  opLe = createMagic("<=", mLeI)
+  opLt = createMagic("<", mLtI)
+  opAnd = createMagic("and", mAnd)
+  opOr = createMagic("or", mOr)
+  opNot = createMagic("not", mNot)
+  opIsNil = createMagic("isnil", mIsNil)
+  opContains = createMagic("contains", mInSet)
+  opEq = createMagic("==", mEqI)
+  opAdd = createMagic("+", mAddI)
+  opSub = createMagic("-", mSubI)
+  opMul = createMagic("*", mMulI)
+  opDiv = createMagic("div", mDivI)
+  opLen = createMagic("len", mLengthSeq)
 
 proc swapArgs(fact: PNode, newOp: PSym): PNode =
   result = newNodeI(nkCall, fact.info, 3)
@@ -137,17 +154,141 @@ proc neg(n: PNode): PNode =
     result.sons[0] = newSymNode(opNot)
     result.sons[1] = n
 
-proc buildIsNil(arg: PNode): PNode =
-  result = newNodeI(nkCall, arg.info, 2)
-  result.sons[0] = newSymNode(opIsNil)
-  result.sons[1] = arg
+proc buildCall(op: PSym; a: PNode): PNode =
+  result = newNodeI(nkCall, a.info, 2)
+  result.sons[0] = newSymNode(op)
+  result.sons[1] = a
+
+proc buildCall(op: PSym; a, b: PNode): PNode =
+  result = newNodeI(nkInfix, a.info, 3)
+  result.sons[0] = newSymNode(op)
+  result.sons[1] = a
+  result.sons[2] = b
+
+proc `|+|`(a, b: PNode): PNode =
+  result = copyNode(a)
+  if a.kind in {nkCharLit..nkUInt64Lit}: result.intVal = a.intVal |+| b.intVal
+  else: result.floatVal = a.floatVal + b.floatVal
+
+proc `|*|`(a, b: PNode): PNode =
+  result = copyNode(a)
+  if a.kind in {nkCharLit..nkUInt64Lit}: result.intVal = a.intVal |*| b.intVal
+  else: result.floatVal = a.floatVal * b.floatVal
+
+proc negate(a, b, res: PNode): PNode =
+  if b.kind in {nkCharLit..nkUInt64Lit} and b.intVal != low(BiggestInt):
+    var b = copyNode(b)
+    b.intVal = -b.intVal
+    if a.kind in {nkCharLit..nkUInt64Lit}:
+      b.intVal = b.intVal |+| a.intVal
+      result = b
+    else:
+      result = buildCall(opAdd, a, b)
+  elif b.kind in {nkFloatLit..nkFloat64Lit}:
+    var b = copyNode(b)
+    b.floatVal = -b.floatVal
+    result = buildCall(opAdd, a, b)
+  else:
+    result = res
+
+proc zero(): PNode = nkIntLit.newIntNode(0)
+proc one(): PNode = nkIntLit.newIntNode(1)
+proc minusOne(): PNode = nkIntLit.newIntNode(-1)
+
+proc lowBound*(x: PNode): PNode = 
+  result = nkIntLit.newIntNode(firstOrd(x.typ))
+  result.info = x.info
+
+proc highBound*(x: PNode): PNode =
+  result = if x.typ.skipTypes(abstractInst).kind == tyArray:
+             nkIntLit.newIntNode(lastOrd(x.typ))
+           else:
+             opAdd.buildCall(opLen.buildCall(x), minusOne())
+  result.info = x.info
+
+proc reassociation(n: PNode): PNode =
+  result = n
+  # (foo+5)+5 --> foo+10;  same for '*'
+  case result.getMagic
+  of someAdd:
+    if result[2].isValue and 
+        result[1].getMagic in someAdd and result[1][2].isValue:
+      result = opAdd.buildCall(result[1][1], result[1][2] |+| result[2])
+  of someMul:
+    if result[2].isValue and 
+        result[1].getMagic in someMul and result[1][2].isValue:
+      result = opAdd.buildCall(result[1][1], result[1][2] |*| result[2])
+  else: discard
+
+proc canon*(n: PNode): PNode =
+  # XXX for now only the new code in 'semparallel' uses this
+  if n.safeLen >= 1:
+    result = shallowCopy(n)
+    for i in 0 .. < n.len:
+      result.sons[i] = canon(n.sons[i])
+  else:
+    result = n
+  case result.getMagic
+  of someEq, someAdd, someMul, someMin, someMax:
+    # these are symmetric; put value as last:
+    if result.sons[1].isValue and not result.sons[2].isValue:
+      result = swapArgs(result, result.sons[0].sym)
+      # (4 + foo) + 2 --> (foo + 4) + 2
+  of someHigh:
+    # high == len+(-1)
+    result = opAdd.buildCall(opLen.buildCall(result[1]), minusOne())
+  of mUnaryMinusI, mUnaryMinusI64:
+    result = buildCall(opAdd, result[1], newIntNode(nkIntLit, -1))
+  of someSub:
+    # x - 4  -->  x + (-4)
+    result = negate(result[1], result[2], result)
+  of someLen:
+    result.sons[0] = opLen.newSymNode
+  else: discard
+
+  result = skipConv(result)
+  result = reassociation(result)
+  # most important rule: (x-4) < a.len -->  x < a.len+4
+  case result.getMagic
+  of someLe, someLt:
+    let x = result[1]
+    let y = result[2]
+    if x.kind in nkCallKinds and x.len == 3 and x[2].isValue and 
+        isLetLocation(x[1], true):
+      case x.getMagic
+      of someSub:
+        result = buildCall(result[0].sym, x[1], 
+                           reassociation(opAdd.buildCall(y, x[2])))
+      of someAdd:
+        # Rule A:
+        let plus = negate(y, x[2], nil).reassociation
+        if plus != nil: result = buildCall(result[0].sym, x[1], plus)
+      else: discard
+    elif y.kind in nkCallKinds and y.len == 3 and y[2].isValue and 
+        isLetLocation(y[1], true):
+      # a.len < x-3
+      case y.getMagic
+      of someSub:
+        result = buildCall(result[0].sym, y[1],
+                           reassociation(opAdd.buildCall(x, y[2])))
+      of someAdd:
+        let plus = negate(x, y[2], nil).reassociation
+        # ensure that Rule A will not trigger afterwards with the
+        # additional 'not isLetLocation' constraint:
+        if plus != nil and not isLetLocation(x, true):
+          result = buildCall(result[0].sym, plus, y[1])
+      else: discard
+  else: discard
+
+proc `+@`*(a: PNode; b: BiggestInt): PNode =
+  canon(if b != 0: opAdd.buildCall(a, nkIntLit.newIntNode(b)) else: a)
 
 proc usefulFact(n: PNode): PNode =
   case n.getMagic
   of someEq:
     if skipConv(n.sons[2]).kind == nkNilLit and (
         isLetLocation(n.sons[1], false) or isVar(n.sons[1])):
-      result = buildIsNil(n.sons[1])
+      result = opIsNil.buildCall(n.sons[1])
     else:
       if isLetLocation(n.sons[1], true) or isLetLocation(n.sons[2], true):
         # XXX algebraic simplifications!  'i-1 < a.len' --> 'i < a.len+1'
@@ -217,7 +358,7 @@ proc addFactNeg*(m: var TModel, n: PNode) =
   let n = n.neg
   if n != nil: addFact(m, n)
 
-proc sameTree(a, b: PNode): bool = 
+proc sameTree*(a, b: PNode): bool = 
   result = false
   if a == b:
     result = true
@@ -484,7 +625,7 @@ proc factImplies(fact, prop: PNode): TImplication =
     #  == not a or not b == not (a and b)
     let arg = fact.sons[1]
     case arg.getMagic
-    of mIsNil:
+    of mIsNil, mEqRef:
       return ~factImplies(arg, prop)
     of mAnd:
       # not (a and b)  means  not a or not b:
@@ -519,7 +660,144 @@ proc doesImply*(facts: TModel, prop: PNode): TImplication =
       if result != impUnknown: return
 
 proc impliesNotNil*(facts: TModel, arg: PNode): TImplication =
-  result = doesImply(facts, buildIsNil(arg).neg)
+  result = doesImply(facts, opIsNil.buildCall(arg).neg)
+
+proc simpleSlice*(a, b: PNode): BiggestInt =
+  # returns 'c' if a..b matches (i+c)..(i+c), -1 otherwise. (i)..(i) is matched
+  # as if it is (i+0)..(i+0).
+  if guards.sameTree(a, b):
+    if a.getMagic in someAdd and a[2].kind in {nkCharLit..nkUInt64Lit}:
+      result = a[2].intVal
+    else:
+      result = 0
+  else:
+    result = -1
+
+proc pleViaModel(model: TModel; aa, bb: PNode): TImplication
+
+proc ple(m: TModel; a, b: PNode): TImplication =
+  template `<=?`(a,b): expr = ple(m,a,b) == impYes
+  #   0 <= 3
+  if a.isValue and b.isValue:
+    return if leValue(a, b): impYes else: impNo
+
+  # use type information too:  x <= 4  iff  high(x) <= 4
+  if b.isValue and a.typ != nil and a.typ.isOrdinalType:
+    if lastOrd(a.typ) <= b.intVal: return impYes
+  # 3 <= x   iff  low(x) <= 3
+  if a.isValue and b.typ != nil and b.typ.isOrdinalType:
+    if firstOrd(b.typ) <= a.intVal: return impYes
+
+  # x <= x
+  if sameTree(a, b): return impYes
+
+  # 0 <= x.len
+  if b.getMagic in someLen and a.isValue:
+    if a.intVal <= 0: return impYes
+
+  #   x <= y+c  if 0 <= c and x <= y
+  if b.getMagic in someAdd and zero() <=? b[2] and a <=? b[1]: return impYes
+
+  #   x+c <= y  if c <= 0 and x <= y
+  if a.getMagic in someAdd and a[2] <=? zero() and a[1] <=? b: return impYes
+
+  #   x <= y*c  if  1 <= c and x <= y  and 0 <= y
+  if b.getMagic in someMul:
+    if a <=? b[1] and one() <=? b[2] and zero() <=? b[1]: return impYes
+
+  #   x div c <= y   if   1 <= c  and  0 <= y  and x <= y:
+  if a.getMagic in someDiv:
+    if one() <=? a[2] and zero() <=? b and a[1] <=? b: return impYes
+
+  # slightly subtle:
+  # x <= max(y, z)  iff x <= y or x <= z
+  # note that 'x <= max(x, z)' is a special case of the above rule
+  if b.getMagic in someMax:
+    if a <=? b[1] or a <=? b[2]: return impYes
+
+  # min(x, y) <= z  iff x <= z or y <= z
+  if a.getMagic in someMin:
+    if a[1] <=? b or a[2] <=? b: return impYes
+
+  # use the knowledge base:
+  return pleViaModel(m, a, b)
+  #return doesImply(m, opLe.buildCall(a, b))
+
+type TReplacements = seq[tuple[a,b: PNode]]
+
+proc replaceSubTree(n, x, by: PNode): PNode =
+  if sameTree(n, x):
+    result = by
+  elif hasSubTree(n, x):
+    result = shallowCopy(n)
+    for i in 0 .. safeLen(n)-1:
+      result.sons[i] = replaceSubTree(n.sons[i], x, by)
+  else:
+    result = n
+
+proc applyReplacements(n: PNode; rep: TReplacements): PNode =
+  result = n
+  for x in rep: result = result.replaceSubTree(x.a, x.b)
+
+proc pleViaModelRec(m: var TModel; a, b: PNode): TImplication =
+  # now check for inferrable facts: a <= b and b <= c  implies a <= c
+  for i in 0..m.high:
+    let fact = m[i]
+    if fact != nil and fact.getMagic in someLe:
+      # x <= y implies a <= b  if  a <= x and y <= b
+      let x = fact[1]
+      let y = fact[2]
+      # mark as used:
+      m[i] = nil
+      if ple(m, a, x) == impYes:
+        if ple(m, y, b) == impYes: return impYes
+        #if pleViaModelRec(m, y, b): return impYes
+      # fact:  16 <= i
+      #         x    y
+      # question: i <= 15? no!
+      result = impliesLe(fact, a, b)
+      if result != impUnknown: return result
+      if sameTree(y, a):
+        result = ple(m, x, b)
+        if result != impUnknown: return result
+
+proc pleViaModel(model: TModel; aa, bb: PNode): TImplication =
+  # compute replacements:
+  var replacements: TReplacements = @[]
+  for fact in model:
+    if fact != nil and fact.getMagic in someEq:
+      let a = fact[1]
+      let b = fact[2]
+      if a.kind == nkSym: replacements.add((a,b))
+      else: replacements.add((b,a))
+  var m: TModel
+  var a = aa
+  var b = bb
+  if replacements.len > 0:
+    m = @[]
+    # make the other facts consistent:
+    for fact in model:
+      if fact != nil and fact.getMagic notin someEq:
+        # XXX 'canon' should not be necessary here, but it is
+        m.add applyReplacements(fact, replacements).canon
+    a = applyReplacements(aa, replacements)
+    b = applyReplacements(bb, replacements)
+  else:
+    # we have to make a copy here, because the model will be modified:
+    m = model
+  result = pleViaModelRec(m, a, b)
+
+proc proveLe*(m: TModel; a, b: PNode): TImplication =
+  let x = canon(opLe.buildCall(a, b))
+  #echo "ROOT ", renderTree(x[1]), " <=? ", renderTree(x[2])
+  result = ple(m, x[1], x[2])
+  if result == impUnknown:
+    # try an alternative:  a <= b  iff  not (b < a)  iff  not (b+1 <= a):
+    let y = canon(opLe.buildCall(opAdd.buildCall(b, one()), a))
+    result = ~ple(m, y[1], y[2])
+
+proc addFactLe*(m: var TModel; a, b: PNode) =
+  m.add canon(opLe.buildCall(a, b))
 
 proc settype(n: PNode): PType =
   result = newType(tySet, n.typ.owner)