summary refs log tree commit diff stats
diff options
context:
space:
mode:
-rw-r--r--.github/workflows/ci_packages.yml8
-rw-r--r--azure-pipelines.yml7
-rw-r--r--lib/pure/net.nim65
-rw-r--r--lib/wrappers/openssl.nim134
-rw-r--r--tests/stdlib/tssl.nim44
5 files changed, 118 insertions, 140 deletions
diff --git a/.github/workflows/ci_packages.yml b/.github/workflows/ci_packages.yml
index 281e55b61..1be06e696 100644
--- a/.github/workflows/ci_packages.yml
+++ b/.github/workflows/ci_packages.yml
@@ -42,7 +42,10 @@ jobs:
               valgrind libc6-dbg libblas-dev xorg-dev
       - name: 'Install dependencies (macOS)'
         if: runner.os == 'macOS'
-        run: brew install boehmgc make sfml gtk+3
+        run: |
+          brew install boehmgc make sfml gtk+3 openssl@1.1
+          ln -s $(brew --prefix)/opt/openssl/lib/libcrypto.1.1.dylib /usr/local/lib
+          ln -s $(brew --prefix)/opt/openssl/lib/libssl.1.1.dylib /usr/local/lib/
       - name: 'Install dependencies (Windows)'
         if: runner.os == 'Windows'
         shell: bash
@@ -66,4 +69,5 @@ jobs:
 
       - name: 'koch, Run CI'
         shell: bash
-        run: . ci/funs.sh && nimInternalBuildKochAndRunCI
+        run: |
+          . ci/funs.sh && nimInternalBuildKochAndRunCI
diff --git a/azure-pipelines.yml b/azure-pipelines.yml
index bfc58d072..734adbc7e 100644
--- a/azure-pipelines.yml
+++ b/azure-pipelines.yml
@@ -131,6 +131,13 @@ jobs:
       condition: and(succeeded(), eq(variables['skipci'], 'false'), eq(variables['Agent.OS'], 'Darwin'))
 
     - bash: |
+        brew install openssl@1.1
+        ln -s $(brew --prefix)/opt/openssl/lib/libcrypto.1.1.dylib /usr/local/lib
+        ln -s $(brew --prefix)/opt/openssl/lib/libssl.1.1.dylib /usr/local/lib/
+      displayName: 'Install OpenSSL (OSX)'
+      condition: and(succeeded(), eq(variables['skipci'], 'false'), eq(variables['Agent.OS'], 'Darwin'))
+
+    - bash: |
         set -e
         . ci/funs.sh
         nimInternalInstallDepsWindows
diff --git a/lib/pure/net.nim b/lib/pure/net.nim
index f18e64463..dcc35d65d 100644
--- a/lib/pure/net.nim
+++ b/lib/pure/net.nim
@@ -543,18 +543,14 @@ proc fromSockAddr*(sa: Sockaddr_storage | SockAddr | Sockaddr_in | Sockaddr_in6,
   fromSockAddrAux(cast[ptr Sockaddr_storage](unsafeAddr sa), sl, address, port)
 
 when defineSsl:
-  CRYPTO_malloc_init()
-  doAssert SslLibraryInit() == 1
-  SSL_load_error_strings()
-  ERR_load_BIO_strings()
-  OpenSSL_add_all_algorithms()
+  # OpenSSL >= 1.1.0 does not need explicit init.
 
   proc sslHandle*(self: Socket): SslPtr =
     ## Retrieve the ssl pointer of `socket`.
     ## Useful for interfacing with `openssl`.
     self.sslHandle
 
-  proc raiseSSLError*(s = "") =
+  proc raiseSSLError*(s = "") {.raises: [SslError].}=
     ## Raises a new SSL error.
     if s != "":
       raise newException(SslError, s)
@@ -619,9 +615,7 @@ when defineSsl:
                    caDir = "", caFile = ""): SslContext =
     ## Creates an SSL context.
     ##
