summary refs log tree commit diff stats
path: root/lib/pure/asyncnet.nim
diff options
context:
space:
mode:
Diffstat (limited to 'lib/pure/asyncnet.nim')
-rw-r--r--lib/pure/asyncnet.nim25
1 files changed, 18 insertions, 7 deletions
diff --git a/lib/pure/asyncnet.nim b/lib/pure/asyncnet.nim
index 5be457d2a..e7552e3e3 100644
--- a/lib/pure/asyncnet.nim
+++ b/lib/pure/asyncnet.nim
@@ -134,15 +134,20 @@ type
     protocol: Protocol
   AsyncSocket* = ref AsyncSocketDesc
 
-{.deprecated: [PAsyncSocket: AsyncSocket].}
-
 proc newAsyncSocket*(fd: AsyncFD, domain: Domain = AF_INET,
     sockType: SockType = SOCK_STREAM,
     protocol: Protocol = IPPROTO_TCP, buffered = true): AsyncSocket =
   ## Creates a new ``AsyncSocket`` based on the supplied params.
+  ##
+  ## The supplied ``fd``'s non-blocking state will be enabled implicitly.
+  ##
+  ## **Note**: This procedure will **NOT** register ``fd`` with the global
+  ## async dispatcher. You need to do this manually. If you have used
+  ## ``newAsyncNativeSocket`` to create ``fd`` then it's already registered.
   assert fd != osInvalidSocket.AsyncFD
   new(result)
   result.fd = fd.SocketHandle
+  fd.SocketHandle.setBlocking(false)
   result.isBuffered = buffered
   result.domain = domain
   result.sockType = sockType
@@ -156,8 +161,10 @@ proc newAsyncSocket*(domain: Domain = AF_INET, sockType: SockType = SOCK_STREAM,
   ##
   ## This procedure will also create a brand new file descriptor for
   ## this socket.
-  result = newAsyncSocket(newAsyncNativeSocket(domain, sockType, protocol),
-                          domain, sockType, protocol, buffered)
+  let fd = createAsyncNativeSocket(domain, sockType, protocol)
+  if fd.SocketHandle == osInvalidSocket:
+    raiseOSError(osLastError())
+  result = newAsyncSocket(fd, domain, sockType, protocol, buffered)
 
 proc newAsyncSocket*(domain, sockType, protocol: cint,
     buffered = true): AsyncSocket =
@@ -165,8 +172,10 @@ proc newAsyncSocket*(domain, sockType, protocol: cint,
   ##
   ## This procedure will also create a brand new file descriptor for
   ## this socket.
-  result = newAsyncSocket(newAsyncNativeSocket(domain, sockType, protocol),
-                          Domain(domain), SockType(sockType),
+  let fd = createAsyncNativeSocket(domain, sockType, protocol)
+  if fd.SocketHandle == osInvalidSocket:
+    raiseOSError(osLastError())
+  result = newAsyncSocket(fd, Domain(domain), SockType(sockType),
                           Protocol(protocol), buffered)
 
 when defineSsl:
@@ -190,7 +199,7 @@ when defineSsl:
       flags: set[SocketFlag]) {.async.} =
     let len = bioCtrlPending(socket.bioOut)
     if len > 0:
-      var data = newStringOfCap(len)
+      var data = newString(len)
       let read = bioRead(socket.bioOut, addr data[0], len)
       assert read != 0
       if read < 0:
@@ -277,6 +286,7 @@ template readInto(buf: pointer, size: int, socket: AsyncSocket,
                   flags: set[SocketFlag]): int =
   ## Reads **up to** ``size`` bytes from ``socket`` into ``buf``. Note that
   ## this is a template and not a proc.
+  assert(not socket.closed, "Cannot `recv` on a closed socket")
   var res = 0
   if socket.isSsl:
     when defineSsl:
@@ -403,6 +413,7 @@ proc send*(socket: AsyncSocket, buf: pointer, size: int,
   ## Sends ``size`` bytes from ``buf`` to ``socket``. The returned future will complete once all
   ## data has been sent.
   assert socket != nil
+  assert(not socket.closed, "Cannot `send` on a closed socket")
   if socket.isSsl:
     when defineSsl:
       sslLoop(socket, flags,
a> 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325