summary refs log tree commit diff stats
diff options
context:
space:
mode:
authorJason Beetham <beefers331@gmail.com>2021-09-11 05:05:53 -0600
committerGitHub <noreply@github.com>2021-09-11 13:05:53 +0200
commit66e53bdd7b465edd9045314d4d6a60ef6e0b5e32 (patch)
tree1b1f5a032f8a558d09f0e100bbf7c86fab6e2b3f
parent1f68f71ec24d3ec6b8e83411a6f1604277f9d493 (diff)
downloadNim-66e53bdd7b465edd9045314d4d6a60ef6e0b5e32.tar.gz
Fixed type inference for 'set` and 'tuple' (#18827)
* improved built in typeclass inference

* Smarter logic to fit node

* Forgot the untyped check
-rw-r--r--compiler/sem.nim3
-rw-r--r--tests/metatype/typeclassinference.nim22
2 files changed, 24 insertions, 1 deletions
diff --git a/compiler/sem.nim b/compiler/sem.nim
index ad0ba1f7c..bdecbe602 100644
--- a/compiler/sem.nim
+++ b/compiler/sem.nim
@@ -77,7 +77,8 @@ template semIdeForTemplateOrGeneric(c: PContext; n: PNode;
 
 proc fitNodePostMatch(c: PContext, formal: PType, arg: PNode): PNode =
   let x = arg.skipConv
-  if x.kind in {nkPar, nkTupleConstr, nkCurly} and formal.kind != tyUntyped:
+  if (x.kind == nkCurly and formal.kind == tySet and formal.base.kind != tyGenericParam) or
+    (x.kind in {nkPar, nkTupleConstr}) and formal.kind notin {tyUntyped, tyBuiltInTypeClass}:
     changeType(c, x, formal, check=true)
   result = arg
   result = skipHiddenSubConv(result, c.graph, c.idgen)
diff --git a/tests/metatype/typeclassinference.nim b/tests/metatype/typeclassinference.nim
index c845e04f7..b3f197718 100644
--- a/tests/metatype/typeclassinference.nim
+++ b/tests/metatype/typeclassinference.nim
@@ -19,3 +19,25 @@ var ptr1: ptr = addr(str1)
 var str2: string = "hello, world!"
 var ptr2: ptr = str2
 
+block: # built in typeclass inference
+  proc tupleA(): tuple = return (1, 2)
+  proc tupleB(): tuple = (1f, 2f)
+  assert typeof(tupleA()) is (int, int)
+  assert typeof(tupleB()) is (float32, float32)
+
+  proc a(val: int or float): tuple = 
+    when typeof(val) is int:
+      (10, 10)
+    else:
+      (30f, 30f)
+
+  assert typeof(a(10)) is (int, int)
+  assert typeof(a(10.0)) is (float32, float32)
+
+  proc b(val: int or float): set = 
+    when typeof(val) is int:
+      {10u8, 3}
+    else:
+      {'a', 'b'}
+  assert typeof(b(10)) is set[uint8]
+  assert typeof(b(10.0)) is set[char]
\ No newline at end of file