summary refs log tree commit diff stats
path: root/lib/pure/asyncmacro.nim
blob: 11eba427bd346788551b76ee4ca8b2fe79f6cb24 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
#
#
#            Nim's Runtime Library
#        (c) Copyright 2015 Dominik Picheta
#
#    See the file "copying.txt", included in this
#    distribution, for details about the copyright.
#

## `asyncdispatch` module depends on the `asyncmacro` module to work properly.

import macros, strutils, asyncfutures

proc skipUntilStmtList(node: NimNode): NimNode {.compileTime.} =
  # Skips a nest of StmtList's.
  result = node
  if node[0].kind == nnkStmtList:
    result = skipUntilStmtList(node[0])

proc skipStmtList(node: NimNode): NimNode {.compileTime.} =
  result = node
  if node[0].kind == nnkStmtList:
    result = node[0]

template createCb(retFutureSym, iteratorNameSym,
                  strName, identName, futureVarCompletions: untyped) =
  bind finished
  let retFutUnown = unown retFutureSym

  var nameIterVar = iteratorNameSym
  proc identName {.closure.} =
    try:
      if not nameIterVar.finished:
        var next = unown nameIterVar()
        # Continue while the yielded future is already finished.
        while (not next.isNil) and next.finished:
          next = unown nameIterVar()
          if nameIterVar.finished:
            break

        if next == nil:
          if not retFutUnown.finished:
            let msg = "Async procedure ($1) yielded `nil`, are you await'ing a " &
                    "`nil` Future?"
            raise newException(AssertionError, msg % strName)
        else:
          {.gcsafe.}:
            {.push hint[ConvFromXtoItselfNotNeeded]: off.}
            next.callback = (proc() {.closure, gcsafe.})(identName)
            {.pop.}
    except:
      futureVarCompletions
      if retFutUnown.finished:
        # Take a look at tasyncexceptions for the bug which this fixes.
        # That test explains it better than I can here.
        raise
      else:
        retFutUnown.fail(getCurrentException())
  identName()

template useVar(result: var NimNode, futureVarNode: NimNode, valueReceiver,
                rootReceiver: untyped, fromNode: NimNode) =
  ## Params:
  ##    futureVarNode: The NimNode which is a symbol identifying the Future[T]
  ##                   variable to yield.
  ##    fromNode: Used for better debug information (to give context).
  ##    valueReceiver: The node which defines an expression that retrieves the
  ##                   future's value.
  ##
  ##    rootReceiver: ??? TODO
  # -> yield future<x>
  result.add newNimNode(nnkYieldStmt, fromNode).add(futureVarNode)
  # -> future<x>.read
  valueReceiver = newDotExpr(futureVarNode, newIdentNode("read"))
  result.add rootReceiver

template createVar(result: var NimNode, futSymName: string,
                   asyncProc: NimNode,
                   valueReceiver, rootReceiver: untyped,
                   fromNode: NimNode) =
  result = newNimNode(nnkStmtList, fromNode)
  var futSym = genSym(nskVar, "future")
  result.add newVarStmt(futSym, asyncProc) # -> var future<x> = y
  useVar(result, futSym, valueReceiver, rootReceiver, fromNode)

proc createFutureVarCompletions(futureVarIdents: seq[NimNode],
    fromNode: NimNode): NimNode {.compileTime.} =
  result = newNimNode(nnkStmtList, fromNode)
  # Add calls to complete each FutureVar parameter.
  for ident in futureVarIdents:
    # Only complete them if they have not been completed already by the user.
    # TODO: Once https://github.com/nim-lang/Nim/issues/5617 is fixed.
    # TODO: Add line info to the complete() call!
    # In the meantime, this was really useful for debugging :)
    #result.add(newCall(newIdentNode("echo"), newStrLitNode(fromNode.lineinfo)))
    result.add newIfStmt(
      (
        newCall(newIdentNode("not"),
                newDotExpr(ident, newIdentNode("finished"))),
        newCall(newIdentNode("complete"), ident)
      )
    )

