summary refs log tree commit diff stats
path: root/lib/pure/asyncio.nim
diff options
context:
space:
mode:
authorDominik Picheta <dominikpicheta@googlemail.com>2012-07-22 23:32:49 +0100
committerDominik Picheta <dominikpicheta@googlemail.com>2012-07-22 23:32:49 +0100
commit5310a3044fa4187274e2bfe59de68f394a81c89d (patch)
treeeab220ba823fb2f706ff6abab7b52cdb982821e3 /lib/pure/asyncio.nim
parentb839e42e92edf6acfca73768cbdd9c7595ca8797 (diff)
downloadNim-5310a3044fa4187274e2bfe59de68f394a81c89d.tar.gz
Many fixes for asynchronous sockets. Asyncio should now work well with buffered and unbuffered plain and ssl sockets. Added asyncio
test to the test suite.
Diffstat (limited to 'lib/pure/asyncio.nim')
-rw-r--r--lib/pure/asyncio.nim169
1 files changed, 142 insertions, 27 deletions
diff --git a/lib/pure/asyncio.nim b/lib/pure/asyncio.nim
index ac94f2087..7765a8b29 100644
--- a/lib/pure/asyncio.nim
+++ b/lib/pure/asyncio.nim
@@ -11,8 +11,9 @@ import sockets, os
 ## This module implements an asynchronous event loop for sockets. 
 ## It is akin to Python's asyncore module. Many modules that use sockets
 ## have an implementation for this module, those modules should all have a 
-## ``register`` function which you should use to add it to a dispatcher so
-## that you can receive the events associated with that module.
+## ``register`` function which you should use to add the desired objects to a 
+## dispatcher which you created so
+## that you can receive the events associated with that module's object.
 ##
 ## Once everything is registered in a dispatcher, you need to call the ``poll``
 ## function in a while loop.
@@ -28,9 +29,46 @@ import sockets, os
 ## Most (if not all) modules that use asyncio provide a userArg which is passed
 ## on with the events. The type that you set userArg to must be inheriting from
 ## TObject!
+##
+## Asynchronous sockets
+## ====================
+##
+## For most purposes you do not need to worry about the ``TDelegate`` type. The
+## ``PAsyncSocket`` is what you are after. It's a reference to the ``TAsyncSocket``
+## object. This object defines events which you should overwrite by your own
+## procedures.
+##
+## For server sockets the only event you need to worry about is the ``handleAccept``
+## event, in your handleAccept proc you should call ``accept`` on the server
+## socket which will give you the client which is connecting. You should then
+## set any events that you want to use on that client and add it to your dispatcher
+## using the ``register`` procedure.
+## 
+## An example ``handleAccept`` follows:
+## 
+## .. code:: nimrod
+##   
+##    var disp: PDispatcher = newDispatcher()
+##    ...
+##    proc handleAccept(s: PAsyncSocket, arg: Pobject) {.nimcall.} =
+##      echo("Accepted client.")
+##      var client: PAsyncSocket
+##      new(client)
+##      s.accept(client)
+##      client.handleRead = ...
+##      disp.register(client)
+##    ...
+## 
+## For client sockets you should only be interested in the ``handleRead`` and
+## ``handleConnect`` events. The former gets called whenever the socket has
+## received messages and can be read from and the latter gets called whenever
+## the socket has established a connection to a server socket; from that point
+## it can be safely written to.
+
+
 
 type
-  TDelegate = object
+  TDelegate* = object
     deleVal*: PObject
 
     handleRead*: proc (h: PObject) {.nimcall.}
@@ -50,7 +88,7 @@ type
     delegates: seq[PDelegate]
 
   PAsyncSocket* = ref TAsyncSocket
-  TAsyncSocket = object of TObject
+  TAsyncSocket* = object of TObject
     socket: TSocket
     info: TInfo
 
