summary refs log tree commit diff stats
path: root/compiler/seminst.nim
blob: 14631a590ef73225b9c7bdfaa8ffee7f18cfa00d (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
#
#
#           The Nim Compiler
#        (c) Copyright 2012 Andreas Rumpf
#
#    See the file "copying.txt", included in this
#    distribution, for details about the copyright.
#

# This module implements the instantiation of generic procs.
# included from sem.nim

proc addObjFieldsToLocalScope(c: PContext; n: PNode) =
  template rec(n) = addObjFieldsToLocalScope(c, n)
  case n.kind
  of nkRecList:
    for i in countup(0, len(n)-1):
      rec n[i]
  of nkRecCase:
    if n.len > 0: rec n.sons[0]
    for i in countup(1, len(n)-1):
      if n[i].kind in {nkOfBranch, nkElse}: rec lastSon(n[i])
  of nkSym:
    let f = n.sym
    if f.kind == skField and fieldVisible(c, f):
      c.currentScope.symbols.strTableIncl(f, onConflictKeepOld=true)
      incl(f.flags, sfUsed)
      # it is not an error to shadow fields via parameters
  else: discard

proc rawPushProcCon(c: PContext, owner: PSym) =
  var x: PProcCon
  new(x)
  x.owner = owner
  x.next = c.p
  c.p = x

proc rawHandleSelf(c: PContext; owner: PSym) =
  if c.selfName != nil and owner.kind in {skProc, skMethod, skConverter, skIterator, skMacro} and owner.typ != nil:
    let params = owner.typ.n
    if params.len > 1:
      let arg = params[1].sym
      if arg.name.id == c.selfName.id:
        c.p.selfSym = arg
        arg.flags.incl sfIsSelf
        let t = c.p.selfSym.typ.skipTypes(abstractPtrs)
        if t.kind == tyObject:
          addObjFieldsToLocalScope(c, t.n)

proc pushProcCon*(c: PContext; owner: PSym) =
  rawPushProcCon(c, owner)
  rawHandleSelf(c, owner)

iterator instantiateGenericParamList(c: PContext, n: PNode, pt: TIdTable): PSym =
  internalAssert n.kind == nkGenericParams
  for i, a in n.pairs:
    internalAssert a.kind == nkSym
    var q = a.sym
    if q.typ.kind notin {tyTypeDesc, tyGenericParam, tyStatic, tyIter}+tyTypeClasses:
      continue
    let symKind = if q.typ.kind == tyStatic: skConst else: skType
    var s = newSym(symKind, q.name, getCurrOwner(), q.info)
    s.flags = s.flags + {sfUsed, sfFromGeneric}
    var t = PType(idTableGet(pt, q.typ))
    if t == nil:
      if tfRetType in q.typ.flags:
        # keep the generic type and allow the return type to be bound
        # later by semAsgn in return type inference scenario
        t = q.typ
      else:
        localError(a.info, errCannotInstantiateX, s.name.s)
        t = errorType(c)
    elif t.kind == tyGenericParam:
      localError(a.info, errCannotInstantiateX, q.name.s)
      t = errorType(c)
    elif t.kind == tyGenericInvocation:
      #t = instGenericContainer(c, a, t)
      t = generateTypeInstance(c, pt, a, t)
      #t = ReplaceTypeVarsT(cl, t)
    s.typ = t
    if t.kind == tyStatic: s.ast = t.n
    yield s

proc sameInstantiation(a, b: TInstantiation): bool =
  if a.concreteTypes.len == b.concreteTypes.len:
    for i in 0..a.concreteTypes.high:
      if not compareTypes(a.concreteTypes[i], b.concreteTypes[i],
                          flags = {ExactTypeDescValues}): return
    result = true

proc genericCacheGet(genericSym: PSym, entry: TInstantiation;
                     id: CompilesId): PSym =
  if genericSym.procInstCache != nil:
    for inst in genericSym.procInstCache:
      if inst.compilesId == id and sameInstantiation(entry, inst[]):
        return inst.sym

proc removeDefaultParamValues(n: PNode) =
  # we remove default params, because they cannot be instantiated properly
  # and they are not needed anyway for instantiation (each param is already
  # provided).
  when false:
    for i in countup(1, sonsLen(n)-1):
      var a = n.sons[i]
      if a.kind != nkIdentDefs: IllFormedAst(a)
      var L = a.len
      if a.sons[L-1].kind != nkEmpty and a.sons[L-2].kind != nkEmpty:
        # ``param: typ = defaultVal``.
        # We don't need defaultVal for semantic checking and it's wrong for
        # ``cmp: proc (a, b: T): int = cmp``. Hm, for ``cmp = cmp`` that is
        # not possible... XXX We don't solve this issue here.
        a.sons[L-1] = ast.emptyNode

proc freshGenSyms(n: PNode, owner, orig: PSym, symMap: var TIdTable) =
  # we need to create a fresh set of gensym'ed symbols:
  if n.kind == nkSym and sfGenSym in n.sym.flags and n.sym.owner == orig:
    let s = n.sym
    var x = PSym(idTableGet(symMap, s))
    if x == nil:
      x = copySym(s, false)
      x.owner = owner
      idTablePut(symMap, s, x)
    n.sym = x
  else:
    for i in 0 .. <safeLen(n): freshGenSyms(n.sons[i], owner, orig, symMap)

proc addParamOrResult(c: PContext, param: PSym, kind: TSymKind)

proc addProcDecls(c: PContext, fn: PSym) =
  # get the proc itself in scope (e.g. for recursion)
  addDecl(c, fn)

  for i in 1 .. <fn.typ.n.len:
    var param = fn.typ.n.sons[i].sym
    param.owner = fn
    addParamOrResult(c, param, fn.kind)

  maybeAddResult(c, fn, fn.ast)

proc instantiateBody(c: PContext, n, params: PNode, result, orig: PSym) =
  if n.sons[bodyPos].kind != nkEmpty:
    inc c.inGenericInst
    # add it here, so that recursive generic procs are possible:
    var b = n.sons[bodyPos]
    var symMap: TIdTable
    initIdTable symMap
    if params != nil:
      for i in 1 .. <params.len:
        let param = params[i].sym
        if sfGenSym in param.flags:
          idTablePut(symMap, params[i].sym, result.typ.n[param.position+1].sym)
    freshGenSyms(b, result, orig, symMap)
    b = semProcBody(c, b)
    b = hloBody(c, b)
    n.sons[bodyPos] = transformBody(c.module, b, result)
    #echo "code instantiated ", result.name.s
    excl(result.flags, sfForward)
    dec c.inGenericInst

proc fixupInstantiatedSymbols(c: PContext, s: PSym) =
  for i in countup(0, c.generics.len - 1):
    if c.generics[i].genericSym.id == s.id:
      var oldPrc = c.generics[i].inst.sym
      pushInfoContext(oldPrc.info)
      openScope(c)
      var n = oldPrc.ast
      n.sons[bodyPos] = copyTree(s.getBody)
      instantiateBody(c, n, nil, oldPrc, s)
      closeScope(c)
      popInfoContext()

proc sideEffectsCheck(c: PContext, s: PSym) =
  if {sfNoSideEffect, sfSideEffect} * s.flags ==
      {sfNoSideEffect, sfSideEffect}:
    localError(s.info, errXhasSideEffects, s.name.s)

proc instGenericContainer(c: PContext, info: TLineInfo, header: PType,
                          allowMetaTypes = false): PType =
  var cl: TReplTypeVars
  initIdTable(cl.symMap)
  initIdTable(cl.typeMap)
  initIdTable(cl.localCache)
  cl.info = info
  cl.c = c
  cl.allowMetaTypes = allowMetaTypes
  result = replaceTypeVarsT(cl, header)

proc instGenericContainer(c: PContext, n: PNode, header: PType): PType =
  result = instGenericContainer(c, n.info, header)

proc instantiateProcType(c: PContext, pt: TIdTable,
                          prc: PSym, info: TLineInfo) =
  # XXX: Instantiates a generic proc signature, while at the same
  # time adding the instantiated proc params into the current scope.
  # This is necessary, because the instantiation process may refer to
  # these params in situations like this:
  # proc foo[Container](a: Container, b: a.type.Item): type(b.x)
  #
  # Alas, doing this here is probably not enough, because another
  # proc signature could appear in the params:
  # proc foo[T](a: proc (x: T, b: type(x.y))
  #
  # The solution would be to move this logic into semtypinst, but
  # at this point semtypinst have to become part of sem, because it
  # will need to use openScope, addDecl, etc.
  addDecl(c, prc)

  pushInfoContext(info)
  var cl = initTypeVars(c, pt, info, nil)
  var result = instCopyType(cl, prc.typ)
  let originalParams = result.n
  result.n = originalParams.shallowCopy

  for i in 1 .. <result.len:
    # twrong_field_caching requires these 'resetIdTable' calls:
    if i > 1:
      resetIdTable(cl.symMap)
      resetIdTable(cl.localCache)
    result.sons[i] = replaceTypeVarsT(cl, result.sons[i])
    propagateToOwner(result, result.sons[i])
    internalAssert originalParams[i].kind == nkSym
    when true:
      let oldParam = originalParams[i].sym
      let param = copySym(oldParam)
      param.owner = prc
      param.typ = result.sons[i]
      if oldParam.ast != nil:
        param.ast = fitNode(c, param.typ, oldParam.ast)

      # don't be lazy here and call replaceTypeVarsN(cl, originalParams[i])!
      result.n.sons[i] = newSymNode(param)
      addDecl(c, param)
    else:
      let param = replaceTypeVarsN(cl, originalParams[i])
      result.n.sons[i] = param
      param.sym.owner = prc
      addDecl(c, result.n.sons[i].sym)

  resetIdTable(cl.symMap)
  resetIdTable(cl.localCache)
  result.sons[0] = replaceTypeVarsT(cl, result.sons[0])
  result.n.sons[0] = originalParams[0].copyTree

  eraseVoidParams(result)
  skipIntLiteralParams(result)

  prc.typ = result
  maybeAddResult(c, prc, prc.ast)
  popInfoContext()

proc generateInstance(c: PContext, fn: PSym, pt: TIdTable,
                      info: TLineInfo): PSym =
  ## Generates a new instance of a generic procedure.
  ## The `pt` parameter is a type-unsafe mapping table used to link generic
  ## parameters to their concrete types within the generic instance.
  # no need to instantiate generic templates/macros:
  internalAssert fn.kind notin {skMacro, skTemplate}
  # generates an instantiated proc
  if c.instCounter > 1000: internalError(fn.ast.info, "nesting too deep")
  inc(c.instCounter)
  # careful! we copy the whole AST including the possibly nil body!
  var n = copyTree(fn.ast)
  # NOTE: for access of private fields within generics from a different module
  # we set the friend module:
  c.friendModules.add(getModule(fn))
  let oldInTypeClass = c.inTypeClass
  c.inTypeClass = 0
  let oldScope = c.currentScope
  while not isTopLevel(c): c.currentScope = c.currentScope.parent
  result = copySym(fn, false)
  incl(result.flags, sfFromGeneric)
  result.owner = fn
  result.ast = n
  pushOwner(result)

  openScope(c)
  let gp = n.sons[genericParamsPos]
  internalAssert gp.kind != nkEmpty
  n.sons[namePos] = newSymNode(result)
  pushInfoContext(info)
  var entry = TInstantiation.new
  entry.sym = result
  # we need to compare both the generic types and the concrete types:
  # generic[void](), generic[int]()
  # see ttypeor.nim test.
  var i = 0
  newSeq(entry.concreteTypes, fn.typ.len+gp.len-1)
  for s in instantiateGenericParamList(c, gp, pt):
    addDecl(c, s)
    entry.concreteTypes[i] = s.typ
    inc i
  rawPushProcCon(c, result)
  instantiateProcType(c, pt, result, info)
  for j in 1 .. result.typ.len-1:
    entry.concreteTypes[i] = result.typ.sons[j]
    inc i
  if tfTriggersCompileTime in result.typ.flags:
    incl(result.flags, sfCompileTime)
  n.sons[genericParamsPos] = ast.emptyNode
  var oldPrc = genericCacheGet(fn, entry[], c.compilesContextId)
  if oldPrc == nil:
    # we MUST not add potentially wrong instantiations to the caching mechanism.
    # This means recursive instantiations behave differently when in
    # a ``compiles`` context but this is the lesser evil. See
    # bug #1055 (tevilcompiles).
    #if c.compilesContextId == 0:
    rawHandleSelf(c, result)
    entry.compilesId = c.compilesContextId
    fn.procInstCache.safeAdd(entry)
    c.generics.add(makeInstPair(fn, entry))
    if n.sons[pragmasPos].kind != nkEmpty:
      pragma(c, result, n.sons[pragmasPos], allRoutinePragmas)
    if isNil(n.sons[bodyPos]):
      n.sons[bodyPos] = copyTree(fn.getBody)
    instantiateBody(c, n, fn.typ.n, result, fn)
    sideEffectsCheck(c, result)
    paramsTypeCheck(c, result.typ)
  else:
    result = oldPrc
  popProcCon(c)
  popInfoContext()
  closeScope(c)           # close scope for parameters
  popOwner()
  c.currentScope = oldScope
  discard c.friendModules.pop()
  dec(c.instCounter)
  c.inTypeClass = oldInTypeClass
  if result.kind == skMethod: finishMethod(c, result)
an> typ.kind == tySequence and x.kind == nkSym and x.sym.kind == skConst: nkIntLit.newIntNode(x.sym.ast.len-1) else: o.opAdd.buildCall(o.opLen.buildCall(x), minusOne()) result.info = x.info proc reassociation(n: PNode; o: Operators): PNode = result = n # (foo+5)+5 --> foo+10; same for '*' case result.getMagic of someAdd: if result[2].isValue and result[1].getMagic in someAdd and result[1][2].isValue: result = o.opAdd.buildCall(result[1][1], result[1][2] |+| result[2]) if result[2].intVal == 0: result = result[1] of someMul: if result[2].isValue and result[1].getMagic in someMul and result[1][2].isValue: result = o.opMul.buildCall(result[1][1], result[1][2] |*| result[2]) if result[2].intVal == 1: result = result[1] elif result[2].intVal == 0: result = zero() else: discard proc pred(n: PNode): PNode = if n.kind in {nkCharLit..nkUInt64Lit} and n.intVal != low(BiggestInt): result = copyNode(n) dec result.intVal else: result = n proc canon*(n: PNode; o: Operators): PNode = # XXX for now only the new code in 'semparallel' uses this if n.safeLen >= 1: result = shallowCopy(n) for i in 0..<n.len: result[i] = canon(n[i], o) elif n.kind == nkSym and n.sym.kind == skLet and n.sym.astdef.getMagic in (someEq + someAdd + someMul + someMin + someMax + someHigh + {mUnaryLt} + someSub + someLen + someDiv): result = n.sym.astdef.copyTree else: result = n case result.getMagic of someEq, someAdd, someMul, someMin, someMax: # these are symmetric; put value as last: if result[1].isValue and not result[2].isValue: result = swapArgs(result, result[0].sym) # (4 + foo) + 2 --> (foo + 4) + 2 of someHigh: # high == len+(-1) result = o.opAdd.buildCall(o.opLen.buildCall(result[1]), minusOne()) of mUnaryLt: result = buildCall(o.opAdd, result[1], minusOne()) of someSub: # x - 4 --> x + (-4) result = negate(result[1], result[2], result, o) of someLen: result[0] = o.opLen.newSymNode of someLt: # x < y same as x <= y-1: let y = n[2].canon(o) let p = pred(y) let minus = if p != y: p else: o.opAdd.buildCall(y, minusOne()).canon(o) result = o.opLe.buildCall(n[1].canon(o), minus) else: discard result = skipConv(result) result = reassociation(result, o) # most important rule: (x-4) <= a.len --> x <= a.len+4 case result.getMagic of someLe: let x = result[1] let y = result[2] if x.kind in nkCallKinds and x.len == 3 and x[2].isValue and isLetLocation(x[1], true): case x.getMagic of someSub: result = buildCall(result[0].sym, x[1], reassociation(o.opAdd.buildCall(y, x[2]), o)) of someAdd: # Rule A: let plus = negate(y, x[2], nil, o).reassociation(o) if plus != nil: result = buildCall(result[0].sym, x[1], plus) else: discard elif y.kind in nkCallKinds and y.len == 3 and y[2].isValue and isLetLocation(y[1], true): # a.len < x-3 case y.getMagic of someSub: result = buildCall(result[0].sym, y[1], reassociation(o.opAdd.buildCall(x, y[2]), o)) of someAdd: let plus = negate(x, y[2], nil, o).reassociation(o) # ensure that Rule A will not trigger afterwards with the # additional 'not isLetLocation' constraint: if plus != nil and not isLetLocation(x, true): result = buildCall(result[0].sym, plus, y[1]) else: discard elif x.isValue and y.getMagic in someAdd and y[2].isValue: # 0 <= a.len + 3 # -3 <= a.len result[1] = x |-| y[2] result[2] = y[1] elif x.isValue and y.getMagic in someSub and y[2].isValue: # 0 <= a.len - 3 # 3 <= a.len result[1] = x |+| y[2] result[2] = y[1] else: discard proc buildAdd*(a: PNode; b: BiggestInt; o: Operators): PNode = canon(if b != 0: o.opAdd.buildCall(a, nkIntLit.newIntNode(b)) else: a, o) proc usefulFact(n: PNode; o: Operators): PNode = case n.getMagic of someEq: if skipConv(n[2]).kind == nkNilLit and ( isLetLocation(n[1], false) or isVar(n[1])): result = o.opIsNil.buildCall(n[1]) else: if isLetLocation(n[1], true) or isLetLocation(n[2], true): # XXX algebraic simplifications! 'i-1 < a.len' --> 'i < a.len+1' result = n of someLe+someLt: if isLetLocation(n[1], true) or isLetLocation(n[2], true): # XXX algebraic simplifications! 'i-1 < a.len' --> 'i < a.len+1' result = n elif n[1].getMagic in someLen or n[2].getMagic in someLen: # XXX Rethink this whole idea of 'usefulFact' for semparallel result = n of mIsNil: if isLetLocation(n[1], false) or isVar(n[1]): result = n of someIn: if isLetLocation(n[1], true): result = n of mAnd: let a = usefulFact(n[1], o) b = usefulFact(n[2], o) if a != nil and b != nil: result = newNodeI(nkCall, n.info, 3) result[0] = newSymNode(o.opAnd) result[1] = a result[2] = b elif a != nil: result = a elif b != nil: result = b of mNot: let a = usefulFact(n[1], o) if a != nil: result = a.neg(o) of mOr: # 'or' sucks! (p.isNil or q.isNil) --> hard to do anything # with that knowledge... # DeMorgan helps a little though: # not a or not b --> not (a and b) # (x == 3) or (y == 2) ---> not ( not (x==3) and not (y == 2)) # not (x != 3 and y != 2) let a = usefulFact(n[1], o).neg(o) b = usefulFact(n[2], o).neg(o) if a != nil and b != nil: result = newNodeI(nkCall, n.info, 3) result[0] = newSymNode(o.opAnd) result[1] = a result[2] = b result = result.neg(o) elif n.kind == nkSym and n.sym.kind == skLet: # consider: # let a = 2 < x # if a: # ... # We make can easily replace 'a' by '2 < x' here: if n.sym.astdef != nil: result = usefulFact(n.sym.astdef, o) elif n.kind == nkStmtListExpr: result = usefulFact(n.lastSon, o) type TModel* = object s*: seq[PNode] # the "knowledge base" o*: Operators proc addFact*(m: var TModel, nn: PNode) = let n = usefulFact(nn, m.o) if n != nil: m.s.add n proc addFactNeg*(m: var TModel, n: PNode) = let n = n.neg(m.o) if n != nil: addFact(m, n) proc sameOpr(a, b: PSym): bool = case a.magic of someEq: result = b.magic in someEq of someLe: result = b.magic in someLe of someLt: result = b.magic in someLt of someLen: result = b.magic in someLen of someAdd: result = b.magic in someAdd of someSub: result = b.magic in someSub of someMul: result = b.magic in someMul of someDiv: result = b.magic in someDiv else: result = a == b proc sameTree*(a, b: PNode): bool = result = false if a == b: result = true elif a != nil and b != nil and a.kind == b.kind: case a.kind of nkSym: result = a.sym == b.sym if not result and a.sym.magic != mNone: result = a.sym.magic == b.sym.magic or sameOpr(a.sym, b.sym) of nkIdent: result = a.ident.id == b.ident.id of nkCharLit..nkUInt64Lit: result = a.intVal == b.intVal of nkFloatLit..nkFloat64Lit: result = a.floatVal == b.floatVal of nkStrLit..nkTripleStrLit: result = a.strVal == b.strVal of nkType: result = a.typ == b.typ of nkEmpty, nkNilLit: result = true else: if a.len == b.len: for i in 0..<a.len: if not sameTree(a[i], b[i]): return result = true proc hasSubTree(n, x: PNode): bool = if n.sameTree(x): result = true else: for i in 0..n.safeLen-1: if hasSubTree(n[i], x): return true proc invalidateFacts*(m: var TModel, n: PNode) = # We are able to guard local vars (as opposed to 'let' variables)! # 'while p != nil: f(p); p = p.next' # This is actually quite easy to do: # Re-assignments (incl. pass to a 'var' param) trigger an invalidation # of every fact that contains 'v'. # # if x < 4: # if y < 5 # x = unknown() # # we invalidate 'x' here but it's known that x >= 4 # # for the else anyway # else: # echo x # # The same mechanism could be used for more complex data stored on the heap; # procs that 'write: []' cannot invalidate 'n.kind' for instance. In fact, we # could CSE these expressions then and help C's optimizer. for i in 0..high(m.s): if m.s[i] != nil and m.s[i].hasSubTree(n): m.s[i] = nil proc valuesUnequal(a, b: PNode): bool = if a.isValue and b.isValue: result = not sameValue(a, b) proc impliesEq(fact, eq: PNode): TImplication = let (loc, val) = if isLocation(eq[1]): (1, 2) else: (2, 1) case fact[0].sym.magic of someEq: if sameTree(fact[1], eq[loc]): # this is not correct; consider: a == b; a == 1 --> unknown! if sameTree(fact[2], eq[val]): result = impYes elif valuesUnequal(fact[2], eq[val]): result = impNo elif sameTree(fact[2], eq[loc]): if sameTree(fact[1], eq[val]): result = impYes elif valuesUnequal(fact[1], eq[val]): result = impNo of mInSet: # remember: mInSet is 'contains' so the set comes first! if sameTree(fact[2], eq[loc]) and isValue(eq[val]): if inSet(fact[1], eq[val]): result = impYes else: result = impNo of mNot, mOr, mAnd: assert(false, "impliesEq") else: discard proc leImpliesIn(x, c, aSet: PNode): TImplication = if c.kind in {nkCharLit..nkUInt64Lit}: # fact: x <= 4; question x in {56}? # --> true if every value <= 4 is in the set {56} # var value = newIntNode(c.kind, firstOrd(nil, x.typ)) # don't iterate too often: if c.intVal - value.intVal < 1000: var i, pos, neg: int while value.intVal <= c.intVal: if inSet(aSet, value): inc pos else: inc neg inc i; inc value.intVal if pos == i: result = impYes elif neg == i: result = impNo proc geImpliesIn(x, c, aSet: PNode): TImplication = if c.kind in {nkCharLit..nkUInt64Lit}: # fact: x >= 4; question x in {56}? # --> true iff every value >= 4 is in the set {56} # var value = newIntNode(c.kind, c.intVal) let max = lastOrd(nil, x.typ) # don't iterate too often: if max - getInt(value) < toInt128(1000): var i, pos, neg: int while value.intVal <= max: if inSet(aSet, value): inc pos else: inc neg inc i; inc value.intVal if pos == i: result = impYes elif neg == i: result = impNo proc compareSets(a, b: PNode): TImplication = if equalSets(nil, a, b): result = impYes elif intersectSets(nil, a, b).len == 0: result = impNo proc impliesIn(fact, loc, aSet: PNode): TImplication = case fact[0].sym.magic of someEq: if sameTree(fact[1], loc): if inSet(aSet, fact[2]): result = impYes else: result = impNo elif sameTree(fact[2], loc): if inSet(aSet, fact[1]): result = impYes else: result = impNo of mInSet: if sameTree(fact[2], loc): result = compareSets(fact[1], aSet) of someLe: if sameTree(fact[1], loc): result = leImpliesIn(fact[1], fact[2], aSet) elif sameTree(fact[2], loc): result = geImpliesIn(fact[2], fact[1], aSet) of someLt: if sameTree(fact[1], loc): result = leImpliesIn(fact[1], fact[2].pred, aSet) elif sameTree(fact[2], loc): # 4 < x --> 3 <= x result = geImpliesIn(fact[2], fact[1].pred, aSet) of mNot, mOr, mAnd: assert(false, "impliesIn") else: discard proc valueIsNil(n: PNode): TImplication = if n.kind == nkNilLit: impYes elif n.kind in {nkStrLit..nkTripleStrLit, nkBracket, nkObjConstr}: impNo else: impUnknown proc impliesIsNil(fact, eq: PNode): TImplication = case fact[0].sym.magic of mIsNil: if sameTree(fact[1], eq[1]): result = impYes of someEq: if sameTree(fact[1], eq[1]): result = valueIsNil(fact[2].skipConv) elif sameTree(fact[2], eq[1]): result = valueIsNil(fact[1].skipConv) of mNot, mOr, mAnd: assert(false, "impliesIsNil") else: discard proc impliesGe(fact, x, c: PNode): TImplication = assert isLocation(x) case fact[0].sym.magic of someEq: if sameTree(fact[1], x): if isValue(fact[2]) and isValue(c): # fact: x = 4; question x >= 56? --> true iff 4 >= 56 if leValue(c, fact[2]): result = impYes else: result = impNo elif sameTree(fact[2], x): if isValue(fact[1]) and isValue(c): if leValue(c, fact[1]): result = impYes else: result = impNo of someLt: if sameTree(fact[1], x): if isValue(fact[2]) and isValue(c): # fact: x < 4; question N <= x? --> false iff N <= 4 if leValue(fact[2], c): result = impNo # fact: x < 4; question 2 <= x? --> we don't know elif sameTree(fact[2], x): # fact: 3 < x; question: N-1 < x ? --> true iff N-1 <= 3 if isValue(fact[1]) and isValue(c): if leValue(c.pred, fact[1]): result = impYes of someLe: if sameTree(fact[1], x): if isValue(fact[2]) and isValue(c): # fact: x <= 4; question x >= 56? --> false iff 4 <= 56 if leValue(fact[2], c): result = impNo # fact: x <= 4; question x >= 2? --> we don't know elif sameTree(fact[2], x): # fact: 3 <= x; question: x >= 2 ? --> true iff 2 <= 3 if isValue(fact[1]) and isValue(c): if leValue(c, fact[1]): result = impYes of mNot, mOr, mAnd: assert(false, "impliesGe") else: discard proc impliesLe(fact, x, c: PNode): TImplication = if not isLocation(x): return impliesGe(fact, c, x) case fact[0].sym.magic of someEq: if sameTree(fact[1], x): if isValue(fact[2]) and isValue(c): # fact: x = 4; question x <= 56? --> true iff 4 <= 56 if leValue(fact[2], c): result = impYes else: result = impNo elif sameTree(fact[2], x): if isValue(fact[1]) and isValue(c): if leValue(fact[1], c): result = impYes else: result = impNo of someLt: if sameTree(fact[1], x): if isValue(fact[2]) and isValue(c): # fact: x < 4; question x <= N? --> true iff N-1 <= 4 if leValue(fact[2], c.pred): result = impYes # fact: x < 4; question x <= 2? --> we don't know elif sameTree(fact[2], x): # fact: 3 < x; question: x <= 1 ? --> false iff 1 <= 3 if isValue(fact[1]) and isValue(c): if leValue(c, fact[1]): result = impNo of someLe: if sameTree(fact[1], x): if isValue(fact[2]) and isValue(c): # fact: x <= 4; question x <= 56? --> true iff 4 <= 56 if leValue(fact[2], c): result = impYes # fact: x <= 4; question x <= 2? --> we don't know elif sameTree(fact[2], x): # fact: 3 <= x; question: x <= 2 ? --> false iff 2 < 3 if isValue(fact[1]) and isValue(c): if leValue(c, fact[1].pred): result = impNo of mNot, mOr, mAnd: assert(false, "impliesLe") else: discard proc impliesLt(fact, x, c: PNode): TImplication = # x < 3 same as x <= 2: let p = c.pred if p != c: result = impliesLe(fact, x, p) else: # 4 < x same as 3 <= x let q = x.pred if q != x: result = impliesLe(fact, q, c) proc `~`(x: TImplication): TImplication = case x of impUnknown: impUnknown of impNo: impYes of impYes: impNo proc factImplies(fact, prop: PNode): TImplication = case fact.getMagic of mNot: # Consider: # enum nkBinary, nkTernary, nkStr # fact: not (k <= nkBinary) # question: k in {nkStr} # --> 'not' for facts is entirely different than 'not' for questions! # it's provably wrong if every value > 4 is in the set {56} # That's because we compute the implication and 'a -> not b' cannot # be treated the same as 'not a -> b' # (not a) -> b compute as not (a -> b) ??? # == not a or not b == not (a and b) let arg = fact[1] case arg.getMagic of mIsNil, mEqRef: return ~factImplies(arg, prop) of mAnd: # not (a and b) means not a or not b: # a or b --> both need to imply 'prop' let a = factImplies(arg[1], prop) let b = factImplies(arg[2], prop) if a == b: return ~a return impUnknown else: return impUnknown of mAnd: result = factImplies(fact[1], prop) if result != impUnknown: return result return factImplies(fact[2], prop) else: discard case prop[0].sym.magic of mNot: result = ~fact.factImplies(prop[1]) of mIsNil: result = impliesIsNil(fact, prop) of someEq: result = impliesEq(fact, prop) of someLe: result = impliesLe(fact, prop[1], prop[2]) of someLt: result = impliesLt(fact, prop[1], prop[2]) of mInSet: result = impliesIn(fact, prop[2], prop[1]) else: result = impUnknown proc doesImply*(facts: TModel, prop: PNode): TImplication = assert prop.kind in nkCallKinds for f in facts.s: # facts can be invalidated, in which case they are 'nil': if not f.isNil: result = f.factImplies(prop) if result != impUnknown: return proc impliesNotNil*(m: TModel, arg: PNode): TImplication = result = doesImply(m, m.o.opIsNil.buildCall(arg).neg(m.o)) proc simpleSlice*(a, b: PNode): BiggestInt = # returns 'c' if a..b matches (i+c)..(i+c), -1 otherwise. (i)..(i) is matched # as if it is (i+0)..(i+0). if guards.sameTree(a, b): if a.getMagic in someAdd and a[2].kind in {nkCharLit..nkUInt64Lit}: result = a[2].intVal else: result = 0 else: result = -1 template isMul(x): untyped = x.getMagic in someMul template isDiv(x): untyped = x.getMagic in someDiv template isAdd(x): untyped = x.getMagic in someAdd template isSub(x): untyped = x.getMagic in someSub template isVal(x): untyped = x.kind in {nkCharLit..nkUInt64Lit} template isIntVal(x, y): untyped = x.intVal == y import macros macro `=~`(x: PNode, pat: untyped): bool = proc m(x, pat, conds: NimNode) = case pat.kind of nnkInfix: case $pat[0] of "*": conds.add getAst(isMul(x)) of "/": conds.add getAst(isDiv(x)) of "+": conds.add getAst(isAdd(x)) of "-": conds.add getAst(isSub(x)) else: error("invalid pattern") m(newTree(nnkBracketExpr, x, newLit(1)), pat[1], conds) m(newTree(nnkBracketExpr, x, newLit(2)), pat[2], conds) of nnkPar: if pat.len == 1: m(x, pat[0], conds) else: error("invalid pattern") of nnkIdent: let c = newTree(nnkStmtListExpr, newLetStmt(pat, x)) conds.add c # XXX why is this 'isVal(pat)' and not 'isVal(x)'? if ($pat)[^1] == 'c': c.add(getAst(isVal(x))) else: c.add bindSym"true" of nnkIntLit: conds.add(getAst(isIntVal(x, pat.intVal))) else: error("invalid pattern") var conds = newTree(nnkBracket) m(x, pat, conds) when compiles(nestList(ident"and", conds)): result = nestList(ident"and", conds) #elif declared(macros.toNimIdent): # result = nestList(toNimIdent"and", conds) else: result = nestList(!"and", conds) proc isMinusOne(n: PNode): bool = n.kind in {nkCharLit..nkUInt64Lit} and n.intVal == -1 proc pleViaModel(model: TModel; aa, bb: PNode): TImplication proc ple(m: TModel; a, b: PNode): TImplication = template `<=?`(a,b): untyped = ple(m,a,b) == impYes template `>=?`(a,b): untyped = ple(m, nkIntLit.newIntNode(b), a) == impYes # 0 <= 3 if a.isValue and b.isValue: return if leValue(a, b): impYes else: impNo # use type information too: x <= 4 iff high(x) <= 4 if b.isValue and a.typ != nil and a.typ.isOrdinalType: if lastOrd(nil, a.typ) <= b.intVal: return impYes # 3 <= x iff low(x) <= 3 if a.isValue and b.typ != nil and b.typ.isOrdinalType: if firstOrd(nil, b.typ) <= a.intVal: return impYes # x <= x if sameTree(a, b): return impYes # 0 <= x.len if b.getMagic in someLen and a.isValue: if a.intVal <= 0: return impYes # x <= y+c if 0 <= c and x <= y # x <= y+(-c) if c <= 0 and y >= x if b.getMagic in someAdd and zero() <=? b[2] and a <=? b[1]: return impYes # x+c <= y if c <= 0 and x <= y if a.getMagic in someAdd and a[2] <=? zero() and a[1] <=? b: return impYes # x <= y*c if 1 <= c and x <= y and 0 <= y if b.getMagic in someMul: if a <=? b[1] and one() <=? b[2] and zero() <=? b[1]: return impYes if a.getMagic in someMul and a[2].isValue and a[1].getMagic in someDiv and a[1][2].isValue: # simplify (x div 4) * 2 <= y to x div (c div d) <= y if ple(m, buildCall(m.o.opDiv, a[1][1], `|div|`(a[1][2], a[2])), b) == impYes: return impYes # x*3 + x == x*4. It follows that: # x*3 + y <= x*4 if y <= x and 3 <= 4 if a =~ x*dc + y and b =~ x2*ec: if sameTree(x, x2): let ec1 = m.o.opAdd.buildCall(ec, minusOne()) if x >=? 1 and ec >=? 1 and dc >=? 1 and dc <=? ec1 and y <=? x: return impYes elif a =~ x*dc and b =~ x2*ec + y: #echo "BUG cam ehrer e ", a, " <=? ", b if sameTree(x, x2): let ec1 = m.o.opAdd.buildCall(ec, minusOne()) if x >=? 1 and ec >=? 1 and dc >=? 1 and dc <=? ec1 and y <=? zero(): return impYes # x+c <= x+d if c <= d. Same for *, - etc. if a.getMagic in someBinaryOp and a.getMagic == b.getMagic: if sameTree(a[1], b[1]) and a[2] <=? b[2]: return impYes elif sameTree(a[2], b[2]) and a[1] <=? b[1]: return impYes # x div c <= y if 1 <= c and 0 <= y and x <= y: if a.getMagic in someDiv: if one() <=? a[2] and zero() <=? b and a[1] <=? b: return impYes # x div c <= x div d if d <= c if b.getMagic in someDiv: if sameTree(a[1], b[1]) and b[2] <=? a[2]: return impYes # x div z <= x - 1 if z <= x if a[2].isValue and b.getMagic in someAdd and b[2].isMinusOne: if a[2] <=? a[1] and sameTree(a[1], b[1]): return impYes # slightly subtle: # x <= max(y, z) iff x <= y or x <= z # note that 'x <= max(x, z)' is a special case of the above rule if b.getMagic in someMax: if a <=? b[1] or a <=? b[2]: return impYes # min(x, y) <= z iff x <= z or y <= z if a.getMagic in someMin: if a[1] <=? b or a[2] <=? b: return impYes # use the knowledge base: return pleViaModel(m, a, b) #return doesImply(m, o.opLe.buildCall(a, b)) type TReplacements = seq[tuple[a, b: PNode]] proc replaceSubTree(n, x, by: PNode): PNode = if sameTree(n, x): result = by elif hasSubTree(n, x): result = shallowCopy(n) for i in 0..n.safeLen-1: result[i] = replaceSubTree(n[i], x, by) else: result = n proc applyReplacements(n: PNode; rep: TReplacements): PNode = result = n for x in rep: result = result.replaceSubTree(x.a, x.b) proc pleViaModelRec(m: var TModel; a, b: PNode): TImplication = # now check for inferrable facts: a <= b and b <= c implies a <= c for i in 0..m.s.high: let fact = m.s[i] if fact != nil and fact.getMagic in someLe: # mark as used: m.s[i] = nil # i <= len-100 # i <=? len-1 # --> true if (len-100) <= (len-1) let x = fact[1] let y = fact[2] if sameTree(x, a) and y.getMagic in someAdd and b.getMagic in someAdd and sameTree(y[1], b[1]): if ple(m, b[2], y[2]) == impYes: return impYes # x <= y implies a <= b if a <= x and y <= b if ple(m, a, x) == impYes: if ple(m, y, b) == impYes: return impYes #if pleViaModelRec(m, y, b): return impYes # fact: 16 <= i # x y # question: i <= 15? no! result = impliesLe(fact, a, b) if result != impUnknown: return result when false: # given: x <= y; y==a; x <= a this means: a <= b if x <= b if sameTree(y, a): result = ple(m, b, x) if result != impUnknown: return result proc pleViaModel(model: TModel; aa, bb: PNode): TImplication = # compute replacements: var replacements: TReplacements = @[] for fact in model.s: if fact != nil and fact.getMagic in someEq: let a = fact[1] let b = fact[2] if a.kind == nkSym: replacements.add((a,b)) else: replacements.add((b,a)) var m: TModel var a = aa var b = bb if replacements.len > 0: m.s = @[] m.o = model.o # make the other facts consistent: for fact in model.s: if fact != nil and fact.getMagic notin someEq: # XXX 'canon' should not be necessary here, but it is m.s.add applyReplacements(fact, replacements).canon(m.o) a = applyReplacements(aa, replacements) b = applyReplacements(bb, replacements) else: # we have to make a copy here, because the model will be modified: m = model result = pleViaModelRec(m, a, b) proc proveLe*(m: TModel; a, b: PNode): TImplication = let x = canon(m.o.opLe.buildCall(a, b), m.o) #echo "ROOT ", renderTree(x[1]), " <=? ", renderTree(x[2]) result = ple(m, x[1], x[2]) if result == impUnknown: # try an alternative: a <= b iff not (b < a) iff not (b+1 <= a): let y = canon(m.o.opLe.buildCall(m.o.opAdd.buildCall(b, one()), a), m.o) result = ~ple(m, y[1], y[2]) proc addFactLe*(m: var TModel; a, b: PNode) = m.s.add canon(m.o.opLe.buildCall(a, b), m.o) proc settype(n: PNode): PType = result = newType(tySet, n.typ.owner) addSonSkipIntLit(result, n.typ) proc buildOf(it, loc: PNode; o: Operators): PNode = var s = newNodeI(nkCurly, it.info, it.len-1) s.typ = settype(loc) for i in 0..<it.len-1: s[i] = it[i] result = newNodeI(nkCall, it.info, 3) result[0] = newSymNode(o.opContains) result[1] = s result[2] = loc proc buildElse(n: PNode; o: Operators): PNode = var s = newNodeIT(nkCurly, n.info, settype(n[0])) for i in 1..<n.len-1: let branch = n[i] assert branch.kind != nkElse if branch.kind == nkOfBranch: for j in 0..<branch.len-1: s.add(branch[j]) result = newNodeI(nkCall, n.info, 3) result[0] = newSymNode(o.opContains) result[1] = s result[2] = n[0] proc addDiscriminantFact*(m: var TModel, n: PNode) = var fact = newNodeI(nkCall, n.info, 3) fact[0] = newSymNode(m.o.opEq) fact[1] = n[0] fact[2] = n[1] m.s.add fact proc addAsgnFact*(m: var TModel, key, value: PNode) = var fact = newNodeI(nkCall, key.info, 3) fact[0] = newSymNode(m.o.opEq) fact[1] = key fact[2] = value m.s.add fact proc sameSubexprs*(m: TModel; a, b: PNode): bool = # This should be used to check whether two *path expressions* refer to the # same memory location according to 'm'. This is tricky: # lock a[i].guard: # ... # access a[i].guarded # # Here a[i] is the same as a[i] iff 'i' and 'a' are not changed via '...'. # However, nil checking requires exactly the same mechanism! But for now # we simply use sameTree and live with the unsoundness of the analysis. var check = newNodeI(nkCall, a.info, 3) check[0] = newSymNode(m.o.opEq) check[1] = a check[2] = b result = m.doesImply(check) == impYes proc addCaseBranchFacts*(m: var TModel, n: PNode, i: int) = let branch = n[i] if branch.kind == nkOfBranch: m.s.add buildOf(branch, n[0], m.o) else: m.s.add n.buildElse(m.o).neg(m.o) proc buildProperFieldCheck(access, check: PNode; o: Operators): PNode = if check[1].kind == nkCurly: result = copyTree(check) if access.kind == nkDotExpr: var a = copyTree(access) a[1] = check[2] result[2] = a # 'access.kind != nkDotExpr' can happen for object constructors # which we don't check yet else: # it is some 'not' assert check.getMagic == mNot result = buildProperFieldCheck(access, check[1], o).neg(o) proc checkFieldAccess*(m: TModel, n: PNode; conf: ConfigRef) = for i in 1..<n.len: let check = buildProperFieldCheck(n[0], n[i], m.o) if check != nil and m.doesImply(check) != impYes: message(conf, n.info, warnProveField, renderTree(n[0])); break