diff options
Diffstat (limited to 'lib')
-rwxr-xr-x | lib/impure/ssl.nim | 2 | ||||
-rwxr-xr-x | lib/pure/httpclient.nim | 20 | ||||
-rwxr-xr-x | lib/pure/smtp.nim | 55 | ||||
-rwxr-xr-x | lib/pure/sockets.nim | 454 | ||||
-rwxr-xr-x | lib/wrappers/openssl.nim | 62 |
5 files changed, 473 insertions, 120 deletions
diff --git a/lib/impure/ssl.nim b/lib/impure/ssl.nim index 5fe986b14..4a101ca92 100755 --- a/lib/impure/ssl.nim +++ b/lib/impure/ssl.nim @@ -10,6 +10,8 @@ ## This module provides an easy to use sockets-style ## nimrod interface to the OpenSSL library. +{.deprecate.} + import openssl, strutils, os type diff --git a/lib/pure/httpclient.nim b/lib/pure/httpclient.nim index 3af08f040..c4dbd8509 100755 --- a/lib/pure/httpclient.nim +++ b/lib/pure/httpclient.nim @@ -42,6 +42,14 @@ ## body.add("--xyz--") ## ## echo(postContent("http://validator.w3.org/check", headers, body)) +## +## SSL/TLS support +## =============== +## This requires the OpenSSL library, fortunately it's widely used and installed +## on many operating systems. httpclient will use SSL automatically if you give +## any of the functions a url with the ``https`` schema, for example: +## ``https://github.com/``, you also have to compile with ``ssl`` defined like so: +## ``nimrod c -d:ssl ...``. import sockets, strutils, parseurl, parseutils, strtabs @@ -152,7 +160,7 @@ proc parseBody(d: var string, start: int, s: TSocket, result.add(moreData) proc parseResponse(s: TSocket): TResponse = - var d = s.recv.string + var d = s.recv.string # Warning: without a Connection: Close header this will not work. var i = 0 # Parse the version @@ -241,11 +249,19 @@ proc request*(url: string, httpMethod = httpGET, extraHeaders = "", headers.add(" HTTP/1.1\c\L") add(headers, "Host: " & r.hostname & "\c\L") + add(headers, "Connection: Close\c\L") add(headers, extraHeaders) add(headers, "\c\L") var s = socket() - s.connect(r.hostname, TPort(80)) + var port = TPort(80) + if r.scheme == "https": + when defined(ssl): + s.wrapSocket(verifyMode = CVerifyNone) + port = TPort(443) + if r.port != "": + port = TPort(r.port.parseInt) + s.connect(r.hostname, port) s.send(headers) if body != "": s.send(body) diff --git a/lib/pure/smtp.nim b/lib/pure/smtp.nim index 7eeb026d3..58c1d4b58 100755 --- a/lib/pure/smtp.nim +++ b/lib/pure/smtp.nim @@ -25,20 +25,17 @@ ## smtp.sendmail("username@gmail.com", @["foo@gmail.com"], $msg) ## ## -## For SSL support this module relies on the SSL module. If you want to -## disable SSL, compile with ``-d:NoSSL``. +## For SSL support this module relies on OpenSSL. If you want to +## enable SSL, compile with ``-d:ssl``. -import sockets, strutils, strtabs, base64, os +when not defined(ssl): + {.error: "The SMTP module should be compiled with SSL support. Compile with -d:ssl."} -when not defined(noSSL): - import ssl +import sockets, strutils, strtabs, base64, os type TSMTP* {.final.} = object sock: TSocket - when not defined(noSSL): - sslSock: TSecureSocket - ssl: Bool debug: Bool TMessage* {.final.} = object @@ -53,20 +50,13 @@ type proc debugSend(smtp: TSMTP, cmd: string) = if smtp.debug: echo("C:" & cmd) - if not smtp.ssl: - smtp.sock.send(cmd) - else: - when not defined(noSSL): - smtp.sslSock.send(cmd) + smtp.sock.send(cmd) -proc debugRecv(smtp: TSMTP): TaintedString = +proc debugRecv(smtp: var TSMTP): TaintedString = var line = TaintedString"" var ret = False - if not smtp.ssl: - ret = smtp.sock.recvLine(line) - else: - when not defined(noSSL): - ret = smtp.sslSock.recvLine(line) + ret = smtp.sock.recvLine(line) + if ret: if smtp.debug: echo("S:" & line.string) @@ -79,7 +69,7 @@ proc quitExcpt(smtp: TSMTP, msg: string) = smtp.debugSend("QUIT") raise newException(EInvalidReply, msg) -proc checkReply(smtp: TSMTP, reply: string) = +proc checkReply(smtp: var TSMTP, reply: string) = var line = smtp.debugRecv() if not line.string.startswith(reply): quitExcpt(smtp, "Expected " & reply & " reply, got: " & line.string) @@ -88,25 +78,21 @@ proc connect*(address: string, port = 25, ssl = false, debug = false): TSMTP = ## Establishes a connection with a SMTP server. ## May fail with EInvalidReply or with a socket error. - - if not ssl: - result.sock = socket() - result.sock.connect(address, TPort(port)) - else: - when not defined(noSSL): - result.ssl = True - discard result.sslSock.connect(address, port) + result.sock = socket() + if ssl: + when defined(ssl): + result.sock.wrapSocket(verifyMode = CVerifyNone) else: - raise newException(EInvalidReply, + raise newException(ESystem, "SMTP module compiled without SSL support") - + result.sock.connect(address, TPort(port)) result.debug = debug result.checkReply("220") result.debugSend("HELO " & address & "\c\L") result.checkReply("250") -proc auth*(smtp: TSMTP, username, password: string) = +proc auth*(smtp: var TSMTP, username, password: string) = ## Sends an AUTH command to the server to login as the `username` ## using `password`. ## May fail with EInvalidReply. @@ -120,7 +106,7 @@ proc auth*(smtp: TSMTP, username, password: string) = smtp.debugSend(encode(password) & "\c\L") smtp.checkReply("235") # Check whether the authentification was successful. -proc sendmail*(smtp: TSMTP, fromaddr: string, +proc sendmail*(smtp: var TSMTP, fromaddr: string, toaddrs: seq[string], msg: string) = ## Sends `msg` from `fromaddr` to `toaddr`. ## Messages may be formed using ``createMessage`` by converting the @@ -142,10 +128,7 @@ proc sendmail*(smtp: TSMTP, fromaddr: string, proc close*(smtp: TSMTP) = ## Disconnects from the SMTP server and closes the socket. smtp.debugSend("QUIT\c\L") - if not smtp.ssl: - smtp.sock.close() - else: - smtp.sslSock.close() + smtp.sock.close() proc createMessage*(mSubject, mBody: string, mTo, mCc: seq[string], otherHeaders: openarray[tuple[name, value: string]]): TMessage = diff --git a/lib/pure/sockets.nim b/lib/pure/sockets.nim index eeec62843..517952781 100755 --- a/lib/pure/sockets.nim +++ b/lib/pure/sockets.nim @@ -10,11 +10,16 @@ ## This module implements a simple portable type-safe sockets layer. ## ## Most procedures raise EOS on error. - +## +## For OpenSSL support compile with ``-d:ssl``. When using SSL be aware that +## most functions will then raise ``ESSL`` on SSL errors. import os, parseutils from times import epochTime +when defined(ssl): + import openssl + when defined(Windows): import winlean else: @@ -22,8 +27,40 @@ else: # Note: The enumerations are mapped to Window's constants. +when defined(ssl): + type + ESSL* = object of ESynch + + TSSLCVerifyMode* = enum + CVerifyNone, CVerifyPeer + + TSSLProtVersion* = enum + protSSLv2, protSSLv3, protTLSv1, protSSLv23 + + TSSLOptions* = object + verifyMode*: TSSLCVerifyMode + certFile*, keyFile*: string + protVer*: TSSLprotVersion + type - TSocket* = distinct cint ## socket type + TSocketImpl = object ## socket type + fd: cint + case isBuffered: bool # determines whether this socket is buffered. + of true: + buffer: array[0..4000, char] + currPos: int # current index in buffer + bufLen: int # current length of buffer + of false: nil + when defined(ssl): + case isSsl: bool + of true: + sslHandle: PSSL + sslContext: PSSLCTX + wrapOptions: TSSLOptions + of false: nil + + TSocket* = ref TSocketImpl + TPort* = distinct int16 ## port type TDomain* = enum ## domain, which specifies the protocol family of the @@ -65,11 +102,15 @@ type ETimeout* = object of ESynch -const - InvalidSocket* = TSocket(-1'i32) ## invalid socket number +proc newTSocket(fd: int32, isBuff: bool): TSocket = + new(result) + result.fd = fd + result.isBuffered = isBuff + if isBuff: + result.currPos = 0 -proc `==`*(a, b: TSocket): bool {.borrow.} - ## ``==`` for sockets. +let + InvalidSocket*: TSocket = nil ## invalid socket proc `==`*(a, b: TPort): bool {.borrow.} ## ``==`` for ports. @@ -144,18 +185,111 @@ else: result = cint(ord(p)) proc socket*(domain: TDomain = AF_INET, typ: TType = SOCK_STREAM, - protocol: TProtocol = IPPROTO_TCP): TSocket = + protocol: TProtocol = IPPROTO_TCP, buffered = true): TSocket = ## creates a new socket; returns `InvalidSocket` if an error occurs. when defined(Windows): - result = TSocket(winlean.socket(ord(domain), ord(typ), ord(protocol))) + result = newTSocket(winlean.socket(ord(domain), ord(typ), ord(protocol)), buffered) else: - result = TSocket(posix.socket(ToInt(domain), ToInt(typ), ToInt(protocol))) + result = newTSocket(posix.socket(ToInt(domain), ToInt(typ), ToInt(protocol)), buffered) + +when defined(ssl): + CRYPTO_malloc_init() + SslLibraryInit() + SslLoadErrorStrings() + ErrLoadBioStrings() + OpenSSL_add_all_algorithms() + + proc SSLError(s = "") = + if s != "": + raise newException(ESSL, s) + let err = ErrGetError() + if err == 0: + raise newException(ESSL, "An EOF was observed that violates the protocol.") + if err == -1: + OSError() + var errStr = ErrErrorString(err, nil) + raise newException(ESSL, $errStr) + + # http://simplestcodings.blogspot.co.uk/2010/08/secure-server-client-using-openssl-in-c.html + proc loadCertificates(socket: var TSocket, certFile, keyFile: string) = + if certFile != "": + if SSLCTXUseCertificateFile(socket.sslContext, certFile, + SSL_FILETYPE_PEM) != 1: + SSLError() + if keyFile != "": + if SSL_CTX_use_PrivateKey_file(socket.sslContext, keyFile, + SSL_FILETYPE_PEM) != 1: + SSLError() + + if SSL_CTX_check_private_key(socket.sslContext) != 1: + SSLError("Verification of private key file failed.") + + proc wrapSocket*(socket: var TSocket, protVersion = ProtSSLv23, + verifyMode = CVerifyPeer, + certFile = "", keyFile = "") = + ## Creates a SSL context for ``socket`` and wraps the socket in it. + ## + ## Protocol version specifies the protocol to use. SSLv2, SSLv3, TLSv1 are + ## are available with the addition of ``ProtSSLv23`` which allows for + ## compatibility with all of them. + ## + ## There are currently only two options for verify mode; one is ``CVerifyNone`` + ## and with it certificates will not be verified the other is ``CVerifyPeer`` + ## and certificates will be verified for it, ``CVerifyPeer`` is the safest choice. + ## + ## The last two parameters specify the certificate file path and the key file + ## path, a server socket will most likely not work without these. + ## Certificates can be generated using the following command: + ## ``openssl req -x509 -nodes -days 365 -newkey rsa:1024 -keyout mycert.pem -out mycert.pem``. + ## + ## **Warning:** Because SSL is meant to be secure I feel the need to warn you + ## that this "wrapper" has not been thorougly tested and is therefore + ## most likely very prone to security vulnerabilities. + + socket.isSSL = true + socket.wrapOptions.verifyMode = verifyMode + socket.wrapOptions.certFile = certFile + socket.wrapOptions.keyFile = keyFile + socket.wrapOptions.protVer = protVersion + + case protVersion + of protSSLv23: + socket.sslContext = SSL_CTX_new(SSLv23_method()) # SSlv2,3 and TLS1 support. + of protSSLv2: + socket.sslContext = SSL_CTX_new(SSLv2_method()) + of protSSLv3: + socket.sslContext = SSL_CTX_new(SSLv3_method()) + of protTLSv1: + socket.sslContext = SSL_CTX_new(TLSv1_method()) + + if socket.sslContext.SSLCTXSetCipherList("ALL") != 1: + SSLError() + case verifyMode + of CVerifyPeer: + socket.sslContext.SSLCTXSetVerify(SSLVerifyPeer, nil) + of CVerifyNone: + socket.sslContext.SSLCTXSetVerify(SSLVerifyNone, nil) + if socket.sslContext == nil: + SSLError() + + socket.loadCertificates(certFile, keyFile) + + socket.sslHandle = SSLNew(socket.sslContext) + if socket.sslHandle == nil: + SSLError() + + if SSLSetFd(socket.sslHandle, socket.fd) != 1: + SSLError() + + proc wrapSocket*(socket: var TSocket, wo: TSSLOptions) = + ## A variant of the above with a options object. + wrapSocket(socket, wo.protVer, wo.verifyMode, wo.certFile, wo.keyFile) proc listen*(socket: TSocket, backlog = SOMAXCONN) = ## Marks ``socket`` as accepting connections. ## ``Backlog`` specifies the maximum length of the ## queue of pending connections. - if listen(cint(socket), cint(backlog)) < 0'i32: OSError() + if listen(socket.fd, cint(backlog)) < 0'i32: OSError() proc invalidIp4(s: string) {.noreturn, noinline.} = raise newException(EInvalidValue, "invalid ip4 address: " & s) @@ -208,7 +342,7 @@ proc bindAddr*(socket: TSocket, port = TPort(0), address = "") = name.sin_family = posix.AF_INET name.sin_port = sockets.htons(int16(port)) name.sin_addr.s_addr = sockets.htonl(INADDR_ANY) - if bindSocket(cint(socket), cast[ptr TSockAddr](addr(name)), + if bindSocket(socket.fd, cast[ptr TSockAddr](addr(name)), sizeof(name)) < 0'i32: OSError() else: @@ -218,7 +352,7 @@ proc bindAddr*(socket: TSocket, port = TPort(0), address = "") = hints.ai_socktype = toInt(SOCK_STREAM) hints.ai_protocol = toInt(IPPROTO_TCP) gaiNim(address, port, hints, aiList) - if bindSocket(cint(socket), aiList.ai_addr, aiList.ai_addrLen) < 0'i32: + if bindSocket(socket.fd, aiList.ai_addr, aiList.ai_addrLen) < 0'i32: OSError() when false: @@ -245,47 +379,88 @@ proc getSockName*(socket: TSocket): TPort = #name.sin_port = htons(cint16(port)) #name.sin_addr.s_addr = htonl(INADDR_ANY) var namelen: cint = sizeof(name) - if getsockname(cint(socket), cast[ptr TSockAddr](addr(name)), + if getsockname(socket.fd, cast[ptr TSockAddr](addr(name)), addr(namelen)) == -1'i32: OSError() result = TPort(sockets.ntohs(name.sin_port)) -proc acceptAddr*(server: TSocket): tuple[sock: TSocket, address: string] = +proc selectWrite*(writefds: var seq[TSocket], timeout = 500): int + +proc acceptAddr*(server: TSocket): tuple[client: TSocket, address: string] = ## Blocks until a connection is being made from a client. When a connection - ## is made returns the client socket and address of the connecting client. + ## is made sets ``client`` to the client socket and ``address`` to the address + ## of the connecting client. ## If ``server`` is non-blocking then this function returns immediately, and ## if there are no connections queued the returned socket will be ## ``InvalidSocket``. ## This function will raise EOS if an error occurs. - var address: Tsockaddr_in - var addrLen: cint = sizeof(address) - var sock = accept(cint(server), cast[ptr TSockAddr](addr(address)), + ## + ## **Warning:** This function might block even if socket is non-blocking + ## when using SSL. + var sockAddress: Tsockaddr_in + var addrLen: cint = sizeof(sockAddress) + var sock = accept(server.fd, cast[ptr TSockAddr](addr(sockAddress)), addr(addrLen)) + if sock < 0: # TODO: Test on Windows. when defined(windows): var err = WSAGetLastError() if err == WSAEINPROGRESS: - return (InvalidSocket, "") + client = InvalidSocket else: OSError() else: if errno == EAGAIN or errno == EWOULDBLOCK: return (InvalidSocket, "") else: OSError() - else: return (TSocket(sock), $inet_ntoa(address.sin_addr)) + else: + when defined(ssl): + if server.isSSL: + # We must wrap the client sock in a ssl context. + var client = newTSocket(sock, server.isBuffered) + let wo = server.wrapOptions + wrapSocket(client, wo.protVer, wo.verifyMode, + wo.certFile, wo.keyFile) + let ret = SSLAccept(client.sslHandle) + while ret <= 0: + let err = SSLGetError(client.sslHandle, ret) + if err != SSL_ERROR_WANT_ACCEPT: + case err + of SSL_ERROR_ZERO_RETURN: + SSLError("TLS/SSL connection failed to initiate, socket closed prematurely.") + of SSL_ERROR_WANT_READ, SSL_ERROR_WANT_WRITE, SSL_ERROR_WANT_CONNECT: + SSLError("The operation did not complete. Perhaps you should use connectAsync?") + of SSL_ERROR_WANT_ACCEPT: + var sss: seq[TSocket] = @[client] + discard selectWrite(sss, 1500) + continue + of SSL_ERROR_WANT_X509_LOOKUP: + SSLError("Function for x509 lookup has been called.") + of SSL_ERROR_SYSCALL, SSL_ERROR_SSL: + SSLError() + else: + SSLError("Unknown error") + return (client, $inet_ntoa(sockAddress.sin_addr)) + return (newTSocket(sock, server.isBuffered), $inet_ntoa(sockAddress.sin_addr)) proc accept*(server: TSocket): TSocket = ## Equivalent to ``acceptAddr`` but doesn't return the address, only the ## socket. - var (client, a) = acceptAddr(server) + let (client, a) = acceptAddr(server) return client proc close*(socket: TSocket) = ## closes a socket. when defined(windows): - discard winlean.closeSocket(cint(socket)) + discard winlean.closeSocket(socket.fd) else: - discard posix.close(cint(socket)) + discard posix.close(socket.fd) + + when defined(ssl): + if socket.isSSL: + discard SSLShutdown(socket.sslHandle) + + SSLCTXFree(socket.sslContext) proc getServByName*(name, proto: string): TServent = ## well-known getservbyname proc. @@ -365,7 +540,7 @@ proc getSockOptInt*(socket: TSocket, level, optname: int): int = ## getsockopt for integer options. var res: cint var size: cint = sizeof(res) - if getsockopt(cint(socket), cint(level), cint(optname), + if getsockopt(socket.fd, cint(level), cint(optname), addr(res), addr(size)) < 0'i32: OSError() result = int(res) @@ -373,7 +548,7 @@ proc getSockOptInt*(socket: TSocket, level, optname: int): int = proc setSockOptInt*(socket: TSocket, level, optname, optval: int) = ## setsockopt for integer options. var value = cint(optval) - if setsockopt(cint(socket), cint(level), cint(optname), addr(value), + if setsockopt(socket.fd, cint(level), cint(optname), addr(value), sizeof(value)) < 0'i32: OSError() @@ -395,7 +570,7 @@ proc connect*(socket: TSocket, name: string, port = TPort(0), var success = false var it = aiList while it != nil: - if connect(cint(socket), it.ai_addr, it.ai_addrlen) == 0'i32: + if connect(socket.fd, it.ai_addr, it.ai_addrlen) == 0'i32: success = true break it = it.ai_next @@ -403,6 +578,24 @@ proc connect*(socket: TSocket, name: string, port = TPort(0), freeaddrinfo(aiList) if not success: OSError() + when defined(ssl): + if socket.isSSL: + let ret = SSLConnect(socket.sslHandle) + if ret <= 0: + let err = SSLGetError(socket.sslHandle, ret) + case err + of SSL_ERROR_ZERO_RETURN: + SSLError("TLS/SSL connection failed to initiate, socket closed prematurely.") + of SSL_ERROR_WANT_READ, SSL_ERROR_WANT_WRITE, SSL_ERROR_WANT_CONNECT, + SSL_ERROR_WANT_ACCEPT: + SSLError("The operation did not complete. Perhaps you should use connectAsync?") + of SSL_ERROR_WANT_X509_LOOKUP: + SSLError("Function for x509 lookup has been called.") + of SSL_ERROR_SYSCALL, SSL_ERROR_SSL: + SSLError() + else: + SSLError("Unknown error") + when false: var s: TSockAddrIn s.sin_addr.s_addr = inet_addr(name) @@ -415,7 +608,7 @@ proc connect*(socket: TSocket, name: string, port = TPort(0), of AF_INET: s.sin_family = posix.AF_INET of AF_INET6: s.sin_family = posix.AF_INET6 else: nil - if connect(cint(socket), cast[ptr TSockAddr](addr(s)), sizeof(s)) < 0'i32: + if connect(socket.fd, cast[ptr TSockAddr](addr(s)), sizeof(s)) < 0'i32: OSError() proc connectAsync*(socket: TSocket, name: string, port = TPort(0), @@ -431,7 +624,7 @@ proc connectAsync*(socket: TSocket, name: string, port = TPort(0), var success = false var it = aiList while it != nil: - var ret = connect(cint(socket), it.ai_addr, it.ai_addrlen) + var ret = connect(socket.fd, it.ai_addr, it.ai_addrlen) if ret == 0'i32: success = true break @@ -453,6 +646,26 @@ proc connectAsync*(socket: TSocket, name: string, port = TPort(0), freeaddrinfo(aiList) if not success: OSError() + when defined(ssl): + if socket.isSSL: + var ret = SSLConnect(socket.sslHandle) + if ret <= 0: + var errret = SSLGetError(socket.sslHandle, ret) + case errret + of SSL_ERROR_ZERO_RETURN: + SSLError("TLS/SSL connection failed to initiate, socket closed prematurely.") + of SSL_ERROR_WANT_READ, SSL_ERROR_WANT_WRITE, + SSL_ERROR_WANT_ACCEPT: + SSLError("Unexpected error occured.") # This should just not happen. + of SSL_ERROR_WANT_CONNECT: + return + of SSL_ERROR_WANT_X509_LOOKUP: + SSLError("Function for x509 lookup has been called.") + of SSL_ERROR_SYSCALL, SSL_ERROR_SSL: + SSLError() + else: + SSLError("Unknown Error") + proc timeValFromMilliseconds(timeout = 500): TTimeVal = if timeout != -1: var seconds = timeout div 1000 @@ -467,14 +680,14 @@ proc timeValFromMilliseconds(timeout = 500): TTimeVal = proc createFdSet(fd: var TFdSet, s: seq[TSocket], m: var int) = FD_ZERO(fd) for i in items(s): - m = max(m, int(i)) - FD_SET(cint(i), fd) + m = max(m, int(i.fd)) + FD_SET(i.fd, fd) proc pruneSocketSet(s: var seq[TSocket], fd: var TFdSet) = var i = 0 var L = s.len while i < L: - if FD_ISSET(cint(s[i]), fd) != 0'i32: + if FD_ISSET(s[i].fd, fd) != 0'i32: s[i] = s[L-1] dec(L) else: @@ -552,19 +765,64 @@ proc select*(readfds: var seq[TSocket], timeout = 500): int = result = int(select(cint(m+1), addr(rd), nil, nil, nil)) pruneSocketSet(readfds, (rd)) - + +proc readIntoBuf(socket: TSocket, flags: int32): int = + result = 0 + when defined(ssl): + if socket.isSSL: + result = SSLRead(socket.sslHandle, addr(socket.buffer), int(socket.buffer.high)) + else: + result = recv(socket.fd, addr(socket.buffer), int(socket.buffer.high), flags) + else: + result = recv(socket.fd, addr(socket.buffer), int(socket.buffer.high), flags) + if result <= 0: return + socket.bufLen = result + socket.currPos = 0 + +template retRead(flags, read: int) = + let res = socket.readIntoBuf(flags) + if res <= 0: + if read > 0: + return read + else: + return res + proc recv*(socket: TSocket, data: pointer, size: int): int = ## receives data from a socket - result = recv(cint(socket), data, size, 0'i32) - -template waitFor(): stmt = - if timeout - int(waited * 1000.0) < 1: - raise newException(ETimeout, "Call to recv() timed out.") - var s = @[socket] - var startTime = epochTime() - if select(s, timeout - int(waited * 1000.0)) != 1: - raise newException(ETimeout, "Call to recv() timed out.") - waited += (epochTime() - startTime) + if socket.isBuffered: + if socket.bufLen == 0: + retRead(0'i32, 0) + + var read = 0 + while read < size: + if socket.currPos >= socket.bufLen: + retRead(0'i32, read) + + let chunk = min(socket.bufLen, size-read) + var d = cast[cstring](data) + copyMem(addr(d[read]), addr(socket.buffer[socket.currPos]), chunk) + read.inc(chunk) + socket.currPos.inc(chunk) + + result = read + else: + when defined(ssl): + if socket.isSSL: + result = SSLRead(socket.sslHandle, data, size) + else: + result = recv(socket.fd, data, size, 0'i32) + else: + result = recv(socket.fd, data, size, 0'i32) + +proc waitFor(socket: TSocket, waited: var float, timeout: int) = + if socket.bufLen == 0: + if timeout - int(waited * 1000.0) < 1: + raise newException(ETimeout, "Call to recv() timed out.") + var s = @[socket] + var startTime = epochTime() + if select(s, timeout - int(waited * 1000.0)) != 1: + raise newException(ETimeout, "Call to recv() timed out.") + waited += (epochTime() - startTime) proc recv*(socket: TSocket, data: var string, size: int, timeout: int): int = ## overload with a ``timeout`` parameter in miliseconds. @@ -572,14 +830,30 @@ proc recv*(socket: TSocket, data: var string, size: int, timeout: int): int = var read = 0 while read < size: - waitFor() - result = recv(cint(socket), addr(data[read]), 1, 0'i32) + waitFor(socket, waited, timeout) + result = recv(socket, addr(data[read]), 1) if result < 0: return inc(read) result = read +proc peekChar(socket: TSocket, c: var char): int = + if socket.isBuffered: + result = 1 + if socket.bufLen == 0 or socket.currPos > socket.bufLen-1: + var res = socket.readIntoBuf(0'i32) + if res <= 0: + result = res + + c = socket.buffer[socket.currPos] + else: + when defined(ssl): + if socket.isSSL: + raise newException(ESSL, "Sorry, you cannot use recvLine on an unbuffered SSL socket.") + + result = recv(socket.fd, addr(c), 1, MSG_PEEK) + proc recvLine*(socket: TSocket, line: var TaintedString): bool = ## retrieves a line from ``socket``. If a full line is received ``\r\L`` is not ## added to ``line``, however if solely ``\r\L`` is received then ``data`` @@ -590,6 +864,9 @@ proc recvLine*(socket: TSocket, line: var TaintedString): bool = ## ## If the socket is disconnected, ``line`` will be set to ``""`` and ``True`` ## will be returned. + ## + ## **Warning:** Using this function on a unbuffered ssl socket will result + ## in an error. template addNLIfEmpty(): stmt = if line.len == 0: line.add("\c\L") @@ -597,13 +874,13 @@ proc recvLine*(socket: TSocket, line: var TaintedString): bool = setLen(line.string, 0) while true: var c: char - var n = recv(cint(socket), addr(c), 1, 0'i32) + var n = recv(socket, addr(c), 1) if n < 0: return elif n == 0: return true if c == '\r': - n = recv(cint(socket), addr(c), 1, MSG_PEEK) + n = peekChar(socket, c) if n > 0 and c == '\L': - discard recv(cint(socket), addr(c), 1, 0'i32) + discard recv(socket, addr(c), 1) elif n <= 0: return false addNlIfEmpty() return true @@ -624,15 +901,15 @@ proc recvLine*(socket: TSocket, line: var TaintedString, timeout: int): bool = setLen(line.string, 0) while true: var c: char - waitFor() - var n = recv(cint(socket), addr(c), 1, 0'i32) + waitFor(socket, waited, timeout) + var n = recv(socket, addr(c), 1) if n < 0: return elif n == 0: return true if c == '\r': - waitFor() - n = recv(cint(socket), addr(c), 1, MSG_PEEK) + waitFor(socket, waited, timeout) + n = peekChar(socket, c) if n > 0 and c == '\L': - discard recv(cint(socket), addr(c), 1, 0'i32) + discard recv(socket, addr(c), 1) elif n <= 0: return false addNlIfEmpty() return true @@ -651,15 +928,15 @@ proc recvLineAsync*(socket: TSocket, line: var TaintedString): TRecvLineResult = setLen(line.string, 0) while true: var c: char - var n = recv(cint(socket), addr(c), 1, 0'i32) + var n = recv(socket, addr(c), 1) if n < 0: return (if line.len == 0: RecvFail else: RecvPartialLine) elif n == 0: return (if line.len == 0: RecvDisconnected else: RecvPartialLine) if c == '\r': - n = recv(cint(socket), addr(c), 1, MSG_PEEK) + n = peekChar(socket, c) if n > 0 and c == '\L': - discard recv(cint(socket), addr(c), 1, 0'i32) + discard recv(socket, addr(c), 1) elif n <= 0: return (if line.len == 0: RecvFail else: RecvPartialLine) return RecvFullLine @@ -671,7 +948,7 @@ proc recv*(socket: TSocket): TaintedString = ## Socket errors will result in an ``EOS`` error. ## If socket is not a connectionless socket and socket is not connected ## ``""`` will be returned. - const bufSize = 1000 + const bufSize = 4000 result = newStringOfCap(bufSize).TaintedString var pos = 0 while true: @@ -699,9 +976,10 @@ proc recvTimeout*(socket: TSocket, timeout: int): TaintedString = ## overloaded variant to support a ``timeout`` parameter, the ``timeout`` ## parameter specifies the amount of miliseconds to wait for data on the ## socket. - var s = @[socket] - if s.select(timeout) != 1: - raise newException(ETimeout, "Call to recv() timed out.") + if socket.bufLen == 0: + var s = @[socket] + if s.select(timeout) != 1: + raise newException(ETimeout, "Call to recv() timed out.") return socket.recv @@ -718,7 +996,24 @@ proc recvAsync*(socket: TSocket, s: var TaintedString): bool = var pos = 0 while true: var bytesRead = recv(socket, addr(string(s)[pos]), bufSize-1) - if bytesRead == -1: + when defined(ssl): + if socket.isSSL: + if bytesRead <= 0: + var ret = SSLGetError(socket.sslHandle, bytesRead) + case ret + of SSL_ERROR_ZERO_RETURN: + SSLError("TLS/SSL connection failed to initiate, socket closed prematurely.") + of SSL_ERROR_WANT_CONNECT, SSL_ERROR_WANT_ACCEPT: + SSLError("Unexpected error occured.") # This should just not happen. + of SSL_ERROR_WANT_WRITE, SSL_ERROR_WANT_READ: + return false + of SSL_ERROR_WANT_X509_LOOKUP: + SSLError("Function for x509 lookup has been called.") + of SSL_ERROR_SYSCALL, SSL_ERROR_SSL: + SSLError() + else: SSLError("Unknown Error") + + if bytesRead == -1 and not (when defined(ssl): socket.isSSL else: false): when defined(windows): # TODO: Test on Windows var err = WSAGetLastError() @@ -746,19 +1041,46 @@ proc skip*(socket: TSocket) = proc send*(socket: TSocket, data: pointer, size: int): int = ## sends data to a socket. + when defined(ssl): + if socket.isSSL: + return SSLWrite(socket.sslHandle, cast[cstring](data), size) + when defined(windows) or defined(macosx): - result = send(cint(socket), data, size, 0'i32) + result = send(socket.fd, data, size, 0'i32) else: - result = send(cint(socket), data, size, int32(MSG_NOSIGNAL)) + result = send(socket.fd, data, size, int32(MSG_NOSIGNAL)) proc send*(socket: TSocket, data: string) = ## sends data to a socket. - if send(socket, cstring(data), data.len) != data.len: OSError() + if send(socket, cstring(data), data.len) != data.len: + when defined(ssl): + if socket.isSSL: + SSLError() + + OSError() proc sendAsync*(socket: TSocket, data: string): bool = ## sends data to a non-blocking socket. Returns whether ``data`` was sent. result = true var bytesSent = send(socket, cstring(data), data.len) + when defined(ssl): + if socket.isSSL: + if bytesSent <= 0: + let ret = SSLGetError(socket.sslHandle, bytesSent) + case ret + of SSL_ERROR_ZERO_RETURN: + SSLError("TLS/SSL connection failed to initiate, socket closed prematurely.") + of SSL_ERROR_WANT_CONNECT, SSL_ERROR_WANT_ACCEPT: + SSLError("Unexpected error occured.") # This should just not happen. + of SSL_ERROR_WANT_WRITE, SSL_ERROR_WANT_READ: + return false + of SSL_ERROR_WANT_X509_LOOKUP: + SSLError("Function for x509 lookup has been called.") + of SSL_ERROR_SYSCALL, SSL_ERROR_SSL: + SSLError() + else: SSLError("Unknown Error") + else: + return if bytesSent == -1: when defined(windows): var err = WSAGetLastError() @@ -792,15 +1114,15 @@ proc setBlocking*(s: TSocket, blocking: bool) = ## sets blocking mode on socket when defined(Windows): var mode = clong(ord(not blocking)) # 1 for non-blocking, 0 for blocking - if SOCKET_ERROR == ioctlsocket(TWinSocket(s), FIONBIO, addr(mode)): + if SOCKET_ERROR == ioctlsocket(TWinSocket(s.fd), FIONBIO, addr(mode)): OSError() else: # BSD sockets - var x: int = fcntl(cint(s), F_GETFL, 0) + var x: int = fcntl(s.fd, F_GETFL, 0) if x == -1: OSError() else: var mode = if blocking: x and not O_NONBLOCK else: x or O_NONBLOCK - if fcntl(cint(s), F_SETFL, mode) == -1: + if fcntl(s.fd, F_SETFL, mode) == -1: OSError() proc connect*(socket: TSocket, timeout: int, name: string, port = TPort(0), diff --git a/lib/wrappers/openssl.nim b/lib/wrappers/openssl.nim index 5fc6ddd02..b5eed38f3 100755 --- a/lib/wrappers/openssl.nim +++ b/lib/wrappers/openssl.nim @@ -192,19 +192,44 @@ const BIO_C_DO_STATE_MACHINE = 101 BIO_C_GET_SSL = 110 -proc SSL_library_init*(): cInt{.cdecl, dynlib: DLLSSLName, importc.} +proc SSL_library_init*(): cInt{.cdecl, dynlib: DLLSSLName, importc, discardable.} proc SSL_load_error_strings*(){.cdecl, dynlib: DLLSSLName, importc.} proc ERR_load_BIO_strings*(){.cdecl, dynlib: DLLSSLName, importc.} proc SSLv23_client_method*(): PSSL_METHOD{.cdecl, dynlib: DLLSSLName, importc.} +proc SSLv23_method*(): PSSL_METHOD{.cdecl, dynlib: DLLSSLName, importc.} +proc SSLv2_method*(): PSSL_METHOD{.cdecl, dynlib: DLLSSLName, importc.} +proc SSLv3_method*(): PSSL_METHOD{.cdecl, dynlib: DLLSSLName, importc.} +proc TLSv1_method*(): PSSL_METHOD{.cdecl, dynlib: DLLSSLName, importc.} +proc SSL_new*(context: PSSL_CTX): PSSL{.cdecl, dynlib: DLLSSLName, importc.} +proc SSL_free*(ssl: PSSL){.cdecl, dynlib: DLLSSLName, importc.} proc SSL_CTX_new*(meth: PSSL_METHOD): PSSL_CTX{.cdecl, dynlib: DLLSSLName, importc.} proc SSL_CTX_load_verify_locations*(ctx: PSSL_CTX, CAfile: cstring, CApath: cstring): cInt{.cdecl, dynlib: DLLSSLName, importc.} +proc SSL_CTX_free*(arg0: PSSL_CTX){.cdecl, dynlib: DLLSSLName, importc.} +proc SSL_CTX_set_verify*(s: PSSL_CTX, mode: int, cb: proc (a: int, b: pointer): int){.cdecl, dynlib: DLLSSLName, importc.} proc SSL_get_verify_result*(ssl: PSSL): int{.cdecl, dynlib: DLLSSLName, importc.} +proc SSL_CTX_set_cipher_list*(s: PSSLCTX, ciphers: cstring): cint{.cdecl, dynlib: DLLSSLName, importc.} +proc SSL_CTX_use_certificate_file*(ctx: PSSL_CTX, filename: cstring, typ: cInt): cInt{. + cdecl, dynlib: DLLSSLName, importc.} +proc SSL_CTX_use_PrivateKey_file*(ctx: PSSL_CTX, + filename: cstring, typ: cInt): cInt{.cdecl, dynlib: DLLSSLName, importc.} +proc SSL_CTX_check_private_key*(ctx: PSSL_CTX): cInt{.cdecl, dynlib: DLLSSLName, + importc.} + +proc SSL_set_fd*(ssl: PSSL, fd: cint): cint{.cdecl, dynlib: DLLSSLName, importc.} + +proc SSL_shutdown*(ssl: PSSL): cInt{.cdecl, dynlib: DLLSSLName, importc.} +proc SSL_connect*(ssl: PSSL): cint{.cdecl, dynlib: DLLSSLName, importc.} +proc SSL_read*(ssl: PSSL, buf: pointer, num: int): cint{.cdecl, dynlib: DLLSSLName, importc.} +proc SSL_write*(ssl: PSSL, buf: cstring, num: int): cint{.cdecl, dynlib: DLLSSLName, importc.} +proc SSL_get_error*(s: PSSL, ret_code: cInt): cInt{.cdecl, dynlib: DLLSSLName, importc.} +proc SSL_accept*(ssl: PSSL): cInt{.cdecl, dynlib: DLLSSLName, importc.} + proc BIO_new_ssl_connect*(ctx: PSSL_CTX): PBIO{.cdecl, dynlib: DLLSSLName, importc.} proc BIO_ctrl*(bio: PBIO, cmd: cint, larg: int, arg: cstring): int{.cdecl, @@ -227,16 +252,27 @@ proc BIO_free*(b: PBIO): cInt{.cdecl, dynlib: DLLUtilName, importc.} proc ERR_print_errors_fp*(fp: TFile){.cdecl, dynlib: DLLSSLName, importc.} +proc ERR_error_string*(e: cInt, buf: cstring): cstring{.cdecl, + dynlib: DLLUtilName, importc.} +proc ERR_get_error*(): cInt{.cdecl, dynlib: DLLUtilName, importc.} + +proc OpenSSL_add_all_algorithms*(){.cdecl, dynlib: DLLSSLName, importc: "OPENSSL_add_all_algorithms_conf".} + +proc OPENSSL_config*(configName: cstring){.cdecl, dynlib: DLLSSLName, importc.} + +proc CRYPTO_set_mem_functions(a,b,c: pointer){.cdecl, dynlib: DLLSSLName, importc.} + +proc CRYPTO_malloc_init*() = + CRYPTO_set_mem_functions(alloc, realloc, dealloc) + when True: nil else: - proc SslGetError*(s: PSSL, ret_code: cInt): cInt{.cdecl, dynlib: DLLSSLName, - importc.} proc SslCtxSetCipherList*(arg0: PSSL_CTX, str: cstring): cInt{.cdecl, dynlib: DLLSSLName, importc.} proc SslCtxNew*(meth: PSSL_METHOD): PSSL_CTX{.cdecl, dynlib: DLLSSLName, importc.} - proc SslCtxFree*(arg0: PSSL_CTX){.cdecl, dynlib: DLLSSLName, importc.} + proc SslSetFd*(s: PSSL, fd: cInt): cInt{.cdecl, dynlib: DLLSSLName, importc.} proc SslCtrl*(ssl: PSSL, cmd: cInt, larg: int, parg: Pointer): int{.cdecl, dynlib: DLLSSLName, importc.} @@ -255,19 +291,15 @@ else: dynlib: DLLSSLName, importc.} proc SslCtxUsePrivateKeyASN1*(pk: cInt, ctx: PSSL_CTX, d: cstring, length: int): cInt{.cdecl, dynlib: DLLSSLName, importc.} - proc SslCtxUsePrivateKeyFile*(ctx: PSSL_CTX, - filename: cstring, typ: cInt): cInt{.cdecl, dynlib: DLLSSLName, importc.} + proc SslCtxUseCertificate*(ctx: PSSL_CTX, x: SslPtr): cInt{.cdecl, dynlib: DLLSSLName, importc.} proc SslCtxUseCertificateASN1*(ctx: PSSL_CTX, length: int, d: cstring): cInt{. cdecl, dynlib: DLLSSLName, importc.} - proc SslCtxUseCertificateFile*(ctx: PSSL_CTX, filename: cstring, typ: cInt): cInt{. - cdecl, dynlib: DLLSSLName, importc.} + # function SslCtxUseCertificateChainFile(ctx: PSSL_CTX; const filename: PChar):cInt; proc SslCtxUseCertificateChainFile*(ctx: PSSL_CTX, filename: cstring): cInt{. cdecl, dynlib: DLLSSLName, importc.} - proc SslCtxCheckPrivateKeyFile*(ctx: PSSL_CTX): cInt{.cdecl, dynlib: DLLSSLName, - importc.} proc SslCtxSetDefaultPasswdCb*(ctx: PSSL_CTX, cb: PPasswdCb){.cdecl, dynlib: DLLSSLName, importc.} proc SslCtxSetDefaultPasswdCbUserdata*(ctx: PSSL_CTX, u: SslPtr){.cdecl, @@ -276,10 +308,10 @@ else: proc SslCtxLoadVerifyLocations*(ctx: PSSL_CTX, CAfile: cstring, CApath: cstring): cInt{. cdecl, dynlib: DLLSSLName, importc.} proc SslNew*(ctx: PSSL_CTX): PSSL{.cdecl, dynlib: DLLSSLName, importc.} - proc SslFree*(ssl: PSSL){.cdecl, dynlib: DLLSSLName, importc.} - proc SslAccept*(ssl: PSSL): cInt{.cdecl, dynlib: DLLSSLName, importc.} + + proc SslConnect*(ssl: PSSL): cInt{.cdecl, dynlib: DLLSSLName, importc.} - proc SslShutdown*(ssl: PSSL): cInt{.cdecl, dynlib: DLLSSLName, importc.} + proc SslRead*(ssl: PSSL, buf: SslPtr, num: cInt): cInt{.cdecl, dynlib: DLLSSLName, importc.} proc SslPeek*(ssl: PSSL, buf: SslPtr, num: cInt): cInt{.cdecl, @@ -339,9 +371,7 @@ else: proc EVPcleanup*(){.cdecl, dynlib: DLLUtilName, importc.} # function ErrErrorString(e: cInt; buf: PChar): PChar; proc SSLeayversion*(t: cInt): cstring{.cdecl, dynlib: DLLUtilName, importc.} - proc ErrErrorString*(e: cInt, buf: cstring, length: cInt){.cdecl, - dynlib: DLLUtilName, importc.} - proc ErrGetError*(): cInt{.cdecl, dynlib: DLLUtilName, importc.} + proc ErrClearError*(){.cdecl, dynlib: DLLUtilName, importc.} proc ErrFreeStrings*(){.cdecl, dynlib: DLLUtilName, importc.} proc ErrRemoveState*(pid: cInt){.cdecl, dynlib: DLLUtilName, importc.} |