diff options
Diffstat (limited to 'lib/pure/net.nim')
-rw-r--r-- | lib/pure/net.nim | 165 |
1 files changed, 115 insertions, 50 deletions
diff --git a/lib/pure/net.nim b/lib/pure/net.nim index 8afc6c5c5..ffbc6e320 100644 --- a/lib/pure/net.nim +++ b/lib/pure/net.nim @@ -1,7 +1,7 @@ # # # Nim's Runtime Library -# (c) Copyright 2014 Dominik Picheta +# (c) Copyright 2015 Dominik Picheta # # See the file "copying.txt", included in this # distribution, for details about the copyright. @@ -44,23 +44,24 @@ const type SocketImpl* = object ## socket type - fd*: SocketHandle - case isBuffered*: bool # determines whether this socket is buffered. + fd: SocketHandle + case isBuffered: bool # determines whether this socket is buffered. of true: - buffer*: array[0..BufferSize, char] - currPos*: int # current index in buffer - bufLen*: int # current length of buffer + buffer: array[0..BufferSize, char] + currPos: int # current index in buffer + bufLen: int # current length of buffer of false: nil when defined(ssl): - case isSsl*: bool + case isSsl: bool of true: - sslHandle*: SSLPtr - sslContext*: SSLContext - sslNoHandshake*: bool # True if needs handshake. - sslHasPeekChar*: bool - sslPeekChar*: char + sslHandle: SSLPtr + sslContext: SSLContext + sslNoHandshake: bool # True if needs handshake. + sslHasPeekChar: bool + sslPeekChar: char of false: nil - + lastError: OSErrorCode ## stores the last error on this socket + Socket* = ref SocketImpl SOBool* = enum ## Boolean socket options. @@ -80,6 +81,23 @@ type TReadLineResult: ReadLineResult, TSOBool: SOBool, PSocket: Socket, TSocketImpl: SocketImpl].} +type + IpAddressFamily* {.pure.} = enum ## Describes the type of an IP address + IPv6, ## IPv6 address + IPv4 ## IPv4 address + + TIpAddress* = object ## stores an arbitrary IP address + case family*: IpAddressFamily ## the type of the IP address (IPv4 or IPv6) + of IpAddressFamily.IPv6: + address_v6*: array[0..15, uint8] ## Contains the IP address in bytes in + ## case of IPv6 + of IpAddressFamily.IPv4: + address_v4*: array[0..3, uint8] ## Contains the IP address in bytes in + ## case of IPv4 + +proc isIpAddress*(address_str: string): bool {.tags: [].} +proc parseIpAddress*(address_str: string): TIpAddress + proc isDisconnectionError*(flags: set[SocketFlag], lastError: OSErrorCode): bool = ## Determines whether ``lastError`` is a disconnection error. Only does this @@ -100,7 +118,8 @@ proc toOSFlags*(socketFlags: set[SocketFlag]): cint = result = result or MSG_PEEK of SocketFlag.SafeDisconn: continue -proc createSocket(fd: SocketHandle, isBuff: bool): Socket = +proc newSocket(fd: SocketHandle, isBuff: bool): Socket = + ## Creates a new socket as specified by the params. assert fd != osInvalidSocket new(result) result.fd = fd @@ -115,17 +134,17 @@ proc newSocket*(domain, typ, protocol: cint, buffered = true): Socket = let fd = newRawSocket(domain, typ, protocol) if fd == osInvalidSocket: raiseOSError(osLastError()) - result = createSocket(fd, buffered) + result = newSocket(fd, buffered) proc newSocket*(domain: Domain = AF_INET, typ: SockType = SOCK_STREAM, - protocol: Protocol = IPPROTO_TCP, buffered = true): Socket = + protocol: Protocol = IPPROTO_TCP, buffered = true): Socket = ## Creates a new socket. ## ## If an error occurs EOS will be raised. let fd = newRawSocket(domain, typ, protocol) if fd == osInvalidSocket: raiseOSError(osLastError()) - result = createSocket(fd, buffered) + result = newSocket(fd, buffered) when defined(ssl): CRYPTO_malloc_init() @@ -230,6 +249,15 @@ when defined(ssl): if SSLSetFd(socket.sslHandle, socket.fd) != 1: raiseSSLError() +proc getSocketError*(socket: Socket): OSErrorCode = + ## Checks ``osLastError`` for a valid error. If it has been reset it uses + ## the last error stored in the socket object. + result = osLastError() + if result == 0.OSErrorCode: + result = socket.lastError + if result == 0.OSErrorCode: + raise newException(OSError, "No valid socket error code available") + proc socketError*(socket: Socket, err: int = -1, async = false, lastError = (-1).OSErrorCode) = ## Raises an OSError based on the error code returned by ``SSLGetError`` @@ -256,12 +284,26 @@ proc socketError*(socket: Socket, err: int = -1, async = false, else: raiseSSLError("Not enough data on socket.") of SSL_ERROR_WANT_X509_LOOKUP: raiseSSLError("Function for x509 lookup has been called.") - of SSL_ERROR_SYSCALL, SSL_ERROR_SSL: + of SSL_ERROR_SYSCALL: + var errStr = "IO error has occurred " + let sslErr = ErrPeekLastError() + if sslErr == 0 and err == 0: + errStr.add "because an EOF was observed that violates the protocol" + elif sslErr == 0 and err == -1: + errStr.add "in the BIO layer" + else: + let errStr = $ErrErrorString(sslErr, nil) + raiseSSLError(errStr & ": " & errStr) + let osMsg = osErrorMsg osLastError() + if osMsg != "": + errStr.add ". The OS reports: " & osMsg + raise newException(OSError, errStr) + of SSL_ERROR_SSL: raiseSSLError() else: raiseSSLError("Unknown Error") if err == -1 and not (when defined(ssl): socket.isSSL else: false): - let lastE = if lastError.int == -1: osLastError() else: lastError + var lastE = if lastError.int == -1: getSocketError(socket) else: lastError if async: when useWinVersion: if lastE.int32 == WSAEWOULDBLOCK: @@ -279,7 +321,8 @@ proc listen*(socket: Socket, backlog = SOMAXCONN) {.tags: [ReadIOEffect].} = ## queue of pending connections. ## ## Raises an EOS error upon failure. - if listen(socket.fd, backlog) < 0'i32: raiseOSError(osLastError()) + if rawsockets.listen(socket.fd, backlog) < 0'i32: + raiseOSError(osLastError()) proc bindAddr*(socket: Socket, port = Port(0), address = "") {. tags: [ReadIOEffect].} = @@ -306,7 +349,8 @@ proc bindAddr*(socket: Socket, port = Port(0), address = "") {. dealloc(aiList) proc acceptAddr*(server: Socket, client: var Socket, address: var string, - flags = {SocketFlag.SafeDisconn}) {.tags: [ReadIOEffect].} = + flags = {SocketFlag.SafeDisconn}) {. + tags: [ReadIOEffect], gcsafe, locks: 0.} = ## Blocks until a connection is being made from a client. When a connection ## is made sets ``client`` to the client socket and ``address`` to the address ## of the connecting client. @@ -418,15 +462,23 @@ proc accept*(server: Socket, client: var Socket, proc close*(socket: Socket) = ## Closes a socket. - socket.fd.close() - when defined(ssl): - if socket.isSSL: - let res = SSLShutdown(socket.sslHandle) - if res == 0: - if SSLShutdown(socket.sslHandle) != 1: - socketError(socket) - elif res != 1: - socketError(socket) + try: + when defined(ssl): + if socket.isSSL: + ErrClearError() + # As we are closing the underlying socket immediately afterwards, + # it is valid, under the TLS standard, to perform a unidirectional + # shutdown i.e not wait for the peers "close notify" alert with a second + # call to SSLShutdown + let res = SSLShutdown(socket.sslHandle) + SSLFree(socket.sslHandle) + socket.sslHandle = nil + if res == 0: + discard + elif res != 1: + socketError(socket, res) + finally: + socket.fd.close() proc toCInt*(opt: SOBool): cint = ## Converts a ``SOBool`` into its Socket Option cint representation. @@ -476,6 +528,12 @@ proc connect*(socket: Socket, address: string, port = Port(0), when defined(ssl): if socket.isSSL: + # RFC3546 for SNI specifies that IP addresses are not allowed. + if not isIpAddress(address): + # Discard result in case OpenSSL version doesn't support SNI, or we're + # not using TLSv1+ + discard SSL_set_tlsext_host_name(socket.sslHandle, address) + let ret = SSLConnect(socket.sslHandle) socketError(socket, ret) @@ -547,6 +605,10 @@ proc readIntoBuf(socket: Socket, flags: int32): int = result = recv(socket.fd, addr(socket.buffer), cint(socket.buffer.high), flags) else: result = recv(socket.fd, addr(socket.buffer), cint(socket.buffer.high), flags) + if result < 0: + # Save it in case it gets reset (the Nim codegen occassionally may call + # Win API functions which reset it). + socket.lastError = osLastError() if result <= 0: socket.bufLen = 0 socket.currPos = 0 @@ -602,6 +664,9 @@ proc recv*(socket: Socket, data: pointer, size: int): int {.tags: [ReadIOEffect] result = recv(socket.fd, data, size.cint, 0'i32) else: result = recv(socket.fd, data, size.cint, 0'i32) + if result < 0: + # Save the error in case it gets reset. + socket.lastError = osLastError() proc waitFor(socket: Socket, waited: var float, timeout, size: int, funcName: string): int {.tags: [TimeEffect].} = @@ -674,7 +739,7 @@ proc recv*(socket: Socket, data: var string, size: int, timeout = -1, result = recv(socket, cstring(data), size, timeout) if result < 0: data.setLen(0) - let lastError = osLastError() + let lastError = getSocketError(socket) if flags.isDisconnectionError(lastError): return socket.socketError(result, lastError = lastError) data.setLen(result) @@ -722,7 +787,7 @@ proc readLine*(socket: Socket, line: var TaintedString, timeout = -1, line.add("\c\L") template raiseSockError(): stmt {.dirty, immediate.} = - let lastError = osLastError() + let lastError = getSocketError(socket) if flags.isDisconnectionError(lastError): setLen(line.string, 0); return socket.socketError(n, lastError = lastError) @@ -865,7 +930,7 @@ proc connectAsync(socket: Socket, name: string, port = Port(0), af: Domain = AF_INET) {.tags: [ReadIOEffect].} = ## A variant of ``connect`` for non-blocking sockets. ## - ## This procedure will immediatelly return, it will not block until a connection + ## This procedure will immediately return, it will not block until a connection ## is made. It is up to the caller to make sure the connection has been established ## by checking (using ``select``) whether the socket is writeable. ## @@ -917,26 +982,16 @@ proc connect*(socket: Socket, address: string, port = Port(0), timeout: int, doAssert socket.handshake() socket.fd.setBlocking(true) -proc isSSL*(socket: Socket): bool = return socket.isSSL +proc isSsl*(socket: Socket): bool = ## Determines whether ``socket`` is a SSL socket. + when defined(ssl): + result = socket.isSSL + else: + result = false -proc getFD*(socket: Socket): SocketHandle = return socket.fd +proc getFd*(socket: Socket): SocketHandle = return socket.fd ## Returns the socket's file descriptor -type - IpAddressFamily* {.pure.} = enum ## Describes the type of an IP address - IPv6, ## IPv6 address - IPv4 ## IPv4 address - - TIpAddress* = object ## stores an arbitrary IP address - case family*: IpAddressFamily ## the type of the IP address (IPv4 or IPv6) - of IpAddressFamily.IPv6: - address_v6*: array[0..15, uint8] ## Contains the IP address in bytes in - ## case of IPv6 - of IpAddressFamily.IPv4: - address_v4*: array[0..3, uint8] ## Contains the IP address in bytes in - ## case of IPv4 - proc IPv4_any*(): TIpAddress = ## Returns the IPv4 any address, which can be used to listen on all available ## network adapters @@ -1195,7 +1250,7 @@ proc parseIPv6Address(address_str: string): TIpAddress = raise newException(ValueError, "Invalid IP Address. The address consists of too many groups") -proc parseIpAddress*(address_str: string): TIpAddress = +proc parseIpAddress(address_str: string): TIpAddress = ## Parses an IP address ## Raises EInvalidValue on error if address_str == nil: @@ -1204,3 +1259,13 @@ proc parseIpAddress*(address_str: string): TIpAddress = return parseIPv6Address(address_str) else: return parseIPv4Address(address_str) + + +proc isIpAddress(address_str: string): bool = + ## Checks if a string is an IP address + ## Returns true if it is, false otherwise + try: + discard parseIpAddress(address_str) + except ValueError: + return false + return true |