diff options
Diffstat (limited to 'lib/pure/sockets.nim')
-rwxr-xr-x | lib/pure/sockets.nim | 342 |
1 files changed, 245 insertions, 97 deletions
diff --git a/lib/pure/sockets.nim b/lib/pure/sockets.nim index 2cd9be786..67dbd6d9f 100755 --- a/lib/pure/sockets.nim +++ b/lib/pure/sockets.nim @@ -37,10 +37,10 @@ when defined(ssl): TSSLProtVersion* = enum protSSLv2, protSSLv3, protTLSv1, protSSLv23 - TSSLOptions* = object - verifyMode*: TSSLCVerifyMode - certFile*, keyFile*: string - protVer*: TSSLprotVersion + PSSLContext* = distinct PSSLCTX + + TSSLAcceptResult* = enum + AcceptNoClient = 0, AcceptNoHandshake, AcceptSuccess type TSocketImpl = object ## socket type @@ -55,8 +55,8 @@ type case isSsl: bool of true: sslHandle: PSSL - sslContext: PSSLCTX - wrapOptions: TSSLOptions + sslContext: PSSLContext + sslNoHandshake: bool # True if needs handshake. of false: nil TSocket* = ref TSocketImpl @@ -211,22 +211,49 @@ when defined(ssl): raise newException(ESSL, $errStr) # http://simplestcodings.blogspot.co.uk/2010/08/secure-server-client-using-openssl-in-c.html - proc loadCertificates(socket: var TSocket, certFile, keyFile: string) = + proc loadCertificates(ctx: PSSL_CTX, certFile, keyFile: string) = if certFile != "": - if SSLCTXUseCertificateFile(socket.sslContext, certFile, - SSL_FILETYPE_PEM) != 1: + var ret = SSLCTXUseCertificateFile(ctx, certFile, + SSL_FILETYPE_PEM) + if ret != 1: SSLError() + + # TODO: Password? www.rtfm.com/openssl-examples/part1.pdf if keyFile != "": - if SSL_CTX_use_PrivateKey_file(socket.sslContext, keyFile, + if SSL_CTX_use_PrivateKey_file(ctx, keyFile, SSL_FILETYPE_PEM) != 1: SSLError() - if SSL_CTX_check_private_key(socket.sslContext) != 1: + if SSL_CTX_check_private_key(ctx) != 1: SSLError("Verification of private key file failed.") - proc wrapSocket*(socket: var TSocket, protVersion = ProtSSLv23, - verifyMode = CVerifyPeer, - certFile = "", keyFile = "") = + proc newContext*(protVersion = ProtSSLv23, verifyMode = CVerifyPeer, + certFile = "", keyFile = ""): PSSLContext = + var newCTX: PSSL_CTX + case protVersion + of protSSLv23: + newCTX = SSL_CTX_new(SSLv23_method()) # SSlv2,3 and TLS1 support. + of protSSLv2: + newCTX = SSL_CTX_new(SSLv2_method()) + of protSSLv3: + newCTX = SSL_CTX_new(SSLv3_method()) + of protTLSv1: + newCTX = SSL_CTX_new(TLSv1_method()) + + if newCTX.SSLCTXSetCipherList("ALL") != 1: + SSLError() + case verifyMode + of CVerifyPeer: + newCTX.SSLCTXSetVerify(SSLVerifyPeer, nil) + of CVerifyNone: + newCTX.SSLCTXSetVerify(SSLVerifyNone, nil) + if newCTX == nil: + SSLError() + + newCTX.loadCertificates(certFile, keyFile) + return PSSLContext(newCTX) + + proc wrapSocket*(ctx: PSSLContext, socket: TSocket) = ## Creates a SSL context for ``socket`` and wraps the socket in it. ## ## Protocol version specifies the protocol to use. SSLv2, SSLv3, TLSv1 are @@ -247,44 +274,15 @@ when defined(ssl): ## most likely very prone to security vulnerabilities. socket.isSSL = true - socket.wrapOptions.verifyMode = verifyMode - socket.wrapOptions.certFile = certFile - socket.wrapOptions.keyFile = keyFile - socket.wrapOptions.protVer = protVersion - - case protVersion - of protSSLv23: - socket.sslContext = SSL_CTX_new(SSLv23_method()) # SSlv2,3 and TLS1 support. - of protSSLv2: - socket.sslContext = SSL_CTX_new(SSLv2_method()) - of protSSLv3: - socket.sslContext = SSL_CTX_new(SSLv3_method()) - of protTLSv1: - socket.sslContext = SSL_CTX_new(TLSv1_method()) - - if socket.sslContext.SSLCTXSetCipherList("ALL") != 1: - SSLError() - case verifyMode - of CVerifyPeer: - socket.sslContext.SSLCTXSetVerify(SSLVerifyPeer, nil) - of CVerifyNone: - socket.sslContext.SSLCTXSetVerify(SSLVerifyNone, nil) - if socket.sslContext == nil: - SSLError() - - socket.loadCertificates(certFile, keyFile) - - socket.sslHandle = SSLNew(socket.sslContext) + socket.sslContext = ctx + socket.sslHandle = SSLNew(PSSLCTX(socket.sslContext)) + socket.sslNoHandshake = false if socket.sslHandle == nil: SSLError() if SSLSetFd(socket.sslHandle, socket.fd) != 1: SSLError() - proc wrapSocket*(socket: var TSocket, wo: TSSLOptions) = - ## A variant of the above with a options object. - wrapSocket(socket, wo.protVer, wo.verifyMode, wo.certFile, wo.keyFile) - proc listen*(socket: TSocket, backlog = SOMAXCONN) = ## Marks ``socket`` as accepting connections. ## ``Backlog`` specifies the maximum length of the @@ -352,7 +350,7 @@ proc bindAddr*(socket: TSocket, port = TPort(0), address = "") = hints.ai_socktype = toInt(SOCK_STREAM) hints.ai_protocol = toInt(IPPROTO_TCP) gaiNim(address, port, hints, aiList) - if bindSocket(socket.fd, aiList.ai_addr, aiList.ai_addrLen.cint) < 0'i32: + if bindSocket(socket.fd, aiList.ai_addr, aiList.ai_addrLen.cuint) < 0'i32: OSError() when false: @@ -386,17 +384,8 @@ proc getSockName*(socket: TSocket): TPort = proc selectWrite*(writefds: var seq[TSocket], timeout = 500): int -proc acceptAddr*(server: TSocket): tuple[client: TSocket, address: string] = - ## Blocks until a connection is being made from a client. When a connection - ## is made sets ``client`` to the client socket and ``address`` to the address - ## of the connecting client. - ## If ``server`` is non-blocking then this function returns immediately, and - ## if there are no connections queued the returned socket will be - ## ``InvalidSocket``. - ## This function will raise EOS if an error occurs. - ## - ## **Warning:** This function might block even if socket is non-blocking - ## when using SSL. +template acceptAddrPlain(noClientRet, successRet: expr, sslImplementation: stmt): stmt = + assert(client != nil) var sockAddress: Tsockaddr_in var addrLen = sizeof(sockAddress).TSockLen var sock = accept(server.fd, cast[ptr TSockAddr](addr(sockAddress)), @@ -407,20 +396,56 @@ proc acceptAddr*(server: TSocket): tuple[client: TSocket, address: string] = when defined(windows): var err = WSAGetLastError() if err == WSAEINPROGRESS: - return (InvalidSocket, "") + client = InvalidSocket + address = "" + when noClientRet.int == -1: + return + else: + return noClientRet else: OSError() else: if errno == EAGAIN or errno == EWOULDBLOCK: - return (InvalidSocket, "") + client = InvalidSocket + address = "" + when noClientRet.int == -1: + return + else: + return noClientRet else: OSError() - else: + else: + client.fd = sock + client.isBuffered = server.isBuffered + sslImplementation + # Client socket is set above. + address = $inet_ntoa(sockAddress.sin_addr) + when successRet.int == -1: + return + else: + return successRet + +proc acceptAddr*(server: TSocket, client: var TSocket, address: var string) = + ## Blocks until a connection is being made from a client. When a connection + ## is made sets ``client`` to the client socket and ``address`` to the address + ## of the connecting client. + ## If ``server`` is non-blocking then this function returns immediately, and + ## if there are no connections queued the returned socket will be + ## ``InvalidSocket``. + ## This function will raise EOS if an error occurs. + ## + ## The resulting client will inherit any properties of the server socket. For + ## example: whether the socket is buffered or not. + ## + ## **Note**: ``client`` must be initialised, this function makes no effort to + ## initialise the ``client`` variable. + ## + ## **Warning:** When using SSL with non-blocking sockets, it is best to use + ## the acceptAddrAsync procedure as this procedure will most likely block. + acceptAddrPlain(-1, -1): when defined(ssl): if server.isSSL: # We must wrap the client sock in a ssl context. - var client = newTSocket(sock, server.isBuffered) - let wo = server.wrapOptions - wrapSocket(client, wo.protVer, wo.verifyMode, - wo.certFile, wo.keyFile) + + server.sslContext.wrapSocket(client) let ret = SSLAccept(client.sslHandle) while ret <= 0: let err = SSLGetError(client.sslHandle, ret) @@ -428,26 +453,93 @@ proc acceptAddr*(server: TSocket): tuple[client: TSocket, address: string] = case err of SSL_ERROR_ZERO_RETURN: SSLError("TLS/SSL connection failed to initiate, socket closed prematurely.") - of SSL_ERROR_WANT_READ, SSL_ERROR_WANT_WRITE, SSL_ERROR_WANT_CONNECT: - SSLError("The operation did not complete. Perhaps you should use connectAsync?") - of SSL_ERROR_WANT_ACCEPT: - var sss: seq[TSocket] = @[client] - discard selectWrite(sss, 1500) - continue + of SSL_ERROR_WANT_READ, SSL_ERROR_WANT_WRITE, + SSL_ERROR_WANT_CONNECT, SSL_ERROR_WANT_ACCEPT: + SSLError("Please use acceptAsync instead of accept.") of SSL_ERROR_WANT_X509_LOOKUP: SSLError("Function for x509 lookup has been called.") of SSL_ERROR_SYSCALL, SSL_ERROR_SSL: SSLError() else: SSLError("Unknown error") - return (client, $inet_ntoa(sockAddress.sin_addr)) - return (newTSocket(sock, server.isBuffered), $inet_ntoa(sockAddress.sin_addr)) -proc accept*(server: TSocket): TSocket = +proc setBlocking*(s: TSocket, blocking: bool) +when defined(ssl): + proc acceptAddrSSL*(server: TSocket, client: var TSocket, + address: var string): TSSLAcceptResult = + ## This procedure should only be used for non-blocking **SSL** sockets. + ## It will immediatelly return with one of the following values: + ## + ## ``AcceptSuccess`` will be returned when a client has been successfully + ## accepted and the handshake has been successfully performed between + ## ``server`` and the newly connected client. + ## + ## ``AcceptNoHandshake`` will be returned when a client has been accepted + ## but no handshake could be performed. This can happen when the client + ## connects but does not yet initiate a handshake. In this case + ## ``acceptAddrSSL`` should be called again with the same parameters. + ## + ## ``AcceptNoClient`` will be returned when no client is currently attempting + ## to connect. + template doHandshake(): stmt = + when defined(ssl): + if server.isSSL: + client.setBlocking(false) + # We must wrap the client sock in a ssl context. + + if not client.isSSL or client.sslHandle == nil: + server.sslContext.wrapSocket(client) + let ret = SSLAccept(client.sslHandle) + while ret <= 0: + let err = SSLGetError(client.sslHandle, ret) + if err != SSL_ERROR_WANT_ACCEPT: + case err + of SSL_ERROR_ZERO_RETURN: + SSLError("TLS/SSL connection failed to initiate, socket closed prematurely.") + of SSL_ERROR_WANT_READ, SSL_ERROR_WANT_WRITE, + SSL_ERROR_WANT_CONNECT, SSL_ERROR_WANT_ACCEPT: + client.sslNoHandshake = true + return AcceptNoHandshake + of SSL_ERROR_WANT_X509_LOOKUP: + SSLError("Function for x509 lookup has been called.") + of SSL_ERROR_SYSCALL, SSL_ERROR_SSL: + SSLError() + else: + SSLError("Unknown error") + client.sslNoHandshake = false + + if client.isSSL and client.sslNoHandshake: + doHandshake() + return AcceptSuccess + else: + acceptAddrPlain(AcceptNoClient, AcceptSuccess): + doHandshake() + +proc accept*(server: TSocket, client: var TSocket) = ## Equivalent to ``acceptAddr`` but doesn't return the address, only the ## socket. - let (client, a) = acceptAddr(server) - return client + ## + ## **Note**: ``client`` must be initialised, this function makes no effort to + ## initialise the ``client`` variable. + + var addrDummy = "" + acceptAddr(server, client, addrDummy) + +proc acceptAddr*(server: TSocket): tuple[client: TSocket, address: string] {.deprecated.} = + ## Slightly different version of ``acceptAddr``. + ## + ## **Warning**: This function is now deprecated, you shouldn't use it! + var client: TSocket + new(client) + var address = "" + acceptAddr(server, client, address) + return (client, address) + +proc accept*(server: TSocket): TSocket {.deprecated.} = + ## **Warning**: This function is now deprecated, you shouldn't use it! + new(result) + var address = "" + acceptAddr(server, result, address) proc close*(socket: TSocket) = ## closes a socket. @@ -459,8 +551,6 @@ proc close*(socket: TSocket) = when defined(ssl): if socket.isSSL: discard SSLShutdown(socket.sslHandle) - - SSLCTXFree(socket.sslContext) proc getServByName*(name, proto: string): TServent = ## well-known getservbyname proc. @@ -492,11 +582,11 @@ proc getHostByAddr*(ip: string): THostEnt = myaddr.s_addr = inet_addr(ip) when defined(windows): - var s = winlean.gethostbyaddr(addr(myaddr), sizeof(myaddr).cint, + var s = winlean.gethostbyaddr(addr(myaddr), sizeof(myaddr).cuint, cint(sockets.AF_INET)) if s == nil: OSError() else: - var s = posix.gethostbyaddr(addr(myaddr), sizeof(myaddr).cint, + var s = posix.gethostbyaddr(addr(myaddr), sizeof(myaddr).cuint, cint(posix.AF_INET)) if s == nil: raise newException(EOS, $hStrError(h_errno)) @@ -539,7 +629,7 @@ proc getHostByName*(name: string): THostEnt = proc getSockOptInt*(socket: TSocket, level, optname: int): int = ## getsockopt for integer options. var res: cint - var size = sizeof(res).cint + var size = sizeof(res).cuint if getsockopt(socket.fd, cint(level), cint(optname), addr(res), addr(size)) < 0'i32: OSError() @@ -549,7 +639,7 @@ proc setSockOptInt*(socket: TSocket, level, optname, optval: int) = ## setsockopt for integer options. var value = cint(optval) if setsockopt(socket.fd, cint(level), cint(optname), addr(value), - sizeof(value).cint) < 0'i32: + sizeof(value).cuint) < 0'i32: OSError() proc connect*(socket: TSocket, name: string, port = TPort(0), @@ -558,7 +648,8 @@ proc connect*(socket: TSocket, name: string, port = TPort(0), ## host name. If ``name`` is a host name, this function will try each IP ## of that host name. ``htons`` is already performed on ``port`` so you must ## not do it. - + ## + ## If ``socket`` is an SSL socket a handshake will be automatically performed. var hints: TAddrInfo var aiList: ptr TAddrInfo = nil hints.ai_family = toInt(af) @@ -570,7 +661,7 @@ proc connect*(socket: TSocket, name: string, port = TPort(0), var success = false var it = aiList while it != nil: - if connect(socket.fd, it.ai_addr, it.ai_addrlen.cint) == 0'i32: + if connect(socket.fd, it.ai_addr, it.ai_addrlen.cuint) == 0'i32: success = true break it = it.ai_next @@ -614,6 +705,13 @@ proc connect*(socket: TSocket, name: string, port = TPort(0), proc connectAsync*(socket: TSocket, name: string, port = TPort(0), af: TDomain = AF_INET) = ## A variant of ``connect`` for non-blocking sockets. + ## + ## This procedure will immediatelly return, it will not block until a connection + ## is made. It is up to the caller to make sure the connections has been established + ## by checking (using ``select``) whether the socket is writeable. + ## + ## **Note**: For SSL sockets, the ``handshake`` procedure must be called + ## whenever the socket successfully connects to a server. var hints: TAddrInfo var aiList: ptr TAddrInfo = nil hints.ai_family = toInt(af) @@ -624,7 +722,7 @@ proc connectAsync*(socket: TSocket, name: string, port = TPort(0), var success = false var it = aiList while it != nil: - var ret = connect(socket.fd, it.ai_addr, it.ai_addrlen.cint) + var ret = connect(socket.fd, it.ai_addr, it.ai_addrlen.cuint) if ret == 0'i32: success = true break @@ -634,37 +732,61 @@ proc connectAsync*(socket: TSocket, name: string, port = TPort(0), var err = WSAGetLastError() # Windows EINTR doesn't behave same as POSIX. if err == WSAEWOULDBLOCK: - freeaddrinfo(aiList) - return + success = true + break else: if errno == EINTR or errno == EINPROGRESS: - freeaddrinfo(aiList) - return + success = true + break it = it.ai_next freeaddrinfo(aiList) if not success: OSError() - when defined(ssl): if socket.isSSL: + socket.sslNoHandshake = true + +when defined(ssl): + proc handshake*(socket: TSocket): bool = + ## This proc needs to be called on a socket after it connects. This is + ## only applicable when using ``connectAsync``. + ## This proc performs the SSL handshake. + ## + ## Returns ``False`` whenever the socket is not yet ready for a handshake, + ## ``True`` whenever handshake completed successfully. + ## + ## A ESSL error is raised on any other errors. + result = true + if socket.isSSL: var ret = SSLConnect(socket.sslHandle) if ret <= 0: var errret = SSLGetError(socket.sslHandle, ret) case errret of SSL_ERROR_ZERO_RETURN: SSLError("TLS/SSL connection failed to initiate, socket closed prematurely.") - of SSL_ERROR_WANT_READ, SSL_ERROR_WANT_WRITE, - SSL_ERROR_WANT_ACCEPT: - SSLError("Unexpected error occured.") # This should just not happen. - of SSL_ERROR_WANT_CONNECT: - return + of SSL_ERROR_WANT_CONNECT, SSL_ERROR_WANT_ACCEPT, + SSL_ERROR_WANT_READ, SSL_ERROR_WANT_WRITE: + return false of SSL_ERROR_WANT_X509_LOOKUP: SSLError("Function for x509 lookup has been called.") of SSL_ERROR_SYSCALL, SSL_ERROR_SSL: SSLError() else: SSLError("Unknown Error") + socket.sslNoHandshake = false + else: + SSLError("Socket is not an SSL socket.") + + proc gotHandshake*(socket: TSocket): bool = + ## Determines whether a handshake has occurred between a client - ``socket`` + ## and the server that ``socket`` is connected to. + ## + ## Throws ESSL if ``socket`` is not an SSL socket. + if socket.isSSL: + return not socket.sslNoHandshake + else: + SSLError("Socket is not an SSL socket.") proc timeValFromMilliseconds(timeout = 500): TTimeVal = if timeout != -1: @@ -694,6 +816,21 @@ proc pruneSocketSet(s: var seq[TSocket], fd: var TFdSet) = inc(i) setLen(s, L) +proc checkBuffer(readfds: var seq[TSocket]): int = + ## Checks the buffer of each socket in ``readfds`` to see whether there is data. + ## Removes the sockets from ``readfds`` and returns the count of removed sockets. + var res: seq[TSocket] = @[] + result = 0 + for s in readfds: + if s.isBuffered: + if s.bufLen <= 0 or s.currPos == s.bufLen: + res.add(s) + else: + inc(result) + else: + res.add(s) + readfds = res + proc select*(readfds, writefds, exceptfds: var seq[TSocket], timeout = 500): int = ## Traditional select function. This function will return the number of @@ -702,6 +839,9 @@ proc select*(readfds, writefds, exceptfds: var seq[TSocket], ## ## You can determine whether a socket is ready by checking if it's still ## in one of the TSocket sequences. + let buffersFilled = checkBuffer(readfds) + if buffersFilled > 0: + return buffersFilled var tv {.noInit.}: TTimeVal = timeValFromMilliseconds(timeout) @@ -722,6 +862,9 @@ proc select*(readfds, writefds, exceptfds: var seq[TSocket], proc select*(readfds, writefds: var seq[TSocket], timeout = 500): int = + let buffersFilled = checkBuffer(readfds) + if buffersFilled > 0: + return buffersFilled var tv {.noInit.}: TTimeVal = timeValFromMilliseconds(timeout) var rd, wr: TFdSet @@ -753,6 +896,9 @@ proc selectWrite*(writefds: var seq[TSocket], pruneSocketSet(writefds, (wr)) proc select*(readfds: var seq[TSocket], timeout = 500): int = + let buffersFilled = checkBuffer(readfds) + if buffersFilled > 0: + return buffersFilled var tv {.noInit.}: TTimeVal = timeValFromMilliseconds(timeout) var rd: TFdSet @@ -878,7 +1024,7 @@ proc peekChar(socket: TSocket, c: var char): int = proc recvLine*(socket: TSocket, line: var TaintedString): bool = ## retrieves a line from ``socket``. If a full line is received ``\r\L`` is not - ## added to ``line``, however if solely ``\r\L`` is received then ``data`` + ## added to ``line``, however if solely ``\r\L`` is received then ``line`` ## will be set to it. ## ## ``True`` is returned if data is available. ``False`` usually suggests an @@ -945,7 +1091,7 @@ proc recvLineAsync*(socket: TSocket, line: var TaintedString): TRecvLineResult = ## The values of the returned enum should be pretty self explanatory: ## If a full line has been retrieved; ``RecvFullLine`` is returned. ## If some data has been retrieved; ``RecvPartialLine`` is returned. - ## If the socket has been disconnected; ``RecvDisconncted`` is returned. + ## If the socket has been disconnected; ``RecvDisconnected`` is returned. ## If call to ``recv`` failed; ``RecvFail`` is returned. setLen(line.string, 0) while true: @@ -1162,6 +1308,8 @@ proc connect*(socket: TSocket, timeout: int, name: string, port = TPort(0), if selectWrite(s, timeout) != 1: raise newException(ETimeout, "Call to connect() timed out.") +proc isSSL*(socket: TSocket): bool = return socket.isSSL + when defined(Windows): var wsa: TWSADATA if WSAStartup(0x0101'i16, wsa) != 0: OSError() |