proc processBody(node, retFutureSym: NimNode,
                 subTypeIsVoid: bool,
                 futureVarIdents: seq[NimNode]): NimNode {.compileTime.} =
  #echo(node.treeRepr)
  result = node
  case node.kind
  of nnkReturnStmt:
    result = newNimNode(nnkStmtList, node)

    # As I've painfully found out, the order here really DOES matter.
    result.add createFutureVarCompletions(futureVarIdents, node)

    if node[0].kind == nnkEmpty:
      if not subTypeIsVoid:
        result.add newCall(newIdentNode("complete"), retFutureSym,
            newIdentNode("result"))
      else:
        result.add newCall(newIdentNode("complete"), retFutureSym)
    else:
      let x = node[0].processBody(retFutureSym, subTypeIsVoid,
                                  futureVarIdents)
      if x.kind == nnkYieldStmt: result.add x
      else:
        result.add newCall(newIdentNode("complete"), retFutureSym, x)

    result.add newNimNode(nnkReturnStmt, node).add(newNilLit())
    return # Don't process the children of this return stmt
  of nnkCommand, nnkCall:
    if node[0].eqIdent("await"):
      case node[1].kind
      of nnkIdent, nnkInfix, nnkDotExpr, nnkCall, nnkCommand:
        # await x
        # await x or y
        # await foo(p, x)
        # await foo p, x
        var futureValue: NimNode
        result.createVar("future" & $node[1][0].toStrLit, node[1], futureValue,
                  futureValue, node)
      else:
        error("Invalid node kind in 'await', got: " & $node[1].kind)
    elif node.len > 1 and node[1].kind == nnkCommand and
         node[1][0].eqIdent("await"):
      # foo await x
      var newCommand = node
      result.createVar("future" & $node[0].toStrLit, node[1][1], newCommand[1],
                newCommand, node)

  of nnkVarSection, nnkLetSection:
    case node[0][^1].kind
    of nnkCommand:
      if node[0][^1][0].eqIdent("await"):
        # var x = await y
        var newVarSection = node # TODO: Should this use copyNimNode?
        result.createVar("future" & node[0][0].strVal, node[0][^1][1],
          newVarSection[0][^1], newVarSection, node)
    else: discard
  of nnkAsgn:
    case node[1].kind
    of nnkCommand:
      if node[1][0].eqIdent("await"):
        # x = await y
        var newAsgn = node
        result.createVar("future" & $node[0].toStrLit, node[1][1], newAsgn[1],
            newAsgn, node)
    else: discard
  of nnkDiscardStmt:
    # discard await x
    if node[0].kind == nnkCommand and
          node[0][0].eqIdent("await"):
      var newDiscard = node
      result.createVar("futureDiscard_" & $toStrLit(node[0][1]), node[0][1],
                newDiscard[0], newDiscard, node)
  of RoutineNodes-{nnkTemplateDef}:
    # skip all the nested procedure definitions
    return
  else: discard

  for i in 0 ..< result.len:
    result[i] = processBody(result[i], retFutureSym, subTypeIsVoid,
                            futureVarIdents)

proc getName(node: NimNode): string {.compileTime.} =
  case node.kind
  of nnkPostfix:
    return node[1].strVal
  of nnkIdent, nnkSym:
    return node.strVal
  of nnkEmpty:
    return "anonymous"
  else:
    error("Unknown name.")

proc getFutureVarIdents(params: NimNode): seq[NimNode] {.compileTime.} =
  result = @[]
  for i in 1 ..< len(params):
    expectKind(params[i], nnkIdentDefs)
    if params[i][1].kind == nnkBracketExpr and
       params[i][1][0].eqIdent("futurevar"):
      result.add(params[i][0])

proc isInvalidReturnType(typeName: string): bool =
  return typeName notin ["Future"] #, "FutureStream"]

proc verifyReturnType(typeName: string) {.compileTime.} =
  if typeName.isInvalidReturnType:
    error("Expected return type of 'Future' got '$1'" %
          typeName)

