summary refs log tree commit diff stats
diff options
context:
space:
mode:
authorAndreas Rumpf <rumpf_a@web.de>2020-01-17 11:14:17 +0100
committerGitHub <noreply@github.com>2020-01-17 11:14:17 +0100
commit76269074014b672582bc630ff393a95e5e21dcec (patch)
tree481e3e494d4b8bef21e503a890ecf8b6fe4638c9
parent796aafe7e0856375494a352f738ac14f8de9fd4d (diff)
downloadNim-76269074014b672582bc630ff393a95e5e21dcec.tar.gz
ARC works for async on Windows (#13179)
-rw-r--r--compiler/types.nim4
-rw-r--r--lib/pure/asyncdispatch.nim56
-rw-r--r--lib/pure/asyncfile.nim12
-rw-r--r--lib/system/refs_v2.nim9
-rw-r--r--tests/arc/tasyncawait.nim69
5 files changed, 110 insertions, 40 deletions
diff --git a/compiler/types.nim b/compiler/types.nim
index 2273a0a2f..3c481ec23 100644
--- a/compiler/types.nim
+++ b/compiler/types.nim
@@ -366,8 +366,8 @@ proc canFormAcycleAux(marker: var IntSet, typ: PType, startId: int): bool =
   else: discard
 
 proc isFinal*(t: PType): bool =
-  var t = t.skipTypes(abstractInst)
-  result = t.kind != tyObject or tfFinal in t.flags
+  let t = t.skipTypes(abstractInst)
+  result = t.kind != tyObject or tfFinal in t.flags or isPureObject(t)
 
 proc canFormAcycle*(typ: PType): bool =
   var marker = initIntSet()
diff --git a/lib/pure/asyncdispatch.nim b/lib/pure/asyncdispatch.nim
index 43d2066e9..d724242d7 100644
--- a/lib/pure/asyncdispatch.nim
+++ b/lib/pure/asyncdispatch.nim
@@ -247,10 +247,10 @@ when defined(windows) or defined(nimdoc):
       ioPort: Handle
       handles: HashSet[AsyncFD]
 
-    CustomOverlapped = object of OVERLAPPED
+    CustomObj = object of OVERLAPPED
       data*: CompletionData
 
-    PCustomOverlapped* = ref CustomOverlapped
+    CustomRef* = ref CustomObj
 
     AsyncFD* = distinct int
 
@@ -258,7 +258,7 @@ when defined(windows) or defined(nimdoc):
       ioPort: Handle
       handleFd: AsyncFD
       waitFd: Handle
-      ovl: owned PCustomOverlapped
+      ovl: owned CustomRef
     PostCallbackDataPtr = ptr PostCallbackData
 
     AsyncEventImpl = object
@@ -336,13 +336,15 @@ when defined(windows) or defined(nimdoc):
 
     var lpNumberOfBytesTransferred: DWORD
     var lpCompletionKey: ULONG_PTR
-    var customOverlapped: PCustomOverlapped
+    var customOverlapped: CustomRef
     let res = getQueuedCompletionStatus(p.ioPort,
         addr lpNumberOfBytesTransferred, addr lpCompletionKey,
         cast[ptr POVERLAPPED](addr customOverlapped), llTimeout).bool
     result = true
-    when defined(gcDestructors):
-      GC_ref(customOverlapped)
+    # For 'gcDestructors' the destructor of 'customOverlapped' will
+    # be called at the end and we are the only owner here. This means
+    # We do not have to 'GC_unref(customOverlapped)' because the destructor
+    # does that for us.
 
     # http://stackoverflow.com/a/12277264/492186
     # TODO: http://www.serverframework.com/handling-multiple-pending-socket-read-and-write-operations.html
@@ -359,7 +361,8 @@ when defined(windows) or defined(nimdoc):
       if customOverlapped.data.cell.data != nil:
         system.dispose(customOverlapped.data.cell)
 
-      GC_unref(customOverlapped)
+      when not defined(gcDestructors):
+        GC_unref(customOverlapped)
     else:
       let errCode = osLastError()
       if customOverlapped != nil:
@@ -368,7 +371,8 @@ when defined(windows) or defined(nimdoc):
             lpNumberOfBytesTransferred, errCode)
         if customOverlapped.data.cell.data != nil:
           system.dispose(customOverlapped.data.cell)