-    ## Protocol version specifies the protocol to use. SSLv2, SSLv3, TLSv1
-    ## are available with the addition of `protSSLv23` which allows for
-    ## compatibility with all of them.
+    ## protVersion is currently unsed.
     ##
     ## There are three options for verify mode:
     ## `CVerifyNone`: certificates are not verified;
@@ -648,16 +642,12 @@ when defineSsl:
     ## or using ECDSA:
     ## - `openssl ecparam -out mykey.pem -name secp256k1 -genkey`
     ## - `openssl req -new -key mykey.pem -x509 -nodes -days 365 -out mycert.pem`
-    var newCTX: SslCtx
-    case protVersion
-    of protSSLv23:
-      newCTX = SSL_CTX_new(SSLv23_method()) # SSlv2,3 and TLS1 support.
-    of protSSLv2:
-      raiseSSLError("SSLv2 is no longer secure and has been deprecated, use protSSLv23")
-    of protSSLv3:
-      raiseSSLError("SSLv3 is no longer secure and has been deprecated, use protSSLv23")
-    of protTLSv1:
-      newCTX = SSL_CTX_new(TLSv1_method())
+    let mtd = TLS_method()
+    if mtd == nil:
+      raiseSSLError("Failed to create TLS context")
+    var newCTX = SSL_CTX_new(mtd)
+    if newCTX == nil:
+      raiseSSLError("Failed to create TLS context")
 
     if newCTX.SSL_CTX_set_cipher_list(cipherList) != 1:
       raiseSSLError()
@@ -812,24 +802,28 @@ when defineSsl:
     if SSL_set_fd(socket.sslHandle, socket.fd) != 1:
       raiseSSLError()
 
-  proc checkCertName(socket: Socket, hostname: string) =
+  proc checkCertName(socket: Socket, hostname: string) {.raises: [SslError], tags:[RootEffect].} =
     ## Check if the certificate Subject Alternative Name (SAN) or Subject CommonName (CN) matches hostname.
     ## Wildcards match only in the left-most label.
     ## When name starts with a dot it will be matched by a certificate valid for any subdomain
     when not defined(nimDisableCertificateValidation) and not defined(windows):
       assert socket.isSsl
-      let certificate = socket.sslHandle.SSL_get_peer_certificate()
-      if certificate.isNil:
-        raiseSSLError("No SSL certificate found.")
-
-      const X509_CHECK_FLAG_ALWAYS_CHECK_SUBJECT = 0x1.cuint
-      # https://www.openssl.org/docs/man1.1.1/man3/X509_check_host.html
-      let match = certificate.X509_check_host(hostname.cstring, hostname.len.cint,
-        X509_CHECK_FLAG_ALWAYS_CHECK_SUBJECT, nil)
-      # https://www.openssl.org/docs/man1.1.1/man3/SSL_get_peer_certificate.html
-      X509_free(certificate)
-      if match != 1:
-        raiseSSLError("SSL Certificate check failed.")
+      try:
+        let certificate = socket.sslHandle.SSL_get_peer_certificate()
+        if certificate.isNil:
+          raiseSSLError("No SSL certificate found.")
+
+        const X509_CHECK_FLAG_ALWAYS_CHECK_SUBJECT = 0x1.cuint
+        # https://www.openssl.org/docs/man1.1.1/man3/X509_check_host.html
+        let match = certificate.X509_check_host(hostname.cstring, hostname.len.cint,
+          X509_CHECK_FLAG_ALWAYS_CHECK_SUBJECT, nil)
+        # https://www.openssl.org/docs/man1.1.1/man3/SSL_get_peer_certificate.html
+        X509_free(certificate)
+        if match != 1:
+          raiseSSLError("SSL Certificate check failed.")
+
+      except LibraryError:
+        raiseSSLError("SSL import failed")
 
   proc wrapConnectedSocket*(ctx: SslContext, socket: Socket,
                             handshake: SslHandshakeType,
@@ -856,6 +850,7 @@ when defineSsl:
       let ret = SSL_connect(socket.sslHandle)
       socketError(socket, ret)
       when not defined(nimDisableCertificateValidation) and not defined(windows):
+        # FIXME: this should be skipped on CVerifyNone
         if hostname.len > 0 and not isIpAddress(hostname):
           socket.checkCertName(hostname)
     of handshakeAsServer:
@@ -1311,7 +1306,7 @@ when defined(nimdoc) or (defined(posix) and not useNimNetLite):
           (sizeof(socketAddr.sun_family) + path.len).SockLen) != 0'i32:
         raiseOSError(osLastError())
 
