summary refs log tree commit diff stats
path: root/lib/pure/asyncio2.nim
diff options
context:
space:
mode:
Diffstat (limited to 'lib/pure/asyncio2.nim')
-rw-r--r--lib/pure/asyncio2.nim223
1 files changed, 150 insertions, 73 deletions
diff --git a/lib/pure/asyncio2.nim b/lib/pure/asyncio2.nim
index 12d4cb5a3..eb31eca13 100644
--- a/lib/pure/asyncio2.nim
+++ b/lib/pure/asyncio2.nim
@@ -43,6 +43,14 @@ proc complete*[T](future: PFuture[T], val: T) =
   if future.cb != nil:
     future.cb()
 
+proc complete*(future: PFuture[void]) =
+  ## Completes a void ``future``.
+  assert(not future.finished, "Future already finished, cannot finish twice.")
+  assert(future.error == nil)
+  future.finished = true
+  if future.cb != nil:
+    future.cb()
+
 proc fail*[T](future: PFuture[T], error: ref EBase) =
   ## Completes ``future`` with ``error``.
   assert(not future.finished, "Future already finished, cannot finish twice.")
@@ -76,7 +84,8 @@ proc read*[T](future: PFuture[T]): T =
   ## If the result of the future is an error then that error will be raised.
   if future.finished:
     if future.error != nil: raise future.error
-    return future.value
+    when T isnot void:
+      return future.value
   else:
     # TODO: Make a custom exception type for this?
     raise newException(EInvalidValue, "Future still in progress.")
@@ -94,7 +103,8 @@ proc failed*[T](future: PFuture[T]): bool =
 # TODO: Get rid of register. Do it implicitly.
 
 when defined(windows) or defined(nimdoc):
-  import winlean
+  import winlean, sets, hashes
+  #from hashes import THash
   type
     TCompletionKey = dword
 
@@ -105,7 +115,7 @@ when defined(windows) or defined(nimdoc):
 
     PDispatcher* = ref object
       ioPort: THandle
-      hasHandles: bool
+      handles: TSet[TSocketHandle]
 
     TCustomOverlapped = object
       Internal*: DWORD
@@ -117,21 +127,31 @@ when defined(windows) or defined(nimdoc):
 
     PCustomOverlapped = ptr TCustomOverlapped
 
+  proc hash(x: TSocketHandle): THash {.borrow.}
+
   proc newDispatcher*(): PDispatcher =
     ## Creates a new Dispatcher instance.
     new result
     result.ioPort = CreateIOCompletionPort(INVALID_HANDLE_VALUE, 0, 0, 1)
+    result.handles = initSet[TSocketHandle]()
 
   proc register*(p: PDispatcher, sock: TSocketHandle) =
     ## Registers ``sock`` with the dispatcher ``p``.
     if CreateIOCompletionPort(sock.THandle, p.ioPort,
                               cast[TCompletionKey](sock), 1) == 0:
       OSError(OSLastError())
-    p.hasHandles = true
+    p.handles.incl(sock)
+
+  proc verifyPresence(p: PDispatcher, sock: TSocketHandle) =
+    ## Ensures that socket has been registered with the dispatcher.
+    if sock notin p.handles:
+      raise newException(EInvalidValue,
+        "Operation performed on a socket which has not been registered with" &
+        " the dispatcher yet.")
 
   proc poll*(p: PDispatcher, timeout = 500) =
     ## Waits for completion events and processes them.
-    if not p.hasHandles:
+    if p.handles.len == 0:
       raise newException(EInvalidValue, "No handles registered in dispatcher.")
     
     let llTimeout =
@@ -232,13 +252,13 @@ when defined(windows) or defined(nimdoc):
                   RemoteSockaddr, RemoteSockaddrLength)
 
   proc connect*(p: PDispatcher, socket: TSocketHandle, address: string, port: TPort,
-    af = AF_INET): PFuture[int] =
+    af = AF_INET): PFuture[void] =
     ## Connects ``socket`` to server at ``address:port``.
     ##
     ## Returns a ``PFuture`` which will complete when the connection succeeds
     ## or an error occurs.
-
-    var retFuture = newFuture[int]()# TODO: Change to void when that regression is fixed.
+    verifyPresence(p, socket)
+    var retFuture = newFuture[void]()
     # Apparently ``ConnectEx`` expects the socket to be initially bound:
     var saddr: Tsockaddr_in
     saddr.sin_family = int16(toInt(af))