-        GC_unref(customOverlapped)
+        when not defined(gcDestructors):
+          GC_unref(customOverlapped)
       else:
         if errCode.int32 == WAIT_TIMEOUT:
           # Timed out
@@ -409,6 +413,13 @@ when defined(windows) or defined(nimdoc):
     getAcceptExSockAddrs = cast[WSAPROC_GETACCEPTEXSOCKADDRS](fun)
     close(dummySock)
 
+  proc newCustom*(): CustomRef =
+    result = CustomRef() # 0
+    GC_ref(result) # 1  prevent destructor from doing a premature free.
+    # destructor of newCustom's caller --> 0. This means
+    # Windows holds a ref for us with RC == 0 (single owner).
+    # This is passed back to us in the IO completion port.
+
   proc recv*(socket: AsyncFD, size: int,
              flags = {SocketFlag.SafeDisconn}): owned(Future[string]) =
     ## Reads **up to** ``size`` bytes from ``socket``. Returned future will
@@ -435,8 +446,7 @@ when defined(windows) or defined(nimdoc):
 
     var bytesReceived: DWORD
     var flagsio = flags.toOSFlags().DWORD
-    var ol = PCustomOverlapped()
-    GC_ref(ol)
+    var ol = newCustom()
     ol.data = CompletionData(fd: socket, cb:
       proc (fd: AsyncFD, bytesCount: DWORD, errcode: OSErrorCode) =
         if not retFuture.finished:
@@ -512,8 +522,7 @@ when defined(windows) or defined(nimdoc):
 
     var bytesReceived: DWORD
     var flagsio = flags.toOSFlags().DWORD
-    var ol = PCustomOverlapped()
-    GC_ref(ol)
+    var ol = newCustom()
     ol.data = CompletionData(fd: socket, cb:
       proc (fd: AsyncFD, bytesCount: DWORD, errcode: OSErrorCode) =
         if not retFuture.finished:
@@ -565,8 +574,7 @@ when defined(windows) or defined(nimdoc):
     dataBuf.len = size.ULONG
 
     var bytesReceived, lowFlags: DWORD
-    var ol = PCustomOverlapped()
-    GC_ref(ol)
+    var ol = newCustom()
     ol.data = CompletionData(fd: socket, cb:
       proc (fd: AsyncFD, bytesCount: DWORD, errcode: OSErrorCode) =
         if not retFuture.finished:
@@ -616,8 +624,7 @@ when defined(windows) or defined(nimdoc):
     zeroMem(addr(staddr[0]), 128)
     copyMem(addr(staddr[0]), saddr, saddrLen)
 
-    var ol = PCustomOverlapped()
-    GC_ref(ol)
+    var ol = newCustom()
     ol.data = CompletionData(fd: socket, cb:
       proc (fd: AsyncFD, bytesCount: DWORD, errcode: OSErrorCode) =
         if not retFuture.finished:
@@ -658,8 +665,7 @@ when defined(windows) or defined(nimdoc):
     var bytesReceived = 0.DWORD
     var lowFlags = 0.DWORD
 
-    var ol = PCustomOverlapped()
-    GC_ref(ol)
+    var ol = newCustom()
     ol.data = CompletionData(fd: socket, cb:
       proc (fd: AsyncFD, bytesCount: DWORD, errcode: OSErrorCode) =
         if not retFuture.finished:
@@ -754,8 +760,7 @@ when defined(windows) or defined(nimdoc):
           clientSock.close()
           retFuture.fail(getCurrentException())
 
-    var ol = PCustomOverlapped()
-    GC_ref(ol)
+    var ol = newCustom()
     ol.data = CompletionData(fd: socket, cb:
       proc (fd: AsyncFD, bytesCount: DWORD, errcode: OSErrorCode) =
         if not retFuture.finished:
@@ -799,7 +804,7 @@ when defined(windows) or defined(nimdoc):
 
   {.push stackTrace: off.}
   proc waitableCallback(param: pointer,
-                        timerOrWaitFired: WINBOOL): void {.stdcall.} =
+                        timerOrWaitFired: WINBOOL) {.stdcall.} =
     var p = cast[PostCallbackDataPtr](param)
     discard postQueuedCompletionStatus(p.ioPort, timerOrWaitFired.DWORD,
                                        ULONG_PTR(p.handleFd),
@@ -815,8 +820,7 @@ when defined(windows) or defined(nimdoc):
     var pcd = cast[PostCallbackDataPtr](allocShared0(sizeof(PostCallbackData)))
     pcd.ioPort = p.ioPort
     pcd.handleFd = fd
-    var ol = PCustomOverlapped()
-    GC_ref(ol)
+    var ol = newCustom()
 
     ol.data = CompletionData(fd: fd, cb:
       proc(fd: AsyncFD, bytesCount: DWORD, errcode: OSErrorCode) {.gcsafe.} =
@@ -931,8 +935,7 @@ when defined(windows) or defined(nimdoc):
     let handleFD = AsyncFD(hEvent)
     pcd.ioPort = p.ioPort
     pcd.handleFd = handleFD
-    var ol = PCustomOverlapped()
-    GC_ref(ol)
+    var ol = newCustom()
     ol.data.fd = handleFD
     ol.data.cb = handleCallback
     # We need to protect our callback environment value, so GC will not free it
@@ -1621,8 +1624,7 @@ when defined(windows) or defined(nimdoc):
     let retFuture = newFuture[void]("doConnect")
     result = retFuture
 
-    var ol = PCustomOverlapped()
-    GC_ref(ol)
+    var ol = newCustom()
     ol.data = CompletionData(fd: socket, cb:
       proc (fd: AsyncFD, bytesCount: DWORD, errcode: OSErrorCode) =
         if not retFuture.finished:
diff --git a/lib/pure/asyncfile.nim b/lib/pure/asyncfile.nim
index 0bf4e9d27..ad111ec50 100644
--- a/lib/pure/asyncfile.nim
+++ b/lib/pure/asyncfile.nim
@@ -132,8 +132,7 @@ proc readBuffer*(f: AsyncFile, buf: pointer, size: int): Future[int] =
   var retFuture = newFuture[int]("asyncfile.readBuffer")
 
   when defined(windows) or defined(nimdoc):
-    var ol = PCustomOverlapped()
-    GC_ref(ol)
+    var ol = newCustom()
     ol.data = CompletionData(fd: f.fd, cb:
       proc (fd: AsyncFD, bytesCount: DWORD, errcode: OSErrorCode) =
         if not retFuture.finished:
@@ -212,8 +211,7 @@ proc read*(f: AsyncFile, size: int): Future[string] =
   when defined(windows) or defined(nimdoc):
     var buffer = alloc0(size)
 
-    var ol = PCustomOverlapped()
-    GC_ref(ol)
+    var ol = newCustom()
     ol.data = CompletionData(fd: f.fd, cb:
       proc (fd: AsyncFD, bytesCount: DWORD, errcode: OSErrorCode) =
         if not retFuture.finished:
@@ -340,8 +338,7 @@ proc writeBuffer*(f: AsyncFile, buf: pointer, size: int): Future[void] =
   ## specified file.
   var retFuture = newFuture[void]("asyncfile.writeBuffer")
   when defined(windows) or defined(nimdoc):
-    var ol = PCustomOverlapped()
-    GC_ref(ol)
+    var ol = newCustom()
     ol.data = CompletionData(fd: f.fd, cb:
       proc (fd: AsyncFD, bytesCount: DWORD, errcode: OSErrorCode) =
         if not retFuture.finished:
@@ -414,8 +411,7 @@ proc write*(f: AsyncFile, data: string): Future[void] =
     var buffer = alloc0(data.len)
     copyMem(buffer, addr copy[0], data.len)
 
-    var ol = PCustomOverlapped()
-    GC_ref(ol)
+    var ol = newCustom()
     ol.data = CompletionData(fd: f.fd, cb:
       proc (fd: AsyncFD, bytesCount: DWORD, errcode: OSErrorCode) =
         if not retFuture.finished:
diff --git a/lib/system/refs_v2.nim b/lib/system/refs_v2.nim
index e07c33086..261f215e0 100644
--- a/lib/system/refs_v2.nim
+++ b/lib/system/refs_v2.nim
@@ -83,7 +83,8 @@ proc nimDecWeakRef(p: pointer) {.compilerRtl, inl.} =
 
 proc nimIncRef(p: pointer) {.compilerRtl, inl.} =
   inc head(p).rc, rcIncrement
-  #cprintf("[INCREF] %p\n", p)
+  when traceCollector:
+    cprintf("[INCREF] %p\n", head(p))
 
 proc nimRawDispose(p: pointer) {.compilerRtl.} =
   when not defined(nimscript):
@@ -131,11 +132,13 @@ proc nimDecRefIsLast(p: pointer): bool {.compilerRtl, inl.} =
     var cell = head(p)
     if (cell.rc and not rcMask) == 0:
       result = true
-      #cprintf("[DESTROY] %p\n", p)
+      when traceCollector:
+        cprintf("[ABOUT TO DESTROY] %p\n", cell)
     else:
       dec cell.rc, rcIncrement
       # According to Lins it's correct to do nothing else here.
-      #cprintf("[DeCREF] %p\n", p)
+      when traceCollector:
+        cprintf("[DeCREF] %p\n", cell)
 
 proc GC_unref*[T](x: ref T) =
   ## New runtime only supports this operation for 'ref T'.
diff --git a/tests/arc/tasyncawait.nim b/tests/arc/tasyncawait.nim
new file mode 100644
index 000000000..89e924104
--- /dev/null
+++ b/tests/arc/tasyncawait.nim
@@ -0,0 +1,69 @@
+discard """
+  output: "5000"
+  cmd: "nim c --gc:arc $file"
+"""
+
+import asyncdispatch, asyncnet, nativesockets, net, strutils, os
+
+var msgCount = 0
+
+const
+  swarmSize = 50
+  messagesToSend = 100
+
+var clientCount = 0
+
+proc sendMessages(client: AsyncFD) {.async.} =
+  for i in 0 ..< messagesToSend:
+    await send(client, "Message " & $i & "\c\L")
+
+proc launchSwarm(port: Port) {.async.} =
+  for i in 0 ..< swarmSize:
+    var sock = createAsyncNativeSocket()
+
+    await connect(sock, "localhost", port)
+    await sendMessages(sock)
+    closeSocket(sock)
+
+proc readMessages(client: AsyncFD) {.async.} =
+  # wrapping the AsyncFd into a AsyncSocket object
+  var sockObj = newAsyncSocket(client)
+  var (ipaddr, port) = sockObj.getPeerAddr()
+  doAssert ipaddr == "127.0.0.1"
+  (ipaddr, port) = sockObj.getLocalAddr()
+  doAssert ipaddr == "127.0.0.1"
+  while true:
+    var line = await recvLine(sockObj)
+    if line == "":
+      closeSocket(client)
+      clientCount.inc
+      break
+    else:
+      if line.startswith("Message "):
+        msgCount.inc
+      else:
+        doAssert false
+
+proc createServer(port: Port) {.async.} =
+  var server = createAsyncNativeSocket()
+  block:
+    var name: Sockaddr_in
+    name.sin_family = typeof(name.sin_family)(toInt(AF_INET))
+    name.sin_port = htons(uint16(port))
+    name.sin_addr.s_addr = htonl(INADDR_ANY)
+    if bindAddr(server.SocketHandle, cast[ptr SockAddr](addr(name)),
+                sizeof(name).Socklen) < 0'i32:
+      raiseOSError(osLastError())
+
+  discard server.SocketHandle.listen()
+  while true:
+    asyncCheck readMessages(await accept(server))
+
+asyncCheck createServer(Port(10335))
+asyncCheck launchSwarm(Port(10335))
+while true:
+  poll()
+  if clientCount == swarmSize: break
+
+assert msgCount == swarmSize * messagesToSend
+echo msgCount