summary refs log tree commit diff stats
diff options
context:
space:
mode:
authorcheatfate <ka@hardcore.kiev.ua>2016-06-12 15:09:30 +0300
committercheatfate <ka@hardcore.kiev.ua>2016-06-12 15:09:30 +0300
commit43329c59094b88d8e65b7ae2eebf22ffb467649f (patch)
tree2499aa0b3fffdda2ef1bd810b88e1435549c9ca4
parentaadc154c9512559650afe6bdbb0a627829f486fc (diff)
downloadNim-43329c59094b88d8e65b7ae2eebf22ffb467649f.tar.gz
Introduce addRead/addWrite for Windows IOCP.
-rw-r--r--lib/pure/asyncdispatch.nim120
-rw-r--r--lib/windows/winlean.nim116
-rw-r--r--tests/async/twinasyncrw.nim257
3 files changed, 493 insertions, 0 deletions
diff --git a/lib/pure/asyncdispatch.nim b/lib/pure/asyncdispatch.nim
index 7d765ce75..c0816bf15 100644
--- a/lib/pure/asyncdispatch.nim
+++ b/lib/pure/asyncdispatch.nim
@@ -444,7 +444,16 @@ when defined(windows) or defined(nimdoc):
 
     PCustomOverlapped* = ref CustomOverlapped
 
+    PCD = object
+      ioPort: Handle
+      handleFd: AsyncFD
+      waitFd: Handle
+      ovl: PCustomOverlapped
+    PPCD = ptr PCD
+
     AsyncFD* = distinct int
+
+    Callback = proc (fd: AsyncFD): bool {.closure,gcsafe.}
   {.deprecated: [TCompletionKey: CompletionKey, TAsyncFD: AsyncFD,
                 TCustomOverlapped: CustomOverlapped, TCompletionData: CompletionData].}
 
@@ -953,6 +962,117 @@ when defined(windows) or defined(nimdoc):
     ## Unregisters ``fd``.
     getGlobalDispatcher().handles.excl(fd)
 
