summary refs log tree commit diff stats
path: root/lib/pure/asyncnet.nim
diff options
context:
space:
mode:
Diffstat (limited to 'lib/pure/asyncnet.nim')
-rw-r--r--lib/pure/asyncnet.nim214
1 files changed, 137 insertions, 77 deletions
diff --git a/lib/pure/asyncnet.nim b/lib/pure/asyncnet.nim
index e7325e0d7..cfd3d7666 100644
--- a/lib/pure/asyncnet.nim
+++ b/lib/pure/asyncnet.nim
@@ -24,7 +24,7 @@
 ##
 ## Chat server
 ## ^^^^^^^^^^^
-## 
+##
 ## The following example demonstrates a simple chat server.
 ##
 ## .. code-block::nim
@@ -85,35 +85,44 @@ type
         bioIn: BIO
         bioOut: BIO
     of false: nil
+    domain: Domain
+    sockType: SockType
+    protocol: Protocol
   AsyncSocket* = ref AsyncSocketDesc
 
 {.deprecated: [PAsyncSocket: AsyncSocket].}
 
-# TODO: Save AF, domain etc info and reuse it in procs which need it like connect.
-
-proc newAsyncSocket*(fd: TAsyncFD, isBuff: bool): AsyncSocket =
+proc newAsyncSocket*(fd: AsyncFD, domain: Domain = AF_INET,
+    sockType: SockType = SOCK_STREAM,
+    protocol: Protocol = IPPROTO_TCP, buffered = true): AsyncSocket =
   ## Creates a new ``AsyncSocket`` based on the supplied params.
-  assert fd != osInvalidSocket.TAsyncFD
+  assert fd != osInvalidSocket.AsyncFD
   new(result)
   result.fd = fd.SocketHandle
-  result.isBuffered = isBuff
-  if isBuff:
+  result.isBuffered = buffered
+  result.domain = domain
+  result.sockType = sockType
+  result.protocol = protocol
+  if buffered:
     result.currPos = 0
 