-when defined(ssl):
+when defineSsl:
   proc gotHandshake*(socket: Socket): bool =
     ## Determines whether a handshake has occurred between a client (`socket`)
     ## and the server that `socket` is connected to.
@@ -1998,7 +1993,7 @@ proc dial*(address: string, port: Port,
     raise newException(IOError, "Couldn't resolve address: " & address)
 
 proc connect*(socket: Socket, address: string,
-    port = Port(0)) {.tags: [ReadIOEffect].} =
+    port = Port(0)) {.tags: [ReadIOEffect, RootEffect].} =
   ## Connects socket to `address`:`port`. `Address` can be an IP address or a
   ## host name. If `address` is a host name, this function will try each IP
   ## of that host name. `htons` is already performed on `port` so you must
@@ -2073,7 +2068,7 @@ proc connectAsync(socket: Socket, name: string, port = Port(0),
   if not success: raiseOSError(lastError)
 
 proc connect*(socket: Socket, address: string, port = Port(0),
-    timeout: int) {.tags: [ReadIOEffect, WriteIOEffect].} =
+    timeout: int) {.tags: [ReadIOEffect, WriteIOEffect, RootEffect].} =
   ## Connects to server as specified by `address` on port specified by `port`.
   ##
   ## The `timeout` parameter specifies the time in milliseconds to allow for
diff --git a/lib/wrappers/openssl.nim b/lib/wrappers/openssl.nim
index 1db3fc239..b872e6f2d 100644
--- a/lib/wrappers/openssl.nim
+++ b/lib/wrappers/openssl.nim
@@ -7,27 +7,30 @@
 #    distribution, for details about the copyright.
 #
 
-## OpenSSL support
+## OpenSSL wrapper. Supports OpenSSL >= 1.1.0 dynamically (as default) or statically linked
+## using `--dynlibOverride:ssl`.
 ##
-## When OpenSSL is dynamically linked, the wrapper provides partial forward and backward
-## compatibility for OpenSSL versions above and below 1.1.0
-##
-## OpenSSL can also be statically linked using `--dynlibOverride:ssl` for OpenSSL >= 1.1.0.
-## If you want to statically link against OpenSSL 1.0.x, you now have to
-## define the `openssl10` symbol via `-d:openssl10`.
+## To use openSSL 3 set the symbol: -d:sslVersion=3
 ##
 ## Build and test examples:
 ##
 ## .. code-block::
+##   ./bin/nim c -d:ssl -p:. -r tests/stdlib/tssl.nim
+##   ./bin/nim c -d:ssl --threads:on -p:. -r tests/stdlib/thttpclient_ssl.nim
 ##   ./bin/nim c -d:ssl -p:. -r tests/untestable/tssl.nim
 ##   ./bin/nim c -d:ssl -p:. --dynlibOverride:ssl --passl:-lcrypto --passl:-lssl -r tests/untestable/tssl.nim
+##   ./bin/nim r --putenv:NIM_TESTAMENT_REMOTE_NETWORKING:1 -d:ssl -p:testament/lib --threads:on tests/untestable/thttpclient_ssl_remotenetwork.nim
+
+# https://www.feistyduck.com/library/openssl-cookbook/online/ch-testing-with-openssl.html
+#
+from strutils import startsWith
 
 when defined(nimHasStyleChecks):
   {.push styleChecks: off.}
 
 const useWinVersion = defined(windows) or defined(nimdoc)
 
