summary refs log tree commit diff stats
diff options
context:
space:
mode:
authorWim Lewis <wiml@hhhh.org>2015-04-26 22:10:35 -0700
committerWim Lewis <wiml@hhhh.org>2015-04-26 22:10:35 -0700
commitdf1cdced1d9ec5663c735065a21dc5b00067b8b2 (patch)
treec35916e3363e49d7cd1cea0d2141cd75a665a001
parent9c19ce0698fe8bf5baad5cebf78776c161dceb8e (diff)
downloadNim-df1cdced1d9ec5663c735065a21dc5b00067b8b2.tar.gz
Make the post-connection wrapSocket() call available in both the synchronous and asynchrinous net modules.
-rw-r--r--lib/pure/asyncnet.nim13
-rw-r--r--lib/pure/net.nim29
2 files changed, 32 insertions, 10 deletions
diff --git a/lib/pure/asyncnet.nim b/lib/pure/asyncnet.nim
index 1b11aaffc..c328649e5 100644
--- a/lib/pure/asyncnet.nim
+++ b/lib/pure/asyncnet.nim
@@ -87,12 +87,6 @@ type
     of false: nil
   AsyncSocket* = ref AsyncSocketDesc
 
-when defined(ssl):
-  type HandshakeType* = enum
-    handshakeNone,
-    handshakeAsClient,
-    handshakeAsServer
-
 {.deprecated: [PAsyncSocket: AsyncSocket].}
 
 # TODO: Save AF, domain etc info and reuse it in procs which need it like connect.
@@ -424,7 +418,7 @@ proc close*(socket: AsyncSocket) =
   socket.closed = true # TODO: Add extra debugging checks for this.
 
 when defined(ssl):
-  proc wrapSocket*(ctx: SslContext, socket: AsyncSocket, handshake: HandshakeType = handshakeNone) =
+  proc wrapSocket*(ctx: SslContext, socket: AsyncSocket) =
     ## Wraps a socket in an SSL context. This function effectively turns
     ## ``socket`` into an SSL socket.
     ##
@@ -440,9 +434,10 @@ when defined(ssl):
     socket.bioOut = bioNew(bio_s_mem())
     sslSetBio(socket.sslHandle, socket.bioIn, socket.bioOut)
 
+  proc wrapSocket*(ctx: SslContext, socket: AsyncSocket, handshake: SslHandshakeType) =
+    wrapSocket(ctx, socket)
+
     case handshake
-    of handshakeNone:
-      discard
     of handshakeAsClient:
       sslSetConnectState(socket.sslHandle)
     of handshakeAsServer:
diff --git a/lib/pure/net.nim b/lib/pure/net.nim
index ffbc6e320..a21485655 100644
--- a/lib/pure/net.nim
+++ b/lib/pure/net.nim
@@ -35,6 +35,9 @@ when defined(ssl):
     SslAcceptResult* = enum
       AcceptNoClient = 0, AcceptNoHandshake, AcceptSuccess
 
+    SslHandshakeType* = enum
+      handshakeAsClient, handshakeAsServer
+
   {.deprecated: [ESSL: SSLError, TSSLCVerifyMode: SSLCVerifyMode,
     TSSLProtVersion: SSLProtVersion, PSSLContext: SSLContext,
     TSSLAcceptResult: SSLAcceptResult].}
@@ -97,6 +100,8 @@ type
 
 proc isIpAddress*(address_str: string): bool {.tags: [].}
 proc parseIpAddress*(address_str: string): TIpAddress
+proc socketError*(socket: Socket, err: int = -1, async = false,
+lastError = (-1).OSErrorCode): void
 
 proc isDisconnectionError*(flags: set[SocketFlag],
     lastError: OSErrorCode): bool =
@@ -235,9 +240,13 @@ when defined(ssl):
     ## Wraps a socket in an SSL context. This function effectively turns
     ## ``socket`` into an SSL socket.
     ##
+    ## This must be called on an unconnected socket; an SSL session will
+    ## be started when the socket is connected.
+    ##
     ## **Disclaimer**: This code is not well tested, may be very unsafe and
     ## prone to security vulnerabilities.
-    
+
+    assert (not socket.isSSL)
     socket.isSSL = true
     socket.sslContext = ctx
     socket.sslHandle = SSLNew(SSLCTX(socket.sslContext))
@@ -249,6 +258,24 @@ when defined(ssl):
     if SSLSetFd(socket.sslHandle, socket.fd) != 1:
       raiseSSLError()
 
+  proc wrapSocket*(ctx: SSLContext, socket: Socket, handshake: SslHandshakeType) =
+    ## Wraps a socket in an SSL context. This function effectively turns
+    ## ``socket`` into an SSL socket.
+    ##
+    ## This should be called on a connected socket, and will perform
+    ## an SSL handshake immediately.
+    ##
+    ## **Disclaimer**: This code is not well tested, may be very unsafe and
+    ## prone to security vulnerabilities.
+    wrapSocket(ctx, socket)
+    case handshake
+    of handshakeAsClient:
+      let ret = SSLConnect(socket.sslHandle)
+      socketError(socket, ret)
+    of handshakeAsServer:
+      let ret = SSLAccept(socket.sslHandle)
+      socketError(socket, ret)
+
 proc getSocketError*(socket: Socket): OSErrorCode =
   ## Checks ``osLastError`` for a valid error. If it has been reset it uses
   ## the last error stored in the socket object.