summary refs log tree commit diff stats
path: root/compiler/sigmatch.nim
diff options
context:
space:
mode:
authorZahary Karadjov <zahary@gmail.com>2016-08-12 03:25:59 +0300
committerZahary Karadjov <zahary@gmail.com>2017-03-24 16:58:15 +0200
commit0f2c4be1299fc99aeea2011c57240c8cfabd83c3 (patch)
tree64fa8e58c3c9745b54c127f8fdc3b897ff3896a0 /compiler/sigmatch.nim
parent0b0a3e5f203f6b21f3790a6cd50ceeaa8786badc (diff)
downloadNim-0f2c4be1299fc99aeea2011c57240c8cfabd83c3.tar.gz
infer static parameters even when more complicated arithmetic is involved
Diffstat (limited to 'compiler/sigmatch.nim')
-rw-r--r--compiler/sigmatch.nim143
1 files changed, 114 insertions, 29 deletions
diff --git a/compiler/sigmatch.nim b/compiler/sigmatch.nim
index 162385e6d..ca9cdcaf8 100644
--- a/compiler/sigmatch.nim
+++ b/compiler/sigmatch.nim
@@ -681,16 +681,125 @@ proc maybeSkipDistinct(t: PType, callee: PSym): PType =
   else:
     result = t
 
-proc tryResolvingStaticExpr(c: var TCandidate, n: PNode): PNode =
+proc tryResolvingStaticExpr(c: var TCandidate, n: PNode,
+                            allowUnresolved = false): PNode =
   # Consider this example:
   #   type Value[N: static[int]] = object
   #   proc foo[N](a: Value[N], r: range[0..(N-1)])
   # Here, N-1 will be initially nkStaticExpr that can be evaluated only after
   # N is bound to a concrete value during the matching of the first param.
   # This proc is used to evaluate such static expressions.
-  let instantiated = replaceTypesInBody(c.c, c.bindings, n, nil)
+  let instantiated = replaceTypesInBody(c.c, c.bindings, n, nil,
+                                        allowMetaTypes = allowUnresolved)
   result = c.c.semExpr(c.c, instantiated)
 