-# To force openSSL version use -d:sslVersion=1.0.0
+# To force openSSL version use -d:sslVersion=1.2.3
 # See: #10281, #10230
 # General issue:
 # Other dynamic libraries (like libpg) load different openSSL version then what nim loads.
@@ -52,7 +55,7 @@ when sslVersion != "":
     from posix import SocketHandle
 
 elif useWinVersion:
-  when defined(openssl10) or defined(nimOldDlls):
+  when defined(nimOldDlls):
     when defined(cpu64):
       const
         DLLSSLName* = "(ssleay32|ssleay64).dll"
@@ -72,10 +75,11 @@ elif useWinVersion:
 
   from winlean import SocketHandle
 else:
-  when defined(osx):
-    const versions = "(.1.1|.38|.39|.41|.43|.44|.45|.46|.47|.48|.10|.1.0.2|.1.0.1|.1.0.0|.0.9.9|.0.9.8|)"
+  when defined(macosx):
+    # use only versioned soname
+    const versions = ".1.1"
   else:
-    const versions = "(.1.1|.1.0.2|.1.0.1|.1.0.0|.0.9.9|.0.9.8|.48|.47|.46|.45|.44|.43|.41|.39|.38|.10|)"
+    const versions = "(.1.1|.48|.47|.46|.45|.44|.43|.41|.39|.38|.10|)"
 
   when defined(macosx):
     const
@@ -268,40 +272,24 @@ proc TLSv1_method*(): PSSL_METHOD{.cdecl, dynlib: DLLSSLName, importc.}
 # and support SSLv3, TLSv1, TLSv1.1 and TLSv1.2
 # SSLv23_method(), SSLv23_server_method(), SSLv23_client_method() are removed in 1.1.0
 
-when compileOption("dynlibOverride", "ssl") or defined(noOpenSSLHacks):
+when compileOption("dynlibOverride", "ssl"):
   # Static linking
 
-  when defined(openssl10):
-    proc SSL_library_init*(): cint {.cdecl, dynlib: DLLSSLName, importc, discardable.}
-    proc SSL_load_error_strings*() {.cdecl, dynlib: DLLSSLName, importc.}
-    proc SSLv23_method*(): PSSL_METHOD {.cdecl, dynlib: DLLSSLName, importc.}
-    proc SSLeay(): culong {.cdecl, dynlib: DLLUtilName, importc.}
-
-    proc getOpenSSLVersion*(): culong =
-      SSLeay()
-  else:
-    proc OPENSSL_init_ssl*(opts: uint64, settings: uint8): cint {.cdecl, dynlib: DLLSSLName, importc, discardable.}
-    proc SSL_library_init*(): cint {.discardable.} =
-      ## Initialize SSL using OPENSSL_init_ssl for OpenSSL >= 1.1.0
-      return OPENSSL_init_ssl(0.uint64, 0.uint8)
-
-    proc TLS_method*(): PSSL_METHOD {.cdecl, dynlib: DLLSSLName, importc.}
-    proc SSLv23_method*(): PSSL_METHOD =
-      TLS_method()
+  proc TLS_method*(): PSSL_METHOD {.cdecl, dynlib: DLLSSLName, importc.}
 
-    proc OpenSSL_version_num(): culong {.cdecl, dynlib: DLLUtilName, importc.}
+  proc OpenSSL_version_num(): culong {.cdecl, dynlib: DLLUtilName, importc.}
 
-    proc getOpenSSLVersion*(): culong =
-      ## Return OpenSSL version as unsigned long
-      OpenSSL_version_num()
+  proc getOpenSSLVersion*(): culong =
+    ## Return OpenSSL version as unsigned long
+    OpenSSL_version_num()
 
