summary refs log tree commit diff stats
path: root/compiler/semstmts.nim
diff options
context:
space:
mode:
Diffstat (limited to 'compiler/semstmts.nim')
-rw-r--r--compiler/semstmts.nim88
1 files changed, 87 insertions, 1 deletions
diff --git a/compiler/semstmts.nim b/compiler/semstmts.nim
index e52a1b5cc..4c1fbede3 100644
--- a/compiler/semstmts.nim
+++ b/compiler/semstmts.nim
@@ -730,7 +730,9 @@ proc typeSectionLeftSidePass(c: PContext, n: PNode) =
     var s: PSym
     if name.kind == nkDotExpr:
       s = qualifiedLookUp(c, name, {checkUndeclared, checkModule})
-      if s.kind != skType or s.typ.skipTypes(abstractPtrs).kind != tyObject or tfPartial notin s.typ.skipTypes(abstractPtrs).flags:
+      if s.kind != skType or
+         s.typ.skipTypes(abstractPtrs).kind != tyObject or
+         tfPartial notin s.typ.skipTypes(abstractPtrs).flags:
         localError(name.info, "only .partial objects can be extended")
     else:
       s = semIdentDef(c, name, skType)
@@ -742,6 +744,87 @@ proc typeSectionLeftSidePass(c: PContext, n: PNode) =
       if sfGenSym notin s.flags: addInterfaceDecl(c, s)
     a.sons[0] = newSymNode(s)
 
+proc checkCovariantParamsUsages(genericType: PType) =
+  var body = genericType{-1}
+
+  proc traverseSubTypes(t: PType): bool =
+    template error(msg) = localError(genericType.sym.info, msg)
+
+    result = false
+
+    template subresult(r) =
+      let sub = r
+      result = result or sub
+
+    case t.kind
+    of tyGenericParam:
+      t.sym.flags.incl sfWeakCovariant
+      return true
+
+    of tyObject:
+      for field in t.n:
+        subresult traverseSubTypes(field.typ)
+
+    of tyArray:
+      return traverseSubTypes(t[1])
+
+    of tyProc:
+      for subType in t.sons:
+        if subType != nil:
+          subresult traverseSubTypes(subType)
+      if result:
+        error("non-invariant type param used in a proc type: " &  $t)
+
+    of tySequence:
+      return traverseSubTypes(t[0])
+
+    of tyGenericInvocation:
+      let targetBody = t[0]
+      for i in 1 .. <t.len:
+        let param = t[i]
+        if param.kind == tyGenericParam:
+          if sfCovariant in param.sym.flags:
+            let formalFlags = targetBody[i-1].sym.flags
+            if sfCovariant notin formalFlags:
+              error("covariant param '" & param.sym.name.s &
+                    "' used in a non-covariant position")
+            elif sfWeakCovariant in formalFlags:
+              param.sym.flags.incl sfWeakCovariant
+            result = true
+          elif sfContravariant in param.sym.flags:
+            let formalParam = targetBody[i-1].sym
+            if sfContravariant notin formalParam.flags:
+              error("contravariant param '" & param.sym.name.s &
+                    "' used in a non-contravariant position")
+            result = true
+        else:
+          subresult traverseSubTypes(param)
+
+    of tyAnd, tyOr, tyNot, tyStatic, tyBuiltInTypeClass, tyCompositeTypeClass:
+      error("non-invariant type parameters cannot be used with types such '" & $t & "'")
+
+    of tyUserTypeClass, tyUserTypeClassInst:
+      error("non-invariant type parameters are not supported in concepts")
+
+    of tyTuple:
+      for fieldType in t.sons:
+        subresult traverseSubTypes(fieldType)
+
+    of tyPtr, tyRef, tyVar:
+      if t.base.kind == tyGenericParam: return true
+      return traverseSubTypes(t.base)
+
+    of tyDistinct, tyAlias:
+      return traverseSubTypes(t.lastSon)
+
+    of tyGenericInst:
+      internalAssert false
+
+    else:
+      discard
+
+  discard traverseSubTypes(body)
+
 proc typeSectionRightSidePass(c: PContext, n: PNode) =
   for i in countup(0, sonsLen(n) - 1):
     var a = n.sons[i]
@@ -782,6 +865,9 @@ proc typeSectionRightSidePass(c: PContext, n: PNode) =
         body.sym = s
         body.size = -1 # could not be computed properly
         s.typ.sons[sonsLen(s.typ) - 1] = body
+        if sfCovariant in s.flags:
+          checkCovariantParamsUsages(s.typ)
+
       popOwner(c)
       closeScope(c)
     elif a.sons[2].kind != nkEmpty: