summary refs log tree commit diff stats
diff options
context:
space:
mode:
-rw-r--r--compiler/options.nim3
-rw-r--r--compiler/semcall.nim64
-rw-r--r--compiler/semdata.nim2
-rw-r--r--compiler/semexprs.nim10
-rw-r--r--doc/manual_experimental.md81
-rw-r--r--tests/generics/treturn_inference.nim139
6 files changed, 288 insertions, 11 deletions
diff --git a/compiler/options.nim b/compiler/options.nim
index 8286a575d..5b61cb049 100644
--- a/compiler/options.nim
+++ b/compiler/options.nim
@@ -220,7 +220,8 @@ type
     unicodeOperators, # deadcode
     flexibleOptionalParams,
     strictDefs,
-    strictCaseObjects
+    strictCaseObjects,
+    inferGenericTypes
 
   LegacyFeature* = enum
     allowSemcheckedAstModification,
diff --git a/compiler/semcall.nim b/compiler/semcall.nim
index d2460ab06..f0d0f648a 100644
--- a/compiler/semcall.nim
+++ b/compiler/semcall.nim
@@ -562,8 +562,61 @@ proc getCallLineInfo(n: PNode): TLineInfo =
     discard
   result = n.info
 
-proc semResolvedCall(c: PContext, x: TCandidate,
-                     n: PNode, flags: TExprFlags): PNode =
+proc inheritBindings(c: PContext, x: var TCandidate, expectedType: PType) =
+  ## Helper proc to inherit bound generic parameters from expectedType into x.
+  ## Does nothing if 'inferGenericTypes' isn't in c.features
+  if inferGenericTypes notin c.features: return
+  if expectedType == nil or x.callee[0] == nil: return # required for inference
+
+  var
+    flatUnbound: seq[PType]
+    flatBound: seq[PType]
+  # seq[(result type, expected type)]
+  var typeStack = newSeq[(PType, PType)]()
+
+  template stackPut(a, b) =
+    ## skips types and puts the skipped version on stack
+    # It might make sense to skip here one by one. It's not part of the main
+    #  type reduction because the right side normally won't be skipped
+    const toSkip = { tyVar, tyLent, tyStatic, tyCompositeTypeClass }
+    let
+      x = a.skipTypes(toSkip)
+      y = if a.kind notin toSkip: b
+          else: b.skipTypes(toSkip)
+    typeStack.add((x, y))
+
+  stackPut(x.callee[0], expectedType)
+
+  while typeStack.len() > 0:
+    let (t, u) = typeStack.pop()
+    if t == u or t == nil or u == nil or t.kind == tyAnything or u.kind == tyAnything:
+      continue
+    case t.kind
+    of ConcreteTypes, tyGenericInvocation, tyUncheckedArray:
+      # nested, add all the types to stack
+      let
+        startIdx = if u.kind in ConcreteTypes: 0 else: 1
+        endIdx = min(u.sons.len() - startIdx, t.sons.len())
+
+      for i in startIdx ..< endIdx:
+        # early exit with current impl
+        if t[i] == nil or u[i] == nil: return
+        stackPut(t[i], u[i])
+    of tyGenericParam:
+      if x.bindings.idTableGet(t) != nil: return
+
+      # fully reduced generic param, bind it
+      if t notin flatUnbound:
+        flatUnbound.add(t)
+        flatBound.add(u)
+    else:
+      discard
+  for i in 0 ..< flatUnbound.len():
+    x.bindings.idTablePut(flatUnbound[i], flatBound[i])
+
+proc semResolvedCall(c: PContext, x: var TCandidate,
+                     n: PNode, flags: TExprFlags;
+                     expectedType: PType = nil): PNode =
   assert x.state == csMatch
   var finalCallee = x.calleeSym
   let info = getCallLineInfo(n)
