summary refs log tree commit diff stats
path: root/lib/upcoming/asyncdispatch.nim
diff options
context:
space:
mode:
Diffstat (limited to 'lib/upcoming/asyncdispatch.nim')
-rw-r--r--lib/upcoming/asyncdispatch.nim181
1 files changed, 32 insertions, 149 deletions
diff --git a/lib/upcoming/asyncdispatch.nim b/lib/upcoming/asyncdispatch.nim
index feee87bae..1623d8375 100644
--- a/lib/upcoming/asyncdispatch.nim
+++ b/lib/upcoming/asyncdispatch.nim
@@ -9,7 +9,7 @@
 
 include "system/inclrtl"
 
-import os, oids, tables, strutils, times, heapqueue, lists
+import os, tables, strutils, times, heapqueue, lists, options
 
 import nativesockets, net, deques
 
@@ -219,6 +219,11 @@ when defined(windows) or defined(nimdoc):
     if gDisp.isNil: gDisp = newDispatcher()
     result = gDisp
 
+  proc setGlobalDispatcher*(disp: PDispatcher) =
+    if not gDisp.isNil:
+      assert gDisp.callbacks.len == 0
+    gDisp = disp
+
   proc register*(fd: AsyncFD) =
     ## Registers ``fd`` with the dispatcher.
     let p = getGlobalDispatcher()
@@ -325,68 +330,6 @@ when defined(windows) or defined(nimdoc):
     getAcceptExSockAddrs = cast[WSAPROC_GETACCEPTEXSOCKADDRS](fun)
     close(dummySock)
 
-  proc connect*(socket: AsyncFD, address: string, port: Port,
-    domain = nativesockets.AF_INET): Future[void] =
-    ## Connects ``socket`` to server at ``address:port``.
-    ##
-    ## Returns a ``Future`` which will complete when the connection succeeds
-    ## or an error occurs.
-    verifyPresence(socket)
-    var retFuture = newFuture[void]("connect")
-    # Apparently ``ConnectEx`` expects the socket to be initially bound:
-    var saddr: Sockaddr_in
-    saddr.sin_family = int16(toInt(domain))
-    saddr.sin_port = 0
-    saddr.sin_addr.s_addr = INADDR_ANY
-    if bindAddr(socket.SocketHandle, cast[ptr SockAddr](addr(saddr)),
-                  sizeof(saddr).SockLen) < 0'i32:
-      raiseOSError(osLastError())
-
-    var aiList = getAddrInfo(address, port, domain)
-    var success = false
-    var lastError: OSErrorCode
-    var it = aiList
-    while it != nil:
-      # "the OVERLAPPED structure must remain valid until the I/O completes"
-      # http://blogs.msdn.com/b/oldnewthing/archive/2011/02/02/10123392.aspx
-      var ol = PCustomOverlapped()
-      GC_ref(ol)
-      ol.data = CompletionData(fd: socket, cb:
-        proc (fd: AsyncFD, bytesCount: Dword, errcode: OSErrorCode) =
-          if not retFuture.finished:
-            if errcode == OSErrorCode(-1):
-              retFuture.complete()
-            else:
-              retFuture.fail(newException(OSError, osErrorMsg(errcode)))
-      )
-
-      var ret = connectEx(socket.SocketHandle, it.ai_addr,
-                          sizeof(Sockaddr_in).cint, nil, 0, nil,
-                          cast[POVERLAPPED](ol))
-      if ret:
-        # Request to connect completed immediately.
-        success = true
-        retFuture.complete()
-        # We don't deallocate ``ol`` here because even though this completed
-        # immediately poll will still be notified about its completion and it will
-        # free ``ol``.
-        break
-      else:
-        lastError = osLastError()
-        if lastError.int32 == ERROR_IO_PENDING:
-          # In this case ``ol`` will be deallocated in ``poll``.
-          success = true
-          break
-        else:
-          GC_unref(ol)
-          success = false
-      it = it.ai_next
-
-    freeAddrInfo(aiList)
-    if not success:
-      retFuture.fail(newException(OSError, osErrorMsg(lastError)))
-    return retFuture
-
   proc recv*(socket: AsyncFD, size: int,
              flags = {SocketFlag.SafeDisconn}): Future[string] =
     ## Reads **up to** ``size`` bytes from ``socket``. Returned future will
@@ -739,8 +682,8 @@ when defined(windows) or defined(nimdoc):
     var lpOutputBuf = newString(lpOutputLen)
     var dwBytesReceived: Dword
     let dwReceiveDataLength = 0.Dword # We don't want any data to be read.
-    let dwLocalAddressLength = Dword(sizeof(Sockaddr_in) + 16)
-    let dwRemoteAddressLength = Dword(sizeof(Sockaddr_in) + 16)
+    let dwLocalAddressLength = Dword(sizeof(Sockaddr_in6) + 16)
+    let dwRemoteAddressLength = Dword(sizeof(Sockaddr_in6) + 16)
 
     template failAccept(errcode) =
       if flags.isDisconnectionError(errcode):
@@ -770,12 +713,14 @@ when defined(windows) or defined(nimdoc):
                              dwLocalAddressLength, dwRemoteAddressLength,
                              addr localSockaddr, addr localLen,
                              addr remoteSockaddr, addr remoteLen)
-        register(clientSock.AsyncFD)
-        # TODO: IPv6. Check ``sa_family``. http://stackoverflow.com/a/9212542/492186
-        retFuture.complete(
-          (address: $inet_ntoa(cast[ptr Sockaddr_in](remoteSockAddr).sin_addr),
-          client: clientSock.AsyncFD)
-        )
+        try:
+          let address = getAddrString(remoteSockAddr)
+          register(clientSock.AsyncFD)
+          retFuture.complete((address: address, client: clientSock.AsyncFD))
+        except:
+          # getAddrString may raise
+          clientSock.close()
+          retFuture.fail(getCurrentException())
 
     var ol = PCustomOverlapped()
     GC_ref(ol)
