summary refs log tree commit diff stats
diff options
context:
space:
mode:
-rw-r--r--lib/pure/asyncdispatch.nim149
-rw-r--r--lib/windows/winlean.nim16
-rw-r--r--tests/async/tnewasyncudp.nim102
3 files changed, 265 insertions, 2 deletions
diff --git a/lib/pure/asyncdispatch.nim b/lib/pure/asyncdispatch.nim
index 455bebc7f..bb19f87ef 100644
--- a/lib/pure/asyncdispatch.nim
+++ b/lib/pure/asyncdispatch.nim
@@ -862,6 +862,101 @@ when defined(windows) or defined(nimdoc):
       # free ``ol``.
     return retFuture
 
+  proc sendTo*(socket: AsyncFD, data: pointer, size: int, saddr: ptr SockAddr,
+               saddrLen: Socklen,
+               flags = {SocketFlag.SafeDisconn}): Future[void] =
+    ## Sends ``data`` to specified destination ``saddr``, using
+    ## socket ``socket``. The returned future will complete once all data
+    ## has been sent.
+    verifyPresence(socket)
+    var retFuture = newFuture[void]("sendTo")
+    var dataBuf: TWSABuf
+    dataBuf.buf = cast[cstring](data)
+    dataBuf.len = size.ULONG
+    var bytesSent = 0.Dword
+    var lowFlags = 0.Dword
+
+    # we will preserve address in our stack
+    var staddr: array[128, char] # SOCKADDR_STORAGE size is 128 bytes
+    var stalen: cint = cint(saddrLen)
+    zeroMem(addr(staddr[0]), 128)
+    copyMem(addr(staddr[0]), saddr, saddrLen)
+
+    var ol = PCustomOverlapped()
+    GC_ref(ol)
+    ol.data = CompletionData(fd: socket, cb:
+      proc (fd: AsyncFD, bytesCount: Dword, errcode: OSErrorCode) =
+        if not retFuture.finished:
+          if errcode == OSErrorCode(-1):
+            retFuture.complete()
+          else:
+            retFuture.fail(newException(OSError, osErrorMsg(errcode)))
+    )
+
+    let ret = WSASendTo(socket.SocketHandle, addr dataBuf, 1, addr bytesSent,
+                        lowFlags, cast[ptr SockAddr](addr(staddr[0])),
+                        stalen, cast[POVERLAPPED](ol), nil)
+    if ret == -1:
+      let err = osLastError()
+      if err.int32 != ERROR_IO_PENDING:
+        GC_unref(ol)
+        retFuture.fail(newException(OSError, osErrorMsg(err)))
+    else:
+      retFuture.complete()
+      # We don't deallocate ``ol`` here because even though this completed
+      # immediately poll will still be notified about its completion and it will
+      # free ``ol``.
+    return retFuture
+
+  proc recvFromInto*(socket: AsyncFD, data: pointer, size: int,
+                     saddr: ptr SockAddr, saddrLen: ptr SockLen,
+                     flags = {SocketFlag.SafeDisconn}): Future[int] =
+    ## Receives a datagram data from ``socket`` into ``buf``, which must
+    ## be at least of size ``size``, address of datagram's sender will be
+    ## stored into ``saddr`` and ``saddrLen``. Returned future will complete
+    ## once one datagram has been received, and will return size of packet
+    ## received.
+    verifyPresence(socket)
+    var retFuture = newFuture[int]("recvFromInto")
+
+    var dataBuf = TWSABuf(buf: cast[cstring](data), len: size.ULONG)
+
+    var bytesReceived = 0.Dword
+    var lowFlags = 0.Dword
+
+    var ol = PCustomOverlapped()
+    GC_ref(ol)
+    ol.data = CompletionData(fd: socket, cb:
+      proc (fd: AsyncFD, bytesCount: Dword, errcode: OSErrorCode) =
+        if not retFuture.finished:
+          if errcode == OSErrorCode(-1):
+            assert bytesCount <= size
+            retFuture.complete(bytesCount)
+          else:
+            # datagram sockets don't have disconnection,
+            # so we can just raise an exception
+            retFuture.fail(newException(OSError, osErrorMsg(errcode)))
+    )
+
+    let res = WSARecvFrom(socket.SocketHandle, addr dataBuf, 1,
+                          addr bytesReceived, addr lowFlags,
+                          saddr, cast[ptr cint](saddrLen),
+                          cast[POVERLAPPED](ol), nil)
+    if res == -1:
+      let err = osLastError()
+      if err.int32 != ERROR_IO_PENDING:
+        GC_unref(ol)
+        retFuture.fail(newException(OSError, osErrorMsg(err)))
+    else:
+      # Request completed immediately.
+      if bytesReceived != 0:
+        assert bytesReceived <= size
+        retFuture.complete(bytesReceived)
+      else:
+        if hasOverlappedIoCompleted(cast[POVERLAPPED](ol)):
+          retFuture.complete(bytesReceived)
+    return retFuture
+
   proc acceptAddr*(socket: AsyncFD, flags = {SocketFlag.SafeDisconn}):
       Future[tuple[address: string, client: AsyncFD]] =
     ## Accepts a new connection. Returns a future containing the client socket