@@ -62,6 +100,7 @@ type
     handleAccept*: proc (s:  PAsyncSocket, arg: PObject) {.nimcall.}
 
     lineBuffer: TaintedString ## Temporary storage for ``recvLine``
+    sslNeedAccept: bool
 
   TInfo* = enum
     SockIdle, SockConnecting, SockConnected, SockListening, SockClosed
@@ -94,29 +133,60 @@ proc newAsyncSocket(userArg: PObject = nil): PAsyncSocket =
 
 proc AsyncSocket*(domain: TDomain = AF_INET, typ: TType = SOCK_STREAM, 
                   protocol: TProtocol = IPPROTO_TCP, 
-                  userArg: PObject = nil): PAsyncSocket =
+                  userArg: PObject = nil, buffered = true): PAsyncSocket =
   result = newAsyncSocket(userArg)
-  result.socket = socket(domain, typ, protocol)
+  result.socket = socket(domain, typ, protocol, buffered)
   if result.socket == InvalidSocket: OSError()
   result.socket.setBlocking(false)
 
+proc asyncSockHandleConnect(h: PObject) =
+  when defined(ssl):
+    if PAsyncSocket(h).socket.isSSL and not
+         PAsyncSocket(h).socket.gotHandshake:
+      return  
+      
+  PAsyncSocket(h).info = SockConnected
+  PAsyncSocket(h).handleConnect(PAsyncSocket(h),
+     PAsyncSocket(h).userArg)
+
+proc asyncSockHandleRead(h: PObject) =
+  when defined(ssl):
+    if PAsyncSocket(h).socket.isSSL and not
+         PAsyncSocket(h).socket.gotHandshake:
+      return
+  PAsyncSocket(h).handleRead(PAsyncSocket(h), PAsyncSocket(h).userArg)
+
+when defined(ssl):
+  proc asyncSockDoHandshake(h: PObject) =
+    if PAsyncSocket(h).socket.isSSL and not
+         PAsyncSocket(h).socket.gotHandshake:
+      if PAsyncSocket(h).sslNeedAccept:
+        var d = ""
+        let ret = PAsyncSocket(h).socket.acceptAddrSSL(PAsyncSocket(h).socket, d)
+        assert ret != AcceptNoClient
+        if ret == AcceptSuccess:
+          PAsyncSocket(h).info = SockConnected
+      else:
+        # handshake will set socket's ``sslNoHandshake`` field.
+        discard PAsyncSocket(h).socket.handshake()
+
 proc toDelegate(sock: PAsyncSocket): PDelegate =
   result = newDelegate()
   result.deleVal = sock
   result.getSocket = (proc (h: PObject): tuple[info: TInfo, sock: TSocket] =
                         return (PAsyncSocket(h).info, PAsyncSocket(h).socket))
 
-  result.handleConnect = (proc (h: PObject) =
-                            PAsyncSocket(h).info = SockConnected
-                            PAsyncSocket(h).handleConnect(PAsyncSocket(h),
-                               PAsyncSocket(h).userArg))
-  result.handleRead = (proc (h: PObject) =
-                         PAsyncSocket(h).handleRead(PAsyncSocket(h),
-                            PAsyncSocket(h).userArg))
+  result.handleConnect = asyncSockHandleConnect
+  
+  result.handleRead = asyncSockHandleRead
+  
   result.handleAccept = (proc (h: PObject) =
                            PAsyncSocket(h).handleAccept(PAsyncSocket(h),
                               PAsyncSocket(h).userArg))
 
+  when defined(ssl):
+    result.task = asyncSockDoHandshake
+
 proc connect*(sock: PAsyncSocket, name: string, port = TPort(0),
                    af: TDomain = AF_INET) =
   ## Begins connecting ``sock`` to ``name``:``port``.
@@ -137,21 +207,61 @@ proc listen*(sock: PAsyncSocket) =
   sock.socket.listen()
   sock.info = SockListening
 