+proc inferStaticParam*(lhs: PNode, rhs: BiggestInt): PType =
+  # This is a simple integer arithimetic equation solver,
+  # capable of deriving the value of a static parameter in
+  # expressions such as (N + 5) / 2 = rhs
+  #
+  # Preconditions:
+  #
+  #   * The input of this proc must be semantized 
+  #     - all templates should be expanded
+  #     - aby constant folding possible should already be performed
+  #
+  #   * There must be exactly one unresolved static parameter
+  #
+  # Result:
+  #
+  #   The proc will return the inferred static type with the `n` field
+  #   populated with the inferred value.
+  #
+  #   `nil` will be returned if the inference was not possible
+  #
+  if lhs.kind in nkCallKinds and lhs[0].kind == nkSym:
+    case lhs[0].sym.magic
+    of mUnaryLt:
+      return inferStaticParam(lhs[1], rhs + 1)
+
+    of mAddI, mAddU, mInc, mSucc:
+      if lhs[1].kind == nkIntLit:
+        return inferStaticParam(lhs[2], rhs - lhs[1].intVal)
+      elif lhs[2].kind == nkIntLit:
+        return inferStaticParam(lhs[1], rhs - lhs[2].intVal)
+    
+    of mDec, mSubI, mSubU, mPred:
+      if lhs[1].kind == nkIntLit:
+        return inferStaticParam(lhs[2], lhs[1].intVal - rhs)
+      elif lhs[2].kind == nkIntLit:
+        return inferStaticParam(lhs[1], rhs + lhs[2].intVal)
+    
+    of mMulI, mMulU:
+      if lhs[1].kind == nkIntLit:
+        if rhs mod lhs[1].intVal == 0:
+          return inferStaticParam(lhs[2], rhs div lhs[1].intVal)
+      elif lhs[2].kind == nkIntLit:
+        if rhs mod lhs[2].intVal == 0:
+          return inferStaticParam(lhs[1], rhs div lhs[2].intVal)
+    
+    of mDivI, mDivU:
+      if lhs[1].kind == nkIntLit:
+        if lhs[1].intVal mod rhs == 0:
+          return inferStaticParam(lhs[2], lhs[1].intVal div rhs)
+      elif lhs[2].kind == nkIntLit:
+        return inferStaticParam(lhs[1], lhs[2].intVal * rhs)
+    
+    of mShlI:
+      if lhs[2].kind == nkIntLit:
+        return inferStaticParam(lhs[1], rhs shr lhs[2].intVal)
+    
+    of mShrI:
+      if lhs[2].kind == nkIntLit:
+        return inferStaticParam(lhs[1], rhs shl lhs[2].intVal)
+    
+    of mUnaryMinusI:
+      return inferStaticParam(lhs[1], -rhs)
+    
+    of mUnaryPlusI, mToInt, mToBiggestInt:
+      return inferStaticParam(lhs[1], rhs)
+    
+    else: discard
+  
+  elif lhs.kind == nkSym and lhs.typ.kind == tyStatic and lhs.typ.n == nil:
+    lhs.typ.n = newIntNode(nkIntLit, rhs)
+    return lhs.typ
+  
+  return nil
+
+proc failureToInferStaticParam(n: PNode) =
+  let staticParam = n.findUnresolvedStatic
+  let name = if staticParam != nil: staticParam.sym.name.s
+             else: "unknown"
+  localError(n.info, errCannotInferStaticParam, name)
+
+proc inferStaticsInRange(c: var TCandidate,
+                         inferred, concrete: PType): TTypeRelation =
+  let lowerBound = tryResolvingStaticExpr(c, inferred.n[0],
+                                          allowUnresolved = true)
+  let upperBound = tryResolvingStaticExpr(c, inferred.n[1],
+                                          allowUnresolved = true)
+  
+  template doInferStatic(c: var TCandidate, e: PNode, r: BiggestInt) =
+    var exp = e
+    var rhs = r
+    var inferred = inferStaticParam(exp, rhs)
+    if inferred != nil:
+      put(c.bindings, inferred, inferred)
+      return isGeneric
+    else:
+      failureToInferStaticParam exp
+
+  if lowerBound.kind == nkIntLit:
+    if upperBound.kind == nkIntLit:
+      if lengthOrd(concrete) == upperBound.intVal - lowerBound.intVal + 1:
+        return isGeneric
+      else:
+        return isNone
+    doInferStatic(c, upperBound, lengthOrd(concrete) + lowerBound.intVal - 1)
+  elif upperBound.kind == nkIntLit:
+    doInferStatic(c, lowerBound, upperBound.intVal + 1 - lengthOrd(concrete))
+
 template subtypeCheck() =
   if result <= isSubrange and f.lastSon.skipTypes(abstractInst).kind in {tyRef, tyPtr, tyVar}:
     result = isNone
@@ -894,34 +1003,10 @@ proc typeRel(c: var TCandidate, f, aOrig: PType, doBind = true): TTypeRelation =
                           a.sons[1].skipTypes({tyTypeDesc}))
       if result < isGeneric: return isNone
       
-      proc inferStaticRange(c: var TCandidate, inferred, concrete: PType) =
-        var (staticT, offset) = inferred.findUnresolvedStaticInRange
-        var
-          replacementT = newTypeWithSons(c.c, tyStatic, @[tyInt.getSysType])
-          concreteUpperBound = concrete.n[1].intVal
-        # we must correct for the off-by-one discrepancy between
-        # ranges and static params:
-        replacementT.n = newIntNode(nkIntLit, concreteUpperBound + offset)
-        if tfInferrableStatic in staticT.flags:
-          staticT.n = replacementT.n
-        put(c.bindings, staticT, replacementT)
-
-      if rangeHasUnresolvedStatic(fRange):
-        if tfUnresolved in fRange.flags:
-          # This is a range from an array instantiated with a generic
-          # static param. We must extract the static param here and bind
-          # it to the size of the currently supplied array.
-          inferStaticRange(c, fRange, aRange)
-          return isGeneric
-
-        let len = tryResolvingStaticExpr(c, fRange.n[1])
-        if len.kind == nkIntLit and len.intVal+1 == lengthOrd(a):
-          return # if we get this far, the result is already good
-        else:
-          return isNone
+      if fRange.rangeHasUnresolvedStatic:
+        return inferStaticsInRange(c, fRange, a)
       elif c.c.inTypeClass > 0 and aRange.rangeHasUnresolvedStatic:
-        inferStaticRange(c, aRange, fRange)
-        return isGeneric
+        return inferStaticsInRange(c, aRange, f)
       elif lengthOrd(fRange) != lengthOrd(a):
         result = isNone
     else: discard