@@ -260,7 +280,7 @@ when defined(windows) or defined(nimdoc):
         proc (sock: TSocketHandle, bytesCount: DWord, errcode: TOSErrorCode) =
           if not retFuture.finished:
             if errcode == TOSErrorCode(-1):
-              retFuture.complete(0)
+              retFuture.complete()
             else:
               retFuture.fail(newException(EOS, osErrorMsg(errcode)))
       )
@@ -270,7 +290,7 @@ when defined(windows) or defined(nimdoc):
       if ret:
         # Request to connect completed immediately.
         success = true
-        retFuture.complete(0)
+        retFuture.complete()
         # We don't deallocate ``ol`` here because even though this completed
         # immediately poll will still be notified about its completion and it will
         # free ``ol``.
@@ -298,7 +318,7 @@ when defined(windows) or defined(nimdoc):
     ## recv operation then the future may complete with only a part of the
     ## requested data read. If socket is disconnected and no data is available
     ## to be read then the future will complete with a value of ``""``.
-
+    verifyPresence(p, socket)
     var retFuture = newFuture[string]()
     
     var dataBuf: TWSABuf
@@ -351,10 +371,11 @@ when defined(windows) or defined(nimdoc):
       # free ``ol``.
     return retFuture
 
-  proc send*(p: PDispatcher, socket: TSocketHandle, data: string): PFuture[int] =
+  proc send*(p: PDispatcher, socket: TSocketHandle, data: string): PFuture[void] =
     ## Sends ``data`` to ``socket``. The returned future will complete once all
     ## data has been sent.
-    var retFuture = newFuture[int]()
+    verifyPresence(p, socket)
+    var retFuture = newFuture[void]()
 
     var dataBuf: TWSABuf
     dataBuf.buf = data
@@ -366,7 +387,7 @@ when defined(windows) or defined(nimdoc):
       proc (sock: TSocketHandle, bytesCount: DWord, errcode: TOSErrorCode) =
         if not retFuture.finished:
           if errcode == TOSErrorCode(-1):
-            retFuture.complete(0)
+            retFuture.complete()
           else:
             retFuture.fail(newException(EOS, osErrorMsg(errcode)))
     )
@@ -379,7 +400,7 @@ when defined(windows) or defined(nimdoc):
         retFuture.fail(newException(EOS, osErrorMsg(err)))
         dealloc(ol)
     else:
-      retFuture.complete(0)
+      retFuture.complete()
       # We don't deallocate ``ol`` here because even though this completed
       # immediately poll will still be notified about its completion and it will
       # free ``ol``.
@@ -390,7 +411,9 @@ when defined(windows) or defined(nimdoc):
     ## Accepts a new connection. Returns a future containing the client socket
     ## corresponding to that connection and the remote address of the client.
     ## The future will complete when the connection is successfully accepted.
-    
+    ##
+    ## The resulting client socket is automatically registered to dispatcher.
+    verifyPresence(p, socket)
     var retFuture = newFuture[tuple[address: string, client: TSocketHandle]]()
 
     var clientSock = socket()
@@ -416,6 +439,7 @@ when defined(windows) or defined(nimdoc):
                            dwLocalAddressLength, dwRemoteAddressLength,
                            addr LocalSockaddr, addr localLen,
                            addr RemoteSockaddr, addr remoteLen)
