summary refs log tree commit diff stats
path: root/lib
diff options
context:
space:
mode:
authorDominik Picheta <dominikpicheta@googlemail.com>2015-06-22 21:34:21 +0100
committerDominik Picheta <dominikpicheta@googlemail.com>2015-06-22 21:34:21 +0100
commit8853dfb3539048b3d3d905b1a41a45860bf2d327 (patch)
tree2bd57ce43d552392209c8e5a34de8a0e955b1213 /lib
parent37677636bc919c4a8fab25fedbcf917dcf177d98 (diff)
parentdf1cdced1d9ec5663c735065a21dc5b00067b8b2 (diff)
downloadNim-8853dfb3539048b3d3d905b1a41a45860bf2d327.tar.gz
Merge branch 'starttls' of https://github.com/wiml/Nim into wiml-starttls
Conflicts:
	lib/pure/net.nim
Diffstat (limited to 'lib')
-rw-r--r--lib/pure/asyncnet.nim9
-rw-r--r--lib/pure/net.nim133
2 files changed, 89 insertions, 53 deletions
diff --git a/lib/pure/asyncnet.nim b/lib/pure/asyncnet.nim
index 01c28a13a..4b221eb72 100644
--- a/lib/pure/asyncnet.nim
+++ b/lib/pure/asyncnet.nim
@@ -472,6 +472,15 @@ when defined(ssl):
     socket.bioOut = bioNew(bio_s_mem())
     sslSetBio(socket.sslHandle, socket.bioIn, socket.bioOut)
 
+  proc wrapSocket*(ctx: SslContext, socket: AsyncSocket, handshake: SslHandshakeType) =
+    wrapSocket(ctx, socket)
+
+    case handshake
+    of handshakeAsClient:
+      sslSetConnectState(socket.sslHandle)
+    of handshakeAsServer:
+      sslSetAcceptState(socket.sslHandle)
+
 proc getSockOpt*(socket: AsyncSocket, opt: SOBool, level = SOL_SOCKET): bool {.
   tags: [ReadIOEffect].} =
   ## Retrieves option ``opt`` as a boolean value.
diff --git a/lib/pure/net.nim b/lib/pure/net.nim
index 9ce0669bc..7dcc35495 100644
--- a/lib/pure/net.nim
+++ b/lib/pure/net.nim
@@ -26,15 +26,18 @@ when defined(ssl):
 
     SslCVerifyMode* = enum
       CVerifyNone, CVerifyPeer
-    
+
     SslProtVersion* = enum
       protSSLv2, protSSLv3, protTLSv1, protSSLv23
-    
+
     SslContext* = distinct SslCtx
 
     SslAcceptResult* = enum
       AcceptNoClient = 0, AcceptNoHandshake, AcceptSuccess
 
+    SslHandshakeType* = enum
+      handshakeAsClient, handshakeAsServer
+
   {.deprecated: [ESSL: SSLError, TSSLCVerifyMode: SSLCVerifyMode,
     TSSLProtVersion: SSLProtVersion, PSSLContext: SSLContext,
     TSSLAcceptResult: SSLAcceptResult].}
@@ -86,7 +89,7 @@ type
     IPv6, ## IPv6 address
     IPv4  ## IPv4 address
 
-  IpAddress* = object ## stores an arbitrary IP address    
+  IpAddress* = object ## stores an arbitrary IP address
     case family*: IpAddressFamily ## the type of the IP address (IPv4 or IPv6)
     of IpAddressFamily.IPv6:
       address_v6*: array[0..15, uint8] ## Contains the IP address in bytes in
@@ -98,6 +101,8 @@ type
 
 proc isIpAddress*(address_str: string): bool {.tags: [].}
 proc parseIpAddress*(address_str: string): IpAddress
+proc socketError*(socket: Socket, err: int = -1, async = false,
+lastError = (-1).OSErrorCode): void
 
 proc isDisconnectionError*(flags: set[SocketFlag],
     lastError: OSErrorCode): bool =
@@ -109,7 +114,7 @@ proc isDisconnectionError*(flags: set[SocketFlag],
                           WSAEDISCON, ERROR_NETNAME_DELETED}
   else:
     SocketFlag.SafeDisconn in flags and
-      lastError.int32 in {ECONNRESET, EPIPE, ENETRESET} 
+      lastError.int32 in {ECONNRESET, EPIPE, ENETRESET}
 
 proc toOSFlags*(socketFlags: set[SocketFlag]): cint =
   ## Converts the flags into the underlying OS representation.