proc asyncSingleProc(prc: NimNode): NimNode {.compileTime.} =
  ## This macro transforms a single procedure into a closure iterator.
  ## The ``async`` macro supports a stmtList holding multiple async procedures.
  if prc.kind notin {nnkProcDef, nnkLambda, nnkMethodDef, nnkDo}:
    error("Cannot transform this node kind into an async proc." &
          " proc/method definition or lambda node expected.")

  let prcName = prc.name.getName

  var returnType = prc.params[0]
  var baseType: NimNode
  if returnType.kind in nnkCallKinds and returnType[0].eqIdent("owned") and
      returnType.len == 2:
    returnType = returnType[1]
  # Verify that the return type is a Future[T]
  if returnType.kind == nnkBracketExpr:
    let fut = repr(returnType[0])
    verifyReturnType(fut)
    baseType = returnType[1]
  elif returnType.kind in nnkCallKinds and returnType[0].eqIdent("[]"):
    let fut = repr(returnType[1])
    verifyReturnType(fut)
    baseType = returnType[2]
  elif returnType.kind == nnkEmpty:
    baseType = returnType
  else:
    verifyReturnType(repr(returnType))

  let subtypeIsVoid = returnType.kind == nnkEmpty or
        (baseType.kind == nnkIdent and returnType[1].eqIdent("void"))

  let futureVarIdents = getFutureVarIdents(prc.params)

  var outerProcBody = newNimNode(nnkStmtList, prc.body)

  # Extract the documentation comment from the original procedure declaration.
  # Note that we're not removing it from the body in order not to make this
  # transformation even more complex.
  if prc.body.len > 1 and prc.body[0].kind == nnkCommentStmt:
    outerProcBody.add(prc.body[0])

  # -> var retFuture = newFuture[T]()
  var retFutureSym = genSym(nskVar, "retFuture")
  var subRetType =
    if returnType.kind == nnkEmpty: newIdentNode("void")
    else: baseType
  outerProcBody.add(
    newVarStmt(retFutureSym,
      newCall(
        newNimNode(nnkBracketExpr, prc.body).add(
          newIdentNode("newFuture"),
          subRetType),
      newLit(prcName)))) # Get type from return type of this proc

  # -> iterator nameIter(): FutureBase {.closure.} =
  # ->   {.push warning[resultshadowed]: off.}
  # ->   var result: T
  # ->   {.pop.}
  # ->   <proc_body>
  # ->   complete(retFuture, result)
  var iteratorNameSym = genSym(nskIterator, $prcName & "Iter")
  var procBody = prc.body.processBody(retFutureSym, subtypeIsVoid,
                                    futureVarIdents)
  # don't do anything with forward bodies (empty)
  if procBody.kind != nnkEmpty:
    procBody.add(createFutureVarCompletions(futureVarIdents, nil))

    if not subtypeIsVoid:
      procBody.insert(0, newNimNode(nnkPragma).add(newIdentNode("push"),
        newNimNode(nnkExprColonExpr).add(newNimNode(nnkBracketExpr).add(
          newIdentNode("warning"), newIdentNode("resultshadowed")),
        newIdentNode("off")))) # -> {.push warning[resultshadowed]: off.}

      procBody.insert(1, newNimNode(nnkVarSection, prc.body).add(
        newIdentDefs(newIdentNode("result"), baseType))) # -> var result: T

      procBody.insert(2, newNimNode(nnkPragma).add(
        newIdentNode("pop"))) # -> {.pop.})

      procBody.add(
        newCall(newIdentNode("complete"),
          retFutureSym, newIdentNode("result"))) # -> complete(retFuture, result)
    else:
      # -> complete(retFuture)
      procBody.add(newCall(newIdentNode("complete"), retFutureSym))

    var closureIterator = newProc(iteratorNameSym, [parseExpr("owned(FutureBase)")],
                                  procBody, nnkIteratorDef)
    closureIterator.pragma = newNimNode(nnkPragma, lineInfoFrom = prc.body)
    closureIterator.addPragma(newIdentNode("closure"))

    # If proc has an explicit gcsafe pragma, we add it to iterator as well.
    if prc.pragma.findChild(it.kind in {nnkSym, nnkIdent} and $it ==
        "gcsafe") != nil:
      closureIterator.addPragma(newIdentNode("gcsafe"))
    outerProcBody.add(closureIterator)

    # -> createCb(retFuture)
    # NOTE: The NimAsyncContinueSuffix is checked for in asyncfutures.nim to produce
    # friendlier stack traces:
    var cbName = genSym(nskProc, prcName & NimAsyncContinueSuffix)
    var procCb = getAst createCb(retFutureSym, iteratorNameSym,
                         newStrLitNode(prcName),
                         cbName,
                         createFutureVarCompletions(futureVarIdents, nil))
    outerProcBody.add procCb

    # -> return retFuture
    outerProcBody.add newNimNode(nnkReturnStmt, prc.body[^1]).add(retFutureSym)

  result = prc

  if subtypeIsVoid:
    # Add discardable pragma.
    if returnType.kind == nnkEmpty:
      # Add Future[void]
      result.params[0] = parseExpr("owned(Future[void])")
  if procBody.kind != nnkEmpty:
    result.body = outerProcBody
  #echo(treeRepr(result))
  #if prcName == "recvLineInto":
  #  echo(toStrLit(result))

