diff options
author | Dominik Picheta <dominikpicheta@googlemail.com> | 2015-06-22 21:34:21 +0100 |
---|---|---|
committer | Dominik Picheta <dominikpicheta@googlemail.com> | 2015-06-22 21:34:21 +0100 |
commit | 8853dfb3539048b3d3d905b1a41a45860bf2d327 (patch) | |
tree | 2bd57ce43d552392209c8e5a34de8a0e955b1213 /lib | |
parent | 37677636bc919c4a8fab25fedbcf917dcf177d98 (diff) | |
parent | df1cdced1d9ec5663c735065a21dc5b00067b8b2 (diff) | |
download | Nim-8853dfb3539048b3d3d905b1a41a45860bf2d327.tar.gz |
Merge branch 'starttls' of https://github.com/wiml/Nim into wiml-starttls
Conflicts: lib/pure/net.nim
Diffstat (limited to 'lib')
-rw-r--r-- | lib/pure/asyncnet.nim | 9 | ||||
-rw-r--r-- | lib/pure/net.nim | 133 |
2 files changed, 89 insertions, 53 deletions
diff --git a/lib/pure/asyncnet.nim b/lib/pure/asyncnet.nim index 01c28a13a..4b221eb72 100644 --- a/lib/pure/asyncnet.nim +++ b/lib/pure/asyncnet.nim @@ -472,6 +472,15 @@ when defined(ssl): socket.bioOut = bioNew(bio_s_mem()) sslSetBio(socket.sslHandle, socket.bioIn, socket.bioOut) + proc wrapSocket*(ctx: SslContext, socket: AsyncSocket, handshake: SslHandshakeType) = + wrapSocket(ctx, socket) + + case handshake + of handshakeAsClient: + sslSetConnectState(socket.sslHandle) + of handshakeAsServer: + sslSetAcceptState(socket.sslHandle) + proc getSockOpt*(socket: AsyncSocket, opt: SOBool, level = SOL_SOCKET): bool {. tags: [ReadIOEffect].} = ## Retrieves option ``opt`` as a boolean value. diff --git a/lib/pure/net.nim b/lib/pure/net.nim index 9ce0669bc..7dcc35495 100644 --- a/lib/pure/net.nim +++ b/lib/pure/net.nim @@ -26,15 +26,18 @@ when defined(ssl): SslCVerifyMode* = enum CVerifyNone, CVerifyPeer - + SslProtVersion* = enum protSSLv2, protSSLv3, protTLSv1, protSSLv23 - + SslContext* = distinct SslCtx SslAcceptResult* = enum AcceptNoClient = 0, AcceptNoHandshake, AcceptSuccess + SslHandshakeType* = enum + handshakeAsClient, handshakeAsServer + {.deprecated: [ESSL: SSLError, TSSLCVerifyMode: SSLCVerifyMode, TSSLProtVersion: SSLProtVersion, PSSLContext: SSLContext, TSSLAcceptResult: SSLAcceptResult].} @@ -86,7 +89,7 @@ type IPv6, ## IPv6 address IPv4 ## IPv4 address - IpAddress* = object ## stores an arbitrary IP address + IpAddress* = object ## stores an arbitrary IP address case family*: IpAddressFamily ## the type of the IP address (IPv4 or IPv6) of IpAddressFamily.IPv6: address_v6*: array[0..15, uint8] ## Contains the IP address in bytes in @@ -98,6 +101,8 @@ type proc isIpAddress*(address_str: string): bool {.tags: [].} proc parseIpAddress*(address_str: string): IpAddress +proc socketError*(socket: Socket, err: int = -1, async = false, +lastError = (-1).OSErrorCode): void proc isDisconnectionError*(flags: set[SocketFlag], lastError: OSErrorCode): bool = @@ -109,7 +114,7 @@ proc isDisconnectionError*(flags: set[SocketFlag], WSAEDISCON, ERROR_NETNAME_DELETED} else: SocketFlag.SafeDisconn in flags and - lastError.int32 in {ECONNRESET, EPIPE, ENETRESET} + lastError.int32 in {ECONNRESET, EPIPE, ENETRESET} proc toOSFlags*(socketFlags: set[SocketFlag]): cint = ## Converts the flags into the underlying OS representation. @@ -172,27 +177,27 @@ when defined(ssl): raise newException(system.IOError, "Certificate file could not be found: " & certFile) if keyFile != "" and not existsFile(keyFile): raise newException(system.IOError, "Key file could not be found: " & keyFile) - + if certFile != "": var ret = SSLCTXUseCertificateChainFile(ctx, certFile) if ret != 1: raiseSSLError() - + # TODO: Password? www.rtfm.com/openssl-examples/part1.pdf if keyFile != "": if SSL_CTX_use_PrivateKey_file(ctx, keyFile, SSL_FILETYPE_PEM) != 1: raiseSSLError() - + if SSL_CTX_check_private_key(ctx) != 1: raiseSSLError("Verification of private key file failed.") proc newContext*(protVersion = protSSLv23, verifyMode = CVerifyPeer, certFile = "", keyFile = ""): SSLContext = ## Creates an SSL context. - ## - ## Protocol version specifies the protocol to use. SSLv2, SSLv3, TLSv1 - ## are available with the addition of ``protSSLv23`` which allows for + ## + ## Protocol version specifies the protocol to use. SSLv2, SSLv3, TLSv1 + ## are available with the addition of ``protSSLv23`` which allows for ## compatibility with all of them. ## ## There are currently only two options for verify mode; @@ -217,7 +222,7 @@ when defined(ssl): newCTX = SSL_CTX_new(SSLv3_method()) of protTLSv1: newCTX = SSL_CTX_new(TLSv1_method()) - + if newCTX.SSLCTXSetCipherList("ALL") != 1: raiseSSLError() case verifyMode @@ -236,9 +241,13 @@ when defined(ssl): ## Wraps a socket in an SSL context. This function effectively turns ## ``socket`` into an SSL socket. ## + ## This must be called on an unconnected socket; an SSL session will + ## be started when the socket is connected. + ## ## **Disclaimer**: This code is not well tested, may be very unsafe and ## prone to security vulnerabilities. - + + assert (not socket.isSSL) socket.isSSL = true socket.sslContext = ctx socket.sslHandle = SSLNew(SSLCTX(socket.sslContext)) @@ -246,10 +255,28 @@ when defined(ssl): socket.sslHasPeekChar = false if socket.sslHandle == nil: raiseSSLError() - + if SSLSetFd(socket.sslHandle, socket.fd) != 1: raiseSSLError() + proc wrapSocket*(ctx: SSLContext, socket: Socket, handshake: SslHandshakeType) = + ## Wraps a socket in an SSL context. This function effectively turns + ## ``socket`` into an SSL socket. + ## + ## This should be called on a connected socket, and will perform + ## an SSL handshake immediately. + ## + ## **Disclaimer**: This code is not well tested, may be very unsafe and + ## prone to security vulnerabilities. + wrapSocket(ctx, socket) + case handshake + of handshakeAsClient: + let ret = SSLConnect(socket.sslHandle) + socketError(socket, ret) + of handshakeAsServer: + let ret = SSLAccept(socket.sslHandle) + socketError(socket, ret) + proc getSocketError*(socket: Socket): OSErrorCode = ## Checks ``osLastError`` for a valid error. If it has been reset it uses ## the last error stored in the socket object. @@ -302,7 +329,7 @@ proc socketError*(socket: Socket, err: int = -1, async = false, of SSL_ERROR_SSL: raiseSSLError() else: raiseSSLError("Unknown Error") - + if err == -1 and not (when defined(ssl): socket.isSSL else: false): var lastE = if lastError.int == -1: getSocketError(socket) else: lastError if async: @@ -317,8 +344,8 @@ proc socketError*(socket: Socket, err: int = -1, async = false, else: raiseOSError(lastE) proc listen*(socket: Socket, backlog = SOMAXCONN) {.tags: [ReadIOEffect].} = - ## Marks ``socket`` as accepting connections. - ## ``Backlog`` specifies the maximum length of the + ## Marks ``socket`` as accepting connections. + ## ``Backlog`` specifies the maximum length of the ## queue of pending connections. ## ## Raises an EOS error upon failure. @@ -360,7 +387,7 @@ proc acceptAddr*(server: Socket, client: var Socket, address: var string, ## The resulting client will inherit any properties of the server socket. For ## example: whether the socket is buffered or not. ## - ## **Note**: ``client`` must be initialised (with ``new``), this function + ## **Note**: ``client`` must be initialised (with ``new``), this function ## makes no effort to initialise the ``client`` variable. ## ## The ``accept`` call may result in an error if the connecting socket @@ -372,7 +399,7 @@ proc acceptAddr*(server: Socket, client: var Socket, address: var string, var addrLen = sizeof(sockAddress).SockLen var sock = accept(server.fd, cast[ptr SockAddr](addr(sockAddress)), addr(addrLen)) - + if sock == osInvalidSocket: let err = osLastError() if flags.isDisconnectionError(err): @@ -386,11 +413,11 @@ proc acceptAddr*(server: Socket, client: var Socket, address: var string, when defined(ssl): if server.isSSL: # We must wrap the client sock in a ssl context. - + server.sslContext.wrapSocket(client) let ret = SSLAccept(client.sslHandle) socketError(client, ret, false) - + # Client socket is set above. address = $inet_ntoa(sockAddress.sin_addr) @@ -398,9 +425,9 @@ when false: #defined(ssl): proc acceptAddrSSL*(server: Socket, client: var Socket, address: var string): SSLAcceptResult {. tags: [ReadIOEffect].} = - ## This procedure should only be used for non-blocking **SSL** sockets. + ## This procedure should only be used for non-blocking **SSL** sockets. ## It will immediately return with one of the following values: - ## + ## ## ``AcceptSuccess`` will be returned when a client has been successfully ## accepted and the handshake has been successfully performed between ## ``server`` and the newly connected client. @@ -417,7 +444,7 @@ when false: #defined(ssl): if server.isSSL: client.setBlocking(false) # We must wrap the client sock in a ssl context. - + if not client.isSSL or client.sslHandle == nil: server.sslContext.wrapSocket(client) let ret = SSLAccept(client.sslHandle) @@ -450,7 +477,7 @@ proc accept*(server: Socket, client: var Socket, flags = {SocketFlag.SafeDisconn}) {.tags: [ReadIOEffect].} = ## Equivalent to ``acceptAddr`` but doesn't return the address, only the ## socket. - ## + ## ## **Note**: ``client`` must be initialised (with ``new``), this function ## makes no effort to initialise the ``client`` variable. ## @@ -504,7 +531,7 @@ proc setSockOpt*(socket: Socket, opt: SOBool, value: bool, level = SOL_SOCKET) { var valuei = cint(if value: 1 else: 0) setSockOptInt(socket.fd, cint(level), toCInt(opt), valuei) -proc connect*(socket: Socket, address: string, port = Port(0), +proc connect*(socket: Socket, address: string, port = Port(0), af: Domain = AF_INET) {.tags: [ReadIOEffect].} = ## Connects socket to ``address``:``port``. ``Address`` can be an IP address or a ## host name. If ``address`` is a host name, this function will try each IP @@ -526,7 +553,7 @@ proc connect*(socket: Socket, address: string, port = Port(0), dealloc(aiList) if not success: raiseOSError(lastError) - + when defined(ssl): if socket.isSSL: # RFC3546 for SNI specifies that IP addresses are not allowed. @@ -634,12 +661,12 @@ proc recv*(socket: Socket, data: pointer, size: int): int {.tags: [ReadIOEffect] if socket.isBuffered: if socket.bufLen == 0: retRead(0'i32, 0) - + var read = 0 while read < size: if socket.currPos >= socket.bufLen: retRead(0'i32, read) - + let chunk = min(socket.bufLen-socket.currPos, size-read) var d = cast[cstring](data) assert size-read >= chunk @@ -686,7 +713,7 @@ proc waitFor(socket: Socket, waited: var float, timeout, size: int, else: if timeout - int(waited * 1000.0) < 1: raise newException(TimeoutError, "Call to '" & funcName & "' timed out.") - + when defined(ssl): if socket.isSSL: if socket.hasDataBuffered: @@ -695,7 +722,7 @@ proc waitFor(socket: Socket, waited: var float, timeout, size: int, let sslPending = SSLPending(socket.sslHandle) if sslPending != 0: return sslPending - + var startTime = epochTime() let selRet = select(socket, timeout - int(waited * 1000.0)) if selRet < 0: raiseOSError(osLastError()) @@ -706,8 +733,8 @@ proc waitFor(socket: Socket, waited: var float, timeout, size: int, proc recv*(socket: Socket, data: pointer, size: int, timeout: int): int {. tags: [ReadIOEffect, TimeEffect].} = ## overload with a ``timeout`` parameter in milliseconds. - var waited = 0.0 # number of seconds already waited - + var waited = 0.0 # number of seconds already waited + var read = 0 while read < size: let avail = waitFor(socket, waited, timeout, size-read, "recv") @@ -718,7 +745,7 @@ proc recv*(socket: Socket, data: pointer, size: int, timeout: int): int {. if result < 0: return result inc(read, result) - + result = read proc recv*(socket: Socket, data: var string, size: int, timeout = -1, @@ -752,7 +779,7 @@ proc peekChar(socket: Socket, c: var char): int {.tags: [ReadIOEffect].} = var res = socket.readIntoBuf(0'i32) if res <= 0: result = res - + c = socket.buffer[socket.currPos] else: when defined(ssl): @@ -760,7 +787,7 @@ proc peekChar(socket: Socket, c: var char): int {.tags: [ReadIOEffect].} = if not socket.sslHasPeekChar: result = SSLRead(socket.sslHandle, addr(socket.sslPeekChar), 1) socket.sslHasPeekChar = true - + c = socket.sslPeekChar return result = recv(socket.fd, addr(c), 1, MSG_PEEK) @@ -773,7 +800,7 @@ proc readLine*(socket: Socket, line: var TaintedString, timeout = -1, ## If a full line is read ``\r\L`` is not ## added to ``line``, however if solely ``\r\L`` is read then ``line`` ## will be set to it. - ## + ## ## If the socket is disconnected, ``line`` will be set to ``""``. ## ## An EOS exception will be raised in the case of a socket error. @@ -782,7 +809,7 @@ proc readLine*(socket: Socket, line: var TaintedString, timeout = -1, ## the specified time an ETimeout exception will be raised. ## ## **Warning**: Only the ``SafeDisconn`` flag is currently supported. - + template addNLIfEmpty(): stmt = if line.len == 0: line.add("\c\L") @@ -809,7 +836,7 @@ proc readLine*(socket: Socket, line: var TaintedString, timeout = -1, elif n <= 0: raiseSockError() addNLIfEmpty() return - elif c == '\L': + elif c == '\L': addNLIfEmpty() return add(line.string, c) @@ -827,7 +854,7 @@ proc recvFrom*(socket: Socket, data: var string, length: int, ## so when ``socket`` is buffered the non-buffered implementation will be ## used. Therefore if ``socket`` contains something in its buffer this ## function will make no effort to return it. - + # TODO: Buffered sockets data.setLen(length) var sockAddress: Sockaddr_in @@ -861,16 +888,16 @@ proc send*(socket: Socket, data: pointer, size: int): int {. tags: [WriteIOEffect].} = ## Sends data to a socket. ## - ## **Note**: This is a low-level version of ``send``. You likely should use + ## **Note**: This is a low-level version of ``send``. You likely should use ## the version below. when defined(ssl): if socket.isSSL: return SSLWrite(socket.sslHandle, cast[cstring](data), size) - + when useWinVersion or defined(macosx): result = send(socket.fd, data, size.cint, 0'i32) else: - when defined(solaris): + when defined(solaris): const MSG_NOSIGNAL = 0 result = send(socket.fd, data, size, int32(MSG_NOSIGNAL)) @@ -895,7 +922,7 @@ proc sendTo*(socket: Socket, address: string, port: Port, data: pointer, size: int, af: Domain = AF_INET, flags = 0'i32): int {. tags: [WriteIOEffect].} = ## This proc sends ``data`` to the specified ``address``, - ## which may be an IP address or a hostname, if a hostname is specified + ## which may be an IP address or a hostname, if a hostname is specified ## this function will try each IP of that hostname. ## ## @@ -904,7 +931,7 @@ proc sendTo*(socket: Socket, address: string, port: Port, data: pointer, ## ## **Note:** This proc is not available for SSL sockets. var aiList = getAddrInfo(address, port, af) - + # try all possibilities: var success = false var it = aiList @@ -918,10 +945,10 @@ proc sendTo*(socket: Socket, address: string, port: Port, data: pointer, dealloc(aiList) -proc sendTo*(socket: Socket, address: string, port: Port, +proc sendTo*(socket: Socket, address: string, port: Port, data: string): int {.tags: [WriteIOEffect].} = ## This proc sends ``data`` to the specified ``address``, - ## which may be an IP address or a hostname, if a hostname is specified + ## which may be an IP address or a hostname, if a hostname is specified ## this function will try each IP of that hostname. ## ## This is the high-level version of the above ``sendTo`` function. @@ -958,7 +985,7 @@ proc connectAsync(socket: Socket, name: string, port = Port(0), if lastError.int32 == EINTR or lastError.int32 == EINPROGRESS: success = true break - + it = it.ai_next dealloc(aiList) @@ -971,7 +998,7 @@ proc connect*(socket: Socket, address: string, port = Port(0), timeout: int, ## The ``timeout`` paremeter specifies the time in milliseconds to allow for ## the connection to the server to be made. socket.fd.setBlocking(false) - + socket.connectAsync(address, port, af) var s = @[socket.fd] if selectWrite(s, timeout) != 1: @@ -983,7 +1010,7 @@ proc connect*(socket: Socket, address: string, port = Port(0), timeout: int, doAssert socket.handshake() socket.fd.setBlocking(true) -proc isSsl*(socket: Socket): bool = +proc isSsl*(socket: Socket): bool = ## Determines whether ``socket`` is a SSL socket. when defined(ssl): result = socket.isSSL @@ -1014,7 +1041,7 @@ proc IPv4_broadcast*(): IpAddress = proc IPv6_any*(): IpAddress = ## Returns the IPv6 any address (::0), which can be used - ## to listen on all available network adapters + ## to listen on all available network adapters result = IpAddress( family: IpAddressFamily.IPv6, address_v6: [0'u8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]) @@ -1152,7 +1179,7 @@ proc parseIPv6Address(address_str: string): IpAddress = if not seperatorValid: raise newException(ValueError, "Invalid IP Address. Address contains an invalid seperator") - if lastWasColon: + if lastWasColon: if dualColonGroup != -1: raise newException(ValueError, "Invalid IP Address. Address contains more than one \"::\" seperator") @@ -1165,14 +1192,14 @@ proc parseIPv6Address(address_str: string): IpAddress = result.address_v6[groupCount*2] = cast[uint8](currentShort shr 8) result.address_v6[groupCount*2+1] = cast[uint8](currentShort and 0xFF) currentShort = 0 - groupCount.inc() + groupCount.inc() if dualColonGroup != -1: seperatorValid = false elif i == 0: # only valid if address starts with :: if address_str[1] != ':': raise newException(ValueError, "Invalid IP Address. Address may not start with \":\"") else: # i == high(address_str) - only valid if address ends with :: - if address_str[high(address_str)-1] != ':': + if address_str[high(address_str)-1] != ':': raise newException(ValueError, "Invalid IP Address. Address may not end with \":\"") lastWasColon = true |