-    proc SSL_load_error_strings*() =
-      ## Removed from OpenSSL 1.1.0
-      # This proc prevents breaking existing code calling SslLoadErrorStrings
-      # Static linking against OpenSSL < 1.1.0 is not supported
-      discard
+  proc SSL_load_error_strings*() =
+    ## Removed from OpenSSL 1.1.0
+    # This proc prevents breaking existing code calling SslLoadErrorStrings
+    # Static linking against OpenSSL < 1.1.0 is not supported
+    discard
 
-  when defined(libressl) or defined(openssl10):
+  when defined(libressl):
     proc SSL_state(ssl: SslPtr): cint {.cdecl, dynlib: DLLSSLName, importc.}
     proc SSL_in_init*(ssl: SslPtr): cint {.inline.} =
       SSl_state(ssl) and SSL_ST_INIT
@@ -311,12 +299,8 @@ when compileOption("dynlibOverride", "ssl") or defined(noOpenSSLHacks):
 
   template OpenSSL_add_all_algorithms*() = discard
 
-  proc SSLv23_client_method*(): PSSL_METHOD {.cdecl, dynlib: DLLSSLName, importc.}
-  proc SSLv2_method*(): PSSL_METHOD {.cdecl, dynlib: DLLSSLName, importc.}
-  proc SSLv3_method*(): PSSL_METHOD {.cdecl, dynlib: DLLSSLName, importc.}
-
 else:
-  # Here we're trying to stay compatible with openssl 1.0.* and 1.1.*. Some
+  # Here we're trying to stay compatible with openssl 1.1.*. Some
   # symbols are loaded dynamically and we don't use them if not found.
   proc thisModule(): LibHandle {.inline.} =
     var thisMod {.global.}: LibHandle
@@ -324,9 +308,12 @@ else:
 
     result = thisMod
 
-  proc sslModule(): LibHandle {.inline.} =
+  proc sslModule(): LibHandle {.inline, raises: [LibraryError], tags:[RootEffect].} =
     var sslMod {.global.}: LibHandle
-    if sslMod.isNil: sslMod = loadLibPattern(DLLSSLName)
+    try:
+      if sslMod.isNil: sslMod = loadLibPattern(DLLSSLName)
+    except:
+      raise newException(LibraryError, "Could not load SSL using " & DLLSSLName)
 
     result = sslMod
 
@@ -351,63 +338,37 @@ else:
       if result.isNil and alternativeName.len > 0:
         result = symAddr(thisDynlib, alternativeName)
 
-  proc sslSymNullable(name: string, alternativeName = ""): pointer =
+  proc sslSymNullable(name: string, alternativeName = ""): pointer {.raises: [LibraryError], tags:[RootEffect].} =
     sslModule().symNullable(name, alternativeName)
 
-  proc sslSymThrows(name: string, alternativeName = ""): pointer =
+  proc sslSymThrows(name: string, alternativeName = ""): pointer {.raises: [LibraryError].} =
     result = sslSymNullable(name, alternativeName)
     if result.isNil: raiseInvalidLibrary(name)
 
   proc utilSymNullable(name: string, alternativeName = ""): pointer =
     utilModule().symNullable(name, alternativeName)
 
-  proc loadPSSLMethod(method1, method2: string): PSSL_METHOD =
+  proc loadPSSLMethod(method1, method2: string): PSSL_METHOD {.raises: [LibraryError], tags:[RootEffect].} =
     ## Load <method1> from OpenSSL if available, otherwise <method2>
     ##
     let methodSym = sslSymNullable(method1, method2)
     if methodSym.isNil:
       raise newException(LibraryError, "Could not load " & method1 & " nor " & method2)
 
-    let method2Proc = cast[proc(): PSSL_METHOD {.cdecl, gcsafe.}](methodSym)
+    let method2Proc = cast[proc(): PSSL_METHOD {.cdecl, gcsafe, raises: [].}](methodSym)
     return method2Proc()
 
