summary refs log tree commit diff stats
diff options
context:
space:
mode:
authorAraq <rumpf_a@web.de>2012-08-13 18:52:00 +0200
committerAraq <rumpf_a@web.de>2012-08-13 18:52:00 +0200
commitadb73ec9ed87a5c5fe1ced35f3440b85bac49d8a (patch)
tree1cecba68cdfca101900750e8e84d8dc62f520084
parent8805829d7ff2d3fc6c692cfa4931439d25d1bb6e (diff)
parent8dd1a5a1818584535edbb3cbf90c599cf4b9eefc (diff)
downloadNim-adb73ec9ed87a5c5fe1ced35f3440b85bac49d8a.tar.gz
Merge branch 'master' of github.com:Araq/Nimrod
-rw-r--r--lib/pure/asyncio.nim98
-rw-r--r--lib/pure/ftpclient.nim13
-rwxr-xr-xlib/pure/sockets.nim76
-rwxr-xr-xlib/windows/winlean.nim6
-rw-r--r--tests/run/tasynciossl.nim8
-rw-r--r--tests/run/tasyncudp.nim77
6 files changed, 213 insertions, 65 deletions
diff --git a/lib/pure/asyncio.nim b/lib/pure/asyncio.nim
index e5808baaf..113b1d080 100644
--- a/lib/pure/asyncio.nim
+++ b/lib/pure/asyncio.nim
@@ -30,6 +30,11 @@ import sockets, os
 ## on with the events. The type that you set userArg to must be inheriting from
 ## TObject!
 ##
+## **Note:** If you want to provide async ability to your module please do not 
+## use the ``TDelegate`` object, instead use ``PAsyncSocket``. It is possible 
+## that in the future this type's fields will not be exported therefore breaking
+## your code.
+##
 ## Asynchronous sockets
 ## ====================
 ##
@@ -68,7 +73,8 @@ import sockets, os
 
 
 type
-  TDelegate* = object
+
+  TDelegate = object
     deleVal*: PObject
 
     handleRead*: proc (h: PObject) {.nimcall.}
@@ -92,12 +98,10 @@ type
     socket: TSocket
     info: TInfo
 
-    userArg: PObject
-
-    handleRead*: proc (s: PAsyncSocket, arg: PObject) {.nimcall.}
-    handleConnect*: proc (s:  PAsyncSocket, arg: PObject) {.nimcall.}
+    handleRead*: proc (s: PAsyncSocket) {.closure.}
+    handleConnect*: proc (s:  PAsyncSocket) {.closure.}
 
-    handleAccept*: proc (s:  PAsyncSocket, arg: PObject) {.nimcall.}
+    handleAccept*: proc (s:  PAsyncSocket) {.closure.}
 
     lineBuffer: TaintedString ## Temporary storage for ``recvLine``
     sslNeedAccept: bool
@@ -121,21 +125,20 @@ proc newDelegate*(): PDelegate =
   result.task = (proc (h: PObject) = nil)
   result.mode = MReadable
 
-proc newAsyncSocket(userArg: PObject = nil): PAsyncSocket =
+proc newAsyncSocket(): PAsyncSocket =
   new(result)
   result.info = SockIdle
-  result.userArg = userArg
 
-  result.handleRead = (proc (s: PAsyncSocket, arg: PObject) = nil)
-  result.handleConnect = (proc (s: PAsyncSocket, arg: PObject) = nil)
-  result.handleAccept = (proc (s: PAsyncSocket, arg: PObject) = nil)
+  result.handleRead = (proc (s: PAsyncSocket) = nil)
+  result.handleConnect = (proc (s: PAsyncSocket) = nil)
+  result.handleAccept = (proc (s: PAsyncSocket) = nil)
 
   result.lineBuffer = "".TaintedString
 
 proc AsyncSocket*(domain: TDomain = AF_INET, typ: TType = SOCK_STREAM, 
                   protocol: TProtocol = IPPROTO_TCP, 
-                  userArg: PObject = nil, buffered = true): PAsyncSocket =
-  result = newAsyncSocket(userArg)
+                  buffered = true): PAsyncSocket =
+  result = newAsyncSocket()
   result.socket = socket(domain, typ, protocol, buffered)
   result.proto = protocol
   if result.socket == InvalidSocket: OSError()
