diff options
-rw-r--r-- | compiler/options.nim | 3 | ||||
-rw-r--r-- | compiler/semcall.nim | 64 | ||||
-rw-r--r-- | compiler/semdata.nim | 2 | ||||
-rw-r--r-- | compiler/semexprs.nim | 10 | ||||
-rw-r--r-- | doc/manual_experimental.md | 81 | ||||
-rw-r--r-- | tests/generics/treturn_inference.nim | 139 |
6 files changed, 288 insertions, 11 deletions
diff --git a/compiler/options.nim b/compiler/options.nim index 8286a575d..5b61cb049 100644 --- a/compiler/options.nim +++ b/compiler/options.nim @@ -220,7 +220,8 @@ type unicodeOperators, # deadcode flexibleOptionalParams, strictDefs, - strictCaseObjects + strictCaseObjects, + inferGenericTypes LegacyFeature* = enum allowSemcheckedAstModification, diff --git a/compiler/semcall.nim b/compiler/semcall.nim index d2460ab06..f0d0f648a 100644 --- a/compiler/semcall.nim +++ b/compiler/semcall.nim @@ -562,8 +562,61 @@ proc getCallLineInfo(n: PNode): TLineInfo = discard result = n.info -proc semResolvedCall(c: PContext, x: TCandidate, - n: PNode, flags: TExprFlags): PNode = +proc inheritBindings(c: PContext, x: var TCandidate, expectedType: PType) = + ## Helper proc to inherit bound generic parameters from expectedType into x. + ## Does nothing if 'inferGenericTypes' isn't in c.features + if inferGenericTypes notin c.features: return + if expectedType == nil or x.callee[0] == nil: return # required for inference + + var + flatUnbound: seq[PType] + flatBound: seq[PType] + # seq[(result type, expected type)] + var typeStack = newSeq[(PType, PType)]() + + template stackPut(a, b) = + ## skips types and puts the skipped version on stack + # It might make sense to skip here one by one. It's not part of the main + # type reduction because the right side normally won't be skipped + const toSkip = { tyVar, tyLent, tyStatic, tyCompositeTypeClass } + let + x = a.skipTypes(toSkip) + y = if a.kind notin toSkip: b + else: b.skipTypes(toSkip) + typeStack.add((x, y)) + + stackPut(x.callee[0], expectedType) + + while typeStack.len() > 0: + let (t, u) = typeStack.pop() + if t == u or t == nil or u == nil or t.kind == tyAnything or u.kind == tyAnything: + continue + case t.kind + of ConcreteTypes, tyGenericInvocation, tyUncheckedArray: + # nested, add all the types to stack + let + startIdx = if u.kind in ConcreteTypes: 0 else: 1 + endIdx = min(u.sons.len() - startIdx, t.sons.len()) + + for i in startIdx ..< endIdx: + # early exit with current impl + if t[i] == nil or u[i] == nil: return + stackPut(t[i], u[i]) + of tyGenericParam: + if x.bindings.idTableGet(t) != nil: return + + # fully reduced generic param, bind it + if t notin flatUnbound: + flatUnbound.add(t) + flatBound.add(u) + else: + discard + for i in 0 ..< flatUnbound.len(): + x.bindings.idTablePut(flatUnbound[i], flatBound[i]) + +proc semResolvedCall(c: PContext, x: var TCandidate, + n: PNode, flags: TExprFlags; + expectedType: PType = nil): PNode = assert x.state == csMatch var finalCallee = x.calleeSym let info = getCallLineInfo(n) @@ -583,10 +636,12 @@ proc semResolvedCall(c: PContext, x: TCandidate, if x.calleeSym.magic in {mArrGet, mArrPut}: finalCallee = x.calleeSym else: + c.inheritBindings(x, expectedType) finalCallee = generateInstance(c, x.calleeSym, x.bindings, n.info) else: # For macros and templates, the resolved generic params # are added as normal params. + c.inheritBindings(x, expectedType) for s in instantiateGenericParamList(c, gp, x.bindings): case s.kind of skConst: @@ -615,7 +670,8 @@ proc tryDeref(n: PNode): PNode = result.add n proc semOverloadedCall(c: PContext, n, nOrig: PNode, - filter: TSymKinds, flags: TExprFlags): PNode = + filter: TSymKinds, flags: TExprFlags; + expectedType: PType = nil): PNode = var errors: CandidateErrors = @[] # if efExplain in flags: @[] else: nil var r = resolveOverloads(c, n, nOrig, filter, flags, errors, efExplain in flags) if r.state == csMatch: @@ -625,7 +681,7 @@ proc semOverloadedCall(c: PContext, n, nOrig: PNode, message(c.config, n.info, hintUserRaw, "Non-matching candidates for " & renderTree(n) & "\n" & candidates) - result = semResolvedCall(c, r, n, flags) + result = semResolvedCall(c, r, n, flags, expectedType) else: if efDetermineType in flags and c.inGenericContext > 0 and c.matchedConcept == nil: result = semGenericStmt(c, n) diff --git a/compiler/semdata.nim b/compiler/semdata.nim index ddd8d33ef..a85f8e638 100644 --- a/compiler/semdata.nim +++ b/compiler/semdata.nim @@ -135,7 +135,7 @@ type semOperand*: proc (c: PContext, n: PNode, flags: TExprFlags = {}): PNode {.nimcall.} semConstBoolExpr*: proc (c: PContext, n: PNode): PNode {.nimcall.} # XXX bite the bullet semOverloadedCall*: proc (c: PContext, n, nOrig: PNode, - filter: TSymKinds, flags: TExprFlags): PNode {.nimcall.} + filter: TSymKinds, flags: TExprFlags, expectedType: PType = nil): PNode {.nimcall.} semTypeNode*: proc(c: PContext, n: PNode, prev: PType): PType {.nimcall.} semInferredLambda*: proc(c: PContext, pt: TIdTable, n: PNode): PNode semGenerateInstance*: proc (c: PContext, fn: PSym, pt: TIdTable, diff --git a/compiler/semexprs.nim b/compiler/semexprs.nim index c6be3e833..398424bbf 100644 --- a/compiler/semexprs.nim +++ b/compiler/semexprs.nim @@ -952,17 +952,17 @@ proc semStaticExpr(c: PContext, n: PNode; expectedType: PType = nil): PNode = result = fixupTypeAfterEval(c, result, a) proc semOverloadedCallAnalyseEffects(c: PContext, n: PNode, nOrig: PNode, - flags: TExprFlags): PNode = + flags: TExprFlags; expectedType: PType = nil): PNode = if flags*{efInTypeof, efWantIterator, efWantIterable} != {}: # consider: 'for x in pReturningArray()' --> we don't want the restriction # to 'skIterator' anymore; skIterator is preferred in sigmatch already # for typeof support. # for ``typeof(countup(1,3))``, see ``tests/ttoseq``. result = semOverloadedCall(c, n, nOrig, - {skProc, skFunc, skMethod, skConverter, skMacro, skTemplate, skIterator}, flags) + {skProc, skFunc, skMethod, skConverter, skMacro, skTemplate, skIterator}, flags, expectedType) else: result = semOverloadedCall(c, n, nOrig, - {skProc, skFunc, skMethod, skConverter, skMacro, skTemplate}, flags) + {skProc, skFunc, skMethod, skConverter, skMacro, skTemplate}, flags, expectedType) if result != nil: if result[0].kind != nkSym: @@ -1138,7 +1138,7 @@ proc semDirectOp(c: PContext, n: PNode, flags: TExprFlags; expectedType: PType = # this seems to be a hotspot in the compiler! let nOrig = n.copyTree #semLazyOpAux(c, n) - result = semOverloadedCallAnalyseEffects(c, n, nOrig, flags) + result = semOverloadedCallAnalyseEffects(c, n, nOrig, flags, expectedType) if result != nil: result = afterCallActions(c, result, nOrig, flags, expectedType) else: result = errorNode(c, n) @@ -3120,7 +3120,7 @@ proc semExpr(c: PContext, n: PNode, flags: TExprFlags = {}, expectedType: PType elif s.magic == mNone: result = semDirectOp(c, n, flags, expectedType) else: result = semMagic(c, n, s, flags, expectedType) of skProc, skFunc, skMethod, skConverter, skIterator: - if s.magic == mNone: result = semDirectOp(c, n, flags) + if s.magic == mNone: result = semDirectOp(c, n, flags, expectedType) else: result = semMagic(c, n, s, flags, expectedType) else: #liMessage(n.info, warnUser, renderTree(n)); diff --git a/doc/manual_experimental.md b/doc/manual_experimental.md index 602ca46a5..4ee035b65 100644 --- a/doc/manual_experimental.md +++ b/doc/manual_experimental.md @@ -124,6 +124,87 @@ would not match the type of the variable, and an error would be given. The extent of this varies, but there are some notable special cases. + +Inferred generic parameters +--------------------------- + +In expressions making use of generic procs or templates, the expected +(unbound) types are often able to be inferred based on context. +This feature has to be enabled via `{.experimental: "inferGenericTypes".}` + + ```nim test = "nim c $1" + {.experimental: "inferGenericTypes".} + + import std/options + + var x = newSeq[int](1) + # Do some work on 'x'... + + # Works! + # 'x' is 'seq[int]' so 'newSeq[int]' is implied + x = newSeq(10) + + # Works! + # 'T' of 'none' is bound to the 'T' of 'noneProducer', passing it along. + # Effectively 'none.T = noneProducer.T' + proc noneProducer[T](): Option[T] = none() + let myNone = noneProducer[int]() + + # Also works + # 'myOtherNone' binds its 'T' to 'float' and 'noneProducer' inherits it + # noneProducer.T = myOtherNone.T + let myOtherNone: Option[float] = noneProducer() + + # Works as well + # none.T = myOtherOtherNone.T + let myOtherOtherNone: Option[int] = none() + ``` + +This is achieved by reducing the types on the lhs and rhs until the *lhs* is left with only types such as `T`. +While lhs and rhs are reduced together, this does *not* mean that the *rhs* will also only be left +with a flat type `Z`, it may be of the form `MyType[Z]`. + +After the types have been reduced, the types `T` are bound to the types that are left on the rhs. + +If bindings *cannot be inferred*, compilation will fail and manual specification is required. + +An example for *failing inference* can be found when passing a generic expression +to a function/template call: + + ```nim test = "nim c $1" status = 1 + {.experimental: "inferGenericTypes".} + + proc myProc[T](a, b: T) = discard + + # Fails! Unable to infer that 'T' is supposed to be 'int' + myProc(newSeq[int](), newSeq(1)) + + # Works! Manual specification of 'T' as 'int' necessary + myProc(newSeq[int](), newSeq[int](1)) + ``` + +Combination of generic inference with the `auto` type is also unsupported: + + ```nim test = "nim c $1" status = 1 + {.experimental: "inferGenericTypes".} + + proc produceValue[T]: auto = default(T) + let a: int = produceValue() # 'auto' cannot be inferred here + ``` + +**Note**: The described inference does not permit the creation of overrides based on +the return type of a procedure. It is a mapping mechanism that does not attempt to +perform deeper inference, nor does it modify what is a valid override. + + ```nim test = "nim c $1" status = 1 + # Doesn't affect the following code, it is invalid either way + {.experimental: "inferGenericTypes".} + + proc a: int = 0 + proc a: float = 1.0 # Fails! Invalid code and not recommended + ``` + + Sequence literals ----------------- diff --git a/tests/generics/treturn_inference.nim b/tests/generics/treturn_inference.nim new file mode 100644 index 000000000..05d38cef4 --- /dev/null +++ b/tests/generics/treturn_inference.nim @@ -0,0 +1,139 @@ + +{.experimental: "inferGenericTypes".} + +import std/tables + +block: + type + MyOption[T, Z] = object + x: T + y: Z + + proc none[T, Z](): MyOption[T, Z] = + when T is int: + result.x = 22 + when Z is float: + result.y = 12.0 + + proc myGenericProc[T, Z](): MyOption[T, Z] = + none() # implied by return type + + let a = myGenericProc[int, float]() + doAssert a.x == 22 + doAssert a.y == 12.0 + + let b: MyOption[int, float] = none() # implied by type of b + doAssert b.x == 22 + doAssert b.y == 12.0 + +# Simple template based result with inferred type for errors +block: + type + ResultKind {.pure.} = enum + Ok + Err + + Result[T] = object + case kind: ResultKind + of Ok: + data: T + of Err: + errmsg: cstring + + template err[T](msg: static cstring): Result[T] = + Result[T](kind : ResultKind.Err, errmsg : msg) + + proc testproc(): Result[int] = + err("Inferred error!") # implied by proc return + let r = testproc() + doAssert r.kind == ResultKind.Err + doAssert r.errmsg == "Inferred error!" + +# Builtin seq +block: + let x: seq[int] = newSeq(1) + doAssert x is seq[int] + doAssert x.len() == 1 + + type + MyType[T, Z] = object + x: T + y: Z + + let y: seq[MyType[int, float]] = newSeq(2) + doAssert y is seq[MyType[int, float]] + doAssert y.len() == 2 + + let z = MyType[seq[float], string]( + x : newSeq(3), + y : "test" + ) + doAssert z.x is seq[float] + doAssert z.x.len() == 3 + doAssert z.y is string + doAssert z.y == "test" + +# array +block: + proc giveArray[N, T](): array[N, T] = + for i in 0 .. N.high: + result[i] = i + var x: array[2, int] = giveArray() + doAssert x == [0, 1] + +# tuples +block: + proc giveTuple[T, Z]: (T, Z, T) = discard + let x: (int, float, int) = giveTuple() + doAssert x is (int, float, int) + doAssert x == (0, 0.0, 0) + + proc giveNamedTuple[T, Z]: tuple[a: T, b: Z] = discard + let y: tuple[a: int, b: float] = giveNamedTuple() + doAssert y is (int, float) + doAssert y is tuple[a: int, b: float] + doAssert y == (0, 0.0) + + proc giveNestedTuple[T, Z]: ((T, Z), Z) = discard + let z: ((int, float), float) = giveNestedTuple() + doAssert z is ((int, float), float) + doAssert z == ((0, 0.0), 0.0) + + # nesting inside a generic type + type MyType[T] = object + x: T + let a = MyType[(int, MyType[float])](x : giveNamedTuple()) + doAssert a.x is (int, MyType[float]) + + +# basic constructors +block: + type MyType[T] = object + x: T + + proc giveValue[T](): T = + when T is int: + 12 + else: + default(T) + + let x = MyType[int](x : giveValue()) + doAssert x.x is int + doAssert x.x == 12 + + let y = MyType[MyType[float]](x : MyType[float](x : giveValue())) + doAssert y.x is MyType[float] + doAssert y.x.x is float + doAssert y.x.x == 0.0 + + # 'MyType[float]' is bound to 'T' directly + # instead of mapping 'T' to 'float' + let z = MyType[MyType[float]](x : giveValue()) + doAssert z.x is MyType[float] + doAssert z.x.x == 0.0 + + type Foo = object + x: Table[int, float] + + let a = Foo(x: initTable()) + doAssert a.x is Table[int, float] \ No newline at end of file |