macro async*(prc: untyped): untyped =
  ## Macro which processes async procedures into the appropriate
  ## iterators and yield statements.
  if prc.kind == nnkStmtList:
    result = newStmtList()
    for oneProc in prc:
      result.add asyncSingleProc(oneProc)
  else:
    result = asyncSingleProc(prc)
  when defined(nimDumpAsync):
    echo repr result


# Multisync
proc emptyNoop[T](x: T): T =
  # The ``await``s are replaced by a call to this for simplicity.
  when T isnot void:
    return x

proc stripAwait(node: NimNode): NimNode =
  ## Strips out all ``await`` commands from a procedure body, replaces them
  ## with ``emptyNoop`` for simplicity.
  result = node

  let emptyNoopSym = bindSym("emptyNoop")

  case node.kind
  of nnkCommand, nnkCall:
    if node[0].eqIdent("await"):
      node[0] = emptyNoopSym
    elif node.len > 1 and node[1].kind == nnkCommand and node[1][0].eqIdent("await"):
      # foo await x
      node[1][0] = emptyNoopSym
  of nnkVarSection, nnkLetSection:
    case node[0][^1].kind
    of nnkCommand:
      if node[0][^1][0].eqIdent("await"):
        # var x = await y
        node[0][^1][0] = emptyNoopSym
    else: discard
  of nnkAsgn:
    case node[1].kind
    of nnkCommand:
      if node[1][0].eqIdent("await"):
        # x = await y
        node[1][0] = emptyNoopSym
    else: discard
  of nnkDiscardStmt:
    # discard await x
    if node[0].kind == nnkCommand and node[0][0].eqIdent("await"):
      node[0][0] = emptyNoopSym
  else: discard

  for i in 0 ..< result.len:
    result[i] = stripAwait(result[i])

proc splitParamType(paramType: NimNode, async: bool): NimNode =
  result = paramType
  if paramType.kind == nnkInfix and paramType[0].strVal in ["|", "or"]:
    let firstAsync = "async" in paramType[1].strVal.normalize
    let secondAsync = "async" in paramType[2].strVal.normalize

    if firstAsync:
      result = paramType[if async: 1 else: 2]
    elif secondAsync:
      result = paramType[if async: 2 else: 1]

proc stripReturnType(returnType: NimNode): NimNode =
  # Strip out the 'Future' from 'Future[T]'.
  result = returnType
  if returnType.kind == nnkBracketExpr:
    let fut = repr(returnType[0])
    verifyReturnType(fut)
    result = returnType[1]

proc splitProc(prc: NimNode): (NimNode, NimNode) =
  ## Takes a procedure definition which takes a generic union of arguments,
  ## for example: proc (socket: Socket | AsyncSocket).
  ## It transforms them so that ``proc (socket: Socket)`` and
  ## ``proc (socket: AsyncSocket)`` are returned.

  result[0] = prc.copyNimTree()
  # Retrieve the `T` inside `Future[T]`.
  let returnType = stripReturnType(result[0][3][0])
  result[0][3][0] = splitParamType(returnType, async = false)
  for i in 1 ..< result[0][3].len:
    # Sync proc (0) -> FormalParams (3) -> IdentDefs, the parameter (i) ->
    # parameter type (1).
    result[0][3][i][1] = splitParamType(result[0][3][i][1], async = false)
  result[0][6] = stripAwait(result[0][6])

  result[1] = prc.copyNimTree()
  if result[1][3][0].kind == nnkBracketExpr:
    result[1][3][0][1] = splitParamType(result[1][3][0][1], async = true)
  for i in 1 ..< result[1][3].len:
    # Async proc (1) -> FormalParams (3) -> IdentDefs, the parameter (i) ->
    # parameter type (1).
    result[1][3][i][1] = splitParamType(result[1][3][i][1], async = true)