-  proc SSL_library_init*(): cint {.discardable.} =
-    ## Initialize SSL using OPENSSL_init_ssl for OpenSSL >= 1.1.0 otherwise
-    ## SSL_library_init
-    let newInitSym = sslSymNullable("OPENSSL_init_ssl")
-    if not newInitSym.isNil:
-      let newInitProc =
-        cast[proc(opts: uint64, settings: uint8): cint {.cdecl.}](newInitSym)
-      return newInitProc(0, 0)
-    let olderProc = cast[proc(): cint {.cdecl.}](sslSymThrows("SSL_library_init"))
-    if not olderProc.isNil: result = olderProc()
-
   proc SSL_load_error_strings*() =
     # TODO: Are we ignoring this on purpose? SSL GitHub CI fails otherwise.
     let theProc = cast[proc() {.cdecl.}](sslSymNullable("SSL_load_error_strings"))
     if not theProc.isNil: theProc()
 
-  proc SSLv23_client_method*(): PSSL_METHOD =
-    loadPSSLMethod("SSLv23_client_method", "TLS_client_method")
-
-  proc SSLv23_method*(): PSSL_METHOD =
-    loadPSSLMethod("SSLv23_method", "TLS_method")
-
-  proc SSLv2_method*(): PSSL_METHOD =
-    loadPSSLMethod("SSLv2_method", "TLS_method")
-
   proc SSLv3_method*(): PSSL_METHOD =
     loadPSSLMethod("SSLv3_method", "TLS_method")
 
   proc TLS_method*(): PSSL_METHOD =
     loadPSSLMethod("TLS_method", "SSLv23_method")
 
-  proc TLS_client_method*(): PSSL_METHOD =
-    loadPSSLMethod("TLS_client_method", "SSLv23_client_method")
-
-  proc TLS_server_method*(): PSSL_METHOD =
-    loadPSSLMethod("TLS_server_method", "SSLv23_server_method")
-
   proc OpenSSL_add_all_algorithms*() =
     # TODO: Are we ignoring this on purpose? SSL GitHub CI fails otherwise.
     let theProc = cast[proc() {.cdecl.}](sslSymNullable("OPENSSL_add_all_algorithms_conf"))
@@ -441,7 +402,11 @@ else:
       theProc = cast[typeof(theProc)](sslSymThrows("SSL_CTX_set_ciphersuites"))
     theProc(ctx, str)
 
-proc ERR_load_BIO_strings*(){.cdecl, dynlib: DLLUtilName, importc.}
+
+proc OPENSSL_init_ssl*(opts: uint64, settings: uint8): cint {.cdecl, dynlib: DLLSSLName, importc.}
+
+proc TLS_client_method*(): PSSL_METHOD {.cdecl, dynlib: DLLSSLName, importc.}
+
 
 proc SSL_new*(context: SslCtx): SslPtr{.cdecl, dynlib: DLLSSLName, importc.}
 proc SSL_free*(ssl: SslPtr){.cdecl, dynlib: DLLSSLName, importc.}
@@ -804,8 +769,17 @@ when defined(nimHasStyleChecks):
 # On old openSSL version some of these symbols are not available
 when not defined(nimDisableCertificateValidation) and not defined(windows):
 
-  proc SSL_get_peer_certificate*(ssl: SslCtx): PX509{.cdecl, dynlib: DLLSSLName,
-      importc.}
+  # proc SSL_get_peer_certificate*(ssl: SslCtx): PX509 =
+  #  loadPSSLMethod("SSL_get_peer_certificate", "SSL_get1_peer_certificate")
+
+  when sslVersion.startsWith('3'):
+    proc SSL_get1_peer_certificate*(ssl: SslCtx): PX509 {.cdecl, dynlib: DLLSSLName, importc.}
+    proc SSL_get_peer_certificate*(ssl: SslCtx): PX509 =
+      SSL_get1_peer_certificate(ssl)
+
+  else:
+    proc SSL_get_peer_certificate*(ssl: SslCtx): PX509 {.cdecl, dynlib: DLLSSLName, importc.}
+
 
   proc X509_get_subject_name*(a: PX509): PX509_NAME{.cdecl, dynlib: DLLSSLName, importc.}
 