+      p.register(clientSock)
       # TODO: IPv6. Check ``sa_family``. http://stackoverflow.com/a/9212542/492186
       retFuture.complete(
         (address: $inet_ntoa(cast[ptr Tsockaddr_in](remoteSockAddr).sin_addr),
@@ -452,6 +476,18 @@ when defined(windows) or defined(nimdoc):
 
     return retFuture
 
+  proc socket*(disp: PDispatcher, domain: TDomain = AF_INET,
+               typ: TType = SOCK_STREAM,
+               protocol: TProtocol = IPPROTO_TCP): TSocketHandle =
+    ## Creates a new socket and registers it with the dispatcher implicitly.
+    result = socket(domain, typ, protocol)
+    disp.register(result)
+
+  proc close*(disp: PDispatcher, socket: TSocketHandle) =
+    ## Closes a socket and ensures that it is unregistered.
+    socket.close()
+    disp.handles.excl(socket)
+
   initAll()
 else:
   import selectors
@@ -473,62 +509,76 @@ else:
 
   proc update(p: PDispatcher, sock: TSocketHandle, events: set[TEvent]) =
     assert sock in p.selector
-    echo("Update: ", events)
-    if events == {}:
-      discard p.selector.unregister(sock)
-    else:
-      discard p.selector.update(sock, events)
+    discard p.selector.update(sock, events)
+
+  proc register(p: PDispatcher, sock: TSocketHandle) =
+    var data = PData(sock: sock, readCBs: @[], writeCBs: @[])
+    p.selector.register(sock, {}, data.PObject)
+
+  proc socket*(disp: PDispatcher, domain: TDomain = AF_INET,
+               typ: TType = SOCK_STREAM,
+               protocol: TProtocol = IPPROTO_TCP): TSocketHandle =
+    result = socket(domain, typ, protocol)
+    disp.register(result)
   
+  proc close*(disp: PDispatcher, sock: TSocketHandle) =
+    sock.close()
+    disp.selector.unregister(sock)
+
   proc addRead(p: PDispatcher, sock: TSocketHandle, cb: TCallback) =
     if sock notin p.selector:
-      var data = PData(sock: sock, readCBs: @[cb], writeCBs: @[])
-      p.selector.register(sock, {EvRead}, data.PObject)
-    else:
-      p.selector[sock].data.PData.readCBs.add(cb)
-      p.update(sock, p.selector[sock].events + {EvRead})
+      raise newException(EInvalidValue, "File descriptor not registered.")
+    p.selector[sock].data.PData.readCBs.add(cb)
+    p.update(sock, p.selector[sock].events + {EvRead})
   
   proc addWrite(p: PDispatcher, sock: TSocketHandle, cb: TCallback) =
     if sock notin p.selector:
-      var data = PData(sock: sock, readCBs: @[], writeCBs: @[cb])
-      p.selector.register(sock, {EvWrite}, data.PObject)
-    else:
-      p.selector[sock].data.PData.writeCBs.add(cb)
-      p.update(sock, p.selector[sock].events + {EvWrite})
+      raise newException(EInvalidValue, "File descriptor not registered.")
+    p.selector[sock].data.PData.writeCBs.add(cb)
+    p.update(sock, p.selector[sock].events + {EvWrite})
   
   proc poll*(p: PDispatcher, timeout = 500) =
     for info in p.selector.select(timeout):
       let data = PData(info.key.data)
       assert data.sock == info.key.fd
-      echo("R: ", data.readCBs.len, " W: ", data.writeCBs.len, ". ", info.events)
-      
+      #echo("In poll ", data.sock.cint)
       if EvRead in info.events:
-        var newReadCBs: seq[TCallback] = @[]
-        for cb in data.readCBs:
+        # Callback may add items to ``data.readCBs`` which causes issues if
+        # we are iterating over ``data.readCBs`` at the same time. We therefore
+        # make a copy to iterate over.
+        let currentCBs = data.readCBs
+        data.readCBs = @[]
+        for cb in currentCBs:
           if not cb(data.sock):
             # Callback wants to be called again.
-            newReadCBs.add(cb)
-        data.readCBs = newReadCBs
+            data.readCBs.add(cb)
       
       if EvWrite in info.events:
-        var newWriteCBs: seq[TCallback] = @[]
-        for cb in data.writeCBs:
+        let currentCBs = data.writeCBs
+        data.writeCBs = @[]
+        for cb in currentCBs:
           if not cb(data.sock):
             # Callback wants to be called again.
-            newWriteCBs.add(cb)
-        data.writeCBs = newWriteCBs
-  
-      var newEvents: set[TEvent]
-      if data.readCBs.len != 0: newEvents = {EvRead}
-      if data.writeCBs.len != 0: newEvents = newEvents + {EvWrite}
-      p.update(data.sock, newEvents)
+            data.writeCBs.add(cb)
+      
+      if info.key in p.selector:
+        var newEvents: set[TEvent]
+        if data.readCBs.len != 0: newEvents = {EvRead}
+        if data.writeCBs.len != 0: newEvents = newEvents + {EvWrite}
+        if newEvents != info.key.events:
+          echo(info.key.events, " -> ", newEvents)
+          p.update(data.sock, newEvents)
+      else:
+        # FD no longer a part of the selector. Likely been closed
+        # (e.g. socket disconnected).
   
   proc connect*(p: PDispatcher, socket: TSocketHandle, address: string, port: TPort,
-    af = AF_INET): PFuture[int] =
-    var retFuture = newFuture[int]()
+    af = AF_INET): PFuture[void] =
+    var retFuture = newFuture[void]()
     
     proc cb(sock: TSocketHandle): bool =
       # We have connected.
-      retFuture.complete(0)
+      retFuture.complete()
       return true
     
     var aiList = getAddrInfo(address, port, af)
@@ -540,7 +590,7 @@ else:
       if ret == 0:
         # Request to connect completed immediately.
         success = true
-        retFuture.complete(0)
+        retFuture.complete()
         break
       else:
         lastError = osLastError()
@@ -568,6 +618,7 @@ else:
       result = true
       let netSize = size - sizeRead
       let res = recv(sock, addr readBuffer[sizeRead], netSize, flags.cint)
+      #echo("recv cb res: ", res)
       if res < 0:
         let lastError = osLastError()
         if lastError.int32 notin {EINTR, EWOULDBLOCK, EAGAIN}: 
@@ -575,6 +626,7 @@ else:
         else:
           result = false # We still want this callback to be called.
       elif res == 0:
+        #echo("Disconnected recv: ", sizeRead)
         # Disconnected
         if sizeRead == 0:
           retFuture.complete("")
@@ -587,12 +639,13 @@ else:
           result = false # We want to read all the data requested.
         else:
           retFuture.complete(readBuffer)
+      #echo("Recv cb result: ", result)
   
     addRead(p, socket, cb)
     return retFuture
 
-  proc send*(p: PDispatcher, socket: TSocketHandle, data: string): PFuture[int] =
-    var retFuture = newFuture[int]()
+  proc send*(p: PDispatcher, socket: TSocketHandle, data: string): PFuture[void] =
+    var retFuture = newFuture[void]()
     
     var written = 0
     
@@ -612,10 +665,9 @@ else:
         if res != netSize:
           result = false # We still have data to send.
         else:
-          retFuture.complete(0)
+          retFuture.complete()
     addWrite(p, socket, cb)
     return retFuture
-        
 
   proc acceptAddr*(p: PDispatcher, socket: TSocketHandle): 
       PFuture[tuple[address: string, client: TSocketHandle]] =
@@ -634,6 +686,7 @@ else:
         else:
           retFuture.fail(newException(EOS, osErrorMsg(lastError)))
       else:
+        p.register(client)
         retFuture.complete(($inet_ntoa(sockAddress.sin_addr), client))
     addRead(p, socket, cb)
     return retFuture
@@ -745,12 +798,17 @@ macro async*(prc: stmt): stmt {.immediate.} =
 
   hint("Processing " & prc[0].getName & " as an async proc.")
 
+  let returnType = prc[3][0]
+  var subtypeName = ""
   # Verify that the return type is a PFuture[T]
-  if prc[3][0].kind == nnkIdent:
-    error("Expected return type of 'PFuture' got '" & $prc[3][0] & "'")
-  elif prc[3][0].kind == nnkBracketExpr:
-    if $prc[3][0][0] != "PFuture":
-      error("Expected return type of 'PFuture' got '" & $prc[3][0][0] & "'")
+  if returnType.kind == nnkIdent:
+    error("Expected return type of 'PFuture' got '" & $returnType & "'")
+  elif returnType.kind == nnkBracketExpr:
+    if $returnType[0] != "PFuture":
+      error("Expected return type of 'PFuture' got '" & $returnType[0] & "'")
+    subtypeName = $returnType[1].ident
+  elif returnType.kind == nnkEmpty:
+    subtypeName = "void"
   
   # TODO: Why can't I use genSym? I get illegal capture errors for Syms.
   # TODO: It seems genSym is broken. Change all usages back to genSym when fixed
@@ -763,20 +821,24 @@ macro async*(prc: stmt): stmt {.immediate.} =
     newVarStmt(retFutureSym, 
       newCall(
         newNimNode(nnkBracketExpr).add(
-          newIdentNode("newFuture"),
-          prc[3][0][1])))) # Get type from return type of this proc.
-
+          newIdentNode(!"newFuture"), # TODO: Strange bug here? Remove the `!`.
+          newIdentNode(subtypeName))))) # Get type from return type of this proc
+  echo(treeRepr(outerProcBody))
   # -> iterator nameIter(): PFutureBase {.closure.} = 
   # ->   var result: T
   # ->   <proc_body>
   # ->   complete(retFuture, result)
   var iteratorNameSym = newIdentNode($prc[0].getName & "Iter") #genSym(nskIterator, $prc[0].ident & "Iter")
   var procBody = prc[6].processBody(retFutureSym)