@@ -583,10 +636,12 @@ proc semResolvedCall(c: PContext, x: TCandidate,
       if x.calleeSym.magic in {mArrGet, mArrPut}:
         finalCallee = x.calleeSym
       else:
+        c.inheritBindings(x, expectedType)
         finalCallee = generateInstance(c, x.calleeSym, x.bindings, n.info)
     else:
       # For macros and templates, the resolved generic params
       # are added as normal params.
+      c.inheritBindings(x, expectedType)
       for s in instantiateGenericParamList(c, gp, x.bindings):
         case s.kind
         of skConst:
@@ -615,7 +670,8 @@ proc tryDeref(n: PNode): PNode =
   result.add n
 
 proc semOverloadedCall(c: PContext, n, nOrig: PNode,
-                       filter: TSymKinds, flags: TExprFlags): PNode =
+                       filter: TSymKinds, flags: TExprFlags;
+                       expectedType: PType = nil): PNode =
   var errors: CandidateErrors = @[] # if efExplain in flags: @[] else: nil
   var r = resolveOverloads(c, n, nOrig, filter, flags, errors, efExplain in flags)
   if r.state == csMatch:
@@ -625,7 +681,7 @@ proc semOverloadedCall(c: PContext, n, nOrig: PNode,
       message(c.config, n.info, hintUserRaw,
               "Non-matching candidates for " & renderTree(n) & "\n" &
               candidates)
-    result = semResolvedCall(c, r, n, flags)
+    result = semResolvedCall(c, r, n, flags, expectedType)
   else:
     if efDetermineType in flags and c.inGenericContext > 0 and c.matchedConcept == nil:
       result = semGenericStmt(c, n)
diff --git a/compiler/semdata.nim b/compiler/semdata.nim
index ddd8d33ef..a85f8e638 100644
--- a/compiler/semdata.nim
+++ b/compiler/semdata.nim
@@ -135,7 +135,7 @@ type
     semOperand*: proc (c: PContext, n: PNode, flags: TExprFlags = {}): PNode {.nimcall.}
     semConstBoolExpr*: proc (c: PContext, n: PNode): PNode {.nimcall.} # XXX bite the bullet
     semOverloadedCall*: proc (c: PContext, n, nOrig: PNode,
-                              filter: TSymKinds, flags: TExprFlags): PNode {.nimcall.}
+                              filter: TSymKinds, flags: TExprFlags, expectedType: PType = nil): PNode {.nimcall.}
     semTypeNode*: proc(c: PContext, n: PNode, prev: PType): PType {.nimcall.}
     semInferredLambda*: proc(c: PContext, pt: TIdTable, n: PNode): PNode
     semGenerateInstance*: proc (c: PContext, fn: PSym, pt: TIdTable,
diff --git a/compiler/semexprs.nim b/compiler/semexprs.nim
index c6be3e833..398424bbf 100644
--- a/compiler/semexprs.nim
+++ b/compiler/semexprs.nim
@@ -952,17 +952,17 @@ proc semStaticExpr(c: PContext, n: PNode; expectedType: PType = nil): PNode =
     result = fixupTypeAfterEval(c, result, a)
 
 proc semOverloadedCallAnalyseEffects(c: PContext, n: PNode, nOrig: PNode,
-                                     flags: TExprFlags): PNode =
+                                     flags: TExprFlags; expectedType: PType = nil): PNode =
   if flags*{efInTypeof, efWantIterator, efWantIterable} != {}:
     # consider: 'for x in pReturningArray()' --> we don't want the restriction
     # to 'skIterator' anymore; skIterator is preferred in sigmatch already
     # for typeof support.
     # for ``typeof(countup(1,3))``, see ``tests/ttoseq``.
     result = semOverloadedCall(c, n, nOrig,
-      {skProc, skFunc, skMethod, skConverter, skMacro, skTemplate, skIterator}, flags)
+      {skProc, skFunc, skMethod, skConverter, skMacro, skTemplate, skIterator}, flags, expectedType)
   else:
     result = semOverloadedCall(c, n, nOrig,
-      {skProc, skFunc, skMethod, skConverter, skMacro, skTemplate}, flags)
+      {skProc, skFunc, skMethod, skConverter, skMacro, skTemplate}, flags, expectedType)
 
   if result != nil:
     if result[0].kind != nkSym:
@@ -1138,7 +1138,7 @@ proc semDirectOp(c: PContext, n: PNode, flags: TExprFlags; expectedType: PType =
   # this seems to be a hotspot in the compiler!
   let nOrig = n.copyTree
   #semLazyOpAux(c, n)
-  result = semOverloadedCallAnalyseEffects(c, n, nOrig, flags)
+  result = semOverloadedCallAnalyseEffects(c, n, nOrig, flags, expectedType)
   if result != nil: result = afterCallActions(c, result, nOrig, flags, expectedType)
   else: result = errorNode(c, n)
 
@@ -3120,7 +3120,7 @@ proc semExpr(c: PContext, n: PNode, flags: TExprFlags = {}, expectedType: PType
         elif s.magic == mNone: result = semDirectOp(c, n, flags, expectedType)
         else: result = semMagic(c, n, s, flags, expectedType)
       of skProc, skFunc, skMethod, skConverter, skIterator:
-        if s.magic == mNone: result = semDirectOp(c, n, flags)
+        if s.magic == mNone: result = semDirectOp(c, n, flags, expectedType)
         else: result = semMagic(c, n, s, flags, expectedType)
       else:
         #liMessage(n.info, warnUser, renderTree(n));
diff --git a/doc/manual_experimental.md b/doc/manual_experimental.md
index 602ca46a5..4ee035b65 100644
--- a/doc/manual_experimental.md
+++ b/doc/manual_experimental.md
@@ -124,6 +124,87 @@ would not match the type of the variable, and an error would be given.
 
 The extent of this varies, but there are some notable special cases.
 
+
+Inferred generic parameters
+---------------------------
+
+In expressions making use of generic procs or templates, the expected
+(unbound) types are often able to be inferred based on context.
+This feature has to be enabled via `{.experimental: "inferGenericTypes".}`
+
+  ```nim  test = "nim c $1"
+  {.experimental: "inferGenericTypes".}
+
+  import std/options
+
+  var x = newSeq[int](1)
+  # Do some work on 'x'...
+
+  # Works!
+  # 'x' is 'seq[int]' so 'newSeq[int]' is implied
+  x = newSeq(10)
+
+  # Works!
+  # 'T' of 'none' is bound to the 'T' of 'noneProducer', passing it along.
+  # Effectively 'none.T = noneProducer.T'
+  proc noneProducer[T](): Option[T] = none()
+  let myNone = noneProducer[int]()
+
+  # Also works
+  # 'myOtherNone' binds its 'T' to 'float' and 'noneProducer' inherits it
+  # noneProducer.T = myOtherNone.T
+  let myOtherNone: Option[float] = noneProducer()
+
+  # Works as well
+  # none.T = myOtherOtherNone.T
+  let myOtherOtherNone: Option[int] = none()
+  ```
+
+This is achieved by reducing the types on the lhs and rhs until the *lhs* is left with only types such as `T`.
+While lhs and rhs are reduced together, this does *not* mean that the *rhs* will also only be left
+with a flat type `Z`, it may be of the form `MyType[Z]`.
+
+After the types have been reduced, the types `T` are bound to the types that are left on the rhs.
+
+If bindings *cannot be inferred*, compilation will fail and manual specification is required.
+
+An example for *failing inference* can be found when passing a generic expression
+to a function/template call:
+
+  ```nim  test = "nim c $1"  status = 1
+  {.experimental: "inferGenericTypes".}
+
+  proc myProc[T](a, b: T) = discard
+
+  # Fails! Unable to infer that 'T' is supposed to be 'int'
+  myProc(newSeq[int](), newSeq(1))
+
+  # Works! Manual specification of 'T' as 'int' necessary
+  myProc(newSeq[int](), newSeq[int](1))
+  ```
+
+Combination of generic inference with the `auto` type is also unsupported:
+
+  ```nim  test = "nim c $1"  status = 1
+  {.experimental: "inferGenericTypes".}
+
+  proc produceValue[T]: auto = default(T)
+  let a: int = produceValue() # 'auto' cannot be inferred here
+  ```
+
+**Note**: The described inference does not permit the creation of overrides based on
+the return type of a procedure. It is a mapping mechanism that does not attempt to 
+perform deeper inference, nor does it modify what is a valid override.
+
+  ```nim  test = "nim c $1"  status = 1
+  # Doesn't affect the following code, it is invalid either way
+  {.experimental: "inferGenericTypes".}
+
+  proc a: int = 0
+  proc a: float = 1.0 # Fails! Invalid code and not recommended
+  ```
+
+
 Sequence literals
 -----------------
 
diff --git a/tests/generics/treturn_inference.nim b/tests/generics/treturn_inference.nim
new file mode 100644
index 000000000..05d38cef4
--- /dev/null
+++ b/tests/generics/treturn_inference.nim
@@ -0,0 +1,139 @@
+
+{.experimental: "inferGenericTypes".}
+
+import std/tables
+
+block:
+  type
+    MyOption[T, Z] = object
+      x: T
+      y: Z
+
+  proc none[T, Z](): MyOption[T, Z] =
+    when T is int:
+      result.x = 22
+    when Z is float:
+      result.y = 12.0
+
+  proc myGenericProc[T, Z](): MyOption[T, Z] =
+    none() # implied by return type
+
+  let a = myGenericProc[int, float]()
+  doAssert a.x == 22
+  doAssert a.y == 12.0
+
+  let b: MyOption[int, float] = none() # implied by type of b
+  doAssert b.x == 22
+  doAssert b.y == 12.0
+
+# Simple template based result with inferred type for errors
+block:
+  type
+    ResultKind {.pure.} = enum
+      Ok
+      Err
+
+    Result[T] = object
+      case kind: ResultKind
+      of Ok:
+        data: T
+      of Err:
+        errmsg: cstring
+
+  template err[T](msg: static cstring): Result[T] =
+    Result[T](kind : ResultKind.Err, errmsg : msg)
+
+  proc testproc(): Result[int] =
+    err("Inferred error!") # implied by proc return
+  let r = testproc()
+  doAssert r.kind == ResultKind.Err
+  doAssert r.errmsg == "Inferred error!"
+
+# Builtin seq
+block:
+  let x: seq[int] = newSeq(1)
+  doAssert x is seq[int]
+  doAssert x.len() == 1
+
+  type
+    MyType[T, Z] = object
+      x: T
+      y: Z
+
+  let y: seq[MyType[int, float]] = newSeq(2)
+  doAssert y is seq[MyType[int, float]]
+  doAssert y.len() == 2
+
+  let z = MyType[seq[float], string](
+    x : newSeq(3),
+    y : "test"
+  )
+  doAssert z.x is seq[float]
+  doAssert z.x.len() == 3
+  doAssert z.y is string
+  doAssert z.y == "test"
+
+# array
+block:
+  proc giveArray[N, T](): array[N, T] =
+    for i in 0 .. N.high:
+      result[i] = i
+  var x: array[2, int] = giveArray()
+  doAssert x == [0, 1]
+
+# tuples
+block:
+  proc giveTuple[T, Z]: (T, Z, T) = discard
+  let x: (int, float, int) = giveTuple()
+  doAssert x is (int, float, int)
+  doAssert x == (0, 0.0, 0)
+
+  proc giveNamedTuple[T, Z]: tuple[a: T, b: Z] = discard
+  let y: tuple[a: int, b: float] = giveNamedTuple()
+  doAssert y is (int, float)
+  doAssert y is tuple[a: int, b: float]
+  doAssert y == (0, 0.0)
+
+  proc giveNestedTuple[T, Z]: ((T, Z), Z) = discard
+  let z: ((int, float), float) = giveNestedTuple()
+  doAssert z is ((int, float), float)
+  doAssert z == ((0, 0.0), 0.0)
+
+  # nesting inside a generic type
+  type MyType[T] = object
+    x: T
+  let a = MyType[(int, MyType[float])](x : giveNamedTuple())
+  doAssert a.x is (int, MyType[float])
+
+
+# basic constructors
+block:
+  type MyType[T] = object
+    x: T
+
+  proc giveValue[T](): T =
+    when T is int:
+      12
+    else:
+      default(T)
+
+  let x = MyType[int](x : giveValue())
+  doAssert x.x is int
+  doAssert x.x == 12
+
+  let y = MyType[MyType[float]](x : MyType[float](x : giveValue()))
+  doAssert y.x is MyType[float]
+  doAssert y.x.x is float
+  doAssert y.x.x == 0.0
+
+  # 'MyType[float]' is bound to 'T' directly
+  #  instead of mapping 'T' to 'float'
+  let z = MyType[MyType[float]](x : giveValue())
+  doAssert z.x is MyType[float]
+  doAssert z.x.x == 0.0
+
+  type Foo = object
+    x: Table[int, float]
+
+  let a = Foo(x: initTable())
+  doAssert a.x is Table[int, float]
\ No newline at end of file