@@ -148,15 +151,14 @@ proc asyncSockHandleConnect(h: PObject) =
       return  
       
   PAsyncSocket(h).info = SockConnected
-  PAsyncSocket(h).handleConnect(PAsyncSocket(h),
-     PAsyncSocket(h).userArg)
+  PAsyncSocket(h).handleConnect(PAsyncSocket(h))
 
 proc asyncSockHandleRead(h: PObject) =
   when defined(ssl):
     if PAsyncSocket(h).socket.isSSL and not
          PAsyncSocket(h).socket.gotHandshake:
       return
-  PAsyncSocket(h).handleRead(PAsyncSocket(h), PAsyncSocket(h).userArg)
+  PAsyncSocket(h).handleRead(PAsyncSocket(h))
 
 when defined(ssl):
   proc asyncSockDoHandshake(h: PObject) =
@@ -183,8 +185,7 @@ proc toDelegate(sock: PAsyncSocket): PDelegate =
   result.handleRead = asyncSockHandleRead
   
   result.handleAccept = (proc (h: PObject) =
-                           PAsyncSocket(h).handleAccept(PAsyncSocket(h),
-                              PAsyncSocket(h).userArg))
+                           PAsyncSocket(h).handleAccept(PAsyncSocket(h)))
 
   when defined(ssl):
     result.task = asyncSockDoHandshake
@@ -337,7 +338,6 @@ proc recvLine*(s: PAsyncSocket, line: var TaintedString): bool =
   of RecvFail:
     result = false
 
-
 proc poll*(d: PDispatcher, timeout: int = 500): bool =
   ## This function checks for events on all the sockets in the `PDispatcher`.
   ## It then proceeds to call the correct event handler.
@@ -417,59 +417,47 @@ proc len*(disp: PDispatcher): int =
   return disp.delegates.len
 
 when isMainModule:
-  type
-    PIntType = ref TIntType
-    TIntType = object of TObject
-      val: int
-
-    PMyArg = ref TMyArg
-    TMyArg = object of TObject
-      dispatcher: PDispatcher
-      val: int
-
-  proc testConnect(s: PAsyncSocket, arg: PObject) =
-    echo("Connected! " & $PIntType(arg).val)
+
+  proc testConnect(s: PAsyncSocket, no: int) =
+    echo("Connected! " & $no)
   
-  proc testRead(s: PAsyncSocket, arg: PObject) =
-    echo("Reading! " & $PIntType(arg).val)
+  proc testRead(s: PAsyncSocket, no: int) =
+    echo("Reading! " & $no)
     var data = s.getSocket.recv()
     if data == "":
-      echo("Closing connection. " & $PIntType(arg).val)
+      echo("Closing connection. " & $no)
       s.close()
     echo(data)
-    echo("Finished reading! " & $PIntType(arg).val)
+    echo("Finished reading! " & $no)
 
-  proc testAccept(s: PAsyncSocket, arg: PObject) =
-    echo("Accepting client! " & $PMyArg(arg).val)
+  proc testAccept(s: PAsyncSocket, disp: PDispatcher, no: int) =
+    echo("Accepting client! " & $no)
     var client: PAsyncSocket
     new(client)
     var address = ""
     s.acceptAddr(client, address)
     echo("Accepted ", address)
-    client.handleRead = testRead
-    var userArg: PIntType
-    new(userArg)
-    userArg.val = 78
-    client.userArg = userArg
-    PMyArg(arg).dispatcher.register(client)
+    client.handleRead = 
+      proc (s: PAsyncSocket) =
+        testRead(s, 2)
+    disp.register(client)
 
   var d = newDispatcher()
   
-  var userArg: PIntType
-  new(userArg)
-  userArg.val = 0
-  var s = AsyncSocket(userArg = userArg)
+  var s = AsyncSocket()
   s.connect("amber.tenthbit.net", TPort(6667))
-  s.handleConnect = testConnect
-  s.handleRead = testRead
+  s.handleConnect = 
+    proc (s: PAsyncSocket) =
+      testConnect(s, 1)
+  s.handleRead = 
+    proc (s: PAsyncSocket) =
+      testRead(s, 1)
   d.register(s)
   
