summary refs log tree commit diff stats
diff options
context:
space:
mode:
authorJuan M Gómez <info@jmgomez.me>2023-05-17 10:44:42 +0100
committerGitHub <noreply@github.com>2023-05-17 11:44:42 +0200
commit02a10ec379d427f27f471d489247aa586078354b (patch)
tree8a92c0455785aa14320d8437a52691ee308451d8
parent1314ea75169b877f458e8b4eb1455d5f6428227b (diff)
downloadNim-02a10ec379d427f27f471d489247aa586078354b.tar.gz
Cpp Vfunctions draft (#21790)
* introduces virtual pragma, modifies proc def, prevents proc decl

* marks virtual procs as infix

* forward declare vfuncs inside the typedef

* adds naked callConv to virtual

* virtual proc error if not defined in the same top level scope as the type

* first param is now this. extracts genvirtualheaderproc

* WIP syntax

* supports obj. Removes the need for the prefix

* parameter count starts as this. Cleanup

* clean up

* sem tests

* adds integration tests

* uses constraint to store the virtual content

* introduces genVirtualProcParams

---------

Co-authored-by: Andreas Rumpf <rumpf_a@web.de>
-rw-r--r--compiler/ast.nim5
-rw-r--r--compiler/ccgtypes.nim155
-rw-r--r--compiler/cgen.nim7
-rw-r--r--compiler/modulegraphs.nim1
-rw-r--r--compiler/pragmas.nim17
-rw-r--r--compiler/sempass2.nim1
-rw-r--r--compiler/semstmts.nim19
-rw-r--r--tests/cpp/tvirtual.nim68
8 files changed, 260 insertions, 13 deletions
diff --git a/compiler/ast.nim b/compiler/ast.nim
index 3ed9bf2b2..7e92cd140 100644
--- a/compiler/ast.nim
+++ b/compiler/ast.nim
@@ -231,7 +231,7 @@ type
   TNodeKinds* = set[TNodeKind]
 
 type
-  TSymFlag* = enum    # 49 flags!
+  TSymFlag* = enum    # 50 flags!
     sfUsed,           # read access of sym (for warnings) or simply used
     sfExported,       # symbol is exported from module
     sfFromGeneric,    # symbol is instantiation of a generic; this is needed
@@ -312,6 +312,7 @@ type
                       #
                       # This is disallowed but can cause the typechecking to go into
                       # an infinite loop, this flag is used as a sentinel to stop it.
+    sfVirtual         # proc is a C++ virtual function
 
   TSymFlags* = set[TSymFlag]
 
@@ -929,7 +930,7 @@ type
       cname*: string          # resolved C declaration name in importc decl, e.g.:
                               # proc fun() {.importc: "$1aux".} => cname = funaux
     constraint*: PNode        # additional constraints like 'lit|result'; also
-                              # misused for the codegenDecl pragma in the hope
+                              # misused for the codegenDecl and virtual pragmas in the hope
                               # it won't cause problems
                               # for skModule the string literal to output for
                               # deprecated modules.
diff --git a/compiler/ccgtypes.nim b/compiler/ccgtypes.nim
index 2669dec24..7bd4dac81 100644
--- a/compiler/ccgtypes.nim
+++ b/compiler/ccgtypes.nim
@@ -12,6 +12,7 @@
 # ------------------------- Name Mangling --------------------------------
 
 import sighashes, modulegraphs
+import strscans
 import ../dist/checksums/src/checksums/md5
 
 proc isKeyword(w: PIdent): bool =
@@ -424,9 +425,105 @@ proc paramStorageLoc(param: PSym): TStorageLoc =
   else:
     result = OnUnknown
 
+macro unrollChars(x: static openArray[char], name, body: untyped) =
+  result = newStmtList()
+  for a in x:
+    result.add(newBlockStmt(newStmtList(
+      newConstStmt(name, newLit(a)),
+      copy body
+    )))
+
+proc multiFormat*(frmt: var string, chars : static openArray[char], args: openArray[seq[string]]) =
+  var res : string
+  unrollChars(chars, c):
+    res = ""
+    let arg = args[find(chars, c)]
+    var i = 0
+    var num = 0
+    while i < frmt.len:
+      if frmt[i] == c:
+        inc(i)                  
+        case frmt[i]
+        of c:
+          res.add(c)
+          inc(i)
+        of '0'..'9':
+          var j = 0
+          while true:
+            j = j * 10 + ord(frmt[i]) - ord('0')
+            inc(i)
+            if i >= frmt.len or frmt[i] notin {'0'..'9'}: break
+          num = j
+          if j > high(arg) + 1:
+            doAssert false, "invalid format string: " & frmt
+          else:
+            res.add(arg[j-1])
+        else:
+          doAssert false, "invalid format string: " & frmt
+      var start = i
+      while i < frmt.len:
+        if frmt[i] != c: inc(i)
+        else: break
+      if i - 1 >= start:
+        res.add(substr(frmt, start, i - 1))
+    frmt = res
+
+proc genVirtualProcParams(m: BModule; t: PType, rettype, params: var string,
+                   check: var IntSet, declareEnvironment=true;
+                   weakDep=false;) =
+  if t[0] == nil or isInvalidReturnType(m.config, t):
+    rettype = "void"
+  else:
+    if rettype == "":
+      rettype = getTypeDescAux(m, t[0], check, skResult)
+    else:
+      rettype = runtimeFormat(rettype.replace("'0", "$1"), [getTypeDescAux(m, t[0], check, skResult)])
+  var this = t.n[1].sym
+  fillParamName(m, this)
+  fillLoc(this.loc, locParam, t.n[1],
+          this.paramStorageLoc)
+  if this.typ.kind == tyPtr:
+    this.loc.r = "this"
+  else:
+    this.loc.r = "(*this)"
+  
+  var types = @[getTypeDescWeak(m, this.typ, check, skParam)]
+  var names = @[this.loc.r]
+
+  for i in 2..<t.n.len: 
+    if t.n[i].kind != nkSym: internalError(m.config, t.n.info, "genVirtualProcParams")
+    var param = t.n[i].sym
+    var typ, name : string
+    fillParamName(m, param)
+    fillLoc(param.loc, locParam, t.n[i],
+            param.paramStorageLoc)
+    if ccgIntroducedPtr(m.config, param, t[0]):
+      typ = getTypeDescWeak(m, param.typ, check, skParam) & "*"
+      incl(param.loc.flags, lfIndirect)
+      param.loc.storage = OnUnknown
+    elif weakDep:
+      typ = getTypeDescWeak(m, param.typ, check, skParam)
+    else:
+      typ = getTypeDescAux(m, param.typ, check, skParam)
+    if sfNoalias in param.flags:
+      typ.add("NIM_NOALIAS ")
+
+    name = param.loc.r
+    types.add typ
+    names.add name
+  multiFormat(params, @['\'', '#'], [types, names])
+  if params == "()":
+    params = "(void)"
+  if tfVarargs in t.flags:
+    if params != "(":
+      params[^1] = ','
+    else:
+      params.delete(params.len()-1..params.len()-1)
+    params.add("...)")
+
 proc genProcParams(m: BModule; t: PType, rettype, params: var Rope,
                    check: var IntSet, declareEnvironment=true;
-                   weakDep=false) =
+                   weakDep=false;) =
   params = "("
   if t[0] == nil or isInvalidReturnType(m.config, t):
     rettype = "void"