-  procBody.insert(0, newNimNode(nnkVarSection).add(
-    newIdentDefs(newIdentNode("result"), prc[3][0][1]))) # -> var result: T
-  procBody.add(
-    newCall(newIdentNode("complete"),
-      retFutureSym, newIdentNode("result"))) # -> complete(retFuture, result)
+  if subtypeName != "void":
+    procBody.insert(0, newNimNode(nnkVarSection).add(
+      newIdentDefs(newIdentNode("result"), returnType[1]))) # -> var result: T
+    procBody.add(
+      newCall(newIdentNode("complete"),
+        retFutureSym, newIdentNode("result"))) # -> complete(retFuture, result)
+  else:
+    # -> complete(retFuture)
+    procBody.add(newCall(newIdentNode("complete"), retFutureSym))
   
   var closureIterator = newProc(iteratorNameSym, [newIdentNode("PFutureBase")],
                                 procBody, nnkIteratorDef)
@@ -811,6 +873,12 @@ macro async*(prc: stmt): stmt {.immediate.} =
   for i in 0 .. <result[4].len:
     if result[4][i].ident == !"async":
       result[4].del(i)
+  if subtypeName == "void":
+    # Add discardable pragma.
+    result[4].add(newIdentNode("discardable"))
+    if returnType.kind == nnkEmpty:
+      # Add PFuture[void]
+      result[3][0] = parseExpr("PFuture[void]")
 
   result[6] = outerProcBody
 