@@ -172,27 +177,27 @@ when defined(ssl):
       raise newException(system.IOError, "Certificate file could not be found: " & certFile)
     if keyFile != "" and not existsFile(keyFile):
       raise newException(system.IOError, "Key file could not be found: " & keyFile)
-    
+
     if certFile != "":
       var ret = SSLCTXUseCertificateChainFile(ctx, certFile)
       if ret != 1:
         raiseSSLError()
-    
+
     # TODO: Password? www.rtfm.com/openssl-examples/part1.pdf
     if keyFile != "":
       if SSL_CTX_use_PrivateKey_file(ctx, keyFile,
                                      SSL_FILETYPE_PEM) != 1:
         raiseSSLError()
-        
+
       if SSL_CTX_check_private_key(ctx) != 1:
         raiseSSLError("Verification of private key file failed.")
 
   proc newContext*(protVersion = protSSLv23, verifyMode = CVerifyPeer,
                    certFile = "", keyFile = ""): SSLContext =
     ## Creates an SSL context.
-    ## 
-    ## Protocol version specifies the protocol to use. SSLv2, SSLv3, TLSv1 
-    ## are available with the addition of ``protSSLv23`` which allows for 
+    ##
+    ## Protocol version specifies the protocol to use. SSLv2, SSLv3, TLSv1
+    ## are available with the addition of ``protSSLv23`` which allows for
     ## compatibility with all of them.
     ##
     ## There are currently only two options for verify mode;
@@ -217,7 +222,7 @@ when defined(ssl):
       newCTX = SSL_CTX_new(SSLv3_method())
     of protTLSv1:
       newCTX = SSL_CTX_new(TLSv1_method())
-    
+
     if newCTX.SSLCTXSetCipherList("ALL") != 1:
       raiseSSLError()
     case verifyMode
@@ -236,9 +241,13 @@ when defined(ssl):
     ## Wraps a socket in an SSL context. This function effectively turns
     ## ``socket`` into an SSL socket.
     ##
+    ## This must be called on an unconnected socket; an SSL session will
+    ## be started when the socket is connected.
+    ##
     ## **Disclaimer**: This code is not well tested, may be very unsafe and
     ## prone to security vulnerabilities.
-    
+
+    assert (not socket.isSSL)
     socket.isSSL = true
     socket.sslContext = ctx
     socket.sslHandle = SSLNew(SSLCTX(socket.sslContext))
@@ -246,10 +255,28 @@ when defined(ssl):
     socket.sslHasPeekChar = false
     if socket.sslHandle == nil:
       raiseSSLError()
-    
+
     if SSLSetFd(socket.sslHandle, socket.fd) != 1:
       raiseSSLError()
 
+  proc wrapSocket*(ctx: SSLContext, socket: Socket, handshake: SslHandshakeType) =
+    ## Wraps a socket in an SSL context. This function effectively turns
+    ## ``socket`` into an SSL socket.
+    ##
+    ## This should be called on a connected socket, and will perform
+    ## an SSL handshake immediately.
+    ##
+    ## **Disclaimer**: This code is not well tested, may be very unsafe and
+    ## prone to security vulnerabilities.
+    wrapSocket(ctx, socket)
+    case handshake
+    of handshakeAsClient:
+      let ret = SSLConnect(socket.sslHandle)
+      socketError(socket, ret)
+    of handshakeAsServer:
+      let ret = SSLAccept(socket.sslHandle)
+      socketError(socket, ret)
+
 proc getSocketError*(socket: Socket): OSErrorCode =
   ## Checks ``osLastError`` for a valid error. If it has been reset it uses
   ## the last error stored in the socket object.