diff --git a/tests/stdlib/tssl.nim b/tests/stdlib/tssl.nim
index 379c1b1e5..fd85cb55b 100644
--- a/tests/stdlib/tssl.nim
+++ b/tests/stdlib/tssl.nim
@@ -16,9 +16,22 @@ when not defined(ssl):
 
 const DummyData = "dummy data\n"
 
+proc createSocket(): Socket =
+  result = newSocket(buffered = false)
+  result.setSockOpt(OptReuseAddr, true)
+  result.setSockOpt(OptReusePort, true)
+
+proc createServer(serverContext: SslContext): (Socket, Port) =
+  var server = createSocket()
+  serverContext.wrapSocket(server)
+  server.bindAddr(address = "localhost")
+  let (_, port) = server.getLocalAddr()
+  server.listen()
+  return (server, port)
+
 proc abruptShutdown(port: Port) {.thread.} =
   let clientContext = newContext(verifyMode = CVerifyNone)
-  var client = newSocket(buffered = false)
+  var client = createSocket()
   clientContext.wrapSocket(client)
   client.connect("localhost", port)
 
@@ -27,7 +40,7 @@ proc abruptShutdown(port: Port) {.thread.} =
 
 proc notifiedShutdown(port: Port) {.thread.} =
   let clientContext = newContext(verifyMode = CVerifyNone)
-  var client = newSocket(buffered = false)
+  var client = createSocket()
   clientContext.wrapSocket(client)
   client.connect("localhost", port)
 
@@ -49,13 +62,7 @@ proc main() =
                                  keyFile = "tests/testdata/mycert.pem")
 
   block peer_close_during_write_without_shutdown:
-    var server = newSocket(buffered = false)
-    defer: server.close()
-    serverContext.wrapSocket(server)
-    server.bindAddr(address = "localhost")
-    let (_, port) = server.getLocalAddr()
-    server.listen()
-
+    var (server, port) = createServer(serverContext)
     var clientThread: Thread[Port]
     createThread(clientThread, abruptShutdown, port)
 
@@ -73,19 +80,14 @@ proc main() =
       discard
     finally:
       peer.close()
+      server.close()
 
   when defined(posix):
     if sigaction(SIGPIPE, oldSigPipeHandler, nil) == -1:
       raiseOSError(osLastError(), "Couldn't restore SIGPIPE handler")
 
   block peer_close_before_received_shutdown:
-    var server = newSocket(buffered = false)
-    defer: server.close()
-    serverContext.wrapSocket(server)
-    server.bindAddr(address = "localhost")
-    let (_, port) = server.getLocalAddr()
-    server.listen()
-
+    var (server, port) = createServer(serverContext)
     var clientThread: Thread[Port]
     createThread(clientThread, abruptShutdown, port)
 
@@ -104,15 +106,10 @@ proc main() =
         discard peer.getFd.shutdown(SD_SEND)
     finally:
       peer.close()
+      server.close()
 
   block peer_close_after_received_shutdown:
-    var server = newSocket(buffered = false)
-    defer: server.close()
-    serverContext.wrapSocket(server)
-    server.bindAddr(address = "localhost")
-    let (_, port) = server.getLocalAddr()
-    server.listen()
-
+    var (server, port) = createServer(serverContext)
     var clientThread: Thread[Port]
     createThread(clientThread, notifiedShutdown, port)
 
@@ -132,5 +129,6 @@ proc main() =
         discard peer.getFd.shutdown(SD_SEND)
     finally:
       peer.close()
+      server.close()
 
 when isMainModule: main()