@@ -833,9 +901,13 @@ proc recvLine*(p: PDispatcher, socket: TSocketHandle): PFuture[string] {.async.}
   result = ""
   var c = ""
   while true:
+    #echo("1")
     c = await p.recv(socket, 1)
+    #echo("Received ", c.len)
     if c.len == 0:
+      #echo("returning")
       return
+    #echo("2")
     if c == "\r":
       c = await p.recv(socket, 1, MSG_PEEK)
       if c.len > 0 and c == "\L":
@@ -845,12 +917,14 @@ proc recvLine*(p: PDispatcher, socket: TSocketHandle): PFuture[string] {.async.}
     elif c == "\L":
       addNLIfEmpty()
       return
+    #echo("3")
     add(result.string, c)
+  #echo("4")
 
 when isMainModule:
   
   var p = newDispatcher()
-  var sock = socket()
+  var sock = p.socket()
   sock.setBlocking false
 
 
@@ -859,6 +933,7 @@ when isMainModule:
     proc main(p: PDispatcher): PFuture[int] {.async.} =
       discard await p.connect(sock, "irc.freenode.net", TPort(6667))
       while true:
+        echo("recvLine")
         var line = await p.recvLine(sock)
         echo("Line is: ", line.repr)
         if line == "":
@@ -882,7 +957,7 @@ when isMainModule:
   else:
     when false:
 
-      var f = p.connect(sock, "irc.freenode.org", TPort(6667))
+      var f = p.connect(sock, "irc.poop.nl", TPort(6667))
       f.callback =
         proc (future: PFuture[int]) =
           echo("Connected in future!")
@@ -898,11 +973,13 @@ when isMainModule:
       sock.bindAddr(TPort(6667))
       sock.listen()
       proc onAccept(future: PFuture[TSocketHandle]) =
-        echo "Accepted"
-        var t = p.send(future.read, "test\c\L")
+        let client = future.read
+        echo "Accepted ", client.cint
+        var t = p.send(client, "test\c\L")
         t.callback =
           proc (future: PFuture[int]) =
-            echo(future.read)
+            echo("Send: ", future.read)
+            client.close()
         
         var f = p.accept(sock)
         f.callback = onAccept
@@ -919,4 +996,4 @@ when isMainModule:
 
   
 
-  
\ No newline at end of file
+