summary refs log tree commit diff stats
diff options
context:
space:
mode:
authorAraq <rumpf_a@web.de>2017-02-07 14:44:27 +0100
committerAraq <rumpf_a@web.de>2017-02-07 14:44:27 +0100
commit4790b6b203ff8c0c54787fec4978eb7befa8c688 (patch)
tree5c09bec04b5d140ea48ec571744e8f87b9b2da22
parentb5b9c7d2e2839767fc514d7499b31d45eb732150 (diff)
parent7a839d7b02770987967c5485cb1ccb33cd45b380 (diff)
downloadNim-4790b6b203ff8c0c54787fec4978eb7befa8c688.tar.gz
Merge branch 'accept-close-race-5279' of https://github.com/endragor/Nim into endragor-accept-close-race-5279
-rw-r--r--lib/pure/asyncdispatch.nim43
-rw-r--r--lib/pure/nativesockets.nim28
-rw-r--r--lib/upcoming/asyncdispatch.nim43
-rw-r--r--lib/windows/winlean.nim8
-rw-r--r--tests/async/tacceptcloserace.nim36
5 files changed, 114 insertions, 44 deletions
diff --git a/lib/pure/asyncdispatch.nim b/lib/pure/asyncdispatch.nim
index 8db7eba25..107e26c0c 100644
--- a/lib/pure/asyncdispatch.nim
+++ b/lib/pure/asyncdispatch.nim
@@ -753,26 +753,6 @@ when defined(windows) or defined(nimdoc):
     let dwLocalAddressLength = Dword(sizeof(Sockaddr_in) + 16)
     let dwRemoteAddressLength = Dword(sizeof(Sockaddr_in) + 16)
 
-    template completeAccept() {.dirty.} =
-      var listenSock = socket
-      let setoptRet = setsockopt(clientSock, SOL_SOCKET,
-          SO_UPDATE_ACCEPT_CONTEXT, addr listenSock,
-          sizeof(listenSock).SockLen)
-      if setoptRet != 0: raiseOSError(osLastError())
-
-      var localSockaddr, remoteSockaddr: ptr SockAddr
-      var localLen, remoteLen: int32
-      getAcceptExSockaddrs(addr lpOutputBuf[0], dwReceiveDataLength,
-                           dwLocalAddressLength, dwRemoteAddressLength,
-                           addr localSockaddr, addr localLen,
-                           addr remoteSockaddr, addr remoteLen)
-      register(clientSock.AsyncFD)
-      # TODO: IPv6. Check ``sa_family``. http://stackoverflow.com/a/9212542/492186
-      retFuture.complete(
-        (address: $inet_ntoa(cast[ptr Sockaddr_in](remoteSockAddr).sin_addr),
-         client: clientSock.AsyncFD)
-      )
-
     template failAccept(errcode) =
       if flags.isDisconnectionError(errcode):
         var newAcceptFut = acceptAddr(socket, flags)
@@ -785,6 +765,29 @@ when defined(windows) or defined(nimdoc):
       else:
         retFuture.fail(newException(OSError, osErrorMsg(errcode)))
 