+  {.push stackTrace:off.}
+  proc waitableCallback(param: pointer,
+                        TimerOrWaitFired: WINBOOL): void {.stdcall.} =
+    var p = cast[PPCD](param)
+    discard postQueuedCompletionStatus(p.ioPort, Dword(TimerOrWaitFired),
+                                       ULONG_PTR(p.handleFd),
+                                       cast[pointer](p.ovl))
+  {.pop.}
+
+  template registerWaitableEvent(mask) =
+    let p = getGlobalDispatcher()
+    var flags = (WT_EXECUTEINWAITTHREAD or WT_EXECUTEONLYONCE).Dword
+    var hEvent = wsaCreateEvent()
+    if hEvent == 0:
+      raiseOSError(osLastError())
+    var pcd = cast[PPCD](allocShared0(sizeof(PCD)))
+    pcd.ioPort = p.ioPort
+    pcd.handleFd = fd
+    var ol = PCustomOverlapped()
+    GC_ref(ol)
+
+    ol.data = CompletionData(fd: fd, cb:
+      proc(fd: AsyncFD, bytesCount: Dword, errcode: OSErrorCode) =
+        # we excluding our `fd` because cb(fd) can register own handler
+        # for this `fd`
+        p.handles.excl(fd)
+        # unregisterWait() is called before callback, because appropriate
+        # winsockets function can re-enable event.
+        # https://msdn.microsoft.com/en-us/library/windows/desktop/ms741576(v=vs.85).aspx
+        discard unregisterWait(pcd.waitFd)
+        if cb(fd):
+          # callback returned `true`, so we free all allocated resources
+          discard wsaCloseEvent(hEvent)
+          deallocShared(cast[pointer](pcd))
+          # pcd.ovl will be unrefed in poll().
+        else:
+          # callback returned `false` we need to continue
+          if p.handles.contains(fd):
+            # new callback was already registered with `fd`, so we free all
+            # allocated resources. This happens because in callback `cb`
+            # addRead/addWrite was called with same `fd`.
+            discard wsaCloseEvent(hEvent)
+            deallocShared(cast[pointer](pcd))
+          else:
+            # we need to include `fd` again
+            p.handles.incl(fd)
+            # and register WaitForSingleObject again
+            if not registerWaitForSingleObject(addr(pcd.waitFd), hEvent,
+                                    cast[WAITORTIMERCALLBACK](waitableCallback),
+                                       cast[pointer](pcd), INFINITE, flags):
+              # pcd.ovl will be unrefed in poll()
+              discard wsaCloseEvent(hEvent)
+              deallocShared(cast[pointer](pcd))
+              raiseOSError(osLastError())
+            else:
+              # we ref pcd.ovl one more time, because it will be unrefed in
+              # poll()
+              GC_ref(pcd.ovl)
+    )
+    # This is main part of `hacky way` is using WSAEventSelect, so `hEvent`
+    # will be signaled when appropriate `mask` events will be triggered.
+    if wsaEventSelect(fd.SocketHandle, hEvent, mask) != 0:
+      GC_unref(ol)
+      deallocShared(cast[pointer](pcd))
+      discard wsaCloseEvent(hEvent)
+      raiseOSError(osLastError())
+
+    pcd.ovl = ol
+    if not registerWaitForSingleObject(addr(pcd.waitFd), hEvent,
+                                    cast[WAITORTIMERCALLBACK](waitableCallback),
+                                       cast[pointer](pcd), INFINITE, flags):
+      GC_unref(ol)
+      deallocShared(cast[pointer](pcd))
+      discard wsaCloseEvent(hEvent)
+      raiseOSError(osLastError())
+    p.handles.incl(fd)
+
+  proc addRead*(fd: AsyncFD, cb: Callback) =
+    ## Start watching the file descriptor for read availability and then call
+    ## the callback ``cb``.
+    ##
+    ## This is not ``pure`` mechanism for Windows Completion Ports (IOCP),
+    ## so if you can avoid it, please do it. Use `addRead` only if really
+    ## need it (main usecase is adaptation of `unix like` libraries to be
+    ## asynchronous on Windows).
+    ## If you use this function, you dont need to use asyncdispatch.recv()
+    ## or asyncdispatch.accept(), because they are using IOCP, please use
+    ## nativesockets.recv() and nativesockets.accept() instead.
+    ##
+    ## Be sure your callback ``cb`` returns ``true``, if you want to remove
+    ## watch of `read` notifications, and ``false``, if you want to continue
+    ## receiving notifies.
+    registerWaitableEvent(FD_READ or FD_ACCEPT or FD_OOB or FD_CLOSE)
+
+  proc addWrite*(fd: AsyncFD, cb: Callback) =
+    ## Start watching the file descriptor for write availability and then call
+    ## the callback ``cb``.
+    ##
+    ## This is not ``pure`` mechanism for Windows Completion Ports (IOCP),
+    ## so if you can avoid it, please do it. Use `addWrite` only if really
+    ## need it (main usecase is adaptation of `unix like` libraries to be
+    ## asynchronous on Windows).
+    ## If you use this function, you dont need to use asyncdispatch.send()
+    ## or asyncdispatch.connect(), because they are using IOCP, please use
+    ## nativesockets.send() and nativesockets.connect() instead.
+    ##
+    ## Be sure your callback ``cb`` returns ``true``, if you want to remove
+    ## watch of `write` notifications, and ``false``, if you want to continue
+    ## receiving notifies.
+    registerWaitableEvent(FD_WRITE or FD_CONNECT or FD_CLOSE)
+
   initAll()
 else:
   import selectors
diff --git a/lib/windows/winlean.nim b/lib/windows/winlean.nim
index 645ccd57b..2989e5ac9 100644
--- a/lib/windows/winlean.nim
+++ b/lib/windows/winlean.nim
@@ -442,6 +442,8 @@ type
     sa_family*: int16 # unsigned
     sa_data: array[0..13, char]
 
+  PSockAddr = ptr SockAddr
+
   InAddr* {.importc: "IN_ADDR", header: "winsock2.h".} = object
     s_addr*: uint32  # IP address
 