@@ -808,20 +753,6 @@ when defined(windows) or defined(nimdoc):
 
     return retFuture
 
-  proc newAsyncNativeSocket*(domain, sockType, protocol: cint): AsyncFD =
-    ## Creates a new socket and registers it with the dispatcher implicitly.
-    result = newNativeSocket(domain, sockType, protocol).AsyncFD
-    result.SocketHandle.setBlocking(false)
-    register(result)
-
-  proc newAsyncNativeSocket*(domain: Domain = nativesockets.AF_INET,
-                             sockType: SockType = SOCK_STREAM,
-                             protocol: Protocol = IPPROTO_TCP): AsyncFD =
-    ## Creates a new socket and registers it with the dispatcher implicitly.
-    result = newNativeSocket(domain, sockType, protocol).AsyncFD
-    result.SocketHandle.setBlocking(false)
-    register(result)
-
   proc closeSocket*(socket: AsyncFD) =
     ## Closes a socket and ensures that it is unregistered.
     socket.SocketHandle.close()
@@ -1154,28 +1085,16 @@ else:
     if gDisp.isNil: gDisp = newDispatcher()
     result = gDisp
 
+  proc setGlobalDispatcher*(disp: PDispatcher) =
+    if not gDisp.isNil:
+      assert gDisp.callbacks.len == 0
+    gDisp = disp
+
   proc register*(fd: AsyncFD) =
     let p = getGlobalDispatcher()
     var data = newAsyncData()
     p.selector.registerHandle(fd.SocketHandle, {}, data)
 
-  proc newAsyncNativeSocket*(domain: cint, sockType: cint,
-                             protocol: cint): AsyncFD =
-    result = newNativeSocket(domain, sockType, protocol).AsyncFD
-    result.SocketHandle.setBlocking(false)
-    when defined(macosx):
-      result.SocketHandle.setSockOptInt(SOL_SOCKET, SO_NOSIGPIPE, 1)
-    register(result)
-
-  proc newAsyncNativeSocket*(domain: Domain = AF_INET,
-                             sockType: SockType = SOCK_STREAM,
-                             protocol: Protocol = IPPROTO_TCP): AsyncFD =
-    result = newNativeSocket(domain, sockType, protocol).AsyncFD
-    result.SocketHandle.setBlocking(false)
-    when defined(macosx):
-      result.SocketHandle.setSockOptInt(SOL_SOCKET, SO_NOSIGPIPE, 1)
-    register(result)
-
   proc closeSocket*(sock: AsyncFD) =
     let disp = getGlobalDispatcher()
     disp.selector.unregister(sock.SocketHandle)
@@ -1331,50 +1250,6 @@ else:
     # Callback queue processing
     processPendingCallbacks(p)
 
-  proc connect*(socket: AsyncFD, address: string, port: Port,
-    domain = AF_INET): Future[void] =
-    var retFuture = newFuture[void]("connect")
-
-    proc cb(fd: AsyncFD): bool =
-      var ret = SocketHandle(fd).getSockOptInt(cint(SOL_SOCKET), cint(SO_ERROR))
-      if ret == 0:
-          # We have connected.
-          retFuture.complete()
-          return true
-      elif ret == EINTR:
-          # interrupted, keep waiting
-          return false
-      else:
-          retFuture.fail(newException(OSError, osErrorMsg(OSErrorCode(ret))))
-          return true
-
-    assert getSockDomain(socket.SocketHandle) == domain
-    var aiList = getAddrInfo(address, port, domain)
-    var success = false
-    var lastError: OSErrorCode
-    var it = aiList
-    while it != nil:
-      var ret = connect(socket.SocketHandle, it.ai_addr, it.ai_addrlen.Socklen)
-      if ret == 0:
-        # Request to connect completed immediately.
-        success = true
-        retFuture.complete()
-        break
-      else:
-        lastError = osLastError()
-        if lastError.int32 == EINTR or lastError.int32 == EINPROGRESS:
-          success = true
-          addWrite(socket, cb)
-          break
-        else:
-          success = false
-      it = it.ai_next
-
-    freeAddrInfo(aiList)
-    if not success:
-      retFuture.fail(newException(OSError, osErrorMsg(lastError)))
-    return retFuture
-
   proc recv*(socket: AsyncFD, size: int,
              flags = {SocketFlag.SafeDisconn}): Future[string] =
     var retFuture = newFuture[string]("recv")
@@ -1568,9 +1443,14 @@ else:
           else:
             retFuture.fail(newException(OSError, osErrorMsg(lastError)))
       else:
-        register(client.AsyncFD)
-        retFuture.complete((getAddrString(cast[ptr SockAddr](addr sockAddress)),
-                            client.AsyncFD))
+        try:
+          let address = getAddrString(cast[ptr SockAddr](addr sockAddress))
+          register(client.AsyncFD)
+          retFuture.complete((address, client.AsyncFD))
+        except:
+          # getAddrString may raise
+          client.close()
+          retFuture.fail(getCurrentException())
     addRead(socket, cb)
     return retFuture
 
@@ -1623,6 +1503,9 @@ else:
     data.readList.add(cb)
     p.selector.registerEvent(SelectEvent(ev), data)
 
+# Common procedures between current and upcoming asyncdispatch
+include includes.asynccommon
+
 proc sleepAsync*(ms: int): Future[void] =
   ## Suspends the execution of the current async procedure for the next
   ## ``ms`` milliseconds.