diff options
Diffstat (limited to 'compiler/guards.nim')
-rw-r--r-- | compiler/guards.nim | 136 |
1 files changed, 128 insertions, 8 deletions
diff --git a/compiler/guards.nim b/compiler/guards.nim index bc802ae33..5ad932e48 100644 --- a/compiler/guards.nim +++ b/compiler/guards.nim @@ -37,6 +37,7 @@ const someMod = {mModI} someMax = {mMaxI, mMaxF64} someMin = {mMinI, mMinF64} + someBinaryOp = someAdd+someSub+someMul+someMax+someMin proc isValue(n: PNode): bool = n.kind in {nkCharLit..nkNilLit} proc isLocation(n: PNode): bool = not n.isValue @@ -165,11 +166,21 @@ proc `|+|`(a, b: PNode): PNode = if a.kind in {nkCharLit..nkUInt64Lit}: result.intVal = a.intVal |+| b.intVal else: result.floatVal = a.floatVal + b.floatVal +proc `|-|`(a, b: PNode): PNode = + result = copyNode(a) + if a.kind in {nkCharLit..nkUInt64Lit}: result.intVal = a.intVal |-| b.intVal + else: result.floatVal = a.floatVal - b.floatVal + proc `|*|`(a, b: PNode): PNode = result = copyNode(a) if a.kind in {nkCharLit..nkUInt64Lit}: result.intVal = a.intVal |*| b.intVal else: result.floatVal = a.floatVal * b.floatVal +proc `|div|`(a, b: PNode): PNode = + result = copyNode(a) + if a.kind in {nkCharLit..nkUInt64Lit}: result.intVal = a.intVal div b.intVal + else: result.floatVal = a.floatVal / b.floatVal + proc negate(a, b, res: PNode): PNode = if b.kind in {nkCharLit..nkUInt64Lit} and b.intVal != low(BiggestInt): var b = copyNode(b) @@ -213,10 +224,16 @@ proc reassociation(n: PNode): PNode = if result[2].isValue and result[1].getMagic in someAdd and result[1][2].isValue: result = opAdd.buildCall(result[1][1], result[1][2] |+| result[2]) + if result[2].intVal == 0: + result = result[1] of someMul: if result[2].isValue and result[1].getMagic in someMul and result[1][2].isValue: - result = opAdd.buildCall(result[1][1], result[1][2] |*| result[2]) + result = opMul.buildCall(result[1][1], result[1][2] |*| result[2]) + if result[2].intVal == 1: + result = result[1] + elif result[2].intVal == 0: + result = zero() else: discard proc pred(n: PNode): PNode = @@ -234,7 +251,7 @@ proc canon*(n: PNode): PNode = result.sons[i] = canon(n.sons[i]) elif n.kind == nkSym and n.sym.kind == skLet and n.sym.ast.getMagic in (someEq + someAdd + someMul + someMin + - someMax + someHigh + {mUnaryLt} + someSub + someLen): + someMax + someHigh + {mUnaryLt} + someSub + someLen + someDiv): result = n.sym.ast.copyTree else: result = n @@ -248,7 +265,7 @@ proc canon*(n: PNode): PNode = # high == len+(-1) result = opAdd.buildCall(opLen.buildCall(result[1]), minusOne()) of mUnaryLt: - result = buildCall(opAdd, result[1], newIntNode(nkIntLit, -1)) + result = buildCall(opAdd, result[1], minusOne()) of someSub: # x - 4 --> x + (-4) result = negate(result[1], result[2], result) @@ -294,6 +311,16 @@ proc canon*(n: PNode): PNode = if plus != nil and not isLetLocation(x, true): result = buildCall(result[0].sym, plus, y[1]) else: discard + elif x.isValue and y.getMagic in someAdd and y[2].isValue: + # 0 <= a.len + 3 + # -3 <= a.len + result.sons[1] = x |-| y[2] + result.sons[2] = y[1] + elif x.isValue and y.getMagic in someSub and y[2].isValue: + # 0 <= a.len - 3 + # 3 <= a.len + result.sons[1] = x |+| y[2] + result.sons[2] = y[1] else: discard proc `+@`*(a: PNode; b: BiggestInt): PNode = @@ -313,6 +340,9 @@ proc usefulFact(n: PNode): PNode = if isLetLocation(n.sons[1], true) or isLetLocation(n.sons[2], true): # XXX algebraic simplifications! 'i-1 < a.len' --> 'i < a.len+1' result = n + elif n[1].getMagic in someLen or n[2].getMagic in someLen: + # XXX Rethink this whole idea of 'usefulFact' for semparallel + result = n of mIsNil: if isLetLocation(n.sons[1], false) or isVar(n.sons[1]): result = n @@ -366,8 +396,8 @@ proc usefulFact(n: PNode): PNode = type TModel* = seq[PNode] # the "knowledge base" -proc addFact*(m: var TModel, n: PNode) = - let n = usefulFact(n) +proc addFact*(m: var TModel, nn: PNode) = + let n = usefulFact(nn) if n != nil: m.add n proc addFactNeg*(m: var TModel, n: PNode) = @@ -697,10 +727,57 @@ proc simpleSlice*(a, b: PNode): BiggestInt = else: result = -1 + +template isMul(x): expr = x.getMagic in someMul +template isDiv(x): expr = x.getMagic in someDiv +template isAdd(x): expr = x.getMagic in someAdd +template isSub(x): expr = x.getMagic in someSub +template isVal(x): expr = x.kind in {nkCharLit..nkUInt64Lit} +template isIntVal(x, y): expr = x.intVal == y + +import macros + +macro `=~`(x: PNode, pat: untyped): bool = + proc m(x, pat, conds: NimNode) = + case pat.kind + of nnkInfix: + case $pat[0] + of "*": conds.add getAst(isMul(x)) + of "/": conds.add getAst(isDiv(x)) + of "+": conds.add getAst(isAdd(x)) + of "-": conds.add getAst(isSub(x)) + else: + error("invalid pattern") + m(newTree(nnkBracketExpr, x, newLit(1)), pat[1], conds) + m(newTree(nnkBracketExpr, x, newLit(2)), pat[2], conds) + of nnkPar: + if pat.len == 1: + m(x, pat[0], conds) + else: + error("invalid pattern") + of nnkIdent: + let c = newTree(nnkStmtListExpr, newLetStmt(pat, x)) + conds.add c + if ($pat)[^1] == 'c': c.add(getAst(isVal(pat))) + else: c.add bindSym"true" + of nnkIntLit: + conds.add(getAst(isIntVal(pat.intVal))) + else: + error("invalid pattern") + + var conds = newTree(nnkBracket) + m(x, pat, conds) + result = nestList(!"and", conds) + + +proc isMinusOne(n: PNode): bool = + n.kind in {nkCharLit..nkUInt64Lit} and n.intVal == -1 + proc pleViaModel(model: TModel; aa, bb: PNode): TImplication proc ple(m: TModel; a, b: PNode): TImplication = template `<=?`(a,b): expr = ple(m,a,b) == impYes + template `>=?`(a,b): expr = ple(m, nkIntLit.newIntNode(b), a) == impYes # 0 <= 3 if a.isValue and b.isValue: @@ -721,6 +798,7 @@ proc ple(m: TModel; a, b: PNode): TImplication = if a.intVal <= 0: return impYes # x <= y+c if 0 <= c and x <= y + # x <= y+(-c) if c <= 0 and y >= x if b.getMagic in someAdd and zero() <=? b[2] and a <=? b[1]: return impYes # x+c <= y if c <= 0 and x <= y @@ -730,10 +808,44 @@ proc ple(m: TModel; a, b: PNode): TImplication = if b.getMagic in someMul: if a <=? b[1] and one() <=? b[2] and zero() <=? b[1]: return impYes + + if a.getMagic in someMul and a[2].isValue and a[1].getMagic in someDiv and + a[1][2].isValue: + # simplify (x div 4) * 2 <= y to x div (c div d) <= y + if ple(m, buildCall(opDiv, a[1][1], `|div|`(a[1][2], a[2])), b) == impYes: + return impYes + + # x*3 + x == x*4. It follows that: + # x*3 + y <= x*4 if y <= x and 3 <= 4 + if a =~ x*dc + y and b =~ x2*ec: + if sameTree(x, x2): + let ec1 = opAdd.buildCall(ec, minusOne()) + if x >=? 1 and ec >=? 1 and dc >=? 1 and dc <=? ec1 and y <=? x: + return impYes + elif a =~ x*dc and b =~ x2*ec + y: + #echo "BUG cam ehrer e ", a, " <=? ", b + if sameTree(x, x2): + let ec1 = opAdd.buildCall(ec, minusOne()) + if x >=? 1 and ec >=? 1 and dc >=? 1 and dc <=? ec1 and y <=? zero(): + return impYes + + # x+c <= x+d if c <= d. Same for *, - etc. + if a.getMagic in someBinaryOp and a.getMagic == b.getMagic: + if sameTree(a[1], b[1]) and a[2] <=? b[2]: return impYes + elif sameTree(a[2], b[2]) and a[1] <=? b[1]: return impYes + # x div c <= y if 1 <= c and 0 <= y and x <= y: if a.getMagic in someDiv: if one() <=? a[2] and zero() <=? b and a[1] <=? b: return impYes + # x div c <= x div d if d <= c + if b.getMagic in someDiv: + if sameTree(a[1], b[1]) and b[2] <=? a[2]: return impYes + + # x div z <= x - 1 if z <= x + if a[2].isValue and b.getMagic in someAdd and b[2].isMinusOne: + if a[2] <=? a[1] and sameTree(a[1], b[1]): return impYes + # slightly subtle: # x <= max(y, z) iff x <= y or x <= z # note that 'x <= max(x, z)' is a special case of the above rule @@ -769,11 +881,19 @@ proc pleViaModelRec(m: var TModel; a, b: PNode): TImplication = for i in 0..m.high: let fact = m[i] if fact != nil and fact.getMagic in someLe: - # x <= y implies a <= b if a <= x and y <= b - let x = fact[1] - let y = fact[2] # mark as used: m[i] = nil + # i <= len-100 + # i <=? len-1 + # --> true if (len-100) <= (len-1) + let x = fact[1] + let y = fact[2] + if sameTree(x, a) and y.getMagic in someAdd and b.getMagic in someAdd and + sameTree(y[1], b[1]): + if ple(m, b[2], y[2]) == impYes: + return impYes + + # x <= y implies a <= b if a <= x and y <= b if ple(m, a, x) == impYes: if ple(m, y, b) == impYes: return impYes |