@@ -564,10 +661,18 @@ proc genRecordFieldsAux(m: BModule; n: PNode,
           result.addf("\t$1$3 $2;$n", [getTypeDescAux(m, field.loc.t, check, skField), sname, noAlias])
   else: internalError(m.config, n.info, "genRecordFieldsAux()")
 
+proc genVirtualProcHeader(m: BModule; prc: PSym; result: var Rope; asPtr: bool = false, isFwdDecl:bool = false)
+
 proc getRecordFields(m: BModule; typ: PType, check: var IntSet): Rope =
   result = newRopeAppender()
   genRecordFieldsAux(m, typ.n, typ, check, result)
-
+  if typ.itemId in m.g.graph.virtualProcsPerType:
+    let procs = m.g.graph.virtualProcsPerType[typ.itemId]
+    for prc in procs:
+      var header: Rope
+      genVirtualProcHeader(m, prc, header, false, true)
+      result.add "\t" & header & ";\n"
+    
 proc fillObjectFields*(m: BModule; typ: PType) =
   # sometimes generic objects are not consistently merged. We patch over
   # this fact here.
@@ -971,19 +1076,58 @@ proc isReloadable(m: BModule; prc: PSym): bool =
 proc isNonReloadable(m: BModule; prc: PSym): bool =
   return m.hcrOn and sfNonReloadable in prc.flags
 
+proc parseVFunctionDecl(val: string; name, params, retType: var string; isFnConst, isOverride: var bool) =
+  var afterParams: string
+  if scanf(val, "$*($*)$s$*", name, params, afterParams):
+    isFnConst = afterParams.find("const") > -1
+    isOverride = afterParams.find("override") > -1
+    discard scanf(afterParams, "->$s$* ", retType)
+  params = "(" & params & ")"
+
+proc genVirtualProcHeader(m: BModule; prc: PSym; result: var Rope; asPtr: bool = false, isFwdDecl : bool = false) =
+  assert sfVirtual in prc.flags
+  # using static is needed for inline procs
+  var check = initIntSet()
+  fillBackendName(m, prc)
+  fillLoc(prc.loc, locProc, prc.ast[namePos], OnUnknown)
+  var typ = prc.typ.n[1].sym.typ
+  var memberOp = "#."
+  if typ.kind == tyPtr:
+    typ = typ[0]
+    memberOp = "#->"
+  var typDesc = getTypeDescWeak(m, typ, check, skParam)
+  let asPtrStr = rope(if asPtr: "_PTR" else: "")
+  var name, params, rettype: string
+  var isFnConst, isOverride: bool
+  parseVFunctionDecl(prc.constraint.strVal, name, params, rettype, isFnConst, isOverride)
+  genVirtualProcParams(m, prc.typ, rettype, params, check, true, false) 
+  var fnConst, override: string
+  if isFnConst:
+    fnConst = " const"
+  if isFwdDecl:
+    rettype = "virtual " & rettype
+    if isOverride: 
+      override = " override"
+  else: 
+    prc.loc.r = "$1 $2 (@)" % [memberOp, name]
+    name = "$1::$2" % [typDesc, name]
+  
+  result.add "N_LIB_PRIVATE "
+  result.addf("$1$2($3, $4)$5$6$7",
+        [rope(CallingConvToStr[prc.typ.callConv]), asPtrStr, rettype, name,
+        params, fnConst, override])
+  
 proc genProcHeader(m: BModule; prc: PSym; result: var Rope; asPtr: bool = false) =
   # using static is needed for inline procs
   var check = initIntSet()
   fillBackendName(m, prc)
   fillLoc(prc.loc, locProc, prc.ast[namePos], OnUnknown)
   var rettype, params: Rope
-  genProcParams(m, prc.typ, rettype, params, check)
+  genProcParams(m, prc.typ, rettype, params, check, true, false)
   # handle the 2 options for hotcodereloading codegen - function pointer
   # (instead of forward declaration) or header for function body with "_actual" postfix
   let asPtrStr = rope(if asPtr: "_PTR" else: "")
   var name = prc.loc.r
-  if isReloadable(m, prc) and not asPtr:
-    name.add("_actual")
   # careful here! don't access ``prc.ast`` as that could reload large parts of
   # the object graph!
   if prc.constraint.isNil:
@@ -1003,6 +1147,7 @@ proc genProcHeader(m: BModule; prc: PSym; result: var Rope; asPtr: bool = false)
     let asPtrStr = if asPtr: (rope("(*") & name & ")") else: name
     result.add runtimeFormat(prc.cgDeclFrmt, [rettype, asPtrStr, params])
 
+
 # ------------------ type info generation -------------------------------------
 
 proc genTypeInfoV1(m: BModule; t: PType; info: TLineInfo): Rope
diff --git a/compiler/cgen.nim b/compiler/cgen.nim
index 17b0350b6..107af373b 100644
--- a/compiler/cgen.nim
+++ b/compiler/cgen.nim
@@ -1128,7 +1128,10 @@ proc isNoReturn(m: BModule; s: PSym): bool {.inline.} =
 proc genProcAux*(m: BModule, prc: PSym) =
   var p = newProc(prc, m)
   var header = newRopeAppender()
-  genProcHeader(m, prc, header)
+  if m.config.backend == backendCpp and sfVirtual in prc.flags:
+    genVirtualProcHeader(m, prc, header)
+  else:
+    genProcHeader(m, prc, header)
   var returnStmt: Rope = ""
   assert(prc.ast != nil)
 
@@ -1234,7 +1237,7 @@ proc requiresExternC(m: BModule; sym: PSym): bool {.inline.} =
 
 proc genProcPrototype(m: BModule, sym: PSym) =
   useHeader(m, sym)
-  if lfNoDecl in sym.loc.flags: return
+  if lfNoDecl in sym.loc.flags or sfVirtual in sym.flags: return
   if lfDynamicLib in sym.loc.flags:
     if sym.itemId.module != m.module.position and
         not containsOrIncl(m.declaredThings, sym.id):
diff --git a/compiler/modulegraphs.nim b/compiler/modulegraphs.nim
index 5cb6a1c34..de97ced99 100644
--- a/compiler/modulegraphs.nim
+++ b/compiler/modulegraphs.nim
@@ -79,6 +79,7 @@ type
     procInstCache*: Table[ItemId, seq[LazyInstantiation]] # A symbol's ItemId.
     attachedOps*: array[TTypeAttachedOp, Table[ItemId, LazySym]] # Type ID, destructors, etc.
     methodsPerType*: Table[ItemId, seq[(int, LazySym)]] # Type ID, attached methods
+    virtualProcsPerType*: Table[ItemId, seq[PSym]] # Type ID, attached virtual procs
     enumToStringProcs*: Table[ItemId, LazySym]
     emittedTypeInfo*: Table[string, FileIndex]
 
diff --git a/compiler/pragmas.nim b/compiler/pragmas.nim
index 6fe09921d..be8e83d25 100644
--- a/compiler/pragmas.nim
+++ b/compiler/pragmas.nim
@@ -34,7 +34,7 @@ const
     wAsmNoStackFrame, wDiscardable, wNoInit, wCodegenDecl,
     wGensym, wInject, wRaises, wEffectsOf, wTags, wForbids, wLocks, wDelegator, wGcSafe,
     wConstructor, wLiftLocals, wStackTrace, wLineTrace, wNoDestroy,
-    wRequires, wEnsures, wEnforceNoRaises, wSystemRaisesDefect}
+    wRequires, wEnsures, wEnforceNoRaises, wSystemRaisesDefect, wVirtual}
   converterPragmas* = procPragmas
   methodPragmas* = procPragmas+{wBase}-{wImportCpp}
   templatePragmas* = {wDeprecated, wError, wGensym, wInject, wDirty,
@@ -211,9 +211,9 @@ proc processImportObjC(c: PContext; s: PSym, extname: string, info: TLineInfo) =
   let m = s.getModule()
   incl(m.flags, sfCompileToObjc)
 
-proc newEmptyStrNode(c: PContext; n: PNode): PNode {.noinline.} =
+proc newEmptyStrNode(c: PContext; n: PNode, strVal: string = ""): PNode {.noinline.} =
   result = newNodeIT(nkStrLit, n.info, getSysType(c.graph, n.info, tyString))
-  result.strVal = ""
+  result.strVal = strVal
 
 proc getStrLitNode(c: PContext, n: PNode): PNode =
   if n.kind notin nkPragmaCallKinds or n.len != 2:
@@ -245,6 +245,14 @@ proc getOptionalStr(c: PContext, n: PNode, defaultStr: string): string =
   if n.kind in nkPragmaCallKinds: result = expectStrLit(c, n)
   else: result = defaultStr
 
+proc processVirtual(c: PContext, n: PNode, s: PSym) =
+  s.constraint = newEmptyStrNode(c, n, getOptionalStr(c, n, "$1"))
+  s.constraint.strVal = s.constraint.strVal % s.name.s
+  s.flags.incl {sfVirtual, sfInfixCall, sfExportc, sfMangleCpp}
+  
+  s.typ.callConv = ccNoConvention
+  incl c.config.globalOptions, optMixedMode
+
 proc processCodegenDecl(c: PContext, n: PNode, sym: PSym) =
   sym.constraint = getStrLitNode(c, n)
 
@@ -1263,6 +1271,9 @@ proc singlePragma(c: PContext, sym: PSym, n: PNode, i: var int,
         sym.flags.incl sfNeverRaises
       of wSystemRaisesDefect:
         sym.flags.incl sfSystemRaisesDefect
+      of wVirtual:
+          processVirtual(c, it, sym)
+      
       else: invalidPragma(c, it)
     elif comesFromPush and whichKeyword(ident) != wInvalid:
       discard "ignore the .push pragma; it doesn't apply"
diff --git a/compiler/sempass2.nim b/compiler/sempass2.nim
index 7024c99fe..fc9755aa2 100644
--- a/compiler/sempass2.nim
+++ b/compiler/sempass2.nim
@@ -833,7 +833,6 @@ proc trackCall(tracked: PEffects; n: PNode) =
       # and it's not a recursive call:
       if not (a.kind == nkSym and a.sym == tracked.owner):
         markSideEffect(tracked, a, n.info)
-
   # p's effects are ours too:
   var a = n[0]
   #if canRaise(a):
diff --git a/compiler/semstmts.nim b/compiler/semstmts.nim
index f81423915..579af973e 100644
--- a/compiler/semstmts.nim
+++ b/compiler/semstmts.nim
@@ -2188,6 +2188,25 @@ proc semProcAux(c: PContext, n: PNode, kind: TSymKind,
 
   if sfBorrow in s.flags and c.config.cmd notin cmdDocLike:
     result[bodyPos] = c.graph.emptyNode
+  
+  if sfVirtual in s.flags:
+    if c.config.backend == backendCpp:
+      for son in s.typ.sons:
+        if son!=nil and son.isMetaType:
+          localError(c.config, n.info, "virtual unsupported for generic routine")
+
+      var typ = s.typ.sons[1]
+      if typ.kind == tyPtr:
+        typ = typ[0]
+      if typ.kind != tyObject:
+        localError(c.config, n.info, "virtual must be a non ref object type")
+      if typ.owner.id == s.owner.id and c.module.id == s.owner.id:
+        c.graph.virtualProcsPerType.mgetOrPut(typ.itemId, @[]).add s
+      else:
+        localError(c.config, n.info, 
+          "virtual procs must be defined in the same scope as the type they are virtual for and it must be a top level scope")
+    else:
+      localError(c.config, n.info, "virtual procs are only supported in C++")
 
   if n[bodyPos].kind != nkEmpty and sfError notin s.flags:
     # for DLL generation we allow sfImportc to have a body, for use in VM
diff --git a/tests/cpp/tvirtual.nim b/tests/cpp/tvirtual.nim
new file mode 100644
index 000000000..d7dd6a7c4
--- /dev/null
+++ b/tests/cpp/tvirtual.nim
@@ -0,0 +1,68 @@
+discard """
+  targets: "cpp"
+  cmd: "nim cpp $file"
+  output: '''
+hello foo
+hello boo
+hello boo
+Const Message: hello world
+NimPrinter: hello world
+NimPrinterConstRef: hello world
+'''
+"""
+
+{.emit:"""/*TYPESECTION*/
+#include <iostream>
+  class CppPrinter {
+  public:
+    
+    virtual void printConst(char* message) const {
+        std::cout << "Const Message: " << message << std::endl;
+    }
+    virtual void printConstRef(char* message, const int& flag) const {
+        std::cout << "Const Ref Message: " << message << std::endl;
+    }  
+};
+""".}
+
+proc newCpp*[T](): ptr T {.importcpp:"new '*0()".}
+type 
+  Foo = object of RootObj
+  FooPtr = ptr Foo
+  Boo = object of Foo
+  BooPtr = ptr Boo
+  CppPrinter {.importcpp, inheritable.} = object
+  NimPrinter {.exportc.} = object of CppPrinter
+
+proc salute(self:FooPtr) {.virtual.} = 
+  echo "hello foo"
+
+proc salute(self:BooPtr) {.virtual.} =
+  echo "hello boo"
+
+let foo = newCpp[Foo]()
+let boo = newCpp[Boo]()
+let booAsFoo = cast[FooPtr](newCpp[Boo]())
+
+#polymorphism works
+foo.salute()
+boo.salute()
+booAsFoo.salute()
+let message = "hello world".cstring
+
+proc printConst(self:CppPrinter, message:cstring) {.importcpp.}
+CppPrinter().printConst(message)
+
+#notice override is optional. 
+#Will make the cpp compiler to fail if not virtual function with the same signature if found in the base type
+proc printConst(self:NimPrinter, message:cstring) {.virtual:"$1('2 #2) const override".} =
+  echo "NimPrinter: " & $message
+
+proc printConstRef(self:NimPrinter, message:cstring, flag:int32) {.virtual:"$1('2 #2, const '3& #3 ) const override".} =
+  echo "NimPrinterConstRef: " & $message
+
+NimPrinter().printConst(message)
+var val : int32 = 10
+NimPrinter().printConstRef(message, val)
+
+