-proc newAsyncSocket*(domain: Domain = AF_INET, typ: SockType = SOCK_STREAM,
+proc newAsyncSocket*(domain: Domain = AF_INET, sockType: SockType = SOCK_STREAM,
     protocol: Protocol = IPPROTO_TCP, buffered = true): AsyncSocket =
   ## Creates a new asynchronous socket.
   ##
   ## This procedure will also create a brand new file descriptor for
   ## this socket.
-  result = newAsyncSocket(newAsyncRawSocket(domain, typ, protocol), buffered)
+  result = newAsyncSocket(newAsyncRawSocket(domain, sockType, protocol), domain,
+      sockType, protocol, buffered)
 
-proc newAsyncSocket*(domain, typ, protocol: cint, buffered = true): AsyncSocket =
+proc newAsyncSocket*(domain, sockType, protocol: cint,
+    buffered = true): AsyncSocket =
   ## Creates a new asynchronous socket.
   ##
   ## This procedure will also create a brand new file descriptor for
   ## this socket.
-  result = newAsyncSocket(newAsyncRawSocket(domain, typ, protocol), buffered)
+  result = newAsyncSocket(newAsyncRawSocket(domain, sockType, protocol),
+      Domain(domain), SockType(sockType), Protocol(protocol), buffered)
 
 when defined(ssl):
   proc getSslError(handle: SslPtr, err: cint): cint =
@@ -142,7 +151,7 @@ when defined(ssl):
       if read < 0:
         raiseSslError()
       data.setLen(read)
-      await socket.fd.TAsyncFd.send(data, flags)
+      await socket.fd.AsyncFd.send(data, flags)
 
   proc appeaseSsl(socket: AsyncSocket, flags: set[SocketFlag],
                   sslError: cint) {.async.} =
@@ -150,7 +159,7 @@ when defined(ssl):
     of SSL_ERROR_WANT_WRITE:
       await sendPendingSslData(socket, flags)
     of SSL_ERROR_WANT_READ:
-      var data = await recv(socket.fd.TAsyncFD, BufferSize, flags)
+      var data = await recv(socket.fd.AsyncFD, BufferSize, flags)
       let ret = bioWrite(socket.bioIn, addr data[0], data.len.cint)
       if ret < 0:
         raiseSSLError()
@@ -169,39 +178,42 @@ when defined(ssl):
         let err = getSslError(socket.sslHandle, opResult.cint)
         yield appeaseSsl(socket, flags, err.cint)
 
-proc connect*(socket: AsyncSocket, address: string, port: Port,
-    af = AF_INET) {.async.} =
+proc connect*(socket: AsyncSocket, address: string, port: Port) {.async.} =
   ## Connects ``socket`` to server at ``address:port``.
   ##
   ## Returns a ``Future`` which will complete when the connection succeeds
   ## or an error occurs.
-  await connect(socket.fd.TAsyncFD, address, port, af)
+  await connect(socket.fd.AsyncFD, address, port, socket.domain)
   if socket.isSsl:
     when defined(ssl):
       let flags = {SocketFlag.SafeDisconn}
       sslSetConnectState(socket.sslHandle)
       sslLoop(socket, flags, sslDoHandshake(socket.sslHandle))
 
-proc readInto(buf: cstring, size: int, socket: AsyncSocket,
-              flags: set[SocketFlag]): Future[int] {.async.} =
+template readInto(buf: cstring, size: int, socket: AsyncSocket,
+                  flags: set[SocketFlag]): int =
+  ## Reads **up to** ``size`` bytes from ``socket`` into ``buf``. Note that
+  ## this is a template and not a proc.
+  var res = 0
   if socket.isSsl:
     when defined(ssl):
       # SSL mode.
       sslLoop(socket, flags,
         sslRead(socket.sslHandle, buf, size.cint))
-      result = opResult
+      res = opResult
   else:
-    var data = await recv(socket.fd.TAsyncFD, size, flags)
-    if data.len != 0:
-      copyMem(buf, addr data[0], data.len)
+    var recvIntoFut = recvInto(socket.fd.AsyncFD, buf, size, flags)
+    yield recvIntoFut
     # Not in SSL mode.
-    result = data.len
+    res = recvIntoFut.read()
+  res
 
-proc readIntoBuf(socket: AsyncSocket,
-    flags: set[SocketFlag]): Future[int] {.async.} =
-  result = await readInto(addr socket.buffer[0], BufferSize, socket, flags)
+template readIntoBuf(socket: AsyncSocket,
+    flags: set[SocketFlag]): int =
+  var size = readInto(addr socket.buffer[0], BufferSize, socket, flags)
   socket.currPos = 0
-  socket.bufLen = result
+  socket.bufLen = size
+  size
 
 proc recv*(socket: AsyncSocket, size: int,
            flags = {SocketFlag.SafeDisconn}): Future[string] {.async.} =
@@ -222,10 +234,11 @@ proc recv*(socket: AsyncSocket, size: int,
   ## to be read then the future will complete with a value of ``""``.
   if socket.isBuffered:
     result = newString(size)
+    shallow(result)
     let originalBufPos = socket.currPos
 
     if socket.bufLen == 0:
-      let res = await socket.readIntoBuf(flags - {SocketFlag.Peek})
+      let res = socket.readIntoBuf(flags - {SocketFlag.Peek})
       if res == 0:
         result.setLen(0)
         return
@@ -236,7 +249,7 @@ proc recv*(socket: AsyncSocket, size: int,
         if SocketFlag.Peek in flags:
           # We don't want to get another buffer if we're peeking.
           break
-        let res = await socket.readIntoBuf(flags - {SocketFlag.Peek})
+        let res = socket.readIntoBuf(flags - {SocketFlag.Peek})
         if res == 0:
           break
 
@@ -251,7 +264,7 @@ proc recv*(socket: AsyncSocket, size: int,
     result.setLen(read)
   else:
     result = newString(size)
-    let read = await readInto(addr result[0], size, socket, flags)
+    let read = readInto(addr result[0], size, socket, flags)
     result.setLen(read)
 
 proc send*(socket: AsyncSocket, data: string,
@@ -266,7 +279,7 @@ proc send*(socket: AsyncSocket, data: string,
         sslWrite(socket.sslHandle, addr copy[0], copy.len.cint))
       await sendPendingSslData(socket, flags)
   else:
-    await send(socket.fd.TAsyncFD, data, flags)
+    await send(socket.fd.AsyncFD, data, flags)
 
 proc acceptAddr*(socket: AsyncSocket, flags = {SocketFlag.SafeDisconn}):
       Future[tuple[address: string, client: AsyncSocket]] =
@@ -274,15 +287,16 @@ proc acceptAddr*(socket: AsyncSocket, flags = {SocketFlag.SafeDisconn}):
   ## corresponding to that connection and the remote address of the client.
   ## The future will complete when the connection is successfully accepted.
   var retFuture = newFuture[tuple[address: string, client: AsyncSocket]]("asyncnet.acceptAddr")
-  var fut = acceptAddr(socket.fd.TAsyncFD, flags)
+  var fut = acceptAddr(socket.fd.AsyncFD, flags)
   fut.callback =
-    proc (future: Future[tuple[address: string, client: TAsyncFD]]) =
+    proc (future: Future[tuple[address: string, client: AsyncFD]]) =
       assert future.finished
       if future.failed:
         retFuture.fail(future.readError)
       else:
         let resultTup = (future.read.address,
-                         newAsyncSocket(future.read.client, socket.isBuffered))
+                         newAsyncSocket(future.read.client, socket.domain,
+                         socket.sockType, socket.protocol, socket.isBuffered))
         retFuture.complete(resultTup)
   return retFuture
 
@@ -302,15 +316,14 @@ proc accept*(socket: AsyncSocket,
         retFut.complete(future.read.client)
   return retFut
 
-proc recvLine*(socket: AsyncSocket,
-    flags = {SocketFlag.SafeDisconn}): Future[string] {.async.} =
-  ## Reads a line of data from ``socket``. Returned future will complete once
-  ## a full line is read or an error occurs.
+proc recvLineInto*(socket: AsyncSocket, resString: ptr string,
+    flags = {SocketFlag.SafeDisconn}) {.async.} =
+  ## Reads a line of data from ``socket`` into ``resString``.
   ##
   ## If a full line is read ``\r\L`` is not
   ## added to ``line``, however if solely ``\r\L`` is read then ``line``
   ## will be set to it.
-  ## 
+  ##
   ## If the socket is disconnected, ``line`` will be set to ``""``.
   ##
   ## If the socket is disconnected in the middle of a line (before ``\r\L``
@@ -318,27 +331,32 @@ proc recvLine*(socket: AsyncSocket,
   ## The partial line **will be lost**.
   ##
   ## **Warning**: The ``Peek`` flag is not yet implemented.
-  ## 
-  ## **Warning**: ``recvLine`` on unbuffered sockets assumes that the protocol
-  ## uses ``\r\L`` to delimit a new line.
-  template addNLIfEmpty(): stmt =
-    if result.len == 0:
-      result.add("\c\L")
+  ##
+  ## **Warning**: ``recvLineInto`` on unbuffered sockets assumes that the
+  ## protocol uses ``\r\L`` to delimit a new line.
+  ##
+  ## **Warning**: ``recvLineInto`` currently uses a raw pointer to a string for
+  ## performance reasons. This will likely change soon to use FutureVars.
   assert SocketFlag.Peek notin flags ## TODO:
+  result = newFuture[void]("asyncnet.recvLineInto")
+
+  template addNLIfEmpty(): stmt =
+    if resString[].len == 0:
+      resString[].add("\c\L")
+
   if socket.isBuffered:
-    result = ""
     if socket.bufLen == 0:
-      let res = await socket.readIntoBuf(flags)
+      let res = socket.readIntoBuf(flags)
       if res == 0:
         return
 
     var lastR = false
     while true:
       if socket.currPos >= socket.bufLen:
-        let res = await socket.readIntoBuf(flags)
+        let res = socket.readIntoBuf(flags)
         if res == 0:
-          result = ""
-          break
+          resString[].setLen(0)
+          return
 
       case socket.buffer[socket.currPos]
       of '\r':
@@ -353,24 +371,53 @@ proc recvLine*(socket: AsyncSocket,
           socket.currPos.inc()
           return
         else:
-          result.add socket.buffer[socket.currPos]
+          resString[].add socket.buffer[socket.currPos]
       socket.currPos.inc()
   else:
-    result = ""
     var c = ""
     while true:
-      c = await recv(socket, 1, flags)
+      let recvFut = recv(socket, 1, flags)
+      c = recvFut.read()
       if c.len == 0:
-        return ""
+        resString[].setLen(0)
+        return
       if c == "\r":
-        c = await recv(socket, 1, flags) # Skip \L
+        let recvFut = recv(socket, 1, flags) # Skip \L
+        c = recvFut.read()
         assert c == "\L"
         addNLIfEmpty()
         return
       elif c == "\L":
         addNLIfEmpty()
         return
-      add(result.string, c)
+      resString[].add c
+
+proc recvLine*(socket: AsyncSocket,
+    flags = {SocketFlag.SafeDisconn}): Future[string] {.async.} =
+  ## Reads a line of data from ``socket``. Returned future will complete once
+  ## a full line is read or an error occurs.
+  ##
+  ## If a full line is read ``\r\L`` is not
+  ## added to ``line``, however if solely ``\r\L`` is read then ``line``
+  ## will be set to it.
+  ##
+  ## If the socket is disconnected, ``line`` will be set to ``""``.
+  ##
+  ## If the socket is disconnected in the middle of a line (before ``\r\L``
+  ## is read) then line will be set to ``""``.
+  ## The partial line **will be lost**.
+  ##
+  ## **Warning**: The ``Peek`` flag is not yet implemented.
+  ##
+  ## **Warning**: ``recvLine`` on unbuffered sockets assumes that the protocol
+  ## uses ``\r\L`` to delimit a new line.
+  template addNLIfEmpty(): stmt =
+    if result.len == 0:
+      result.add("\c\L")
+  assert SocketFlag.Peek notin flags ## TODO:
+
+  result = ""
+  await socket.recvLineInto(addr result, flags)
 
 proc listen*(socket: AsyncSocket, backlog = SOMAXCONN) {.tags: [ReadIOEffect].} =
   ## Marks ``socket`` as accepting connections.
@@ -385,29 +432,24 @@ proc bindAddr*(socket: AsyncSocket, port = Port(0), address = "") {.
   ## Binds ``address``:``port`` to the socket.
   ##
   ## If ``address`` is "" then ADDR_ANY will be bound.
-
-  if address == "":
-    var name: Sockaddr_in
-    when defined(Windows) or defined(nimdoc):
-      name.sin_family = toInt(AF_INET).int16
+  var realaddr = address
+  if realaddr == "":
+    case socket.domain
+    of AF_INET6: realaddr = "::"
+    of AF_INET:  realaddr = "0.0.0.0"
     else:
-      name.sin_family = toInt(AF_INET)
-    name.sin_port = htons(int16(port))
-    name.sin_addr.s_addr = htonl(INADDR_ANY)
-    if bindAddr(socket.fd, cast[ptr SockAddr](addr(name)),
-                  sizeof(name).Socklen) < 0'i32:
-      raiseOSError(osLastError())
-  else:
-    var aiList = getAddrInfo(address, port, AF_INET)
-    if bindAddr(socket.fd, aiList.ai_addr, aiList.ai_addrlen.Socklen) < 0'i32:
-      dealloc(aiList)
-      raiseOSError(osLastError())
+      raiseOSError("Unknown socket address family and no address specified to bindAddr")
+
+  var aiList = getAddrInfo(realaddr, port, socket.domain)
+  if bindAddr(socket.fd, aiList.ai_addr, aiList.ai_addrlen.Socklen) < 0'i32:
     dealloc(aiList)
+    raiseOSError(osLastError())
+  dealloc(aiList)
 
 proc close*(socket: AsyncSocket) =
   ## Closes the socket.
   defer:
-    socket.fd.TAsyncFD.closeSocket()
+    socket.fd.AsyncFD.closeSocket()
   when defined(ssl):
     if socket.isSSL:
       let res = SslShutdown(socket.sslHandle)
@@ -434,6 +476,24 @@ when defined(ssl):
     socket.bioOut = bioNew(bio_s_mem())
     sslSetBio(socket.sslHandle, socket.bioIn, socket.bioOut)
 
+  proc wrapConnectedSocket*(ctx: SslContext, socket: AsyncSocket,
+                            handshake: SslHandshakeType) =
+    ## Wraps a connected socket in an SSL context. This function effectively
+    ## turns ``socket`` into an SSL socket.
+    ##
+    ## This should be called on a connected socket, and will perform
+    ## an SSL handshake immediately.
+    ##
+    ## **Disclaimer**: This code is not well tested, may be very unsafe and
+    ## prone to security vulnerabilities.
+    wrapSocket(ctx, socket)
+
+    case handshake
+    of handshakeAsClient:
+      sslSetConnectState(socket.sslHandle)
+    of handshakeAsServer:
+      sslSetAcceptState(socket.sslHandle)
+
 proc getSockOpt*(socket: AsyncSocket, opt: SOBool, level = SOL_SOCKET): bool {.
   tags: [ReadIOEffect].} =
   ## Retrieves option ``opt`` as a boolean value.
@@ -458,7 +518,7 @@ proc isClosed*(socket: AsyncSocket): bool =
   ## Determines whether the socket has been closed.
   return socket.closed
 
-when isMainModule:
+when not defined(testing) and isMainModule:
   type
     TestCases = enum
       HighClient, LowClient, LowServer
@@ -500,11 +560,11 @@ when isMainModule:
         proc (future: Future[void]) =
           echo("Send")
           client.close()
-      
+
       var f = accept(sock)
       f.callback = onAccept
-      
+
     var f = accept(sock)
     f.callback = onAccept
   runForever()
-    
+