summary refs log tree commit diff stats
diff options
context:
space:
mode:
-rw-r--r--compiler/ast.nim20
-rw-r--r--compiler/semmagic.nim10
-rw-r--r--lib/std/importutils.nim32
-rw-r--r--tests/stdlib/mimportutils.nim15
-rw-r--r--tests/stdlib/timportutils.nim56
5 files changed, 118 insertions, 15 deletions
diff --git a/compiler/ast.nim b/compiler/ast.nim
index 96adb8c1f..8bb8de1d6 100644
--- a/compiler/ast.nim
+++ b/compiler/ast.nim
@@ -1887,6 +1887,26 @@ proc toObject*(typ: PType): PType =
   if t.kind == tyRef: t.lastSon
   else: typ
 
+proc toObjectFromRefPtrGeneric*(typ: PType): PType =
+  #[
+  See also `toObject`.
+  Finds the underlying `object`, even in cases like these:
+  type
+    B[T] = object f0: int
+    A1[T] = ref B[T]
+    A2[T] = ref object f1: int
+    A3 = ref object f2: int
+    A4 = object f3: int
+  ]#
+  result = typ
+  while true:
+    case result.kind
+    of tyGenericBody: result = result.lastSon
+    of tyRef, tyPtr, tyGenericInst, tyGenericInvocation, tyAlias: result = result[0]
+      # automatic dereferencing is deep, refs #18298.
+    else: break
+  assert result.sym != nil
+
 proc isImportedException*(t: PType; conf: ConfigRef): bool =
   assert t != nil
 
diff --git a/compiler/semmagic.nim b/compiler/semmagic.nim
index e4e007678..d3f26e630 100644
--- a/compiler/semmagic.nim
+++ b/compiler/semmagic.nim
@@ -457,6 +457,11 @@ proc semOld(c: PContext; n: PNode): PNode =
     localError(c.config, n[1].info, n[1].sym.name.s & " does not belong to " & getCurrOwner(c).name.s)
   result = n
 
+proc semPrivateAccess(c: PContext, n: PNode): PNode =
+  let t = n[1].typ[0].toObjectFromRefPtrGeneric
+  c.currentScope.allowPrivateAccess.add t.sym
+  result = newNodeIT(nkEmpty, n.info, getSysType(c.graph, n.info, tyVoid))
+
 proc magicsAfterOverloadResolution(c: PContext, n: PNode,
                                    flags: TExprFlags): PNode =
   ## This is the preferred code point to implement magics.
@@ -574,9 +579,6 @@ proc magicsAfterOverloadResolution(c: PContext, n: PNode,
       n[0].sym.magic = mSubU
     result = n
   of mPrivateAccess:
-    var t = n[1].typ[0]
-    if t.kind in {tyRef, tyPtr}: t = t[0]
-    c.currentScope.allowPrivateAccess.add t.sym
-    result = newNodeIT(nkEmpty, n.info, getSysType(c.graph, n.info, tyVoid))
+    result = semPrivateAccess(c, n)
   else:
     result = n
diff --git a/lib/std/importutils.nim b/lib/std/importutils.nim
index 0c0f546b9..d2da76ea8 100644
--- a/lib/std/importutils.nim
+++ b/lib/std/importutils.nim
@@ -13,22 +13,32 @@ Possible future APIs:
 ]#
 
 when defined(nimImportutilsExample):
-  type Foo = object
-    x1: int # private
+  type
+    Foo = object
+      f0: int # private
+    Goo*[T] = object
+      g0: int # private
   proc initFoo*(): auto = Foo()
 
-proc privateAccess*(t: typedesc[object|(ref object)|(ptr object)]) {.magic: "PrivateAccess".} =
+proc privateAccess*(t: typedesc) {.magic: "PrivateAccess".} =
   ## Enables access to private fields of `t` in current scope.
   runnableExamples("-d:nimImportutilsExample"):
     # here we're importing a module containing:
-    # type Foo = object
-    #   x1: int # private
+    # type
+    #   Foo = object
+    #     f0: int # private
+    #   Goo*[T] = object
+    #     g0: int # private
     # proc initFoo*(): auto = Foo()
-    var a = initFoo()
+    var f = initFoo()
     block:
-      assert not compiles(a.x1)
-      privateAccess(a.type)
-      a.x1 = 1 # accessible in this scope
+      assert not compiles(f.f0)
+      privateAccess(f.type)
+      f.f0 = 1 # accessible in this scope
       block:
-        assert a.x1 == 1 # still in scope
-    assert not compiles(a.x1)
+        assert f.f0 == 1 # still in scope
+    assert not compiles(f.f0)
+
+    # this also works with generics
+    privateAccess(Goo)
+    assert Goo[float](g0: 1).g0 == 1
diff --git a/tests/stdlib/mimportutils.nim b/tests/stdlib/mimportutils.nim
index d2b185cd3..e89d58d27 100644
--- a/tests/stdlib/mimportutils.nim
+++ b/tests/stdlib/mimportutils.nim
@@ -13,5 +13,20 @@ type
     hd1: float
   PA* = ref A
   PtA* = ptr A
+  E*[T] = object
+    he1: int
+  FSub[T1, T2] = object
+    h3: T1
+    h4: T2
+  F*[T1, T2] = ref FSub[T1, T2]
+  G*[T] = ref E[T]
+  H3*[T] = object
+    h5: T
+  H2*[T] = H3[T]
+  H1*[T] = ref H2[T]
+  H*[T] = H1[T]
+
+type BAalias* = typeof(B.default)
+  # typeof is not a transparent abstraction, creates a `tyAlias`
 
 proc initB*(): B = B()
diff --git a/tests/stdlib/timportutils.nim b/tests/stdlib/timportutils.nim
index 37e2b7102..be912e702 100644
--- a/tests/stdlib/timportutils.nim
+++ b/tests/stdlib/timportutils.nim
@@ -38,6 +38,55 @@ template main =
 
     block:
       assertAll:
+        not compiles(E[int](he1: 1))
+        privateAccess E[int]
+        var e = E[int](he1: 1)
+        e.he1 == 1
+        e.he1 = 2
+        e.he1 == 2
+        e.he1 += 3
+        e.he1 == 5
+        # xxx caveat: this currently compiles but in future, we may want
+        # to make `privateAccess E[int]` only affect a specific instantiation;
+        # note that `privateAccess E` does work to cover all instantiations.
+        var e2 = E[float](he1: 1)
+
+    block:
+      assertAll:
+        not compiles(E[int](he1: 1))
+        privateAccess E
+        var e = E[int](he1: 1)
+        e.he1 == 1
+
+    block:
+      assertAll:
+        not compiles(F[int, int](h3: 1))
+        privateAccess F[int, int]
+        var e = F[int, int](h3: 1)
+        e.h3 == 1
+
+    block:
+      assertAll:
+        not compiles(F[int, int](h3: 1))
+        privateAccess F[int, int].default[].typeof
+        var e = F[int, int](h3: 1)
+        e.h3 == 1
+
+    block:
+      assertAll:
+        var a = G[int]()
+        var b = a.addr
+        privateAccess b.type
+        discard b.he1
+        discard b[][].he1
+
+    block:
+      assertAll:
+        privateAccess H[int]
+        var a = H[int](h5: 2)
+
+    block:
+      assertAll:
         privateAccess PA
         var pa = PA(a0: 1, ha1: 2)
         pa.ha1 == 2
@@ -46,6 +95,13 @@ template main =
 
     block:
       assertAll:
+        var b = BAalias()
+        not compiles(b.hb1)
+        privateAccess BAalias
+        discard b.hb1
+
+    block:
+      assertAll:
         var a = A(a0: 1)
         var a2 = a.addr
         not compiles(a2.ha1)