macro multisync*(prc: untyped): untyped =
  ## Macro which processes async procedures into both asynchronous and
  ## synchronous procedures.
  ##
  ## The generated async procedures use the ``async`` macro, whereas the
  ## generated synchronous procedures simply strip off the ``await`` calls.
  let (sync, asyncPrc) = splitProc(prc)
  result = newStmtList()
  result.add(asyncSingleProc(asyncPrc))
  result.add(sync)

proc await*[T](x: T) =
  ## The 'await' keyword is also defined here for technical
  ## reasons. (Generic symbol lookup prepass.)
  {.error: "Await only available within .async".}
                     
                         
                               

                                                          
                                   



                                    
                                           




                                                                    
                                                 



                                      
                                     







                           

              
#
#
#            Nim's Runtime Library
#        (c) Copyright 2014 Dominik Picheta
#
#    See the file "copying.txt", included in this
#    distribution, for details about the copyright.
#

## This module implements a high-level asynchronous sockets API based on the
## asynchronous dispatcher defined in the ``asyncdispatch`` module.
##
## SSL
## ---
##
## SSL can be enabled by compiling with the ``-d:ssl`` flag.
##
## You must create a new SSL context with the ``newContext`` function defined
## in the ``net`` module. You may then call ``wrapSocket`` on your socket using
## the newly created SSL context to get an SSL socket.
##
## Examples
## --------
##
## Chat server
## ^^^^^^^^^^^
## 
## The following example demonstrates a simple chat server.
##
## .. code-block::nim
##
##   import asyncnet, asyncdispatch
##
##   var clients {.threadvar.}: seq[AsyncSocket]
##
##   proc processClient(client: AsyncSocket) {.async.} =
##     while true:
##       let line = await client.recvLine()
##       for c in clients:
##         await c.send(line & "\c\L")
##
##   proc serve() {.async.} =
##     clients = @[]
##     var server = newAsyncSocket()
##     server.bindAddr(Port(12345))
##     server.listen()
##
##     while true:
##       let client = await server.accept()
##       clients.add client
##
##       asyncCheck processClient(client)
##
##   asyncCheck serve()
##   runForever()
##

import asyncdispatch
import rawsockets
import net
import os

when defined(ssl):
  import openssl

type
  # TODO: I would prefer to just do:
  # PAsyncSocket* {.borrow: `.`.} = distinct PSocket. But that doesn't work.
  AsyncSocketDesc  = object
    fd*: SocketHandle
    closed*: bool ## determines whether this socket has been closed
    case isBuffered*: bool ## determines whether this socket is buffered.
    of true:
      buffer*: array[0..BufferSize, char]
      currPos*: int # current index in buffer
      bufLen*: int # current length of buffer
    of false: nil
    case isSsl: bool
    of true:
      when defined(ssl):
        sslHandle: SslPtr
        sslContext: SslContext
        bioIn: BIO
        bioOut: BIO
    of false: nil
  AsyncSocket* = ref AsyncSocketDesc

{.deprecated: [PAsyncSocket: AsyncSocket].}

# TODO: Save AF, domain etc info and reuse it in procs which need it like connect.

proc newSocket(fd: TAsyncFD, isBuff: bool): PAsyncSocket =
  assert fd != osInvalidSocket.TAsyncFD
  new(result)
  result.fd = fd.SocketHandle
  result.isBuffered = isBuff
  if isBuff:
    result.currPos = 0

proc newAsyncSocket*(domain: TDomain = AF_INET, typ: TType = SOCK_STREAM,
    protocol: TProtocol = IPPROTO_TCP, buffered = true): PAsyncSocket =
  ## Creates a new asynchronous socket.
  result = newSocket(newAsyncRawSocket(domain, typ, protocol), buffered)

proc newAsyncSocket*(domain, typ, protocol: cint, buffered = true): PAsyncSocket =
  ## Creates a new asynchronous socket.
  result = newSocket(newAsyncRawSocket(domain, typ, protocol), buffered)