-  var userArg1: PMyArg
-  new(userArg1)
-  userArg1.val = 1
-  userArg1.dispatcher = d
-  var server = AsyncSocket(userArg = userArg1)
-  server.handleAccept = testAccept
+  var server = AsyncSocket()
+  server.handleAccept =
+    proc (s: PAsyncSocket) = 
+      testAccept(s, d, 78)
   server.bindAddr(TPort(5555))
   server.listen()
   d.register(server)
diff --git a/lib/pure/ftpclient.nim b/lib/pure/ftpclient.nim
index fc049e3b7..4cf25a3cb 100644
--- a/lib/pure/ftpclient.nim
+++ b/lib/pure/ftpclient.nim
@@ -227,7 +227,18 @@ proc listDirs*(ftp: var TFTPClient, dir: string = "",
     ftp.deleteJob()
   else: return @[]
 
-proc fileExists*(ftp: var TFTPClient, file: string): bool =
+proc fileExists*(ftp: var TFTPClient, file: string): bool {.deprecated.} =
+  ## **Deprecated:** Please use ``existsFile``.
+  ##
+  ## Determines whether ``file`` exists.
+  ##
+  ## Warning: This function may block. Especially on directories with many
+  ## files, because a full list of file names must be retrieved.
+  var files = ftp.listDirs()
+  for f in items(files):
+    if f.normalizePathSep == file.normalizePathSep: return true
+
+proc existsFile*(ftp: var TFTPClient, file: string): bool =
   ## Determines whether ``file`` exists.
   ##
   ## Warning: This function may block. Especially on directories with many
diff --git a/lib/pure/sockets.nim b/lib/pure/sockets.nim
index c2774e84f..28f0a325e 100755
--- a/lib/pure/sockets.nim
+++ b/lib/pure/sockets.nim
@@ -74,7 +74,7 @@ type
     SOCK_STREAM = 1,   ## reliable stream-oriented service or Stream Sockets
     SOCK_DGRAM = 2,    ## datagram service or Datagram Sockets
     SOCK_RAW = 3,      ## raw protocols atop the network layer.
-    SOCK_SEQPACKET = 5 ## reliable sequenced packet service, or
+    SOCK_SEQPACKET = 5 ## reliable sequenced packet service
 
   TProtocol* = enum     ## third argument to `socket` proc
     IPPROTO_TCP = 6,    ## Transmission control protocol. 
@@ -1203,7 +1203,45 @@ proc recvAsync*(socket: TSocket, s: var TaintedString): bool =
     setLen(s.string, s.string.len + bufSize)
     inc(pos, bytesRead)
   result = True
+
+proc recvFrom*(socket: TSocket, data: var string, length: int,
+               address: var string, flags = 0'i32): int =
+  ## Receives data from ``socket``. This function should normally be used with
+  ## connection-less sockets (UDP sockets).
+  ##
+  ## **Warning:** This function does not yet have a buffered implementation,
+  ## so when ``socket`` is buffered the non-buffered implementation will be
+  ## used. Therefore if ``socket`` contains something in its buffer this
+  ## function will make no effort to return it.
+  
+  # TODO: Buffered sockets
+  data = newString(length)
+  var sockAddress: Tsockaddr_in
+  var addrLen = sizeof(sockAddress).TSockLen
+  result = recvFrom(socket.fd, cstring(data), length.cint, flags.cint,
+                    cast[ptr TSockAddr](addr(sockAddress)), addr(addrLen))
   
+  if result != -1:
+    address = $inet_ntoa(sockAddress.sin_addr)
+
+proc recvFromAsync*(socket: TSocket, data: var String, length: int,
+                    address: var string, flags = 0'i32): bool =
+  ## Similar to ``recvFrom`` but raises an EOS error when an error occurs.
+  ## Returns False if no messages could be received from ``socket``.
+  result = true
+  var callRes = recvFrom(socket, data, length, address)
+  if callRes < 0:
+    when defined(windows):
+      # TODO: Test on Windows
+      var err = WSAGetLastError()
+      if err == WSAEWOULDBLOCK:
+        return False
+      else: OSError()
+    else:
+      if errno == EAGAIN or errno == EWOULDBLOCK:
+        return False
+      else: OSError()
+
 proc skip*(socket: TSocket) =
   ## skips all the data that is pending for the socket
   const bufSize = 1000
@@ -1270,6 +1308,37 @@ proc trySend*(socket: TSocket, data: string): bool =
   ## and instead returns ``false`` on failure.
   result = send(socket, cstring(data), data.len) == data.len
 
+proc sendTo*(socket: TSocket, address: string, port: TPort, data: pointer,
+             size: int, af: TDomain = AF_INET, flags = 0'i32): int =
+  ## low-level sendTo proc. This proc sends ``data`` to the specified ``address``,
+  ## which may be an IP address or a hostname, if a hostname is specified 
+  ## this function will try each IP of that hostname.
+  ##
+  ## **Note:** This proc is not available for SSL sockets.
+  var hints: TAddrInfo
+  var aiList: ptr TAddrInfo = nil
+  hints.ai_family = toInt(af)
+  hints.ai_socktype = toInt(SOCK_STREAM)
+  hints.ai_protocol = toInt(IPPROTO_TCP)
+  gaiNim(address, port, hints, aiList)
+  
+  # try all possibilities:
+  var success = false
+  var it = aiList
+  while it != nil:
+    result = sendTo(socket.fd, data, size.cint, flags.cint, it.ai_addr,
+                    it.ai_addrlen.TSockLen)
+    if result != -1'i32:
+      success = true
+      break
+    it = it.ai_next
+
+  freeaddrinfo(aiList)
+
+proc sendTo*(socket: TSocket, address: string, port: TPort, data: string): int =
+  ## Friendlier version of the low-level ``sendTo``.
+  result = socket.sendTo(address, port, cstring(data), data.len)
+
 when defined(Windows):
   const 
     SOCKET_ERROR = -1
@@ -1303,8 +1372,8 @@ proc connect*(socket: TSocket, timeout: int, name: string, port = TPort(0),
   ## specifies the time in miliseconds of how long to wait for a connection
   ## to be made.
   ##
-  ## **Warning:** If ``socket`` is non-blocking and timeout is not ``-1`` then
-  ## this function may set blocking mode on ``socket`` to true.
+  ## **Warning:** If ``socket`` is non-blocking then
+  ## this function will set blocking mode on ``socket`` to true.
   socket.setBlocking(true)
   
   socket.connectAsync(name, port, af)
@@ -1313,6 +1382,7 @@ proc connect*(socket: TSocket, timeout: int, name: string, port = TPort(0),
     raise newException(ETimeout, "Call to connect() timed out.")
 
 proc isSSL*(socket: TSocket): bool = return socket.isSSL
+  ## Determines whether ``socket`` is a SSL socket.
 
 when defined(Windows):
   var wsa: TWSADATA
diff --git a/lib/windows/winlean.nim b/lib/windows/winlean.nim
index 4c0671df5..1ea00c737 100755
--- a/lib/windows/winlean.nim
+++ b/lib/windows/winlean.nim
@@ -451,15 +451,15 @@ proc listen*(s: TWinSocket, backlog: cint): cint {.
 proc recv*(s: TWinSocket, buf: pointer, len, flags: cint): cint {.
   stdcall, importc: "recv", dynlib: ws2dll.}
 proc recvfrom*(s: TWinSocket, buf: cstring, len, flags: cint, 
-               fromm: ptr TSockAddr, fromlen: ptr cint): cint {.
+               fromm: ptr TSockAddr, fromlen: ptr Tsocklen): cint {.
   stdcall, importc: "recvfrom", dynlib: ws2dll.}
 proc select*(nfds: cint, readfds, writefds, exceptfds: ptr TFdSet,
              timeout: ptr TTimeval): cint {.
   stdcall, importc: "select", dynlib: ws2dll.}
 proc send*(s: TWinSocket, buf: pointer, len, flags: cint): cint {.
   stdcall, importc: "send", dynlib: ws2dll.}
-proc sendto*(s: TWinSocket, buf: cstring, len, flags: cint,
-             to: ptr TSockAddr, tolen: cint): cint {.
+proc sendto*(s: TWinSocket, buf: pointer, len, flags: cint,
+             to: ptr TSockAddr, tolen: Tsocklen): cint {.
   stdcall, importc: "sendto", dynlib: ws2dll.}
 
 proc shutdown*(s: TWinSocket, how: cint): cint {.
diff --git a/tests/run/tasynciossl.nim b/tests/run/tasynciossl.nim
index 99e7df172..e5fb9610c 100644
--- a/tests/run/tasynciossl.nim
+++ b/tests/run/tasynciossl.nim
@@ -18,13 +18,13 @@ const
   swarmSize = 50
   messagesToSend = 100
 
-proc swarmConnect(s: PAsyncSocket, arg: PObject) {.nimcall.} =
+proc swarmConnect(s: PAsyncSocket) =
   #echo("Connected")
   for i in 1..messagesToSend:
     s.send("Message " & $i & "\c\L")
   s.close()
 
-proc serverRead(s: PAsyncSocket, arg: PObject) {.nimcall.} =
+proc serverRead(s: PAsyncSocket) =
   var line = ""
   assert s.recvLine(line)
   if line != "":
@@ -36,7 +36,7 @@ proc serverRead(s: PAsyncSocket, arg: PObject) {.nimcall.} =
   else:
     s.close()
 
-proc serverAccept(s: PAsyncSocket, arg: Pobject) {.nimcall.} =
+proc serverAccept(s: PAsyncSocket) =
   var client: PAsyncSocket
   new(client)
   s.accept(client)
@@ -83,6 +83,8 @@ while true:
     break
   if not disp.poll(): break
   if disp.len == serverCount:
+    # Only the servers are left in the dispatcher. All clients finished,
+    # we need to therefore break.
     break
 
 assert msgCount == (swarmSize * messagesToSend) * serverCount
diff --git a/tests/run/tasyncudp.nim b/tests/run/tasyncudp.nim
new file mode 100644
index 000000000..b404169dc
--- /dev/null
+++ b/tests/run/tasyncudp.nim
@@ -0,0 +1,77 @@
+discard """
+  file: "tasyncudp.nim"
+  output: "2000"
+"""
+import asyncio, sockets, strutils, times
+
+const
+  swarmSize = 5
+  messagesToSend = 200
+
+var
+  disp = newDispatcher()
+  msgCount = 0
+  currentClient = 0
+
+proc serverRead(s: PAsyncSocket) =
+  var data = ""
+  var address = ""
+  if s.recvFromAsync(data, 9, address):
+    assert address == "127.0.0.1"
+    msgCount.inc()
+  
+  discard """
+  
+  var line = ""
+  assert s.recvLine(line)
+  
+  if line == "":
+    assert(false)
+  else:
+    if line.startsWith("Message "):
+      msgCount.inc()
+    else:
+      assert(false)
+  """
+
+proc swarmConnect(s: PAsyncSocket) =
+  for i in 1..messagesToSend:
+    s.send("Message\c\L")
+
+proc createClient(disp: var PDispatcher, port: TPort,
+                  buffered = true) =
+  currentClient.inc()
+  var client = AsyncSocket(typ = SOCK_DGRAM, protocol = IPPROTO_UDP,
+                           buffered = buffered)
+  client.handleConnect = swarmConnect
+  disp.register(client)
+  client.connect("localhost", port)
+
+proc createServer(port: TPort, buffered = true) =
+  var server = AsyncSocket(typ = SOCK_DGRAM, protocol = IPPROTO_UDP,
+                           buffered = buffered)
+  server.handleRead = serverRead
+  disp.register(server)
+  server.bindAddr(port)
+
+let serverCount = 2
+
+createServer(TPort(10335), false)
+createServer(TPort(10336), true)
+var startTime = epochTime()
+while true:
+  if epochTime() - startTime >= 300.0:
+    break
+
+  if not disp.poll():
+    break
+  
+  if (msgCount div messagesToSend) * serverCount == currentClient:
+    createClient(disp, TPort(10335), false)
+    createClient(disp, TPort(10336), true)
+  
+  if msgCount == messagesToSend * serverCount * swarmSize:
+    break
+
+assert msgCount == messagesToSend * serverCount * swarmSize
+echo(msgCount)
\ No newline at end of file