+    template completeAccept() {.dirty.} =
+      var listenSock = socket
+      let setoptRet = setsockopt(clientSock, SOL_SOCKET,
+          SO_UPDATE_ACCEPT_CONTEXT, addr listenSock,
+          sizeof(listenSock).SockLen)
+      if setoptRet != 0:
+        let errcode = osLastError()
+        checkCloseError clientSock.closeSocket()
+        failAccept(errcode)
+      else:
+        var localSockaddr, remoteSockaddr: ptr SockAddr
+        var localLen, remoteLen: int32
+        getAcceptExSockaddrs(addr lpOutputBuf[0], dwReceiveDataLength,
+                             dwLocalAddressLength, dwRemoteAddressLength,
+                             addr localSockaddr, addr localLen,
+                             addr remoteSockaddr, addr remoteLen)
+        register(clientSock.AsyncFD)
+        # TODO: IPv6. Check ``sa_family``. http://stackoverflow.com/a/9212542/492186
+        retFuture.complete(
+          (address: $inet_ntoa(cast[ptr Sockaddr_in](remoteSockAddr).sin_addr),
+          client: clientSock.AsyncFD)
+        )
+
     var ol = PCustomOverlapped()
     GC_ref(ol)
     ol.data = CompletionData(fd: socket, cb:
diff --git a/lib/pure/nativesockets.nim b/lib/pure/nativesockets.nim
index 5f10a7b4c..17e23c8e0 100644
--- a/lib/pure/nativesockets.nim
+++ b/lib/pure/nativesockets.nim
@@ -22,11 +22,12 @@ const useWinVersion = defined(Windows) or defined(nimdoc)
 when useWinVersion:
   import winlean
   export WSAEWOULDBLOCK, WSAECONNRESET, WSAECONNABORTED, WSAENETRESET,
+         WSANOTINITIALISED, WSAENOTSOCK, WSAEINPROGRESS, WSAEINTR,
          WSAEDISCON, ERROR_NETNAME_DELETED
 else:
   import posix
   export fcntl, F_GETFL, O_NONBLOCK, F_SETFL, EAGAIN, EWOULDBLOCK, MSG_NOSIGNAL,
-    EINTR, EINPROGRESS, ECONNRESET, EPIPE, ENETRESET
+    EINTR, EINPROGRESS, ECONNRESET, EPIPE, ENETRESET, EBADF
   export Sockaddr_storage, Sockaddr_un, Sockaddr_un_path_length
 
 export SocketHandle, Sockaddr_in, Addrinfo, INADDR_ANY, SockAddr, SockLen,
@@ -619,3 +620,28 @@ proc selectWrite*(writefds: var seq[SocketHandle],
 when defined(Windows):
   var wsa: WSAData
   if wsaStartup(0x0101'i16, addr wsa) != 0: raiseOSError(osLastError())
+
+proc checkCloseError*(ret: cint) =
+  ## Asserts that the return value of close() or closeSocket() syscall
+  ## does not indicate a programming error (such as invalid descriptor).
+  ## This must only be used when an error has already occurred and
+  ## you are performing a cleanup.
+  ## Otherwise, error handling must be performed as usual.
+  ##
+  ## This procedure must be called right after performing the syscall. Example:
+  ##
+  ## .. code-block:: nim
+  ##
+  ##  let ret = someSysCall()
+  ##  if ret != 0:
+  ##    let errcode = osLastError()
+  ##    checkCloseError sock.closeSocket()
+  ##    raise newException(OSError, osErrorMsg(errcode))
+
+  if ret != 0:
+    let errcode = osLastError()
+    when useWinVersion:
+      doAssert(errcode.int32 notin {WSANOTINITIALISED, WSAENOTSOCK,
+                                    WSAEINPROGRESS, WSAEINTR, WSAEWOULDBLOCK})
+    else:
+      doAssert(errcode.int32 notin {EBADF})
diff --git a/lib/upcoming/asyncdispatch.nim b/lib/upcoming/asyncdispatch.nim
index c4b3e11e9..74619ab42 100644
--- a/lib/upcoming/asyncdispatch.nim
+++ b/lib/upcoming/asyncdispatch.nim
@@ -738,26 +738,6 @@ when defined(windows) or defined(nimdoc):
     let dwLocalAddressLength = Dword(sizeof(Sockaddr_in) + 16)
     let dwRemoteAddressLength = Dword(sizeof(Sockaddr_in) + 16)
 
-    template completeAccept() {.dirty.} =
-      var listenSock = socket
-      let setoptRet = setsockopt(clientSock, SOL_SOCKET,
-          SO_UPDATE_ACCEPT_CONTEXT, addr listenSock,
-          sizeof(listenSock).SockLen)
-      if setoptRet != 0: raiseOSError(osLastError())
-
-      var localSockaddr, remoteSockaddr: ptr SockAddr
-      var localLen, remoteLen: int32
-      getAcceptExSockaddrs(addr lpOutputBuf[0], dwReceiveDataLength,
-                           dwLocalAddressLength, dwRemoteAddressLength,
-                           addr localSockaddr, addr localLen,
-                           addr remoteSockaddr, addr remoteLen)
-      register(clientSock.AsyncFD)
-      # TODO: IPv6. Check ``sa_family``. http://stackoverflow.com/a/9212542/492186
-      retFuture.complete(
-        (address: $inet_ntoa(cast[ptr Sockaddr_in](remoteSockAddr).sin_addr),
-         client: clientSock.AsyncFD)
-      )
-
     template failAccept(errcode) =
       if flags.isDisconnectionError(errcode):
         var newAcceptFut = acceptAddr(socket, flags)
@@ -770,6 +750,29 @@ when defined(windows) or defined(nimdoc):
       else:
         retFuture.fail(newException(OSError, osErrorMsg(errcode)))
 
+    template completeAccept() {.dirty.} =
+      var listenSock = socket
+      let setoptRet = setsockopt(clientSock, SOL_SOCKET,
+          SO_UPDATE_ACCEPT_CONTEXT, addr listenSock,
+          sizeof(listenSock).SockLen)
+      if setoptRet != 0:
+        let errcode = osLastError()
+        checkCloseError clientSock.closeSocket()
+        failAccept(errcode)
+      else:
+        var localSockaddr, remoteSockaddr: ptr SockAddr
+        var localLen, remoteLen: int32
+        getAcceptExSockaddrs(addr lpOutputBuf[0], dwReceiveDataLength,
+                             dwLocalAddressLength, dwRemoteAddressLength,
+                             addr localSockaddr, addr localLen,
+                             addr remoteSockaddr, addr remoteLen)
+        register(clientSock.AsyncFD)
+        # TODO: IPv6. Check ``sa_family``. http://stackoverflow.com/a/9212542/492186
+        retFuture.complete(
+          (address: $inet_ntoa(cast[ptr Sockaddr_in](remoteSockAddr).sin_addr),
+          client: clientSock.AsyncFD)
+        )
+
     var ol = PCustomOverlapped()
     GC_ref(ol)
     ol.data = CompletionData(fd: socket, cb:
diff --git a/lib/windows/winlean.nim b/lib/windows/winlean.nim
index 367fa8b81..02821b792 100644
--- a/lib/windows/winlean.nim
+++ b/lib/windows/winlean.nim
@@ -419,9 +419,6 @@ const
 
   ws2dll = "Ws2_32.dll"
 
-  WSAEWOULDBLOCK* = 10035
-  WSAEINPROGRESS* = 10036
-
 proc wsaGetLastError*(): cint {.importc: "WSAGetLastError", dynlib: ws2dll.}
 
 type
@@ -760,6 +757,11 @@ const
   WSAEDISCON* = 10101
   WSAENETRESET* = 10052
   WSAETIMEDOUT* = 10060
+  WSANOTINITIALISED* = 10093
+  WSAENOTSOCK* = 10038
+  WSAEINPROGRESS* = 10036
+  WSAEINTR* = 10004
+  WSAEWOULDBLOCK* = 10035
   ERROR_NETNAME_DELETED* = 64
   STATUS_PENDING* = 0x103
 
diff --git a/tests/async/tacceptcloserace.nim b/tests/async/tacceptcloserace.nim
new file mode 100644
index 000000000..cbb5b5098
--- /dev/null
+++ b/tests/async/tacceptcloserace.nim
@@ -0,0 +1,36 @@
+discard """
+  exitcode: 0
+  output: ""
+"""
+
+import asyncdispatch, net, os, nativesockets
+
+# bug: https://github.com/nim-lang/Nim/issues/5279
+
+proc setupServerSocket(hostname: string, port: Port): AsyncFD =
+  let fd = newNativeSocket()
+  if fd == osInvalidSocket:
+    raiseOSError(osLastError())
+  setSockOptInt(fd, SOL_SOCKET, SO_REUSEADDR, 1)
+  var aiList = getAddrInfo(hostname, port)
+  if bindAddr(fd, aiList.ai_addr, aiList.ai_addrlen.Socklen) < 0'i32:
+    freeAddrInfo(aiList)
+    raiseOSError(osLastError())
+  freeAddrInfo(aiList)
+  if listen(fd) != 0:
+    raiseOSError(osLastError())
+  setBlocking(fd, false)
+  result = fd.AsyncFD
+  register(result)
+
+const port = Port(5614)
+for i in 0..100:
+  let serverFd = setupServerSocket("localhost", port)
+  serverFd.accept().callback = proc(fut: Future[AsyncFD]) =
+    if not fut.failed:
+      fut.read().closeSocket()
+
+  var fd = newAsyncNativeSocket()
+  waitFor fd.connect("localhost", port)
+  serverFd.closeSocket()
+  fd.closeSocket()