summary refs log tree commit diff stats
path: root/lib/pure/asyncnet.nim
diff options
context:
space:
mode:
Diffstat (limited to 'lib/pure/asyncnet.nim')
-rw-r--r--lib/pure/asyncnet.nim174
1 files changed, 156 insertions, 18 deletions
diff --git a/lib/pure/asyncnet.nim b/lib/pure/asyncnet.nim
index 8734bab4c..f55442488 100644
--- a/lib/pure/asyncnet.nim
+++ b/lib/pure/asyncnet.nim
@@ -47,6 +47,7 @@
 import asyncdispatch
 import rawsockets
 import net
+import os
 
 when defined(ssl):
   import openssl
@@ -54,7 +55,22 @@ when defined(ssl):
 type
   # TODO: I would prefer to just do:
   # PAsyncSocket* {.borrow: `.`.} = distinct PSocket. But that doesn't work.
-  AsyncSocketDesc {.borrow: `.`.} = distinct TSocketImpl
+  AsyncSocketDesc  = object
+    fd*: SocketHandle
+    case isBuffered*: bool # determines whether this socket is buffered.
+    of true:
+      buffer*: array[0..BufferSize, char]
+      currPos*: int # current index in buffer
+      bufLen*: int # current length of buffer
+    of false: nil
+    case isSsl: bool
+    of true:
+      when defined(ssl):
+        sslHandle: SslPtr
+        sslContext: SslContext
+        bioIn: BIO
+        bioOut: BIO
+    of false: nil
   AsyncSocket* = ref AsyncSocketDesc
 
 {.deprecated: [PAsyncSocket: AsyncSocket].}
@@ -63,7 +79,7 @@ type
 
 proc newSocket(fd: TAsyncFD, isBuff: bool): PAsyncSocket =
   assert fd != osInvalidSocket.TAsyncFD
-  new(result.PSocket)
+  new(result)
   result.fd = fd.SocketHandle
   result.isBuffered = isBuff
   if isBuff:
@@ -74,22 +90,94 @@ proc newAsyncSocket*(domain: TDomain = AF_INET, typ: TType = SOCK_STREAM,
   ## Creates a new asynchronous socket.
   result = newSocket(newAsyncRawSocket(domain, typ, protocol), buffered)
 
+when defined(ssl):
+  proc getSslError(handle: SslPtr, err: cint): cint =
+    assert err < 0
+    var ret = SSLGetError(handle, err.cint)
+    case ret
+    of SSL_ERROR_ZERO_RETURN:
+      raiseSSLError("TLS/SSL connection failed to initiate, socket closed prematurely.")
+    of SSL_ERROR_WANT_CONNECT, SSL_ERROR_WANT_ACCEPT:
+      return ret
+    of SSL_ERROR_WANT_WRITE, SSL_ERROR_WANT_READ:
+      return ret
+    of SSL_ERROR_WANT_X509_LOOKUP:
+      raiseSSLError("Function for x509 lookup has been called.")
+    of SSL_ERROR_SYSCALL, SSL_ERROR_SSL:
+      raiseSSLError()
+    else: raiseSSLError("Unknown Error")
+
+  proc sendPendingSslData(socket: AsyncSocket,
+      flags: set[TSocketFlags]) {.async.} =
+    let len = bioCtrlPending(socket.bioOut)
+    if len > 0:
+      var data = newStringOfCap(len)
+      let read = bioRead(socket.bioOut, addr data[0], len)
+      assert read != 0
+      if read < 0:
+        raiseSslError()
+      data.setLen(read)
+      await socket.fd.TAsyncFd.send(data, flags)
+
+  proc appeaseSsl(socket: AsyncSocket, flags: set[TSocketFlags],
+                  sslError: cint) {.async.} =
+    case sslError
+    of SSL_ERROR_WANT_WRITE:
+      await sendPendingSslData(socket, flags)
+    of SSL_ERROR_WANT_READ:
+      var data = await recv(socket.fd.TAsyncFD, BufferSize, flags)
+      let ret = bioWrite(socket.bioIn, addr data[0], data.len.cint)
+      if ret < 0:
+        raiseSSLError()
+    else:
+      raiseSSLError("Cannot appease SSL.")
+
+  template sslLoop(socket: AsyncSocket, flags: set[TSocketFlags],
+                   op: expr) =
+    var opResult {.inject.} = -1.cint
+    while opResult < 0:
+      opResult = op
+      # Bit hackish here.
+      # TODO: Introduce an async template transformation pragma?
+      yield sendPendingSslData(socket, flags)
+      if opResult < 0:
+        let err = getSslError(socket.sslHandle, opResult.cint)
+        yield appeaseSsl(socket, flags, err.cint)
+
 proc connect*(socket: PAsyncSocket, address: string, port: TPort,
-    af = AF_INET): Future[void] =
+    af = AF_INET) {.async.} =
   ## Connects ``socket`` to server at ``address:port``.
   ##
   ## Returns a ``Future`` which will complete when the connection succeeds
   ## or an error occurs.
-  result = connect(socket.fd.TAsyncFD, address, port, af)
+  await connect(socket.fd.TAsyncFD, address, port, af)
+  let flags = {TSocketFlags.SafeDisconn}
+  if socket.isSsl:
+    when defined(ssl):
+      sslSetConnectState(socket.sslHandle)
+      sslLoop(socket, flags, sslDoHandshake(socket.sslHandle))
 
 proc readIntoBuf(socket: PAsyncSocket,
     flags: set[TSocketFlags]): Future[int] {.async.} =
   var data = await recv(socket.fd.TAsyncFD, BufferSize, flags)
   if data.len != 0:
     copyMem(addr socket.buffer[0], addr data[0], data.len)
-  socket.bufLen = data.len
-  socket.currPos = 0
-  result = data.len
+  if socket.isSsl:
+    when defined(ssl):
+      # SSL mode.
+      let ret = bioWrite(socket.bioIn, addr socket.buffer[0], data.len.cint)
+      if ret < 0:
+        raiseSSLError()
+      sslLoop(socket, flags,
+        sslRead(socket.sslHandle, addr socket.buffer[0], BufferSize.cint))
+      socket.currPos = 0
+      socket.bufLen = opResult # Injected from sslLoop template.
+      result = opResult
+  else:
+    # Not in SSL mode.
+    socket.bufLen = data.len
+    socket.currPos = 0
+    result = data.len
 
 proc recv*(socket: PAsyncSocket, size: int,
            flags = {TSocketFlags.SafeDisconn}): Future[string] {.async.} =
@@ -131,11 +219,18 @@ proc recv*(socket: PAsyncSocket, size: int,
     result = await recv(socket.fd.TAsyncFD, size, flags)
 
 proc send*(socket: PAsyncSocket, data: string,
-           flags = {TSocketFlags.SafeDisconn}): Future[void] =
+           flags = {TSocketFlags.SafeDisconn}) {.async.} =
   ## Sends ``data`` to ``socket``. The returned future will complete once all
   ## data has been sent.
   assert socket != nil
-  result = send(socket.fd.TAsyncFD, data, flags)
+  if socket.isSsl:
+    when defined(ssl):
+      var copy = data
+      sslLoop(socket, flags,
+        sslWrite(socket.sslHandle, addr copy[0], copy.len.cint))
+      await sendPendingSslData(socket, flags)
+  else:
+    await send(socket.fd.TAsyncFD, data, flags)
 
 proc acceptAddr*(socket: PAsyncSocket, flags = {TSocketFlags.SafeDisconn}):
       Future[tuple[address: string, client: PAsyncSocket]] =
@@ -240,24 +335,67 @@ proc recvLine*(socket: PAsyncSocket,
         return
       add(result.string, c)
 
-proc bindAddr*(socket: PAsyncSocket, port = TPort(0), address = "") =
-  ## Binds ``address``:``port`` to the socket.
-  ##
-  ## If ``address`` is "" then ADDR_ANY will be bound.
-  socket.PSocket.bindAddr(port, address)
-
-proc listen*(socket: PAsyncSocket, backlog = SOMAXCONN) =
+proc listen*(socket: Socket, backlog = SOMAXCONN) {.tags: [ReadIOEffect].} =
   ## Marks ``socket`` as accepting connections.
   ## ``Backlog`` specifies the maximum length of the
   ## queue of pending connections.
   ##
   ## Raises an EOS error upon failure.
-  socket.PSocket.listen(backlog)
+  if listen(socket.fd, backlog) < 0'i32: raiseOSError(osLastError())
+
+proc bindAddr*(socket: Socket, port = Port(0), address = "") {.
+  tags: [ReadIOEffect].} =
+  ## Binds ``address``:``port`` to the socket.
+  ##
+  ## If ``address`` is "" then ADDR_ANY will be bound.
+
+  if address == "":
+    var name: Sockaddr_in
+    when defined(Windows) or defined(nimdoc):
+      name.sin_family = toInt(AF_INET).int16
+    else:
+      name.sin_family = toInt(AF_INET)
+    name.sin_port = htons(int16(port))
+    name.sin_addr.s_addr = htonl(INADDR_ANY)
+    if bindAddr(socket.fd, cast[ptr SockAddr](addr(name)),
+                  sizeof(name).Socklen) < 0'i32:
+      raiseOSError(osLastError())
+  else:
+    var aiList = getAddrInfo(address, port, AF_INET)
+    if bindAddr(socket.fd, aiList.ai_addr, aiList.ai_addrlen.Socklen) < 0'i32:
+      dealloc(aiList)
+      raiseOSError(osLastError())
+    dealloc(aiList)
 
 proc close*(socket: PAsyncSocket) =
   ## Closes the socket.
   socket.fd.TAsyncFD.closeSocket()
-  # TODO SSL
+  when defined(ssl):
+    if socket.isSSL:
+      let res = SslShutdown(socket.sslHandle)
+      if res == 0:
+        if SslShutdown(socket.sslHandle) != 1:
+          raiseSslError()
+      elif res != 1:
+        raiseSslError()
+
+when defined(ssl):
+  proc wrapSocket*(ctx: SslContext, socket: AsyncSocket) =
+    ## Wraps a socket in an SSL context. This function effectively turns
+    ## ``socket`` into an SSL socket.
+    ##
+    ## **Disclaimer**: This code is not well tested, may be very unsafe and
+    ## prone to security vulnerabilities.
+    socket.isSsl = true
+    socket.sslContext = ctx
+    socket.sslHandle = SSLNew(PSSLCTX(socket.sslContext))
+    if socket.sslHandle == nil:
+      raiseSslError()
+
+    socket.bioIn = bioNew(bio_s_mem())
+    socket.bioOut = bioNew(bio_s_mem())
+    sslSetBio(socket.sslHandle, socket.bioIn, socket.bioOut)
+
 
 when isMainModule:
   type