@@ -302,7 +329,7 @@ proc socketError*(socket: Socket, err: int = -1, async = false,
         of SSL_ERROR_SSL:
           raiseSSLError()
         else: raiseSSLError("Unknown Error")
-  
+
   if err == -1 and not (when defined(ssl): socket.isSSL else: false):
     var lastE = if lastError.int == -1: getSocketError(socket) else: lastError
     if async:
@@ -317,8 +344,8 @@ proc socketError*(socket: Socket, err: int = -1, async = false,
     else: raiseOSError(lastE)
 
 proc listen*(socket: Socket, backlog = SOMAXCONN) {.tags: [ReadIOEffect].} =
-  ## Marks ``socket`` as accepting connections. 
-  ## ``Backlog`` specifies the maximum length of the 
+  ## Marks ``socket`` as accepting connections.
+  ## ``Backlog`` specifies the maximum length of the
   ## queue of pending connections.
   ##
   ## Raises an EOS error upon failure.
@@ -360,7 +387,7 @@ proc acceptAddr*(server: Socket, client: var Socket, address: var string,
   ## The resulting client will inherit any properties of the server socket. For
   ## example: whether the socket is buffered or not.
   ##
-  ## **Note**: ``client`` must be initialised (with ``new``), this function 
+  ## **Note**: ``client`` must be initialised (with ``new``), this function
   ## makes no effort to initialise the ``client`` variable.
   ##
   ## The ``accept`` call may result in an error if the connecting socket
@@ -372,7 +399,7 @@ proc acceptAddr*(server: Socket, client: var Socket, address: var string,
   var addrLen = sizeof(sockAddress).SockLen
   var sock = accept(server.fd, cast[ptr SockAddr](addr(sockAddress)),
                     addr(addrLen))
-  
+
   if sock == osInvalidSocket:
     let err = osLastError()
     if flags.isDisconnectionError(err):
@@ -386,11 +413,11 @@ proc acceptAddr*(server: Socket, client: var Socket, address: var string,
     when defined(ssl):
       if server.isSSL:
         # We must wrap the client sock in a ssl context.
-        
+
         server.sslContext.wrapSocket(client)
         let ret = SSLAccept(client.sslHandle)
         socketError(client, ret, false)
-    
+
     # Client socket is set above.
     address = $inet_ntoa(sockAddress.sin_addr)
 
@@ -398,9 +425,9 @@ when false: #defined(ssl):
   proc acceptAddrSSL*(server: Socket, client: var Socket,
                       address: var string): SSLAcceptResult {.
                       tags: [ReadIOEffect].} =
-    ## This procedure should only be used for non-blocking **SSL** sockets. 
+    ## This procedure should only be used for non-blocking **SSL** sockets.
     ## It will immediately return with one of the following values:
-    ## 
+    ##
     ## ``AcceptSuccess`` will be returned when a client has been successfully
     ## accepted and the handshake has been successfully performed between
     ## ``server`` and the newly connected client.
@@ -417,7 +444,7 @@ when false: #defined(ssl):
         if server.isSSL:
           client.setBlocking(false)
           # We must wrap the client sock in a ssl context.
-          
+
           if not client.isSSL or client.sslHandle == nil:
             server.sslContext.wrapSocket(client)
           let ret = SSLAccept(client.sslHandle)
@@ -450,7 +477,7 @@ proc accept*(server: Socket, client: var Socket,
              flags = {SocketFlag.SafeDisconn}) {.tags: [ReadIOEffect].} =
   ## Equivalent to ``acceptAddr`` but doesn't return the address, only the
   ## socket.
-  ## 
+  ##
   ## **Note**: ``client`` must be initialised (with ``new``), this function
   ## makes no effort to initialise the ``client`` variable.
   ##
@@ -504,7 +531,7 @@ proc setSockOpt*(socket: Socket, opt: SOBool, value: bool, level = SOL_SOCKET) {
   var valuei = cint(if value: 1 else: 0)
   setSockOptInt(socket.fd, cint(level), toCInt(opt), valuei)
 
-proc connect*(socket: Socket, address: string, port = Port(0), 
+proc connect*(socket: Socket, address: string, port = Port(0),
               af: Domain = AF_INET) {.tags: [ReadIOEffect].} =
   ## Connects socket to ``address``:``port``. ``Address`` can be an IP address or a
   ## host name. If ``address`` is a host name, this function will try each IP
@@ -526,7 +553,7 @@ proc connect*(socket: Socket, address: string, port = Port(0),
 
   dealloc(aiList)
   if not success: raiseOSError(lastError)
-  
+
   when defined(ssl):
     if socket.isSSL:
       # RFC3546 for SNI specifies that IP addresses are not allowed.
@@ -634,12 +661,12 @@ proc recv*(socket: Socket, data: pointer, size: int): int {.tags: [ReadIOEffect]
   if socket.isBuffered:
     if socket.bufLen == 0:
       retRead(0'i32, 0)
-    
+
     var read = 0
     while read < size:
       if socket.currPos >= socket.bufLen:
         retRead(0'i32, read)
-    
+
       let chunk = min(socket.bufLen-socket.currPos, size-read)
       var d = cast[cstring](data)
       assert size-read >= chunk
@@ -686,7 +713,7 @@ proc waitFor(socket: Socket, waited: var float, timeout, size: int,
   else:
     if timeout - int(waited * 1000.0) < 1:
       raise newException(TimeoutError, "Call to '" & funcName & "' timed out.")
-    
+
     when defined(ssl):
       if socket.isSSL:
         if socket.hasDataBuffered:
@@ -695,7 +722,7 @@ proc waitFor(socket: Socket, waited: var float, timeout, size: int,
         let sslPending = SSLPending(socket.sslHandle)
         if sslPending != 0:
           return sslPending
-    
+
     var startTime = epochTime()
     let selRet = select(socket, timeout - int(waited * 1000.0))
     if selRet < 0: raiseOSError(osLastError())
@@ -706,8 +733,8 @@ proc waitFor(socket: Socket, waited: var float, timeout, size: int,
 proc recv*(socket: Socket, data: pointer, size: int, timeout: int): int {.
   tags: [ReadIOEffect, TimeEffect].} =
   ## overload with a ``timeout`` parameter in milliseconds.
-  var waited = 0.0 # number of seconds already waited  
-  
+  var waited = 0.0 # number of seconds already waited
+
   var read = 0
   while read < size:
     let avail = waitFor(socket, waited, timeout, size-read, "recv")
@@ -718,7 +745,7 @@ proc recv*(socket: Socket, data: pointer, size: int, timeout: int): int {.
     if result < 0:
       return result
     inc(read, result)
-  
+
   result = read
 
 proc recv*(socket: Socket, data: var string, size: int, timeout = -1,
@@ -752,7 +779,7 @@ proc peekChar(socket: Socket, c: var char): int {.tags: [ReadIOEffect].} =
       var res = socket.readIntoBuf(0'i32)
       if res <= 0:
         result = res
-    
+
     c = socket.buffer[socket.currPos]
   else:
     when defined(ssl):
@@ -760,7 +787,7 @@ proc peekChar(socket: Socket, c: var char): int {.tags: [ReadIOEffect].} =
         if not socket.sslHasPeekChar:
           result = SSLRead(socket.sslHandle, addr(socket.sslPeekChar), 1)
           socket.sslHasPeekChar = true
-        
+
         c = socket.sslPeekChar
         return
     result = recv(socket.fd, addr(c), 1, MSG_PEEK)
@@ -773,7 +800,7 @@ proc readLine*(socket: Socket, line: var TaintedString, timeout = -1,
   ## If a full line is read ``\r\L`` is not
   ## added to ``line``, however if solely ``\r\L`` is read then ``line``
   ## will be set to it.
-  ## 
+  ##
   ## If the socket is disconnected, ``line`` will be set to ``""``.
   ##
   ## An EOS exception will be raised in the case of a socket error.
@@ -782,7 +809,7 @@ proc readLine*(socket: Socket, line: var TaintedString, timeout = -1,
   ## the specified time an ETimeout exception will be raised.
   ##
   ## **Warning**: Only the ``SafeDisconn`` flag is currently supported.
-  
+
   template addNLIfEmpty(): stmt =
     if line.len == 0:
       line.add("\c\L")
@@ -809,7 +836,7 @@ proc readLine*(socket: Socket, line: var TaintedString, timeout = -1,
       elif n <= 0: raiseSockError()
       addNLIfEmpty()
       return
-    elif c == '\L': 
+    elif c == '\L':
       addNLIfEmpty()
       return
     add(line.string, c)
@@ -827,7 +854,7 @@ proc recvFrom*(socket: Socket, data: var string, length: int,
   ## so when ``socket`` is buffered the non-buffered implementation will be
   ## used. Therefore if ``socket`` contains something in its buffer this
   ## function will make no effort to return it.
-  
+
   # TODO: Buffered sockets
   data.setLen(length)
   var sockAddress: Sockaddr_in
@@ -861,16 +888,16 @@ proc send*(socket: Socket, data: pointer, size: int): int {.
   tags: [WriteIOEffect].} =
   ## Sends data to a socket.
   ##
-  ## **Note**: This is a low-level version of ``send``. You likely should use 
+  ## **Note**: This is a low-level version of ``send``. You likely should use
   ## the version below.
   when defined(ssl):
     if socket.isSSL:
       return SSLWrite(socket.sslHandle, cast[cstring](data), size)
-  
+
   when useWinVersion or defined(macosx):
     result = send(socket.fd, data, size.cint, 0'i32)
   else:
-    when defined(solaris): 
+    when defined(solaris):
       const MSG_NOSIGNAL = 0
     result = send(socket.fd, data, size, int32(MSG_NOSIGNAL))
 
@@ -895,7 +922,7 @@ proc sendTo*(socket: Socket, address: string, port: Port, data: pointer,
              size: int, af: Domain = AF_INET, flags = 0'i32): int {.
              tags: [WriteIOEffect].} =
   ## This proc sends ``data`` to the specified ``address``,
-  ## which may be an IP address or a hostname, if a hostname is specified 
+  ## which may be an IP address or a hostname, if a hostname is specified
   ## this function will try each IP of that hostname.
   ##
   ##
@@ -904,7 +931,7 @@ proc sendTo*(socket: Socket, address: string, port: Port, data: pointer,
   ##
   ## **Note:** This proc is not available for SSL sockets.
   var aiList = getAddrInfo(address, port, af)
-  
+
   # try all possibilities:
   var success = false
   var it = aiList
@@ -918,10 +945,10 @@ proc sendTo*(socket: Socket, address: string, port: Port, data: pointer,
 
   dealloc(aiList)
 
-proc sendTo*(socket: Socket, address: string, port: Port, 
+proc sendTo*(socket: Socket, address: string, port: Port,
              data: string): int {.tags: [WriteIOEffect].} =
   ## This proc sends ``data`` to the specified ``address``,
-  ## which may be an IP address or a hostname, if a hostname is specified 
+  ## which may be an IP address or a hostname, if a hostname is specified
   ## this function will try each IP of that hostname.
   ##
   ## This is the high-level version of the above ``sendTo`` function.
@@ -958,7 +985,7 @@ proc connectAsync(socket: Socket, name: string, port = Port(0),
         if lastError.int32 == EINTR or lastError.int32 == EINPROGRESS:
           success = true
           break
-        
+
     it = it.ai_next
 
   dealloc(aiList)
@@ -971,7 +998,7 @@ proc connect*(socket: Socket, address: string, port = Port(0), timeout: int,
   ## The ``timeout`` paremeter specifies the time in milliseconds to allow for
   ## the connection to the server to be made.
   socket.fd.setBlocking(false)
-  
+
   socket.connectAsync(address, port, af)
   var s = @[socket.fd]
   if selectWrite(s, timeout) != 1:
@@ -983,7 +1010,7 @@ proc connect*(socket: Socket, address: string, port = Port(0), timeout: int,
         doAssert socket.handshake()
   socket.fd.setBlocking(true)
 
-proc isSsl*(socket: Socket): bool = 
+proc isSsl*(socket: Socket): bool =
   ## Determines whether ``socket`` is a SSL socket.
   when defined(ssl):
     result = socket.isSSL
@@ -1014,7 +1041,7 @@ proc IPv4_broadcast*(): IpAddress =
 
 proc IPv6_any*(): IpAddress =
   ## Returns the IPv6 any address (::0), which can be used
-  ## to listen on all available network adapters 
+  ## to listen on all available network adapters
   result = IpAddress(
     family: IpAddressFamily.IPv6,
     address_v6: [0'u8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
@@ -1152,7 +1179,7 @@ proc parseIPv6Address(address_str: string): IpAddress =
       if not seperatorValid:
         raise newException(ValueError,
           "Invalid IP Address. Address contains an invalid seperator")
-      if lastWasColon:        
+      if lastWasColon:
         if dualColonGroup != -1:
           raise newException(ValueError,
             "Invalid IP Address. Address contains more than one \"::\" seperator")
@@ -1165,14 +1192,14 @@ proc parseIPv6Address(address_str: string): IpAddress =
         result.address_v6[groupCount*2] = cast[uint8](currentShort shr 8)
         result.address_v6[groupCount*2+1] = cast[uint8](currentShort and 0xFF)
         currentShort = 0
-        groupCount.inc()        
+        groupCount.inc()
         if dualColonGroup != -1: seperatorValid = false
       elif i == 0: # only valid if address starts with ::
         if address_str[1] != ':':
           raise newException(ValueError,
             "Invalid IP Address. Address may not start with \":\"")
       else: # i == high(address_str) - only valid if address ends with ::
-        if address_str[high(address_str)-1] != ':': 
+        if address_str[high(address_str)-1] != ':':
           raise newException(ValueError,
             "Invalid IP Address. Address may not end with \":\"")
       lastWasColon = true