@@ -889,3 +891,117 @@ proc inet_ntop*(family: cint, paddr: pointer, pStringBuffer: cstring,
     result = inet_ntop_real(family, paddr, pStringBuffer, stringBufSize)
   else:
     result = inet_ntop_emulated(family, paddr, pStringBuffer, stringBufSize)
+
+type
+  WSAPROC_ACCEPTEX* = proc (sListenSocket: SocketHandle,
+                            sAcceptSocket: SocketHandle,
+                            lpOutputBuffer: pointer, dwReceiveDataLength: DWORD,
+                            dwLocalAddressLength: DWORD,
+                            dwRemoteAddressLength: DWORD,
+                            lpdwBytesReceived: ptr DWORD,
+                            lpOverlapped: POVERLAPPED): bool {.
+                            stdcall,gcsafe.}
+
+  WSAPROC_CONNECTEX* = proc (s: SocketHandle, name: ptr SockAddr, namelen: cint,
+                             lpSendBuffer: pointer, dwSendDataLength: DWORD,
+                             lpdwBytesSent: ptr DWORD,
+                             lpOverlapped: POVERLAPPED): bool {.
+                             stdcall,gcsafe.}
+
+  WSAPROC_GETACCEPTEXSOCKADDRS* = proc(lpOutputBuffer: pointer,
+                                       dwReceiveDataLength: DWORD,
+                                       dwLocalAddressLength: DWORD,
+                                       dwRemoteAddressLength: DWORD,
+                                       LocalSockaddr: ptr PSockAddr,
+                                       LocalSockaddrLength: ptr cint,
+                                       RemoteSockaddr: ptr PSockAddr,
+                                       RemoteSockaddrLength: ptr cint) {.
+                                       stdcall,gcsafe.}
+
+const
+  WT_EXECUTEDEFAULT* = 0x00000000'i32
+  WT_EXECUTEINIOTHREAD* = 0x00000001'i32
+  WT_EXECUTEINUITHREAD* = 0x00000002'i32
+  WT_EXECUTEINWAITTHREAD* = 0x00000004'i32
+  WT_EXECUTEONLYONCE* = 0x00000008'i32
+  WT_EXECUTELONGFUNCTION* = 0x00000010'i32
+  WT_EXECUTEINTIMERTHREAD* = 0x00000020'i32
+  WT_EXECUTEINPERSISTENTIOTHREAD* = 0x00000040'i32
+  WT_EXECUTEINPERSISTENTTHREAD* = 0x00000080'i32
+  WT_TRANSFER_IMPERSONATION* = 0x00000100'i32
+  PROCESS_TERMINATE* = 0x00000001'i32
+  PROCESS_CREATE_THREAD* = 0x00000002'i32
+  PROCESS_SET_SESSIONID* = 0x00000004'i32
+  PROCESS_VM_OPERATION* = 0x00000008'i32
+  PROCESS_VM_READ* = 0x00000010'i32
+  PROCESS_VM_WRITE* = 0x00000020'i32
+  PROCESS_DUP_HANDLE* = 0x00000040'i32
+  PROCESS_CREATE_PROCESS* = 0x00000080'i32
+  PROCESS_SET_QUOTA* = 0x00000100'i32
+  PROCESS_SET_INFORMATION* = 0x00000200'i32
+  PROCESS_QUERY_INFORMATION* = 0x00000400'i32
+  PROCESS_SUSPEND_RESUME* = 0x00000800'i32
+  PROCESS_QUERY_LIMITED_INFORMATION* = 0x00001000'i32
+  PROCESS_SET_LIMITED_INFORMATION* = 0x00002000'i32
+type
+  WAITORTIMERCALLBACK* = proc(para1: pointer, para2: int32): void {.stdcall.}
+
+proc postQueuedCompletionStatus*(CompletionPort: HANDLE,
+                                dwNumberOfBytesTransferred: DWORD,
+                                dwCompletionKey: ULONG_PTR,
+                                lpOverlapped: pointer): bool
+     {.stdcall, dynlib: "kernel32", importc: "PostQueuedCompletionStatus".}
+
+proc registerWaitForSingleObject*(phNewWaitObject: ptr Handle, hObject: Handle,
+                                 Callback: WAITORTIMERCALLBACK,
+                                 Context: pointer,
+                                 dwMilliseconds: ULONG,
+                                 dwFlags: ULONG): bool
+     {.stdcall, dynlib: "kernel32", importc: "RegisterWaitForSingleObject".}
+
+proc unregisterWait*(WaitHandle: HANDLE): DWORD
+     {.stdcall, dynlib: "kernel32", importc: "UnregisterWait".}
+
+proc openProcess*(dwDesiredAccess: DWORD, bInheritHandle: WINBOOL,
+                    dwProcessId: DWORD): Handle
+     {.stdcall, dynlib: "kernel32", importc: "OpenProcess".}
+
+when defined(useWinAnsi):
+  proc createEvent*(lpEventAttributes: ptr SECURITY_ATTRIBUTES,
+                    bManualReset: DWORD, bInitialState: DWORD,
+                    lpName: cstring): Handle
+       {.stdcall, dynlib: "kernel32", importc: "CreateEventA".}
+else:
+  proc createEvent*(lpEventAttributes: ptr SECURITY_ATTRIBUTES,
+                    bManualReset: DWORD, bInitialState: DWORD,
+                    lpName: ptr Utf16Char): Handle
+       {.stdcall, dynlib: "kernel32", importc: "CreateEventW".}
+
+proc setEvent*(hEvent: Handle): cint
+     {.stdcall, dynlib: "kernel32", importc: "SetEvent".}
+
+const
+  FD_READ* = 0x00000001'i32
+  FD_WRITE* = 0x00000002'i32
+  FD_OOB* = 0x00000004'i32
+  FD_ACCEPT* = 0x00000008'i32
+  FD_CONNECT* = 0x00000010'i32
+  FD_CLOSE* = 0x00000020'i32
+  FD_QQS* = 0x00000040'i32
+  FD_GROUP_QQS* = 0x00000080'i32
+  FD_ROUTING_INTERFACE_CHANGE* = 0x00000100'i32
+  FD_ADDRESS_LIST_CHANGE* = 0x00000200'i32
+  FD_ALL_EVENTS* = 0x000003FF'i32
+
+proc wsaEventSelect*(s: SocketHandle, hEventObject: Handle,
+                     lNetworkEvents: clong): cint 
+    {.stdcall, importc: "WSAEventSelect", dynlib: "ws2_32.dll".}
+
+proc wsaCreateEvent*(): Handle
+    {.stdcall, importc: "WSACreateEvent", dynlib: "ws2_32.dll".}
+
+proc wsaCloseEvent*(hEvent: Handle): bool
+     {.stdcall, importc: "WSACloseEvent", dynlib: "ws2_32.dll".}
+
+proc wsaResetEvent*(hEvent: Handle): bool
+     {.stdcall, importc: "WSAResetEvent", dynlib: "ws2_32.dll".}
\ No newline at end of file
diff --git a/tests/async/twinasyncrw.nim b/tests/async/twinasyncrw.nim
new file mode 100644
index 000000000..17b7d1cf5
--- /dev/null
+++ b/tests/async/twinasyncrw.nim
@@ -0,0 +1,257 @@
+discard """
+  file: "twinasyncrw.nim"
+  output: "5000"
+"""
+when defined(windows):
+  import asyncdispatch, nativesockets, net, strutils, os, winlean
+
+  var msgCount = 0
+
+  const
+    swarmSize = 50
+    messagesToSend = 100
+
+  var clientCount = 0
+
+  proc winConnect*(socket: AsyncFD, address: string, port: Port,
+    domain = Domain.AF_INET): Future[void] =
+    var retFuture = newFuture[void]("winConnect")
+    proc cb(fd: AsyncFD): bool =
+      var ret = SocketHandle(fd).getSockOptInt(cint(SOL_SOCKET), cint(SO_ERROR))
+      if ret == 0:
+          # We have connected.
+          retFuture.complete()
+          return true
+      else:
+          retFuture.fail(newException(OSError, osErrorMsg(OSErrorCode(ret))))
+          return true
+
+    var aiList = getAddrInfo(address, port, domain)
+    var success = false
+    var lastError: OSErrorCode = OSErrorCode(0)
+    var it = aiList
+    while it != nil:
+      var ret = nativesockets.connect(socket.SocketHandle, it.ai_addr, it.ai_addrlen.Socklen)
+      if ret == 0:
+        # Request to connect completed immediately.
+        success = true
+        retFuture.complete()
+        break
+      else:
+        lastError = osLastError()
+        if lastError.int32 == WSAEWOULDBLOCK:
+          success = true
+          addWrite(socket, cb)
+          break
+        else:
+          success = false
+      it = it.ai_next
+
+    dealloc(aiList)
+    if not success:
+      retFuture.fail(newException(OSError, osErrorMsg(lastError)))
+    return retFuture
+
+  proc winRecv*(socket: AsyncFD, size: int,
+             flags = {SocketFlag.SafeDisconn}): Future[string] =
+    var retFuture = newFuture[string]("recv")
+
+    var readBuffer = newString(size)
+
+    proc cb(sock: AsyncFD): bool =
+      result = true
+      let res = recv(sock.SocketHandle, addr readBuffer[0], size.cint,
+                     flags.toOSFlags())
+      if res < 0:
+        let lastError = osLastError()
+        if flags.isDisconnectionError(lastError):
+          retFuture.complete("")
+        else:
+          retFuture.fail(newException(OSError, osErrorMsg(lastError)))
+      elif res == 0:
+        # Disconnected
+        retFuture.complete("")
+      else:
+        readBuffer.setLen(res)
+        retFuture.complete(readBuffer)
+    # TODO: The following causes a massive slowdown.
+    #if not cb(socket):
+    addRead(socket, cb)
+    return retFuture
+
+  proc winRecvInto*(socket: AsyncFD, buf: cstring, size: int,
+                  flags = {SocketFlag.SafeDisconn}): Future[int] =
+    var retFuture = newFuture[int]("winRecvInto")
+
+    proc cb(sock: AsyncFD): bool =
+      result = true
+      let res = nativesockets.recv(sock.SocketHandle, buf, size.cint,
+                                   flags.toOSFlags())
+      if res < 0:
+        let lastError = osLastError()
+        if flags.isDisconnectionError(lastError):
+          retFuture.complete(0)
+        else:
+          retFuture.fail(newException(OSError, osErrorMsg(lastError)))
+      else:
+        retFuture.complete(res)
+    # TODO: The following causes a massive slowdown.
+    #if not cb(socket):
+    addRead(socket, cb)
+    return retFuture
+
+  proc winSend*(socket: AsyncFD, data: string,
+             flags = {SocketFlag.SafeDisconn}): Future[void] =
+    var retFuture = newFuture[void]("winSend")
+
+    var written = 0
+
+    proc cb(sock: AsyncFD): bool =
+      result = true
+      let netSize = data.len-written
+      var d = data.cstring
+      let res = nativesockets.send(sock.SocketHandle, addr d[written], netSize.cint, 0)
+      if res < 0:
+        let lastError = osLastError()
+        if flags.isDisconnectionError(lastError):
+          retFuture.complete()
+        else:
+          retFuture.fail(newException(OSError, osErrorMsg(lastError)))
+      else:
+        written.inc(res)
+        if res != netSize:
+          result = false # We still have data to send.
+        else:
+          retFuture.complete()
+    # TODO: The following causes crashes.
+    #if not cb(socket):
+    addWrite(socket, cb)
+    return retFuture
+
+  proc winAcceptAddr*(socket: AsyncFD, flags = {SocketFlag.SafeDisconn}):
+      Future[tuple[address: string, client: AsyncFD]] =
+    var retFuture = newFuture[tuple[address: string,
+        client: AsyncFD]]("winAcceptAddr")
+    proc cb(sock: AsyncFD): bool =
+      result = true
+      if not retFuture.finished:
+        var sockAddress = Sockaddr()
+        var addrLen = sizeof(sockAddress).Socklen
+        var client = nativesockets.accept(sock.SocketHandle,
+                                          cast[ptr SockAddr](addr(sockAddress)), addr(addrLen))
+        if client == osInvalidSocket:
+          retFuture.fail(newException(OSError, osErrorMsg(osLastError())))
+        else:
+          retFuture.complete((getAddrString(cast[ptr SockAddr](addr sockAddress)), client.AsyncFD))
+
+    addRead(socket, cb)
+    return retFuture
+
+  proc winAccept*(socket: AsyncFD,
+      flags = {SocketFlag.SafeDisconn}): Future[AsyncFD] =
+    ## Accepts a new connection. Returns a future containing the client socket
+    ## corresponding to that connection.
+    ## The future will complete when the connection is successfully accepted.
+    var retFut = newFuture[AsyncFD]("winAccept")
+    var fut = winAcceptAddr(socket, flags)
+    fut.callback =
+      proc (future: Future[tuple[address: string, client: AsyncFD]]) =
+        assert future.finished
+        if future.failed:
+          retFut.fail(future.error)
+        else:
+          retFut.complete(future.read.client)
+    return retFut
+
+
+  proc winRecvLine*(socket: AsyncFD): Future[string] {.async.} =
+    ## Reads a line of data from ``socket``. Returned future will complete once
+    ## a full line is read or an error occurs.
+    ##
+    ## If a full line is read ``\r\L`` is not
+    ## added to ``line``, however if solely ``\r\L`` is read then ``line``
+    ## will be set to it.
+    ##
+    ## If the socket is disconnected, ``line`` will be set to ``""``.
+    ##
+    ## If the socket is disconnected in the middle of a line (before ``\r\L``
+    ## is read) then line will be set to ``""``.
+    ## The partial line **will be lost**.
+    ##
+    ## **Warning**: This assumes that lines are delimited by ``\r\L``.
+    ##
+    ## **Note**: This procedure is mostly used for testing. You likely want to
+    ## use ``asyncnet.recvLine`` instead.
+
+    template addNLIfEmpty(): stmt =
+      if result.len == 0:
+        result.add("\c\L")
+
+    result = ""
+    var c = ""
+    while true:
+      c = await winRecv(socket, 1)
+      if c.len == 0:
+        return ""
+      if c == "\r":
+        c = await winRecv(socket, 1)
+        assert c == "\l"
+        addNLIfEmpty()
+        return
+      elif c == "\L":
+        addNLIfEmpty()
+        return
+      add(result, c)
+
+  proc sendMessages(client: AsyncFD) {.async.} =
+    for i in 0 .. <messagesToSend:
+      await winSend(client, "Message " & $i & "\c\L")
+
+  proc launchSwarm(port: Port) {.async.} =
+    for i in 0 .. <swarmSize:
+      var sock = newNativeSocket()
+      setBlocking(sock, false)
+
+      await winConnect(AsyncFD(sock), "localhost", port)
+      await sendMessages(AsyncFD(sock))
+      discard closeSocket(sock)
+
+  proc readMessages(client: AsyncFD) {.async.} =
+    while true:
+      var line = await winRecvLine(client)
+      if line == "":
+        closeSocket(client)
+        clientCount.inc
+        break
+      else:
+        if line.startswith("Message "):
+          msgCount.inc
+        else:
+          doAssert false
+
+  proc createServer(port: Port) {.async.} =
+    var server = newNativeSocket()
+    setBlocking(server, false)
+    block:
+      var name = Sockaddr_in()
+      name.sin_family = toInt(Domain.AF_INET).int16
+      name.sin_port = htons(uint16(port))
+      name.sin_addr.s_addr = htonl(INADDR_ANY)
+      if bindAddr(server, cast[ptr SockAddr](addr(name)),
+                  sizeof(name).Socklen) < 0'i32:
+        raiseOSError(osLastError())
+
+    discard server.listen()
+    while true:
+      asyncCheck readMessages(await winAccept(AsyncFD(server)))
+
+  asyncCheck createServer(Port(10335))
+  asyncCheck launchSwarm(Port(10335))
+  while true:
+    poll()
+    if clientCount == swarmSize: break
+
+  assert msgCount == swarmSize * messagesToSend
+  echo msgCount
+else:
+  echo(5000)