diff options
Diffstat (limited to 'lib/pure')
-rw-r--r-- | lib/pure/asyncdispatch.nim | 180 | ||||
-rw-r--r-- | lib/pure/asyncfile.nim | 37 | ||||
-rw-r--r-- | lib/pure/asyncnet.nim | 20 | ||||
-rw-r--r-- | lib/pure/base64.nim | 41 | ||||
-rw-r--r-- | lib/pure/collections/tables.nim | 45 | ||||
-rw-r--r-- | lib/pure/cookies.nim | 4 | ||||
-rw-r--r-- | lib/pure/httpclient.nim | 55 | ||||
-rw-r--r-- | lib/pure/includes/asynccommon.nim | 201 | ||||
-rw-r--r-- | lib/pure/json.nim | 532 | ||||
-rw-r--r-- | lib/pure/nativesockets.nim | 79 | ||||
-rw-r--r-- | lib/pure/net.nim | 411 | ||||
-rw-r--r-- | lib/pure/oids.nim | 3 | ||||
-rw-r--r-- | lib/pure/os.nim | 2 | ||||
-rw-r--r-- | lib/pure/ospaths.nim | 2 | ||||
-rw-r--r-- | lib/pure/osproc.nim | 48 | ||||
-rw-r--r-- | lib/pure/parseutils.nim | 2 | ||||
-rw-r--r-- | lib/pure/strutils.nim | 2 | ||||
-rw-r--r-- | lib/pure/times.nim | 34 | ||||
-rw-r--r-- | lib/pure/uri.nim | 21 |
19 files changed, 1264 insertions, 455 deletions
diff --git a/lib/pure/asyncdispatch.nim b/lib/pure/asyncdispatch.nim index 1696c4ed9..1697384e0 100644 --- a/lib/pure/asyncdispatch.nim +++ b/lib/pure/asyncdispatch.nim @@ -9,7 +9,7 @@ include "system/inclrtl" -import os, oids, tables, strutils, times, heapqueue +import os, tables, strutils, times, heapqueue, options import nativesockets, net, deques @@ -242,6 +242,11 @@ when defined(windows) or defined(nimdoc): if gDisp.isNil: gDisp = newDispatcher() result = gDisp + proc setGlobalDispatcher*(disp: PDispatcher) = + if not gDisp.isNil: + assert gDisp.callbacks.len == 0 + gDisp = disp + proc register*(fd: AsyncFD) = ## Registers ``fd`` with the dispatcher. let p = getGlobalDispatcher() @@ -385,68 +390,6 @@ when defined(windows) or defined(nimdoc): dwRemoteAddressLength, LocalSockaddr, LocalSockaddrLength, RemoteSockaddr, RemoteSockaddrLength) - proc connect*(socket: AsyncFD, address: string, port: Port, - domain = nativesockets.AF_INET): Future[void] = - ## Connects ``socket`` to server at ``address:port``. - ## - ## Returns a ``Future`` which will complete when the connection succeeds - ## or an error occurs. - verifyPresence(socket) - var retFuture = newFuture[void]("connect") - # Apparently ``ConnectEx`` expects the socket to be initially bound: - var saddr: Sockaddr_in - saddr.sin_family = int16(toInt(domain)) - saddr.sin_port = 0 - saddr.sin_addr.s_addr = INADDR_ANY - if bindAddr(socket.SocketHandle, cast[ptr SockAddr](addr(saddr)), - sizeof(saddr).SockLen) < 0'i32: - raiseOSError(osLastError()) - - var aiList = getAddrInfo(address, port, domain) - var success = false - var lastError: OSErrorCode - var it = aiList - while it != nil: - # "the OVERLAPPED structure must remain valid until the I/O completes" - # http://blogs.msdn.com/b/oldnewthing/archive/2011/02/02/10123392.aspx - var ol = PCustomOverlapped() - GC_ref(ol) - ol.data = CompletionData(fd: socket, cb: - proc (fd: AsyncFD, bytesCount: Dword, errcode: OSErrorCode) = - if not retFuture.finished: - if errcode == OSErrorCode(-1): - retFuture.complete() - else: - retFuture.fail(newException(OSError, osErrorMsg(errcode))) - ) - - var ret = connectEx(socket.SocketHandle, it.ai_addr, - sizeof(Sockaddr_in).cint, nil, 0, nil, - cast[POVERLAPPED](ol)) - if ret: - # Request to connect completed immediately. - success = true - retFuture.complete() - # We don't deallocate ``ol`` here because even though this completed - # immediately poll will still be notified about its completion and it will - # free ``ol``. - break - else: - lastError = osLastError() - if lastError.int32 == ERROR_IO_PENDING: - # In this case ``ol`` will be deallocated in ``poll``. - success = true - break - else: - GC_unref(ol) - success = false - it = it.ai_next - - freeAddrInfo(aiList) - if not success: - retFuture.fail(newException(OSError, osErrorMsg(lastError))) - return retFuture - proc recv*(socket: AsyncFD, size: int, flags = {SocketFlag.SafeDisconn}): Future[string] = ## Reads **up to** ``size`` bytes from ``socket``. Returned future will @@ -754,8 +697,8 @@ when defined(windows) or defined(nimdoc): var lpOutputBuf = newString(lpOutputLen) var dwBytesReceived: Dword let dwReceiveDataLength = 0.Dword # We don't want any data to be read. - let dwLocalAddressLength = Dword(sizeof(Sockaddr_in) + 16) - let dwRemoteAddressLength = Dword(sizeof(Sockaddr_in) + 16) + let dwLocalAddressLength = Dword(sizeof(Sockaddr_in6) + 16) + let dwRemoteAddressLength = Dword(sizeof(Sockaddr_in6) + 16) template failAccept(errcode) = if flags.isDisconnectionError(errcode): @@ -785,12 +728,14 @@ when defined(windows) or defined(nimdoc): dwLocalAddressLength, dwRemoteAddressLength, addr localSockaddr, addr localLen, addr remoteSockaddr, addr remoteLen) - register(clientSock.AsyncFD) - # TODO: IPv6. Check ``sa_family``. http://stackoverflow.com/a/9212542/492186 - retFuture.complete( - (address: $inet_ntoa(cast[ptr Sockaddr_in](remoteSockAddr).sin_addr), - client: clientSock.AsyncFD) - ) + try: + let address = getAddrString(remoteSockAddr) + register(clientSock.AsyncFD) + retFuture.complete((address: address, client: clientSock.AsyncFD)) + except: + # getAddrString may raise + clientSock.close() + retFuture.fail(getCurrentException()) var ol = PCustomOverlapped() GC_ref(ol) @@ -823,20 +768,6 @@ when defined(windows) or defined(nimdoc): return retFuture - proc newAsyncNativeSocket*(domain, sockType, protocol: cint): AsyncFD = - ## Creates a new socket and registers it with the dispatcher implicitly. - result = newNativeSocket(domain, sockType, protocol).AsyncFD - result.SocketHandle.setBlocking(false) - register(result) - - proc newAsyncNativeSocket*(domain: Domain = nativesockets.AF_INET, - sockType: SockType = SOCK_STREAM, - protocol: Protocol = IPPROTO_TCP): AsyncFD = - ## Creates a new socket and registers it with the dispatcher implicitly. - result = newNativeSocket(domain, sockType, protocol).AsyncFD - result.SocketHandle.setBlocking(false) - register(result) - proc closeSocket*(socket: AsyncFD) = ## Closes a socket and ensures that it is unregistered. socket.SocketHandle.close() @@ -1005,6 +936,11 @@ else: if gDisp.isNil: gDisp = newDispatcher() result = gDisp + proc setGlobalDispatcher*(disp: PDispatcher) = + if not gDisp.isNil: + assert gDisp.callbacks.len == 0 + gDisp = disp + proc update(fd: AsyncFD, events: set[Event]) = let p = getGlobalDispatcher() assert fd.SocketHandle in p.selector @@ -1015,23 +951,6 @@ else: var data = PData(fd: fd, readCBs: @[], writeCBs: @[]) p.selector.register(fd.SocketHandle, {}, data.RootRef) - proc newAsyncNativeSocket*(domain: cint, sockType: cint, - protocol: cint): AsyncFD = - result = newNativeSocket(domain, sockType, protocol).AsyncFD - result.SocketHandle.setBlocking(false) - when defined(macosx): - result.SocketHandle.setSockOptInt(SOL_SOCKET, SO_NOSIGPIPE, 1) - register(result) - - proc newAsyncNativeSocket*(domain: Domain = AF_INET, - sockType: SockType = SOCK_STREAM, - protocol: Protocol = IPPROTO_TCP): AsyncFD = - result = newNativeSocket(domain, sockType, protocol).AsyncFD - result.SocketHandle.setBlocking(false) - when defined(macosx): - result.SocketHandle.setSockOptInt(SOL_SOCKET, SO_NOSIGPIPE, 1) - register(result) - proc closeSocket*(sock: AsyncFD) = let disp = getGlobalDispatcher() disp.selector.unregister(sock.SocketHandle) @@ -1115,50 +1034,6 @@ else: # Callback queue processing processPendingCallbacks(p) - proc connect*(socket: AsyncFD, address: string, port: Port, - domain = AF_INET): Future[void] = - var retFuture = newFuture[void]("connect") - - proc cb(fd: AsyncFD): bool = - var ret = SocketHandle(fd).getSockOptInt(cint(SOL_SOCKET), cint(SO_ERROR)) - if ret == 0: - # We have connected. - retFuture.complete() - return true - elif ret == EINTR: - # interrupted, keep waiting - return false - else: - retFuture.fail(newException(OSError, osErrorMsg(OSErrorCode(ret)))) - return true - - assert getSockDomain(socket.SocketHandle) == domain - var aiList = getAddrInfo(address, port, domain) - var success = false - var lastError: OSErrorCode - var it = aiList - while it != nil: - var ret = connect(socket.SocketHandle, it.ai_addr, it.ai_addrlen.Socklen) - if ret == 0: - # Request to connect completed immediately. - success = true - retFuture.complete() - break - else: - lastError = osLastError() - if lastError.int32 == EINTR or lastError.int32 == EINPROGRESS: - success = true - addWrite(socket, cb) - break - else: - success = false - it = it.ai_next - - freeAddrInfo(aiList) - if not success: - retFuture.fail(newException(OSError, osErrorMsg(lastError))) - return retFuture - proc recv*(socket: AsyncFD, size: int, flags = {SocketFlag.SafeDisconn}): Future[string] = var retFuture = newFuture[string]("recv") @@ -1320,11 +1195,20 @@ else: else: retFuture.fail(newException(OSError, osErrorMsg(lastError))) else: - register(client.AsyncFD) - retFuture.complete((getAddrString(cast[ptr SockAddr](addr sockAddress)), client.AsyncFD)) + try: + let address = getAddrString(cast[ptr SockAddr](addr sockAddress)) + register(client.AsyncFD) + retFuture.complete((address, client.AsyncFD)) + except: + # getAddrString may raise + client.close() + retFuture.fail(getCurrentException()) addRead(socket, cb) return retFuture +# Common procedures between current and upcoming asyncdispatch +include includes.asynccommon + proc sleepAsync*(ms: int): Future[void] = ## Suspends the execution of the current async procedure for the next ## ``ms`` milliseconds. diff --git a/lib/pure/asyncfile.nim b/lib/pure/asyncfile.nim index c58e6c11b..8fb30075c 100644 --- a/lib/pure/asyncfile.nim +++ b/lib/pure/asyncfile.nim @@ -70,14 +70,16 @@ else: result = O_RDWR result = result or O_NONBLOCK -proc getFileSize(f: AsyncFile): int64 = +proc getFileSize*(f: AsyncFile): int64 = ## Retrieves the specified file's size. when defined(windows) or defined(nimdoc): var high: DWord let low = getFileSize(f.fd.Handle, addr high) if low == INVALID_FILE_SIZE: raiseOSError(osLastError()) - return (high shl 32) or low + result = (high shl 32) or low + else: + result = lseek(f.fd.cint, 0, SEEK_END) proc openAsync*(filename: string, mode = fmRead): AsyncFile = ## Opens a file specified by the path in ``filename`` using @@ -310,7 +312,7 @@ proc setFilePos*(f: AsyncFile, pos: int64) = ## operations. The file's first byte has the index zero. f.offset = pos when not defined(windows) and not defined(nimdoc): - let ret = lseek(f.fd.cint, pos, SEEK_SET) + let ret = lseek(f.fd.cint, pos.Off, SEEK_SET) if ret == -1: raiseOSError(osLastError()) @@ -337,13 +339,17 @@ proc writeBuffer*(f: AsyncFile, buf: pointer, size: int): Future[void] = if not retFuture.finished: if errcode == OSErrorCode(-1): assert bytesCount == size.int32 - f.offset.inc(size) retFuture.complete() else: retFuture.fail(newException(OSError, osErrorMsg(errcode))) ) + # passing -1 here should work according to MSDN, but doesn't. For more + # information see + # http://stackoverflow.com/questions/33650899/does-asynchronous-file- + # appending-in-windows-preserve-order ol.offset = DWord(f.offset and 0xffffffff) ol.offsetHigh = DWord(f.offset shr 32) + f.offset.inc(size) # According to MSDN we're supposed to pass nil to lpNumberOfBytesWritten. let ret = writeFile(f.fd.Handle, buf, size.int32, nil, @@ -362,7 +368,6 @@ proc writeBuffer*(f: AsyncFile, buf: pointer, size: int): Future[void] = retFuture.fail(newException(OSError, osErrorMsg(osLastError()))) else: assert bytesWritten == size.int32 - f.offset.inc(size) retFuture.complete() else: var written = 0 @@ -408,7 +413,6 @@ proc write*(f: AsyncFile, data: string): Future[void] = if not retFuture.finished: if errcode == OSErrorCode(-1): assert bytesCount == data.len.int32 - f.offset.inc(data.len) retFuture.complete() else: retFuture.fail(newException(OSError, osErrorMsg(errcode))) @@ -418,6 +422,7 @@ proc write*(f: AsyncFile, data: string): Future[void] = ) ol.offset = DWord(f.offset and 0xffffffff) ol.offsetHigh = DWord(f.offset shr 32) + f.offset.inc(data.len) # According to MSDN we're supposed to pass nil to lpNumberOfBytesWritten. let ret = writeFile(f.fd.Handle, buffer, data.len.int32, nil, @@ -439,7 +444,6 @@ proc write*(f: AsyncFile, data: string): Future[void] = retFuture.fail(newException(OSError, osErrorMsg(osLastError()))) else: assert bytesWritten == data.len.int32 - f.offset.inc(data.len) retFuture.complete() else: var written = 0 @@ -466,6 +470,23 @@ proc write*(f: AsyncFile, data: string): Future[void] = addWrite(f.fd, cb) return retFuture +proc setFileSize*(f: AsyncFile, length: int64) = + ## Set a file length. + when defined(windows) or defined(nimdoc): + var + high = (length shr 32).Dword + let + low = (length and 0xffffffff).Dword + status = setFilePointer(f.fd.Handle, low, addr high, 0) + lastErr = osLastError() + if (status == INVALID_SET_FILE_POINTER and lastErr.int32 != NO_ERROR) or + (setEndOfFile(f.fd.Handle) == 0): + raiseOSError(osLastError()) + else: + # will truncate if Off is a 32-bit type! + if ftruncate(f.fd.cint, length.Off) == -1: + raiseOSError(osLastError()) + proc close*(f: AsyncFile) = ## Closes the file specified. unregister(f.fd) @@ -498,4 +519,4 @@ proc readToStream*(f: AsyncFile, fs: FutureStream[string]) {.async.} = break await fs.write(data) - fs.complete() \ No newline at end of file + fs.complete() diff --git a/lib/pure/asyncnet.nim b/lib/pure/asyncnet.nim index 1ec751a64..9f73bc3cf 100644 --- a/lib/pure/asyncnet.nim +++ b/lib/pure/asyncnet.nim @@ -244,6 +244,17 @@ when defineSsl: else: raiseSSLError("Socket has been disconnected") +proc dial*(address: string, port: Port, protocol = IPPROTO_TCP, + buffered = true): Future[AsyncSocket] {.async.} = + ## Establishes connection to the specified ``address``:``port`` pair via the + ## specified protocol. The procedure iterates through possible + ## resolutions of the ``address`` until it succeeds, meaning that it + ## seamlessly works with both IPv4 and IPv6. + ## Returns AsyncSocket ready to send or receive data. + let asyncFd = await asyncdispatch.dial(address, port, protocol) + let sockType = protocol.toSockType() + let domain = getSockDomain(asyncFd.SocketHandle) + result = newAsyncSocket(asyncFd, domain, sockType, protocol, buffered) proc connect*(socket: AsyncSocket, address: string, port: Port) {.async.} = ## Connects ``socket`` to server at ``address:port``. @@ -636,9 +647,12 @@ when defineSsl: sslSetBio(socket.sslHandle, socket.bioIn, socket.bioOut) proc wrapConnectedSocket*(ctx: SslContext, socket: AsyncSocket, - handshake: SslHandshakeType) = + handshake: SslHandshakeType, + hostname: string = nil) = ## Wraps a connected socket in an SSL context. This function effectively ## turns ``socket`` into an SSL socket. + ## ``hostname`` should be specified so that the client knows which hostname + ## the server certificate should be validated against. ## ## This should be called on a connected socket, and will perform ## an SSL handshake immediately. @@ -649,6 +663,10 @@ when defineSsl: case handshake of handshakeAsClient: + if not hostname.isNil and not isIpAddress(hostname): + # Set the SNI address for this connection. This call can fail if + # we're not using TLSv1+. + discard SSL_set_tlsext_host_name(socket.sslHandle, hostname) sslSetConnectState(socket.sslHandle) of handshakeAsServer: sslSetAcceptState(socket.sslHandle) diff --git a/lib/pure/base64.nim b/lib/pure/base64.nim index eee03d7ae..4b0d08292 100644 --- a/lib/pure/base64.nim +++ b/lib/pure/base64.nim @@ -44,21 +44,23 @@ const cb64 = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/" -template encodeInternal(s: expr, lineLen: int, newLine: string): stmt {.immediate.} = +template encodeInternal(s: typed, lineLen: int, newLine: string): untyped = ## encodes `s` into base64 representation. After `lineLen` characters, a ## `newline` is added. var total = ((len(s) + 2) div 3) * 4 - var numLines = (total + lineLen - 1) div lineLen + let numLines = (total + lineLen - 1) div lineLen if numLines > 0: inc(total, (numLines - 1) * newLine.len) result = newString(total) - var i = 0 - var r = 0 - var currLine = 0 + var + i = 0 + r = 0 + currLine = 0 while i < s.len - 2: - var a = ord(s[i]) - var b = ord(s[i+1]) - var c = ord(s[i+2]) + let + a = ord(s[i]) + b = ord(s[i+1]) + c = ord(s[i+2]) result[r] = cb64[a shr 2] result[r+1] = cb64[((a and 3) shl 4) or ((b and 0xF0) shr 4)] result[r+2] = cb64[((b and 0x0F) shl 2) or ((c and 0xC0) shr 6)] @@ -74,8 +76,9 @@ template encodeInternal(s: expr, lineLen: int, newLine: string): stmt {.immediat currLine = 0 if i < s.len-1: - var a = ord(s[i]) - var b = ord(s[i+1]) + let + a = ord(s[i]) + b = ord(s[i+1]) result[r] = cb64[a shr 2] result[r+1] = cb64[((a and 3) shl 4) or ((b and 0xF0) shr 4)] result[r+2] = cb64[((b and 0x0F) shl 2)] @@ -83,7 +86,7 @@ template encodeInternal(s: expr, lineLen: int, newLine: string): stmt {.immediat if r+4 != result.len: setLen(result, r+4) elif i < s.len: - var a = ord(s[i]) + let a = ord(s[i]) result[r] = cb64[a shr 2] result[r+1] = cb64[(a and 3) shl 4] result[r+2] = '=' @@ -127,15 +130,17 @@ proc decode*(s: string): string = # total is an upper bound, as we will skip arbitrary whitespace: result = newString(total) - var i = 0 - var r = 0 + var + i = 0 + r = 0 while true: while s[i] in Whitespace: inc(i) if i < s.len-3: - var a = s[i].decodeByte - var b = s[i+1].decodeByte - var c = s[i+2].decodeByte - var d = s[i+3].decodeByte + let + a = s[i].decodeByte + b = s[i+1].decodeByte + c = s[i+2].decodeByte + d = s[i+3].decodeByte result[r] = chr((a shl 2) and 0xff or ((b shr 4) and 0x03)) result[r+1] = chr((b shl 4) and 0xff or ((c shr 2) and 0x0F)) @@ -169,4 +174,4 @@ when isMainModule: for t in items(tests): assert decode(encode(t)) == t assert decode(encode(t, lineLen=40)) == t - assert decode(encode(t, lineLen=76)) == t \ No newline at end of file + assert decode(encode(t, lineLen=76)) == t diff --git a/lib/pure/collections/tables.nim b/lib/pure/collections/tables.nim index b6c00966f..323af5a38 100644 --- a/lib/pure/collections/tables.nim +++ b/lib/pure/collections/tables.nim @@ -269,6 +269,18 @@ proc del*[A, B](t: var Table[A, B], key: A) = ## deletes `key` from hash table `t`. delImpl() +proc take*[A, B](t: var Table[A, B], key: A, val: var B): bool = + ## Deletes the ``key`` from the table. + ## Returns ``true``, if the ``key`` existed, and sets ``val`` to the + ## mapping of the key. Otherwise, returns ``false``, and the ``val`` is + ## unchanged. + var hc: Hash + var index = rawGet(t, key, hc) + result = index >= 0 + if result: + shallowCopy(val, t.data[index].val) + delImplIdx(t, index) + proc enlarge[A, B](t: var Table[A, B]) = var n: KeyValuePairSeq[A, B] newSeq(n, len(t.data) * growthFactor) @@ -424,6 +436,13 @@ proc del*[A, B](t: TableRef[A, B], key: A) = ## deletes `key` from hash table `t`. t[].del(key) +proc take*[A, B](t: TableRef[A, B], key: A, val: var B): bool = + ## Deletes the ``key`` from the table. + ## Returns ``true``, if the ``key`` existed, and sets ``val`` to the + ## mapping of the key. Otherwise, returns ``false``, and the ``val`` is + ## unchanged. + result = t[].take(key, val) + proc newTable*[A, B](initialSize=64): TableRef[A, B] = new(result) result[] = initTable[A, B](initialSize) @@ -625,7 +644,7 @@ proc `==`*[A, B](s, t: OrderedTable[A, B]): bool = while ht >= 0 and hs >= 0: var nxtt = t.data[ht].next var nxts = s.data[hs].next - if isFilled(t.data[ht].hcode) and isFilled(s.data[hs].hcode): + if isFilled(t.data[ht].hcode) and isFilled(s.data[hs].hcode): if (s.data[hs].key != t.data[ht].key) and (s.data[hs].val != t.data[ht].val): return false ht = nxtt @@ -785,7 +804,7 @@ proc sort*[A, B](t: OrderedTableRef[A, B], t[].sort(cmp) proc del*[A, B](t: var OrderedTable[A, B], key: A) = - ## deletes `key` from ordered hash table `t`. O(n) comlexity. + ## deletes `key` from ordered hash table `t`. O(n) complexity. var n: OrderedKeyValuePairSeq[A, B] newSeq(n, len(t.data)) var h = t.first @@ -804,7 +823,7 @@ proc del*[A, B](t: var OrderedTable[A, B], key: A) = h = nxt proc del*[A, B](t: var OrderedTableRef[A, B], key: A) = - ## deletes `key` from ordered hash table `t`. O(n) comlexity. + ## deletes `key` from ordered hash table `t`. O(n) complexity. t[].del(key) # ------------------------------ count tables ------------------------------- @@ -829,7 +848,7 @@ proc clear*[A](t: CountTableRef[A]) = proc clear*[A](t: var CountTable[A]) = ## Resets the table so that it is empty. clearImpl() - + iterator pairs*[A](t: CountTable[A]): (A, int) = ## iterates over any (key, value) pair in the table `t`. for h in 0..high(t.data): @@ -1256,17 +1275,17 @@ when isMainModule: var b = newOrderedTable[string, string](initialSize=2) b.add("wrong?", "foo") b.add("wrong?", "foo2") - assert a == b + assert a == b block: #5482 - var a = {"wrong?": "foo", "wrong?": "foo2"}.newOrderedTable() + var a = {"wrong?": "foo", "wrong?": "foo2"}.newOrderedTable() var b = newOrderedTable[string, string](initialSize=2) b.add("wrong?", "foo") b.add("wrong?", "foo2") - assert a == b + assert a == b block: #5487 - var a = {"wrong?": "foo", "wrong?": "foo2"}.newOrderedTable() + var a = {"wrong?": "foo", "wrong?": "foo2"}.newOrderedTable() var b = newOrderedTable[string, string]() # notice, default size! b.add("wrong?", "foo") b.add("wrong?", "foo2") @@ -1279,13 +1298,13 @@ when isMainModule: b.add("wrong?", "foo2") assert a == b - block: - var a = {"wrong?": "foo", "wrong?": "foo2"}.newOrderedTable() - var b = [("wrong?","foo"), ("wrong?", "foo2")].newOrderedTable() + block: + var a = {"wrong?": "foo", "wrong?": "foo2"}.newOrderedTable() + var b = [("wrong?","foo"), ("wrong?", "foo2")].newOrderedTable() var c = newOrderedTable[string, string]() # notice, default size! c.add("wrong?", "foo") - c.add("wrong?", "foo2") + c.add("wrong?", "foo2") assert a == b assert a == c - + diff --git a/lib/pure/cookies.nim b/lib/pure/cookies.nim index 8090cd49d..7d850798c 100644 --- a/lib/pure/cookies.nim +++ b/lib/pure/cookies.nim @@ -39,7 +39,7 @@ proc setCookie*(key, value: string, domain = "", path = "", if domain != "": result.add("; Domain=" & domain) if path != "": result.add("; Path=" & path) if expires != "": result.add("; Expires=" & expires) - if secure: result.add("; secure") + if secure: result.add("; Secure") if httpOnly: result.add("; HttpOnly") proc setCookie*(key, value: string, expires: TimeInfo, @@ -50,7 +50,7 @@ proc setCookie*(key, value: string, expires: TimeInfo, ## ## **Note:** UTC is assumed as the timezone for ``expires``. return setCookie(key, value, domain, path, - format(expires, "ddd',' dd MMM yyyy HH:mm:ss 'UTC'"), + format(expires, "ddd',' dd MMM yyyy HH:mm:ss 'GMT'"), noname, secure, httpOnly) when isMainModule: diff --git a/lib/pure/httpclient.nim b/lib/pure/httpclient.nim index 62c7e2067..4f43177a8 100644 --- a/lib/pure/httpclient.nim +++ b/lib/pure/httpclient.nim @@ -434,7 +434,7 @@ proc `[]=`*(p: var MultipartData, name: string, ## "<html><head></head><body><p>test</p></body></html>") p.add(name, file.content, file.name, file.contentType) -proc format(p: MultipartData): tuple[header, body: string] = +proc format(p: MultipartData): tuple[contentType, body: string] = if p == nil or p.content == nil or p.content.len == 0: return ("", "") @@ -449,7 +449,7 @@ proc format(p: MultipartData): tuple[header, body: string] = if not found: break - result.header = "Content-Type: multipart/form-data; boundary=" & bound & "\c\L" + result.contentType = "multipart/form-data; boundary=" & bound result.body = "" for s in p.content: result.body.add("--" & bound & "\c\L" & s) @@ -512,7 +512,7 @@ proc request*(url: string, httpMethod: string, extraHeaders = "", raise newException(HttpRequestError, "The proxy server rejected a CONNECT request, " & "so a secure connection could not be established.") - sslContext.wrapConnectedSocket(s, handshakeAsClient) + sslContext.wrapConnectedSocket(s, handshakeAsClient, hostUrl.hostname) else: raise newException(HttpRequestError, "SSL support not available. Cannot connect via proxy over SSL") else: @@ -640,7 +640,7 @@ proc post*(url: string, extraHeaders = "", body = "", ## ``multipart/form-data`` POSTs comfortably. ## ## **Deprecated since version 0.15.0**: use ``HttpClient.post`` instead. - let (mpHeaders, mpBody) = format(multipart) + let (mpContentType, mpBody) = format(multipart) template withNewLine(x): untyped = if x.len > 0 and not x.endsWith("\c\L"): @@ -650,9 +650,12 @@ proc post*(url: string, extraHeaders = "", body = "", var xb = mpBody.withNewLine() & body - var xh = extraHeaders.withNewLine() & mpHeaders.withNewLine() & + var xh = extraHeaders.withNewLine() & withNewLine("Content-Length: " & $len(xb)) + if not multipart.isNil: + xh.add(withNewLine("Content-Type: " & mpContentType)) + result = request(url, httpPOST, xh, xb, sslContext, timeout, userAgent, proxy) var lastURL = url @@ -1030,32 +1033,39 @@ proc newConnection(client: HttpClient | AsyncHttpClient, if client.currentURL.hostname != url.hostname or client.currentURL.scheme != url.scheme or client.currentURL.port != url.port: + let isSsl = url.scheme.toLowerAscii() == "https" + + if isSsl and not defined(ssl): + raise newException(HttpRequestError, + "SSL support is not available. Cannot connect over SSL.") + if client.connected: client.close() - when client is HttpClient: - client.socket = newSocket() - elif client is AsyncHttpClient: - client.socket = newAsyncSocket() - else: {.fatal: "Unsupported client type".} - # TODO: I should be able to write 'net.Port' here... let port = if url.port == "": - if url.scheme.toLower() == "https": + if isSsl: nativesockets.Port(443) else: nativesockets.Port(80) else: nativesockets.Port(url.port.parseInt) - if url.scheme.toLower() == "https": - when defined(ssl): - client.sslContext.wrapSocket(client.socket) - else: - raise newException(HttpRequestError, - "SSL support is not available. Cannot connect over SSL.") + when client is HttpClient: + client.socket = await net.dial(url.hostname, port) + elif client is AsyncHttpClient: + client.socket = await asyncnet.dial(url.hostname, port) + else: {.fatal: "Unsupported client type".} + + when defined(ssl): + if isSsl: + try: + client.sslContext.wrapConnectedSocket( + client.socket, handshakeAsClient, url.hostname) + except: + client.socket.close() + raise getCurrentException() - await client.socket.connect(url.hostname, port) client.currentURL = url client.connected = true @@ -1093,7 +1103,8 @@ proc requestAux(client: HttpClient | AsyncHttpClient, url: string, raise newException(HttpRequestError, "The proxy server rejected a CONNECT request, " & "so a secure connection could not be established.") - client.sslContext.wrapConnectedSocket(client.socket, handshakeAsClient) + client.sslContext.wrapConnectedSocket( + client.socket, handshakeAsClient, requestUrl.hostname) client.proxy = nil else: raise newException(HttpRequestError, @@ -1188,7 +1199,7 @@ proc post*(client: HttpClient | AsyncHttpClient, url: string, body = "", ## ## This procedure will follow redirects up to a maximum number of redirects ## specified in ``client.maxRedirects``. - let (mpHeader, mpBody) = format(multipart) + let (mpContentType, mpBody) = format(multipart) # TODO: Support FutureStream for `body` parameter. template withNewLine(x): untyped = if x.len > 0 and not x.endsWith("\c\L"): @@ -1199,7 +1210,7 @@ proc post*(client: HttpClient | AsyncHttpClient, url: string, body = "", var headers = newHttpHeaders() if multipart != nil: - headers["Content-Type"] = mpHeader.split(": ")[1] + headers["Content-Type"] = mpContentType headers["Content-Length"] = $len(xb) result = await client.requestAux(url, $HttpPOST, xb, headers) diff --git a/lib/pure/includes/asynccommon.nim b/lib/pure/includes/asynccommon.nim new file mode 100644 index 000000000..a7d2f803f --- /dev/null +++ b/lib/pure/includes/asynccommon.nim @@ -0,0 +1,201 @@ +template newAsyncNativeSocketImpl(domain, sockType, protocol) = + let handle = newNativeSocket(domain, sockType, protocol) + if handle == osInvalidSocket: + raiseOSError(osLastError()) + handle.setBlocking(false) + when defined(macosx): + handle.setSockOptInt(SOL_SOCKET, SO_NOSIGPIPE, 1) + result = handle.AsyncFD + register(result) + +proc newAsyncNativeSocket*(domain: cint, sockType: cint, + protocol: cint): AsyncFD = + newAsyncNativeSocketImpl(domain, sockType, protocol) + +proc newAsyncNativeSocket*(domain: Domain = Domain.AF_INET, + sockType: SockType = SOCK_STREAM, + protocol: Protocol = IPPROTO_TCP): AsyncFD = + newAsyncNativeSocketImpl(domain, sockType, protocol) + +when defined(windows) or defined(nimdoc): + proc bindToDomain(handle: SocketHandle, domain: Domain) = + # Extracted into a separate proc, because connect() on Windows requires + # the socket to be initially bound. + template doBind(saddr) = + if bindAddr(handle, cast[ptr SockAddr](addr(saddr)), + sizeof(saddr).SockLen) < 0'i32: + raiseOSError(osLastError()) + + if domain == Domain.AF_INET6: + var saddr: Sockaddr_in6 + saddr.sin6_family = int16(toInt(domain)) + doBind(saddr) + else: + var saddr: Sockaddr_in + saddr.sin_family = int16(toInt(domain)) + doBind(saddr) + + proc doConnect(socket: AsyncFD, addrInfo: ptr AddrInfo): Future[void] = + let retFuture = newFuture[void]("doConnect") + result = retFuture + + var ol = PCustomOverlapped() + GC_ref(ol) + ol.data = CompletionData(fd: socket, cb: + proc (fd: AsyncFD, bytesCount: Dword, errcode: OSErrorCode) = + if not retFuture.finished: + if errcode == OSErrorCode(-1): + retFuture.complete() + else: + retFuture.fail(newException(OSError, osErrorMsg(errcode))) + ) + + let ret = connectEx(socket.SocketHandle, addrInfo.ai_addr, + cint(addrInfo.ai_addrlen), nil, 0, nil, + cast[POVERLAPPED](ol)) + if ret: + # Request to connect completed immediately. + retFuture.complete() + # We don't deallocate ``ol`` here because even though this completed + # immediately poll will still be notified about its completion and it + # will free ``ol``. + else: + let lastError = osLastError() + if lastError.int32 != ERROR_IO_PENDING: + # With ERROR_IO_PENDING ``ol`` will be deallocated in ``poll``, + # and the future will be completed/failed there, too. + GC_unref(ol) + retFuture.fail(newException(OSError, osErrorMsg(lastError))) +else: + proc doConnect(socket: AsyncFD, addrInfo: ptr AddrInfo): Future[void] = + let retFuture = newFuture[void]("doConnect") + result = retFuture + + proc cb(fd: AsyncFD): bool = + let ret = SocketHandle(fd).getSockOptInt( + cint(SOL_SOCKET), cint(SO_ERROR)) + if ret == 0: + # We have connected. + retFuture.complete() + return true + elif ret == EINTR: + # interrupted, keep waiting + return false + else: + retFuture.fail(newException(OSError, osErrorMsg(OSErrorCode(ret)))) + return true + + let ret = connect(socket.SocketHandle, + addrInfo.ai_addr, + addrInfo.ai_addrlen.Socklen) + if ret == 0: + # Request to connect completed immediately. + retFuture.complete() + else: + let lastError = osLastError() + if lastError.int32 == EINTR or lastError.int32 == EINPROGRESS: + addWrite(socket, cb) + else: + retFuture.fail(newException(OSError, osErrorMsg(lastError))) + +template asyncAddrInfoLoop(addrInfo: ptr AddrInfo, fd: untyped, + protocol: Protocol = IPPROTO_RAW) = + ## Iterates through the AddrInfo linked list asynchronously + ## until the connection can be established. + const shouldCreateFd = not declared(fd) + + when shouldCreateFd: + let sockType = protocol.toSockType() + + var fdPerDomain: array[low(Domain).ord..high(Domain).ord, AsyncFD] + for i in low(fdPerDomain)..high(fdPerDomain): + fdPerDomain[i] = osInvalidSocket.AsyncFD + template closeUnusedFds(domainToKeep = -1) {.dirty.} = + for i, fd in fdPerDomain: + if fd != osInvalidSocket.AsyncFD and i != domainToKeep: + fd.closeSocket() + + var lastException: ref Exception + var curAddrInfo = addrInfo + var domain: Domain + when shouldCreateFd: + var curFd: AsyncFD + else: + var curFd = fd + proc tryNextAddrInfo(fut: Future[void]) {.gcsafe.} = + if fut == nil or fut.failed: + if fut != nil: + lastException = fut.readError() + + while curAddrInfo != nil: + let domainOpt = curAddrInfo.ai_family.toKnownDomain() + if domainOpt.isSome: + domain = domainOpt.unsafeGet() + break + curAddrInfo = curAddrInfo.ai_next + + if curAddrInfo == nil: + freeAddrInfo(addrInfo) + when shouldCreateFd: + closeUnusedFds() + if lastException != nil: + retFuture.fail(lastException) + else: + retFuture.fail(newException( + IOError, "Couldn't resolve address: " & address)) + return + + when shouldCreateFd: + curFd = fdPerDomain[ord(domain)] + if curFd == osInvalidSocket.AsyncFD: + try: + curFd = newAsyncNativeSocket(domain, sockType, protocol) + except: + freeAddrInfo(addrInfo) + closeUnusedFds() + raise getCurrentException() + when defined(windows): + curFd.SocketHandle.bindToDomain(domain) + fdPerDomain[ord(domain)] = curFd + + doConnect(curFd, curAddrInfo).callback = tryNextAddrInfo + curAddrInfo = curAddrInfo.ai_next + else: + freeAddrInfo(addrInfo) + when shouldCreateFd: + closeUnusedFds(ord(domain)) + retFuture.complete(curFd) + else: + retFuture.complete() + + tryNextAddrInfo(nil) + +proc dial*(address: string, port: Port, + protocol: Protocol = IPPROTO_TCP): Future[AsyncFD] = + ## Establishes connection to the specified ``address``:``port`` pair via the + ## specified protocol. The procedure iterates through possible + ## resolutions of the ``address`` until it succeeds, meaning that it + ## seamlessly works with both IPv4 and IPv6. + ## Returns the async file descriptor, registered in the dispatcher of + ## the current thread, ready to send or receive data. + let retFuture = newFuture[AsyncFD]("dial") + result = retFuture + let sockType = protocol.toSockType() + + let aiList = getAddrInfo(address, port, Domain.AF_UNSPEC, sockType, protocol) + asyncAddrInfoLoop(aiList, noFD, protocol) + +proc connect*(socket: AsyncFD, address: string, port: Port, + domain = Domain.AF_INET): Future[void] = + let retFuture = newFuture[void]("connect") + result = retFuture + + when defined(windows): + verifyPresence(socket) + else: + assert getSockDomain(socket.SocketHandle) == domain + + let aiList = getAddrInfo(address, port, domain) + when defined(windows): + socket.SocketHandle.bindToDomain(domain) + asyncAddrInfoLoop(aiList, socket) diff --git a/lib/pure/json.nim b/lib/pure/json.nim index bacb182b4..564f952d3 100644 --- a/lib/pure/json.nim +++ b/lib/pure/json.nim @@ -14,25 +14,56 @@ ## JSON is based on a subset of the JavaScript Programming Language, ## Standard ECMA-262 3rd Edition - December 1999. ## -## Usage example: +## Dynamically retrieving fields from JSON +## ======================================= ## -## .. code-block:: nim -## let -## small_json = """{"test": 1.3, "key2": true}""" -## jobj = parseJson(small_json) -## assert (jobj.kind == JObject)\ -## jobj["test"] = newJFloat(0.7) # create or update -## echo($jobj["test"].fnum) -## echo($jobj["key2"].bval) -## echo jobj{"missing key"}.getFNum(0.1) # read a float value using a default -## jobj{"a", "b", "c"} = newJFloat(3.3) # created nested keys +## This module allows you to access fields in a parsed JSON object in two +## different ways, one of them is described in this section. ## -## Results in: +## The ``parseJson`` procedure takes a string containing JSON and returns a +## ``JsonNode`` object. This is an object variant and it is either a +## ``JObject``, ``JArray``, ``JString``, ``JInt``, ``JFloat``, ``JBool`` or +## ``JNull``. You +## check the kind of this object variant by using the ``kind`` accessor. ## -## .. code-block:: nim +## For a ``JsonNode`` who's kind is ``JObject``, you can acess its fields using +## the ``[]`` operator. The following example shows how to do this: +## +## .. code-block:: Nim +## let jsonNode = parseJson("""{"key": 3.14}""") +## doAssert jsonNode.kind == JObject +## doAssert jsonNode["key"].kind == JFloat +## +## Retrieving the value of a JSON node can then be achieved using one of the +## helper procedures, which include: +## +## * ``getNum`` +## * ``getFNum`` +## * ``getStr`` +## * ``getBVal`` +## +## To retrieve the value of ``"key"`` you can do the following: +## +## .. code-block:: Nim +## doAssert jsonNode["key"].getFNum() == 3.14 +## +## The ``[]`` operator will raise an exception when the specified field does +## not exist. If you wish to avoid this behaviour you can use the ``{}`` +## operator instead, it will simply return ``nil`` when the field is not found. +## The ``get``-family of procedures will return a default value when called on +## ``nil``. +## +## Unmarshalling JSON into a type +## ============================== ## -## 1.3000000000000000e+00 -## true +## This module allows you to access fields in a parsed JSON object in two +## different ways, one of them is described in this section. +## +## This is done using the ``to`` macro. Take a look at +## `its documentation <#to.m,JsonNode,typedesc>`_ to see an example of its use. +## +## Creating JSON +## ============= ## ## This module can also be used to comfortably create JSON using the `%*` ## operator: @@ -124,6 +155,9 @@ type state: seq[ParserState] filename: string + JsonKindError* = object of ValueError ## raised by the ``to`` macro if the + ## JSON kind is incorrect. + {.deprecated: [TJsonEventKind: JsonEventKind, TJsonError: JsonError, TJsonParser: JsonParser, TTokKind: TokKind].} @@ -1179,22 +1213,22 @@ when not defined(js): proc parseJson*(s: Stream, filename: string): JsonNode = ## Parses from a stream `s` into a `JsonNode`. `filename` is only needed ## for nice error messages. - ## If `s` contains extra data, it will raising `JsonParsingError`. + ## If `s` contains extra data, it will raise `JsonParsingError`. var p: JsonParser p.open(s, filename) defer: p.close() discard getTok(p) # read first token result = p.parseJson() - eat(p, tkEof) # check there are no exstra data + eat(p, tkEof) # check if there is no extra data proc parseJson*(buffer: string): JsonNode = ## Parses JSON from `buffer`. - ## If `buffer` contains extra data, it will raising `JsonParsingError`. + ## If `buffer` contains extra data, it will raise `JsonParsingError`. result = parseJson(newStringStream(buffer), "input") proc parseFile*(filename: string): JsonNode = ## Parses `file` into a `JsonNode`. - ## If `file` contains extra data, it will raising `JsonParsingError`. + ## If `file` contains extra data, it will raise `JsonParsingError`. var stream = newFileStream(filename, fmRead) if stream == nil: raise newException(IOError, "cannot read from file: " & filename) @@ -1272,6 +1306,465 @@ else: proc parseJson*(buffer: string): JsonNode = return parseNativeJson(buffer).convertObject() +# -- Json deserialiser macro. -- + +proc createJsonIndexer(jsonNode: NimNode, + index: string | int | NimNode): NimNode + {.compileTime.} = + when index is string: + let indexNode = newStrLitNode(index) + elif index is int: + let indexNode = newIntLitNode(index) + elif index is NimNode: + let indexNode = index + + result = newNimNode(nnkBracketExpr).add( + jsonNode, + indexNode + ) + +template verifyJsonKind(node: JsonNode, kinds: set[JsonNodeKind], + ast: string) = + if node.kind notin kinds: + let msg = "Incorrect JSON kind. Wanted '$1' in '$2' but got '$3'." % [ + $kinds, + ast, + $node.kind + ] + raise newException(JsonKindError, msg) + +proc getEnum(node: JsonNode, ast: string, T: typedesc): T = + when T is SomeInteger: + # TODO: I shouldn't need this proc. + proc convert[T](x: BiggestInt): T = T(x) + verifyJsonKind(node, {JInt}, ast) + return convert[T](node.getNum()) + else: + verifyJsonKind(node, {JString}, ast) + return parseEnum[T](node.getStr()) + +proc toIdentNode(typeNode: NimNode): NimNode = + ## Converts a Sym type node (returned by getType et al.) into an + ## Ident node. Placing Sym type nodes inside the resulting code AST is + ## unsound (according to @Araq) so this is necessary. + case typeNode.kind + of nnkSym: + return newIdentNode($typeNode) + of nnkBracketExpr: + result = typeNode + for i in 0..<len(result): + result[i] = newIdentNode($result[i]) + of nnkIdent: + return typeNode + else: + doAssert false, "Cannot convert typeNode to an ident node: " & $typeNode.kind + +proc createGetEnumCall(jsonNode, kindType: NimNode): NimNode = + # -> getEnum(`jsonNode`, `kindType`) + let getEnumSym = bindSym("getEnum") + let astStrLit = toStrLit(jsonNode) + let getEnumCall = newCall(getEnumSym, jsonNode, astStrLit, kindType) + return getEnumCall + +proc createOfBranchCond(ofBranch, getEnumCall: NimNode): NimNode = + ## Creates an expression that acts as the condition for an ``of`` branch. + var cond = newIdentNode("false") + for ofCond in ofBranch: + if ofCond.kind == nnkRecList: + break + + let comparison = infix(getEnumCall, "==", ofCond) + cond = infix(cond, "or", comparison) + + return cond + +proc processObjField(field, jsonNode: NimNode): seq[NimNode] {.compileTime.} +proc processOfBranch(ofBranch, jsonNode, kindType, + kindJsonNode: NimNode): seq[NimNode] {.compileTime.} = + ## Processes each field inside of an object's ``of`` branch. + ## For each field a new ExprColonExpr node is created and put in the + ## resulting list. + ## + ## Sample ``ofBranch`` AST: + ## + ## .. code-block::plain + ## OfBranch of 0, 1: + ## IntLit 0 foodPos: float + ## IntLit 1 enemyPos: float + ## RecList + ## Sym "foodPos" + ## Sym "enemyPos" + result = @[] + let getEnumCall = createGetEnumCall(kindJsonNode, kindType) + + for branchField in ofBranch[^1]: + let objFields = processObjField(branchField, jsonNode) + + for objField in objFields: + let exprColonExpr = newNimNode(nnkExprColonExpr) + result.add(exprColonExpr) + # Add the name of the field. + exprColonExpr.add(toIdentNode(objField[0])) + + # Add the value of the field. + let cond = createOfBranchCond(ofBranch, getEnumCall) + exprColonExpr.add(newIfStmt( + (cond, objField[1]) + )) + +proc processElseBranch(recCaseNode, elseBranch, jsonNode, kindType, + kindJsonNode: NimNode): seq[NimNode] {.compileTime.} = + ## Processes each field inside of a variant object's ``else`` branch. + ## + ## ..code-block::plain + ## Else + ## RecList + ## Sym "other" + result = @[] + let getEnumCall = createGetEnumCall(kindJsonNode, kindType) + + # We need to build up a list of conditions from each ``of`` branch so that + # we can then negate it to get ``else``. + var cond = newIdentNode("false") + for i in 1 .. <len(recCaseNode): + if recCaseNode[i].kind == nnkElse: + break + + cond = infix(cond, "or", createOfBranchCond(recCaseNode[i], getEnumCall)) + + # Negate the condition. + cond = prefix(cond, "not") + + for branchField in elseBranch[^1]: + let objFields = processObjField(branchField, jsonNode) + + for objField in objFields: + let exprColonExpr = newNimNode(nnkExprColonExpr) + result.add(exprColonExpr) + # Add the name of the field. + exprColonExpr.add(toIdentNode(objField[0])) + + # Add the value of the field. + let ifStmt = newIfStmt((cond, objField[1])) + exprColonExpr.add(ifStmt) + +proc createConstructor(typeSym, jsonNode: NimNode): NimNode {.compileTime.} +proc processObjField(field, jsonNode: NimNode): seq[NimNode] = + ## Process a field from a ``RecList``. + ## + ## The field will typically be a simple ``Sym`` node, but for object variants + ## it may also be a ``RecCase`` in which case things become complicated. + result = @[] + case field.kind + of nnkSym: + # Ordinary field. For example, `name: string`. + let exprColonExpr = newNimNode(nnkExprColonExpr) + result.add(exprColonExpr) + + # Add the field name. + exprColonExpr.add(toIdentNode(field)) + + # Add the field value. + # -> jsonNode["`field`"] + let indexedJsonNode = createJsonIndexer(jsonNode, $field) + exprColonExpr.add(createConstructor(getTypeInst(field), indexedJsonNode)) + + of nnkRecCase: + # A "case" field that introduces a variant. + let exprColonExpr = newNimNode(nnkExprColonExpr) + result.add(exprColonExpr) + + # Add the "case" field name (usually "kind"). + exprColonExpr.add(toIdentNode(field[0])) + + # -> jsonNode["`field[0]`"] + let kindJsonNode = createJsonIndexer(jsonNode, $field[0]) + + # Add the "case" field's value. + let kindType = toIdentNode(getTypeInst(field[0])) + let getEnumSym = bindSym("getEnum") + let astStrLit = toStrLit(kindJsonNode) + let getEnumCall = newCall(getEnumSym, kindJsonNode, astStrLit, kindType) + exprColonExpr.add(getEnumCall) + + # Iterate through each `of` branch. + for i in 1 .. <field.len: + case field[i].kind + of nnkOfBranch: + result.add processOfBranch(field[i], jsonNode, kindType, kindJsonNode) + of nnkElse: + result.add processElseBranch(field, field[i], jsonNode, kindType, kindJsonNode) + else: + doAssert false, "Expected OfBranch or Else node kinds, got: " & $field[i].kind + else: + doAssert false, "Unable to process object field: " & $field.kind + + doAssert result.len > 0 + +proc processType(typeName: NimNode, obj: NimNode, + jsonNode: NimNode, isRef: bool): NimNode {.compileTime.} = + ## Process a type such as ``Sym "float"`` or ``ObjectTy ...``. + ## + ## Sample ``ObjectTy``: + ## + ## .. code-block::plain + ## ObjectTy + ## Empty + ## Empty + ## RecList + ## Sym "events" + case obj.kind + of nnkObjectTy: + # Create object constructor. + result = newNimNode(nnkObjConstr) + result.add(typeName) # Name of the type to construct. + + # Process each object field and add it as an exprColonExpr + expectKind(obj[2], nnkRecList) + for field in obj[2]: + let nodes = processObjField(field, jsonNode) + result.add(nodes) + + # Object might be null. So we need to check for that. + if isRef: + result = quote do: + verifyJsonKind(`jsonNode`, {JObject, JNull}, astToStr(`jsonNode`)) + if `jsonNode`.kind == JNull: + nil + else: + `result` + else: + result = quote do: + verifyJsonKind(`jsonNode`, {JObject}, astToStr(`jsonNode`)); + `result` + + of nnkEnumTy: + let instType = toIdentNode(getTypeInst(typeName)) + let getEnumCall = createGetEnumCall(jsonNode, instType) + result = quote do: + ( + `getEnumCall` + ) + of nnkSym: + case ($typeName).normalize + of "float": + result = quote do: + ( + verifyJsonKind(`jsonNode`, {JFloat, JInt}, astToStr(`jsonNode`)); + if `jsonNode`.kind == JFloat: `jsonNode`.fnum else: `jsonNode`.num.float + ) + of "string": + result = quote do: + ( + verifyJsonKind(`jsonNode`, {JString, JNull}, astToStr(`jsonNode`)); + if `jsonNode`.kind == JNull: nil else: `jsonNode`.str + ) + of "int": + result = quote do: + ( + verifyJsonKind(`jsonNode`, {JInt}, astToStr(`jsonNode`)); + `jsonNode`.num.int + ) + of "biggestint": + result = quote do: + ( + verifyJsonKind(`jsonNode`, {JInt}, astToStr(`jsonNode`)); + `jsonNode`.num + ) + of "bool": + result = quote do: + ( + verifyJsonKind(`jsonNode`, {JBool}, astToStr(`jsonNode`)); + `jsonNode`.bval + ) + else: + doAssert false, "Unable to process nnkSym " & $typeName + else: + doAssert false, "Unable to process type: " & $obj.kind + + doAssert(not result.isNil(), "processType not initialised.") + +proc createConstructor(typeSym, jsonNode: NimNode): NimNode = + ## Accepts a type description, i.e. "ref Type", "seq[Type]", "Type" etc. + ## + ## The ``jsonNode`` refers to the node variable that we are deserialising. + ## + ## Returns an object constructor node. + # echo("--createConsuctor-- \n", treeRepr(typeSym)) + # echo() + + case typeSym.kind + of nnkBracketExpr: + var bracketName = ($typeSym[0]).normalize + case bracketName + of "ref": + # Ref type. + var typeName = $typeSym[1] + # Remove the `:ObjectType` suffix. + if typeName.endsWith(":ObjectType"): + typeName = typeName[0 .. ^12] + + let obj = getType(typeSym[1]) + result = processType(newIdentNode(typeName), obj, jsonNode, true) + of "seq": + let seqT = typeSym[1] + let forLoopI = newIdentNode("i") + let indexerNode = createJsonIndexer(jsonNode, forLoopI) + let constructorNode = createConstructor(seqT, indexerNode) + + # Create a statement expression containing a for loop. + result = quote do: + ( + var list: `typeSym` = @[]; + # if `jsonNode`.kind != JArray: + # # TODO: Improve error message. + # raise newException(ValueError, "Expected a list") + for `forLoopI` in 0 .. <`jsonNode`.len: list.add(`constructorNode`); + list + ) + else: + # Generic type. + let obj = getType(typeSym) + result = processType(typeSym, obj, jsonNode, false) + of nnkSym: + let obj = getType(typeSym) + if obj.kind == nnkBracketExpr: + # When `Sym "Foo"` turns out to be a `ref object`. + result = createConstructor(obj, jsonNode) + else: + result = processType(typeSym, obj, jsonNode, false) + else: + doAssert false, "Unable to create constructor for: " & $typeSym.kind + + doAssert(not result.isNil(), "Constructor not initialised.") + +proc postProcess(node: NimNode): NimNode +proc postProcessValue(value: NimNode): NimNode = + ## Looks for object constructors and calls the ``postProcess`` procedure + ## on them. Otherwise it just returns the node as-is. + case value.kind + of nnkObjConstr: + result = postProcess(value) + else: + result = value + for i in 0 .. <len(result): + result[i] = postProcessValue(result[i]) + +proc postProcessExprColonExpr(exprColonExpr, resIdent: NimNode): NimNode = + ## Transform each field mapping in the ExprColonExpr into a simple + ## field assignment. Special processing is performed if the field mapping + ## has an if statement. + ## + ## ..code-block::plain + ## field: (if true: 12) -> if true: `resIdent`.field = 12 + expectKind(exprColonExpr, nnkExprColonExpr) + let fieldName = exprColonExpr[0] + let fieldValue = exprColonExpr[1] + case fieldValue.kind + of nnkIfStmt: + doAssert fieldValue.len == 1, "Cannot postProcess two ElifBranches." + expectKind(fieldValue[0], nnkElifBranch) + + let cond = fieldValue[0][0] + let bodyValue = postProcessValue(fieldValue[0][1]) + doAssert(bodyValue.kind != nnkNilLit) + result = + quote do: + if `cond`: + `resIdent`.`fieldName` = `bodyValue` + else: + let fieldValue = postProcessValue(fieldValue) + doAssert(fieldValue.kind != nnkNilLit) + result = + quote do: + `resIdent`.`fieldName` = `fieldValue` + + +proc postProcess(node: NimNode): NimNode = + ## The ``createConstructor`` proc creates a ObjConstr node which contains + ## if statements for fields that may not be assignable (due to an object + ## variant). Nim doesn't handle this, but may do in the future. + ## + ## For simplicity, we post process the object constructor into multiple + ## assignments. + ## + ## For example: + ## + ## ..code-block::plain + ## Object( (var res = Object(); + ## field: if true: 12 -> if true: res.field = 12; + ## ) res) + result = newNimNode(nnkStmtListExpr) + + expectKind(node, nnkObjConstr) + + # Create the type. + # -> var res = Object() + var resIdent = genSym(nskVar, "res") + # TODO: Placing `node[0]` inside quote is buggy + var resType = toIdentNode(node[0]) + + result.add( + quote do: + var `resIdent` = `resType`(); + ) + + # Process each ExprColonExpr. + for i in 1..<len(node): + result.add postProcessExprColonExpr(node[i], resIdent) + + # Return the `res` variable. + result.add( + quote do: + `resIdent` + ) + + +macro to*(node: JsonNode, T: typedesc): untyped = + ## `Unmarshals`:idx: the specified node into the object type specified. + ## + ## Known limitations: + ## + ## * Heterogeneous arrays are not supported. + ## * Sets in object variants are not supported. + ## + ## Example: + ## + ## .. code-block:: Nim + ## let jsonNode = parseJson(""" + ## { + ## "person": { + ## "name": "Nimmer", + ## "age": 21 + ## }, + ## "list": [1, 2, 3, 4] + ## } + ## """) + ## + ## type + ## Person = object + ## name: string + ## age: int + ## + ## Data = object + ## person: Person + ## list: seq[int] + ## + ## var data = to(jsonNode, Data) + ## doAssert data.person.name == "Nimmer" + ## doAssert data.person.age == 21 + ## doAssert data.list == @[1, 2, 3, 4] + + let typeNode = getType(T) + expectKind(typeNode, nnkBracketExpr) + doAssert(($typeNode[0]).normalize == "typedesc") + + result = createConstructor(typeNode[1], node) + # TODO: Rename postProcessValue and move it (?) + result = postProcessValue(result) + + # echo(toStrLit(result)) + when false: import os var s = newFileStream(paramStr(1), fmRead) @@ -1300,6 +1793,7 @@ when false: # To get that we shall use, obj["json"] when isMainModule: + # Note: Macro tests are in tests/stdlib/tjsonmacro.nim let testJson = parseJson"""{ "a": [1, 2, 3, 4], "b": "asd", "c": "\ud83c\udf83", "d": "\u00E6"}""" # nil passthrough diff --git a/lib/pure/nativesockets.nim b/lib/pure/nativesockets.nim index 0a7ffb3b3..7568408a6 100644 --- a/lib/pure/nativesockets.nim +++ b/lib/pure/nativesockets.nim @@ -12,7 +12,7 @@ # TODO: Clean up the exports a bit and everything else in general. -import os +import os, options when hostOS == "solaris": {.passl: "-lsocket -lnsl".} @@ -52,9 +52,11 @@ type Domain* = enum ## domain, which specifies the protocol family of the ## created socket. Other domains than those that are listed ## here are unsupported. - AF_UNIX, ## for local socket (using a file). Unsupported on Windows. + AF_UNSPEC = 0, ## unspecified domain (can be detected automatically by + ## some procedures, such as getaddrinfo) + AF_UNIX = 1, ## for local socket (using a file). Unsupported on Windows. AF_INET = 2, ## for network protocol IPv4 or - AF_INET6 = 23 ## for network protocol IPv6. + AF_INET6 = when defined(macosx): 30 else: 23 ## for network protocol IPv6. SockType* = enum ## second argument to `socket` proc SOCK_STREAM = 1, ## reliable stream-oriented service or Stream Sockets @@ -113,7 +115,7 @@ proc `==`*(a, b: Port): bool {.borrow.} proc `$`*(p: Port): string {.borrow.} ## returns the port number as a string -proc toInt*(domain: Domain): cint +proc toInt*(domain: Domain): cshort ## Converts the Domain enum to a platform-dependent ``cint``. proc toInt*(typ: SockType): cint @@ -123,12 +125,21 @@ proc toInt*(p: Protocol): cint ## Converts the Protocol enum to a platform-dependent ``cint``. when not useWinVersion: - proc toInt(domain: Domain): cint = + proc toInt(domain: Domain): cshort = case domain - of AF_UNIX: result = posix.AF_UNIX - of AF_INET: result = posix.AF_INET - of AF_INET6: result = posix.AF_INET6 - else: discard + of AF_UNSPEC: result = posix.AF_UNSPEC.cshort + of AF_UNIX: result = posix.AF_UNIX.cshort + of AF_INET: result = posix.AF_INET.cshort + of AF_INET6: result = posix.AF_INET6.cshort + + proc toKnownDomain*(family: cint): Option[Domain] = + ## Converts the platform-dependent ``cint`` to the Domain or none(), + ## if the ``cint`` is not known. + result = if family == posix.AF_UNSPEC: some(Domain.AF_UNSPEC) + elif family == posix.AF_UNIX: some(Domain.AF_UNIX) + elif family == posix.AF_INET: some(Domain.AF_INET) + elif family == posix.AF_INET6: some(Domain.AF_INET6) + else: none(Domain) proc toInt(typ: SockType): cint = case typ @@ -136,7 +147,6 @@ when not useWinVersion: of SOCK_DGRAM: result = posix.SOCK_DGRAM of SOCK_SEQPACKET: result = posix.SOCK_SEQPACKET of SOCK_RAW: result = posix.SOCK_RAW - else: discard proc toInt(p: Protocol): cint = case p @@ -146,18 +156,33 @@ when not useWinVersion: of IPPROTO_IPV6: result = posix.IPPROTO_IPV6 of IPPROTO_RAW: result = posix.IPPROTO_RAW of IPPROTO_ICMP: result = posix.IPPROTO_ICMP - else: discard else: - proc toInt(domain: Domain): cint = + proc toInt(domain: Domain): cshort = result = toU16(ord(domain)) + proc toKnownDomain*(family: cint): Option[Domain] = + ## Converts the platform-dependent ``cint`` to the Domain or none(), + ## if the ``cint`` is not known. + result = if family == winlean.AF_UNSPEC: some(Domain.AF_UNSPEC) + elif family == winlean.AF_INET: some(Domain.AF_INET) + elif family == winlean.AF_INET6: some(Domain.AF_INET6) + else: none(Domain) + proc toInt(typ: SockType): cint = result = cint(ord(typ)) proc toInt(p: Protocol): cint = result = cint(ord(p)) +proc toSockType*(protocol: Protocol): SockType = + result = case protocol + of IPPROTO_TCP: + SOCK_STREAM + of IPPROTO_UDP: + SOCK_DGRAM + of IPPROTO_IP, IPPROTO_IPV6, IPPROTO_RAW, IPPROTO_ICMP: + SOCK_RAW proc newNativeSocket*(domain: Domain = AF_INET, sockType: SockType = SOCK_STREAM, @@ -239,7 +264,7 @@ template ntohl*(x: int32): untyped {.deprecated.} = ## **Warning**: This template is deprecated since 0.14.0, IPv4 ## addresses are now treated as unsigned integers. Please use the unsigned ## version of this template. - cast[int32](ntohl(cast[uint32](x))) + cast[int32](nativesockets.ntohl(cast[uint32](x))) proc ntohs*(x: uint16): uint16 = ## Converts 16-bit unsigned integers from network to host byte order. On @@ -255,7 +280,7 @@ template ntohs*(x: int16): untyped {.deprecated.} = ## **Warning**: This template is deprecated since 0.14.0, where port ## numbers became unsigned integers. Please use the unsigned version of ## this template. - cast[int16](ntohs(cast[uint16](x))) + cast[int16](nativesockets.ntohs(cast[uint16](x))) template htonl*(x: int32): untyped {.deprecated.} = ## Converts 32-bit integers from host to network byte order. On machines @@ -392,14 +417,14 @@ proc getHostname*(): string {.tags: [ReadIOEffect].} = proc getSockDomain*(socket: SocketHandle): Domain = ## returns the socket's domain (AF_INET or AF_INET6). - var name: SockAddr + var name: Sockaddr_in6 var namelen = sizeof(name).SockLen if getsockname(socket, cast[ptr SockAddr](addr(name)), addr(namelen)) == -1'i32: raiseOSError(osLastError()) - if name.sa_family == nativeAfInet: + if name.sin6_family == nativeAfInet: result = AF_INET - elif name.sa_family == nativeAfInet6: + elif name.sin6_family == nativeAfInet6: result = AF_INET6 else: raiseOSError(osLastError(), "unknown socket family in getSockFamily") @@ -410,17 +435,23 @@ proc getAddrString*(sockAddr: ptr SockAddr): string = if sockAddr.sa_family == nativeAfInet: result = $inet_ntoa(cast[ptr Sockaddr_in](sockAddr).sin_addr) elif sockAddr.sa_family == nativeAfInet6: + let addrLen = when not useWinVersion: posix.INET6_ADDRSTRLEN + else: 46 # it's actually 46 in both cases + result = newString(addrLen) + let addr6 = addr cast[ptr Sockaddr_in6](sockAddr).sin6_addr when not useWinVersion: - # TODO: Windows - result = newString(posix.INET6_ADDRSTRLEN) - let addr6 = addr cast[ptr Sockaddr_in6](sockAddr).sin6_addr - discard posix.inet_ntop(posix.AF_INET6, addr6, result.cstring, - result.len.int32) + if posix.inet_ntop(posix.AF_INET6, addr6, addr result[0], + result.len.int32) == nil: + raiseOSError(osLastError()) if posix.IN6_IS_ADDR_V4MAPPED(addr6) != 0: result = result.substr("::ffff:".len) + else: + if winlean.inet_ntop(winlean.AF_INET6, addr6, addr result[0], + result.len.int32) == nil: + raiseOSError(osLastError()) + setLen(result, len(cstring(result))) else: - raiseOSError(osLastError(), "unknown socket family in getAddrString") - + raise newException(IOError, "Unknown socket family in getAddrString") proc getSockName*(socket: SocketHandle): Port = ## returns the socket's associated port number. diff --git a/lib/pure/net.nim b/lib/pure/net.nim index 56f8b9399..629e916fa 100644 --- a/lib/pure/net.nim +++ b/lib/pure/net.nim @@ -66,7 +66,7 @@ ## {.deadCodeElim: on.} -import nativesockets, os, strutils, parseutils, times, sets +import nativesockets, os, strutils, parseutils, times, sets, options export Port, `$`, `==` export Domain, SockType, Protocol @@ -237,6 +237,180 @@ proc newSocket*(domain: Domain = AF_INET, sockType: SockType = SOCK_STREAM, raiseOSError(osLastError()) result = newSocket(fd, domain, sockType, protocol, buffered) +proc parseIPv4Address(address_str: string): IpAddress = + ## Parses IPv4 adresses + ## Raises EInvalidValue on errors + var + byteCount = 0 + currentByte:uint16 = 0 + seperatorValid = false + + result.family = IpAddressFamily.IPv4 + + for i in 0 .. high(address_str): + if address_str[i] in strutils.Digits: # Character is a number + currentByte = currentByte * 10 + + cast[uint16](ord(address_str[i]) - ord('0')) + if currentByte > 255'u16: + raise newException(ValueError, + "Invalid IP Address. Value is out of range") + seperatorValid = true + elif address_str[i] == '.': # IPv4 address separator + if not seperatorValid or byteCount >= 3: + raise newException(ValueError, + "Invalid IP Address. The address consists of too many groups") + result.address_v4[byteCount] = cast[uint8](currentByte) + currentByte = 0 + byteCount.inc + seperatorValid = false + else: + raise newException(ValueError, + "Invalid IP Address. Address contains an invalid character") + + if byteCount != 3 or not seperatorValid: + raise newException(ValueError, "Invalid IP Address") + result.address_v4[byteCount] = cast[uint8](currentByte) + +proc parseIPv6Address(address_str: string): IpAddress = + ## Parses IPv6 adresses + ## Raises EInvalidValue on errors + result.family = IpAddressFamily.IPv6 + if address_str.len < 2: + raise newException(ValueError, "Invalid IP Address") + + var + groupCount = 0 + currentGroupStart = 0 + currentShort:uint32 = 0 + seperatorValid = true + dualColonGroup = -1 + lastWasColon = false + v4StartPos = -1 + byteCount = 0 + + for i,c in address_str: + if c == ':': + if not seperatorValid: + raise newException(ValueError, + "Invalid IP Address. Address contains an invalid seperator") + if lastWasColon: + if dualColonGroup != -1: + raise newException(ValueError, + "Invalid IP Address. Address contains more than one \"::\" seperator") + dualColonGroup = groupCount + seperatorValid = false + elif i != 0 and i != high(address_str): + if groupCount >= 8: + raise newException(ValueError, + "Invalid IP Address. The address consists of too many groups") + result.address_v6[groupCount*2] = cast[uint8](currentShort shr 8) + result.address_v6[groupCount*2+1] = cast[uint8](currentShort and 0xFF) + currentShort = 0 + groupCount.inc() + if dualColonGroup != -1: seperatorValid = false + elif i == 0: # only valid if address starts with :: + if address_str[1] != ':': + raise newException(ValueError, + "Invalid IP Address. Address may not start with \":\"") + else: # i == high(address_str) - only valid if address ends with :: + if address_str[high(address_str)-1] != ':': + raise newException(ValueError, + "Invalid IP Address. Address may not end with \":\"") + lastWasColon = true + currentGroupStart = i + 1 + elif c == '.': # Switch to parse IPv4 mode + if i < 3 or not seperatorValid or groupCount >= 7: + raise newException(ValueError, "Invalid IP Address") + v4StartPos = currentGroupStart + currentShort = 0 + seperatorValid = false + break + elif c in strutils.HexDigits: + if c in strutils.Digits: # Normal digit + currentShort = (currentShort shl 4) + cast[uint32](ord(c) - ord('0')) + elif c >= 'a' and c <= 'f': # Lower case hex + currentShort = (currentShort shl 4) + cast[uint32](ord(c) - ord('a')) + 10 + else: # Upper case hex + currentShort = (currentShort shl 4) + cast[uint32](ord(c) - ord('A')) + 10 + if currentShort > 65535'u32: + raise newException(ValueError, + "Invalid IP Address. Value is out of range") + lastWasColon = false + seperatorValid = true + else: + raise newException(ValueError, + "Invalid IP Address. Address contains an invalid character") + + + if v4StartPos == -1: # Don't parse v4. Copy the remaining v6 stuff + if seperatorValid: # Copy remaining data + if groupCount >= 8: + raise newException(ValueError, + "Invalid IP Address. The address consists of too many groups") + result.address_v6[groupCount*2] = cast[uint8](currentShort shr 8) + result.address_v6[groupCount*2+1] = cast[uint8](currentShort and 0xFF) + groupCount.inc() + else: # Must parse IPv4 address + for i,c in address_str[v4StartPos..high(address_str)]: + if c in strutils.Digits: # Character is a number + currentShort = currentShort * 10 + cast[uint32](ord(c) - ord('0')) + if currentShort > 255'u32: + raise newException(ValueError, + "Invalid IP Address. Value is out of range") + seperatorValid = true + elif c == '.': # IPv4 address separator + if not seperatorValid or byteCount >= 3: + raise newException(ValueError, "Invalid IP Address") + result.address_v6[groupCount*2 + byteCount] = cast[uint8](currentShort) + currentShort = 0 + byteCount.inc() + seperatorValid = false + else: # Invalid character + raise newException(ValueError, + "Invalid IP Address. Address contains an invalid character") + + if byteCount != 3 or not seperatorValid: + raise newException(ValueError, "Invalid IP Address") + result.address_v6[groupCount*2 + byteCount] = cast[uint8](currentShort) + groupCount += 2 + + # Shift and fill zeros in case of :: + if groupCount > 8: + raise newException(ValueError, + "Invalid IP Address. The address consists of too many groups") + elif groupCount < 8: # must fill + if dualColonGroup == -1: + raise newException(ValueError, + "Invalid IP Address. The address consists of too few groups") + var toFill = 8 - groupCount # The number of groups to fill + var toShift = groupCount - dualColonGroup # Nr of known groups after :: + for i in 0..2*toShift-1: # shift + result.address_v6[15-i] = result.address_v6[groupCount*2-i-1] + for i in 0..2*toFill-1: # fill with 0s + result.address_v6[dualColonGroup*2+i] = 0 + elif dualColonGroup != -1: + raise newException(ValueError, + "Invalid IP Address. The address consists of too many groups") + +proc parseIpAddress*(address_str: string): IpAddress = + ## Parses an IP address + ## Raises EInvalidValue on error + if address_str == nil: + raise newException(ValueError, "IP Address string is nil") + if address_str.contains(':'): + return parseIPv6Address(address_str) + else: + return parseIPv4Address(address_str) + +proc isIpAddress*(address_str: string): bool {.tags: [].} = + ## Checks if a string is an IP address + ## Returns true if it is, false otherwise + try: + discard parseIpAddress(address_str) + except ValueError: + return false + return true + when defineSsl: CRYPTO_malloc_init() SslLibraryInit() @@ -438,9 +612,12 @@ when defineSsl: raiseSSLError() proc wrapConnectedSocket*(ctx: SSLContext, socket: Socket, - handshake: SslHandshakeType) = + handshake: SslHandshakeType, + hostname: string = nil) = ## Wraps a connected socket in an SSL context. This function effectively ## turns ``socket`` into an SSL socket. + ## ``hostname`` should be specified so that the client knows which hostname + ## the server certificate should be validated against. ## ## This should be called on a connected socket, and will perform ## an SSL handshake immediately. @@ -450,6 +627,10 @@ when defineSsl: wrapSocket(ctx, socket) case handshake of handshakeAsClient: + if not hostname.isNil and not isIpAddress(hostname): + # Discard result in case OpenSSL version doesn't support SNI, or we're + # not using TLSv1+ + discard SSL_set_tlsext_host_name(socket.sslHandle, hostname) let ret = SSLConnect(socket.sslHandle) socketError(socket, ret) of handshakeAsServer: @@ -669,7 +850,7 @@ proc close*(socket: Socket) = ## Closes a socket. try: when defineSsl: - if socket.isSSL: + if socket.isSSL and socket.sslHandle != nil: ErrClearError() # As we are closing the underlying socket immediately afterwards, # it is valid, under the TLS standard, to perform a unidirectional @@ -1302,181 +1483,63 @@ proc `$`*(address: IpAddress): string = mask = mask shr 4 printedLastGroup = true -proc parseIPv4Address(address_str: string): IpAddress = - ## Parses IPv4 adresses - ## Raises EInvalidValue on errors - var - byteCount = 0 - currentByte:uint16 = 0 - seperatorValid = false +proc dial*(address: string, port: Port, + protocol = IPPROTO_TCP, buffered = true): Socket + {.tags: [ReadIOEffect, WriteIOEffect].} = + ## Establishes connection to the specified ``address``:``port`` pair via the + ## specified protocol. The procedure iterates through possible + ## resolutions of the ``address`` until it succeeds, meaning that it + ## seamlessly works with both IPv4 and IPv6. + ## Returns Socket ready to send or receive data. + let sockType = protocol.toSockType() + + let aiList = getAddrInfo(address, port, AF_UNSPEC, sockType, protocol) + + var fdPerDomain: array[low(Domain).ord..high(Domain).ord, SocketHandle] + for i in low(fdPerDomain)..high(fdPerDomain): + fdPerDomain[i] = osInvalidSocket + template closeUnusedFds(domainToKeep = -1) {.dirty.} = + for i, fd in fdPerDomain: + if fd != osInvalidSocket and i != domainToKeep: + fd.close() - result.family = IpAddressFamily.IPv4 - - for i in 0 .. high(address_str): - if address_str[i] in strutils.Digits: # Character is a number - currentByte = currentByte * 10 + - cast[uint16](ord(address_str[i]) - ord('0')) - if currentByte > 255'u16: - raise newException(ValueError, - "Invalid IP Address. Value is out of range") - seperatorValid = true - elif address_str[i] == '.': # IPv4 address separator - if not seperatorValid or byteCount >= 3: - raise newException(ValueError, - "Invalid IP Address. The address consists of too many groups") - result.address_v4[byteCount] = cast[uint8](currentByte) - currentByte = 0 - byteCount.inc - seperatorValid = false - else: - raise newException(ValueError, - "Invalid IP Address. Address contains an invalid character") - - if byteCount != 3 or not seperatorValid: - raise newException(ValueError, "Invalid IP Address") - result.address_v4[byteCount] = cast[uint8](currentByte) - -proc parseIPv6Address(address_str: string): IpAddress = - ## Parses IPv6 adresses - ## Raises EInvalidValue on errors - result.family = IpAddressFamily.IPv6 - if address_str.len < 2: - raise newException(ValueError, "Invalid IP Address") - - var - groupCount = 0 - currentGroupStart = 0 - currentShort:uint32 = 0 - seperatorValid = true - dualColonGroup = -1 - lastWasColon = false - v4StartPos = -1 - byteCount = 0 - - for i,c in address_str: - if c == ':': - if not seperatorValid: - raise newException(ValueError, - "Invalid IP Address. Address contains an invalid seperator") - if lastWasColon: - if dualColonGroup != -1: - raise newException(ValueError, - "Invalid IP Address. Address contains more than one \"::\" seperator") - dualColonGroup = groupCount - seperatorValid = false - elif i != 0 and i != high(address_str): - if groupCount >= 8: - raise newException(ValueError, - "Invalid IP Address. The address consists of too many groups") - result.address_v6[groupCount*2] = cast[uint8](currentShort shr 8) - result.address_v6[groupCount*2+1] = cast[uint8](currentShort and 0xFF) - currentShort = 0 - groupCount.inc() - if dualColonGroup != -1: seperatorValid = false - elif i == 0: # only valid if address starts with :: - if address_str[1] != ':': - raise newException(ValueError, - "Invalid IP Address. Address may not start with \":\"") - else: # i == high(address_str) - only valid if address ends with :: - if address_str[high(address_str)-1] != ':': - raise newException(ValueError, - "Invalid IP Address. Address may not end with \":\"") - lastWasColon = true - currentGroupStart = i + 1 - elif c == '.': # Switch to parse IPv4 mode - if i < 3 or not seperatorValid or groupCount >= 7: - raise newException(ValueError, "Invalid IP Address") - v4StartPos = currentGroupStart - currentShort = 0 - seperatorValid = false + var success = false + var lastError: OSErrorCode + var it = aiList + var domain: Domain + var lastFd: SocketHandle + while it != nil: + let domainOpt = it.ai_family.toKnownDomain() + if domainOpt.isNone: + it = it.ai_next + continue + domain = domainOpt.unsafeGet() + lastFd = fdPerDomain[ord(domain)] + if lastFd == osInvalidSocket: + lastFd = newNativeSocket(domain, sockType, protocol) + if lastFd == osInvalidSocket: + # we always raise if socket creation failed, because it means a + # network system problem (e.g. not enough FDs), and not an unreachable + # address. + let err = osLastError() + freeAddrInfo(aiList) + closeUnusedFds() + raiseOSError(err) + fdPerDomain[ord(domain)] = lastFd + if connect(lastFd, it.ai_addr, it.ai_addrlen.SockLen) == 0'i32: + success = true break - elif c in strutils.HexDigits: - if c in strutils.Digits: # Normal digit - currentShort = (currentShort shl 4) + cast[uint32](ord(c) - ord('0')) - elif c >= 'a' and c <= 'f': # Lower case hex - currentShort = (currentShort shl 4) + cast[uint32](ord(c) - ord('a')) + 10 - else: # Upper case hex - currentShort = (currentShort shl 4) + cast[uint32](ord(c) - ord('A')) + 10 - if currentShort > 65535'u32: - raise newException(ValueError, - "Invalid IP Address. Value is out of range") - lastWasColon = false - seperatorValid = true - else: - raise newException(ValueError, - "Invalid IP Address. Address contains an invalid character") - - - if v4StartPos == -1: # Don't parse v4. Copy the remaining v6 stuff - if seperatorValid: # Copy remaining data - if groupCount >= 8: - raise newException(ValueError, - "Invalid IP Address. The address consists of too many groups") - result.address_v6[groupCount*2] = cast[uint8](currentShort shr 8) - result.address_v6[groupCount*2+1] = cast[uint8](currentShort and 0xFF) - groupCount.inc() - else: # Must parse IPv4 address - for i,c in address_str[v4StartPos..high(address_str)]: - if c in strutils.Digits: # Character is a number - currentShort = currentShort * 10 + cast[uint32](ord(c) - ord('0')) - if currentShort > 255'u32: - raise newException(ValueError, - "Invalid IP Address. Value is out of range") - seperatorValid = true - elif c == '.': # IPv4 address separator - if not seperatorValid or byteCount >= 3: - raise newException(ValueError, "Invalid IP Address") - result.address_v6[groupCount*2 + byteCount] = cast[uint8](currentShort) - currentShort = 0 - byteCount.inc() - seperatorValid = false - else: # Invalid character - raise newException(ValueError, - "Invalid IP Address. Address contains an invalid character") - - if byteCount != 3 or not seperatorValid: - raise newException(ValueError, "Invalid IP Address") - result.address_v6[groupCount*2 + byteCount] = cast[uint8](currentShort) - groupCount += 2 - - # Shift and fill zeros in case of :: - if groupCount > 8: - raise newException(ValueError, - "Invalid IP Address. The address consists of too many groups") - elif groupCount < 8: # must fill - if dualColonGroup == -1: - raise newException(ValueError, - "Invalid IP Address. The address consists of too few groups") - var toFill = 8 - groupCount # The number of groups to fill - var toShift = groupCount - dualColonGroup # Nr of known groups after :: - for i in 0..2*toShift-1: # shift - result.address_v6[15-i] = result.address_v6[groupCount*2-i-1] - for i in 0..2*toFill-1: # fill with 0s - result.address_v6[dualColonGroup*2+i] = 0 - elif dualColonGroup != -1: - raise newException(ValueError, - "Invalid IP Address. The address consists of too many groups") - + lastError = osLastError() + it = it.ai_next + freeAddrInfo(aiList) + closeUnusedFds(ord(domain)) -proc parseIpAddress*(address_str: string): IpAddress = - ## Parses an IP address - ## Raises EInvalidValue on error - if address_str == nil: - raise newException(ValueError, "IP Address string is nil") - if address_str.contains(':'): - return parseIPv6Address(address_str) + if success: + result = newSocket(lastFd, domain, sockType, protocol) + elif lastError != 0.OSErrorCode: + raiseOSError(lastError) else: - return parseIPv4Address(address_str) - -proc isIpAddress*(address_str: string): bool {.tags: [].} = - ## Checks if a string is an IP address - ## Returns true if it is, false otherwise - try: - discard parseIpAddress(address_str) - except ValueError: - return false - return true - + raise newException(IOError, "Couldn't resolve address: " & address) proc connect*(socket: Socket, address: string, port = Port(0)) {.tags: [ReadIOEffect].} = diff --git a/lib/pure/oids.nim b/lib/pure/oids.nim index e4c97b260..60b53dbe0 100644 --- a/lib/pure/oids.nim +++ b/lib/pure/oids.nim @@ -69,10 +69,9 @@ var proc genOid*(): Oid = ## generates a new OID. proc rand(): cint {.importc: "rand", header: "<stdlib.h>", nodecl.} - proc gettime(dummy: ptr cint): cint {.importc: "time", header: "<time.h>".} proc srand(seed: cint) {.importc: "srand", header: "<stdlib.h>", nodecl.} - var t = gettime(nil) + var t = getTime().int32 var i = int32(atomicInc(incr)) diff --git a/lib/pure/os.nim b/lib/pure/os.nim index 82acb2a59..98b6aa309 100644 --- a/lib/pure/os.nim +++ b/lib/pure/os.nim @@ -1716,7 +1716,7 @@ template rawToFormalFileInfo(rawInfo, path, formalInfo): untyped = formalInfo.permissions.incl(formalMode) formalInfo.id = (rawInfo.st_dev, rawInfo.st_ino) formalInfo.size = rawInfo.st_size - formalInfo.linkCount = rawInfo.st_Nlink + formalInfo.linkCount = rawInfo.st_Nlink.BiggestInt formalInfo.lastAccessTime = rawInfo.st_atime formalInfo.lastWriteTime = rawInfo.st_mtime formalInfo.creationTime = rawInfo.st_ctime diff --git a/lib/pure/ospaths.nim b/lib/pure/ospaths.nim index 71991e35a..7720fb2a6 100644 --- a/lib/pure/ospaths.nim +++ b/lib/pure/ospaths.nim @@ -569,7 +569,7 @@ when declared(getEnv) or defined(nimscript): ## ``["exe", "cmd", "bat"]``, on Posix ``[""]``. proc findExe*(exe: string, followSymlinks: bool = true; - extensions=ExeExts): string {. + extensions: openarray[string]=ExeExts): string {. tags: [ReadDirEffect, ReadEnvEffect, ReadIOEffect].} = ## Searches for `exe` in the current working directory and then ## in directories listed in the ``PATH`` environment variable. diff --git a/lib/pure/osproc.nim b/lib/pure/osproc.nim index 0f37f8fe0..c94a65a63 100644 --- a/lib/pure/osproc.nim +++ b/lib/pure/osproc.nim @@ -209,9 +209,16 @@ proc waitForExit*(p: Process, timeout: int = -1): int {.rtl, ## ## **Warning**: Be careful when using waitForExit for processes created without ## poParentStreams because they may fill output buffers, causing deadlock. + ## + ## On posix, if the process has exited because of a signal, 128 + signal + ## number will be returned. + proc peekExitCode*(p: Process): int {.tags: [].} ## return -1 if the process is still running. Otherwise the process' exit code + ## + ## On posix, if the process has exited because of a signal, 128 + signal + ## number will be returned. proc inputStream*(p: Process): Stream {.rtl, extern: "nosp$1", tags: [].} ## returns ``p``'s input stream for writing to. @@ -328,7 +335,8 @@ proc execProcesses*(cmds: openArray[string], if afterRunEvent != nil: afterRunEvent(i, p) close(p) -proc select*(readfds: var seq[Process], timeout = 500): int {.benign.} +proc select*(readfds: var seq[Process], timeout = 500): int + {.benign, deprecated.} ## `select` with a sensible Nim interface. `timeout` is in milliseconds. ## Specify -1 for no timeout. Returns the number of processes that are ## ready to read from. The processes that are ready to be read from are @@ -336,6 +344,9 @@ proc select*(readfds: var seq[Process], timeout = 500): int {.benign.} ## ## **Warning**: This function may give unexpected or completely wrong ## results on Windows. + ## + ## **Deprecated since version 0.17.0**: This procedure isn't cross-platform + ## and so should not be used in newly written code. when not defined(useNimRtl): proc execProcess(command: string, @@ -679,6 +690,16 @@ elif not defined(useNimRtl): readIdx = 0 writeIdx = 1 + proc isExitStatus(status: cint): bool = + WIFEXITED(status) or WIFSIGNALED(status) + + proc exitStatus(status: cint): cint = + if WIFSIGNALED(status): + # like the shell! + 128 + WTERMSIG(status) + else: + WEXITSTATUS(status) + proc envToCStringArray(t: StringTableRef): cstringArray = result = cast[cstringArray](alloc0((t.len + 1) * sizeof(cstring))) var i = 0 @@ -967,7 +988,7 @@ elif not defined(useNimRtl): var status : cint = 1 ret = waitpid(p.id, status, WNOHANG) if ret == int(p.id): - if WIFEXITED(status): + if isExitStatus(status): p.exitStatus = status return false else: @@ -990,7 +1011,9 @@ elif not defined(useNimRtl): import kqueue, times proc waitForExit(p: Process, timeout: int = -1): int = - if p.exitStatus != -3: return((p.exitStatus and 0xFF00) shr 8) + if p.exitStatus != -3: + return exitStatus(p.exitStatus) + if timeout == -1: var status : cint = 1 if waitpid(p.id, status, 0) < 0: @@ -1041,7 +1064,7 @@ elif not defined(useNimRtl): finally: discard posix.close(kqFD) - result = ((p.exitStatus and 0xFF00) shr 8) + result = exitStatus(p.exitStatus) else: import times @@ -1077,7 +1100,9 @@ elif not defined(useNimRtl): # ``waitPid`` fails if the process is not running anymore. But then # ``running`` probably set ``p.exitStatus`` for us. Since ``p.exitStatus`` is # initialized with -3, wrong success exit codes are prevented. - if p.exitStatus != -3: return((p.exitStatus and 0xFF00) shr 8) + if p.exitStatus != -3: + return exitStatus(p.exitStatus) + if timeout == -1: var status : cint = 1 if waitpid(p.id, status, 0) < 0: @@ -1151,17 +1176,19 @@ elif not defined(useNimRtl): if sigprocmask(SIG_UNBLOCK, nmask, omask) == -1: raiseOSError(osLastError()) - result = ((p.exitStatus and 0xFF00) shr 8) + result = exitStatus(p.exitStatus) proc peekExitCode(p: Process): int = var status = cint(0) result = -1 - if p.exitStatus != -3: return((p.exitStatus and 0xFF00) shr 8) + if p.exitStatus != -3: + return exitStatus(p.exitStatus) + var ret = waitpid(p.id, status, WNOHANG) if ret > 0: - if WIFEXITED(status): + if isExitStatus(status): p.exitStatus = status - result = (status and 0xFF00) shr 8 + result = exitStatus(status) proc createStream(stream: var Stream, handle: var FileHandle, fileMode: FileMode) = @@ -1189,7 +1216,8 @@ elif not defined(useNimRtl): proc execCmd(command: string): int = when defined(linux): - result = csystem(command) shr 8 + let tmp = csystem(command) + result = if tmp == -1: tmp else: exitStatus(tmp) else: result = csystem(command) diff --git a/lib/pure/parseutils.nim b/lib/pure/parseutils.nim index 8d53a0360..b78e8d000 100644 --- a/lib/pure/parseutils.nim +++ b/lib/pure/parseutils.nim @@ -201,7 +201,7 @@ proc parseWhile*(s: string, token: var string, validChars: set[char], proc captureBetween*(s: string, first: char, second = '\0', start = 0): string = ## Finds the first occurrence of ``first``, then returns everything from there - ## up to ``second``(if ``second`` is '\0', then ``first`` is used). + ## up to ``second`` (if ``second`` is '\0', then ``first`` is used). var i = skipUntil(s, first, start)+1+start result = "" discard s.parseUntil(result, if second == '\0': first else: second, i) diff --git a/lib/pure/strutils.nim b/lib/pure/strutils.nim index 9383675f4..458c22f3a 100644 --- a/lib/pure/strutils.nim +++ b/lib/pure/strutils.nim @@ -1881,6 +1881,8 @@ proc formatFloat*(f: float, format: FloatFormatMode = ffDefault, ## of significant digits to be printed. ## `precision`'s default value is the maximum number of meaningful digits ## after the decimal point for Nim's ``float`` type. + ## + ## If ``precision == 0``, it tries to format it nicely. result = formatBiggestFloat(f, format, precision, decimalSep) proc trimZeros*(x: var string) {.noSideEffect.} = diff --git a/lib/pure/times.nim b/lib/pure/times.nim index 1b088c0ac..bad003a3e 100644 --- a/lib/pure/times.nim +++ b/lib/pure/times.nim @@ -47,15 +47,26 @@ type dMon, dTue, dWed, dThu, dFri, dSat, dSun when defined(posix) and not defined(JS): - type - TimeImpl {.importc: "time_t", header: "<time.h>".} = int - Time* = distinct TimeImpl ## distinct type that represents a time - ## measured as number of seconds since the epoch + when defined(linux) and defined(amd64): + type + TimeImpl {.importc: "time_t", header: "<time.h>".} = clong + Time* = distinct TimeImpl ## distinct type that represents a time + ## measured as number of seconds since the epoch + + Timeval {.importc: "struct timeval", + header: "<sys/select.h>".} = object ## struct timeval + tv_sec: clong ## Seconds. + tv_usec: clong ## Microseconds. + else: + type + TimeImpl {.importc: "time_t", header: "<time.h>".} = int + Time* = distinct TimeImpl ## distinct type that represents a time + ## measured as number of seconds since the epoch - Timeval {.importc: "struct timeval", - header: "<sys/select.h>".} = object ## struct timeval - tv_sec: int ## Seconds. - tv_usec: int ## Microseconds. + Timeval {.importc: "struct timeval", + header: "<sys/select.h>".} = object ## struct timeval + tv_sec: int ## Seconds. + tv_usec: int ## Microseconds. # we cannot import posix.nim here, because posix.nim depends on times.nim. # Ok, we could, but I don't want circular dependencies. @@ -1103,7 +1114,7 @@ when not defined(JS): when defined(freebsd) or defined(netbsd) or defined(openbsd) or defined(macosx): type - StructTM {.importc: "struct tm", final.} = object + StructTM {.importc: "struct tm".} = object second {.importc: "tm_sec".}, minute {.importc: "tm_min".}, hour {.importc: "tm_hour".}, @@ -1116,7 +1127,7 @@ when not defined(JS): gmtoff {.importc: "tm_gmtoff".}: clong else: type - StructTM {.importc: "struct tm", final.} = object + StructTM {.importc: "struct tm".} = object second {.importc: "tm_sec".}, minute {.importc: "tm_min".}, hour {.importc: "tm_hour".}, @@ -1126,6 +1137,9 @@ when not defined(JS): weekday {.importc: "tm_wday".}, yearday {.importc: "tm_yday".}, isdst {.importc: "tm_isdst".}: cint + when defined(linux) and defined(amd64): + gmtoff {.importc: "tm_gmtoff".}: clong + zone {.importc: "tm_zone".}: cstring type TimeInfoPtr = ptr StructTM Clock {.importc: "clock_t".} = distinct int diff --git a/lib/pure/uri.nim b/lib/pure/uri.nim index ba745cfd3..c7e0ed1da 100644 --- a/lib/pure/uri.nim +++ b/lib/pure/uri.nim @@ -50,6 +50,7 @@ proc add*(url: var Url, a: Url) {.deprecated.} = proc parseAuthority(authority: string, result: var Uri) = var i = 0 var inPort = false + var inIPv6 = false while true: case authority[i] of '@': @@ -59,7 +60,14 @@ proc parseAuthority(authority: string, result: var Uri) = result.hostname.setLen(0) inPort = false of ':': - inPort = true + if inIPv6: + result.hostname.add(authority[i]) + else: + inPort = true + of '[': + inIPv6 = true + of ']': + inIPv6 = false of '\0': break else: if inPort: @@ -346,6 +354,17 @@ when isMainModule: doAssert($test == str) block: + # IPv6 address + let str = "foo://[::1]:1234/bar?baz=true&qux#quux" + let uri = parseUri(str) + doAssert uri.scheme == "foo" + doAssert uri.hostname == "::1" + doAssert uri.port == "1234" + doAssert uri.path == "/bar" + doAssert uri.query == "baz=true&qux" + doAssert uri.anchor == "quux" + + block: let str = "urn:example:animal:ferret:nose" let test = parseUri(str) doAssert test.scheme == "urn" |