diff options
Diffstat (limited to 'compiler/guards.nim')
-rw-r--r-- | compiler/guards.nim | 316 |
1 files changed, 297 insertions, 19 deletions
diff --git a/compiler/guards.nim b/compiler/guards.nim index f475f5068..4cf06fe02 100644 --- a/compiler/guards.nim +++ b/compiler/guards.nim @@ -1,7 +1,7 @@ # # # The Nimrod Compiler -# (c) Copyright 2013 Andreas Rumpf +# (c) Copyright 2014 Andreas Rumpf # # See the file "copying.txt", included in this # distribution, for details about the copyright. @@ -9,7 +9,8 @@ ## This module implements the 'implies' relation for guards. -import ast, astalgo, msgs, magicsys, nimsets, trees, types, renderer, idents +import ast, astalgo, msgs, magicsys, nimsets, trees, types, renderer, idents, + saturate const someEq = {mEqI, mEqI64, mEqF64, mEqEnum, mEqCh, mEqB, mEqRef, mEqProc, @@ -25,6 +26,17 @@ const someIn = {mInRange, mInSet} + someHigh = {mHigh} + # we don't list unsigned here because wrap around semantics suck for + # proving anything: + someAdd = {mAddI, mAddI64, mAddF64, mSucc} + someSub = {mSubI, mSubI64, mSubF64, mPred} + someMul = {mMulI, mMulI64, mMulF64} + someDiv = {mDivI, mDivI64, mDivF64} + someMod = {mModI, mModI64} + someMax = {mMaxI, mMaxI64, mMaxF64} + someMin = {mMinI, mMinI64, mMinF64} + proc isValue(n: PNode): bool = n.kind in {nkCharLit..nkNilLit} proc isLocation(n: PNode): bool = not n.isValue @@ -69,19 +81,24 @@ proc isLetLocation(m: PNode, isApprox: bool): bool = proc interestingCaseExpr*(m: PNode): bool = isLetLocation(m, true) -proc getMagicOp(name: string, m: TMagic): PSym = +proc createMagic*(name: string, m: TMagic): PSym = result = newSym(skProc, getIdent(name), nil, unknownLineInfo()) result.magic = m let - opLe = getMagicOp("<=", mLeI) - opLt = getMagicOp("<", mLtI) - opAnd = getMagicOp("and", mAnd) - opOr = getMagicOp("or", mOr) - opNot = getMagicOp("not", mNot) - opIsNil = getMagicOp("isnil", mIsNil) - opContains = getMagicOp("contains", mInSet) - opEq = getMagicOp("==", mEqI) + opLe = createMagic("<=", mLeI) + opLt = createMagic("<", mLtI) + opAnd = createMagic("and", mAnd) + opOr = createMagic("or", mOr) + opNot = createMagic("not", mNot) + opIsNil = createMagic("isnil", mIsNil) + opContains = createMagic("contains", mInSet) + opEq = createMagic("==", mEqI) + opAdd = createMagic("+", mAddI) + opSub = createMagic("-", mSubI) + opMul = createMagic("*", mMulI) + opDiv = createMagic("div", mDivI) + opLen = createMagic("len", mLengthSeq) proc swapArgs(fact: PNode, newOp: PSym): PNode = result = newNodeI(nkCall, fact.info, 3) @@ -137,17 +154,141 @@ proc neg(n: PNode): PNode = result.sons[0] = newSymNode(opNot) result.sons[1] = n -proc buildIsNil(arg: PNode): PNode = - result = newNodeI(nkCall, arg.info, 2) - result.sons[0] = newSymNode(opIsNil) - result.sons[1] = arg +proc buildCall(op: PSym; a: PNode): PNode = + result = newNodeI(nkCall, a.info, 2) + result.sons[0] = newSymNode(op) + result.sons[1] = a + +proc buildCall(op: PSym; a, b: PNode): PNode = + result = newNodeI(nkInfix, a.info, 3) + result.sons[0] = newSymNode(op) + result.sons[1] = a + result.sons[2] = b + +proc `|+|`(a, b: PNode): PNode = + result = copyNode(a) + if a.kind in {nkCharLit..nkUInt64Lit}: result.intVal = a.intVal |+| b.intVal + else: result.floatVal = a.floatVal + b.floatVal + +proc `|*|`(a, b: PNode): PNode = + result = copyNode(a) + if a.kind in {nkCharLit..nkUInt64Lit}: result.intVal = a.intVal |*| b.intVal + else: result.floatVal = a.floatVal * b.floatVal + +proc negate(a, b, res: PNode): PNode = + if b.kind in {nkCharLit..nkUInt64Lit} and b.intVal != low(BiggestInt): + var b = copyNode(b) + b.intVal = -b.intVal + if a.kind in {nkCharLit..nkUInt64Lit}: + b.intVal = b.intVal |+| a.intVal + result = b + else: + result = buildCall(opAdd, a, b) + elif b.kind in {nkFloatLit..nkFloat64Lit}: + var b = copyNode(b) + b.floatVal = -b.floatVal + result = buildCall(opAdd, a, b) + else: + result = res + +proc zero(): PNode = nkIntLit.newIntNode(0) +proc one(): PNode = nkIntLit.newIntNode(1) +proc minusOne(): PNode = nkIntLit.newIntNode(-1) + +proc lowBound*(x: PNode): PNode = + result = nkIntLit.newIntNode(firstOrd(x.typ)) + result.info = x.info + +proc highBound*(x: PNode): PNode = + result = if x.typ.skipTypes(abstractInst).kind == tyArray: + nkIntLit.newIntNode(lastOrd(x.typ)) + else: + opAdd.buildCall(opLen.buildCall(x), minusOne()) + result.info = x.info + +proc reassociation(n: PNode): PNode = + result = n + # (foo+5)+5 --> foo+10; same for '*' + case result.getMagic + of someAdd: + if result[2].isValue and + result[1].getMagic in someAdd and result[1][2].isValue: + result = opAdd.buildCall(result[1][1], result[1][2] |+| result[2]) + of someMul: + if result[2].isValue and + result[1].getMagic in someMul and result[1][2].isValue: + result = opAdd.buildCall(result[1][1], result[1][2] |*| result[2]) + else: discard + +proc canon*(n: PNode): PNode = + # XXX for now only the new code in 'semparallel' uses this + if n.safeLen >= 1: + result = shallowCopy(n) + for i in 0 .. < n.len: + result.sons[i] = canon(n.sons[i]) + else: + result = n + case result.getMagic + of someEq, someAdd, someMul, someMin, someMax: + # these are symmetric; put value as last: + if result.sons[1].isValue and not result.sons[2].isValue: + result = swapArgs(result, result.sons[0].sym) + # (4 + foo) + 2 --> (foo + 4) + 2 + of someHigh: + # high == len+(-1) + result = opAdd.buildCall(opLen.buildCall(result[1]), minusOne()) + of mUnaryMinusI, mUnaryMinusI64: + result = buildCall(opAdd, result[1], newIntNode(nkIntLit, -1)) + of someSub: + # x - 4 --> x + (-4) + result = negate(result[1], result[2], result) + of someLen: + result.sons[0] = opLen.newSymNode + else: discard + + result = skipConv(result) + result = reassociation(result) + # most important rule: (x-4) < a.len --> x < a.len+4 + case result.getMagic + of someLe, someLt: + let x = result[1] + let y = result[2] + if x.kind in nkCallKinds and x.len == 3 and x[2].isValue and + isLetLocation(x[1], true): + case x.getMagic + of someSub: + result = buildCall(result[0].sym, x[1], + reassociation(opAdd.buildCall(y, x[2]))) + of someAdd: + # Rule A: + let plus = negate(y, x[2], nil).reassociation + if plus != nil: result = buildCall(result[0].sym, x[1], plus) + else: discard + elif y.kind in nkCallKinds and y.len == 3 and y[2].isValue and + isLetLocation(y[1], true): + # a.len < x-3 + case y.getMagic + of someSub: + result = buildCall(result[0].sym, y[1], + reassociation(opAdd.buildCall(x, y[2]))) + of someAdd: + let plus = negate(x, y[2], nil).reassociation + # ensure that Rule A will not trigger afterwards with the + # additional 'not isLetLocation' constraint: + if plus != nil and not isLetLocation(x, true): + result = buildCall(result[0].sym, plus, y[1]) + else: discard + else: discard + +proc `+@`*(a: PNode; b: BiggestInt): PNode = + canon(if b != 0: opAdd.buildCall(a, nkIntLit.newIntNode(b)) else: a) proc usefulFact(n: PNode): PNode = case n.getMagic of someEq: if skipConv(n.sons[2]).kind == nkNilLit and ( isLetLocation(n.sons[1], false) or isVar(n.sons[1])): - result = buildIsNil(n.sons[1]) + result = opIsNil.buildCall(n.sons[1]) else: if isLetLocation(n.sons[1], true) or isLetLocation(n.sons[2], true): # XXX algebraic simplifications! 'i-1 < a.len' --> 'i < a.len+1' @@ -217,7 +358,7 @@ proc addFactNeg*(m: var TModel, n: PNode) = let n = n.neg if n != nil: addFact(m, n) -proc sameTree(a, b: PNode): bool = +proc sameTree*(a, b: PNode): bool = result = false if a == b: result = true @@ -484,7 +625,7 @@ proc factImplies(fact, prop: PNode): TImplication = # == not a or not b == not (a and b) let arg = fact.sons[1] case arg.getMagic - of mIsNil: + of mIsNil, mEqRef: return ~factImplies(arg, prop) of mAnd: # not (a and b) means not a or not b: @@ -519,7 +660,144 @@ proc doesImply*(facts: TModel, prop: PNode): TImplication = if result != impUnknown: return proc impliesNotNil*(facts: TModel, arg: PNode): TImplication = - result = doesImply(facts, buildIsNil(arg).neg) + result = doesImply(facts, opIsNil.buildCall(arg).neg) + +proc simpleSlice*(a, b: PNode): BiggestInt = + # returns 'c' if a..b matches (i+c)..(i+c), -1 otherwise. (i)..(i) is matched + # as if it is (i+0)..(i+0). + if guards.sameTree(a, b): + if a.getMagic in someAdd and a[2].kind in {nkCharLit..nkUInt64Lit}: + result = a[2].intVal + else: + result = 0 + else: + result = -1 + +proc pleViaModel(model: TModel; aa, bb: PNode): TImplication + +proc ple(m: TModel; a, b: PNode): TImplication = + template `<=?`(a,b): expr = ple(m,a,b) == impYes + # 0 <= 3 + if a.isValue and b.isValue: + return if leValue(a, b): impYes else: impNo + + # use type information too: x <= 4 iff high(x) <= 4 + if b.isValue and a.typ != nil and a.typ.isOrdinalType: + if lastOrd(a.typ) <= b.intVal: return impYes + # 3 <= x iff low(x) <= 3 + if a.isValue and b.typ != nil and b.typ.isOrdinalType: + if firstOrd(b.typ) <= a.intVal: return impYes + + # x <= x + if sameTree(a, b): return impYes + + # 0 <= x.len + if b.getMagic in someLen and a.isValue: + if a.intVal <= 0: return impYes + + # x <= y+c if 0 <= c and x <= y + if b.getMagic in someAdd and zero() <=? b[2] and a <=? b[1]: return impYes + + # x+c <= y if c <= 0 and x <= y + if a.getMagic in someAdd and a[2] <=? zero() and a[1] <=? b: return impYes + + # x <= y*c if 1 <= c and x <= y and 0 <= y + if b.getMagic in someMul: + if a <=? b[1] and one() <=? b[2] and zero() <=? b[1]: return impYes + + # x div c <= y if 1 <= c and 0 <= y and x <= y: + if a.getMagic in someDiv: + if one() <=? a[2] and zero() <=? b and a[1] <=? b: return impYes + + # slightly subtle: + # x <= max(y, z) iff x <= y or x <= z + # note that 'x <= max(x, z)' is a special case of the above rule + if b.getMagic in someMax: + if a <=? b[1] or a <=? b[2]: return impYes + + # min(x, y) <= z iff x <= z or y <= z + if a.getMagic in someMin: + if a[1] <=? b or a[2] <=? b: return impYes + + # use the knowledge base: + return pleViaModel(m, a, b) + #return doesImply(m, opLe.buildCall(a, b)) + +type TReplacements = seq[tuple[a,b: PNode]] + +proc replaceSubTree(n, x, by: PNode): PNode = + if sameTree(n, x): + result = by + elif hasSubTree(n, x): + result = shallowCopy(n) + for i in 0 .. safeLen(n)-1: + result.sons[i] = replaceSubTree(n.sons[i], x, by) + else: + result = n + +proc applyReplacements(n: PNode; rep: TReplacements): PNode = + result = n + for x in rep: result = result.replaceSubTree(x.a, x.b) + +proc pleViaModelRec(m: var TModel; a, b: PNode): TImplication = + # now check for inferrable facts: a <= b and b <= c implies a <= c + for i in 0..m.high: + let fact = m[i] + if fact != nil and fact.getMagic in someLe: + # x <= y implies a <= b if a <= x and y <= b + let x = fact[1] + let y = fact[2] + # mark as used: + m[i] = nil + if ple(m, a, x) == impYes: + if ple(m, y, b) == impYes: return impYes + #if pleViaModelRec(m, y, b): return impYes + # fact: 16 <= i + # x y + # question: i <= 15? no! + result = impliesLe(fact, a, b) + if result != impUnknown: return result + if sameTree(y, a): + result = ple(m, x, b) + if result != impUnknown: return result + +proc pleViaModel(model: TModel; aa, bb: PNode): TImplication = + # compute replacements: + var replacements: TReplacements = @[] + for fact in model: + if fact != nil and fact.getMagic in someEq: + let a = fact[1] + let b = fact[2] + if a.kind == nkSym: replacements.add((a,b)) + else: replacements.add((b,a)) + var m: TModel + var a = aa + var b = bb + if replacements.len > 0: + m = @[] + # make the other facts consistent: + for fact in model: + if fact != nil and fact.getMagic notin someEq: + # XXX 'canon' should not be necessary here, but it is + m.add applyReplacements(fact, replacements).canon + a = applyReplacements(aa, replacements) + b = applyReplacements(bb, replacements) + else: + # we have to make a copy here, because the model will be modified: + m = model + result = pleViaModelRec(m, a, b) + +proc proveLe*(m: TModel; a, b: PNode): TImplication = + let x = canon(opLe.buildCall(a, b)) + #echo "ROOT ", renderTree(x[1]), " <=? ", renderTree(x[2]) + result = ple(m, x[1], x[2]) + if result == impUnknown: + # try an alternative: a <= b iff not (b < a) iff not (b+1 <= a): + let y = canon(opLe.buildCall(opAdd.buildCall(b, one()), a)) + result = ~ple(m, y[1], y[2]) + +proc addFactLe*(m: var TModel; a, b: PNode) = + m.add canon(opLe.buildCall(a, b)) proc settype(n: PNode): PType = result = newType(tySet, n.typ.owner) |