+proc acceptAddr*(server: PAsyncSocket, client: var PAsyncSocket,
+                 address: var string) =
+  ## Equivalent to ``sockets.acceptAddr``. This procedure should be called in
+  ## a ``handleAccept`` event handler **only** once.
+  ##
+  ## **Note**: ``client`` needs to be initialised.
+  assert(client != nil)
+  var c: TSocket
+  new(c)
+  when defined(ssl):
+    if server.socket.isSSL:
+      var ret = server.socket.acceptAddrSSL(c, address)
+      # The following shouldn't happen because when this function is called
+      # it is guaranteed that there is a client waiting.
+      # (This should be called in handleAccept)
+      assert(ret != AcceptNoClient)
+      if ret == AcceptNoHandshake:
+        client.sslNeedAccept = true
+      else:
+        client.sslNeedAccept = false
+        client.info = SockConnected
+    else:
+      server.socket.acceptAddr(c, address)
+      client.sslNeedAccept = false
+      client.info = SockConnected
+  else:
+    server.socket.acceptAddr(c, address)
+    client.sslNeedAccept = false
+    client.info = SockConnected
+
+  if c == InvalidSocket: OSError()
+  c.setBlocking(false) # TODO: Needs to be tested.
+  
+  client.socket = c
+  client.lineBuffer = ""
+
+proc accept*(server: PAsyncSocket, client: var PAsyncSocket) =
+  ## Equivalent to ``sockets.accept``.
+  var dummyAddr = ""
+  server.acceptAddr(client, dummyAddr)
+
 proc acceptAddr*(server: PAsyncSocket): tuple[sock: PAsyncSocket,
-                                              address: string] =
+                                              address: string] {.deprecated.} =
   ## Equivalent to ``sockets.acceptAddr``.
-  var (client, a) = server.socket.acceptAddr()
-  if client == InvalidSocket: OSError()
-  client.setBlocking(false) # TODO: Needs to be tested.
-  
-  var aSock: PAsyncSocket = newAsyncSocket()
-  aSock.socket = client
-  aSock.info = SockConnected
-  
-  return (aSock, a)
+  ## 
+  ## **Warning**: This is deprecated in favour of the above.
+  var client = newAsyncSocket()
+  var address: string = ""
+  acceptAddr(server, client, address)
+  return (client, address)
 
-proc accept*(server: PAsyncSocket): PAsyncSocket =
+proc accept*(server: PAsyncSocket): PAsyncSocket {.deprecated.} =
   ## Equivalent to ``sockets.accept``.
+  ##
+  ## **Warning**: This is deprecated.
   var (client, a) = server.acceptAddr()
   return client
 
@@ -210,8 +320,9 @@ proc recvLine*(s: PAsyncSocket, line: var TaintedString): bool =
     if s.lineBuffer.len > 0:
       string(line).add(s.lineBuffer.string)
       setLen(s.lineBuffer.string, 0)
-    
     string(line).add(dataReceived.string)
+    if string(line) == "":
+      line = "\c\L".TaintedString
     result = true
   of RecvPartialLine:
     string(s.lineBuffer).add(dataReceived.string)
@@ -263,7 +374,7 @@ proc poll*(d: PDispatcher, timeout: int = 500): bool =
   
   if readSocks.len() == 0 and writeSocks.len() == 0:
     return False
-  
+
   if select(readSocks, writeSocks, timeout) != 0:
     for i in 0..len(d.delegates)-1:
       if i > len(d.delegates)-1: break # One delegate might've been removed.
@@ -294,7 +405,11 @@ proc poll*(d: PDispatcher, timeout: int = 500): bool =
   # Execute tasks
   for i in items(d.delegates):
     i.task(i.deleVal)
-  
+
+proc len*(disp: PDispatcher): int =
+  ## Retrieves the amount of delegates in ``disp``.
+  return disp.delegates.len
+
 when isMainModule:
   type
     PIntType = ref TIntType