diff options
Diffstat (limited to 'lib/pure/asyncnet.nim')
-rw-r--r-- | lib/pure/asyncnet.nim | 214 |
1 files changed, 137 insertions, 77 deletions
diff --git a/lib/pure/asyncnet.nim b/lib/pure/asyncnet.nim index e7325e0d7..cfd3d7666 100644 --- a/lib/pure/asyncnet.nim +++ b/lib/pure/asyncnet.nim @@ -24,7 +24,7 @@ ## ## Chat server ## ^^^^^^^^^^^ -## +## ## The following example demonstrates a simple chat server. ## ## .. code-block::nim @@ -85,35 +85,44 @@ type bioIn: BIO bioOut: BIO of false: nil + domain: Domain + sockType: SockType + protocol: Protocol AsyncSocket* = ref AsyncSocketDesc {.deprecated: [PAsyncSocket: AsyncSocket].} -# TODO: Save AF, domain etc info and reuse it in procs which need it like connect. - -proc newAsyncSocket*(fd: TAsyncFD, isBuff: bool): AsyncSocket = +proc newAsyncSocket*(fd: AsyncFD, domain: Domain = AF_INET, + sockType: SockType = SOCK_STREAM, + protocol: Protocol = IPPROTO_TCP, buffered = true): AsyncSocket = ## Creates a new ``AsyncSocket`` based on the supplied params. - assert fd != osInvalidSocket.TAsyncFD + assert fd != osInvalidSocket.AsyncFD new(result) result.fd = fd.SocketHandle - result.isBuffered = isBuff - if isBuff: + result.isBuffered = buffered + result.domain = domain + result.sockType = sockType + result.protocol = protocol + if buffered: result.currPos = 0 -proc newAsyncSocket*(domain: Domain = AF_INET, typ: SockType = SOCK_STREAM, +proc newAsyncSocket*(domain: Domain = AF_INET, sockType: SockType = SOCK_STREAM, protocol: Protocol = IPPROTO_TCP, buffered = true): AsyncSocket = ## Creates a new asynchronous socket. ## ## This procedure will also create a brand new file descriptor for ## this socket. - result = newAsyncSocket(newAsyncRawSocket(domain, typ, protocol), buffered) + result = newAsyncSocket(newAsyncRawSocket(domain, sockType, protocol), domain, + sockType, protocol, buffered) -proc newAsyncSocket*(domain, typ, protocol: cint, buffered = true): AsyncSocket = +proc newAsyncSocket*(domain, sockType, protocol: cint, + buffered = true): AsyncSocket = ## Creates a new asynchronous socket. ## ## This procedure will also create a brand new file descriptor for ## this socket. - result = newAsyncSocket(newAsyncRawSocket(domain, typ, protocol), buffered) + result = newAsyncSocket(newAsyncRawSocket(domain, sockType, protocol), + Domain(domain), SockType(sockType), Protocol(protocol), buffered) when defined(ssl): proc getSslError(handle: SslPtr, err: cint): cint = @@ -142,7 +151,7 @@ when defined(ssl): if read < 0: raiseSslError() data.setLen(read) - await socket.fd.TAsyncFd.send(data, flags) + await socket.fd.AsyncFd.send(data, flags) proc appeaseSsl(socket: AsyncSocket, flags: set[SocketFlag], sslError: cint) {.async.} = @@ -150,7 +159,7 @@ when defined(ssl): of SSL_ERROR_WANT_WRITE: await sendPendingSslData(socket, flags) of SSL_ERROR_WANT_READ: - var data = await recv(socket.fd.TAsyncFD, BufferSize, flags) + var data = await recv(socket.fd.AsyncFD, BufferSize, flags) let ret = bioWrite(socket.bioIn, addr data[0], data.len.cint) if ret < 0: raiseSSLError() @@ -169,39 +178,42 @@ when defined(ssl): let err = getSslError(socket.sslHandle, opResult.cint) yield appeaseSsl(socket, flags, err.cint) -proc connect*(socket: AsyncSocket, address: string, port: Port, - af = AF_INET) {.async.} = +proc connect*(socket: AsyncSocket, address: string, port: Port) {.async.} = ## Connects ``socket`` to server at ``address:port``. ## ## Returns a ``Future`` which will complete when the connection succeeds ## or an error occurs. - await connect(socket.fd.TAsyncFD, address, port, af) + await connect(socket.fd.AsyncFD, address, port, socket.domain) if socket.isSsl: when defined(ssl): let flags = {SocketFlag.SafeDisconn} sslSetConnectState(socket.sslHandle) sslLoop(socket, flags, sslDoHandshake(socket.sslHandle)) -proc readInto(buf: cstring, size: int, socket: AsyncSocket, - flags: set[SocketFlag]): Future[int] {.async.} = +template readInto(buf: cstring, size: int, socket: AsyncSocket, + flags: set[SocketFlag]): int = + ## Reads **up to** ``size`` bytes from ``socket`` into ``buf``. Note that + ## this is a template and not a proc. + var res = 0 if socket.isSsl: when defined(ssl): # SSL mode. sslLoop(socket, flags, sslRead(socket.sslHandle, buf, size.cint)) - result = opResult + res = opResult else: - var data = await recv(socket.fd.TAsyncFD, size, flags) - if data.len != 0: - copyMem(buf, addr data[0], data.len) + var recvIntoFut = recvInto(socket.fd.AsyncFD, buf, size, flags) + yield recvIntoFut # Not in SSL mode. - result = data.len + res = recvIntoFut.read() + res -proc readIntoBuf(socket: AsyncSocket, - flags: set[SocketFlag]): Future[int] {.async.} = - result = await readInto(addr socket.buffer[0], BufferSize, socket, flags) +template readIntoBuf(socket: AsyncSocket, + flags: set[SocketFlag]): int = + var size = readInto(addr socket.buffer[0], BufferSize, socket, flags) socket.currPos = 0 - socket.bufLen = result + socket.bufLen = size + size proc recv*(socket: AsyncSocket, size: int, flags = {SocketFlag.SafeDisconn}): Future[string] {.async.} = @@ -222,10 +234,11 @@ proc recv*(socket: AsyncSocket, size: int, ## to be read then the future will complete with a value of ``""``. if socket.isBuffered: result = newString(size) + shallow(result) let originalBufPos = socket.currPos if socket.bufLen == 0: - let res = await socket.readIntoBuf(flags - {SocketFlag.Peek}) + let res = socket.readIntoBuf(flags - {SocketFlag.Peek}) if res == 0: result.setLen(0) return @@ -236,7 +249,7 @@ proc recv*(socket: AsyncSocket, size: int, if SocketFlag.Peek in flags: # We don't want to get another buffer if we're peeking. break - let res = await socket.readIntoBuf(flags - {SocketFlag.Peek}) + let res = socket.readIntoBuf(flags - {SocketFlag.Peek}) if res == 0: break @@ -251,7 +264,7 @@ proc recv*(socket: AsyncSocket, size: int, result.setLen(read) else: result = newString(size) - let read = await readInto(addr result[0], size, socket, flags) + let read = readInto(addr result[0], size, socket, flags) result.setLen(read) proc send*(socket: AsyncSocket, data: string, @@ -266,7 +279,7 @@ proc send*(socket: AsyncSocket, data: string, sslWrite(socket.sslHandle, addr copy[0], copy.len.cint)) await sendPendingSslData(socket, flags) else: - await send(socket.fd.TAsyncFD, data, flags) + await send(socket.fd.AsyncFD, data, flags) proc acceptAddr*(socket: AsyncSocket, flags = {SocketFlag.SafeDisconn}): Future[tuple[address: string, client: AsyncSocket]] = @@ -274,15 +287,16 @@ proc acceptAddr*(socket: AsyncSocket, flags = {SocketFlag.SafeDisconn}): ## corresponding to that connection and the remote address of the client. ## The future will complete when the connection is successfully accepted. var retFuture = newFuture[tuple[address: string, client: AsyncSocket]]("asyncnet.acceptAddr") - var fut = acceptAddr(socket.fd.TAsyncFD, flags) + var fut = acceptAddr(socket.fd.AsyncFD, flags) fut.callback = - proc (future: Future[tuple[address: string, client: TAsyncFD]]) = + proc (future: Future[tuple[address: string, client: AsyncFD]]) = assert future.finished if future.failed: retFuture.fail(future.readError) else: let resultTup = (future.read.address, - newAsyncSocket(future.read.client, socket.isBuffered)) + newAsyncSocket(future.read.client, socket.domain, + socket.sockType, socket.protocol, socket.isBuffered)) retFuture.complete(resultTup) return retFuture @@ -302,15 +316,14 @@ proc accept*(socket: AsyncSocket, retFut.complete(future.read.client) return retFut -proc recvLine*(socket: AsyncSocket, - flags = {SocketFlag.SafeDisconn}): Future[string] {.async.} = - ## Reads a line of data from ``socket``. Returned future will complete once - ## a full line is read or an error occurs. +proc recvLineInto*(socket: AsyncSocket, resString: ptr string, + flags = {SocketFlag.SafeDisconn}) {.async.} = + ## Reads a line of data from ``socket`` into ``resString``. ## ## If a full line is read ``\r\L`` is not ## added to ``line``, however if solely ``\r\L`` is read then ``line`` ## will be set to it. - ## + ## ## If the socket is disconnected, ``line`` will be set to ``""``. ## ## If the socket is disconnected in the middle of a line (before ``\r\L`` @@ -318,27 +331,32 @@ proc recvLine*(socket: AsyncSocket, ## The partial line **will be lost**. ## ## **Warning**: The ``Peek`` flag is not yet implemented. - ## - ## **Warning**: ``recvLine`` on unbuffered sockets assumes that the protocol - ## uses ``\r\L`` to delimit a new line. - template addNLIfEmpty(): stmt = - if result.len == 0: - result.add("\c\L") + ## + ## **Warning**: ``recvLineInto`` on unbuffered sockets assumes that the + ## protocol uses ``\r\L`` to delimit a new line. + ## + ## **Warning**: ``recvLineInto`` currently uses a raw pointer to a string for + ## performance reasons. This will likely change soon to use FutureVars. assert SocketFlag.Peek notin flags ## TODO: + result = newFuture[void]("asyncnet.recvLineInto") + + template addNLIfEmpty(): stmt = + if resString[].len == 0: + resString[].add("\c\L") + if socket.isBuffered: - result = "" if socket.bufLen == 0: - let res = await socket.readIntoBuf(flags) + let res = socket.readIntoBuf(flags) if res == 0: return var lastR = false while true: if socket.currPos >= socket.bufLen: - let res = await socket.readIntoBuf(flags) + let res = socket.readIntoBuf(flags) if res == 0: - result = "" - break + resString[].setLen(0) + return case socket.buffer[socket.currPos] of '\r': @@ -353,24 +371,53 @@ proc recvLine*(socket: AsyncSocket, socket.currPos.inc() return else: - result.add socket.buffer[socket.currPos] + resString[].add socket.buffer[socket.currPos] socket.currPos.inc() else: - result = "" var c = "" while true: - c = await recv(socket, 1, flags) + let recvFut = recv(socket, 1, flags) + c = recvFut.read() if c.len == 0: - return "" + resString[].setLen(0) + return if c == "\r": - c = await recv(socket, 1, flags) # Skip \L + let recvFut = recv(socket, 1, flags) # Skip \L + c = recvFut.read() assert c == "\L" addNLIfEmpty() return elif c == "\L": addNLIfEmpty() return - add(result.string, c) + resString[].add c + +proc recvLine*(socket: AsyncSocket, + flags = {SocketFlag.SafeDisconn}): Future[string] {.async.} = + ## Reads a line of data from ``socket``. Returned future will complete once + ## a full line is read or an error occurs. + ## + ## If a full line is read ``\r\L`` is not + ## added to ``line``, however if solely ``\r\L`` is read then ``line`` + ## will be set to it. + ## + ## If the socket is disconnected, ``line`` will be set to ``""``. + ## + ## If the socket is disconnected in the middle of a line (before ``\r\L`` + ## is read) then line will be set to ``""``. + ## The partial line **will be lost**. + ## + ## **Warning**: The ``Peek`` flag is not yet implemented. + ## + ## **Warning**: ``recvLine`` on unbuffered sockets assumes that the protocol + ## uses ``\r\L`` to delimit a new line. + template addNLIfEmpty(): stmt = + if result.len == 0: + result.add("\c\L") + assert SocketFlag.Peek notin flags ## TODO: + + result = "" + await socket.recvLineInto(addr result, flags) proc listen*(socket: AsyncSocket, backlog = SOMAXCONN) {.tags: [ReadIOEffect].} = ## Marks ``socket`` as accepting connections. @@ -385,29 +432,24 @@ proc bindAddr*(socket: AsyncSocket, port = Port(0), address = "") {. ## Binds ``address``:``port`` to the socket. ## ## If ``address`` is "" then ADDR_ANY will be bound. - - if address == "": - var name: Sockaddr_in - when defined(Windows) or defined(nimdoc): - name.sin_family = toInt(AF_INET).int16 + var realaddr = address + if realaddr == "": + case socket.domain + of AF_INET6: realaddr = "::" + of AF_INET: realaddr = "0.0.0.0" else: - name.sin_family = toInt(AF_INET) - name.sin_port = htons(int16(port)) - name.sin_addr.s_addr = htonl(INADDR_ANY) - if bindAddr(socket.fd, cast[ptr SockAddr](addr(name)), - sizeof(name).Socklen) < 0'i32: - raiseOSError(osLastError()) - else: - var aiList = getAddrInfo(address, port, AF_INET) - if bindAddr(socket.fd, aiList.ai_addr, aiList.ai_addrlen.Socklen) < 0'i32: - dealloc(aiList) - raiseOSError(osLastError()) + raiseOSError("Unknown socket address family and no address specified to bindAddr") + + var aiList = getAddrInfo(realaddr, port, socket.domain) + if bindAddr(socket.fd, aiList.ai_addr, aiList.ai_addrlen.Socklen) < 0'i32: dealloc(aiList) + raiseOSError(osLastError()) + dealloc(aiList) proc close*(socket: AsyncSocket) = ## Closes the socket. defer: - socket.fd.TAsyncFD.closeSocket() + socket.fd.AsyncFD.closeSocket() when defined(ssl): if socket.isSSL: let res = SslShutdown(socket.sslHandle) @@ -434,6 +476,24 @@ when defined(ssl): socket.bioOut = bioNew(bio_s_mem()) sslSetBio(socket.sslHandle, socket.bioIn, socket.bioOut) + proc wrapConnectedSocket*(ctx: SslContext, socket: AsyncSocket, + handshake: SslHandshakeType) = + ## Wraps a connected socket in an SSL context. This function effectively + ## turns ``socket`` into an SSL socket. + ## + ## This should be called on a connected socket, and will perform + ## an SSL handshake immediately. + ## + ## **Disclaimer**: This code is not well tested, may be very unsafe and + ## prone to security vulnerabilities. + wrapSocket(ctx, socket) + + case handshake + of handshakeAsClient: + sslSetConnectState(socket.sslHandle) + of handshakeAsServer: + sslSetAcceptState(socket.sslHandle) + proc getSockOpt*(socket: AsyncSocket, opt: SOBool, level = SOL_SOCKET): bool {. tags: [ReadIOEffect].} = ## Retrieves option ``opt`` as a boolean value. @@ -458,7 +518,7 @@ proc isClosed*(socket: AsyncSocket): bool = ## Determines whether the socket has been closed. return socket.closed -when isMainModule: +when not defined(testing) and isMainModule: type TestCases = enum HighClient, LowClient, LowServer @@ -500,11 +560,11 @@ when isMainModule: proc (future: Future[void]) = echo("Send") client.close() - + var f = accept(sock) f.callback = onAccept - + var f = accept(sock) f.callback = onAccept runForever() - + |