when defined(ssl):
  proc getSslError(handle: SslPtr, err: cint): cint =
    assert err < 0
    var ret = SSLGetError(handle, err.cint)
    case ret
    of SSL_ERROR_ZERO_RETURN:
      raiseSSLError("TLS/SSL connection failed to initiate, socket closed prematurely.")
    of SSL_ERROR_WANT_CONNECT, SSL_ERROR_WANT_ACCEPT:
      return ret
    of SSL_ERROR_WANT_WRITE, SSL_ERROR_WANT_READ:
      return ret
    of SSL_ERROR_WANT_X509_LOOKUP:
      raiseSSLError("Function for x509 lookup has been called.")
    of SSL_ERROR_SYSCALL, SSL_ERROR_SSL:
      raiseSSLError()
    else: raiseSSLError("Unknown Error")

  proc sendPendingSslData(socket: AsyncSocket,
      flags: set[TSocketFlags]) {.async.} =
    let len = bioCtrlPending(socket.bioOut)
    if len > 0:
      var data = newStringOfCap(len)
      let read = bioRead(socket.bioOut, addr data[0], len)
      assert read != 0
      if read < 0:
        raiseSslError()
      data.setLen(read)
      await socket.fd.TAsyncFd.send(data, flags)

  proc appeaseSsl(socket: AsyncSocket, flags: set[TSocketFlags],
                  sslError: cint) {.async.} =
    case sslError
    of SSL_ERROR_WANT_WRITE:
      await sendPendingSslData(socket, flags)
    of SSL_ERROR_WANT_READ:
      var data = await recv(socket.fd.TAsyncFD, BufferSize, flags)
      let ret = bioWrite(socket.bioIn, addr data[0], data.len.cint)
      if ret < 0:
        raiseSSLError()
    else:
      raiseSSLError("Cannot appease SSL.")

  template sslLoop(socket: AsyncSocket, flags: set[TSocketFlags],
                   op: expr) =
    var opResult {.inject.} = -1.cint
    while opResult < 0:
      opResult = op
      # Bit hackish here.
      # TODO: Introduce an async template transformation pragma?
      yield sendPendingSslData(socket, flags)
      if opResult < 0:
        let err = getSslError(socket.sslHandle, opResult.cint)
        yield appeaseSsl(socket, flags, err.cint)

proc connect*(socket: PAsyncSocket, address: string, port: TPort,
    af = AF_INET) {.async.} =
  ## Connects ``socket`` to server at ``address:port``.
  ##
  ## Returns a ``Future`` which will complete when the connection succeeds
  ## or an error occurs.
  await connect(socket.fd.TAsyncFD, address, port, af)
  let flags = {TSocketFlags.SafeDisconn}
  if socket.isSsl:
    when defined(ssl):
      sslSetConnectState(socket.sslHandle)
      sslLoop(socket, flags, sslDoHandshake(socket.sslHandle))

proc readInto(buf: cstring, size: int, socket: PAsyncSocket,
              flags: set[TSocketFlags]): Future[int] {.async.} =
  if socket.isSsl:
    when defined(ssl):
      # SSL mode.
      sslLoop(socket, flags,
        sslRead(socket.sslHandle, buf, size.cint))
      result = opResult
  else:
    var data = await recv(socket.fd.TAsyncFD, size, flags)
    if data.len != 0:
      copyMem(buf, addr data[0], data.len)
    # Not in SSL mode.
    result = data.len

proc readIntoBuf(socket: PAsyncSocket,
    flags: set[TSocketFlags]): Future[int] {.async.} =
  result = await readInto(addr socket.buffer[0], BufferSize, socket, flags)
  socket.currPos = 0
  socket.bufLen = result

proc recv*(socket: PAsyncSocket, size: int,
           flags = {TSocketFlags.SafeDisconn}): Future[string] {.async.} =
  ## Reads **up to** ``size`` bytes from ``socket``.
  ##
  ## For buffered sockets this function will attempt to read all the requested
  ## data. It will read this data in ``BufferSize`` chunks.
  ##
  ## For unbuffered sockets this function makes no effort to read
  ## all the data requested. It will return as much data as the operating system
  ## gives it.
  ##
  ## If socket is disconnected during the
  ## recv operation then the future may complete with only a part of the
  ## requested data.
  ##
  ## If socket is disconnected and no data is available
  ## to be read then the future will complete with a value of ``""``.
  if socket.isBuffered:
    result = newString(size)
    let originalBufPos = socket.currPos

    if socket.bufLen == 0:
      let res = await socket.readIntoBuf(flags - {TSocketFlags.Peek})
      if res == 0:
        result.setLen(0)
        return

    var read = 0
    while read < size:
      if socket.currPos >= socket.bufLen:
        if TSocketFlags.Peek in flags:
          # We don't want to get another buffer if we're peeking.
          break
        let res = await socket.readIntoBuf(flags - {TSocketFlags.Peek})
        if res == 0:
          break

      let chunk = min(socket.bufLen-socket.currPos, size-read)
      copyMem(addr(result[read]), addr(socket.buffer[socket.currPos]), chunk)
      read.inc(chunk)
      socket.currPos.inc(chunk)

    if TSocketFlags.Peek in flags:
      # Restore old buffer cursor position.
      socket.currPos = originalBufPos
    result.setLen(read)
  else:
    result = newString(size)
    let read = await readInto(addr result[0], size, socket, flags)
    result.setLen(read)

