diff options
Diffstat (limited to 'lib/pure/asyncnet.nim')
-rw-r--r-- | lib/pure/asyncnet.nim | 400 |
1 files changed, 293 insertions, 107 deletions
diff --git a/lib/pure/asyncnet.nim b/lib/pure/asyncnet.nim index 88852fb84..ee07e599e 100644 --- a/lib/pure/asyncnet.nim +++ b/lib/pure/asyncnet.nim @@ -8,27 +8,27 @@ # ## This module implements a high-level asynchronous sockets API based on the -## asynchronous dispatcher defined in the ``asyncdispatch`` module. +## asynchronous dispatcher defined in the `asyncdispatch` module. ## ## Asynchronous IO in Nim ## ====================== ## ## Async IO in Nim consists of multiple layers (from highest to lowest): ## -## * ``asyncnet`` module +## * `asyncnet` module ## ## * Async await ## -## * ``asyncdispatch`` module (event loop) +## * `asyncdispatch` module (event loop) ## -## * ``selectors`` module +## * `selectors` module ## ## Each builds on top of the layers below it. The selectors module is an -## abstraction for the various system ``select()`` mechanisms such as epoll or +## abstraction for the various system `select()` mechanisms such as epoll or ## kqueue. If you wish you can use it directly, and some people have done so ## `successfully <http://goran.krampe.se/2014/10/25/nim-socketserver/>`_. ## But you must be aware that on Windows it only supports -## ``select()``. +## `select()`. ## ## The async dispatcher implements the proactor pattern and also has an ## implementation of IOCP. It implements the proactor pattern for other @@ -45,16 +45,16 @@ ## layers interchangeably (as long as you only care about non-Windows ## platforms). ## -## For most applications using ``asyncnet`` is the way to go as it builds +## For most applications using `asyncnet` is the way to go as it builds ## over all the layers, providing some extra features such as buffering. ## ## SSL ## === ## -## SSL can be enabled by compiling with the ``-d:ssl`` flag. +## SSL can be enabled by compiling with the `-d:ssl` flag. ## -## You must create a new SSL context with the ``newContext`` function defined -## in the ``net`` module. You may then call ``wrapSocket`` on your socket using +## You must create a new SSL context with the `newContext` function defined +## in the `net` module. You may then call `wrapSocket` on your socket using ## the newly created SSL context to get an SSL socket. ## ## Examples @@ -65,9 +65,8 @@ ## ## The following example demonstrates a simple chat server. ## -## .. code-block::nim -## -## import asyncnet, asyncdispatch +## ```Nim +## import std/[asyncnet, asyncdispatch] ## ## var clients {.threadvar.}: seq[AsyncSocket] ## @@ -93,21 +92,25 @@ ## ## asyncCheck serve() ## runForever() -## +## ``` + +import std/private/since + +when defined(nimPreviewSlimSystem): + import std/[assertions, syncio] -import asyncdispatch -import nativesockets -import net -import os +import std/[asyncdispatch, nativesockets, net, os] export SOBool # TODO: Remove duplication introduced by PR #4683. const defineSsl = defined(ssl) or defined(nimdoc) +const useNimNetLite = defined(nimNetLite) or defined(freertos) or defined(zephyr) or + defined(nuttx) when defineSsl: - import openssl + import std/openssl type # TODO: I would prefer to just do: @@ -125,25 +128,33 @@ type sslContext: SslContext bioIn: BIO bioOut: BIO + sslNoShutdown: bool domain: Domain sockType: SockType protocol: Protocol AsyncSocket* = ref AsyncSocketDesc proc newAsyncSocket*(fd: AsyncFD, domain: Domain = AF_INET, - sockType: SockType = SOCK_STREAM, - protocol: Protocol = IPPROTO_TCP, buffered = true): owned(AsyncSocket) = - ## Creates a new ``AsyncSocket`` based on the supplied params. + sockType: SockType = SOCK_STREAM, + protocol: Protocol = IPPROTO_TCP, + buffered = true, + inheritable = defined(nimInheritHandles)): owned(AsyncSocket) = + ## Creates a new `AsyncSocket` based on the supplied params. ## - ## The supplied ``fd``'s non-blocking state will be enabled implicitly. + ## The supplied `fd`'s non-blocking state will be enabled implicitly. ## - ## **Note**: This procedure will **NOT** register ``fd`` with the global + ## If `inheritable` is false (the default), the supplied `fd` will not + ## be inheritable by child processes. + ## + ## **Note**: This procedure will **NOT** register `fd` with the global ## async dispatcher. You need to do this manually. If you have used - ## ``newAsyncNativeSocket`` to create ``fd`` then it's already registered. + ## `newAsyncNativeSocket` to create `fd` then it's already registered. assert fd != osInvalidSocket.AsyncFD new(result) result.fd = fd.SocketHandle fd.SocketHandle.setBlocking(false) + if not fd.SocketHandle.setInheritable(inheritable): + raiseOSError(osLastError()) result.isBuffered = buffered result.domain = domain result.sockType = sockType @@ -152,15 +163,19 @@ proc newAsyncSocket*(fd: AsyncFD, domain: Domain = AF_INET, result.currPos = 0 proc newAsyncSocket*(domain: Domain = AF_INET, sockType: SockType = SOCK_STREAM, - protocol: Protocol = IPPROTO_TCP, buffered = true): owned(AsyncSocket) = + protocol: Protocol = IPPROTO_TCP, buffered = true, + inheritable = defined(nimInheritHandles)): owned(AsyncSocket) = ## Creates a new asynchronous socket. ## ## This procedure will also create a brand new file descriptor for ## this socket. - let fd = createAsyncNativeSocket(domain, sockType, protocol) + ## + ## If `inheritable` is false (the default), the new file descriptor will not + ## be inheritable by child processes. + let fd = createAsyncNativeSocket(domain, sockType, protocol, inheritable) if fd.SocketHandle == osInvalidSocket: raiseOSError(osLastError()) - result = newAsyncSocket(fd, domain, sockType, protocol, buffered) + result = newAsyncSocket(fd, domain, sockType, protocol, buffered, inheritable) proc getLocalAddr*(socket: AsyncSocket): (string, Port) = ## Get the socket's local address and port number. @@ -168,28 +183,34 @@ proc getLocalAddr*(socket: AsyncSocket): (string, Port) = ## This is high-level interface for `getsockname`:idx:. getLocalAddr(socket.fd, socket.domain) -proc getPeerAddr*(socket: AsyncSocket): (string, Port) = - ## Get the socket's peer address and port number. - ## - ## This is high-level interface for `getpeername`:idx:. - getPeerAddr(socket.fd, socket.domain) +when not useNimNetLite: + proc getPeerAddr*(socket: AsyncSocket): (string, Port) = + ## Get the socket's peer address and port number. + ## + ## This is high-level interface for `getpeername`:idx:. + getPeerAddr(socket.fd, socket.domain) proc newAsyncSocket*(domain, sockType, protocol: cint, - buffered = true): owned(AsyncSocket) = + buffered = true, + inheritable = defined(nimInheritHandles)): owned(AsyncSocket) = ## Creates a new asynchronous socket. ## ## This procedure will also create a brand new file descriptor for ## this socket. - let fd = createAsyncNativeSocket(domain, sockType, protocol) + ## + ## If `inheritable` is false (the default), the new file descriptor will not + ## be inheritable by child processes. + let fd = createAsyncNativeSocket(domain, sockType, protocol, inheritable) if fd.SocketHandle == osInvalidSocket: raiseOSError(osLastError()) result = newAsyncSocket(fd, Domain(domain), SockType(sockType), - Protocol(protocol), buffered) + Protocol(protocol), buffered, inheritable) when defineSsl: - proc getSslError(handle: SslPtr, err: cint): cint = + proc getSslError(socket: AsyncSocket, err: cint): cint = + assert socket.isSsl assert err < 0 - var ret = SSL_get_error(handle, err.cint) + var ret = SSL_get_error(socket.sslHandle, err.cint) case ret of SSL_ERROR_ZERO_RETURN: raiseSSLError("TLS/SSL connection failed to initiate, socket closed prematurely.") @@ -200,6 +221,7 @@ when defineSsl: of SSL_ERROR_WANT_X509_LOOKUP: raiseSSLError("Function for x509 lookup has been called.") of SSL_ERROR_SYSCALL, SSL_ERROR_SSL: + socket.sslNoShutdown = true raiseSSLError() else: raiseSSLError("Unknown Error") @@ -208,7 +230,7 @@ when defineSsl: let len = bioCtrlPending(socket.bioOut) if len > 0: var data = newString(len) - let read = bioRead(socket.bioOut, addr data[0], len) + let read = bioRead(socket.bioOut, cast[cstring](addr data[0]), len) assert read != 0 if read < 0: raiseSSLError() @@ -217,7 +239,7 @@ when defineSsl: proc appeaseSsl(socket: AsyncSocket, flags: set[SocketFlag], sslError: cint): owned(Future[bool]) {.async.} = - ## Returns ``true`` if ``socket`` is still connected, otherwise ``false``. + ## Returns `true` if `socket` is still connected, otherwise `false`. result = true case sslError of SSL_ERROR_WANT_WRITE: @@ -226,7 +248,7 @@ when defineSsl: var data = await recv(socket.fd.AsyncFD, BufferSize, flags) let length = len(data) if length > 0: - let ret = bioWrite(socket.bioIn, addr data[0], length.cint) + let ret = bioWrite(socket.bioIn, cast[cstring](addr data[0]), length.cint) if ret < 0: raiseSSLError() elif length == 0: @@ -240,18 +262,20 @@ when defineSsl: op: untyped) = var opResult {.inject.} = -1.cint while opResult < 0: + ErrClearError() # Call the desired operation. opResult = op - # Bit hackish here. - # TODO: Introduce an async template transformation pragma? - + let err = + if opResult < 0: + getSslError(socket, opResult.cint) + else: + SSL_ERROR_NONE # Send any remaining pending SSL data. - yield sendPendingSslData(socket, flags) + await sendPendingSslData(socket, flags) # If the operation failed, try to see if SSL has some data to read # or write. if opResult < 0: - let err = getSslError(socket.sslHandle, opResult.cint) let fut = appeaseSsl(socket, flags, err.cint) yield fut if not fut.read(): @@ -264,9 +288,9 @@ when defineSsl: proc dial*(address: string, port: Port, protocol = IPPROTO_TCP, buffered = true): owned(Future[AsyncSocket]) {.async.} = - ## Establishes connection to the specified ``address``:``port`` pair via the + ## Establishes connection to the specified `address`:`port` pair via the ## specified protocol. The procedure iterates through possible - ## resolutions of the ``address`` until it succeeds, meaning that it + ## resolutions of the `address` until it succeeds, meaning that it ## seamlessly works with both IPv4 and IPv6. ## Returns AsyncSocket ready to send or receive data. let asyncFd = await asyncdispatch.dial(address, port, protocol) @@ -275,9 +299,9 @@ proc dial*(address: string, port: Port, protocol = IPPROTO_TCP, result = newAsyncSocket(asyncFd, domain, sockType, protocol, buffered) proc connect*(socket: AsyncSocket, address: string, port: Port) {.async.} = - ## Connects ``socket`` to server at ``address:port``. + ## Connects `socket` to server at `address:port`. ## - ## Returns a ``Future`` which will complete when the connection succeeds + ## Returns a `Future` which will complete when the connection succeeds ## or an error occurs. await connect(socket.fd.AsyncFD, address, port, socket.domain) if socket.isSsl: @@ -293,7 +317,7 @@ proc connect*(socket: AsyncSocket, address: string, port: Port) {.async.} = template readInto(buf: pointer, size: int, socket: AsyncSocket, flags: set[SocketFlag]): int = - ## Reads **up to** ``size`` bytes from ``socket`` into ``buf``. Note that + ## Reads **up to** `size` bytes from `socket` into `buf`. Note that ## this is a template and not a proc. assert(not socket.closed, "Cannot `recv` on a closed socket") var res = 0 @@ -304,10 +328,8 @@ template readInto(buf: pointer, size: int, socket: AsyncSocket, sslRead(socket.sslHandle, cast[cstring](buf), size.cint)) res = opResult else: - var recvIntoFut = asyncdispatch.recvInto(socket.fd.AsyncFD, buf, size, flags) - yield recvIntoFut # Not in SSL mode. - res = recvIntoFut.read() + res = await asyncdispatch.recvInto(socket.fd.AsyncFD, buf, size, flags) res template readIntoBuf(socket: AsyncSocket, @@ -319,10 +341,10 @@ template readIntoBuf(socket: AsyncSocket, proc recvInto*(socket: AsyncSocket, buf: pointer, size: int, flags = {SocketFlag.SafeDisconn}): owned(Future[int]) {.async.} = - ## Reads **up to** ``size`` bytes from ``socket`` into ``buf``. + ## Reads **up to** `size` bytes from `socket` into `buf`. ## ## For buffered sockets this function will attempt to read all the requested - ## data. It will read this data in ``BufferSize`` chunks. + ## data. It will read this data in `BufferSize` chunks. ## ## For unbuffered sockets this function makes no effort to read ## all the data requested. It will return as much data as the operating system @@ -333,7 +355,7 @@ proc recvInto*(socket: AsyncSocket, buf: pointer, size: int, ## requested data. ## ## If socket is disconnected and no data is available - ## to be read then the future will complete with a value of ``0``. + ## to be read then the future will complete with a value of `0`. if socket.isBuffered: let originalBufPos = socket.currPos @@ -367,10 +389,10 @@ proc recvInto*(socket: AsyncSocket, buf: pointer, size: int, proc recv*(socket: AsyncSocket, size: int, flags = {SocketFlag.SafeDisconn}): owned(Future[string]) {.async.} = - ## Reads **up to** ``size`` bytes from ``socket``. + ## Reads **up to** `size` bytes from `socket`. ## ## For buffered sockets this function will attempt to read all the requested - ## data. It will read this data in ``BufferSize`` chunks. + ## data. It will read this data in `BufferSize` chunks. ## ## For unbuffered sockets this function makes no effort to read ## all the data requested. It will return as much data as the operating system @@ -381,10 +403,11 @@ proc recv*(socket: AsyncSocket, size: int, ## requested data. ## ## If socket is disconnected and no data is available - ## to be read then the future will complete with a value of ``""``. + ## to be read then the future will complete with a value of `""`. if socket.isBuffered: result = newString(size) - shallow(result) + when not defined(nimSeqsV2): + shallow(result) let originalBufPos = socket.currPos if socket.bufLen == 0: @@ -419,7 +442,7 @@ proc recv*(socket: AsyncSocket, size: int, proc send*(socket: AsyncSocket, buf: pointer, size: int, flags = {SocketFlag.SafeDisconn}) {.async.} = - ## Sends ``size`` bytes from ``buf`` to ``socket``. The returned future will complete once all + ## Sends `size` bytes from `buf` to `socket`. The returned future will complete once all ## data has been sent. assert socket != nil assert(not socket.closed, "Cannot `send` on a closed socket") @@ -433,25 +456,30 @@ proc send*(socket: AsyncSocket, buf: pointer, size: int, proc send*(socket: AsyncSocket, data: string, flags = {SocketFlag.SafeDisconn}) {.async.} = - ## Sends ``data`` to ``socket``. The returned future will complete once all + ## Sends `data` to `socket`. The returned future will complete once all ## data has been sent. assert socket != nil if socket.isSsl: when defineSsl: var copy = data sslLoop(socket, flags, - sslWrite(socket.sslHandle, addr copy[0], copy.len.cint)) + sslWrite(socket.sslHandle, cast[cstring](addr copy[0]), copy.len.cint)) await sendPendingSslData(socket, flags) else: await send(socket.fd.AsyncFD, data, flags) -proc acceptAddr*(socket: AsyncSocket, flags = {SocketFlag.SafeDisconn}): +proc acceptAddr*(socket: AsyncSocket, flags = {SocketFlag.SafeDisconn}, + inheritable = defined(nimInheritHandles)): owned(Future[tuple[address: string, client: AsyncSocket]]) = ## Accepts a new connection. Returns a future containing the client socket ## corresponding to that connection and the remote address of the client. + ## + ## If `inheritable` is false (the default), the resulting client socket will + ## not be inheritable by child processes. + ## ## The future will complete when the connection is successfully accepted. var retFuture = newFuture[tuple[address: string, client: AsyncSocket]]("asyncnet.acceptAddr") - var fut = acceptAddr(socket.fd.AsyncFD, flags) + var fut = acceptAddr(socket.fd.AsyncFD, flags, inheritable) fut.callback = proc (future: Future[tuple[address: string, client: AsyncFD]]) = assert future.finished @@ -460,7 +488,7 @@ proc acceptAddr*(socket: AsyncSocket, flags = {SocketFlag.SafeDisconn}): else: let resultTup = (future.read.address, newAsyncSocket(future.read.client, socket.domain, - socket.sockType, socket.protocol, socket.isBuffered)) + socket.sockType, socket.protocol, socket.isBuffered, inheritable)) retFuture.complete(resultTup) return retFuture @@ -468,6 +496,8 @@ proc accept*(socket: AsyncSocket, flags = {SocketFlag.SafeDisconn}): owned(Future[AsyncSocket]) = ## Accepts a new connection. Returns a future containing the client socket ## corresponding to that connection. + ## If `inheritable` is false (the default), the resulting client socket will + ## not be inheritable by child processes. ## The future will complete when the connection is successfully accepted. var retFut = newFuture[AsyncSocket]("asyncnet.accept") var fut = acceptAddr(socket, flags) @@ -482,25 +512,24 @@ proc accept*(socket: AsyncSocket, proc recvLineInto*(socket: AsyncSocket, resString: FutureVar[string], flags = {SocketFlag.SafeDisconn}, maxLength = MaxLineLength) {.async.} = - ## Reads a line of data from ``socket`` into ``resString``. + ## Reads a line of data from `socket` into `resString`. ## - ## If a full line is read ``\r\L`` is not - ## added to ``line``, however if solely ``\r\L`` is read then ``line`` + ## If a full line is read `\r\L` is not + ## added to `line`, however if solely `\r\L` is read then `line` ## will be set to it. ## - ## If the socket is disconnected, ``line`` will be set to ``""``. + ## If the socket is disconnected, `line` will be set to `""`. ## - ## If the socket is disconnected in the middle of a line (before ``\r\L`` - ## is read) then line will be set to ``""``. + ## If the socket is disconnected in the middle of a line (before `\r\L` + ## is read) then line will be set to `""`. ## The partial line **will be lost**. ## - ## The ``maxLength`` parameter determines the maximum amount of characters - ## that can be read. ``resString`` will be truncated after that. + ## The `maxLength` parameter determines the maximum amount of characters + ## that can be read. `resString` will be truncated after that. ## - ## **Warning**: The ``Peek`` flag is not yet implemented. + ## .. warning:: The `Peek` flag is not yet implemented. ## - ## **Warning**: ``recvLineInto`` on unbuffered sockets assumes that the - ## protocol uses ``\r\L`` to delimit a new line. + ## .. warning:: `recvLineInto` on unbuffered sockets assumes that the protocol uses `\r\L` to delimit a new line. assert SocketFlag.Peek notin flags ## TODO: result = newFuture[void]("asyncnet.recvLineInto") @@ -575,26 +604,25 @@ proc recvLineInto*(socket: AsyncSocket, resString: FutureVar[string], proc recvLine*(socket: AsyncSocket, flags = {SocketFlag.SafeDisconn}, maxLength = MaxLineLength): owned(Future[string]) {.async.} = - ## Reads a line of data from ``socket``. Returned future will complete once + ## Reads a line of data from `socket`. Returned future will complete once ## a full line is read or an error occurs. ## - ## If a full line is read ``\r\L`` is not - ## added to ``line``, however if solely ``\r\L`` is read then ``line`` + ## If a full line is read `\r\L` is not + ## added to `line`, however if solely `\r\L` is read then `line` ## will be set to it. ## - ## If the socket is disconnected, ``line`` will be set to ``""``. + ## If the socket is disconnected, `line` will be set to `""`. ## - ## If the socket is disconnected in the middle of a line (before ``\r\L`` - ## is read) then line will be set to ``""``. + ## If the socket is disconnected in the middle of a line (before `\r\L` + ## is read) then line will be set to `""`. ## The partial line **will be lost**. ## - ## The ``maxLength`` parameter determines the maximum amount of characters + ## The `maxLength` parameter determines the maximum amount of characters ## that can be read. The result is truncated after that. ## - ## **Warning**: The ``Peek`` flag is not yet implemented. + ## .. warning:: The `Peek` flag is not yet implemented. ## - ## **Warning**: ``recvLine`` on unbuffered sockets assumes that the protocol - ## uses ``\r\L`` to delimit a new line. + ## .. warning:: `recvLine` on unbuffered sockets assumes that the protocol uses `\r\L` to delimit a new line. assert SocketFlag.Peek notin flags ## TODO: # TODO: Optimise this @@ -605,8 +633,8 @@ proc recvLine*(socket: AsyncSocket, proc listen*(socket: AsyncSocket, backlog = SOMAXCONN) {.tags: [ ReadIOEffect].} = - ## Marks ``socket`` as accepting connections. - ## ``Backlog`` specifies the maximum length of the + ## Marks `socket` as accepting connections. + ## `Backlog` specifies the maximum length of the ## queue of pending connections. ## ## Raises an OSError error upon failure. @@ -614,9 +642,9 @@ proc listen*(socket: AsyncSocket, backlog = SOMAXCONN) {.tags: [ proc bindAddr*(socket: AsyncSocket, port = Port(0), address = "") {. tags: [ReadIOEffect].} = - ## Binds ``address``:``port`` to the socket. + ## Binds `address`:`port` to the socket. ## - ## If ``address`` is "" then ADDR_ANY will be bound. + ## If `address` is "" then ADDR_ANY will be bound. var realaddr = address if realaddr == "": case socket.domain @@ -628,11 +656,16 @@ proc bindAddr*(socket: AsyncSocket, port = Port(0), address = "") {. var aiList = getAddrInfo(realaddr, port, socket.domain) if bindAddr(socket.fd, aiList.ai_addr, aiList.ai_addrlen.SockLen) < 0'i32: - freeaddrinfo(aiList) + freeAddrInfo(aiList) raiseOSError(osLastError()) - freeaddrinfo(aiList) + freeAddrInfo(aiList) -when defined(posix): +proc hasDataBuffered*(s: AsyncSocket): bool {.since: (1, 5).} = + ## Determines whether an AsyncSocket has data buffered. + # xxx dedup with std/net + s.isBuffered and s.bufLen > 0 and s.currPos != s.bufLen + +when defined(posix) and not useNimNetLite: proc connectUnix*(socket: AsyncSocket, path: string): owned(Future[void]) = ## Binds Unix socket to `path`. @@ -649,12 +682,12 @@ when defined(posix): elif ret == EINTR: return false else: - retFuture.fail(newException(OSError, osErrorMsg(OSErrorCode(ret)))) + retFuture.fail(newOSError(OSErrorCode(ret))) return true var socketAddr = makeUnixAddr(path) let ret = socket.fd.connect(cast[ptr SockAddr](addr socketAddr), - (sizeof(socketAddr.sun_family) + path.len).SockLen) + (offsetOf(socketAddr, sun_path) + path.len + 1).SockLen) if ret == 0: # Request to connect completed immediately. retFuture.complete() @@ -663,7 +696,7 @@ when defined(posix): if lastError.int32 == EINTR or lastError.int32 == EINPROGRESS: addWrite(AsyncFD(socket.fd), cb) else: - retFuture.fail(newException(OSError, osErrorMsg(lastError))) + retFuture.fail(newOSError(lastError)) proc bindUnix*(socket: AsyncSocket, path: string) {. tags: [ReadIOEffect].} = @@ -672,7 +705,7 @@ when defined(posix): when not defined(nimdoc): var socketAddr = makeUnixAddr(path) if socket.fd.bindAddr(cast[ptr SockAddr](addr socketAddr), - (sizeof(socketAddr.sun_family) + path.len).SockLen) != 0'i32: + (offsetOf(socketAddr, sun_path) + path.len + 1).SockLen) != 0'i32: raiseOSError(osLastError()) elif defined(nimdoc): @@ -689,22 +722,38 @@ elif defined(nimdoc): proc close*(socket: AsyncSocket) = ## Closes the socket. + if socket.closed: return + defer: socket.fd.AsyncFD.closeSocket() + socket.closed = true # TODO: Add extra debugging checks for this. + when defineSsl: if socket.isSsl: - let res = SSL_shutdown(socket.sslHandle) + let res = + # Don't call SSL_shutdown if the connection has not been fully + # established, see: + # https://github.com/openssl/openssl/issues/710#issuecomment-253897666 + if not socket.sslNoShutdown and SSL_in_init(socket.sslHandle) == 0: + ErrClearError() + SSL_shutdown(socket.sslHandle) + else: + 0 SSL_free(socket.sslHandle) if res == 0: discard elif res != 1: raiseSSLError() - socket.closed = true # TODO: Add extra debugging checks for this. when defineSsl: + proc sslHandle*(self: AsyncSocket): SslPtr = + ## Retrieve the ssl pointer of `socket`. + ## Useful for interfacing with `openssl`. + self.sslHandle + proc wrapSocket*(ctx: SslContext, socket: AsyncSocket) = ## Wraps a socket in an SSL context. This function effectively turns - ## ``socket`` into an SSL socket. + ## `socket` into an SSL socket. ## ## **Disclaimer**: This code is not well tested, may be very unsafe and ## prone to security vulnerabilities. @@ -718,12 +767,14 @@ when defineSsl: socket.bioOut = bioNew(bioSMem()) sslSetBio(socket.sslHandle, socket.bioIn, socket.bioOut) + socket.sslNoShutdown = true + proc wrapConnectedSocket*(ctx: SslContext, socket: AsyncSocket, handshake: SslHandshakeType, hostname: string = "") = ## Wraps a connected socket in an SSL context. This function effectively - ## turns ``socket`` into an SSL socket. - ## ``hostname`` should be specified so that the client knows which hostname + ## turns `socket` into an SSL socket. + ## `hostname` should be specified so that the client knows which hostname ## the server certificate should be validated against. ## ## This should be called on a connected socket, and will perform @@ -743,20 +794,31 @@ when defineSsl: of handshakeAsServer: sslSetAcceptState(socket.sslHandle) + proc getPeerCertificates*(socket: AsyncSocket): seq[Certificate] {.since: (1, 1).} = + ## Returns the certificate chain received by the peer we are connected to + ## through the given socket. + ## The handshake must have been completed and the certificate chain must + ## have been verified successfully or else an empty sequence is returned. + ## The chain is ordered from leaf certificate to root certificate. + if not socket.isSsl: + result = newSeq[Certificate]() + else: + result = getPeerCertificates(socket.sslHandle) + proc getSockOpt*(socket: AsyncSocket, opt: SOBool, level = SOL_SOCKET): bool {. tags: [ReadIOEffect].} = - ## Retrieves option ``opt`` as a boolean value. + ## Retrieves option `opt` as a boolean value. var res = getSockOptInt(socket.fd, cint(level), toCInt(opt)) result = res != 0 proc setSockOpt*(socket: AsyncSocket, opt: SOBool, value: bool, level = SOL_SOCKET) {.tags: [WriteIOEffect].} = - ## Sets option ``opt`` to a boolean value specified by ``value``. + ## Sets option `opt` to a boolean value specified by `value`. var valuei = cint(if value: 1 else: 0) setSockOptInt(socket.fd, cint(level), toCInt(opt), valuei) proc isSsl*(socket: AsyncSocket): bool = - ## Determines whether ``socket`` is a SSL socket. + ## Determines whether `socket` is a SSL socket. socket.isSsl proc getFd*(socket: AsyncSocket): SocketHandle = @@ -767,6 +829,130 @@ proc isClosed*(socket: AsyncSocket): bool = ## Determines whether the socket has been closed. return socket.closed +proc sendTo*(socket: AsyncSocket, address: string, port: Port, data: string, + flags = {SocketFlag.SafeDisconn}): owned(Future[void]) + {.async, since: (1, 3).} = + ## This proc sends `data` to the specified `address`, which may be an IP + ## address or a hostname. If a hostname is specified this function will try + ## each IP of that hostname. The returned future will complete once all data + ## has been sent. + ## + ## If an error occurs an OSError exception will be raised. + ## + ## This proc is normally used with connectionless sockets (UDP sockets). + assert(socket.protocol != IPPROTO_TCP, + "Cannot `sendTo` on a TCP socket. Use `send` instead") + assert(not socket.closed, "Cannot `sendTo` on a closed socket") + + let aiList = getAddrInfo(address, port, socket.domain, socket.sockType, + socket.protocol) + + var + it = aiList + success = false + lastException: ref Exception + + while it != nil: + let fut = sendTo(socket.fd.AsyncFD, cstring(data), len(data), it.ai_addr, + it.ai_addrlen.SockLen, flags) + + yield fut + + if not fut.failed: + success = true + + break + + lastException = fut.readError() + + it = it.ai_next + + freeAddrInfo(aiList) + + if not success: + if lastException != nil: + raise lastException + else: + raise newException(IOError, "Couldn't resolve address: " & address) + +proc recvFrom*(socket: AsyncSocket, data: FutureVar[string], size: int, + address: FutureVar[string], port: FutureVar[Port], + flags = {SocketFlag.SafeDisconn}): owned(Future[int]) + {.async, since: (1, 3).} = + ## Receives a datagram data from `socket` into `data`, which must be at + ## least of size `size`. The address and port of datagram's sender will be + ## stored into `address` and `port`, respectively. Returned future will + ## complete once one datagram has been received, and will return size of + ## packet received. + ## + ## If an error occurs an OSError exception will be raised. + ## + ## This proc is normally used with connectionless sockets (UDP sockets). + ## + ## **Notes** + ## * `data` must be initialized to the length of `size`. + ## * `address` must be initialized to 46 in length. + template adaptRecvFromToDomain(domain: Domain) = + var lAddr = sizeof(sAddr).SockLen + + result = await recvFromInto(AsyncFD(getFd(socket)), cstring(data.mget()), size, + cast[ptr SockAddr](addr sAddr), addr lAddr, + flags) + + data.mget().setLen(result) + data.complete() + + getAddrString(cast[ptr SockAddr](addr sAddr), address.mget()) + + address.complete() + + when domain == AF_INET6: + port.complete(ntohs(sAddr.sin6_port).Port) + else: + port.complete(ntohs(sAddr.sin_port).Port) + + assert(socket.protocol != IPPROTO_TCP, + "Cannot `recvFrom` on a TCP socket. Use `recv` or `recvInto` instead") + assert(not socket.closed, "Cannot `recvFrom` on a closed socket") + assert(size == len(data.mget()), + "`date` was not initialized correctly. `size` != `len(data.mget())`") + assert(46 == len(address.mget()), + "`address` was not initialized correctly. 46 != `len(address.mget())`") + + case socket.domain + of AF_INET6: + var sAddr: Sockaddr_in6 + adaptRecvFromToDomain(AF_INET6) + of AF_INET: + var sAddr: Sockaddr_in + adaptRecvFromToDomain(AF_INET) + else: + raise newException(ValueError, "Unknown socket address family") + +proc recvFrom*(socket: AsyncSocket, size: int, + flags = {SocketFlag.SafeDisconn}): + owned(Future[tuple[data: string, address: string, port: Port]]) + {.async, since: (1, 3).} = + ## Receives a datagram data from `socket`, which must be at least of size + ## `size`. Returned future will complete once one datagram has been received + ## and will return tuple with: data of packet received; and address and port + ## of datagram's sender. + ## + ## If an error occurs an OSError exception will be raised. + ## + ## This proc is normally used with connectionless sockets (UDP sockets). + var + data = newFutureVar[string]() + address = newFutureVar[string]() + port = newFutureVar[Port]() + + data.mget().setLen(size) + address.mget().setLen(46) + + let read = await recvFrom(socket, data, size, address, port, flags) + + result = (data.mget(), address.mget(), port.mget()) + when not defined(testing) and isMainModule: type TestCases = enum |