diff options
author | Andreas Rumpf <rumpf_a@web.de> | 2020-04-15 20:03:25 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2020-04-15 20:03:25 +0200 |
commit | 3a2697dd731cb8fcfd0d279bb856090eca5028ee (patch) | |
tree | 971390193c83b0d14045f535a06bfc18071a741b /drnim/drnim.nim | |
parent | 04b6e9cf3e6e1113cb5989a82878e525a7f0891f (diff) | |
download | Nim-3a2697dd731cb8fcfd0d279bb856090eca5028ee.tar.gz |
drnim: tiny progress (#13882)
* drnim: tiny progress * refactoring complete * drnim: prove .ensures annotations * Moved code around to avoid code duplication * drnim: first implementation of the 'old' property * drnim: be precise about the assignment statement * first implementation of --assumeUnique * progress on forall/exists handling
Diffstat (limited to 'drnim/drnim.nim')
-rw-r--r-- | drnim/drnim.nim | 728 |
1 files changed, 594 insertions, 134 deletions
diff --git a/drnim/drnim.nim b/drnim/drnim.nim index e814496e4..7c79add26 100644 --- a/drnim/drnim.nim +++ b/drnim/drnim.nim @@ -9,53 +9,21 @@ #[ -- Most important bug: - - while i < x.len and use(s[i]): inc i # is safe - +- the analysis has to take 'break', 'continue' and 'raises' into account - We need to map arrays to Z3 and test for something like 'forall(i, (i in 3..4) -> (a[i] > 3))' -- forall/exists need syntactic sugar as the manual - We need teach DrNim what 'inc', 'dec' and 'swap' mean, for example 'x in n..m; inc x' implies 'x in n+1..m+1' -- We need an ``old`` annotation: - -proc f(x: var int; y: var int) {.ensures: x == old(x)+1 and y == old(y)+1 .} = - inc x - inc y - -var x = 3 -var y: range[N..M] -f(x, y) -{.assume: y in N+1 .. M+1.} -# --> y in N+1..M+1 - -proc myinc(x: var int) {.ensures: x-1 == old(x).} = - inc x - -facts(x) # x < 3 -myinc x -facts(x+1) - -We handle state transitions in this way: - - for every f in facts: - replace 'x' by 'old(x)' - facts.add ensuresClause - - # then we know: old(x) < 3; x-1 == old(x) - # we can conclude: x-1 < 3 but leave this task to Z3 - ]# import std / [ - parseopt, strutils, os, tables, times + parseopt, strutils, os, tables, times, intsets, hashes ] import ".." / compiler / [ - ast, types, renderer, + ast, astalgo, types, renderer, commands, options, msgs, - platform, + platform, trees, wordrecg, guards, idents, lineinfos, cmdlinehelper, modulegraphs, condsyms, pathutils, passes, passaux, sem, modules ] @@ -91,42 +59,174 @@ proc helpOnError(conf: ConfigRef) = type CannotMapToZ3Error = object of ValueError Z3Exception = object of ValueError + VersionScope = distinct int + DrnimContext = ref object + z3: Z3_context + graph: ModuleGraph + facts: seq[(PNode, VersionScope)] + varVersions: seq[int] # this maps variable IDs to their current version. + o: Operators + hasUnstructedCf: int + currOptions: TOptions + owner: PSym + mangler: seq[PSym] DrCon = object - z3: Z3_context graph: ModuleGraph mapping: Table[string, Z3_ast] canonParameterNames: bool - -proc stableName(result: var string; n: PNode) = + assumeUniqueness: bool + up: DrnimContext + +var + assumeUniqueness: bool + +proc echoFacts(c: DrnimContext) = + echo "FACTS:" + for i in 0 ..< c.facts.len: + let f = c.facts[i] + echo f[0], " version ", int(f[1]) + +proc isLoc(m: PNode; assumeUniqueness: bool): bool = + # We can reason about "locations" and map them to Z3 constants. + # For code that is full of "ref" (e.g. the Nim compiler itself) that + # is too limiting + proc isLet(n: PNode): bool = + if n.kind == nkSym: + if n.sym.kind in {skLet, skTemp, skForVar}: + result = true + elif n.sym.kind == skParam and skipTypes(n.sym.typ, + abstractInst).kind != tyVar: + result = true + + var n = m + while true: + case n.kind + of nkDotExpr, nkCheckedFieldExpr, nkObjUpConv, nkObjDownConv, nkHiddenDeref: + n = n[0] + of nkDerefExpr: + n = n[0] + if not assumeUniqueness: return false + of nkBracketExpr: + if isConstExpr(n[1]) or isLet(n[1]) or isConstExpr(n[1].skipConv): + n = n[0] + else: return + of nkHiddenStdConv, nkHiddenSubConv, nkConv: + n = n[1] + else: + break + if n.kind == nkSym: + case n.sym.kind + of skLet, skTemp, skForVar, skParam: + result = true + #of skParam: + # result = skipTypes(n.sym.typ, abstractInst).kind != tyVar + of skResult, skVar: + result = {sfAddrTaken} * n.sym.flags == {} + else: + discard + +proc varVersion(c: DrnimContext; s: PSym; begin: VersionScope): int = + result = 0 + for i in countdown(int(begin)-1, 0): + if c.varVersions[i] == s.id: inc result + +proc disamb(c: DrnimContext; s: PSym): int = + # we group by 's.name.s' to compute the stable name ID. + result = 0 + for i in 0 ..< c.mangler.len: + if s == c.mangler[i]: return result + if s.name.s == c.mangler[i].name.s: inc result + c.mangler.add s + +proc stableName(result: var string; c: DrnimContext; n: PNode; version: VersionScope; + isOld: bool) = # we can map full Nim expressions like 'f(a, b, c)' to Z3 variables. - # We must be carefult to select a unique, stable name for these expressions + # We must be careful to select a unique, stable name for these expressions # based on structural equality. 'stableName' helps us with this problem. + # In the future we will also use this string for the caching mechanism. case n.kind of nkEmpty, nkNilLit, nkType: discard of nkIdent: result.add n.ident.s of nkSym: result.add n.sym.name.s - result.add '_' - result.addInt n.sym.id + if n.sym.magic == mNone: + let d = disamb(c, n.sym) + if d != 0: + result.add "`scope=" + result.addInt d + let v = c.varVersion(n.sym, version) - ord(isOld) + assert v >= 0 + if v > 0: + result.add '`' + result.addInt v + else: + result.add "`magic=" + result.addInt ord(n.sym.magic) + of nkCharLit..nkUInt64Lit: result.addInt n.intVal of nkFloatLit..nkFloat64Lit: result.addFloat n.floatVal of nkStrLit..nkTripleStrLit: result.add strutils.escape n.strVal + of nkDotExpr: + stableName(result, c, n[0], version, isOld) + result.add '.' + stableName(result, c, n[1], version, isOld) + of nkBracketExpr: + stableName(result, c, n[0], version, isOld) + result.add '[' + stableName(result, c, n[1], version, isOld) + result.add ']' + of nkCallKinds: + if n.len == 2: + stableName(result, c, n[1], version, isOld) + result.add '.' + case getMagic(n) + of mLengthArray, mLengthOpenArray, mLengthSeq, mLengthStr: + result.add "len" + of mHigh: + result.add "high" + of mLow: + result.add "low" + else: + stableName(result, c, n[0], version, isOld) + elif n.kind == nkInfix and n.len == 3: + result.add '(' + stableName(result, c, n[1], version, isOld) + result.add ' ' + stableName(result, c, n[0], version, isOld) + result.add ' ' + stableName(result, c, n[2], version, isOld) + result.add ')' + else: + stableName(result, c, n[0], version, isOld) + result.add '(' + for i in 1..<n.len: + if i > 1: result.add ", " + stableName(result, c, n[i], version, isOld) + result.add ')' else: result.add $n.kind result.add '(' for i in 0..<n.len: if i > 0: result.add ", " - stableName(result, n[i]) + stableName(result, c, n[i], version, isOld) result.add ')' -proc stableName(n: PNode): string = stableName(result, n) +proc stableName(c: DrnimContext; n: PNode; version: VersionScope; + isOld = false): string = + stableName(result, c, n, version, isOld) + +template allScopes(c): untyped = VersionScope(c.varVersions.len) +template currentScope(c): untyped = VersionScope(c.varVersions.len) proc notImplemented(msg: string) {.noinline.} = + when defined(debug): + writeStackTrace() + echo msg raise newException(CannotMapToZ3Error, "; cannot map to Z3: " & msg) proc translateEnsures(e, x: PNode): PNode = @@ -138,7 +238,7 @@ proc translateEnsures(e, x: PNode): PNode = result[i] = translateEnsures(e[i], x) proc typeToZ3(c: DrCon; t: PType): Z3_sort = - template ctx: untyped = c.z3 + template ctx: untyped = c.up.z3 case t.skipTypes(abstractInst+{tyVar}).kind of tyEnum, tyInt..tyInt64: result = Z3_mk_int_sort(ctx) @@ -156,42 +256,64 @@ template binary(op, a, b): untyped = var arr = [a, b] op(ctx, cuint(2), addr(arr[0])) -proc nodeToZ3(c: var DrCon; n: PNode; vars: var seq[PNode]): Z3_ast +proc nodeToZ3(c: var DrCon; n: PNode; scope: VersionScope; vars: var seq[PNode]): Z3_ast + +proc nodeToDomain(c: var DrCon; n, q: PNode; opAnd: PSym): PNode = + assert n.kind == nkInfix + let opLe = createMagic(c.graph, "<=", mLeI) + case $n[0] + of "..": + result = buildCall(opAnd, buildCall(opLe, n[1], q), buildCall(opLe, q, n[2])) + of "..<": + let opLt = createMagic(c.graph, "<", mLtI) + result = buildCall(opAnd, buildCall(opLe, n[1], q), buildCall(opLt, q, n[2])) + else: + notImplemented($n) template quantorToZ3(fn) {.dirty.} = - template ctx: untyped = c.z3 - - var bound = newSeq[Z3_app](n.len-1) - for i in 0..n.len-2: - doAssert n[i].kind == nkSym - let v = n[i].sym + template ctx: untyped = c.up.z3 + + var bound = newSeq[Z3_app](n.len-2) + let opAnd = createMagic(c.graph, "and", mAnd) + var known: PNode + for i in 1..n.len-2: + let it = n[i] + doAssert it.kind == nkInfix + let v = it[1].sym let name = Z3_mk_string_symbol(ctx, v.name.s) let vz3 = Z3_mk_const(ctx, name, typeToZ3(c, v.typ)) - c.mapping[stableName(n[i])] = vz3 - bound[i] = Z3_to_app(ctx, vz3) + c.mapping[stableName(c.up, it[1], allScopes(c.up))] = vz3 + bound[i-1] = Z3_to_app(ctx, vz3) + let domain = nodeToDomain(c, it[2], it[1], opAnd) + if known == nil: + known = domain + else: + known = buildCall(opAnd, known, domain) var dummy: seq[PNode] - let x = nodeToZ3(c, n[^1], dummy) + assert known != nil + let x = nodeToZ3(c, buildCall(createMagic(c.graph, "->", mImplies), + known, n[^1]), scope, dummy) result = fn(ctx, 0, bound.len.cuint, addr(bound[0]), 0, nil, x) -proc forallToZ3(c: var DrCon; n: PNode): Z3_ast = quantorToZ3(Z3_mk_forall_const) -proc existsToZ3(c: var DrCon; n: PNode): Z3_ast = quantorToZ3(Z3_mk_exists_const) +proc forallToZ3(c: var DrCon; n: PNode; scope: VersionScope): Z3_ast = quantorToZ3(Z3_mk_forall_const) +proc existsToZ3(c: var DrCon; n: PNode; scope: VersionScope): Z3_ast = quantorToZ3(Z3_mk_exists_const) -proc paramName(n: PNode): string = +proc paramName(c: DrnimContext; n: PNode): string = case n.sym.kind of skParam: result = "arg" & $n.sym.position of skResult: result = "result" - else: result = stableName(n) + else: result = stableName(c, n, allScopes(c)) -proc nodeToZ3(c: var DrCon; n: PNode; vars: var seq[PNode]): Z3_ast = - template ctx: untyped = c.z3 - template rec(n): untyped = nodeToZ3(c, n, vars) +proc nodeToZ3(c: var DrCon; n: PNode; scope: VersionScope; vars: var seq[PNode]): Z3_ast = + template ctx: untyped = c.up.z3 + template rec(n): untyped = nodeToZ3(c, n, scope, vars) case n.kind of nkSym: - let key = if c.canonParameterNames: paramName(n) else: stableName(n) + let key = if c.canonParameterNames: paramName(c.up, n) else: stableName(c.up, n, scope) result = c.mapping.getOrDefault(key) if pointer(result) == nil: - let name = Z3_mk_string_symbol(ctx, n.sym.name.s) + let name = Z3_mk_string_symbol(ctx, key) result = Z3_mk_const(ctx, name, typeToZ3(c, n.sym.typ)) c.mapping[key] = result vars.add n @@ -222,17 +344,23 @@ proc nodeToZ3(c: var DrCon; n: PNode; vars: var seq[PNode]): Z3_ast = result = Z3_mk_lt(ctx, rec n[1], rec n[2]) of mLengthOpenArray, mLengthStr, mLengthArray, mLengthSeq: # len(x) needs the same logic as 'x' itself - if n[1].kind == nkSym: - let key = stableName(n) - let sym = n[1].sym + if isLoc(n[1], c.assumeUniqueness): + let key = stableName(c.up, n, scope) result = c.mapping.getOrDefault(key) if pointer(result) == nil: - let name = Z3_mk_string_symbol(ctx, sym.name.s & ".len") + let name = Z3_mk_string_symbol(ctx, key) result = Z3_mk_const(ctx, name, Z3_mk_int_sort(ctx)) c.mapping[key] = result vars.add n else: notImplemented(renderTree(n)) + of mHigh: + let addOpr = createMagic(c.graph, "+", mAddI) + let lenOpr = createMagic(c.graph, "len", mLengthOpenArray) + let asLenExpr = addOpr.buildCall(lenOpr.buildCall(n[1]), nkIntLit.newIntNode(-1)) + result = rec asLenExpr + of mLow: + result = rec lowBound(c.graph.config, n[1]) of mAddI, mSucc: result = binary(Z3_mk_add, rec n[1], rec n[2]) of mSubI, mPred: @@ -256,9 +384,12 @@ proc nodeToZ3(c: var DrCon; n: PNode; vars: var seq[PNode]): Z3_ast = of mLtU: result = Z3_mk_bvult(ctx, rec n[1], rec n[2]) of mAnd: - result = binary(Z3_mk_and, rec n[1], rec n[2]) + # 'a and b' <=> ite(a, b, false) + result = Z3_mk_ite(ctx, rec n[1], rec n[2], Z3_mk_false(ctx)) + #result = binary(Z3_mk_and, rec n[1], rec n[2]) of mOr: - result = binary(Z3_mk_or, rec n[1], rec n[2]) + result = Z3_mk_ite(ctx, rec n[1], Z3_mk_true(ctx), rec n[2]) + #result = binary(Z3_mk_or, rec n[1], rec n[2]) of mXor: result = Z3_mk_xor(ctx, rec n[1], rec n[2]) of mNot: @@ -268,9 +399,9 @@ proc nodeToZ3(c: var DrCon; n: PNode; vars: var seq[PNode]): Z3_ast = of mIff: result = Z3_mk_iff(ctx, rec n[1], rec n[2]) of mForall: - result = forallToZ3(c, n) + result = forallToZ3(c, n, scope) of mExists: - result = existsToZ3(c, n) + result = existsToZ3(c, n, scope) of mLeF64: result = Z3_mk_fpa_leq(ctx, rec n[1], rec n[2]) of mLtF64: @@ -299,11 +430,12 @@ proc nodeToZ3(c: var DrCon; n: PNode; vars: var seq[PNode]): Z3_ast = of mOrd, mChr: result = rec n[1] of mOld: - let key = (if c.canonParameterNames: paramName(n[1]) else: stableName(n[1])) & ".old" + let key = if c.canonParameterNames: (paramName(c.up, n[1]) & ".old") + else: stableName(c.up, n[1], scope, isOld = true) result = c.mapping.getOrDefault(key) if pointer(result) == nil: - let name = Z3_mk_string_symbol(ctx, $n) - result = Z3_mk_const(ctx, name, typeToZ3(c, n.typ)) + let name = Z3_mk_string_symbol(ctx, key) + result = Z3_mk_const(ctx, name, typeToZ3(c, n[1].typ)) c.mapping[key] = result # XXX change the logic in `addRangeInfo` for this #vars.add n @@ -318,10 +450,10 @@ proc nodeToZ3(c: var DrCon; n: PNode; vars: var seq[PNode]): Z3_ast = ensuresEffects < op.n[0].len: let ensures = op.n[0][ensuresEffects] if ensures != nil and ensures.kind != nkEmpty: - let key = stableName(n) + let key = stableName(c.up, n, scope) result = c.mapping.getOrDefault(key) if pointer(result) == nil: - let name = Z3_mk_string_symbol(ctx, $n) + let name = Z3_mk_string_symbol(ctx, key) result = Z3_mk_const(ctx, name, typeToZ3(c, n.typ)) c.mapping[key] = result vars.add n @@ -333,15 +465,24 @@ proc nodeToZ3(c: var DrCon; n: PNode; vars: var seq[PNode]): Z3_ast = for i in 0..n.len-2: isTrivial = isTrivial and n[i].kind in {nkEmpty, nkCommentStmt} if isTrivial: - result = nodeToZ3(c, n[^1], vars) + result = rec n[^1] else: notImplemented(renderTree(n)) of nkHiddenDeref: result = rec n[0] else: - notImplemented(renderTree(n)) + if isLoc(n, c.assumeUniqueness): + let key = stableName(c.up, n, scope) + result = c.mapping.getOrDefault(key) + if pointer(result) == nil: + let name = Z3_mk_string_symbol(ctx, key) + result = Z3_mk_const(ctx, name, typeToZ3(c, n.typ)) + c.mapping[key] = result + vars.add n + else: + notImplemented(renderTree(n)) -proc addRangeInfo(c: var DrCon, n: PNode, res: var seq[Z3_ast]) = +proc addRangeInfo(c: var DrCon, n: PNode; scope: VersionScope, res: var seq[Z3_ast]) = var cmpOp = mLeI if n.typ != nil: cmpOp = @@ -393,15 +534,15 @@ proc addRangeInfo(c: var DrCon, n: PNode, res: var seq[Z3_ast]) = let ensures = op.n[0][ensuresEffects] if ensures != nil and ensures.kind != nkEmpty: var dummy: seq[PNode] - res.add nodeToZ3(c, translateEnsures(ensures, n), dummy) + res.add nodeToZ3(c, translateEnsures(ensures, n), scope, dummy) return let x = newTree(nkInfix, newSymNode createMagic(c.graph, "<=", cmpOp), lowBound, n) let y = newTree(nkInfix, newSymNode createMagic(c.graph, "<=", cmpOp), n, highBound) var dummy: seq[PNode] - res.add nodeToZ3(c, x, dummy) - res.add nodeToZ3(c, y, dummy) + res.add nodeToZ3(c, x, scope, dummy) + res.add nodeToZ3(c, y, scope, dummy) proc on_err(ctx: Z3_context, e: Z3_error_code) {.nimcall.} = #writeStackTrace() @@ -423,18 +564,18 @@ proc conj(ctx: Z3_context; conds: seq[Z3_ast]): Z3_ast = else: result = Z3_mk_true(ctx) -proc proofEngineAux(c: var DrCon; assumptions: seq[PNode]; toProve: PNode): (bool, string) = - c.mapping = initTable[string, Z3_ast]() +proc setupZ3(): Z3_context = let cfg = Z3_mk_config() - Z3_set_param_value(cfg, "model", "true"); - let ctx = Z3_mk_context(cfg) - c.z3 = ctx - Z3_del_config(cfg) - Z3_set_error_handler(ctx, on_err) - when false: Z3_set_param_value(cfg, "timeout", "1000") + Z3_set_param_value(cfg, "model", "true") + result = Z3_mk_context(cfg) + Z3_del_config(cfg) + Z3_set_error_handler(result, on_err) +proc proofEngineAux(c: var DrCon; assumptions: seq[(PNode, VersionScope)]; + toProve: (PNode, VersionScope)): (bool, string) = + c.mapping = initTable[string, Z3_ast]() try: #[ @@ -455,20 +596,21 @@ proc proofEngineAux(c: var DrCon; assumptions: seq[PNode]; toProve: PNode): (boo var collectedVars: seq[PNode] + template ctx(): untyped = c.up.z3 + let solver = Z3_mk_solver(ctx) var lhs: seq[Z3_ast] - for assumption in assumptions: - if assumption != nil: - try: - let za = nodeToZ3(c, assumption, collectedVars) - #Z3_solver_assert ctx, solver, za - lhs.add za - except CannotMapToZ3Error: - discard "ignore a fact we cannot map to Z3" - - let z3toProve = nodeToZ3(c, toProve, collectedVars) + for assumption in items(assumptions): + try: + let za = nodeToZ3(c, assumption[0], assumption[1], collectedVars) + #Z3_solver_assert ctx, solver, za + lhs.add za + except CannotMapToZ3Error: + discard "ignore a fact we cannot map to Z3" + + let z3toProve = nodeToZ3(c, toProve[0], toProve[1], collectedVars) for v in collectedVars: - addRangeInfo(c, v, lhs) + addRangeInfo(c, v, toProve[1], lhs) # to make Z3 produce nice counterexamples, we try to prove the # negation of our conjecture and see if it's Z3_L_FALSE @@ -476,7 +618,8 @@ proc proofEngineAux(c: var DrCon; assumptions: seq[PNode]; toProve: PNode): (boo #Z3_mk_not(ctx, forall(ctx, collectedVars, conj(ctx, lhs), z3toProve)) - #echo "toProve: ", Z3_ast_to_string(ctx, fa), " ", c.graph.config $ toProve.info + when defined(dz3): + echo "toProve: ", Z3_ast_to_string(ctx, fa), " ", c.graph.config $ toProve[0].info, " ", int(toProve[1]) Z3_solver_assert ctx, solver, fa let z3res = Z3_solver_check(ctx, solver) @@ -489,18 +632,22 @@ proc proofEngineAux(c: var DrCon; assumptions: seq[PNode]; toProve: PNode): (boo except ValueError: result[0] = false result[1] = getCurrentExceptionMsg() - finally: - Z3_del_context(ctx) -proc proofEngine(graph: ModuleGraph; assumptions: seq[PNode]; toProve: PNode): (bool, string) = +proc proofEngine(ctx: DrnimContext; assumptions: seq[(PNode, VersionScope)]; + toProve: (PNode, VersionScope)): (bool, string) = var c: DrCon - c.graph = graph + c.graph = ctx.graph + c.assumeUniqueness = assumeUniqueness + c.up = ctx result = proofEngineAux(c, assumptions, toProve) +proc skipAddr(n: PNode): PNode {.inline.} = + (if n.kind == nkHiddenAddr: n[0] else: n) + proc translateReq(r, call: PNode): PNode = if r.kind == nkSym and r.sym.kind == skParam: if r.sym.position+1 < call.len: - result = call[r.sym.position+1] + result = call[r.sym.position+1].skipAddr else: notImplemented("no argument given for formal parameter: " & r.sym.name.s) else: @@ -508,11 +655,11 @@ proc translateReq(r, call: PNode): PNode = for i in 0 ..< safeLen(r): result[i] = translateReq(r[i], call) -proc requirementsCheck(graph: ModuleGraph; assumptions: seq[PNode]; - call, requirement: PNode): (bool, string) {.nimcall.} = +proc requirementsCheck(ctx: DrnimContext; assumptions: seq[(PNode, VersionScope)]; + call, requirement: PNode): (bool, string) = try: let r = translateReq(requirement, call) - result = proofEngine(graph, assumptions, r) + result = proofEngine(ctx, assumptions, (r, ctx.currentScope)) except ValueError: result[0] = false result[1] = getCurrentExceptionMsg() @@ -552,24 +699,347 @@ proc compatibleProps(graph: ModuleGraph; formal, actual: PType): bool {.nimcall. var c: DrCon c.graph = graph c.canonParameterNames = true - if not frequires.isEmpty: - result = not arequires.isEmpty and proofEngineAux(c, @[frequires], arequires)[0] - - if result: - if not fensures.isEmpty: - result = not aensures.isEmpty and proofEngineAux(c, @[aensures], fensures)[0] + try: + c.up = DrnimContext(z3: setupZ3(), o: initOperators(graph), graph: graph, owner: nil) + template zero: untyped = VersionScope(0) + if not frequires.isEmpty: + result = not arequires.isEmpty and proofEngineAux(c, @[(frequires, zero)], (arequires, zero))[0] + + if result: + if not fensures.isEmpty: + result = not aensures.isEmpty and proofEngineAux(c, @[(aensures, zero)], (fensures, zero))[0] + finally: + Z3_del_context(c.up.z3) else: # formal has requirements but 'actual' has none, so make it # incompatible. XXX What if the requirement only mentions that # we already know from the type system? result = frequires.isEmpty and fensures.isEmpty +template config(c: typed): untyped = c.graph.config + +proc addFact(c: DrnimContext; n: PNode) = + let v = c.currentScope + if n[0].kind == nkSym and n[0].sym.magic in {mOr, mAnd}: + c.facts.add((n[1], v)) + c.facts.add((n, v)) + +proc addFactNeg(c: DrnimContext; n: PNode) = + var neg = newNodeI(nkCall, n.info, 2) + neg[0] = newSymNode(c.o.opNot) + neg[1] = n + addFact(c, neg) + +proc prove(c: DrnimContext; prop: PNode): bool = + let (success, m) = proofEngine(c, c.facts, (prop, c.currentScope)) + if not success: + message(c.config, prop.info, warnStaticIndexCheck, "cannot prove: " & $prop & m) + result = success + +proc traversePragmaStmt(c: DrnimContext, n: PNode) = + for it in n: + if it.kind == nkExprColonExpr: + let pragma = whichPragma(it) + if pragma == wAssume: + addFact(c, it[1]) + elif pragma == wInvariant or pragma == wAssert: + if prove(c, it[1]): + addFact(c, it[1]) + +proc requiresCheck(c: DrnimContext, call: PNode; op: PType) = + assert op.n[0].kind == nkEffectList + if requiresEffects < op.n[0].len: + let requires = op.n[0][requiresEffects] + if requires != nil and requires.kind != nkEmpty: + # we need to map the call arguments to the formal parameters used inside + # 'requires': + let (success, m) = requirementsCheck(c, c.facts, call, requires) + if not success: + message(c.config, call.info, warnStaticIndexCheck, "cannot prove: " & $requires & m) + +proc freshVersion(c: DrnimContext; arg: PNode) = + let v = getRoot(arg) + if v != nil: + c.varVersions.add v.id + +proc translateEnsuresFromCall(c: DrnimContext, e, call: PNode): PNode = + if e.kind in nkCallKinds and e[0].kind == nkSym and e[0].sym.magic == mOld: + assert e[1].kind == nkSym and e[1].sym.kind == skParam + let param = e[1].sym + let arg = call[param.position+1].skipAddr + result = buildCall(e[0].sym, arg) + elif e.kind == nkSym and e.sym.kind == skParam: + let param = e.sym + let arg = call[param.position+1].skipAddr + result = arg + else: + result = shallowCopy(e) + for i in 0 ..< safeLen(e): result[i] = translateEnsuresFromCall(c, e[i], call) + +proc collectEnsuredFacts(c: DrnimContext, call: PNode; op: PType) = + assert op.n[0].kind == nkEffectList + for i in 1 ..< min(call.len, op.len): + if op[i].kind == tyVar: + freshVersion(c, call[i].skipAddr) + + if ensuresEffects < op.n[0].len: + let ensures = op.n[0][ensuresEffects] + if ensures != nil and ensures.kind != nkEmpty: + addFact(c, translateEnsuresFromCall(c, ensures, call)) + +proc checkLe(c: DrnimContext, a, b: PNode) = + var cmpOp = mLeI + if a.typ != nil: + case a.typ.skipTypes(abstractInst).kind + of tyFloat..tyFloat128: cmpOp = mLeF64 + of tyChar, tyUInt..tyUInt64: cmpOp = mLeU + else: discard + + let cmp = newTree(nkInfix, newSymNode createMagic(c.graph, "<=", cmpOp), a, b) + cmp.info = a.info + discard prove(c, cmp) + +proc checkBounds(c: DrnimContext; arr, idx: PNode) = + checkLe(c, lowBound(c.config, arr), idx) + checkLe(c, idx, highBound(c.config, arr, c.o)) + +proc checkRange(c: DrnimContext; value: PNode; typ: PType) = + let t = typ.skipTypes(abstractInst - {tyRange}) + if t.kind == tyRange: + let lowBound = copyTree(t.n[0]) + lowBound.info = value.info + let highBound = copyTree(t.n[1]) + highBound.info = value.info + checkLe(c, lowBound, value) + checkLe(c, value, highBound) + +proc addAsgnFact*(c: DrnimContext, key, value: PNode) = + var fact = newNodeI(nkCall, key.info, 3) + fact[0] = newSymNode(c.o.opEq) + fact[1] = key + fact[2] = value + c.facts.add((fact, c.currentScope)) + +proc traverse(c: DrnimContext; n: PNode) + +proc traverseTryStmt(c: DrnimContext; n: PNode) = + traverse(c, n[0]) + let oldFacts = c.facts.len + for i in 1 ..< n.len: + traverse(c, n[i].lastSon) + setLen(c.facts, oldFacts) + +proc traverseCase(c: DrnimContext; n: PNode) = + traverse(c, n[0]) + let oldFacts = c.facts.len + for i in 1 ..< n.len: + traverse(c, n[i].lastSon) + # XXX make this as smart as 'if elif' + setLen(c.facts, oldFacts) + +proc traverseIf(c: DrnimContext; n: PNode) = + traverse(c, n[0][0]) + let oldFacts = c.facts.len + addFact(c, n[0][0]) + + traverse(c, n[0][1]) + + for i in 1..<n.len: + let branch = n[i] + setLen(c.facts, oldFacts) + for j in 0..i-1: + addFactNeg(c, n[j][0]) + if branch.len > 1: + addFact(c, branch[0]) + for i in 0..<branch.len: + traverse(c, branch[i]) + setLen(c.facts, oldFacts) + +proc traverseBlock(c: DrnimContext; n: PNode) = + traverse(c, n) + +proc addFactLe(c: DrnimContext; a, b: PNode) = + c.addFact c.o.opLe.buildCall(a, b) + +proc addFactLt(c: DrnimContext; a, b: PNode) = + c.addFact c.o.opLt.buildCall(a, b) + +proc ensuresCheck(c: DrnimContext; owner: PSym) = + if owner.typ != nil and owner.typ.kind == tyProc and owner.typ.n != nil: + let n = owner.typ.n + if n.len > 0 and n[0].kind == nkEffectList and ensuresEffects < n[0].len: + let ensures = n[0][ensuresEffects] + if ensures != nil and ensures.kind != nkEmpty: + discard prove(c, ensures) + +proc traverseAsgn(c: DrnimContext; n: PNode) = + traverse(c, n[0]) + traverse(c, n[1]) + + proc replaceByOldParams(fact, le: PNode): PNode = + if guards.sameTree(fact, le): + result = newNodeIT(nkCall, fact.info, fact.typ) + result.add newSymNode createMagic(c.graph, "old", mOld) + result.add fact + else: + result = shallowCopy(fact) + for i in 0 ..< safeLen(fact): + result[i] = replaceByOldParams(fact[i], le) + + freshVersion(c, n[0]) + addAsgnFact(c, n[0], replaceByOldParams(n[1], n[0])) + when defined(debug): + echoFacts(c) + +proc traverse(c: DrnimContext; n: PNode) = + case n.kind + of nkEmpty..nkNilLit: + discard "nothing to do" + of nkRaiseStmt, nkBreakStmt, nkContinueStmt: + inc c.hasUnstructedCf + for i in 0..<n.safeLen: + traverse(c, n[i]) + of nkReturnStmt: + for i in 0 ..< n.safeLen: + traverse(c, n[i]) + ensuresCheck(c, c.owner) + of nkCallKinds: + # p's effects are ours too: + var a = n[0] + let op = a.typ + if op != nil and op.kind == tyProc and op.n[0].kind == nkEffectList: + requiresCheck(c, n, op) + collectEnsuredFacts(c, n, op) + if a.kind == nkSym: + case a.sym.magic + of mNew, mNewFinalize, mNewSeq: + # may not look like an assignment, but it is: + let arg = n[1] + freshVersion(c, arg) + traverse(c, arg) + addAsgnFact(c, arg, newNodeIT(nkObjConstr, arg.info, arg.typ)) + of mArrGet, mArrPut: + #if optStaticBoundsCheck in c.currOptions: checkBounds(c, n[1], n[2]) + discard + else: + discard + + for i in 0..<n.safeLen: + traverse(c, n[i]) + of nkDotExpr: + #guardDotAccess(c, n) + for i in 0..<n.len: traverse(c, n[i]) + of nkCheckedFieldExpr: + traverse(c, n[0]) + #checkFieldAccess(c.facts, n, c.config) + of nkTryStmt: traverseTryStmt(c, n) + of nkPragma: traversePragmaStmt(c, n) + of nkAsgn, nkFastAsgn: traverseAsgn(c, n) + of nkVarSection, nkLetSection: + for child in n: + let last = lastSon(child) + if last.kind != nkEmpty: traverse(c, last) + if child.kind == nkIdentDefs and last.kind != nkEmpty: + for i in 0..<child.len-2: + addAsgnFact(c, child[i], last) + elif child.kind == nkVarTuple and last.kind != nkEmpty: + for i in 0..<child.len-1: + if child[i].kind == nkEmpty or + child[i].kind == nkSym and child[i].sym.name.s == "_": + discard "anon variable" + elif last.kind in {nkPar, nkTupleConstr}: + addAsgnFact(c, child[i], last[i]) + of nkConstSection: + for child in n: + let last = lastSon(child) + traverse(c, last) + of nkCaseStmt: traverseCase(c, n) + of nkWhen, nkIfStmt, nkIfExpr: traverseIf(c, n) + of nkBlockStmt, nkBlockExpr: traverseBlock(c, n[1]) + of nkWhileStmt: + # 'while true' loop? + if isTrue(n[0]): + traverseBlock(c, n[1]) + else: + let oldFacts = c.facts.len + addFact(c, n[0]) + traverse(c, n[0]) + traverse(c, n[1]) + setLen(c.facts, oldFacts) + of nkForStmt, nkParForStmt: + # we are very conservative here and assume the loop is never executed: + let oldFacts = c.facts.len + let iterCall = n[n.len-2] + if optStaticBoundsCheck in c.currOptions and iterCall.kind in nkCallKinds: + let op = iterCall[0] + if op.kind == nkSym and fromSystem(op.sym): + let iterVar = n[0] + case op.sym.name.s + of "..", "countup", "countdown": + let lower = iterCall[1] + let upper = iterCall[2] + # for i in 0..n means 0 <= i and i <= n. Countdown is + # the same since only the iteration direction changes. + addFactLe(c, lower, iterVar) + addFactLe(c, iterVar, upper) + of "..<": + let lower = iterCall[1] + let upper = iterCall[2] + addFactLe(c, lower, iterVar) + addFactLt(c, iterVar, upper) + else: discard + + for i in 0..<n.len-2: + let it = n[i] + traverse(c, it) + let loopBody = n[^1] + traverse(c, iterCall) + traverse(c, loopBody) + setLen(c.facts, oldFacts) + of nkTypeSection, nkProcDef, nkConverterDef, nkMethodDef, nkIteratorDef, + nkMacroDef, nkTemplateDef, nkLambda, nkDo, nkFuncDef: + discard + of nkCast: + if n.len == 2: + traverse(c, n[1]) + of nkHiddenStdConv, nkHiddenSubConv, nkConv: + if n.len == 2: + traverse(c, n[1]) + if optStaticBoundsCheck in c.currOptions: + checkRange(c, n[1], n.typ) + of nkObjUpConv, nkObjDownConv, nkChckRange, nkChckRangeF, nkChckRange64: + if n.len == 1: + traverse(c, n[0]) + if optStaticBoundsCheck in c.currOptions: + checkRange(c, n[0], n.typ) + of nkBracketExpr: + if optStaticBoundsCheck in c.currOptions and n.len == 2: + if n[0].typ != nil and skipTypes(n[0].typ, abstractVar).kind != tyTuple: + checkBounds(c, n[0], n[1]) + for i in 0 ..< n.len: traverse(c, n[i]) + else: + for i in 0 ..< n.len: traverse(c, n[i]) + +proc strongSemCheck(graph: ModuleGraph; owner: PSym; n: PNode) = + var c = DrnimContext() + c.currOptions = graph.config.options + owner.options + if optStaticBoundsCheck in c.currOptions: + c.z3 = setupZ3() + c.o = initOperators(graph) + c.graph = graph + c.owner = owner + try: + traverse(c, n) + ensuresCheck(c, owner) + finally: + Z3_del_context(c.z3) + + proc mainCommand(graph: ModuleGraph) = let conf = graph.config conf.lastCmdTime = epochTime() - graph.proofEngine = proofEngine - graph.requirementsCheck = requirementsCheck + graph.strongSemCheck = strongSemCheck graph.compatibleProps = compatibleProps graph.config.errorMax = high(int) # do not stop after first error @@ -600,20 +1070,6 @@ proc mainCommand(graph: ModuleGraph) = "output", output, ]) -proc prependCurDir(f: AbsoluteFile): AbsoluteFile = - when defined(unix): - if os.isAbsolute(f.string): result = f - else: result = AbsoluteFile("./" & f.string) - else: - result = f - -proc addCmdPrefix(result: var string, kind: CmdLineKind) = - # consider moving this to std/parseopt - case kind - of cmdLongOption: result.add "--" - of cmdShortOption: result.add "-" - of cmdArgument, cmdEnd: discard - proc processCmdLine(pass: TCmdLinePass, cmd: string; config: ConfigRef) = var p = parseopt.initOptParser(cmd) var argsCount = 1 @@ -638,7 +1094,11 @@ proc processCmdLine(pass: TCmdLinePass, cmd: string; config: ConfigRef) = p.key = "-" if processArgument(pass, p, argsCount, config): break else: - processSwitch(pass, p, config) + case p.key.normalize + of "assumeunique": + assumeUniqueness = true + else: + processSwitch(pass, p, config) of cmdArgument: config.commandLine.add " " config.commandLine.add p.key.quoteShell |