proc send*(socket: PAsyncSocket, data: string,
           flags = {TSocketFlags.SafeDisconn}) {.async.} =
  ## Sends ``data`` to ``socket``. The returned future will complete once all
  ## data has been sent.
  assert socket != nil
  if socket.isSsl:
    when defined(ssl):
      var copy = data
      sslLoop(socket, flags,
        sslWrite(socket.sslHandle, addr copy[0], copy.len.cint))
      await sendPendingSslData(socket, flags)
  else:
    await send(socket.fd.TAsyncFD, data, flags)

proc acceptAddr*(socket: PAsyncSocket, flags = {TSocketFlags.SafeDisconn}):
      Future[tuple[address: string, client: PAsyncSocket]] =
  ## Accepts a new connection. Returns a future containing the client socket
  ## corresponding to that connection and the remote address of the client.
  ## The future will complete when the connection is successfully accepted.
  var retFuture = newFuture[tuple[address: string, client: PAsyncSocket]]("asyncnet.acceptAddr")
  var fut = acceptAddr(socket.fd.TAsyncFD, flags)
  fut.callback =
    proc (future: Future[tuple[address: string, client: TAsyncFD]]) =
      assert future.finished
      if future.failed:
        retFuture.fail(future.readError)
      else:
        let resultTup = (future.read.address,
                         newSocket(future.read.client, socket.isBuffered))
        retFuture.complete(resultTup)
  return retFuture

proc accept*(socket: PAsyncSocket,
    flags = {TSocketFlags.SafeDisconn}): Future[PAsyncSocket] =
  ## Accepts a new connection. Returns a future containing the client socket
  ## corresponding to that connection.
  ## The future will complete when the connection is successfully accepted.
  var retFut = newFuture[PAsyncSocket]("asyncnet.accept")
  var fut = acceptAddr(socket, flags)
  fut.callback =
    proc (future: Future[tuple[address: string, client: PAsyncSocket]]) =
      assert future.finished
      if future.failed:
        retFut.fail(future.readError)
      else:
        retFut.complete(future.read.client)
  return retFut

proc recvLine*(socket: PAsyncSocket,
    flags = {TSocketFlags.SafeDisconn}): Future[string] {.async.} =
  ## Reads a line of data from ``socket``. Returned future will complete once
  ## a full line is read or an error occurs.
  ##
  ## If a full line is read ``\r\L`` is not
  ## added to ``line``, however if solely ``\r\L`` is read then ``line``
  ## will be set to it.
  ## 
  ## If the socket is disconnected, ``line`` will be set to ``""``.
  ##
  ## If the socket is disconnected in the middle of a line (before ``\r\L``
  ## is read) then line will be set to ``""``.
  ## The partial line **will be lost**.
  ##
  ## **Warning**: The ``Peek`` flag is not yet implemented.
  ## 
  ## **Warning**: ``recvLine`` on unbuffered sockets assumes that the protocol
  ## uses ``\r\L`` to delimit a new line.
  template addNLIfEmpty(): stmt =
    if result.len == 0:
      result.add("\c\L")
  assert TSocketFlags.Peek notin flags ## TODO:
  if socket.isBuffered:
    result = ""
    if socket.bufLen == 0:
      let res = await socket.readIntoBuf(flags)
      if res == 0:
        return

    var lastR = false
    while true:
      if socket.currPos >= socket.bufLen:
        let res = await socket.readIntoBuf(flags)
        if res == 0:
          result = ""
          break

      case socket.buffer[socket.currPos]
      of '\r':
        lastR = true
        addNLIfEmpty()
      of '\L':
        addNLIfEmpty()
        socket.currPos.inc()
        return
      else:
        if lastR:
          socket.currPos.inc()
          return
        else:
          result.add socket.buffer[socket.currPos]
      socket.currPos.inc()
  else:
    result = ""
    var c = ""
    while true:
      c = await recv(socket, 1, flags)
      if c.len == 0:
        return ""
      if c == "\r":
        c = await recv(socket, 1, flags) # Skip \L
        assert c == "\L"
        addNLIfEmpty()
        return
      elif c == "\L":
        addNLIfEmpty()
        return
      add(result.string, c)