@@ -1359,6 +1454,60 @@ else:
     addWrite(socket, cb)
     return retFuture
 
+  proc sendTo*(socket: AsyncFD, data: pointer, size: int, saddr: ptr SockAddr,
+               saddrLen: SockLen,
+               flags = {SocketFlag.SafeDisconn}): Future[void] =
+    ## Sends ``data`` of size ``size`` in bytes to specified destination
+    ## (``saddr`` of size ``saddrLen`` in bytes, using socket ``socket``.
+    ## The returned future will complete once all data has been sent.
+    var retFuture = newFuture[void]("sendTo")
+
+    # we will preserve address in our stack
+    var staddr: array[128, char] # SOCKADDR_STORAGE size is 128 bytes
+    var stalen = saddrLen
+    zeroMem(addr(staddr[0]), 128)
+    copyMem(addr(staddr[0]), saddr, saddrLen)
+
+    proc cb(sock: AsyncFD): bool =
+      result = true
+      let res = sendto(sock.SocketHandle, data, size, MSG_NOSIGNAL,
+                       cast[ptr SockAddr](addr(staddr[0])), stalen)
+      if res < 0:
+        let lastError = osLastError()
+        if lastError.int32 notin {EINTR, EWOULDBLOCK, EAGAIN}:
+          retFuture.fail(newException(OSError, osErrorMsg(lastError)))
+        else:
+          result = false # We still want this callback to be called.
+      else:
+        retFuture.complete()
+
+    addWrite(socket, cb)
+    return retFuture
+
+  proc recvFromInto*(socket: AsyncFD, data: pointer, size: int,
+                     saddr: ptr SockAddr, saddrLen: ptr SockLen,
+                     flags = {SocketFlag.SafeDisconn}): Future[int] =
+    ## Receives a datagram data from ``socket`` into ``data``, which must
+    ## be at least of size ``size`` in bytes, address of datagram's sender
+    ## will be stored into ``saddr`` and ``saddrLen``. Returned future will
+    ## complete once one datagram has been received, and will return size
+    ## of packet received.
+    var retFuture = newFuture[int]("recvFromInto")
+    proc cb(sock: AsyncFD): bool =
+      result = true
+      let res = recvfrom(sock.SocketHandle, data, size.cint, flags.toOSFlags(),
+                         saddr, saddrLen)
+      if res < 0:
+        let lastError = osLastError()
+        if lastError.int32 notin {EINTR, EWOULDBLOCK, EAGAIN}:
+          retFuture.fail(newException(OSError, osErrorMsg(lastError)))
+        else:
+          result = false
+      else:
+        retFuture.complete(res)
+    addRead(socket, cb)
+    return retFuture
+
   proc acceptAddr*(socket: AsyncFD, flags = {SocketFlag.SafeDisconn}):
       Future[tuple[address: string, client: AsyncFD]] =
     var retFuture = newFuture[tuple[address: string,
diff --git a/lib/windows/winlean.nim b/lib/windows/winlean.nim
index 2989e5ac9..04edeb2cb 100644
--- a/lib/windows/winlean.nim
+++ b/lib/windows/winlean.nim
@@ -455,7 +455,7 @@ type
     sin_zero*: array[0..7, char]
 
   In6_addr* {.importc: "IN6_ADDR", header: "winsock2.h".} = object
-    bytes*: array[0..15, char]
+    bytes* {.importc: "u.Byte".}: array[0..15, char]
 
   Sockaddr_in6* {.importc: "SOCKADDR_IN6",
                    header: "ws2tcpip.h".} = object
@@ -825,11 +825,23 @@ proc WSARecv*(s: SocketHandle, buf: ptr TWSABuf, bufCount: DWORD,
   completionProc: POVERLAPPED_COMPLETION_ROUTINE): cint {.
   stdcall, importc: "WSARecv", dynlib: "Ws2_32.dll".}
 
+proc WSARecvFrom*(s: SocketHandle, buf: ptr TWSABuf, bufCount: DWORD,
+                  bytesReceived: PDWORD, flags: PDWORD, name: ptr SockAddr,
+                  namelen: ptr cint, lpOverlapped: POVERLAPPED,
+                  completionProc: POVERLAPPED_COMPLETION_ROUTINE): cint {.
+     stdcall, importc: "WSARecvFrom", dynlib: "Ws2_32.dll".}
+
 proc WSASend*(s: SocketHandle, buf: ptr TWSABuf, bufCount: DWORD,
   bytesSent: PDWORD, flags: DWORD, lpOverlapped: POVERLAPPED,
   completionProc: POVERLAPPED_COMPLETION_ROUTINE): cint {.
   stdcall, importc: "WSASend", dynlib: "Ws2_32.dll".}
 
+proc WSASendTo*(s: SocketHandle, buf: ptr TWSABuf, bufCount: DWORD,
+                bytesSent: PDWORD, flags: DWORD, name: ptr SockAddr,
+                namelen: cint, lpOverlapped: POVERLAPPED,
+                completionProc: POVERLAPPED_COMPLETION_ROUTINE): cint {.
+     stdcall, importc: "WSASendTo", dynlib: "Ws2_32.dll".}
+
 proc get_osfhandle*(fd:FileHandle): Handle {.
   importc: "_get_osfhandle", header:"<io.h>".}
 
@@ -994,7 +1006,7 @@ const
   FD_ALL_EVENTS* = 0x000003FF'i32
 
 proc wsaEventSelect*(s: SocketHandle, hEventObject: Handle,
-                     lNetworkEvents: clong): cint 
+                     lNetworkEvents: clong): cint
     {.stdcall, importc: "WSAEventSelect", dynlib: "ws2_32.dll".}
 
 proc wsaCreateEvent*(): Handle
diff --git a/tests/async/tnewasyncudp.nim b/tests/async/tnewasyncudp.nim
new file mode 100644
index 000000000..7025fa20d
--- /dev/null
+++ b/tests/async/tnewasyncudp.nim
@@ -0,0 +1,102 @@
+discard """
+  file: "tnewasyncudp.nim"
+  output: "5000"
+"""
+import asyncdispatch, nativesockets, net, strutils, os
+
+when defined(windows):
+  import winlean
+else:
+  import posix
+
+var msgCount = 0
+var recvCount = 0
+
+const
+  messagesToSend = 100
+  swarmSize = 50
+  serverPort = 10333
+
+var
+  sendports = 0
+  recvports = 0
+
+proc saveSendingPort(port: int) =
+  sendports = sendports + port
+
+proc saveReceivedPort(port: int) =
+  recvports = recvports + port
+
+proc prepareAddress(intaddr: uint32, intport: uint16): ptr Sockaddr_in =
+  result = cast[ptr Sockaddr_in](alloc0(sizeof(Sockaddr_in)))
+  when defined(windows):
+    result.sin_family = toInt(nativesockets.AF_INET).int16
+  else:
+    result.sin_family = toInt(nativesockets.AF_INET)
+  result.sin_port = htons(intport)
+  result.sin_addr.s_addr = htonl(intaddr)
+
+proc launchSwarm(name: ptr SockAddr) {.async.} =
+  var i = 0
+  var k = 0
+  while i < swarmSize:
+    var peeraddr = prepareAddress(INADDR_ANY, 0)
+    var sock = newAsyncNativeSocket(nativesockets.AF_INET,
+                                    nativesockets.SOCK_DGRAM,
+                                    Protocol.IPPROTO_UDP)
+    if bindAddr(sock.SocketHandle, cast[ptr SockAddr](peeraddr),
+              sizeof(Sockaddr_in).Socklen) < 0'i32:
+      raiseOSError(osLastError())
+    let sockport = getSockName(sock.SocketHandle).int
+    k = 0
+    while k < messagesToSend:
+      var message = "Message " & $(i * messagesToSend + k)
+      await sendTo(sock, addr message[0], len(message),
+                   name, sizeof(Sockaddr_in).SockLen)
+      saveSendingPort(sockport)
+      inc(k)
+    closeSocket(sock)
+    inc(i)
+
+proc readMessages(server: AsyncFD) {.async.} =
+  var buffer: array[16384, char]
+  var slen = sizeof(Sockaddr_in).SockLen
+  var saddr = Sockaddr_in()
+  var maxResponses = (swarmSize * messagesToSend)
+
+  var i = 0
+  while i < maxResponses:
+    zeroMem(addr(buffer[0]), 16384)
+    zeroMem(cast[pointer](addr(saddr)), sizeof(Sockaddr_in))
+    var size = await recvFromInto(server, cast[cstring](addr buffer[0]),
+                                  16384, cast[ptr SockAddr](addr(saddr)),
+                                  addr(slen))
+    size = 0
+    var grammString = $buffer
+    if grammString.startswith("Message ") and
+       saddr.sin_addr.s_addr == 0x100007F:
+      inc(msgCount)
+      saveReceivedPort(ntohs(saddr.sin_port).int)
+      inc(recvCount)
+    inc(i)
+
+proc createServer() {.async.} =
+  var name = prepareAddress(INADDR_ANY, serverPort)
+  var server = newAsyncNativeSocket(nativesockets.AF_INET,
+                                    nativesockets.SOCK_DGRAM,
+                                    Protocol.IPPROTO_UDP)
+  if bindAddr(server.SocketHandle, cast[ptr SockAddr](name),
+              sizeof(Sockaddr_in).Socklen) < 0'i32:
+    raiseOSError(osLastError())
+  asyncCheck readMessages(server)
+
+var name = prepareAddress(0x7F000001, serverPort) # 127.0.0.1
+asyncCheck createServer()
+asyncCheck launchSwarm(cast[ptr SockAddr](name))
+while true:
+  poll()
+  if recvCount == swarmSize * messagesToSend:
+    break
+assert msgCount == swarmSize * messagesToSend
+assert sendports == recvports
+echo msgCount