diff options
Diffstat (limited to 'lib/pure/asyncnet.nim')
-rw-r--r-- | lib/pure/asyncnet.nim | 174 |
1 files changed, 156 insertions, 18 deletions
diff --git a/lib/pure/asyncnet.nim b/lib/pure/asyncnet.nim index 8734bab4c..f55442488 100644 --- a/lib/pure/asyncnet.nim +++ b/lib/pure/asyncnet.nim @@ -47,6 +47,7 @@ import asyncdispatch import rawsockets import net +import os when defined(ssl): import openssl @@ -54,7 +55,22 @@ when defined(ssl): type # TODO: I would prefer to just do: # PAsyncSocket* {.borrow: `.`.} = distinct PSocket. But that doesn't work. - AsyncSocketDesc {.borrow: `.`.} = distinct TSocketImpl + AsyncSocketDesc = object + fd*: SocketHandle + case isBuffered*: bool # determines whether this socket is buffered. + of true: + buffer*: array[0..BufferSize, char] + currPos*: int # current index in buffer + bufLen*: int # current length of buffer + of false: nil + case isSsl: bool + of true: + when defined(ssl): + sslHandle: SslPtr + sslContext: SslContext + bioIn: BIO + bioOut: BIO + of false: nil AsyncSocket* = ref AsyncSocketDesc {.deprecated: [PAsyncSocket: AsyncSocket].} @@ -63,7 +79,7 @@ type proc newSocket(fd: TAsyncFD, isBuff: bool): PAsyncSocket = assert fd != osInvalidSocket.TAsyncFD - new(result.PSocket) + new(result) result.fd = fd.SocketHandle result.isBuffered = isBuff if isBuff: @@ -74,22 +90,94 @@ proc newAsyncSocket*(domain: TDomain = AF_INET, typ: TType = SOCK_STREAM, ## Creates a new asynchronous socket. result = newSocket(newAsyncRawSocket(domain, typ, protocol), buffered) +when defined(ssl): + proc getSslError(handle: SslPtr, err: cint): cint = + assert err < 0 + var ret = SSLGetError(handle, err.cint) + case ret + of SSL_ERROR_ZERO_RETURN: + raiseSSLError("TLS/SSL connection failed to initiate, socket closed prematurely.") + of SSL_ERROR_WANT_CONNECT, SSL_ERROR_WANT_ACCEPT: + return ret + of SSL_ERROR_WANT_WRITE, SSL_ERROR_WANT_READ: + return ret + of SSL_ERROR_WANT_X509_LOOKUP: + raiseSSLError("Function for x509 lookup has been called.") + of SSL_ERROR_SYSCALL, SSL_ERROR_SSL: + raiseSSLError() + else: raiseSSLError("Unknown Error") + + proc sendPendingSslData(socket: AsyncSocket, + flags: set[TSocketFlags]) {.async.} = + let len = bioCtrlPending(socket.bioOut) + if len > 0: + var data = newStringOfCap(len) + let read = bioRead(socket.bioOut, addr data[0], len) + assert read != 0 + if read < 0: + raiseSslError() + data.setLen(read) + await socket.fd.TAsyncFd.send(data, flags) + + proc appeaseSsl(socket: AsyncSocket, flags: set[TSocketFlags], + sslError: cint) {.async.} = + case sslError + of SSL_ERROR_WANT_WRITE: + await sendPendingSslData(socket, flags) + of SSL_ERROR_WANT_READ: + var data = await recv(socket.fd.TAsyncFD, BufferSize, flags) + let ret = bioWrite(socket.bioIn, addr data[0], data.len.cint) + if ret < 0: + raiseSSLError() + else: + raiseSSLError("Cannot appease SSL.") + + template sslLoop(socket: AsyncSocket, flags: set[TSocketFlags], + op: expr) = + var opResult {.inject.} = -1.cint + while opResult < 0: + opResult = op + # Bit hackish here. + # TODO: Introduce an async template transformation pragma? + yield sendPendingSslData(socket, flags) + if opResult < 0: + let err = getSslError(socket.sslHandle, opResult.cint) + yield appeaseSsl(socket, flags, err.cint) + proc connect*(socket: PAsyncSocket, address: string, port: TPort, - af = AF_INET): Future[void] = + af = AF_INET) {.async.} = ## Connects ``socket`` to server at ``address:port``. ## ## Returns a ``Future`` which will complete when the connection succeeds ## or an error occurs. - result = connect(socket.fd.TAsyncFD, address, port, af) + await connect(socket.fd.TAsyncFD, address, port, af) + let flags = {TSocketFlags.SafeDisconn} + if socket.isSsl: + when defined(ssl): + sslSetConnectState(socket.sslHandle) + sslLoop(socket, flags, sslDoHandshake(socket.sslHandle)) proc readIntoBuf(socket: PAsyncSocket, flags: set[TSocketFlags]): Future[int] {.async.} = var data = await recv(socket.fd.TAsyncFD, BufferSize, flags) if data.len != 0: copyMem(addr socket.buffer[0], addr data[0], data.len) - socket.bufLen = data.len - socket.currPos = 0 - result = data.len + if socket.isSsl: + when defined(ssl): + # SSL mode. + let ret = bioWrite(socket.bioIn, addr socket.buffer[0], data.len.cint) + if ret < 0: + raiseSSLError() + sslLoop(socket, flags, + sslRead(socket.sslHandle, addr socket.buffer[0], BufferSize.cint)) + socket.currPos = 0 + socket.bufLen = opResult # Injected from sslLoop template. + result = opResult + else: + # Not in SSL mode. + socket.bufLen = data.len + socket.currPos = 0 + result = data.len proc recv*(socket: PAsyncSocket, size: int, flags = {TSocketFlags.SafeDisconn}): Future[string] {.async.} = @@ -131,11 +219,18 @@ proc recv*(socket: PAsyncSocket, size: int, result = await recv(socket.fd.TAsyncFD, size, flags) proc send*(socket: PAsyncSocket, data: string, - flags = {TSocketFlags.SafeDisconn}): Future[void] = + flags = {TSocketFlags.SafeDisconn}) {.async.} = ## Sends ``data`` to ``socket``. The returned future will complete once all ## data has been sent. assert socket != nil - result = send(socket.fd.TAsyncFD, data, flags) + if socket.isSsl: + when defined(ssl): + var copy = data + sslLoop(socket, flags, + sslWrite(socket.sslHandle, addr copy[0], copy.len.cint)) + await sendPendingSslData(socket, flags) + else: + await send(socket.fd.TAsyncFD, data, flags) proc acceptAddr*(socket: PAsyncSocket, flags = {TSocketFlags.SafeDisconn}): Future[tuple[address: string, client: PAsyncSocket]] = @@ -240,24 +335,67 @@ proc recvLine*(socket: PAsyncSocket, return add(result.string, c) -proc bindAddr*(socket: PAsyncSocket, port = TPort(0), address = "") = - ## Binds ``address``:``port`` to the socket. - ## - ## If ``address`` is "" then ADDR_ANY will be bound. - socket.PSocket.bindAddr(port, address) - -proc listen*(socket: PAsyncSocket, backlog = SOMAXCONN) = +proc listen*(socket: Socket, backlog = SOMAXCONN) {.tags: [ReadIOEffect].} = ## Marks ``socket`` as accepting connections. ## ``Backlog`` specifies the maximum length of the ## queue of pending connections. ## ## Raises an EOS error upon failure. - socket.PSocket.listen(backlog) + if listen(socket.fd, backlog) < 0'i32: raiseOSError(osLastError()) + +proc bindAddr*(socket: Socket, port = Port(0), address = "") {. + tags: [ReadIOEffect].} = + ## Binds ``address``:``port`` to the socket. + ## + ## If ``address`` is "" then ADDR_ANY will be bound. + + if address == "": + var name: Sockaddr_in + when defined(Windows) or defined(nimdoc): + name.sin_family = toInt(AF_INET).int16 + else: + name.sin_family = toInt(AF_INET) + name.sin_port = htons(int16(port)) + name.sin_addr.s_addr = htonl(INADDR_ANY) + if bindAddr(socket.fd, cast[ptr SockAddr](addr(name)), + sizeof(name).Socklen) < 0'i32: + raiseOSError(osLastError()) + else: + var aiList = getAddrInfo(address, port, AF_INET) + if bindAddr(socket.fd, aiList.ai_addr, aiList.ai_addrlen.Socklen) < 0'i32: + dealloc(aiList) + raiseOSError(osLastError()) + dealloc(aiList) proc close*(socket: PAsyncSocket) = ## Closes the socket. socket.fd.TAsyncFD.closeSocket() - # TODO SSL + when defined(ssl): + if socket.isSSL: + let res = SslShutdown(socket.sslHandle) + if res == 0: + if SslShutdown(socket.sslHandle) != 1: + raiseSslError() + elif res != 1: + raiseSslError() + +when defined(ssl): + proc wrapSocket*(ctx: SslContext, socket: AsyncSocket) = + ## Wraps a socket in an SSL context. This function effectively turns + ## ``socket`` into an SSL socket. + ## + ## **Disclaimer**: This code is not well tested, may be very unsafe and + ## prone to security vulnerabilities. + socket.isSsl = true + socket.sslContext = ctx + socket.sslHandle = SSLNew(PSSLCTX(socket.sslContext)) + if socket.sslHandle == nil: + raiseSslError() + + socket.bioIn = bioNew(bio_s_mem()) + socket.bioOut = bioNew(bio_s_mem()) + sslSetBio(socket.sslHandle, socket.bioIn, socket.bioOut) + when isMainModule: type |