proc listen*(socket: PAsyncSocket, backlog = SOMAXCONN) {.tags: [ReadIOEffect].} =
  ## Marks ``socket`` as accepting connections.
  ## ``Backlog`` specifies the maximum length of the
  ## queue of pending connections.
  ##
  ## Raises an EOS error upon failure.
  if listen(socket.fd, backlog) < 0'i32: raiseOSError(osLastError())

proc bindAddr*(socket: PAsyncSocket, port = Port(0), address = "") {.
  tags: [ReadIOEffect].} =
  ## Binds ``address``:``port`` to the socket.
  ##
  ## If ``address`` is "" then ADDR_ANY will be bound.

  if address == "":
    var name: Sockaddr_in
    when defined(Windows) or defined(nimdoc):
      name.sin_family = toInt(AF_INET).int16
    else:
      name.sin_family = toInt(AF_INET)
    name.sin_port = htons(int16(port))
    name.sin_addr.s_addr = htonl(INADDR_ANY)
    if bindAddr(socket.fd, cast[ptr SockAddr](addr(name)),
                  sizeof(name).Socklen) < 0'i32:
      raiseOSError(osLastError())
  else:
    var aiList = getAddrInfo(address, port, AF_INET)
    if bindAddr(socket.fd, aiList.ai_addr, aiList.ai_addrlen.Socklen) < 0'i32:
      dealloc(aiList)
      raiseOSError(osLastError())
    dealloc(aiList)

proc close*(socket: PAsyncSocket) =
  ## Closes the socket.
  socket.fd.TAsyncFD.closeSocket()
  when defined(ssl):
    if socket.isSSL:
      let res = SslShutdown(socket.sslHandle)
      if res == 0:
        if SslShutdown(socket.sslHandle) != 1:
          raiseSslError()
      elif res != 1:
        raiseSslError()
  socket.closed = true # TODO: Add extra debugging checks for this.

when defined(ssl):
  proc wrapSocket*(ctx: SslContext, socket: AsyncSocket) =
    ## Wraps a socket in an SSL context. This function effectively turns
    ## ``socket`` into an SSL socket.
    ##
    ## **Disclaimer**: This code is not well tested, may be very unsafe and
    ## prone to security vulnerabilities.
    socket.isSsl = true
    socket.sslContext = ctx
    socket.sslHandle = SSLNew(PSSLCTX(socket.sslContext))
    if socket.sslHandle == nil:
      raiseSslError()

    socket.bioIn = bioNew(bio_s_mem())
    socket.bioOut = bioNew(bio_s_mem())
    sslSetBio(socket.sslHandle, socket.bioIn, socket.bioOut)


when isMainModule:
  type
    TestCases = enum
      HighClient, LowClient, LowServer

  const test = HighClient

  when test == HighClient:
    proc main() {.async.} =
      var sock = newAsyncSocket()
      await sock.connect("irc.freenode.net", TPort(6667))
      while true:
        let line = await sock.recvLine()
        if line == "":
          echo("Disconnected")
          break
        else:
          echo("Got line: ", line)
    asyncCheck main()
  elif test == LowClient:
    var sock = newAsyncSocket()
    var f = connect(sock, "irc.freenode.net", TPort(6667))
    f.callback =
      proc (future: Future[void]) =
        echo("Connected in future!")
        for i in 0 .. 50:
          var recvF = recv(sock, 10)
          recvF.callback =
            proc (future: Future[string]) =
              echo("Read ", future.read.len, ": ", future.read.repr)
  elif test == LowServer:
    var sock = newAsyncSocket()
    sock.bindAddr(TPort(6667))
    sock.listen()
    proc onAccept(future: Future[PAsyncSocket]) =
      let client = future.read
      echo "Accepted ", client.fd.cint
      var t = send(client, "test\c\L")
      t.callback =
        proc (future: Future[void]) =
          echo("Send")
          client.close()
      
      var f = accept(sock)
      f.callback = onAccept
      
    var f = accept(sock)
    f.callback = onAccept
  runForever()