diff options
Diffstat (limited to 'lib/pure')
60 files changed, 7413 insertions, 1151 deletions
diff --git a/lib/pure/algorithm.nim b/lib/pure/algorithm.nim index c0acae138..b83daf245 100644 --- a/lib/pure/algorithm.nim +++ b/lib/pure/algorithm.nim @@ -114,7 +114,7 @@ proc lowerBound*[T](a: openArray[T], key: T, cmp: proc(x,y: T): int {.closure.}) proc lowerBound*[T](a: openArray[T], key: T): int = lowerBound(a, key, cmp[T]) proc merge[T](a, b: var openArray[T], lo, m, hi: int, cmp: proc (x, y: T): int {.closure.}, order: SortOrder) = - template `<-` (a, b: expr) = + template `<-` (a, b) = when false: a = b elif onlySafeCode: @@ -206,7 +206,7 @@ proc sorted*[T](a: openArray[T], cmp: proc(x, y: T): int {.closure.}, result[i] = a[i] sort(result, cmp, order) -template sortedByIt*(seq1, op: expr): expr = +template sortedByIt*(seq1, op: untyped): untyped = ## Convenience template around the ``sorted`` proc to reduce typing. ## ## The template injects the ``it`` variable which you can use directly in an @@ -231,7 +231,7 @@ template sortedByIt*(seq1, op: expr): expr = ## ## echo people.sortedByIt((it.age, it.name)) ## - var result {.gensym.} = sorted(seq1, proc(x, y: type(seq1[0])): int = + var result = sorted(seq1, proc(x, y: type(seq1[0])): int = var it {.inject.} = x let a = op it = y diff --git a/lib/pure/asyncdispatch.nim b/lib/pure/asyncdispatch.nim index cc337452f..79bc1b96d 100644 --- a/lib/pure/asyncdispatch.nim +++ b/lib/pure/asyncdispatch.nim @@ -9,9 +9,9 @@ include "system/inclrtl" -import os, oids, tables, strutils, macros, times +import os, oids, tables, strutils, macros, times, heapqueue -import nativesockets, net +import nativesockets, net, queues export Port, SocketFlag @@ -155,6 +155,9 @@ type when not defined(release): var currentID = 0 + +proc callSoon*(cbproc: proc ()) {.gcsafe.} + proc newFuture*[T](fromProc: string = "unspecified"): Future[T] = ## Creates a new future. ## @@ -257,7 +260,7 @@ proc `callback=`*(future: FutureBase, cb: proc () {.closure,gcsafe.}) = ## passes ``future`` as a param to the callback. future.cb = cb if future.finished: - future.cb() + callSoon(future.cb) proc `callback=`*[T](future: Future[T], cb: proc (future: Future[T]) {.closure,gcsafe.}) = @@ -352,28 +355,88 @@ proc `or`*[T, Y](fut1: Future[T], fut2: Future[Y]): Future[void] = fut2.callback = cb return retFuture +proc all*[T](futs: varargs[Future[T]]): auto = + ## Returns a future which will complete once + ## all futures in ``futs`` complete. + ## + ## If the awaited futures are not ``Future[void]``, the returned future + ## will hold the values of all awaited futures in a sequence. + ## + ## If the awaited futures *are* ``Future[void]``, + ## this proc returns ``Future[void]``. + + when T is void: + var + retFuture = newFuture[void]("asyncdispatch.all") + completedFutures = 0 + + let totalFutures = len(futs) + + for fut in futs: + fut.callback = proc(f: Future[T]) = + inc(completedFutures) + + if completedFutures == totalFutures: + retFuture.complete() + + return retFuture + + else: + var + retFuture = newFuture[seq[T]]("asyncdispatch.all") + retValues = newSeq[T](len(futs)) + completedFutures = 0 + + for i, fut in futs: + proc setCallback(i: int) = + fut.callback = proc(f: Future[T]) = + retValues[i] = f.read() + inc(completedFutures) + + if completedFutures == len(retValues): + retFuture.complete(retValues) + + setCallback(i) + + return retFuture + type PDispatcherBase = ref object of RootRef - timers: seq[tuple[finishAt: float, fut: Future[void]]] - -proc processTimers(p: PDispatcherBase) = - var oldTimers = p.timers - p.timers = @[] - for t in oldTimers: - if epochTime() >= t.finishAt: - t.fut.complete() - else: - p.timers.add(t) + timers: HeapQueue[tuple[finishAt: float, fut: Future[void]]] + callbacks: Queue[proc ()] + +proc processTimers(p: PDispatcherBase) {.inline.} = + while p.timers.len > 0 and epochTime() >= p.timers[0].finishAt: + p.timers.pop().fut.complete() + +proc processPendingCallbacks(p: PDispatcherBase) = + while p.callbacks.len > 0: + var cb = p.callbacks.dequeue() + cb() + +proc adjustedTimeout(p: PDispatcherBase, timeout: int): int {.inline.} = + # If dispatcher has active timers this proc returns the timeout + # of the nearest timer. Returns `timeout` otherwise. + result = timeout + if p.timers.len > 0: + let timerTimeout = p.timers[0].finishAt + let curTime = epochTime() + if timeout == -1 or (curTime + (timeout / 1000)) > timerTimeout: + result = int((timerTimeout - curTime) * 1000) + if result < 0: result = 0 when defined(windows) or defined(nimdoc): import winlean, sets, hashes type - CompletionKey = Dword + CompletionKey = ULONG_PTR CompletionData* = object fd*: AsyncFD # TODO: Rename this. cb*: proc (fd: AsyncFD, bytesTransferred: Dword, errcode: OSErrorCode) {.closure,gcsafe.} + cell*: ForeignCell # we need this `cell` to protect our `cb` environment, + # when using RegisterWaitForSingleObject, because + # waiting is done in different thread. PDispatcher* = ref object of PDispatcherBase ioPort: Handle @@ -385,6 +448,15 @@ when defined(windows) or defined(nimdoc): PCustomOverlapped* = ref CustomOverlapped AsyncFD* = distinct int + + PostCallbackData = object + ioPort: Handle + handleFd: AsyncFD + waitFd: Handle + ovl: PCustomOverlapped + PostCallbackDataPtr = ptr PostCallbackData + + Callback = proc (fd: AsyncFD): bool {.closure,gcsafe.} {.deprecated: [TCompletionKey: CompletionKey, TAsyncFD: AsyncFD, TCustomOverlapped: CustomOverlapped, TCompletionData: CompletionData].} @@ -396,7 +468,8 @@ when defined(windows) or defined(nimdoc): new result result.ioPort = createIoCompletionPort(INVALID_HANDLE_VALUE, 0, 0, 1) result.handles = initSet[AsyncFD]() - result.timers = @[] + result.timers.newHeapQueue() + result.callbacks = initQueue[proc ()](64) var gDisp{.threadvar.}: PDispatcher ## Global dispatcher proc getGlobalDispatcher*(): PDispatcher = @@ -423,15 +496,17 @@ when defined(windows) or defined(nimdoc): proc poll*(timeout = 500) = ## Waits for completion events and processes them. let p = getGlobalDispatcher() - if p.handles.len == 0 and p.timers.len == 0: + if p.handles.len == 0 and p.timers.len == 0 and p.callbacks.len == 0: raise newException(ValueError, "No handles or timers registered in dispatcher.") - let llTimeout = - if timeout == -1: winlean.INFINITE - else: timeout.int32 + let at = p.adjustedTimeout(timeout) + var llTimeout = + if at == -1: winlean.INFINITE + else: at.int32 + var lpNumberOfBytesTransferred: Dword - var lpCompletionKey: ULONG + var lpCompletionKey: ULONG_PTR var customOverlapped: PCustomOverlapped let res = getQueuedCompletionStatus(p.ioPort, addr lpNumberOfBytesTransferred, addr lpCompletionKey, @@ -445,6 +520,13 @@ when defined(windows) or defined(nimdoc): customOverlapped.data.cb(customOverlapped.data.fd, lpNumberOfBytesTransferred, OSErrorCode(-1)) + + # If cell.data != nil, then system.protect(rawEnv(cb)) was called, + # so we need to dispose our `cb` environment, because it is not needed + # anymore. + if customOverlapped.data.cell.data != nil: + system.dispose(customOverlapped.data.cell) + GC_unref(customOverlapped) else: let errCode = osLastError() @@ -452,6 +534,8 @@ when defined(windows) or defined(nimdoc): assert customOverlapped.data.fd == lpCompletionKey.AsyncFD customOverlapped.data.cb(customOverlapped.data.fd, lpNumberOfBytesTransferred, errCode) + if customOverlapped.data.cell.data != nil: + system.dispose(customOverlapped.data.cell) GC_unref(customOverlapped) else: if errCode.int32 == WAIT_TIMEOUT: @@ -461,6 +545,8 @@ when defined(windows) or defined(nimdoc): # Timer processing. processTimers(p) + # Callback queue processing + processPendingCallbacks(p) var connectExPtr: pointer = nil var acceptExPtr: pointer = nil @@ -651,34 +737,16 @@ when defined(windows) or defined(nimdoc): retFuture.complete("") else: retFuture.fail(newException(OSError, osErrorMsg(err))) - elif ret == 0 and bytesReceived == 0 and dataBuf.buf[0] == '\0': - # We have to ensure that the buffer is empty because WSARecv will tell - # us immediately when it was disconnected, even when there is still - # data in the buffer. - # We want to give the user as much data as we can. So we only return - # the empty string (which signals a disconnection) when there is - # nothing left to read. - retFuture.complete("") - # TODO: "For message-oriented sockets, where a zero byte message is often - # allowable, a failure with an error code of WSAEDISCON is used to - # indicate graceful closure." - # ~ http://msdn.microsoft.com/en-us/library/ms741688%28v=vs.85%29.aspx - else: - # Request to read completed immediately. - # From my tests bytesReceived isn't reliable. - let realSize = - if bytesReceived == 0: - size - else: - bytesReceived - var data = newString(realSize) - assert realSize <= size - copyMem(addr data[0], addr dataBuf.buf[0], realSize) - #dealloc dataBuf.buf - retFuture.complete($data) - # We don't deallocate ``ol`` here because even though this completed - # immediately poll will still be notified about its completion and it will - # free ``ol``. + elif ret == 0: + # Request completed immediately. + if bytesReceived != 0: + var data = newString(bytesReceived) + assert bytesReceived <= size + copyMem(addr data[0], addr dataBuf.buf[0], bytesReceived) + retFuture.complete($data) + else: + if hasOverlappedIoCompleted(cast[POVERLAPPED](ol)): + retFuture.complete("") return retFuture proc recvInto*(socket: AsyncFD, buf: cstring, size: int, @@ -741,31 +809,14 @@ when defined(windows) or defined(nimdoc): retFuture.complete(0) else: retFuture.fail(newException(OSError, osErrorMsg(err))) - elif ret == 0 and bytesReceived == 0 and dataBuf.buf[0] == '\0': - # We have to ensure that the buffer is empty because WSARecv will tell - # us immediately when it was disconnected, even when there is still - # data in the buffer. - # We want to give the user as much data as we can. So we only return - # the empty string (which signals a disconnection) when there is - # nothing left to read. - retFuture.complete(0) - # TODO: "For message-oriented sockets, where a zero byte message is often - # allowable, a failure with an error code of WSAEDISCON is used to - # indicate graceful closure." - # ~ http://msdn.microsoft.com/en-us/library/ms741688%28v=vs.85%29.aspx - else: - # Request to read completed immediately. - # From my tests bytesReceived isn't reliable. - let realSize = - if bytesReceived == 0: - size - else: - bytesReceived - assert realSize <= size - retFuture.complete(realSize) - # We don't deallocate ``ol`` here because even though this completed - # immediately poll will still be notified about its completion and it will - # free ``ol``. + elif ret == 0: + # Request completed immediately. + if bytesReceived != 0: + assert bytesReceived <= size + retFuture.complete(bytesReceived) + else: + if hasOverlappedIoCompleted(cast[POVERLAPPED](ol)): + retFuture.complete(bytesReceived) return retFuture proc send*(socket: AsyncFD, data: string, @@ -811,6 +862,101 @@ when defined(windows) or defined(nimdoc): # free ``ol``. return retFuture + proc sendTo*(socket: AsyncFD, data: pointer, size: int, saddr: ptr SockAddr, + saddrLen: Socklen, + flags = {SocketFlag.SafeDisconn}): Future[void] = + ## Sends ``data`` to specified destination ``saddr``, using + ## socket ``socket``. The returned future will complete once all data + ## has been sent. + verifyPresence(socket) + var retFuture = newFuture[void]("sendTo") + var dataBuf: TWSABuf + dataBuf.buf = cast[cstring](data) + dataBuf.len = size.ULONG + var bytesSent = 0.Dword + var lowFlags = 0.Dword + + # we will preserve address in our stack + var staddr: array[128, char] # SOCKADDR_STORAGE size is 128 bytes + var stalen: cint = cint(saddrLen) + zeroMem(addr(staddr[0]), 128) + copyMem(addr(staddr[0]), saddr, saddrLen) + + var ol = PCustomOverlapped() + GC_ref(ol) + ol.data = CompletionData(fd: socket, cb: + proc (fd: AsyncFD, bytesCount: Dword, errcode: OSErrorCode) = + if not retFuture.finished: + if errcode == OSErrorCode(-1): + retFuture.complete() + else: + retFuture.fail(newException(OSError, osErrorMsg(errcode))) + ) + + let ret = WSASendTo(socket.SocketHandle, addr dataBuf, 1, addr bytesSent, + lowFlags, cast[ptr SockAddr](addr(staddr[0])), + stalen, cast[POVERLAPPED](ol), nil) + if ret == -1: + let err = osLastError() + if err.int32 != ERROR_IO_PENDING: + GC_unref(ol) + retFuture.fail(newException(OSError, osErrorMsg(err))) + else: + retFuture.complete() + # We don't deallocate ``ol`` here because even though this completed + # immediately poll will still be notified about its completion and it will + # free ``ol``. + return retFuture + + proc recvFromInto*(socket: AsyncFD, data: pointer, size: int, + saddr: ptr SockAddr, saddrLen: ptr SockLen, + flags = {SocketFlag.SafeDisconn}): Future[int] = + ## Receives a datagram data from ``socket`` into ``buf``, which must + ## be at least of size ``size``, address of datagram's sender will be + ## stored into ``saddr`` and ``saddrLen``. Returned future will complete + ## once one datagram has been received, and will return size of packet + ## received. + verifyPresence(socket) + var retFuture = newFuture[int]("recvFromInto") + + var dataBuf = TWSABuf(buf: cast[cstring](data), len: size.ULONG) + + var bytesReceived = 0.Dword + var lowFlags = 0.Dword + + var ol = PCustomOverlapped() + GC_ref(ol) + ol.data = CompletionData(fd: socket, cb: + proc (fd: AsyncFD, bytesCount: Dword, errcode: OSErrorCode) = + if not retFuture.finished: + if errcode == OSErrorCode(-1): + assert bytesCount <= size + retFuture.complete(bytesCount) + else: + # datagram sockets don't have disconnection, + # so we can just raise an exception + retFuture.fail(newException(OSError, osErrorMsg(errcode))) + ) + + let res = WSARecvFrom(socket.SocketHandle, addr dataBuf, 1, + addr bytesReceived, addr lowFlags, + saddr, cast[ptr cint](saddrLen), + cast[POVERLAPPED](ol), nil) + if res == -1: + let err = osLastError() + if err.int32 != ERROR_IO_PENDING: + GC_unref(ol) + retFuture.fail(newException(OSError, osErrorMsg(err))) + else: + # Request completed immediately. + if bytesReceived != 0: + assert bytesReceived <= size + retFuture.complete(bytesReceived) + else: + if hasOverlappedIoCompleted(cast[POVERLAPPED](ol)): + retFuture.complete(bytesReceived) + return retFuture + proc acceptAddr*(socket: AsyncFD, flags = {SocketFlag.SafeDisconn}): Future[tuple[address: string, client: AsyncFD]] = ## Accepts a new connection. Returns a future containing the client socket @@ -837,7 +983,7 @@ when defined(windows) or defined(nimdoc): let dwLocalAddressLength = Dword(sizeof (Sockaddr_in) + 16) let dwRemoteAddressLength = Dword(sizeof(Sockaddr_in) + 16) - template completeAccept(): stmt {.immediate, dirty.} = + template completeAccept() {.dirty.} = var listenSock = socket let setoptRet = setsockopt(clientSock, SOL_SOCKET, SO_UPDATE_ACCEPT_CONTEXT, addr listenSock, @@ -857,7 +1003,7 @@ when defined(windows) or defined(nimdoc): client: clientSock.AsyncFD) ) - template failAccept(errcode): stmt = + template failAccept(errcode) = if flags.isDisconnectionError(errcode): var newAcceptFut = acceptAddr(socket, flags) newAcceptFut.callback = @@ -923,6 +1069,126 @@ when defined(windows) or defined(nimdoc): ## Unregisters ``fd``. getGlobalDispatcher().handles.excl(fd) + {.push stackTrace:off.} + proc waitableCallback(param: pointer, + timerOrWaitFired: WINBOOL): void {.stdcall.} = + var p = cast[PostCallbackDataPtr](param) + discard postQueuedCompletionStatus(p.ioPort, timerOrWaitFired.Dword, + ULONG_PTR(p.handleFd), + cast[pointer](p.ovl)) + {.pop.} + + template registerWaitableEvent(mask) = + let p = getGlobalDispatcher() + var flags = (WT_EXECUTEINWAITTHREAD or WT_EXECUTEONLYONCE).Dword + var hEvent = wsaCreateEvent() + if hEvent == 0: + raiseOSError(osLastError()) + var pcd = cast[PostCallbackDataPtr](allocShared0(sizeof(PostCallbackData))) + pcd.ioPort = p.ioPort + pcd.handleFd = fd + var ol = PCustomOverlapped() + GC_ref(ol) + + ol.data = CompletionData(fd: fd, cb: + proc(fd: AsyncFD, bytesCount: Dword, errcode: OSErrorCode) = + # we excluding our `fd` because cb(fd) can register own handler + # for this `fd` + p.handles.excl(fd) + # unregisterWait() is called before callback, because appropriate + # winsockets function can re-enable event. + # https://msdn.microsoft.com/en-us/library/windows/desktop/ms741576(v=vs.85).aspx + if unregisterWait(pcd.waitFd) == 0: + let err = osLastError() + if err.int32 != ERROR_IO_PENDING: + raiseOSError(osLastError()) + if cb(fd): + # callback returned `true`, so we free all allocated resources + deallocShared(cast[pointer](pcd)) + if not wsaCloseEvent(hEvent): + raiseOSError(osLastError()) + # pcd.ovl will be unrefed in poll(). + else: + # callback returned `false` we need to continue + if p.handles.contains(fd): + # new callback was already registered with `fd`, so we free all + # allocated resources. This happens because in callback `cb` + # addRead/addWrite was called with same `fd`. + deallocShared(cast[pointer](pcd)) + if not wsaCloseEvent(hEvent): + raiseOSError(osLastError()) + else: + # we need to include `fd` again + p.handles.incl(fd) + # and register WaitForSingleObject again + if not registerWaitForSingleObject(addr(pcd.waitFd), hEvent, + cast[WAITORTIMERCALLBACK](waitableCallback), + cast[pointer](pcd), INFINITE, flags): + # pcd.ovl will be unrefed in poll() + discard wsaCloseEvent(hEvent) + deallocShared(cast[pointer](pcd)) + raiseOSError(osLastError()) + else: + # we ref pcd.ovl one more time, because it will be unrefed in + # poll() + GC_ref(pcd.ovl) + ) + # We need to protect our callback environment value, so GC will not free it + # accidentally. + ol.data.cell = system.protect(rawEnv(ol.data.cb)) + + # This is main part of `hacky way` is using WSAEventSelect, so `hEvent` + # will be signaled when appropriate `mask` events will be triggered. + if wsaEventSelect(fd.SocketHandle, hEvent, mask) != 0: + GC_unref(ol) + deallocShared(cast[pointer](pcd)) + discard wsaCloseEvent(hEvent) + raiseOSError(osLastError()) + + pcd.ovl = ol + if not registerWaitForSingleObject(addr(pcd.waitFd), hEvent, + cast[WAITORTIMERCALLBACK](waitableCallback), + cast[pointer](pcd), INFINITE, flags): + GC_unref(ol) + deallocShared(cast[pointer](pcd)) + discard wsaCloseEvent(hEvent) + raiseOSError(osLastError()) + p.handles.incl(fd) + + proc addRead*(fd: AsyncFD, cb: Callback) = + ## Start watching the file descriptor for read availability and then call + ## the callback ``cb``. + ## + ## This is not ``pure`` mechanism for Windows Completion Ports (IOCP), + ## so if you can avoid it, please do it. Use `addRead` only if really + ## need it (main usecase is adaptation of `unix like` libraries to be + ## asynchronous on Windows). + ## If you use this function, you dont need to use asyncdispatch.recv() + ## or asyncdispatch.accept(), because they are using IOCP, please use + ## nativesockets.recv() and nativesockets.accept() instead. + ## + ## Be sure your callback ``cb`` returns ``true``, if you want to remove + ## watch of `read` notifications, and ``false``, if you want to continue + ## receiving notifies. + registerWaitableEvent(FD_READ or FD_ACCEPT or FD_OOB or FD_CLOSE) + + proc addWrite*(fd: AsyncFD, cb: Callback) = + ## Start watching the file descriptor for write availability and then call + ## the callback ``cb``. + ## + ## This is not ``pure`` mechanism for Windows Completion Ports (IOCP), + ## so if you can avoid it, please do it. Use `addWrite` only if really + ## need it (main usecase is adaptation of `unix like` libraries to be + ## asynchronous on Windows). + ## If you use this function, you dont need to use asyncdispatch.send() + ## or asyncdispatch.connect(), because they are using IOCP, please use + ## nativesockets.send() and nativesockets.connect() instead. + ## + ## Be sure your callback ``cb`` returns ``true``, if you want to remove + ## watch of `write` notifications, and ``false``, if you want to continue + ## receiving notifies. + registerWaitableEvent(FD_WRITE or FD_CONNECT or FD_CLOSE) + initAll() else: import selectors @@ -956,7 +1222,8 @@ else: proc newDispatcher*(): PDispatcher = new result result.selector = newSelector() - result.timers = @[] + result.timers.newHeapQueue() + result.callbacks = initQueue[proc ()](64) var gDisp{.threadvar.}: PDispatcher ## Global dispatcher proc getGlobalDispatcher*(): PDispatcher = @@ -1014,7 +1281,7 @@ else: proc poll*(timeout = 500) = let p = getGlobalDispatcher() - for info in p.selector.select(timeout): + for info in p.selector.select(p.adjustedTimeout(timeout)): let data = PData(info.key.data) assert data.fd == info.key.fd.AsyncFD #echo("In poll ", data.fd.cint) @@ -1052,7 +1319,10 @@ else: # (e.g. socket disconnected). discard + # Timer processing. processTimers(p) + # Callback queue processing + processPendingCallbacks(p) proc connect*(socket: AsyncFD, address: string, port: Port, domain = AF_INET): Future[void] = @@ -1184,6 +1454,60 @@ else: addWrite(socket, cb) return retFuture + proc sendTo*(socket: AsyncFD, data: pointer, size: int, saddr: ptr SockAddr, + saddrLen: SockLen, + flags = {SocketFlag.SafeDisconn}): Future[void] = + ## Sends ``data`` of size ``size`` in bytes to specified destination + ## (``saddr`` of size ``saddrLen`` in bytes, using socket ``socket``. + ## The returned future will complete once all data has been sent. + var retFuture = newFuture[void]("sendTo") + + # we will preserve address in our stack + var staddr: array[128, char] # SOCKADDR_STORAGE size is 128 bytes + var stalen = saddrLen + zeroMem(addr(staddr[0]), 128) + copyMem(addr(staddr[0]), saddr, saddrLen) + + proc cb(sock: AsyncFD): bool = + result = true + let res = sendto(sock.SocketHandle, data, size, MSG_NOSIGNAL, + cast[ptr SockAddr](addr(staddr[0])), stalen) + if res < 0: + let lastError = osLastError() + if lastError.int32 notin {EINTR, EWOULDBLOCK, EAGAIN}: + retFuture.fail(newException(OSError, osErrorMsg(lastError))) + else: + result = false # We still want this callback to be called. + else: + retFuture.complete() + + addWrite(socket, cb) + return retFuture + + proc recvFromInto*(socket: AsyncFD, data: pointer, size: int, + saddr: ptr SockAddr, saddrLen: ptr SockLen, + flags = {SocketFlag.SafeDisconn}): Future[int] = + ## Receives a datagram data from ``socket`` into ``data``, which must + ## be at least of size ``size`` in bytes, address of datagram's sender + ## will be stored into ``saddr`` and ``saddrLen``. Returned future will + ## complete once one datagram has been received, and will return size + ## of packet received. + var retFuture = newFuture[int]("recvFromInto") + proc cb(sock: AsyncFD): bool = + result = true + let res = recvfrom(sock.SocketHandle, data, size.cint, flags.toOSFlags(), + saddr, saddrLen) + if res < 0: + let lastError = osLastError() + if lastError.int32 notin {EINTR, EWOULDBLOCK, EAGAIN}: + retFuture.fail(newException(OSError, osErrorMsg(lastError))) + else: + result = false + else: + retFuture.complete(res) + addRead(socket, cb) + return retFuture + proc acceptAddr*(socket: AsyncFD, flags = {SocketFlag.SafeDisconn}): Future[tuple[address: string, client: AsyncFD]] = var retFuture = newFuture[tuple[address: string, @@ -1215,7 +1539,25 @@ proc sleepAsync*(ms: int): Future[void] = ## ``ms`` milliseconds. var retFuture = newFuture[void]("sleepAsync") let p = getGlobalDispatcher() - p.timers.add((epochTime() + (ms / 1000), retFuture)) + p.timers.push((epochTime() + (ms / 1000), retFuture)) + return retFuture + +proc withTimeout*[T](fut: Future[T], timeout: int): Future[bool] = + ## Returns a future which will complete once ``fut`` completes or after + ## ``timeout`` milliseconds has elapsed. + ## + ## If ``fut`` completes first the returned future will hold true, + ## otherwise, if ``timeout`` milliseconds has elapsed first, the returned + ## future will hold false. + + var retFuture = newFuture[bool]("asyncdispatch.`withTimeout`") + var timeoutFuture = sleepAsync(timeout) + fut.callback = + proc () = + if not retFuture.finished: retFuture.complete(true) + timeoutFuture.callback = + proc () = + if not retFuture.finished: retFuture.complete(false) return retFuture proc accept*(socket: AsyncFD, @@ -1248,7 +1590,7 @@ proc skipStmtList(node: NimNode): NimNode {.compileTime.} = result = node[0] template createCb(retFutureSym, iteratorNameSym, - name: expr): stmt {.immediate.} = + name: untyped) = var nameIterVar = iteratorNameSym #{.push stackTrace: off.} proc cb {.closure,gcsafe.} = @@ -1316,6 +1658,23 @@ proc generateExceptionCheck(futSym, ) result.add elseNode +template useVar(result: var NimNode, futureVarNode: NimNode, valueReceiver, + rootReceiver: expr, fromNode: NimNode) = + ## Params: + ## futureVarNode: The NimNode which is a symbol identifying the Future[T] + ## variable to yield. + ## fromNode: Used for better debug information (to give context). + ## valueReceiver: The node which defines an expression that retrieves the + ## future's value. + ## + ## rootReceiver: ??? TODO + # -> yield future<x> + result.add newNimNode(nnkYieldStmt, fromNode).add(futureVarNode) + # -> future<x>.read + valueReceiver = newDotExpr(futureVarNode, newIdentNode("read")) + result.add generateExceptionCheck(futureVarNode, tryStmt, rootReceiver, + fromNode) + template createVar(result: var NimNode, futSymName: string, asyncProc: NimNode, valueReceiver, rootReceiver: expr, @@ -1323,9 +1682,7 @@ template createVar(result: var NimNode, futSymName: string, result = newNimNode(nnkStmtList, fromNode) var futSym = genSym(nskVar, "future") result.add newVarStmt(futSym, asyncProc) # -> var future<x> = y - result.add newNimNode(nnkYieldStmt, fromNode).add(futSym) # -> yield future<x> - valueReceiver = newDotExpr(futSym, newIdentNode("read")) # -> future<x>.read - result.add generateExceptionCheck(futSym, tryStmt, rootReceiver, fromNode) + useVar(result, futSym, valueReceiver, rootReceiver, fromNode) proc processBody(node, retFutureSym: NimNode, subTypeIsVoid: bool, @@ -1342,19 +1699,23 @@ proc processBody(node, retFutureSym: NimNode, else: result.add newCall(newIdentNode("complete"), retFutureSym) else: - result.add newCall(newIdentNode("complete"), retFutureSym, - node[0].processBody(retFutureSym, subTypeIsVoid, tryStmt)) + let x = node[0].processBody(retFutureSym, subTypeIsVoid, tryStmt) + if x.kind == nnkYieldStmt: result.add x + else: + result.add newCall(newIdentNode("complete"), retFutureSym, x) result.add newNimNode(nnkReturnStmt, node).add(newNilLit()) return # Don't process the children of this return stmt of nnkCommand, nnkCall: if node[0].kind == nnkIdent and node[0].ident == !"await": case node[1].kind - of nnkIdent, nnkInfix: + of nnkIdent, nnkInfix, nnkDotExpr: # await x + # await x or y result = newNimNode(nnkYieldStmt, node).add(node[1]) # -> yield x of nnkCall, nnkCommand: # await foo(p, x) + # await foo p, x var futureValue: NimNode result.createVar("future" & $node[1][0].toStrLit, node[1], futureValue, futureValue, node) @@ -1511,38 +1872,40 @@ proc asyncSingleProc(prc: NimNode): NimNode {.compileTime.} = # -> complete(retFuture, result) var iteratorNameSym = genSym(nskIterator, $prc[0].getName & "Iter") var procBody = prc[6].processBody(retFutureSym, subtypeIsVoid, nil) - if not subtypeIsVoid: - procBody.insert(0, newNimNode(nnkPragma).add(newIdentNode("push"), - newNimNode(nnkExprColonExpr).add(newNimNode(nnkBracketExpr).add( - newIdentNode("warning"), newIdentNode("resultshadowed")), - newIdentNode("off")))) # -> {.push warning[resultshadowed]: off.} - - procBody.insert(1, newNimNode(nnkVarSection, prc[6]).add( - newIdentDefs(newIdentNode("result"), baseType))) # -> var result: T - - procBody.insert(2, newNimNode(nnkPragma).add( - newIdentNode("pop"))) # -> {.pop.}) - - procBody.add( - newCall(newIdentNode("complete"), - retFutureSym, newIdentNode("result"))) # -> complete(retFuture, result) - else: - # -> complete(retFuture) - procBody.add(newCall(newIdentNode("complete"), retFutureSym)) + # don't do anything with forward bodies (empty) + if procBody.kind != nnkEmpty: + if not subtypeIsVoid: + procBody.insert(0, newNimNode(nnkPragma).add(newIdentNode("push"), + newNimNode(nnkExprColonExpr).add(newNimNode(nnkBracketExpr).add( + newIdentNode("warning"), newIdentNode("resultshadowed")), + newIdentNode("off")))) # -> {.push warning[resultshadowed]: off.} + + procBody.insert(1, newNimNode(nnkVarSection, prc[6]).add( + newIdentDefs(newIdentNode("result"), baseType))) # -> var result: T + + procBody.insert(2, newNimNode(nnkPragma).add( + newIdentNode("pop"))) # -> {.pop.}) + + procBody.add( + newCall(newIdentNode("complete"), + retFutureSym, newIdentNode("result"))) # -> complete(retFuture, result) + else: + # -> complete(retFuture) + procBody.add(newCall(newIdentNode("complete"), retFutureSym)) - var closureIterator = newProc(iteratorNameSym, [newIdentNode("FutureBase")], - procBody, nnkIteratorDef) - closureIterator[4] = newNimNode(nnkPragma, prc[6]).add(newIdentNode("closure")) - outerProcBody.add(closureIterator) + var closureIterator = newProc(iteratorNameSym, [newIdentNode("FutureBase")], + procBody, nnkIteratorDef) + closureIterator[4] = newNimNode(nnkPragma, prc[6]).add(newIdentNode("closure")) + outerProcBody.add(closureIterator) - # -> createCb(retFuture) - #var cbName = newIdentNode("cb") - var procCb = newCall(bindSym"createCb", retFutureSym, iteratorNameSym, - newStrLitNode(prc[0].getName)) - outerProcBody.add procCb + # -> createCb(retFuture) + #var cbName = newIdentNode("cb") + var procCb = getAst createCb(retFutureSym, iteratorNameSym, + newStrLitNode(prc[0].getName)) + outerProcBody.add procCb - # -> return retFuture - outerProcBody.add newNimNode(nnkReturnStmt, prc[6][prc[6].len-1]).add(retFutureSym) + # -> return retFuture + outerProcBody.add newNimNode(nnkReturnStmt, prc[6][prc[6].len-1]).add(retFutureSym) result = prc @@ -1550,19 +1913,19 @@ proc asyncSingleProc(prc: NimNode): NimNode {.compileTime.} = for i in 0 .. <result[4].len: if result[4][i].kind == nnkIdent and result[4][i].ident == !"async": result[4].del(i) + result[4] = newEmptyNode() if subtypeIsVoid: # Add discardable pragma. if returnType.kind == nnkEmpty: # Add Future[void] result[3][0] = parseExpr("Future[void]") - - result[6] = outerProcBody - + if procBody.kind != nnkEmpty: + result[6] = outerProcBody #echo(treeRepr(result)) - #if prc[0].getName == "hubConnectionLoop": + #if prc[0].getName == "testInfix": # echo(toStrLit(result)) -macro async*(prc: stmt): stmt {.immediate.} = +macro async*(prc: untyped): untyped = ## Macro which processes async procedures into the appropriate ## iterators and yield statements. if prc.kind == nnkStmtList: @@ -1571,6 +1934,8 @@ macro async*(prc: stmt): stmt {.immediate.} = result.add asyncSingleProc(oneProc) else: result = asyncSingleProc(prc) + when defined(nimDumpAsync): + echo repr result proc recvLine*(socket: AsyncFD): Future[string] {.async.} = ## Reads a line of data from ``socket``. Returned future will complete once @@ -1611,6 +1976,11 @@ proc recvLine*(socket: AsyncFD): Future[string] {.async.} = return add(result, c) +proc callSoon*(cbproc: proc ()) = + ## Schedule `cbproc` to be called as soon as possible. + ## The callback is called when control returns to the event loop. + getGlobalDispatcher().callbacks.enqueue(cbproc) + proc runForever*() = ## Begins a never ending global dispatcher poll loop. while true: diff --git a/lib/pure/asyncfile.nim b/lib/pure/asyncfile.nim index c7b9fac18..5df606ea8 100644 --- a/lib/pure/asyncfile.nim +++ b/lib/pure/asyncfile.nim @@ -162,7 +162,7 @@ proc read*(f: AsyncFile, size: int): Future[string] = # Request completed immediately. var bytesRead: DWord let overlappedRes = getOverlappedResult(f.fd.Handle, - cast[POverlapped](ol)[], bytesRead, false.WinBool) + cast[POverlapped](ol), bytesRead, false.WinBool) if not overlappedRes.bool: let err = osLastError() if err.int32 == ERROR_HANDLE_EOF: @@ -282,7 +282,7 @@ proc write*(f: AsyncFile, data: string): Future[void] = # Request completed immediately. var bytesWritten: DWord let overlappedRes = getOverlappedResult(f.fd.Handle, - cast[POverlapped](ol)[], bytesWritten, false.WinBool) + cast[POverlapped](ol), bytesWritten, false.WinBool) if not overlappedRes.bool: retFuture.fail(newException(OSError, osErrorMsg(osLastError()))) else: @@ -316,6 +316,7 @@ proc write*(f: AsyncFile, data: string): Future[void] = proc close*(f: AsyncFile) = ## Closes the file specified. + unregister(f.fd) when defined(windows) or defined(nimdoc): if not closeHandle(f.fd.Handle).bool: raiseOSError(osLastError()) diff --git a/lib/pure/asynchttpserver.nim b/lib/pure/asynchttpserver.nim index 590b52c1a..6a7326e83 100644 --- a/lib/pure/asynchttpserver.nim +++ b/lib/pure/asynchttpserver.nim @@ -25,12 +25,21 @@ ## ## waitFor server.serve(Port(8080), cb) -import strtabs, asyncnet, asyncdispatch, parseutils, uri, strutils +import tables, asyncnet, asyncdispatch, parseutils, uri, strutils +import httpcore + +export httpcore except parseHeader + +# TODO: If it turns out that the decisions that asynchttpserver makes +# explicitly, about whether to close the client sockets or upgrade them are +# wrong, then add a return value which determines what to do for the callback. +# Also, maybe move `client` out of `Request` object and into the args for +# the proc. type Request* = object client*: AsyncSocket # TODO: Separate this into a Response object? reqMethod*: string - headers*: StringTableRef + headers*: HttpHeaders protocol*: tuple[orig: string, major, minor: int] url*: Uri hostname*: string ## The hostname of the client that made the request. @@ -39,83 +48,29 @@ type AsyncHttpServer* = ref object socket: AsyncSocket reuseAddr: bool - - HttpCode* = enum - Http100 = "100 Continue", - Http101 = "101 Switching Protocols", - Http200 = "200 OK", - Http201 = "201 Created", - Http202 = "202 Accepted", - Http204 = "204 No Content", - Http205 = "205 Reset Content", - Http206 = "206 Partial Content", - Http300 = "300 Multiple Choices", - Http301 = "301 Moved Permanently", - Http302 = "302 Found", - Http303 = "303 See Other", - Http304 = "304 Not Modified", - Http305 = "305 Use Proxy", - Http307 = "307 Temporary Redirect", - Http400 = "400 Bad Request", - Http401 = "401 Unauthorized", - Http403 = "403 Forbidden", - Http404 = "404 Not Found", - Http405 = "405 Method Not Allowed", - Http406 = "406 Not Acceptable", - Http407 = "407 Proxy Authentication Required", - Http408 = "408 Request Timeout", - Http409 = "409 Conflict", - Http410 = "410 Gone", - Http411 = "411 Length Required", - Http412 = "412 Precondition Failed", - Http413 = "413 Request Entity Too Large", - Http414 = "414 Request-URI Too Long", - Http415 = "415 Unsupported Media Type", - Http416 = "416 Requested Range Not Satisfiable", - Http417 = "417 Expectation Failed", - Http418 = "418 I'm a teapot", - Http500 = "500 Internal Server Error", - Http501 = "501 Not Implemented", - Http502 = "502 Bad Gateway", - Http503 = "503 Service Unavailable", - Http504 = "504 Gateway Timeout", - Http505 = "505 HTTP Version Not Supported" - - HttpVersion* = enum - HttpVer11, - HttpVer10 + reusePort: bool {.deprecated: [TRequest: Request, PAsyncHttpServer: AsyncHttpServer, THttpCode: HttpCode, THttpVersion: HttpVersion].} -proc `==`*(protocol: tuple[orig: string, major, minor: int], - ver: HttpVersion): bool = - let major = - case ver - of HttpVer11, HttpVer10: 1 - let minor = - case ver - of HttpVer11: 1 - of HttpVer10: 0 - result = protocol.major == major and protocol.minor == minor - -proc newAsyncHttpServer*(reuseAddr = true): AsyncHttpServer = +proc newAsyncHttpServer*(reuseAddr = true, reusePort = false): AsyncHttpServer = ## Creates a new ``AsyncHttpServer`` instance. new result result.reuseAddr = reuseAddr + result.reusePort = reusePort -proc addHeaders(msg: var string, headers: StringTableRef) = +proc addHeaders(msg: var string, headers: HttpHeaders) = for k, v in headers: msg.add(k & ": " & v & "\c\L") -proc sendHeaders*(req: Request, headers: StringTableRef): Future[void] = +proc sendHeaders*(req: Request, headers: HttpHeaders): Future[void] = ## Sends the specified headers to the requesting client. var msg = "" addHeaders(msg, headers) return req.client.send(msg) proc respond*(req: Request, code: HttpCode, content: string, - headers: StringTableRef = nil): Future[void] = + headers: HttpHeaders = nil): Future[void] = ## Responds to the request with the specified ``HttpCode``, headers and ## content. ## @@ -128,16 +83,6 @@ proc respond*(req: Request, code: HttpCode, content: string, msg.add(content) result = req.client.send(msg) -proc parseHeader(line: string): tuple[key, value: string] = - var i = 0 - i = line.parseUntil(result.key, ':') - inc(i) # skip : - if i < len(line): - i += line.skipWhiteSpace(i) - i += line.parseUntil(result.value, {'\c', '\L'}, i) - else: - result.value = "" - proc parseProtocol(protocol: string): tuple[orig: string, major, minor: int] = var i = protocol.skipIgnoreCase("HTTP/") if i != 5: @@ -156,7 +101,7 @@ proc processClient(client: AsyncSocket, address: string, Future[void] {.closure, gcsafe.}) {.async.} = var request: Request request.url = initUri() - request.headers = newStringTable(modeCaseInsensitive) + request.headers = newHttpHeaders() var lineFut = newFutureVar[string]("asynchttpserver.processClient") lineFut.mget() = newStringOfCap(80) var key, value = "" @@ -165,7 +110,7 @@ proc processClient(client: AsyncSocket, address: string, # GET /path HTTP/1.1 # Header: val # \n - request.headers.clear(modeCaseInsensitive) + request.headers.clear() request.body = "" request.hostname.shallowCopy(address) assert client != nil @@ -208,29 +153,34 @@ proc processClient(client: AsyncSocket, address: string, if lineFut.mget == "\c\L": break let (key, value) = parseHeader(lineFut.mget) request.headers[key] = value + # Ensure the client isn't trying to DoS us. + if request.headers.len > headerLimit: + await client.sendStatus("400 Bad Request") + request.client.close() + return if request.reqMethod == "post": # Check for Expect header if request.headers.hasKey("Expect"): - if request.headers.getOrDefault("Expect").toLower == "100-continue": + if "100-continue" in request.headers["Expect"]: await client.sendStatus("100 Continue") else: await client.sendStatus("417 Expectation Failed") - # Read the body - # - Check for Content-length header - if request.headers.hasKey("Content-Length"): - var contentLength = 0 - if parseInt(request.headers.getOrDefault("Content-Length"), - contentLength) == 0: - await request.respond(Http400, "Bad Request. Invalid Content-Length.") - continue - else: - request.body = await client.recv(contentLength) - assert request.body.len == contentLength - else: - await request.respond(Http400, "Bad Request. No Content-Length.") + # Read the body + # - Check for Content-length header + if request.headers.hasKey("Content-Length"): + var contentLength = 0 + if parseInt(request.headers["Content-Length"], + contentLength) == 0: + await request.respond(Http400, "Bad Request. Invalid Content-Length.") continue + else: + request.body = await client.recv(contentLength) + assert request.body.len == contentLength + elif request.reqMethod == "post": + await request.respond(Http400, "Bad Request. No Content-Length.") + continue case request.reqMethod of "get", "post", "head", "put", "delete", "trace", "options", @@ -240,6 +190,9 @@ proc processClient(client: AsyncSocket, address: string, await request.respond(Http400, "Invalid request method. Got: " & request.reqMethod) + if "upgrade" in request.headers.getOrDefault("connection"): + return + # Persistent connections if (request.protocol == HttpVer11 and request.headers.getOrDefault("connection").normalize != "close") or @@ -264,6 +217,8 @@ proc serve*(server: AsyncHttpServer, port: Port, server.socket = newAsyncSocket() if server.reuseAddr: server.socket.setSockOpt(OptReuseAddr, true) + if server.reusePort: + server.socket.setSockOpt(OptReusePort, true) server.socket.bindAddr(port, address) server.socket.listen() @@ -287,7 +242,7 @@ when not defined(testing) and isMainModule: #echo(req.headers) let headers = {"Date": "Tue, 29 Apr 2014 23:40:08 GMT", "Content-type": "text/plain; charset=utf-8"} - await req.respond(Http200, "Hello World", headers.newStringTable()) + await req.respond(Http200, "Hello World", headers.newHttpHeaders()) asyncCheck server.serve(Port(5555), cb) runForever() diff --git a/lib/pure/asyncnet.nim b/lib/pure/asyncnet.nim index 6b19a48be..a1988f4a6 100644 --- a/lib/pure/asyncnet.nim +++ b/lib/pure/asyncnet.nim @@ -11,7 +11,7 @@ ## asynchronous dispatcher defined in the ``asyncdispatch`` module. ## ## SSL -## --- +## ---- ## ## SSL can be enabled by compiling with the ``-d:ssl`` flag. ## @@ -62,7 +62,9 @@ import os export SOBool -when defined(ssl): +const defineSsl = defined(ssl) or defined(nimdoc) + +when defineSsl: import openssl type @@ -79,7 +81,7 @@ type of false: nil case isSsl: bool of true: - when defined(ssl): + when defineSsl: sslHandle: SslPtr sslContext: SslContext bioIn: BIO @@ -125,7 +127,7 @@ proc newAsyncSocket*(domain, sockType, protocol: cint, Domain(domain), SockType(sockType), Protocol(protocol), buffered) -when defined(ssl): +when defineSsl: proc getSslError(handle: SslPtr, err: cint): cint = assert err < 0 var ret = SSLGetError(handle, err.cint) @@ -186,7 +188,7 @@ proc connect*(socket: AsyncSocket, address: string, port: Port) {.async.} = ## or an error occurs. await connect(socket.fd.AsyncFD, address, port, socket.domain) if socket.isSsl: - when defined(ssl): + when defineSsl: let flags = {SocketFlag.SafeDisconn} sslSetConnectState(socket.sslHandle) sslLoop(socket, flags, sslDoHandshake(socket.sslHandle)) @@ -197,7 +199,7 @@ template readInto(buf: cstring, size: int, socket: AsyncSocket, ## this is a template and not a proc. var res = 0 if socket.isSsl: - when defined(ssl): + when defineSsl: # SSL mode. sslLoop(socket, flags, sslRead(socket.sslHandle, buf, size.cint)) @@ -274,7 +276,7 @@ proc send*(socket: AsyncSocket, data: string, ## data has been sent. assert socket != nil if socket.isSsl: - when defined(ssl): + when defineSsl: var copy = data sslLoop(socket, flags, sslWrite(socket.sslHandle, addr copy[0], copy.len.cint)) @@ -426,9 +428,6 @@ proc recvLine*(socket: AsyncSocket, ## ## **Warning**: ``recvLine`` on unbuffered sockets assumes that the protocol ## uses ``\r\L`` to delimit a new line. - template addNLIfEmpty(): stmt = - if result.len == 0: - result.add("\c\L") assert SocketFlag.Peek notin flags ## TODO: # TODO: Optimise this @@ -456,7 +455,8 @@ proc bindAddr*(socket: AsyncSocket, port = Port(0), address = "") {. of AF_INET6: realaddr = "::" of AF_INET: realaddr = "0.0.0.0" else: - raiseOSError("Unknown socket address family and no address specified to bindAddr") + raise newException(ValueError, + "Unknown socket address family and no address specified to bindAddr") var aiList = getAddrInfo(realaddr, port, socket.domain) if bindAddr(socket.fd, aiList.ai_addr, aiList.ai_addrlen.Socklen) < 0'i32: @@ -468,7 +468,7 @@ proc close*(socket: AsyncSocket) = ## Closes the socket. defer: socket.fd.AsyncFD.closeSocket() - when defined(ssl): + when defineSsl: if socket.isSSL: let res = SslShutdown(socket.sslHandle) SSLFree(socket.sslHandle) @@ -478,7 +478,7 @@ proc close*(socket: AsyncSocket) = raiseSslError() socket.closed = true # TODO: Add extra debugging checks for this. -when defined(ssl): +when defineSsl: proc wrapSocket*(ctx: SslContext, socket: AsyncSocket) = ## Wraps a socket in an SSL context. This function effectively turns ## ``socket`` into an SSL socket. @@ -487,7 +487,7 @@ when defined(ssl): ## prone to security vulnerabilities. socket.isSsl = true socket.sslContext = ctx - socket.sslHandle = SSLNew(SSLCTX(socket.sslContext)) + socket.sslHandle = SSLNew(socket.sslContext.context) if socket.sslHandle == nil: raiseSslError() diff --git a/lib/pure/base64.nim b/lib/pure/base64.nim index 32d37ce02..2fc0b1c5e 100644 --- a/lib/pure/base64.nim +++ b/lib/pure/base64.nim @@ -90,6 +90,8 @@ template encodeInternal(s: expr, lineLen: int, newLine: string): stmt {.immediat if r+4 != result.len: setLen(result, r+4) else: + if r != result.len: + setLen(result, r) #assert(r == result.len) discard @@ -162,4 +164,3 @@ when isMainModule: "asure.", longText] for t in items(tests): assert decode(encode(t)) == t - diff --git a/lib/pure/collections/LockFreeHash.nim b/lib/pure/collections/LockFreeHash.nim index a3ead81e3..954d62491 100644 --- a/lib/pure/collections/LockFreeHash.nim +++ b/lib/pure/collections/LockFreeHash.nim @@ -52,7 +52,7 @@ when sizeof(int) == 4: # 32bit {.deprecated: [TRaw: Raw].} elif sizeof(int) == 8: # 64bit type - Raw = range[0..4611686018427387903] + Raw = range[0'i64..4611686018427387903'i64] ## The range of uint values that can be stored directly in a value slot ## when on a 64 bit platform {.deprecated: [TRaw: Raw].} @@ -74,7 +74,7 @@ type copyDone: int next: PConcTable[K,V] data: EntryArr -{.deprecated: [TEntry: Entry, TEntryArr: EntryArr.} +{.deprecated: [TEntry: Entry, TEntryArr: EntryArr].} proc setVal[K,V](table: var PConcTable[K,V], key: int, val: int, expVal: int, match: bool): int @@ -244,9 +244,9 @@ proc copySlot[K,V](idx: int, oldTbl: var PConcTable[K,V], newTbl: var PConcTable echo("oldVal is Tomb!!!, should not happen") if pop(oldVal) != 0: result = setVal(newTbl, pop(oldKey), pop(oldVal), 0, true) == 0 - if result: + #if result: #echo("Copied a Slot! idx= " & $idx & " key= " & $oldKey & " val= " & $oldVal) - else: + #else: #echo("copy slot failed") # Our copy is done so we disable the old slot while not ok: diff --git a/lib/pure/collections/chains.nim b/lib/pure/collections/chains.nim new file mode 100644 index 000000000..6b2ecd272 --- /dev/null +++ b/lib/pure/collections/chains.nim @@ -0,0 +1,44 @@ +# +# +# Nim's Runtime Library +# (c) Copyright 2016 Andreas Rumpf +# +# See the file "copying.txt", included in this +# distribution, for details about the copyright. +# + +## Template based implementation of singly and doubly linked lists. +## The involved types should have 'prev' or 'next' fields and the +## list header should have 'head' or 'tail' fields. + +template prepend*(header, node) = + when compiles(header.head): + when compiles(node.prev): + if header.head != nil: + header.head.prev = node + node.next = header.head + header.head = node + when compiles(header.tail): + if header.tail == nil: + header.tail = node + +template append*(header, node) = + when compiles(header.head): + if header.head == nil: + header.head = node + when compiles(header.tail): + when compiles(node.prev): + node.prev = header.tail + if header.tail != nil: + header.tail.next = node + header.tail = node + +template unlink*(header, node) = + if node.next != nil: + node.next.prev = node.prev + if node.prev != nil: + node.prev.next = node.next + if header.head == node: + header.head = node.prev + if header.tail == node: + header.tail = node.next diff --git a/lib/pure/collections/heapqueue.nim b/lib/pure/collections/heapqueue.nim new file mode 100644 index 000000000..149a1c9fc --- /dev/null +++ b/lib/pure/collections/heapqueue.nim @@ -0,0 +1,107 @@ +##[ Heap queue algorithm (a.k.a. priority queue). Ported from Python heapq. + +Heaps are arrays for which a[k] <= a[2*k+1] and a[k] <= a[2*k+2] for +all k, counting elements from 0. For the sake of comparison, +non-existing elements are considered to be infinite. The interesting +property of a heap is that a[0] is always its smallest element. + +]## + +type HeapQueue*[T] = distinct seq[T] + +proc newHeapQueue*[T](): HeapQueue[T] {.inline.} = HeapQueue[T](newSeq[T]()) +proc newHeapQueue*[T](h: var HeapQueue[T]) {.inline.} = h = HeapQueue[T](newSeq[T]()) + +proc len*[T](h: HeapQueue[T]): int {.inline.} = seq[T](h).len +proc `[]`*[T](h: HeapQueue[T], i: int): T {.inline.} = seq[T](h)[i] +proc `[]=`[T](h: var HeapQueue[T], i: int, v: T) {.inline.} = seq[T](h)[i] = v +proc add[T](h: var HeapQueue[T], v: T) {.inline.} = seq[T](h).add(v) + +proc heapCmp[T](x, y: T): bool {.inline.} = + return (x < y) + +# 'heap' is a heap at all indices >= startpos, except possibly for pos. pos +# is the index of a leaf with a possibly out-of-order value. Restore the +# heap invariant. +proc siftdown[T](heap: var HeapQueue[T], startpos, p: int) = + var pos = p + var newitem = heap[pos] + # Follow the path to the root, moving parents down until finding a place + # newitem fits. + while pos > startpos: + let parentpos = (pos - 1) shr 1 + let parent = heap[parentpos] + if heapCmp(newitem, parent): + heap[pos] = parent + pos = parentpos + else: + break + heap[pos] = newitem + +proc siftup[T](heap: var HeapQueue[T], p: int) = + let endpos = len(heap) + var pos = p + let startpos = pos + let newitem = heap[pos] + # Bubble up the smaller child until hitting a leaf. + var childpos = 2*pos + 1 # leftmost child position + while childpos < endpos: + # Set childpos to index of smaller child. + let rightpos = childpos + 1 + if rightpos < endpos and not heapCmp(heap[childpos], heap[rightpos]): + childpos = rightpos + # Move the smaller child up. + heap[pos] = heap[childpos] + pos = childpos + childpos = 2*pos + 1 + # The leaf at pos is empty now. Put newitem there, and bubble it up + # to its final resting place (by sifting its parents down). + heap[pos] = newitem + siftdown(heap, startpos, pos) + +proc push*[T](heap: var HeapQueue[T], item: T) = + ## Push item onto heap, maintaining the heap invariant. + (seq[T](heap)).add(item) + siftdown(heap, 0, len(heap)-1) + +proc pop*[T](heap: var HeapQueue[T]): T = + ## Pop the smallest item off the heap, maintaining the heap invariant. + let lastelt = seq[T](heap).pop() + if heap.len > 0: + result = heap[0] + heap[0] = lastelt + siftup(heap, 0) + else: + result = lastelt + +proc replace*[T](heap: var HeapQueue[T], item: T): T = + ## Pop and return the current smallest value, and add the new item. + ## This is more efficient than pop() followed by push(), and can be + ## more appropriate when using a fixed-size heap. Note that the value + ## returned may be larger than item! That constrains reasonable uses of + ## this routine unless written as part of a conditional replacement: + + ## if item > heap[0]: + ## item = replace(heap, item) + result = heap[0] + heap[0] = item + siftup(heap, 0) + +proc pushpop*[T](heap: var HeapQueue[T], item: T): T = + ## Fast version of a push followed by a pop. + if heap.len > 0 and heapCmp(heap[0], item): + swap(item, heap[0]) + siftup(heap, 0) + return item + +when isMainModule: + # Simple sanity test + var heap = newHeapQueue[int]() + let data = [1, 3, 5, 7, 9, 2, 4, 6, 8, 0] + for item in data: + push(heap, item) + doAssert(heap[0] == 0) + var sort = newSeq[int]() + while heap.len > 0: + sort.add(pop(heap)) + doAssert(sort == @[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]) diff --git a/lib/pure/collections/queues.nim b/lib/pure/collections/queues.nim index b9bf33bff..399e4d413 100644 --- a/lib/pure/collections/queues.nim +++ b/lib/pure/collections/queues.nim @@ -8,56 +8,142 @@ # ## Implementation of a `queue`:idx:. The underlying implementation uses a ``seq``. +## +## None of the procs that get an individual value from the queue can be used +## on an empty queue. +## If compiled with `boundChecks` option, those procs will raise an `IndexError` +## on such access. This should not be relied upon, as `-d:release` will +## disable those checks and may return garbage or crash the program. +## +## As such, a check to see if the queue is empty is needed before any +## access, unless your program logic guarantees it indirectly. +## +## .. code-block:: Nim +## proc foo(a, b: Positive) = # assume random positive values for `a` and `b` +## var q = initQueue[int]() # initializes the object +## for i in 1 ..< a: q.add i # populates the queue +## +## if b < q.len: # checking before indexed access +## echo "The element at index position ", b, " is ", q[b] +## +## # The following two lines don't need any checking on access due to the +## # logic of the program, but that would not be the case if `a` could be 0. +## assert q.front == 1 +## assert q.back == a +## +## while q.len > 0: # checking if the queue is empty +## echo q.pop() +## ## Note: For inter thread communication use ## a `Channel <channels.html>`_ instead. import math type - Queue*[T] = object ## a queue + Queue*[T] = object ## A queue. data: seq[T] rd, wr, count, mask: int {.deprecated: [TQueue: Queue].} -proc initQueue*[T](initialSize=4): Queue[T] = - ## creates a new queue. `initialSize` needs to be a power of 2. +proc initQueue*[T](initialSize: int = 4): Queue[T] = + ## Create a new queue. + ## Optionally, the initial capacity can be reserved via `initialSize` as a + ## performance optimization. The length of a newly created queue will still + ## be 0. + ## + ## `initialSize` needs to be a power of two. If you need to accept runtime + ## values for this you could use the ``nextPowerOfTwo`` proc from the + ## `math <math.html>`_ module. assert isPowerOfTwo(initialSize) result.mask = initialSize-1 newSeq(result.data, initialSize) -proc len*[T](q: Queue[T]): int = - ## returns the number of elements of `q`. +proc len*[T](q: Queue[T]): int {.inline.}= + ## Return the number of elements of `q`. result = q.count +template emptyCheck(q) = + # Bounds check for the regular queue access. + when compileOption("boundChecks"): + if unlikely(q.count < 1): + raise newException(IndexError, "Empty queue.") + +template xBoundsCheck(q, i) = + # Bounds check for the array like accesses. + when compileOption("boundChecks"): # d:release should disable this. + if unlikely(i >= q.count): # x < q.low is taken care by the Natural parameter + raise newException(IndexError, + "Out of bounds: " & $i & " > " & $(q.count - 1)) + +proc front*[T](q: Queue[T]): T {.inline.}= + ## Return the oldest element of `q`. Equivalent to `q.pop()` but does not + ## remove it from the queue. + emptyCheck(q) + result = q.data[q.rd] + +proc back*[T](q: Queue[T]): T {.inline.} = + ## Return the newest element of `q` but does not remove it from the queue. + emptyCheck(q) + result = q.data[q.wr - 1 and q.mask] + +proc `[]`*[T](q: Queue[T], i: Natural) : T {.inline.} = + ## Access the i-th element of `q` by order of insertion. + ## q[0] is the oldest (the next one q.pop() will extract), + ## q[^1] is the newest (last one added to the queue). + xBoundsCheck(q, i) + return q.data[q.rd + i and q.mask] + +proc `[]`*[T](q: var Queue[T], i: Natural): var T {.inline.} = + ## Access the i-th element of `q` and returns a mutable + ## reference to it. + xBoundsCheck(q, i) + return q.data[q.rd + i and q.mask] + +proc `[]=`* [T] (q: var Queue[T], i: Natural, val : T) {.inline.} = + ## Change the i-th element of `q`. + xBoundsCheck(q, i) + q.data[q.rd + i and q.mask] = val + iterator items*[T](q: Queue[T]): T = - ## yields every element of `q`. + ## Yield every element of `q`. var i = q.rd - var c = q.count - while c > 0: - dec c + for c in 0 ..< q.count: yield q.data[i] i = (i + 1) and q.mask iterator mitems*[T](q: var Queue[T]): var T = - ## yields every element of `q`. + ## Yield every element of `q`. var i = q.rd - var c = q.count - while c > 0: - dec c + for c in 0 ..< q.count: yield q.data[i] i = (i + 1) and q.mask +iterator pairs*[T](q: Queue[T]): tuple[key: int, val: T] = + ## Yield every (position, value) of `q`. + var i = q.rd + for c in 0 ..< q.count: + yield (c, q.data[i]) + i = (i + 1) and q.mask + +proc contains*[T](q: Queue[T], item: T): bool {.inline.} = + ## Return true if `item` is in `q` or false if not found. Usually used + ## via the ``in`` operator. It is the equivalent of ``q.find(item) >= 0``. + ## + ## .. code-block:: Nim + ## if x in q: + ## assert q.contains x + for e in q: + if e == item: return true + return false + proc add*[T](q: var Queue[T], item: T) = - ## adds an `item` to the end of the queue `q`. + ## Add an `item` to the end of the queue `q`. var cap = q.mask+1 - if q.count >= cap: - var n: seq[T] - newSeq(n, cap*2) - var i = 0 - for x in items(q): + if unlikely(q.count >= cap): + var n = newSeq[T](cap*2) + for i, x in q: # don't use copyMem because the GC and because it's slower. shallowCopy(n[i], x) - inc i shallowCopy(q.data, n) q.mask = cap*2 - 1 q.wr = q.count @@ -66,37 +152,104 @@ proc add*[T](q: var Queue[T], item: T) = q.data[q.wr] = item q.wr = (q.wr + 1) and q.mask -proc enqueue*[T](q: var Queue[T], item: T) = - ## alias for the ``add`` operation. - add(q, item) - -proc dequeue*[T](q: var Queue[T]): T = - ## removes and returns the first element of the queue `q`. - assert q.count > 0 +proc default[T](t: typedesc[T]): T {.inline.} = discard +proc pop*[T](q: var Queue[T]): T {.inline, discardable.} = + ## Remove and returns the first (oldest) element of the queue `q`. + emptyCheck(q) dec q.count result = q.data[q.rd] + q.data[q.rd] = default(type(result)) q.rd = (q.rd + 1) and q.mask +proc enqueue*[T](q: var Queue[T], item: T) = + ## Alias for the ``add`` operation. + q.add(item) + +proc dequeue*[T](q: var Queue[T]): T = + ## Alias for the ``pop`` operation. + q.pop() + proc `$`*[T](q: Queue[T]): string = - ## turns a queue into its string representation. + ## Turn a queue into its string representation. result = "[" - for x in items(q): + for x in items(q): # Don't remove the items here for reasons that don't fit in this margin. if result.len > 1: result.add(", ") result.add($x) result.add("]") when isMainModule: - var q = initQueue[int]() + var q = initQueue[int](1) q.add(123) q.add(9) - q.add(4) - var first = q.dequeue + q.enqueue(4) + var first = q.dequeue() q.add(56) q.add(6) - var second = q.dequeue + var second = q.pop() q.add(789) assert first == 123 assert second == 9 assert($q == "[4, 56, 6, 789]") + assert q[0] == q.front and q.front == 4 + assert q[^1] == q.back and q.back == 789 + q[0] = 42 + q[^1] = 7 + + assert 6 in q and 789 notin q + assert q.find(6) >= 0 + assert q.find(789) < 0 + + for i in -2 .. 10: + if i in q: + assert q.contains(i) and q.find(i) >= 0 + else: + assert(not q.contains(i) and q.find(i) < 0) + + when compileOption("boundChecks"): + try: + echo q[99] + assert false + except IndexError: + discard + + try: + assert q.len == 4 + for i in 0 ..< 5: q.pop() + assert false + except IndexError: + discard + + # grabs some types of resize error. + q = initQueue[int]() + for i in 1 .. 4: q.add i + q.pop() + q.pop() + for i in 5 .. 8: q.add i + assert $q == "[3, 4, 5, 6, 7, 8]" + + # Similar to proc from the documentation example + proc foo(a, b: Positive) = # assume random positive values for `a` and `b`. + var q = initQueue[int]() + assert q.len == 0 + for i in 1 .. a: q.add i + + if b < q.len: # checking before indexed access. + assert q[b] == b + 1 + + # The following two lines don't need any checking on access due to the logic + # of the program, but that would not be the case if `a` could be 0. + assert q.front == 1 + assert q.back == a + + while q.len > 0: # checking if the queue is empty + assert q.pop() > 0 + + #foo(0,0) + foo(8,5) + foo(10,9) + foo(1,1) + foo(2,1) + foo(1,5) + foo(3,2) diff --git a/lib/pure/collections/rtarrays.nim b/lib/pure/collections/rtarrays.nim index 9d8085643..2abe9d1f8 100644 --- a/lib/pure/collections/rtarrays.nim +++ b/lib/pure/collections/rtarrays.nim @@ -18,7 +18,7 @@ type RtArray*[T] = object ## L: Natural spart: seq[T] - apart: array [ArrayPartSize, T] + apart: array[ArrayPartSize, T] UncheckedArray* {.unchecked.}[T] = array[0..100_000_000, T] template usesSeqPart(x): expr = x.L > ArrayPartSize diff --git a/lib/pure/collections/sequtils.nim b/lib/pure/collections/sequtils.nim index 0e3824a81..f458b7636 100644 --- a/lib/pure/collections/sequtils.nim +++ b/lib/pure/collections/sequtils.nim @@ -20,6 +20,8 @@ ## **Note**: This interface will change as soon as the compiler supports ## closures and proper coroutines. +include "system/inclrtl" + when not defined(nimhygiene): {.pragma: dirty.} @@ -355,7 +357,7 @@ proc insert*[T](dest: var seq[T], src: openArray[T], pos=0) = inc(j) -template filterIt*(seq1, pred: expr): expr = +template filterIt*(seq1, pred: untyped): untyped = ## Returns a new sequence with all the items that fulfilled the predicate. ## ## Unlike the `proc` version, the predicate needs to be an expression using @@ -369,12 +371,12 @@ template filterIt*(seq1, pred: expr): expr = ## notAcceptable = filterIt(temperatures, it > 50 or it < -10) ## assert acceptable == @[-2.0, 24.5, 44.31] ## assert notAcceptable == @[-272.15, 99.9, -113.44] - var result {.gensym.} = newSeq[type(seq1[0])]() + var result = newSeq[type(seq1[0])]() for it {.inject.} in items(seq1): if pred: result.add(it) result -template keepItIf*(varSeq: seq, pred: expr) = +template keepItIf*(varSeq: seq, pred: untyped) = ## Convenience template around the ``keepIf`` proc to reduce typing. ## ## Unlike the `proc` version, the predicate needs to be an expression using @@ -409,7 +411,7 @@ proc all*[T](seq1: seq[T], pred: proc(item: T): bool {.closure.}): bool = return false return true -template allIt*(seq1, pred: expr): bool {.immediate.} = +template allIt*(seq1, pred: untyped): bool = ## Checks if every item fulfills the predicate. ## ## Example: @@ -418,7 +420,7 @@ template allIt*(seq1, pred: expr): bool {.immediate.} = ## let numbers = @[1, 4, 5, 8, 9, 7, 4] ## assert allIt(numbers, it < 10) == true ## assert allIt(numbers, it < 9) == false - var result {.gensym.} = true + var result = true for it {.inject.} in items(seq1): if not pred: result = false @@ -440,7 +442,7 @@ proc any*[T](seq1: seq[T], pred: proc(item: T): bool {.closure.}): bool = return true return false -template anyIt*(seq1, pred: expr): bool {.immediate.} = +template anyIt*(seq1, pred: untyped): bool = ## Checks if some item fulfills the predicate. ## ## Example: @@ -449,14 +451,14 @@ template anyIt*(seq1, pred: expr): bool {.immediate.} = ## let numbers = @[1, 4, 5, 8, 9, 7, 4] ## assert anyIt(numbers, it > 8) == true ## assert anyIt(numbers, it > 9) == false - var result {.gensym.} = false + var result = false for it {.inject.} in items(seq1): if pred: result = true break result -template toSeq*(iter: expr): expr {.immediate.} = +template toSeq*(iter: untyped): untyped {.oldimmediate.} = ## Transforms any iterator into a sequence. ## ## Example: @@ -482,7 +484,7 @@ template toSeq*(iter: expr): expr {.immediate.} = result.add(x) result -template foldl*(sequence, operation: expr): expr = +template foldl*(sequence, operation: untyped): untyped = ## Template to fold a sequence from left to right, returning the accumulation. ## ## The sequence is required to have at least a single element. Debug versions @@ -510,7 +512,7 @@ template foldl*(sequence, operation: expr): expr = ## assert concatenation == "nimiscool" let s = sequence assert s.len > 0, "Can't fold empty sequences" - var result {.gensym.}: type(s[0]) + var result: type(s[0]) result = s[0] for i in 1..<s.len: let @@ -519,7 +521,7 @@ template foldl*(sequence, operation: expr): expr = result = operation result -template foldl*(sequence, operation: expr, first): expr = +template foldl*(sequence, operation, first): untyped = ## Template to fold a sequence from left to right, returning the accumulation. ## ## This version of ``foldl`` gets a starting parameter. This makes it possible @@ -535,7 +537,7 @@ template foldl*(sequence, operation: expr, first): expr = ## numbers = @[0, 8, 1, 5] ## digits = foldl(numbers, a & (chr(b + ord('0'))), "") ## assert digits == "0815" - var result {.gensym.}: type(first) + var result: type(first) result = first for x in items(sequence): let @@ -544,7 +546,7 @@ template foldl*(sequence, operation: expr, first): expr = result = operation result -template foldr*(sequence, operation: expr): expr = +template foldr*(sequence, operation: untyped): untyped = ## Template to fold a sequence from right to left, returning the accumulation. ## ## The sequence is required to have at least a single element. Debug versions @@ -572,7 +574,7 @@ template foldr*(sequence, operation: expr): expr = ## assert concatenation == "nimiscool" let s = sequence assert s.len > 0, "Can't fold empty sequences" - var result {.gensym.}: type(s[0]) + var result: type(s[0]) result = sequence[s.len - 1] for i in countdown(s.len - 2, 0): let @@ -581,7 +583,7 @@ template foldr*(sequence, operation: expr): expr = result = operation result -template mapIt*(seq1, typ, op: expr): expr {.deprecated.}= +template mapIt*(seq1, typ, op: untyped): untyped = ## Convenience template around the ``map`` proc to reduce typing. ## ## The template injects the ``it`` variable which you can use directly in an @@ -596,13 +598,13 @@ template mapIt*(seq1, typ, op: expr): expr {.deprecated.}= ## assert strings == @["4", "8", "12", "16"] ## **Deprecated since version 0.12.0:** Use the ``mapIt(seq1, op)`` ## template instead. - var result {.gensym.}: seq[typ] = @[] + var result: seq[typ] = @[] for it {.inject.} in items(seq1): result.add(op) result -template mapIt*(seq1, op: expr): expr = +template mapIt*(seq1, op: untyped): untyped = ## Convenience template around the ``map`` proc to reduce typing. ## ## The template injects the ``it`` variable which you can use directly in an @@ -631,7 +633,7 @@ template mapIt*(seq1, op: expr): expr = result.add(op) result -template applyIt*(varSeq, op: expr) = +template applyIt*(varSeq, op: untyped) = ## Convenience template around the mutable ``apply`` proc to reduce typing. ## ## The template injects the ``it`` variable which you can use directly in an @@ -648,7 +650,7 @@ template applyIt*(varSeq, op: expr) = -template newSeqWith*(len: int, init: expr): expr = +template newSeqWith*(len: int, init: untyped): untyped = ## creates a new sequence, calling `init` to initialize each value. Example: ## ## .. code-block:: @@ -657,10 +659,10 @@ template newSeqWith*(len: int, init: expr): expr = ## seq2D[1][0] = true ## seq2D[0][1] = true ## - ## import math + ## import random ## var seqRand = newSeqWith(20, random(10)) ## echo seqRand - var result {.gensym.} = newSeq[type(init)](len) + var result = newSeq[type(init)](len) for i in 0 .. <len: result[i] = init result diff --git a/lib/pure/collections/sets.nim b/lib/pure/collections/sets.nim index 20f73ded3..552e41ef7 100644 --- a/lib/pure/collections/sets.nim +++ b/lib/pure/collections/sets.nim @@ -58,7 +58,7 @@ proc isValid*[A](s: HashSet[A]): bool = ## initialized. Example: ## ## .. code-block :: - ## proc savePreferences(options: Set[string]) = + ## proc savePreferences(options: HashSet[string]) = ## assert options.isValid, "Pass an initialized set!" ## # Do stuff here, may crash in release builds! result = not s.data.isNil @@ -72,7 +72,7 @@ proc len*[A](s: HashSet[A]): int = ## ## .. code-block:: ## - ## var values: Set[int] + ## var values: HashSet[int] ## assert(not values.isValid) ## assert values.len == 0 result = s.counter @@ -256,11 +256,13 @@ proc incl*[A](s: var HashSet[A], other: HashSet[A]) = assert other.isValid, "The set `other` needs to be initialized." for item in other: incl(s, item) -template doWhile(a: expr, b: stmt): stmt = +template doWhile(a, b) = while true: b if not a: break +proc default[T](t: typedesc[T]): T {.inline.} = discard + proc excl*[A](s: var HashSet[A], key: A) = ## Excludes `key` from the set `s`. ## @@ -277,12 +279,14 @@ proc excl*[A](s: var HashSet[A], key: A) = var msk = high(s.data) if i >= 0: s.data[i].hcode = 0 + s.data[i].key = default(type(s.data[i].key)) dec(s.counter) while true: # KnuthV3 Algo6.4R adapted for i=i+1 instead of i=i-1 var j = i # The correctness of this depends on (h+1) in nextTry, var r = j # though may be adaptable to other simple sequences. s.data[i].hcode = 0 # mark current EMPTY - doWhile ((i >= r and r > j) or (r > j and j > i) or (j > i and i >= r)): + s.data[i].key = default(type(s.data[i].key)) + doWhile((i >= r and r > j) or (r > j and j > i) or (j > i and i >= r)): i = (i + 1) and msk # increment mod table size if isEmpty(s.data[i].hcode): # end of collision cluster; So all done return @@ -334,7 +338,7 @@ proc init*[A](s: var HashSet[A], initialSize=64) = ## existing values and calling `excl() <#excl,TSet[A],A>`_ on them. Example: ## ## .. code-block :: - ## var a: Set[int] + ## var a: HashSet[int] ## a.init(4) ## a.incl(2) ## a.init @@ -367,7 +371,7 @@ proc toSet*[A](keys: openArray[A]): HashSet[A] = result = initSet[A](rightSize(keys.len)) for key in items(keys): result.incl(key) -template dollarImpl(): stmt {.dirty.} = +template dollarImpl() {.dirty.} = result = "{" for key in items(s): if result.len > 1: result.add(", ") @@ -611,7 +615,7 @@ proc card*[A](s: OrderedSet[A]): int {.inline.} = ## <http://en.wikipedia.org/wiki/Cardinality>`_ of a set. result = s.counter -template forAllOrderedPairs(yieldStmt: stmt) {.dirty, immediate.} = +template forAllOrderedPairs(yieldStmt: untyped) {.dirty.} = var h = s.first while h >= 0: var nxt = s.data[h].next diff --git a/lib/pure/collections/sharedtables.nim b/lib/pure/collections/sharedtables.nim index 20e1bb7a9..17600b272 100644 --- a/lib/pure/collections/sharedtables.nim +++ b/lib/pure/collections/sharedtables.nim @@ -45,6 +45,59 @@ template withLock(t, x: untyped) = x release(t.lock) +template withValue*[A, B](t: var SharedTable[A, B], key: A, + value, body: untyped) = + ## retrieves the value at ``t[key]``. + ## `value` can be modified in the scope of the ``withValue`` call. + ## + ## .. code-block:: nim + ## + ## sharedTable.withValue(key, value) do: + ## # block is executed only if ``key`` in ``t`` + ## # value is threadsafe in block + ## value.name = "username" + ## value.uid = 1000 + ## + acquire(t.lock) + try: + var hc: Hash + var index = rawGet(t, key, hc) + let hasKey = index >= 0 + if hasKey: + var value {.inject.} = addr(t.data[index].val) + body + finally: + release(t.lock) + +template withValue*[A, B](t: var SharedTable[A, B], key: A, + value, body1, body2: untyped) = + ## retrieves the value at ``t[key]``. + ## `value` can be modified in the scope of the ``withValue`` call. + ## + ## .. code-block:: nim + ## + ## sharedTable.withValue(key, value) do: + ## # block is executed only if ``key`` in ``t`` + ## # value is threadsafe in block + ## value.name = "username" + ## value.uid = 1000 + ## do: + ## # block is executed when ``key`` not in ``t`` + ## raise newException(KeyError, "Key not found") + ## + acquire(t.lock) + try: + var hc: Hash + var index = rawGet(t, key, hc) + let hasKey = index >= 0 + if hasKey: + var value {.inject.} = addr(t.data[index].val) + body1 + else: + body2 + finally: + release(t.lock) + proc mget*[A, B](t: var SharedTable[A, B], key: A): var B = ## retrieves the value at ``t[key]``. The value can be modified. ## If `key` is not in `t`, the ``KeyError`` exception is raised. diff --git a/lib/pure/collections/tableimpl.nim b/lib/pure/collections/tableimpl.nim index e4ec05b1c..be3507137 100644 --- a/lib/pure/collections/tableimpl.nim +++ b/lib/pure/collections/tableimpl.nim @@ -72,14 +72,14 @@ proc rawInsert[X, A, B](t: var X, data: var KeyValuePairSeq[A, B], key: A, val: B, hc: Hash, h: Hash) = rawInsertImpl() -template addImpl(enlarge) {.dirty, immediate.} = +template addImpl(enlarge) {.dirty.} = if mustRehash(t.dataLen, t.counter): enlarge(t) var hc: Hash var j = rawGetDeep(t, key, hc) rawInsert(t, t.data, key, val, hc, j) inc(t.counter) -template maybeRehashPutImpl(enlarge) {.dirty, immediate.} = +template maybeRehashPutImpl(enlarge) {.oldimmediate, dirty.} = if mustRehash(t.dataLen, t.counter): enlarge(t) index = rawGetKnownHC(t, key, hc) @@ -87,13 +87,13 @@ template maybeRehashPutImpl(enlarge) {.dirty, immediate.} = rawInsert(t, t.data, key, val, hc, index) inc(t.counter) -template putImpl(enlarge) {.dirty, immediate.} = +template putImpl(enlarge) {.oldimmediate, dirty.} = var hc: Hash var index = rawGet(t, key, hc) if index >= 0: t.data[index].val = val else: maybeRehashPutImpl(enlarge) -template mgetOrPutImpl(enlarge) {.dirty, immediate.} = +template mgetOrPutImpl(enlarge) {.dirty.} = var hc: Hash var index = rawGet(t, key, hc) if index < 0: @@ -102,7 +102,7 @@ template mgetOrPutImpl(enlarge) {.dirty, immediate.} = # either way return modifiable val result = t.data[index].val -template hasKeyOrPutImpl(enlarge) {.dirty, immediate.} = +template hasKeyOrPutImpl(enlarge) {.dirty.} = var hc: Hash var index = rawGet(t, key, hc) if index < 0: @@ -110,18 +110,24 @@ template hasKeyOrPutImpl(enlarge) {.dirty, immediate.} = maybeRehashPutImpl(enlarge) else: result = true -template delImpl() {.dirty, immediate.} = +proc default[T](t: typedesc[T]): T {.inline.} = discard + +template delImpl() {.dirty.} = var hc: Hash var i = rawGet(t, key, hc) let msk = maxHash(t) if i >= 0: t.data[i].hcode = 0 + t.data[i].key = default(type(t.data[i].key)) + t.data[i].val = default(type(t.data[i].val)) dec(t.counter) block outer: while true: # KnuthV3 Algo6.4R adapted for i=i+1 instead of i=i-1 var j = i # The correctness of this depends on (h+1) in nextTry, var r = j # though may be adaptable to other simple sequences. t.data[i].hcode = 0 # mark current EMPTY + t.data[i].key = default(type(t.data[i].key)) + t.data[i].val = default(type(t.data[i].val)) while true: i = (i + 1) and msk # increment mod table size if isEmpty(t.data[i].hcode): # end of collision cluster; So all done @@ -133,3 +139,10 @@ template delImpl() {.dirty, immediate.} = t.data[j] = t.data[i] else: shallowCopy(t.data[j], t.data[i]) # data[j] will be marked EMPTY next loop + +template clearImpl() {.dirty.} = + for i in 0 .. <t.data.len: + t.data[i].hcode = 0 + t.data[i].key = default(type(t.data[i].key)) + t.data[i].val = default(type(t.data[i].val)) + t.counter = 0 diff --git a/lib/pure/collections/tables.nim b/lib/pure/collections/tables.nim index 2ed0d2034..9308095aa 100644 --- a/lib/pure/collections/tables.nim +++ b/lib/pure/collections/tables.nim @@ -16,6 +16,39 @@ ## semantics, this means that ``=`` performs a copy of the hash table. ## For **reference** semantics use the ``Ref`` variant: ``TableRef``, ## ``OrderedTableRef``, ``CountTableRef``. +## To give an example, when `a` is a Table, then `var b = a` gives `b` +## as a new independent table. b is initialised with the contents of `a`. +## Changing `b` does not affect `a` and vice versa: +## +## .. code-block:: +## import tables +## +## var +## a = {1: "one", 2: "two"}.toTable # creates a Table +## b = a +## +## echo a, b # output: {1: one, 2: two}{1: one, 2: two} +## +## b[3] = "three" +## echo a, b # output: {1: one, 2: two}{1: one, 2: two, 3: three} +## echo a == b # output: false +## +## On the other hand, when `a` is a TableRef instead, then changes to `b` also affect `a`. +## Both `a` and `b` reference the same data structure: +## +## .. code-block:: +## import tables +## +## var +## a = {1: "one", 2: "two"}.newTable # creates a TableRef +## b = a +## +## echo a, b # output: {1: one, 2: two}{1: one, 2: two} +## +## b[3] = "three" +## echo a, b # output: {1: one, 2: two, 3: three}{1: one, 2: two, 3: three} +## echo a == b # output: true +## ## ## If you are using simple standard types like ``int`` or ``string`` for the ## keys of the table you won't have any problems, but as soon as you try to use @@ -80,11 +113,15 @@ type {.deprecated: [TTable: Table, PTable: TableRef].} -template maxHash(t): expr {.immediate.} = high(t.data) -template dataLen(t): expr = len(t.data) +template maxHash(t): untyped = high(t.data) +template dataLen(t): untyped = len(t.data) include tableimpl +proc clear*[A, B](t: Table[A, B] | TableRef[A, B]) = + ## Resets the table so that it is empty. + clearImpl() + proc rightSize*(count: Natural): int {.inline.} = ## Return the value of `initialSize` to support `count` items. ## @@ -98,7 +135,7 @@ proc len*[A, B](t: Table[A, B]): int = ## returns the number of keys in `t`. result = t.counter -template get(t, key): untyped {.immediate.} = +template get(t, key): untyped = ## retrieves the value at ``t[key]``. The value can be modified. ## If `key` is not in `t`, the ``KeyError`` exception is raised. mixin rawGet @@ -111,7 +148,7 @@ template get(t, key): untyped {.immediate.} = else: raise newException(KeyError, "key not found") -template getOrDefaultImpl(t, key): untyped {.immediate.} = +template getOrDefaultImpl(t, key): untyped = mixin rawGet var hc: Hash var index = rawGet(t, key, hc) @@ -136,6 +173,51 @@ proc mget*[A, B](t: var Table[A, B], key: A): var B {.deprecated.} = proc getOrDefault*[A, B](t: Table[A, B], key: A): B = getOrDefaultImpl(t, key) +template withValue*[A, B](t: var Table[A, B], key: A, + value, body: untyped) = + ## retrieves the value at ``t[key]``. + ## `value` can be modified in the scope of the ``withValue`` call. + ## + ## .. code-block:: nim + ## + ## sharedTable.withValue(key, value) do: + ## # block is executed only if ``key`` in ``t`` + ## value.name = "username" + ## value.uid = 1000 + ## + mixin rawGet + var hc: Hash + var index = rawGet(t, key, hc) + let hasKey = index >= 0 + if hasKey: + var value {.inject.} = addr(t.data[index].val) + body + +template withValue*[A, B](t: var Table[A, B], key: A, + value, body1, body2: untyped) = + ## retrieves the value at ``t[key]``. + ## `value` can be modified in the scope of the ``withValue`` call. + ## + ## .. code-block:: nim + ## + ## table.withValue(key, value) do: + ## # block is executed only if ``key`` in ``t`` + ## value.name = "username" + ## value.uid = 1000 + ## do: + ## # block is executed when ``key`` not in ``t`` + ## raise newException(KeyError, "Key not found") + ## + mixin rawGet + var hc: Hash + var index = rawGet(t, key, hc) + let hasKey = index >= 0 + if hasKey: + var value {.inject.} = addr(t.data[index].val) + body1 + else: + body2 + iterator allValues*[A, B](t: Table[A, B]; key: A): B = ## iterates over any value in the table `t` that belongs to the given `key`. var h: Hash = hash(key) and high(t.data) @@ -229,7 +311,7 @@ proc toTable*[A, B](pairs: openArray[(A, result = initTable[A, B](rightSize(pairs.len)) for key, val in items(pairs): result[key] = val -template dollarImpl(): stmt {.dirty.} = +template dollarImpl(): untyped {.dirty.} = if t.len == 0: result = "{:}" else: @@ -249,18 +331,17 @@ proc hasKey*[A, B](t: TableRef[A, B], key: A): bool = ## returns true iff `key` is in the table `t`. result = t[].hasKey(key) -template equalsImpl() = +template equalsImpl(t) = if s.counter == t.counter: # different insertion orders mean different 'data' seqs, so we have # to use the slow route here: for key, val in s: - # prefix notation leads to automatic dereference in case of PTable if not t.hasKey(key): return false - if t[key] != val: return false + if t.getOrDefault(key) != val: return false return true proc `==`*[A, B](s, t: Table[A, B]): bool = - equalsImpl() + equalsImpl(t) proc indexBy*[A, B, C](collection: A, index: proc(x: B): C): Table[C, B] = ## Index the collection with the proc provided. @@ -350,7 +431,7 @@ proc `$`*[A, B](t: TableRef[A, B]): string = proc `==`*[A, B](s, t: TableRef[A, B]): bool = if isNil(s): result = isNil(t) elif isNil(t): result = false - else: equalsImpl() + else: equalsImpl(t[]) proc newTableFrom*[A, B, C](collection: A, index: proc(x: B): C): TableRef[C, B] = ## Index the collection with the proc provided. @@ -376,7 +457,13 @@ proc len*[A, B](t: OrderedTable[A, B]): int {.inline.} = ## returns the number of keys in `t`. result = t.counter -template forAllOrderedPairs(yieldStmt: stmt) {.dirty, immediate.} = +proc clear*[A, B](t: OrderedTable[A, B] | OrderedTableRef[A, B]) = + ## Resets the table so that it is empty. + clearImpl() + t.first = -1 + t.last = -1 + +template forAllOrderedPairs(yieldStmt: untyped) {.oldimmediate, dirty.} = var h = t.first while h >= 0: var nxt = t.data[h].next @@ -562,7 +649,7 @@ proc len*[A, B](t: OrderedTableRef[A, B]): int {.inline.} = ## returns the number of keys in `t`. result = t.counter -template forAllOrderedPairs(yieldStmt: stmt) {.dirty, immediate.} = +template forAllOrderedPairs(yieldStmt: untyped) {.oldimmediate, dirty.} = var h = t.first while h >= 0: var nxt = t.data[h].next @@ -663,6 +750,27 @@ proc sort*[A, B](t: OrderedTableRef[A, B], ## contrast to the `sort` for count tables). t[].sort(cmp) +proc del*[A, B](t: var OrderedTable[A, B], key: A) = + ## deletes `key` from ordered hash table `t`. O(n) comlexity. + var prev = -1 + let hc = hash(key) + forAllOrderedPairs: + if t.data[h].hcode == hc: + if t.first == h: + t.first = t.data[h].next + else: + t.data[prev].next = t.data[h].next + var zeroValue : type(t.data[h]) + t.data[h] = zeroValue + dec t.counter + break + else: + prev = h + +proc del*[A, B](t: var OrderedTableRef[A, B], key: A) = + ## deletes `key` from ordered hash table `t`. O(n) comlexity. + t[].del(key) + # ------------------------------ count tables ------------------------------- type @@ -678,6 +786,11 @@ proc len*[A](t: CountTable[A]): int = ## returns the number of keys in `t`. result = t.counter +proc clear*[A](t: CountTable[A] | CountTableRef[A]) = + ## Resets the table so that it is empty. + clearImpl() + t.counter = 0 + iterator pairs*[A](t: CountTable[A]): (A, int) = ## iterates over any (key, value) pair in the table `t`. for h in 0..high(t.data): @@ -711,7 +824,7 @@ proc rawGet[A](t: CountTable[A], key: A): int = h = nextTry(h, high(t.data)) result = -1 - h # < 0 => MISSING; insert idx = -1 - result -template ctget(t, key: untyped): untyped {.immediate.} = +template ctget(t, key: untyped): untyped = var index = rawGet(t, key) if index >= 0: result = t.data[index].val else: @@ -984,6 +1097,26 @@ when isMainModule: s3[p1] = 30_000 s3[p2] = 45_000 + block: # Ordered table should preserve order after deletion + var + s4 = initOrderedTable[int, int]() + s4[1] = 1 + s4[2] = 2 + s4[3] = 3 + + var prev = 0 + for i in s4.values: + doAssert(prev < i) + prev = i + + s4.del(2) + doAssert(2 notin s4) + doAssert(s4.len == 2) + prev = 0 + for i in s4.values: + doAssert(prev < i) + prev = i + var t1 = initCountTable[string]() t2 = initCountTable[string]() @@ -1040,3 +1173,12 @@ when isMainModule: doAssert 0 == t.getOrDefault(testKey) t.inc(testKey,3) doAssert 3 == t.getOrDefault(testKey) + + # Clear tests + var clearTable = newTable[int, string]() + clearTable[42] = "asd" + clearTable[123123] = "piuyqwb " + doAssert clearTable[42] == "asd" + clearTable.clear() + doAssert(not clearTable.hasKey(123123)) + doAssert clearTable.getOrDefault(42) == nil diff --git a/lib/pure/concurrency/cpuload.nim b/lib/pure/concurrency/cpuload.nim index 22598b5c9..b0fd002ed 100644 --- a/lib/pure/concurrency/cpuload.nim +++ b/lib/pure/concurrency/cpuload.nim @@ -79,6 +79,8 @@ proc advice*(s: var ThreadPoolState): ThreadPoolAdvice = inc s.calls when not defined(testing) and isMainModule: + import random + proc busyLoop() = while true: discard random(80) diff --git a/lib/pure/future.nim b/lib/pure/future.nim index 3793edc8b..67975cfcb 100644 --- a/lib/pure/future.nim +++ b/lib/pure/future.nim @@ -29,28 +29,24 @@ proc createProcType(p, b: NimNode): NimNode {.compileTime.} = of nnkExprColonExpr: identDefs.add ident[0] identDefs.add ident[1] - of nnkIdent: + else: identDefs.add newIdentNode("i" & $i) identDefs.add(ident) - else: - error("Incorrect type list in proc type declaration.") identDefs.add newEmptyNode() formalParams.add identDefs - of nnkIdent: + else: var identDefs = newNimNode(nnkIdentDefs) identDefs.add newIdentNode("i0") identDefs.add(p) identDefs.add newEmptyNode() formalParams.add identDefs - else: - error("Incorrect type list in proc type declaration.") result.add formalParams result.add newEmptyNode() #echo(treeRepr(result)) #echo(result.toStrLit()) -macro `=>`*(p, b: expr): expr {.immediate.} = +macro `=>`*(p, b: untyped): untyped = ## Syntax sugar for anonymous procedures. ## ## .. code-block:: nim @@ -111,7 +107,7 @@ macro `=>`*(p, b: expr): expr {.immediate.} = #echo(result.toStrLit()) #return result # TODO: Bug? -macro `->`*(p, b: expr): expr {.immediate.} = +macro `->`*(p, b: untyped): untyped = ## Syntax sugar for procedure types. ## ## .. code-block:: nim @@ -129,7 +125,7 @@ macro `->`*(p, b: expr): expr {.immediate.} = type ListComprehension = object var lc*: ListComprehension -macro `[]`*(lc: ListComprehension, comp, typ: expr): expr = +macro `[]`*(lc: ListComprehension, comp, typ: untyped): untyped = ## List comprehension, returns a sequence. `comp` is the actual list ## comprehension, for example ``x | (x <- 1..10, x mod 2 == 0)``. `typ` is ## the type that will be stored inside the result seq. diff --git a/lib/pure/htmlparser.nim b/lib/pure/htmlparser.nim index d620e816e..1fe0b297b 100644 --- a/lib/pure/htmlparser.nim +++ b/lib/pure/htmlparser.nim @@ -433,7 +433,7 @@ proc htmlTag*(n: XmlNode): HtmlTag = proc htmlTag*(s: string): HtmlTag = ## converts `s` to a ``HtmlTag``. If `s` is no HTML tag, ``tagUnknown`` is ## returned. - let s = if allLower(s): s else: s.toLower + let s = if allLower(s): s else: toLowerAscii(s) result = toHtmlTag(s) proc entityToUtf8*(entity: string): string = @@ -464,12 +464,18 @@ proc untilElementEnd(x: var XmlParser, result: XmlNode, case x.kind of xmlElementStart, xmlElementOpen: case result.htmlTag - of tagLi, tagP, tagDt, tagDd, tagInput, tagOption: - # some tags are common to have no ``</end>``, like ``<li>``: + of tagP, tagInput, tagOption: + # some tags are common to have no ``</end>``, like ``<li>`` but + # allow ``<p>`` in `<dd>`, `<dt>` and ``<li>`` in next case if htmlTag(x.elemName) in {tagLi, tagP, tagDt, tagDd, tagInput, tagOption}: errors.add(expected(x, result)) break + of tagDd, tagDt, tagLi: + if htmlTag(x.elemName) in {tagLi, tagDt, tagDd, tagInput, + tagOption}: + errors.add(expected(x, result)) + break of tagTd, tagTh: if htmlTag(x.elemName) in {tagTr, tagTd, tagTh, tagTfoot, tagThead}: errors.add(expected(x, result)) @@ -513,13 +519,13 @@ proc parse(x: var XmlParser, errors: var seq[string]): XmlNode = errors.add(errorMsg(x)) next(x) of xmlElementStart: - result = newElement(x.elemName.toLower) + result = newElement(toLowerAscii(x.elemName)) next(x) untilElementEnd(x, result, errors) of xmlElementEnd: errors.add(errorMsg(x, "unexpected ending tag: " & x.elemName)) of xmlElementOpen: - result = newElement(x.elemName.toLower) + result = newElement(toLowerAscii(x.elemName)) next(x) result.attrs = newStringTable() while true: diff --git a/lib/pure/httpclient.nim b/lib/pure/httpclient.nim index 603763386..778ca2cbb 100644 --- a/lib/pure/httpclient.nim +++ b/lib/pure/httpclient.nim @@ -79,15 +79,18 @@ ## constructor should be used for this purpose. However, ## currently only basic authentication is supported. -import net, strutils, uri, parseutils, strtabs, base64, os, mimetypes, math +import net, strutils, uri, parseutils, strtabs, base64, os, mimetypes, + math, random, httpcore import asyncnet, asyncdispatch import nativesockets +export httpcore except parseHeader # TODO: The ``except`` doesn't work + type Response* = tuple[ version: string, status: string, - headers: StringTableRef, + headers: HttpHeaders, body: string] Proxy* = ref object @@ -164,7 +167,7 @@ proc parseChunks(s: Socket, timeout: int): string = # Trailer headers will only be sent if the request specifies that we want # them: http://tools.ietf.org/html/rfc2616#section-3.6.1 -proc parseBody(s: Socket, headers: StringTableRef, timeout: int): string = +proc parseBody(s: Socket, headers: HttpHeaders, httpVersion: string, timeout: int): string = result = "" if headers.getOrDefault"Transfer-Encoding" == "chunked": result = parseChunks(s, timeout) @@ -190,7 +193,7 @@ proc parseBody(s: Socket, headers: StringTableRef, timeout: int): string = # -REGION- Connection: Close # (http://tools.ietf.org/html/rfc2616#section-4.4) NR.5 - if headers.getOrDefault"Connection" == "close": + if headers.getOrDefault"Connection" == "close" or httpVersion == "1.0": var buf = "" while true: buf = newString(4000) @@ -204,7 +207,7 @@ proc parseResponse(s: Socket, getBody: bool, timeout: int): Response = var linei = 0 var fullyRead = false var line = "" - result.headers = newStringTable(modeCaseInsensitive) + result.headers = newHttpHeaders() while true: line = "" linei = 0 @@ -239,10 +242,14 @@ proc parseResponse(s: Socket, getBody: bool, timeout: int): Response = inc(linei) # Skip : result.headers[name] = line[linei.. ^1].strip() + # Ensure the server isn't trying to DoS us. + if result.headers.len > headerLimit: + httpError("too many headers") + if not fullyRead: httpError("Connection was closed before full request has been made") if getBody: - result.body = parseBody(s, result.headers, timeout) + result.body = parseBody(s, result.headers, result.version, timeout) else: result.body = "" @@ -335,7 +342,7 @@ proc addFiles*(p: var MultipartData, xs: openarray[tuple[name, file: string]]): var m = newMimetypes() for name, file in xs.items: var contentType: string - let (dir, fName, ext) = splitFile(file) + let (_, fName, ext) = splitFile(file) if ext.len > 0: contentType = m.getMimetype(ext[1..ext.high], nil) p.add(name, readFile(file), fName & ext, contentType) @@ -392,6 +399,59 @@ proc request*(url: string, httpMethod: string, extraHeaders = "", var hostUrl = if proxy == nil: r else: parseUri(url) var headers = substr(httpMethod, len("http")) # TODO: Use generateHeaders further down once it supports proxies. + + var s = newSocket() + defer: s.close() + if s == nil: raiseOSError(osLastError()) + var port = net.Port(80) + if r.scheme == "https": + when defined(ssl): + sslContext.wrapSocket(s) + port = net.Port(443) + else: + raise newException(HttpRequestError, + "SSL support is not available. Cannot connect over SSL.") + if r.port != "": + port = net.Port(r.port.parseInt) + + + # get the socket ready. If we are connecting through a proxy to SSL, + # send the appropiate CONNECT header. If not, simply connect to the proper + # host (which may still be the proxy, for normal HTTP) + if proxy != nil and hostUrl.scheme == "https": + when defined(ssl): + var connectHeaders = "CONNECT " + let targetPort = if hostUrl.port == "": 443 else: hostUrl.port.parseInt + connectHeaders.add(hostUrl.hostname) + connectHeaders.add(":" & $targetPort) + connectHeaders.add(" HTTP/1.1\c\L") + connectHeaders.add("Host: " & hostUrl.hostname & ":" & $targetPort & "\c\L") + if proxy.auth != "": + let auth = base64.encode(proxy.auth, newline = "") + connectHeaders.add("Proxy-Authorization: basic " & auth & "\c\L") + connectHeaders.add("\c\L") + if timeout == -1: + s.connect(r.hostname, port) + else: + s.connect(r.hostname, port, timeout) + + s.send(connectHeaders) + let connectResult = parseResponse(s, false, timeout) + if not connectResult.status.startsWith("200"): + raise newException(HttpRequestError, + "The proxy server rejected a CONNECT request, " & + "so a secure connection could not be established.") + sslContext.wrapConnectedSocket(s, handshakeAsClient) + else: + raise newException(HttpRequestError, "SSL support not available. Cannot connect via proxy over SSL") + else: + if timeout == -1: + s.connect(r.hostname, port) + else: + s.connect(r.hostname, port, timeout) + + + # now that the socket is ready, prepare the headers if proxy == nil: headers.add ' ' if r.path[0] != '/': headers.add '/' @@ -415,29 +475,13 @@ proc request*(url: string, httpMethod: string, extraHeaders = "", add(headers, "Proxy-Authorization: basic " & auth & "\c\L") add(headers, extraHeaders) add(headers, "\c\L") - var s = newSocket() - if s == nil: raiseOSError(osLastError()) - var port = net.Port(80) - if r.scheme == "https": - when defined(ssl): - sslContext.wrapSocket(s) - port = net.Port(443) - else: - raise newException(HttpRequestError, - "SSL support is not available. Cannot connect over SSL.") - if r.port != "": - port = net.Port(r.port.parseInt) - if timeout == -1: - s.connect(r.hostname, port) - else: - s.connect(r.hostname, port, timeout) + # headers are ready. send them, await the result, and close the socket. s.send(headers) if body != "": s.send(body) result = parseResponse(s, httpMethod != "httpHEAD", timeout) - s.close() proc request*(url: string, httpMethod = httpGET, extraHeaders = "", body = "", sslContext = defaultSSLContext, timeout = -1, @@ -455,7 +499,7 @@ proc redirection(status: string): bool = if status.startsWith(i): return true -proc getNewLocation(lastURL: string, headers: StringTableRef): string = +proc getNewLocation(lastURL: string, headers: HttpHeaders): string = result = headers.getOrDefault"Location" if result == "": httpError("location header expected") # Relative URLs. (Not part of the spec, but soon will be.) @@ -624,10 +668,10 @@ proc newAsyncHttpClient*(userAgent = defUserAgent, ## ``sslContext`` specifies the SSL context to use for HTTPS requests. new result result.headers = newStringTable(modeCaseInsensitive) - result.userAgent = defUserAgent + result.userAgent = userAgent result.maxRedirects = maxRedirects when defined(ssl): - result.sslContext = net.SslContext(sslContext) + result.sslContext = sslContext proc close*(client: AsyncHttpClient) = ## Closes any connections held by the HTTP client. @@ -678,7 +722,8 @@ proc parseChunks(client: AsyncHttpClient): Future[string] {.async.} = # them: http://tools.ietf.org/html/rfc2616#section-3.6.1 proc parseBody(client: AsyncHttpClient, - headers: StringTableRef): Future[string] {.async.} = + headers: HttpHeaders, + httpVersion: string): Future[string] {.async.} = result = "" if headers.getOrDefault"Transfer-Encoding" == "chunked": result = await parseChunks(client) @@ -700,7 +745,7 @@ proc parseBody(client: AsyncHttpClient, # -REGION- Connection: Close # (http://tools.ietf.org/html/rfc2616#section-4.4) NR.5 - if headers.getOrDefault"Connection" == "close": + if headers.getOrDefault"Connection" == "close" or httpVersion == "1.0": var buf = "" while true: buf = await client.socket.recvFull(4000) @@ -713,7 +758,7 @@ proc parseResponse(client: AsyncHttpClient, var linei = 0 var fullyRead = false var line = "" - result.headers = newStringTable(modeCaseInsensitive) + result.headers = newHttpHeaders() while true: linei = 0 line = await client.socket.recvLine() @@ -748,10 +793,13 @@ proc parseResponse(client: AsyncHttpClient, inc(linei) # Skip : result.headers[name] = line[linei.. ^1].strip() + if result.headers.len > headerLimit: + httpError("too many headers") + if not fullyRead: httpError("Connection was closed before full request has been made") if getBody: - result.body = await parseBody(client, result.headers) + result.body = await parseBody(client, result.headers, result.version) else: result.body = "" diff --git a/lib/pure/httpcore.nim b/lib/pure/httpcore.nim new file mode 100644 index 000000000..562d16c19 --- /dev/null +++ b/lib/pure/httpcore.nim @@ -0,0 +1,198 @@ +# +# +# Nim's Runtime Library +# (c) Copyright 2016 Dominik Picheta +# +# See the file "copying.txt", included in this +# distribution, for details about the copyright. +# + +## Contains functionality shared between the ``httpclient`` and +## ``asynchttpserver`` modules. + +import tables, strutils, parseutils + +type + HttpHeaders* = ref object + table*: TableRef[string, seq[string]] + + HttpHeaderValues* = distinct seq[string] + + HttpCode* = enum + Http100 = "100 Continue", + Http101 = "101 Switching Protocols", + Http200 = "200 OK", + Http201 = "201 Created", + Http202 = "202 Accepted", + Http203 = "203 Non-Authoritative Information", + Http204 = "204 No Content", + Http205 = "205 Reset Content", + Http206 = "206 Partial Content", + Http300 = "300 Multiple Choices", + Http301 = "301 Moved Permanently", + Http302 = "302 Found", + Http303 = "303 See Other", + Http304 = "304 Not Modified", + Http305 = "305 Use Proxy", + Http307 = "307 Temporary Redirect", + Http400 = "400 Bad Request", + Http401 = "401 Unauthorized", + Http403 = "403 Forbidden", + Http404 = "404 Not Found", + Http405 = "405 Method Not Allowed", + Http406 = "406 Not Acceptable", + Http407 = "407 Proxy Authentication Required", + Http408 = "408 Request Timeout", + Http409 = "409 Conflict", + Http410 = "410 Gone", + Http411 = "411 Length Required", + Http412 = "412 Precondition Failed", + Http413 = "413 Request Entity Too Large", + Http414 = "414 Request-URI Too Long", + Http415 = "415 Unsupported Media Type", + Http416 = "416 Requested Range Not Satisfiable", + Http417 = "417 Expectation Failed", + Http418 = "418 I'm a teapot", + Http421 = "421 Misdirected Request", + Http422 = "422 Unprocessable Entity", + Http426 = "426 Upgrade Required", + Http428 = "428 Precondition Required", + Http429 = "429 Too Many Requests", + Http431 = "431 Request Header Fields Too Large", + Http451 = "451 Unavailable For Legal Reasons", + Http500 = "500 Internal Server Error", + Http501 = "501 Not Implemented", + Http502 = "502 Bad Gateway", + Http503 = "503 Service Unavailable", + Http504 = "504 Gateway Timeout", + Http505 = "505 HTTP Version Not Supported" + + HttpVersion* = enum + HttpVer11, + HttpVer10 + +const headerLimit* = 10_000 + +proc newHttpHeaders*(): HttpHeaders = + new result + result.table = newTable[string, seq[string]]() + +proc newHttpHeaders*(keyValuePairs: + openarray[tuple[key: string, val: string]]): HttpHeaders = + var pairs: seq[tuple[key: string, val: seq[string]]] = @[] + for pair in keyValuePairs: + pairs.add((pair.key.toLower(), @[pair.val])) + new result + result.table = newTable[string, seq[string]](pairs) + +proc clear*(headers: HttpHeaders) = + headers.table.clear() + +proc `[]`*(headers: HttpHeaders, key: string): HttpHeaderValues = + ## Returns the values associated with the given ``key``. If the returned + ## values are passed to a procedure expecting a ``string``, the first + ## value is automatically picked. If there are + ## no values associated with the key, an exception is raised. + ## + ## To access multiple values of a key, use the overloaded ``[]`` below or + ## to get all of them access the ``table`` field directly. + return headers.table[key.toLower].HttpHeaderValues + +converter toString*(values: HttpHeaderValues): string = + return seq[string](values)[0] + +proc `[]`*(headers: HttpHeaders, key: string, i: int): string = + ## Returns the ``i``'th value associated with the given key. If there are + ## no values associated with the key or the ``i``'th value doesn't exist, + ## an exception is raised. + return headers.table[key.toLower][i] + +proc `[]=`*(headers: HttpHeaders, key, value: string) = + ## Sets the header entries associated with ``key`` to the specified value. + ## Replaces any existing values. + headers.table[key.toLower] = @[value] + +proc `[]=`*(headers: HttpHeaders, key: string, value: seq[string]) = + ## Sets the header entries associated with ``key`` to the specified list of + ## values. + ## Replaces any existing values. + headers.table[key.toLower] = value + +proc add*(headers: HttpHeaders, key, value: string) = + ## Adds the specified value to the specified key. Appends to any existing + ## values associated with the key. + if not headers.table.hasKey(key.toLower): + headers.table[key.toLower] = @[value] + else: + headers.table[key.toLower].add(value) + +iterator pairs*(headers: HttpHeaders): tuple[key, value: string] = + ## Yields each key, value pair. + for k, v in headers.table: + for value in v: + yield (k, value) + +proc contains*(values: HttpHeaderValues, value: string): bool = + ## Determines if ``value`` is one of the values inside ``values``. Comparison + ## is performed without case sensitivity. + for val in seq[string](values): + if val.toLower == value.toLower: return true + +proc hasKey*(headers: HttpHeaders, key: string): bool = + return headers.table.hasKey(key.toLower()) + +proc getOrDefault*(headers: HttpHeaders, key: string, + default = @[""].HttpHeaderValues): HttpHeaderValues = + ## Returns the values associated with the given ``key``. If there are no + ## values associated with the key, then ``default`` is returned. + if headers.hasKey(key): + return headers[key] + else: + return default + +proc len*(headers: HttpHeaders): int = return headers.table.len + +proc parseList(line: string, list: var seq[string], start: int): int = + var i = 0 + var current = "" + while line[start + i] notin {'\c', '\l', '\0'}: + i += line.skipWhitespace(start + i) + i += line.parseUntil(current, {'\c', '\l', ','}, start + i) + list.add(current) + if line[start + i] == ',': + i.inc # Skip , + current.setLen(0) + +proc parseHeader*(line: string): tuple[key: string, value: seq[string]] = + ## Parses a single raw header HTTP line into key value pairs. + ## + ## Used by ``asynchttpserver`` and ``httpclient`` internally and should not + ## be used by you. + result.value = @[] + var i = 0 + i = line.parseUntil(result.key, ':') + inc(i) # skip : + if i < len(line): + i += parseList(line, result.value, i) + else: + result.value = @[] + +proc `==`*(protocol: tuple[orig: string, major, minor: int], + ver: HttpVersion): bool = + let major = + case ver + of HttpVer11, HttpVer10: 1 + let minor = + case ver + of HttpVer11: 1 + of HttpVer10: 0 + result = protocol.major == major and protocol.minor == minor + +when isMainModule: + var test = newHttpHeaders() + test["Connection"] = @["Upgrade", "Close"] + doAssert test["Connection", 0] == "Upgrade" + doAssert test["Connection", 1] == "Close" + test.add("Connection", "Test") + doAssert test["Connection", 2] == "Test" + doAssert "upgrade" in test["Connection"] diff --git a/lib/pure/ioselectors.nim b/lib/pure/ioselectors.nim new file mode 100644 index 000000000..a5d5d2c01 --- /dev/null +++ b/lib/pure/ioselectors.nim @@ -0,0 +1,255 @@ +# +# +# Nim's Runtime Library +# (c) Copyright 2016 Eugene Kabanov +# +# See the file "copying.txt", included in this +# distribution, for details about the copyright. +# + +## This module allows high-level and efficient I/O multiplexing. +## +## Supported OS primitives: ``epoll``, ``kqueue``, ``poll`` and +## Windows ``select``. +## +## To use threadsafe version of this module, it needs to be compiled +## with both ``-d:threadsafe`` and ``--threads:on`` options. +## +## Supported features: files, sockets, pipes, timers, processes, signals +## and user events. +## +## Fully supported OS: MacOSX, FreeBSD, OpenBSD, NetBSD, Linux. +## +## Partially supported OS: Windows (only sockets and user events), +## Solaris (files, sockets, handles and user events). +## +## TODO: ``/dev/poll``, ``event ports`` and filesystem events. + +import os + +const hasThreadSupport = compileOption("threads") and defined(threadsafe) + +const supportedPlatform = defined(macosx) or defined(freebsd) or + defined(netbsd) or defined(openbsd) or + defined(linux) + +const bsdPlatform = defined(macosx) or defined(freebsd) or + defined(netbsd) or defined(openbsd) + + +when defined(nimdoc): + type + Selector*[T] = ref object + ## An object which holds descriptors to be checked for read/write status + + Event* {.pure.} = enum + ## An enum which hold event types + Read, ## Descriptor is available for read + Write, ## Descriptor is available for write + Timer, ## Timer descriptor is completed + Signal, ## Signal is raised + Process, ## Process is finished + Vnode, ## Currently not supported + User, ## User event is raised + Error ## Error happens while waiting, for descriptor + + ReadyKey*[T] = object + ## An object which holds result for descriptor + fd* : int ## file/socket descriptor + events*: set[Event] ## set of events + data*: T ## application-defined data + + SelectEvent* = object + ## An object which holds user defined event + + proc newSelector*[T](): Selector[T] = + ## Creates a new selector + + proc close*[T](s: Selector[T]) = + ## Closes selector + + proc registerHandle*[T](s: Selector[T], fd: SocketHandle, events: set[Event], + data: T) = + ## Registers file/socket descriptor ``fd`` to selector ``s`` + ## with events set in ``events``. The ``data`` is application-defined + ## data, which to be passed when event happens. + + proc updateHandle*[T](s: Selector[T], fd: SocketHandle, events: set[Event]) = + ## Update file/socket descriptor ``fd``, registered in selector + ## ``s`` with new events set ``event``. + + proc registerTimer*[T](s: Selector[T], timeout: int, oneshot: bool, + data: T): int {.discardable.} = + ## Registers timer notification with ``timeout`` in milliseconds + ## to selector ``s``. + ## If ``oneshot`` is ``true`` timer will be notified only once. + ## Set ``oneshot`` to ``false`` if your want periodic notifications. + ## The ``data`` is application-defined data, which to be passed, when + ## time limit expired. + + proc registerSignal*[T](s: Selector[T], signal: int, + data: T): int {.discardable.} = + ## Registers Unix signal notification with ``signal`` to selector + ## ``s``. The ``data`` is application-defined data, which to be + ## passed, when signal raises. + ## + ## This function is not supported for ``Windows``. + + proc registerProcess*[T](s: Selector[T], pid: int, + data: T): int {.discardable.} = + ## Registers process id (pid) notification when process has + ## exited to selector ``s``. + ## The ``data`` is application-defined data, which to be passed, when + ## process with ``pid`` has exited. + + proc registerEvent*[T](s: Selector[T], ev: SelectEvent, data: T) = + ## Registers selector event ``ev`` to selector ``s``. + ## ``data`` application-defined data, which to be passed, when + ## ``ev`` happens. + + proc newSelectEvent*(): SelectEvent = + ## Creates new event ``SelectEvent``. + + proc setEvent*(ev: SelectEvent) = + ## Trigger event ``ev``. + + proc close*(ev: SelectEvent) = + ## Closes selector event ``ev``. + + proc unregister*[T](s: Selector[T], ev: SelectEvent) = + ## Unregisters event ``ev`` from selector ``s``. + + proc unregister*[T](s: Selector[T], fd: int|SocketHandle|cint) = + ## Unregisters file/socket descriptor ``fd`` from selector ``s``. + + proc flush*[T](s: Selector[T]) = + ## Flushes all changes was made to kernel pool/queue. + ## This function is usefull only for BSD and MacOS, because + ## kqueue supports bulk changes to be made. + ## On Linux/Windows and other Posix compatible operation systems, + ## ``flush`` is alias for `discard`. + + proc selectInto*[T](s: Selector[T], timeout: int, + results: var openarray[ReadyKey[T]]): int = + ## Process call waiting for events registered in selector ``s``. + ## The ``timeout`` argument specifies the minimum number of milliseconds + ## the function will be blocked, if no events are not ready. Specifying a + ## timeout of ``-1`` causes function to block indefinitely. + ## All available events will be stored in ``results`` array. + ## + ## Function returns number of triggered events. + + proc select*[T](s: Selector[T], timeout: int): seq[ReadyKey[T]] = + ## Process call waiting for events registered in selector ``s``. + ## The ``timeout`` argument specifies the minimum number of milliseconds + ## the function will be blocked, if no events are not ready. Specifying a + ## timeout of -1 causes function to block indefinitely. + ## + ## Function returns sequence of triggered events. + + template isEmpty*[T](s: Selector[T]): bool = + ## Returns ``true``, if there no registered events or descriptors + ## in selector. + + template withData*[T](s: Selector[T], fd: SocketHandle, value, + body: untyped) = + ## retrieves the application-data assigned with descriptor ``fd`` + ## to ``value``. This ``value`` can be modified in the scope of + ## the ``withData`` call. + ## + ## .. code-block:: nim + ## + ## s.withData(fd, value) do: + ## # block is executed only if ``fd`` registered in selector ``s`` + ## value.uid = 1000 + ## + + template withData*[T](s: Selector[T], fd: SocketHandle, value, + body1, body2: untyped) = + ## retrieves the application-data assigned with descriptor ``fd`` + ## to ``value``. This ``value`` can be modified in the scope of + ## the ``withData`` call. + ## + ## .. code-block:: nim + ## + ## s.withData(fd, value) do: + ## # block is executed only if ``fd`` registered in selector ``s``. + ## value.uid = 1000 + ## do: + ## # block is executed if ``fd`` not registered in selector ``s``. + ## raise + ## + +else: + when hasThreadSupport: + import locks + + type + SharedArray {.unchecked.}[T] = array[0..100, T] + + proc allocSharedArray[T](nsize: int): ptr SharedArray[T] = + result = cast[ptr SharedArray[T]](allocShared0(sizeof(T) * nsize)) + + proc deallocSharedArray[T](sa: ptr SharedArray[T]) = + deallocShared(cast[pointer](sa)) + type + Event* {.pure.} = enum + Read, Write, Timer, Signal, Process, Vnode, User, Error, Oneshot + + ReadyKey*[T] = object + fd* : int + events*: set[Event] + data*: T + + SelectorKey[T] = object + ident: int + events: set[Event] + param: int + key: ReadyKey[T] + + when not defined(windows): + import posix + proc setNonBlocking(fd: cint) {.inline.} = + var x = fcntl(fd, F_GETFL, 0) + if x == -1: + raiseOSError(osLastError()) + else: + var mode = x or O_NONBLOCK + if fcntl(fd, F_SETFL, mode) == -1: + raiseOSError(osLastError()) + + template setKey(s, pident, pkeyfd, pevents, pparam, pdata) = + var skey = addr(s.fds[pident]) + skey.ident = pident + skey.events = pevents + skey.param = pparam + skey.key.fd = pkeyfd + skey.key.data = pdata + + when supportedPlatform: + template blockSignals(newmask: var Sigset, oldmask: var Sigset) = + when hasThreadSupport: + if posix.pthread_sigmask(SIG_BLOCK, newmask, oldmask) == -1: + raiseOSError(osLastError()) + else: + if posix.sigprocmask(SIG_BLOCK, newmask, oldmask) == -1: + raiseOSError(osLastError()) + + template unblockSignals(newmask: var Sigset, oldmask: var Sigset) = + when hasThreadSupport: + if posix.pthread_sigmask(SIG_UNBLOCK, newmask, oldmask) == -1: + raiseOSError(osLastError()) + else: + if posix.sigprocmask(SIG_UNBLOCK, newmask, oldmask) == -1: + raiseOSError(osLastError()) + + when defined(linux): + include ioselects/ioselectors_epoll + elif bsdPlatform: + include ioselects/ioselectors_kqueue + elif defined(windows): + include ioselects/ioselectors_select + elif defined(solaris): + include ioselects/ioselectors_poll # need to replace it with event ports + else: + include ioselects/ioselectors_poll diff --git a/lib/pure/ioselects/ioselectors_epoll.nim b/lib/pure/ioselects/ioselectors_epoll.nim new file mode 100644 index 000000000..92b2cdc07 --- /dev/null +++ b/lib/pure/ioselects/ioselectors_epoll.nim @@ -0,0 +1,461 @@ +# +# +# Nim's Runtime Library +# (c) Copyright 2016 Eugene Kabanov +# +# See the file "copying.txt", included in this +# distribution, for details about the copyright. +# + +# This module implements Linux epoll(). + +import posix, times + +# Maximum number of events that can be returned +const MAX_EPOLL_RESULT_EVENTS = 64 + +type + SignalFdInfo* {.importc: "struct signalfd_siginfo", + header: "<sys/signalfd.h>", pure, final.} = object + ssi_signo*: uint32 + ssi_errno*: int32 + ssi_code*: int32 + ssi_pid*: uint32 + ssi_uid*: uint32 + ssi_fd*: int32 + ssi_tid*: uint32 + ssi_band*: uint32 + ssi_overrun*: uint32 + ssi_trapno*: uint32 + ssi_status*: int32 + ssi_int*: int32 + ssi_ptr*: uint64 + ssi_utime*: uint64 + ssi_stime*: uint64 + ssi_addr*: uint64 + pad* {.importc: "__pad".}: array[0..47, uint8] + + eventFdData {.importc: "eventfd_t", + header: "<sys/eventfd.h>", pure, final.} = uint64 + epoll_data {.importc: "union epoll_data", header: "<sys/epoll.h>", + pure, final.} = object + u64 {.importc: "u64".}: uint64 + epoll_event {.importc: "struct epoll_event", + header: "<sys/epoll.h>", pure, final.} = object + events: uint32 # Epoll events + data: epoll_data # User data variable + +const + EPOLL_CTL_ADD = 1 # Add a file descriptor to the interface. + EPOLL_CTL_DEL = 2 # Remove a file descriptor from the interface. + EPOLL_CTL_MOD = 3 # Change file descriptor epoll_event structure. + EPOLLIN = 0x00000001 + EPOLLOUT = 0x00000004 + EPOLLERR = 0x00000008 + EPOLLHUP = 0x00000010 + EPOLLRDHUP = 0x00002000 + EPOLLONESHOT = 1 shl 30 + +proc epoll_create(size: cint): cint + {.importc: "epoll_create", header: "<sys/epoll.h>".} +proc epoll_ctl(epfd: cint; op: cint; fd: cint; event: ptr epoll_event): cint + {.importc: "epoll_ctl", header: "<sys/epoll.h>".} +proc epoll_wait(epfd: cint; events: ptr epoll_event; maxevents: cint; + timeout: cint): cint + {.importc: "epoll_wait", header: "<sys/epoll.h>".} +proc timerfd_create(clock_id: ClockId, flags: cint): cint + {.cdecl, importc: "timerfd_create", header: "<sys/timerfd.h>".} +proc timerfd_settime(ufd: cint, flags: cint, + utmr: var Itimerspec, otmr: var Itimerspec): cint + {.cdecl, importc: "timerfd_settime", header: "<sys/timerfd.h>".} +proc signalfd(fd: cint, mask: var Sigset, flags: cint): cint + {.cdecl, importc: "signalfd", header: "<sys/signalfd.h>".} +proc eventfd(count: cuint, flags: cint): cint + {.cdecl, importc: "eventfd", header: "<sys/eventfd.h>".} +proc ulimit(cmd: cint): clong + {.importc: "ulimit", header: "<ulimit.h>", varargs.} + +when hasThreadSupport: + type + SelectorImpl[T] = object + epollFD : cint + maxFD : int + fds: ptr SharedArray[SelectorKey[T]] + count: int + Selector*[T] = ptr SelectorImpl[T] +else: + type + SelectorImpl[T] = object + epollFD : cint + maxFD : int + fds: seq[SelectorKey[T]] + count: int + Selector*[T] = ref SelectorImpl[T] +type + SelectEventImpl = object + efd: cint + SelectEvent* = ptr SelectEventImpl + +proc newSelector*[T](): Selector[T] = + var maxFD = int(ulimit(4, 0)) + doAssert(maxFD > 0) + + var epollFD = epoll_create(MAX_EPOLL_RESULT_EVENTS) + if epollFD < 0: + raiseOsError(osLastError()) + + when hasThreadSupport: + result = cast[Selector[T]](allocShared0(sizeof(SelectorImpl[T]))) + result.epollFD = epollFD + result.maxFD = maxFD + result.fds = allocSharedArray[SelectorKey[T]](maxFD) + else: + result = Selector[T]() + result.epollFD = epollFD + result.maxFD = maxFD + result.fds = newSeq[SelectorKey[T]](maxFD) + +proc close*[T](s: Selector[T]) = + if posix.close(s.epollFD) != 0: + raiseOSError(osLastError()) + when hasThreadSupport: + deallocSharedArray(s.fds) + deallocShared(cast[pointer](s)) + +proc newSelectEvent*(): SelectEvent = + let fdci = eventfd(0, 0) + if fdci == -1: + raiseOSError(osLastError()) + setNonBlocking(fdci) + result = cast[SelectEvent](allocShared0(sizeof(SelectEventImpl))) + result.efd = fdci + +proc setEvent*(ev: SelectEvent) = + var data : uint64 = 1 + if posix.write(ev.efd, addr data, sizeof(uint64)) == -1: + raiseOSError(osLastError()) + +proc close*(ev: SelectEvent) = + discard posix.close(ev.efd) + deallocShared(cast[pointer](ev)) + +template checkFd(s, f) = + if f >= s.maxFD: + raise newException(ValueError, "Maximum file descriptors exceeded") + +proc registerHandle*[T](s: Selector[T], fd: SocketHandle, + events: set[Event], data: T) = + let fdi = int(fd) + s.checkFd(fdi) + doAssert(s.fds[fdi].ident == 0) + s.setKey(fdi, fdi, events, 0, data) + if events != {}: + var epv = epoll_event(events: EPOLLRDHUP) + epv.data.u64 = fdi.uint + if Event.Read in events: epv.events = epv.events or EPOLLIN + if Event.Write in events: epv.events = epv.events or EPOLLOUT + if epoll_ctl(s.epollFD, EPOLL_CTL_ADD, fdi.cint, addr epv) == -1: + raiseOSError(osLastError()) + inc(s.count) + +proc updateHandle*[T](s: Selector[T], fd: SocketHandle, events: set[Event]) = + let maskEvents = {Event.Timer, Event.Signal, Event.Process, Event.Vnode, + Event.User, Event.Oneshot, Event.Error} + let fdi = int(fd) + s.checkFd(fdi) + var pkey = addr(s.fds[fdi]) + doAssert(pkey.ident != 0) + doAssert(pkey.events * maskEvents == {}) + if pkey.events != events: + var epv = epoll_event(events: EPOLLRDHUP) + epv.data.u64 = fdi.uint + + if Event.Read in events: epv.events = epv.events or EPOLLIN + if Event.Write in events: epv.events = epv.events or EPOLLOUT + + if pkey.events == {}: + if epoll_ctl(s.epollFD, EPOLL_CTL_ADD, fdi.cint, addr epv) == -1: + raiseOSError(osLastError()) + inc(s.count) + else: + if events != {}: + if epoll_ctl(s.epollFD, EPOLL_CTL_MOD, fdi.cint, addr epv) == -1: + raiseOSError(osLastError()) + else: + if epoll_ctl(s.epollFD, EPOLL_CTL_DEL, fdi.cint, addr epv) == -1: + raiseOSError(osLastError()) + dec(s.count) + pkey.events = events + +proc unregister*[T](s: Selector[T], fd: int|SocketHandle) = + let fdi = int(fd) + s.checkFd(fdi) + var pkey = addr(s.fds[fdi]) + doAssert(pkey.ident != 0) + + if pkey.events != {}: + if pkey.events * {Event.Read, Event.Write} != {}: + var epv = epoll_event() + if epoll_ctl(s.epollFD, EPOLL_CTL_DEL, fdi.cint, addr epv) == -1: + raiseOSError(osLastError()) + dec(s.count) + elif Event.Timer in pkey.events: + var epv = epoll_event() + if epoll_ctl(s.epollFD, EPOLL_CTL_DEL, fdi.cint, addr epv) == -1: + raiseOSError(osLastError()) + discard posix.close(fdi.cint) + dec(s.count) + elif Event.Signal in pkey.events: + var epv = epoll_event() + if epoll_ctl(s.epollFD, EPOLL_CTL_DEL, fdi.cint, addr epv) == -1: + raiseOSError(osLastError()) + var nmask, omask: Sigset + discard sigemptyset(nmask) + discard sigemptyset(omask) + discard sigaddset(nmask, cint(s.fds[fdi].param)) + unblockSignals(nmask, omask) + discard posix.close(fdi.cint) + dec(s.count) + elif Event.Process in pkey.events: + var epv = epoll_event() + if epoll_ctl(s.epollFD, EPOLL_CTL_DEL, fdi.cint, addr epv) == -1: + raiseOSError(osLastError()) + var nmask, omask: Sigset + discard sigemptyset(nmask) + discard sigemptyset(omask) + discard sigaddset(nmask, SIGCHLD) + unblockSignals(nmask, omask) + discard posix.close(fdi.cint) + dec(s.count) + pkey.ident = 0 + pkey.events = {} + +proc unregister*[T](s: Selector[T], ev: SelectEvent) = + let fdi = int(ev.efd) + s.checkFd(fdi) + var pkey = addr(s.fds[fdi]) + doAssert(pkey.ident != 0) + doAssert(Event.User in pkey.events) + pkey.ident = 0 + pkey.events = {} + var epv = epoll_event() + if epoll_ctl(s.epollFD, EPOLL_CTL_DEL, fdi.cint, addr epv) == -1: + raiseOSError(osLastError()) + dec(s.count) + +proc registerTimer*[T](s: Selector[T], timeout: int, oneshot: bool, + data: T): int {.discardable.} = + var + new_ts: Itimerspec + old_ts: Itimerspec + let fdi = timerfd_create(CLOCK_MONOTONIC, 0).int + if fdi == -1: + raiseOSError(osLastError()) + setNonBlocking(fdi.cint) + + s.checkFd(fdi) + doAssert(s.fds[fdi].ident == 0) + + var events = {Event.Timer} + var epv = epoll_event(events: EPOLLIN or EPOLLRDHUP) + epv.data.u64 = fdi.uint + if oneshot: + new_ts.it_interval.tv_sec = 0.Time + new_ts.it_interval.tv_nsec = 0 + new_ts.it_value.tv_sec = (timeout div 1_000).Time + new_ts.it_value.tv_nsec = (timeout %% 1_000) * 1_000_000 + incl(events, Event.Oneshot) + epv.events = epv.events or EPOLLONESHOT + else: + new_ts.it_interval.tv_sec = (timeout div 1000).Time + new_ts.it_interval.tv_nsec = (timeout %% 1_000) * 1_000_000 + new_ts.it_value.tv_sec = new_ts.it_interval.tv_sec + new_ts.it_value.tv_nsec = new_ts.it_interval.tv_nsec + + if timerfd_settime(fdi.cint, cint(0), new_ts, old_ts) == -1: + raiseOSError(osLastError()) + if epoll_ctl(s.epollFD, EPOLL_CTL_ADD, fdi.cint, addr epv) == -1: + raiseOSError(osLastError()) + s.setKey(fdi, fdi, events, 0, data) + inc(s.count) + result = fdi + +proc registerSignal*[T](s: Selector[T], signal: int, + data: T): int {.discardable.} = + var + nmask: Sigset + omask: Sigset + + discard sigemptyset(nmask) + discard sigemptyset(omask) + discard sigaddset(nmask, cint(signal)) + blockSignals(nmask, omask) + + let fdi = signalfd(-1, nmask, 0).int + if fdi == -1: + raiseOSError(osLastError()) + setNonBlocking(fdi.cint) + + s.checkFd(fdi) + doAssert(s.fds[fdi].ident == 0) + + var epv = epoll_event(events: EPOLLIN or EPOLLRDHUP) + epv.data.u64 = fdi.uint + if epoll_ctl(s.epollFD, EPOLL_CTL_ADD, fdi.cint, addr epv) == -1: + raiseOSError(osLastError()) + s.setKey(fdi, signal, {Event.Signal}, signal, data) + inc(s.count) + result = fdi + +proc registerProcess*[T](s: Selector, pid: int, + data: T): int {.discardable.} = + var + nmask: Sigset + omask: Sigset + + discard sigemptyset(nmask) + discard sigemptyset(omask) + discard sigaddset(nmask, posix.SIGCHLD) + blockSignals(nmask, omask) + + let fdi = signalfd(-1, nmask, 0).int + if fdi == -1: + raiseOSError(osLastError()) + setNonBlocking(fdi.cint) + + s.checkFd(fdi) + doAssert(s.fds[fdi].ident == 0) + + var epv = epoll_event(events: EPOLLIN or EPOLLRDHUP) + epv.data.u64 = fdi.uint + epv.events = EPOLLIN or EPOLLRDHUP + if epoll_ctl(s.epollFD, EPOLL_CTL_ADD, fdi.cint, addr epv) == -1: + raiseOSError(osLastError()) + s.setKey(fdi, pid, {Event.Process, Event.Oneshot}, pid, data) + inc(s.count) + result = fdi + +proc registerEvent*[T](s: Selector[T], ev: SelectEvent, data: T) = + let fdi = int(ev.efd) + doAssert(s.fds[fdi].ident == 0) + s.setKey(fdi, fdi, {Event.User}, 0, data) + var epv = epoll_event(events: EPOLLIN or EPOLLRDHUP) + epv.data.u64 = ev.efd.uint + if epoll_ctl(s.epollFD, EPOLL_CTL_ADD, ev.efd, addr epv) == -1: + raiseOSError(osLastError()) + inc(s.count) + +proc flush*[T](s: Selector[T]) = + discard + +proc selectInto*[T](s: Selector[T], timeout: int, + results: var openarray[ReadyKey[T]]): int = + var + resTable: array[MAX_EPOLL_RESULT_EVENTS, epoll_event] + maxres = MAX_EPOLL_RESULT_EVENTS + events: set[Event] = {} + i, k: int + + if maxres > len(results): + maxres = len(results) + + let count = epoll_wait(s.epollFD, addr(resTable[0]), maxres.cint, + timeout.cint) + if count < 0: + result = 0 + let err = osLastError() + if cint(err) != EINTR: + raiseOSError(err) + elif count == 0: + result = 0 + else: + i = 0 + k = 0 + while i < count: + let fdi = int(resTable[i].data.u64) + let pevents = resTable[i].events + var skey = addr(s.fds[fdi]) + doAssert(skey.ident != 0) + events = {} + + if (pevents and EPOLLERR) != 0 or (pevents and EPOLLHUP) != 0: + events.incl(Event.Error) + if (pevents and EPOLLOUT) != 0: + events.incl(Event.Write) + if (pevents and EPOLLIN) != 0: + if Event.Read in skey.events: + events.incl(Event.Read) + elif Event.Timer in skey.events: + var data: uint64 = 0 + if posix.read(fdi.cint, addr data, sizeof(uint64)) != sizeof(uint64): + raiseOSError(osLastError()) + events = {Event.Timer} + elif Event.Signal in skey.events: + var data = SignalFdInfo() + if posix.read(fdi.cint, addr data, + sizeof(SignalFdInfo)) != sizeof(SignalFdInfo): + raiseOsError(osLastError()) + events = {Event.Signal} + elif Event.Process in skey.events: + var data = SignalFdInfo() + if posix.read(fdi.cint, addr data, + sizeof(SignalFdInfo)) != sizeof(SignalFdInfo): + raiseOsError(osLastError()) + if cast[int](data.ssi_pid) == skey.param: + events = {Event.Process} + else: + inc(i) + continue + elif Event.User in skey.events: + var data: uint = 0 + if posix.read(fdi.cint, addr data, sizeof(uint)) != sizeof(uint): + let err = osLastError() + if err == OSErrorCode(EAGAIN): + inc(i) + continue + else: + raiseOSError(err) + events = {Event.User} + + skey.key.events = events + results[k] = skey.key + inc(k) + + if Event.Oneshot in skey.events: + var epv = epoll_event() + if epoll_ctl(s.epollFD, EPOLL_CTL_DEL, fdi.cint, addr epv) == -1: + raiseOSError(osLastError()) + discard posix.close(fdi.cint) + skey.ident = 0 + skey.events = {} + dec(s.count) + inc(i) + result = k + +proc select*[T](s: Selector[T], timeout: int): seq[ReadyKey[T]] = + result = newSeq[ReadyKey[T]](MAX_EPOLL_RESULT_EVENTS) + let count = selectInto(s, timeout, result) + result.setLen(count) + +template isEmpty*[T](s: Selector[T]): bool = + (s.count == 0) + +template withData*[T](s: Selector[T], fd: SocketHandle, value, + body: untyped) = + mixin checkFd + let fdi = int(fd) + s.checkFd(fdi) + if s.fds[fdi].ident != 0: + var value = addr(s.fds[fdi].key.data) + body + +template withData*[T](s: Selector[T], fd: SocketHandle, value, body1, + body2: untyped) = + mixin checkFd + let fdi = int(fd) + s.checkFd(fdi) + if s.fds[fdi].ident != 0: + var value = addr(s.fds[fdi].key.data) + body1 + else: + body2 diff --git a/lib/pure/ioselects/ioselectors_kqueue.nim b/lib/pure/ioselects/ioselectors_kqueue.nim new file mode 100644 index 000000000..3e86f19aa --- /dev/null +++ b/lib/pure/ioselects/ioselectors_kqueue.nim @@ -0,0 +1,443 @@ +# +# +# Nim's Runtime Library +# (c) Copyright 2016 Eugene Kabanov +# +# See the file "copying.txt", included in this +# distribution, for details about the copyright. +# + +# This module implements BSD kqueue(). + +import posix, times, kqueue + +const + # Maximum number of cached changes. + MAX_KQUEUE_CHANGE_EVENTS = 64 + # Maximum number of events that can be returned. + MAX_KQUEUE_RESULT_EVENTS = 64 + # SIG_IGN and SIG_DFL declared in posix.nim as variables, but we need them + # to be constants and GC-safe. + SIG_DFL = cast[proc(x: cint) {.noconv,gcsafe.}](0) + SIG_IGN = cast[proc(x: cint) {.noconv,gcsafe.}](1) + +when defined(macosx) or defined(freebsd): + when defined(macosx): + const MAX_DESCRIPTORS_ID = 29 # KERN_MAXFILESPERPROC (MacOS) + else: + const MAX_DESCRIPTORS_ID = 27 # KERN_MAXFILESPERPROC (FreeBSD) + proc sysctl(name: ptr cint, namelen: cuint, oldp: pointer, oldplen: ptr int, + newp: pointer, newplen: int): cint + {.importc: "sysctl",header: """#include <sys/types.h> + #include <sys/sysctl.h>"""} +elif defined(netbsd) or defined(openbsd): + # OpenBSD and NetBSD don't have KERN_MAXFILESPERPROC, so we are using + # KERN_MAXFILES, because KERN_MAXFILES is always bigger, + # than KERN_MAXFILESPERPROC. + const MAX_DESCRIPTORS_ID = 7 # KERN_MAXFILES + proc sysctl(name: ptr cint, namelen: cuint, oldp: pointer, oldplen: ptr int, + newp: pointer, newplen: int): cint + {.importc: "sysctl",header: """#include <sys/param.h> + #include <sys/sysctl.h>"""} + +when hasThreadSupport: + type + SelectorImpl[T] = object + kqFD : cint + maxFD : int + changesTable: array[MAX_KQUEUE_CHANGE_EVENTS, KEvent] + changesCount: int + fds: ptr SharedArray[SelectorKey[T]] + count: int + changesLock: Lock + Selector*[T] = ptr SelectorImpl[T] +else: + type + SelectorImpl[T] = object + kqFD : cint + maxFD : int + changesTable: array[MAX_KQUEUE_CHANGE_EVENTS, KEvent] + changesCount: int + fds: seq[SelectorKey[T]] + count: int + Selector*[T] = ref SelectorImpl[T] + +type + SelectEventImpl = object + rfd: cint + wfd: cint +# SelectEvent is declared as `ptr` to be placed in `shared memory`, +# so you can share one SelectEvent handle between threads. +type SelectEvent* = ptr SelectEventImpl + +proc newSelector*[T](): Selector[T] = + var maxFD = 0.cint + var size = sizeof(cint) + var namearr = [1.cint, MAX_DESCRIPTORS_ID.cint] + # Obtain maximum number of file descriptors for process + if sysctl(addr(namearr[0]), 2, cast[pointer](addr maxFD), addr size, + nil, 0) != 0: + raiseOsError(osLastError()) + + var kqFD = kqueue() + if kqFD < 0: + raiseOsError(osLastError()) + + when hasThreadSupport: + result = cast[Selector[T]](allocShared0(sizeof(SelectorImpl[T]))) + result.kqFD = kqFD + result.maxFD = maxFD.int + result.fds = allocSharedArray[SelectorKey[T]](maxFD) + initLock(result.changesLock) + else: + result = Selector[T]() + result.kqFD = kqFD + result.maxFD = maxFD.int + result.fds = newSeq[SelectorKey[T]](maxFD) + +proc close*[T](s: Selector[T]) = + if posix.close(s.kqFD) != 0: + raiseOSError(osLastError()) + when hasThreadSupport: + deinitLock(s.changesLock) + deallocSharedArray(s.fds) + deallocShared(cast[pointer](s)) + +proc newSelectEvent*(): SelectEvent = + var fds: array[2, cint] + if posix.pipe(fds) == -1: + raiseOSError(osLastError()) + setNonBlocking(fds[0]) + setNonBlocking(fds[1]) + result = cast[SelectEvent](allocShared0(sizeof(SelectEventImpl))) + result.rfd = fds[0] + result.wfd = fds[1] + +proc setEvent*(ev: SelectEvent) = + var data: uint64 = 1 + if posix.write(ev.wfd, addr data, sizeof(uint64)) != sizeof(uint64): + raiseOSError(osLastError()) + +proc close*(ev: SelectEvent) = + discard posix.close(cint(ev.rfd)) + discard posix.close(cint(ev.wfd)) + deallocShared(cast[pointer](ev)) + +template checkFd(s, f) = + if f >= s.maxFD: + raise newException(ValueError, "Maximum file descriptors exceeded") + +when hasThreadSupport: + template withChangeLock[T](s: Selector[T], body: untyped) = + acquire(s.changesLock) + {.locks: [s.changesLock].}: + try: + body + finally: + release(s.changesLock) +else: + template withChangeLock(s, body: untyped) = + body + +template modifyKQueue[T](s: Selector[T], nident: uint, nfilter: cshort, + nflags: cushort, nfflags: cuint, ndata: int, + nudata: pointer) = + mixin withChangeLock + s.withChangeLock(): + s.changesTable[s.changesCount] = KEvent(ident: nident, + filter: nfilter, flags: nflags, + fflags: nfflags, data: ndata, + udata: nudata) + inc(s.changesCount) + if s.changesCount == MAX_KQUEUE_CHANGE_EVENTS: + if kevent(s.kqFD, addr(s.changesTable[0]), cint(s.changesCount), + nil, 0, nil) == -1: + raiseOSError(osLastError()) + s.changesCount = 0 + +proc registerHandle*[T](s: Selector[T], fd: SocketHandle, + events: set[Event], data: T) = + let fdi = int(fd) + s.checkFd(fdi) + doAssert(s.fds[fdi].ident == 0) + s.setKey(fdi, fdi, events, 0, data) + if events != {}: + if Event.Read in events: + modifyKQueue(s, fdi.uint, EVFILT_READ, EV_ADD, 0, 0, nil) + inc(s.count) + if Event.Write in events: + modifyKQueue(s, fdi.uint, EVFILT_WRITE, EV_ADD, 0, 0, nil) + inc(s.count) + +proc updateHandle*[T](s: Selector[T], fd: SocketHandle, + events: set[Event]) = + let maskEvents = {Event.Timer, Event.Signal, Event.Process, Event.Vnode, + Event.User, Event.Oneshot, Event.Error} + let fdi = int(fd) + s.checkFd(fdi) + var pkey = addr(s.fds[fdi]) + doAssert(pkey.ident != 0) + doAssert(pkey.events * maskEvents == {}) + + if pkey.events != events: + if (Event.Read in pkey.events) and (Event.Read notin events): + modifyKQueue(s, fdi.uint, EVFILT_READ, EV_DELETE, 0, 0, nil) + dec(s.count) + if (Event.Write in pkey.events) and (Event.Write notin events): + modifyKQueue(s, fdi.uint, EVFILT_WRITE, EV_DELETE, 0, 0, nil) + dec(s.count) + if (Event.Read notin pkey.events) and (Event.Read in events): + modifyKQueue(s, fdi.uint, EVFILT_READ, EV_ADD, 0, 0, nil) + inc(s.count) + if (Event.Write notin pkey.events) and (Event.Write in events): + modifyKQueue(s, fdi.uint, EVFILT_WRITE, EV_ADD, 0, 0, nil) + inc(s.count) + pkey.events = events + +proc registerTimer*[T](s: Selector[T], timeout: int, oneshot: bool, + data: T): int {.discardable.} = + var fdi = posix.socket(posix.AF_INET, posix.SOCK_STREAM, + posix.IPPROTO_TCP).int + if fdi == -1: + raiseOsError(osLastError()) + + s.checkFd(fdi) + doAssert(s.fds[fdi].ident == 0) + + let events = if oneshot: {Event.Timer, Event.Oneshot} else: {Event.Timer} + let flags: cushort = if oneshot: EV_ONESHOT or EV_ADD else: EV_ADD + + s.setKey(fdi, fdi, events, 0, data) + # EVFILT_TIMER on Open/Net(BSD) has granularity of only milliseconds, + # but MacOS and FreeBSD allow use `0` as `fflags` to use milliseconds + # too + modifyKQueue(s, fdi.uint, EVFILT_TIMER, flags, 0, cint(timeout), nil) + inc(s.count) + result = fdi + +proc registerSignal*[T](s: Selector[T], signal: int, + data: T): int {.discardable.} = + var fdi = posix.socket(posix.AF_INET, posix.SOCK_STREAM, + posix.IPPROTO_TCP).int + if fdi == -1: + raiseOsError(osLastError()) + + s.checkFd(fdi) + doAssert(s.fds[fdi].ident == 0) + + s.setKey(fdi, signal, {Event.Signal}, signal, data) + var nmask, omask: Sigset + discard sigemptyset(nmask) + discard sigemptyset(omask) + discard sigaddset(nmask, cint(signal)) + blockSignals(nmask, omask) + # to be compatible with linux semantic we need to "eat" signals + posix.signal(cint(signal), SIG_IGN) + modifyKQueue(s, signal.uint, EVFILT_SIGNAL, EV_ADD, 0, 0, + cast[pointer](fdi)) + inc(s.count) + result = fdi + +proc registerProcess*[T](s: Selector[T], pid: int, + data: T): int {.discardable.} = + var fdi = posix.socket(posix.AF_INET, posix.SOCK_STREAM, + posix.IPPROTO_TCP).int + if fdi == -1: + raiseOsError(osLastError()) + + s.checkFd(fdi) + doAssert(s.fds[fdi].ident == 0) + + var kflags: cushort = EV_ONESHOT or EV_ADD + setKey(s, fdi, pid, {Event.Process, Event.Oneshot}, pid, data) + modifyKQueue(s, pid.uint, EVFILT_PROC, kflags, NOTE_EXIT, 0, + cast[pointer](fdi)) + inc(s.count) + result = fdi + +proc registerEvent*[T](s: Selector[T], ev: SelectEvent, data: T) = + let fdi = ev.rfd.int + doAssert(s.fds[fdi].ident == 0) + setKey(s, fdi, fdi, {Event.User}, 0, data) + modifyKQueue(s, fdi.uint, EVFILT_READ, EV_ADD, 0, 0, nil) + inc(s.count) + +proc unregister*[T](s: Selector[T], fd: int|SocketHandle) = + let fdi = int(fd) + s.checkFd(fdi) + var pkey = addr(s.fds[fdi]) + doAssert(pkey.ident != 0) + + if pkey.events != {}: + if pkey.events * {Event.Read, Event.Write} != {}: + if Event.Read in pkey.events: + modifyKQueue(s, fdi.uint, EVFILT_READ, EV_DELETE, 0, 0, nil) + dec(s.count) + if Event.Write in pkey.events: + modifyKQueue(s, fdi.uint, EVFILT_WRITE, EV_DELETE, 0, 0, nil) + dec(s.count) + elif Event.Timer in pkey.events: + discard posix.close(cint(pkey.key.fd)) + modifyKQueue(s, fdi.uint, EVFILT_TIMER, EV_DELETE, 0, 0, nil) + dec(s.count) + elif Event.Signal in pkey.events: + var nmask, omask: Sigset + var signal = cint(pkey.param) + discard sigemptyset(nmask) + discard sigemptyset(omask) + discard sigaddset(nmask, signal) + unblockSignals(nmask, omask) + posix.signal(signal, SIG_DFL) + discard posix.close(cint(pkey.key.fd)) + modifyKQueue(s, fdi.uint, EVFILT_SIGNAL, EV_DELETE, 0, 0, nil) + dec(s.count) + elif Event.Process in pkey.events: + discard posix.close(cint(pkey.key.fd)) + modifyKQueue(s, fdi.uint, EVFILT_PROC, EV_DELETE, 0, 0, nil) + dec(s.count) + elif Event.User in pkey.events: + modifyKQueue(s, fdi.uint, EVFILT_READ, EV_DELETE, 0, 0, nil) + dec(s.count) + pkey.ident = 0 + pkey.events = {} + +proc unregister*[T](s: Selector[T], ev: SelectEvent) = + let fdi = int(ev.rfd) + s.checkFd(fdi) + var pkey = addr(s.fds[fdi]) + doAssert(pkey.ident != 0) + doAssert(Event.User in pkey.events) + pkey.ident = 0 + pkey.events = {} + modifyKQueue(s, fdi.uint, EVFILT_READ, EV_DELETE, 0, 0, nil) + dec(s.count) + +proc flush*[T](s: Selector[T]) = + s.withChangeLock(): + var tv = Timespec() + if kevent(s.kqFD, addr(s.changesTable[0]), cint(s.changesCount), + nil, 0, addr tv) == -1: + raiseOSError(osLastError()) + s.changesCount = 0 + +proc selectInto*[T](s: Selector[T], timeout: int, + results: var openarray[ReadyKey[T]]): int = + var + tv: Timespec + resTable: array[MAX_KQUEUE_RESULT_EVENTS, KEvent] + ptv = addr tv + maxres = MAX_KQUEUE_RESULT_EVENTS + + if timeout != -1: + if timeout >= 1000: + tv.tv_sec = (timeout div 1_000).Time + tv.tv_nsec = (timeout %% 1_000) * 1_000_000 + else: + tv.tv_sec = 0.Time + tv.tv_nsec = timeout * 1_000_000 + else: + ptv = nil + + if maxres > len(results): + maxres = len(results) + + var count = 0 + s.withChangeLock(): + count = kevent(s.kqFD, addr(s.changesTable[0]), cint(s.changesCount), + addr(resTable[0]), cint(maxres), ptv) + s.changesCount = 0 + + if count < 0: + result = 0 + let err = osLastError() + if cint(err) != EINTR: + raiseOSError(err) + elif count == 0: + result = 0 + else: + var i = 0 + var k = 0 + var pkey: ptr SelectorKey[T] + while i < count: + let kevent = addr(resTable[i]) + if (kevent.flags and EV_ERROR) == 0: + case kevent.filter: + of EVFILT_READ: + pkey = addr(s.fds[kevent.ident.int]) + pkey.key.events = {Event.Read} + if Event.User in pkey.events: + var data: uint64 = 0 + if posix.read(kevent.ident.cint, addr data, + sizeof(uint64)) != sizeof(uint64): + let err = osLastError() + if err == OSErrorCode(EAGAIN): + # someone already consumed event data + inc(i) + continue + else: + raiseOSError(osLastError()) + pkey.key.events = {Event.User} + of EVFILT_WRITE: + pkey = addr(s.fds[kevent.ident.int]) + pkey.key.events = {Event.Write} + of EVFILT_TIMER: + pkey = addr(s.fds[kevent.ident.int]) + if Event.Oneshot in pkey.events: + if posix.close(cint(pkey.ident)) == -1: + raiseOSError(osLastError()) + pkey.ident = 0 + pkey.events = {} + dec(s.count) + pkey.key.events = {Event.Timer} + of EVFILT_VNODE: + pkey = addr(s.fds[kevent.ident.int]) + pkey.key.events = {Event.Vnode} + of EVFILT_SIGNAL: + pkey = addr(s.fds[cast[int](kevent.udata)]) + pkey.key.events = {Event.Signal} + of EVFILT_PROC: + pkey = addr(s.fds[cast[int](kevent.udata)]) + if posix.close(cint(pkey.ident)) == -1: + raiseOSError(osLastError()) + pkey.ident = 0 + pkey.events = {} + dec(s.count) + pkey.key.events = {Event.Process} + else: + raise newException(ValueError, "Unsupported kqueue filter in queue") + + if (kevent.flags and EV_EOF) != 0: + pkey.key.events.incl(Event.Error) + + results[k] = pkey.key + inc(k) + inc(i) + result = k + +proc select*[T](s: Selector[T], timeout: int): seq[ReadyKey[T]] = + result = newSeq[ReadyKey[T]](MAX_KQUEUE_RESULT_EVENTS) + let count = selectInto(s, timeout, result) + result.setLen(count) + +template isEmpty*[T](s: Selector[T]): bool = + (s.count == 0) + +template withData*[T](s: Selector[T], fd: SocketHandle, value, + body: untyped) = + mixin checkFd + let fdi = int(fd) + s.checkFd(fdi) + if s.fds[fdi].ident != 0: + var value = addr(s.fds[fdi].key.data) + body + +template withData*[T](s: Selector[T], fd: SocketHandle, value, body1, + body2: untyped) = + mixin checkFd + let fdi = int(fd) + s.checkFd(fdi) + if s.fds[fdi].ident != 0: + var value = addr(s.fds[fdi].key.data) + body1 + else: + body2 diff --git a/lib/pure/ioselects/ioselectors_poll.nim b/lib/pure/ioselects/ioselectors_poll.nim new file mode 100644 index 000000000..d2a0a1273 --- /dev/null +++ b/lib/pure/ioselects/ioselectors_poll.nim @@ -0,0 +1,295 @@ +# +# +# Nim's Runtime Library +# (c) Copyright 2016 Eugene Kabanov +# +# See the file "copying.txt", included in this +# distribution, for details about the copyright. +# + +# This module implements Posix poll(). + +import posix, times + +# Maximum number of events that can be returned +const MAX_POLL_RESULT_EVENTS = 64 + +when hasThreadSupport: + type + SelectorImpl[T] = object + maxFD : int + pollcnt: int + fds: ptr SharedArray[SelectorKey[T]] + pollfds: ptr SharedArray[TPollFd] + count: int + lock: Lock + Selector*[T] = ptr SelectorImpl[T] +else: + type + SelectorImpl[T] = object + maxFD : int + pollcnt: int + fds: seq[SelectorKey[T]] + pollfds: seq[TPollFd] + count: int + Selector*[T] = ref SelectorImpl[T] + +type + SelectEventImpl = object + rfd: cint + wfd: cint + SelectEvent* = ptr SelectEventImpl + +var RLIMIT_NOFILE {.importc: "RLIMIT_NOFILE", + header: "<sys/resource.h>".}: cint +type + rlimit {.importc: "struct rlimit", + header: "<sys/resource.h>", pure, final.} = object + rlim_cur: int + rlim_max: int +proc getrlimit(resource: cint, rlp: var rlimit): cint + {.importc: "getrlimit",header: "<sys/resource.h>".} + +when hasThreadSupport: + template withPollLock[T](s: Selector[T], body: untyped) = + acquire(s.lock) + {.locks: [s.lock].}: + try: + body + finally: + release(s.lock) +else: + template withPollLock(s, body: untyped) = + body + +proc newSelector*[T](): Selector[T] = + var a = rlimit() + if getrlimit(RLIMIT_NOFILE, a) != 0: + raiseOsError(osLastError()) + var maxFD = int(a.rlim_max) + + when hasThreadSupport: + result = cast[Selector[T]](allocShared0(sizeof(SelectorImpl[T]))) + result.maxFD = maxFD + result.fds = allocSharedArray[SelectorKey[T]](maxFD) + result.pollfds = allocSharedArray[TPollFd](maxFD) + initLock(result.lock) + else: + result = Selector[T]() + result.maxFD = maxFD + result.fds = newSeq[SelectorKey[T]](maxFD) + result.pollfds = newSeq[TPollFd](maxFD) + +proc close*[T](s: Selector[T]) = + when hasThreadSupport: + deinitLock(s.lock) + deallocSharedArray(s.fds) + deallocSharedArray(s.pollfds) + deallocShared(cast[pointer](s)) + +template pollAdd[T](s: Selector[T], sock: cint, events: set[Event]) = + withPollLock(s): + var pollev: cshort = 0 + if Event.Read in events: pollev = pollev or POLLIN + if Event.Write in events: pollev = pollev or POLLOUT + s.pollfds[s.pollcnt].fd = cint(sock) + s.pollfds[s.pollcnt].events = pollev + inc(s.count) + inc(s.pollcnt) + +template pollUpdate[T](s: Selector[T], sock: cint, events: set[Event]) = + withPollLock(s): + var i = 0 + var pollev: cshort = 0 + if Event.Read in events: pollev = pollev or POLLIN + if Event.Write in events: pollev = pollev or POLLOUT + + while i < s.pollcnt: + if s.pollfds[i].fd == sock: + s.pollfds[i].events = pollev + break + inc(i) + + if i == s.pollcnt: + raise newException(ValueError, "Descriptor is not registered in queue") + +template pollRemove[T](s: Selector[T], sock: cint) = + withPollLock(s): + var i = 0 + while i < s.pollcnt: + if s.pollfds[i].fd == sock: + if i == s.pollcnt - 1: + s.pollfds[i].fd = 0 + s.pollfds[i].events = 0 + s.pollfds[i].revents = 0 + else: + while i < (s.pollcnt - 1): + s.pollfds[i].fd = s.pollfds[i + 1].fd + s.pollfds[i].events = s.pollfds[i + 1].events + inc(i) + break + inc(i) + dec(s.pollcnt) + dec(s.count) + +template checkFd(s, f) = + if f >= s.maxFD: + raise newException(ValueError, "Maximum file descriptors exceeded") + +proc registerHandle*[T](s: Selector[T], fd: SocketHandle, + events: set[Event], data: T) = + var fdi = int(fd) + s.checkFd(fdi) + doAssert(s.fds[fdi].ident == 0) + s.setKey(fdi, fdi, events, 0, data) + if events != {}: s.pollAdd(fdi.cint, events) + +proc updateHandle*[T](s: Selector[T], fd: SocketHandle, + events: set[Event]) = + let maskEvents = {Event.Timer, Event.Signal, Event.Process, Event.Vnode, + Event.User, Event.Oneshot, Event.Error} + let fdi = int(fd) + s.checkFd(fdi) + var pkey = addr(s.fds[fdi]) + doAssert(pkey.ident != 0) + doAssert(pkey.events * maskEvents == {}) + + if pkey.events != events: + if pkey.events == {}: + s.pollAdd(fd.cint, events) + else: + if events != {}: + s.pollUpdate(fd.cint, events) + else: + s.pollRemove(fd.cint) + pkey.events = events + +proc registerEvent*[T](s: Selector[T], ev: SelectEvent, data: T) = + var fdi = int(ev.rfd) + doAssert(s.fds[fdi].ident == 0) + var events = {Event.User} + setKey(s, fdi, fdi, events, 0, data) + events.incl(Event.Read) + s.pollAdd(fdi.cint, events) + +proc flush*[T](s: Selector[T]) = discard + +proc unregister*[T](s: Selector[T], fd: int|SocketHandle) = + let fdi = int(fd) + s.checkFd(fdi) + var pkey = addr(s.fds[fdi]) + doAssert(pkey.ident != 0) + pkey.ident = 0 + pkey.events = {} + s.pollRemove(fdi.cint) + +proc unregister*[T](s: Selector[T], ev: SelectEvent) = + let fdi = int(ev.rfd) + s.checkFd(fdi) + var pkey = addr(s.fds[fdi]) + doAssert(pkey.ident != 0) + doAssert(Event.User in pkey.events) + pkey.ident = 0 + pkey.events = {} + s.pollRemove(fdi.cint) + +proc newSelectEvent*(): SelectEvent = + var fds: array[2, cint] + if posix.pipe(fds) == -1: + raiseOSError(osLastError()) + setNonBlocking(fds[0]) + setNonBlocking(fds[1]) + result = cast[SelectEvent](allocShared0(sizeof(SelectEventImpl))) + result.rfd = fds[0] + result.wfd = fds[1] + +proc setEvent*(ev: SelectEvent) = + var data: uint64 = 1 + if posix.write(ev.wfd, addr data, sizeof(uint64)) != sizeof(uint64): + raiseOSError(osLastError()) + +proc close*(ev: SelectEvent) = + discard posix.close(cint(ev.rfd)) + discard posix.close(cint(ev.wfd)) + deallocShared(cast[pointer](ev)) + +proc selectInto*[T](s: Selector[T], timeout: int, + results: var openarray[ReadyKey[T]]): int = + var maxres = MAX_POLL_RESULT_EVENTS + if maxres > len(results): + maxres = len(results) + + s.withPollLock(): + let count = posix.poll(addr(s.pollfds[0]), Tnfds(s.pollcnt), timeout) + if count < 0: + result = 0 + let err = osLastError() + if err.cint == EINTR: + discard + else: + raiseOSError(osLastError()) + elif count == 0: + result = 0 + else: + var i = 0 + var k = 0 + var rindex = 0 + while (i < s.pollcnt) and (k < count) and (rindex < maxres): + let revents = s.pollfds[i].revents + if revents != 0: + let fd = s.pollfds[i].fd + var skey = addr(s.fds[fd]) + skey.key.events = {} + + if (revents and POLLIN) != 0: + skey.key.events.incl(Event.Read) + if Event.User in skey.events: + var data: uint64 = 0 + if posix.read(fd, addr data, sizeof(int)) != sizeof(int): + let err = osLastError() + if err != OSErrorCode(EAGAIN): + raiseOSError(osLastError()) + else: + # someone already consumed event data + inc(i) + continue + skey.key.events = {Event.User} + if (revents and POLLOUT) != 0: + skey.key.events.incl(Event.Write) + if (revents and POLLERR) != 0 or (revents and POLLHUP) != 0 or + (revents and POLLNVAL) != 0: + skey.key.events.incl(Event.Error) + results[rindex] = skey.key + s.pollfds[i].revents = 0 + inc(rindex) + inc(k) + inc(i) + result = k + +proc select*[T](s: Selector[T], timeout: int): seq[ReadyKey[T]] = + result = newSeq[ReadyKey[T]](MAX_POLL_RESULT_EVENTS) + let count = selectInto(s, timeout, result) + result.setLen(count) + +template isEmpty*[T](s: Selector[T]): bool = + (s.count == 0) + +template withData*[T](s: Selector[T], fd: SocketHandle, value, + body: untyped) = + mixin checkFd + let fdi = int(fd) + s.checkFd(fdi) + if s.fds[fdi].ident != 0: + var value = addr(s.fds[fdi].key.data) + body + +template withData*[T](s: Selector[T], fd: SocketHandle, value, body1, + body2: untyped) = + mixin checkFd + let fdi = int(fd) + s.checkFd(fdi) + if s.fds[fdi].ident != 0: + var value = addr(s.fds[fdi].key.data) + body1 + else: + body2 diff --git a/lib/pure/ioselects/ioselectors_select.nim b/lib/pure/ioselects/ioselectors_select.nim new file mode 100644 index 000000000..f8099f9a0 --- /dev/null +++ b/lib/pure/ioselects/ioselectors_select.nim @@ -0,0 +1,416 @@ +# +# +# Nim's Runtime Library +# (c) Copyright 2016 Eugene Kabanov +# +# See the file "copying.txt", included in this +# distribution, for details about the copyright. +# + +# This module implements Posix and Windows select(). + +import times, nativesockets + +when defined(windows): + import winlean + when defined(gcc): + {.passL: "-lws2_32".} + elif defined(vcc): + {.passL: "ws2_32.lib".} + const platformHeaders = """#include <winsock2.h> + #include <windows.h>""" + const EAGAIN = WSAEWOULDBLOCK +else: + const platformHeaders = """#include <sys/select.h> + #include <sys/time.h> + #include <sys/types.h> + #include <unistd.h>""" +type + Fdset {.importc: "fd_set", header: platformHeaders, pure, final.} = object +var + FD_SETSIZE {.importc: "FD_SETSIZE", header: platformHeaders.}: cint + +proc IOFD_SET(fd: SocketHandle, fdset: ptr Fdset) + {.cdecl, importc: "FD_SET", header: platformHeaders, inline.} +proc IOFD_CLR(fd: SocketHandle, fdset: ptr Fdset) + {.cdecl, importc: "FD_CLR", header: platformHeaders, inline.} +proc IOFD_ZERO(fdset: ptr Fdset) + {.cdecl, importc: "FD_ZERO", header: platformHeaders, inline.} + +when defined(windows): + proc IOFD_ISSET(fd: SocketHandle, fdset: ptr Fdset): cint + {.stdcall, importc: "FD_ISSET", header: platformHeaders, inline.} + proc ioselect(nfds: cint, readFds, writeFds, exceptFds: ptr Fdset, + timeout: ptr Timeval): cint + {.stdcall, importc: "select", header: platformHeaders.} +else: + proc IOFD_ISSET(fd: SocketHandle, fdset: ptr Fdset): cint + {.cdecl, importc: "FD_ISSET", header: platformHeaders, inline.} + proc ioselect(nfds: cint, readFds, writeFds, exceptFds: ptr Fdset, + timeout: ptr Timeval): cint + {.cdecl, importc: "select", header: platformHeaders.} + +when hasThreadSupport: + type + SelectorImpl[T] = object + rSet: FdSet + wSet: FdSet + eSet: FdSet + maxFD: int + fds: ptr SharedArray[SelectorKey[T]] + count: int + lock: Lock + Selector*[T] = ptr SelectorImpl[T] +else: + type + SelectorImpl[T] = object + rSet: FdSet + wSet: FdSet + eSet: FdSet + maxFD: int + fds: seq[SelectorKey[T]] + count: int + Selector*[T] = ref SelectorImpl[T] + +type + SelectEventImpl = object + rsock: SocketHandle + wsock: SocketHandle + SelectEvent* = ptr SelectEventImpl + +when hasThreadSupport: + template withSelectLock[T](s: Selector[T], body: untyped) = + acquire(s.lock) + {.locks: [s.lock].}: + try: + body + finally: + release(s.lock) +else: + template withSelectLock[T](s: Selector[T], body: untyped) = + body + +proc newSelector*[T](): Selector[T] = + when hasThreadSupport: + result = cast[Selector[T]](allocShared0(sizeof(SelectorImpl[T]))) + result.fds = allocSharedArray[SelectorKey[T]](FD_SETSIZE) + initLock result.lock + else: + result = Selector[T]() + result.fds = newSeq[SelectorKey[T]](FD_SETSIZE) + + IOFD_ZERO(addr result.rSet) + IOFD_ZERO(addr result.wSet) + IOFD_ZERO(addr result.eSet) + +proc close*[T](s: Selector[T]) = + when hasThreadSupport: + deallocSharedArray(s.fds) + deallocShared(cast[pointer](s)) + +when defined(windows): + proc newSelectEvent*(): SelectEvent = + var ssock = newNativeSocket() + var wsock = newNativeSocket() + var rsock: SocketHandle = INVALID_SOCKET + var saddr = Sockaddr_in() + + saddr.sin_family = winlean.AF_INET + saddr.sin_port = 0 + saddr.sin_addr.s_addr = INADDR_ANY + if bindAddr(ssock, cast[ptr SockAddr](addr(saddr)), + sizeof(saddr).SockLen) < 0'i32: + raiseOSError(osLastError()) + + if winlean.listen(ssock, 1) == -1: + raiseOSError(osLastError()) + + var namelen = sizeof(saddr).SockLen + if getsockname(ssock, cast[ptr SockAddr](addr(saddr)), + addr(namelen)) == -1'i32: + raiseOSError(osLastError()) + + saddr.sin_addr.s_addr = 0x0100007F + if winlean.connect(wsock, cast[ptr SockAddr](addr(saddr)), + sizeof(saddr).SockLen) == -1: + raiseOSError(osLastError()) + namelen = sizeof(saddr).SockLen + rsock = winlean.accept(ssock, cast[ptr SockAddr](addr(saddr)), + cast[ptr SockLen](addr(namelen))) + if rsock == SocketHandle(-1): + raiseOSError(osLastError()) + + if winlean.closesocket(ssock) == -1: + raiseOSError(osLastError()) + + var mode = clong(1) + if ioctlsocket(rsock, FIONBIO, addr(mode)) == -1: + raiseOSError(osLastError()) + mode = clong(1) + if ioctlsocket(wsock, FIONBIO, addr(mode)) == -1: + raiseOSError(osLastError()) + + result = cast[SelectEvent](allocShared0(sizeof(SelectEventImpl))) + result.rsock = rsock + result.wsock = wsock + + proc setEvent*(ev: SelectEvent) = + var data: int = 1 + if winlean.send(ev.wsock, cast[pointer](addr data), + cint(sizeof(int)), 0) != sizeof(int): + raiseOSError(osLastError()) + + proc close*(ev: SelectEvent) = + discard winlean.closesocket(ev.rsock) + discard winlean.closesocket(ev.wsock) + deallocShared(cast[pointer](ev)) + +else: + proc newSelectEvent*(): SelectEvent = + var fds: array[2, cint] + if posix.pipe(fds) == -1: + raiseOSError(osLastError()) + setNonBlocking(fds[0]) + setNonBlocking(fds[1]) + result = cast[SelectEvent](allocShared0(sizeof(SelectEventImpl))) + result.rsock = SocketHandle(fds[0]) + result.wsock = SocketHandle(fds[1]) + + proc setEvent*(ev: SelectEvent) = + var data: uint64 = 1 + if posix.write(cint(ev.wsock), addr data, sizeof(uint64)) != sizeof(uint64): + raiseOSError(osLastError()) + + proc close*(ev: SelectEvent) = + discard posix.close(cint(ev.rsock)) + discard posix.close(cint(ev.wsock)) + deallocShared(cast[pointer](ev)) + +proc setKey[T](s: Selector[T], fd: SocketHandle, events: set[Event], data: T) = + var i = 0 + let fdi = int(fd) + while i < FD_SETSIZE: + if s.fds[i].ident == 0: + var pkey = addr(s.fds[i]) + pkey.ident = fdi + pkey.events = events + pkey.key.fd = fd.int + pkey.key.events = {} + pkey.key.data = data + break + inc(i) + if i == FD_SETSIZE: + raise newException(ValueError, "Maximum numbers of fds exceeded") + +proc getKey[T](s: Selector[T], fd: SocketHandle): ptr SelectorKey[T] = + var i = 0 + let fdi = int(fd) + while i < FD_SETSIZE: + if s.fds[i].ident == fdi: + result = addr(s.fds[i]) + break + inc(i) + doAssert(i < FD_SETSIZE, "Descriptor not registered in queue") + +proc delKey[T](s: Selector[T], fd: SocketHandle) = + var i = 0 + while i < FD_SETSIZE: + if s.fds[i].ident == fd.int: + s.fds[i].ident = 0 + s.fds[i].events = {} + break + inc(i) + doAssert(i < FD_SETSIZE, "Descriptor not registered in queue") + +proc registerHandle*[T](s: Selector[T], fd: SocketHandle, + events: set[Event], data: T) = + when not defined(windows): + let fdi = int(fd) + s.withSelectLock(): + s.setKey(fd, events, data) + when not defined(windows): + if fdi > s.maxFD: s.maxFD = fdi + if Event.Read in events: + IOFD_SET(fd, addr s.rSet) + inc(s.count) + if Event.Write in events: + IOFD_SET(fd, addr s.wSet) + IOFD_SET(fd, addr s.eSet) + inc(s.count) + +proc registerEvent*[T](s: Selector[T], ev: SelectEvent, data: T) = + when not defined(windows): + let fdi = int(ev.rsock) + s.withSelectLock(): + s.setKey(ev.rsock, {Event.User}, data) + when not defined(windows): + if fdi > s.maxFD: s.maxFD = fdi + IOFD_SET(ev.rsock, addr s.rSet) + inc(s.count) + +proc updateHandle*[T](s: Selector[T], fd: SocketHandle, + events: set[Event]) = + let maskEvents = {Event.Timer, Event.Signal, Event.Process, Event.Vnode, + Event.User, Event.Oneshot, Event.Error} + s.withSelectLock(): + var pkey = s.getKey(fd) + doAssert(pkey.events * maskEvents == {}) + if pkey.events != events: + if (Event.Read in pkey.events) and (Event.Read notin events): + IOFD_CLR(fd, addr s.rSet) + dec(s.count) + if (Event.Write in pkey.events) and (Event.Write notin events): + IOFD_CLR(fd, addr s.wSet) + IOFD_CLR(fd, addr s.eSet) + dec(s.count) + if (Event.Read notin pkey.events) and (Event.Read in events): + IOFD_SET(fd, addr s.rSet) + inc(s.count) + if (Event.Write notin pkey.events) and (Event.Write in events): + IOFD_SET(fd, addr s.wSet) + IOFD_SET(fd, addr s.eSet) + inc(s.count) + pkey.events = events + +proc unregister*[T](s: Selector[T], fd: SocketHandle) = + s.withSelectLock(): + var pkey = s.getKey(fd) + if Event.Read in pkey.events: + IOFD_CLR(fd, addr s.rSet) + dec(s.count) + if Event.Write in pkey.events: + IOFD_CLR(fd, addr s.wSet) + IOFD_CLR(fd, addr s.eSet) + dec(s.count) + s.delKey(fd) + +proc unregister*[T](s: Selector[T], ev: SelectEvent) = + let fd = ev.rsock + s.withSelectLock(): + IOFD_CLR(fd, addr s.rSet) + dec(s.count) + s.delKey(fd) + +proc selectInto*[T](s: Selector[T], timeout: int, + results: var openarray[ReadyKey[T]]): int = + var tv = Timeval() + var ptv = addr tv + var rset, wset, eset: FdSet + + if timeout != -1: + tv.tv_sec = timeout.int32 div 1_000 + tv.tv_usec = (timeout.int32 %% 1_000) * 1_000 + else: + ptv = nil + + s.withSelectLock(): + rset = s.rSet + wset = s.wSet + eset = s.eSet + + var count = ioselect(cint(s.maxFD) + 1, addr(rset), addr(wset), + addr(eset), ptv) + if count < 0: + result = 0 + when defined(windows): + raiseOSError(osLastError()) + else: + let err = osLastError() + if cint(err) != EINTR: + raiseOSError(err) + elif count == 0: + result = 0 + else: + var rindex = 0 + var i = 0 + var k = 0 + + while (i < FD_SETSIZE) and (k < count): + if s.fds[i].ident != 0: + var flag = false + var pkey = addr(s.fds[i]) + pkey.key.events = {} + let fd = SocketHandle(pkey.ident) + if IOFD_ISSET(fd, addr rset) != 0: + if Event.User in pkey.events: + var data: uint64 = 0 + if recv(fd, cast[pointer](addr(data)), + sizeof(uint64).cint, 0) != sizeof(uint64): + let err = osLastError() + if cint(err) != EAGAIN: + raiseOSError(err) + else: + inc(i) + inc(k) + continue + else: + flag = true + pkey.key.events = {Event.User} + else: + flag = true + pkey.key.events = {Event.Read} + if IOFD_ISSET(fd, addr wset) != 0: + pkey.key.events.incl(Event.Write) + if IOFD_ISSET(fd, addr eset) != 0: + pkey.key.events.incl(Event.Error) + flag = true + if flag: + results[rindex] = pkey.key + inc(rindex) + inc(k) + inc(i) + result = rindex + +proc select*[T](s: Selector[T], timeout: int): seq[ReadyKey[T]] = + result = newSeq[ReadyKey[T]](FD_SETSIZE) + var count = selectInto(s, timeout, result) + result.setLen(count) + +proc flush*[T](s: Selector[T]) = discard + +template isEmpty*[T](s: Selector[T]): bool = + (s.count == 0) + +when hasThreadSupport: + template withSelectLock[T](s: Selector[T], body: untyped) = + acquire(s.lock) + {.locks: [s.lock].}: + try: + body + finally: + release(s.lock) +else: + template withSelectLock[T](s: Selector[T], body: untyped) = + body + +template withData*[T](s: Selector[T], fd: SocketHandle, value, + body: untyped) = + mixin withSelectLock + s.withSelectLock(): + var value: ptr T + let fdi = int(fd) + var i = 0 + while i < FD_SETSIZE: + if s.fds[i].ident == fdi: + value = addr(s.fds[i].key.data) + break + inc(i) + if i != FD_SETSIZE: + body + +template withData*[T](s: Selector[T], fd: SocketHandle, value, + body1, body2: untyped) = + mixin withSelectLock + s.withSelectLock(): + var value: ptr T + let fdi = int(fd) + var i = 0 + while i < FD_SETSIZE: + if s.fds[i].ident == fdi: + value = addr(s.fds[i].key.data) + break + inc(i) + if i != FD_SETSIZE: + body1 + else: + body2 diff --git a/lib/pure/json.nim b/lib/pure/json.nim index b0179cd82..19947fbc2 100644 --- a/lib/pure/json.nim +++ b/lib/pure/json.nim @@ -20,7 +20,8 @@ ## let ## small_json = """{"test": 1.3, "key2": true}""" ## jobj = parseJson(small_json) -## assert (jobj.kind == JObject) +## assert (jobj.kind == JObject)\ +## jobj["test"] = newJFloat(0.7) # create or update ## echo($jobj["test"].fnum) ## echo($jobj["key2"].bval) ## @@ -49,6 +50,10 @@ ## "age": herAge ## } ## ] +## +## var j2 = %* {"name": "Isaac", "books": ["Robot Dreams"]} +## j2["details"] = %* {"age":35, "pi":3.1415} +## echo j2 import hashes, tables, strutils, lexbase, streams, unicode, macros @@ -56,6 +61,11 @@ import export tables.`$` +when defined(nimJsonGet): + {.pragma: deprecatedGet, deprecated.} +else: + {.pragma: deprecatedGet.} + type JsonEventKind* = enum ## enumeration of all events that may occur when parsing jsonError, ## an error occurred during parsing @@ -116,7 +126,7 @@ type TJsonParser: JsonParser, TTokKind: TokKind].} const - errorMessages: array [JsonError, string] = [ + errorMessages: array[JsonError, string] = [ "no error", "invalid token", "string expected", @@ -129,7 +139,7 @@ const "EOF expected", "expression expected" ] - tokToStr: array [TokKind, string] = [ + tokToStr: array[TokKind, string] = [ "invalid token", "EOF", "string literal", @@ -702,17 +712,28 @@ proc `%`*(b: bool): JsonNode = proc `%`*(keyVals: openArray[tuple[key: string, val: JsonNode]]): JsonNode = ## Generic constructor for JSON data. Creates a new `JObject JsonNode` - new(result) - result.kind = JObject - result.fields = initTable[string, JsonNode](4) + if keyvals.len == 0: return newJArray() + result = newJObject() for key, val in items(keyVals): result.fields[key] = val -proc `%`*(elements: openArray[JsonNode]): JsonNode = +template `%`*(j: JsonNode): JsonNode = j + +proc `%`*[T](elements: openArray[T]): JsonNode = ## Generic constructor for JSON data. Creates a new `JArray JsonNode` - new(result) - result.kind = JArray - newSeq(result.elems, elements.len) - for i, p in pairs(elements): result.elems[i] = p + result = newJArray() + for elem in elements: result.add(%elem) + +proc `%`*(o: object): JsonNode = + ## Generic constructor for JSON data. Creates a new `JObject JsonNode` + result = newJObject() + for k, v in o.fieldPairs: result[k] = %v + +proc `%`*(o: ref object): JsonNode = + ## Generic constructor for JSON data. Creates a new `JObject JsonNode` + if o.isNil: + result = newJNull() + else: + result = %(o[]) proc toJson(x: NimNode): NimNode {.compiletime.} = case x.kind @@ -731,6 +752,9 @@ proc toJson(x: NimNode): NimNode {.compiletime.} = result = newNimNode(nnkTableConstr) x.expectLen(0) + of nnkNilLit: + result = newCall("newJNull") + else: result = x @@ -799,16 +823,23 @@ proc len*(n: JsonNode): int = of JObject: result = n.fields.len else: discard -proc `[]`*(node: JsonNode, name: string): JsonNode {.inline.} = +proc `[]`*(node: JsonNode, name: string): JsonNode {.inline, deprecatedGet.} = ## Gets a field from a `JObject`, which must not be nil. - ## If the value at `name` does not exist, returns nil + ## If the value at `name` does not exist, raises KeyError. + ## + ## **Note:** The behaviour of this procedure changed in version 0.14.0. To + ## get a list of usages and to restore the old behaviour of this procedure, + ## compile with the ``-d:nimJsonGet`` flag. assert(not isNil(node)) assert(node.kind == JObject) - result = node.fields.getOrDefault(name) + when defined(nimJsonGet): + if not node.fields.hasKey(name): return nil + result = node.fields[name] proc `[]`*(node: JsonNode, index: int): JsonNode {.inline.} = ## Gets the node at `index` in an Array. Result is undefined if `index` - ## is out of bounds + ## is out of bounds, but as long as array bound checks are enabled it will + ## result in an exception. assert(not isNil(node)) assert(node.kind == JArray) return node.elems[index] @@ -818,6 +849,16 @@ proc hasKey*(node: JsonNode, key: string): bool = assert(node.kind == JObject) result = node.fields.hasKey(key) +proc contains*(node: JsonNode, key: string): bool = + ## Checks if `key` exists in `node`. + assert(node.kind == JObject) + node.fields.hasKey(key) + +proc contains*(node: JsonNode, val: JsonNode): bool = + ## Checks if `val` exists in array `node`. + assert(node.kind == JArray) + find(node.elems, val) >= 0 + proc existsKey*(node: JsonNode, key: string): bool {.deprecated.} = node.hasKey(key) ## Deprecated for `hasKey` @@ -838,20 +879,28 @@ proc `[]=`*(obj: JsonNode, key: string, val: JsonNode) {.inline.} = proc `{}`*(node: JsonNode, keys: varargs[string]): JsonNode = ## Traverses the node and gets the given value. If any of the - ## keys do not exist, returns nil. Also returns nil if one of the - ## intermediate data structures is not an object + ## keys do not exist, returns ``nil``. Also returns ``nil`` if one of the + ## intermediate data structures is not an object. result = node for key in keys: - if isNil(result) or result.kind!=JObject: + if isNil(result) or result.kind != JObject: return nil - result=result[key] + result = result.fields.getOrDefault(key) + +proc getOrDefault*(node: JsonNode, key: string): JsonNode = + ## Gets a field from a `node`. If `node` is nil or not an object or + ## value at `key` does not exist, returns nil + if not isNil(node) and node.kind == JObject: + result = node.fields.getOrDefault(key) + +template simpleGetOrDefault*{`{}`(node, [key])}(node: JsonNode, key: string): JsonNode = node.getOrDefault(key) proc `{}=`*(node: JsonNode, keys: varargs[string], value: JsonNode) = ## Traverses the node and tries to set the value at the given location - ## to `value` If any of the keys are missing, they are added + ## to ``value``. If any of the keys are missing, they are added. var node = node for i in 0..(keys.len-2): - if isNil(node[keys[i]]): + if not node.hasKey(keys[i]): node[keys[i]] = newJObject() node = node[keys[i]] node[keys[keys.len-1]] = value @@ -972,7 +1021,7 @@ proc toPretty(result: var string, node: JsonNode, indent = 2, ml = true, result.add("null") proc pretty*(node: JsonNode, indent = 2): string = - ## Converts `node` to its JSON Representation, with indentation and + ## Returns a JSON Representation of `node`, with indentation and ## on multiple lines. result = "" toPretty(result, node, indent) @@ -982,7 +1031,9 @@ proc toUgly*(result: var string, node: JsonNode) = ## regard for human readability. Meant to improve ``$`` string ## conversion performance. ## - ## This provides higher efficiency than the ``toPretty`` procedure as it + ## JSON representation is stored in the passed `result` + ## + ## This provides higher efficiency than the ``pretty`` procedure as it ## does **not** attempt to format the resulting JSON to make it human readable. var comma = false case node.kind: @@ -1076,7 +1127,7 @@ proc parseJson(p: var JsonParser): JsonNode = discard getTok(p) while p.tok != tkCurlyRi: if p.tok != tkString: - raiseParseErr(p, "string literal as key expected") + raiseParseErr(p, "string literal as key") var key = p.a discard getTok(p) eat(p, tkColon) @@ -1217,16 +1268,6 @@ when false: # To get that we shall use, obj["json"] when isMainModule: - when not defined(js): - var parsed = parseFile("tests/testdata/jsontest.json") - - try: - discard parsed["key2"][12123] - doAssert(false) - except IndexError: doAssert(true) - - var parsed2 = parseFile("tests/testdata/jsontest2.json") - doAssert(parsed2{"repository", "description"}.str=="IRC Library for Haskell", "Couldn't fetch via multiply nested key using {}") let testJson = parseJson"""{ "a": [1, 2, 3, 4], "b": "asd", "c": "\ud83c\udf83", "d": "\u00E6"}""" # nil passthrough @@ -1314,3 +1355,35 @@ when isMainModule: var j4 = %*{"test": nil} doAssert j4 == %{"test": newJNull()} + + let seqOfNodes = @[%1, %2] + let jSeqOfNodes = %seqOfNodes + doAssert(jSeqOfNodes[1].num == 2) + + type MyObj = object + a, b: int + s: string + f32: float32 + f64: float64 + next: ref MyObj + var m: MyObj + m.s = "hi" + m.a = 5 + let jMyObj = %m + doAssert(jMyObj["a"].num == 5) + doAssert(jMyObj["s"].str == "hi") + + # Test loading of file. + when not defined(js): + echo("99% of tests finished. Going to try loading file.") + var parsed = parseFile("tests/testdata/jsontest.json") + + try: + discard parsed["key2"][12123] + doAssert(false) + except IndexError: doAssert(true) + + var parsed2 = parseFile("tests/testdata/jsontest2.json") + doAssert(parsed2{"repository", "description"}.str=="IRC Library for Haskell", "Couldn't fetch via multiply nested key using {}") + + echo("Tests succeeded!") diff --git a/lib/pure/logging.nim b/lib/pure/logging.nim index f602ce31d..6a27e6af8 100644 --- a/lib/pure/logging.nim +++ b/lib/pure/logging.nim @@ -54,14 +54,15 @@ type lvlAll, ## all levels active lvlDebug, ## debug level (and any above) active lvlInfo, ## info level (and any above) active + lvlNotice, ## info notice (and any above) active lvlWarn, ## warn level (and any above) active lvlError, ## error level (and any above) active lvlFatal, ## fatal level (and any above) active lvlNone ## no levels active const - LevelNames*: array [Level, string] = [ - "DEBUG", "DEBUG", "INFO", "WARN", "ERROR", "FATAL", "NONE" + LevelNames*: array[Level, string] = [ + "DEBUG", "DEBUG", "INFO", "NOTICE", "WARN", "ERROR", "FATAL", "NONE" ] defaultFmtStr* = "$levelname " ## default format string @@ -258,22 +259,47 @@ template log*(level: Level, args: varargs[string, `$`]) = template debug*(args: varargs[string, `$`]) = ## Logs a debug message to all registered handlers. + ## + ## Messages that are useful to the application developer only and are usually + ## turned off in release. log(lvlDebug, args) template info*(args: varargs[string, `$`]) = ## Logs an info message to all registered handlers. + ## + ## Messages that are generated during the normal operation of an application + ## and are of no particular importance. Useful to aggregate for potential + ## later analysis. log(lvlInfo, args) +template notice*(args: varargs[string, `$`]) = + ## Logs an notice message to all registered handlers. + ## + ## Semantically very similar to `info`, but meant to be messages you want to + ## be actively notified about (depending on your application). + ## These could be, for example, grouped by hour and mailed out. + log(lvlNotice, args) + template warn*(args: varargs[string, `$`]) = ## Logs a warning message to all registered handlers. + ## + ## A non-error message that may indicate a potential problem rising or + ## impacted performance. log(lvlWarn, args) template error*(args: varargs[string, `$`]) = ## Logs an error message to all registered handlers. + ## + ## A application-level error condition. For example, some user input generated + ## an exception. The application will continue to run, but functionality or + ## data was impacted, possibly visible to users. log(lvlError, args) template fatal*(args: varargs[string, `$`]) = ## Logs a fatal error message to all registered handlers. + ## + ## A application-level fatal condition. FATAL usually means that the application + ## cannot go on and will exit (but this logging event will not do that for you). log(lvlFatal, args) proc addHandler*(handler: Logger) = diff --git a/lib/pure/matchers.nim b/lib/pure/matchers.nim index 5c28f65a0..7022c21d9 100644 --- a/lib/pure/matchers.nim +++ b/lib/pure/matchers.nim @@ -8,6 +8,10 @@ # ## This module contains various string matchers for email addresses, etc. +## +## **Warning:** This module is deprecated since version 0.14.0. +{.deprecated.} + {.deadCodeElim: on.} {.push debugger:off .} # the user does not want to trace a part diff --git a/lib/pure/math.nim b/lib/pure/math.nim index 84c8d3b11..4ef169b4f 100644 --- a/lib/pure/math.nim +++ b/lib/pure/math.nim @@ -39,8 +39,6 @@ proc fac*(n: int): int {.noSideEffect.} = when defined(Posix) and not defined(haiku): {.passl: "-lm".} -when not defined(js) and not defined(nimscript): - import times const PI* = 3.1415926535897932384626433 ## the circle constant PI (Ludolph's number) @@ -119,192 +117,247 @@ proc sum*[T](x: openArray[T]): T {.noSideEffect.} = ## If `x` is empty, 0 is returned. for i in items(x): result = result + i -proc random*(max: int): int {.benign.} - ## Returns a random number in the range 0..max-1. The sequence of - ## random number is always the same, unless `randomize` is called - ## which initializes the random number generator with a "random" - ## number, i.e. a tickcount. - -proc random*(max: float): float {.benign.} - ## Returns a random number in the range 0..<max. The sequence of - ## random number is always the same, unless `randomize` is called - ## which initializes the random number generator with a "random" - ## number, i.e. a tickcount. This has a 16-bit resolution on windows - ## and a 48-bit resolution on other platforms. - -when not defined(nimscript): - proc randomize*() {.benign.} - ## Initializes the random number generator with a "random" - ## number, i.e. a tickcount. Note: Does nothing for the JavaScript target, - ## as JavaScript does not support this. Nor does it work for NimScript. - -proc randomize*(seed: int) {.benign.} - ## Initializes the random number generator with a specific seed. - ## Note: Does nothing for the JavaScript target, - ## as JavaScript does not support this. - {.push noSideEffect.} when not defined(JS): - proc sqrt*(x: float): float {.importc: "sqrt", header: "<math.h>".} + proc sqrt*(x: float32): float32 {.importc: "sqrtf", header: "<math.h>".} + proc sqrt*(x: float64): float64 {.importc: "sqrt", header: "<math.h>".} ## Computes the square root of `x`. - proc cbrt*(x: float): float {.importc: "cbrt", header: "<math.h>".} + proc cbrt*(x: float32): float32 {.importc: "cbrtf", header: "<math.h>".} + proc cbrt*(x: float64): float64 {.importc: "cbrt", header: "<math.h>".} ## Computes the cubic root of `x` - proc ln*(x: float): float {.importc: "log", header: "<math.h>".} + proc ln*(x: float32): float32 {.importc: "logf", header: "<math.h>".} + proc ln*(x: float64): float64 {.importc: "log", header: "<math.h>".} ## Computes the natural log of `x` - proc log10*(x: float): float {.importc: "log10", header: "<math.h>".} + proc log10*(x: float32): float32 {.importc: "log10f", header: "<math.h>".} + proc log10*(x: float64): float64 {.importc: "log10", header: "<math.h>".} ## Computes the common logarithm (base 10) of `x` - proc log2*(x: float): float = return ln(x) / ln(2.0) + proc log2*[T: float32|float64](x: T): T = return ln(x) / ln(2.0) ## Computes the binary logarithm (base 2) of `x` - proc exp*(x: float): float {.importc: "exp", header: "<math.h>".} + proc exp*(x: float32): float32 {.importc: "expf", header: "<math.h>".} + proc exp*(x: float64): float64 {.importc: "exp", header: "<math.h>".} ## Computes the exponential function of `x` (pow(E, x)) - proc frexp*(x: float, exponent: var int): float {. - importc: "frexp", header: "<math.h>".} - ## Split a number into mantissa and exponent. - ## `frexp` calculates the mantissa m (a float greater than or equal to 0.5 - ## and less than 1) and the integer value n such that `x` (the original - ## float value) equals m * 2**n. frexp stores n in `exponent` and returns - ## m. - - proc round*(x: float): int {.importc: "lrint", header: "<math.h>".} - ## Converts a float to an int by rounding. - - proc arccos*(x: float): float {.importc: "acos", header: "<math.h>".} + proc arccos*(x: float32): float32 {.importc: "acosf", header: "<math.h>".} + proc arccos*(x: float64): float64 {.importc: "acos", header: "<math.h>".} ## Computes the arc cosine of `x` - proc arcsin*(x: float): float {.importc: "asin", header: "<math.h>".} + proc arcsin*(x: float32): float32 {.importc: "asinf", header: "<math.h>".} + proc arcsin*(x: float64): float64 {.importc: "asin", header: "<math.h>".} ## Computes the arc sine of `x` - proc arctan*(x: float): float {.importc: "atan", header: "<math.h>".} + proc arctan*(x: float32): float32 {.importc: "atanf", header: "<math.h>".} + proc arctan*(x: float64): float64 {.importc: "atan", header: "<math.h>".} ## Calculate the arc tangent of `y` / `x` - proc arctan2*(y, x: float): float {.importc: "atan2", header: "<math.h>".} + proc arctan2*(y, x: float32): float32 {.importc: "atan2f", header: "<math.h>".} + proc arctan2*(y, x: float64): float64 {.importc: "atan2", header: "<math.h>".} ## Calculate the arc tangent of `y` / `x`. ## `atan2` returns the arc tangent of `y` / `x`; it produces correct ## results even when the resulting angle is near pi/2 or -pi/2 ## (`x` near 0). - proc cos*(x: float): float {.importc: "cos", header: "<math.h>".} + proc cos*(x: float32): float32 {.importc: "cosf", header: "<math.h>".} + proc cos*(x: float64): float64 {.importc: "cos", header: "<math.h>".} ## Computes the cosine of `x` - proc cosh*(x: float): float {.importc: "cosh", header: "<math.h>".} + + proc cosh*(x: float32): float32 {.importc: "coshf", header: "<math.h>".} + proc cosh*(x: float64): float64 {.importc: "cosh", header: "<math.h>".} ## Computes the hyperbolic cosine of `x` - proc hypot*(x, y: float): float {.importc: "hypot", header: "<math.h>".} + + proc hypot*(x, y: float32): float32 {.importc: "hypotf", header: "<math.h>".} + proc hypot*(x, y: float64): float64 {.importc: "hypot", header: "<math.h>".} ## Computes the hypotenuse of a right-angle triangle with `x` and ## `y` as its base and height. Equivalent to ``sqrt(x*x + y*y)``. - proc sinh*(x: float): float {.importc: "sinh", header: "<math.h>".} + proc sinh*(x: float32): float32 {.importc: "sinhf", header: "<math.h>".} + proc sinh*(x: float64): float64 {.importc: "sinh", header: "<math.h>".} ## Computes the hyperbolic sine of `x` - proc sin*(x: float): float {.importc: "sin", header: "<math.h>".} + proc sin*(x: float32): float32 {.importc: "sinf", header: "<math.h>".} + proc sin*(x: float64): float64 {.importc: "sin", header: "<math.h>".} ## Computes the sine of `x` - proc tan*(x: float): float {.importc: "tan", header: "<math.h>".} + + proc tan*(x: float32): float32 {.importc: "tanf", header: "<math.h>".} + proc tan*(x: float64): float64 {.importc: "tan", header: "<math.h>".} ## Computes the tangent of `x` - proc tanh*(x: float): float {.importc: "tanh", header: "<math.h>".} + proc tanh*(x: float32): float32 {.importc: "tanhf", header: "<math.h>".} + proc tanh*(x: float64): float64 {.importc: "tanh", header: "<math.h>".} ## Computes the hyperbolic tangent of `x` - proc pow*(x, y: float): float {.importc: "pow", header: "<math.h>".} - ## Computes `x` to power of `y`. - proc erf*(x: float): float {.importc: "erf", header: "<math.h>".} + proc pow*(x, y: float32): float32 {.importc: "powf", header: "<math.h>".} + proc pow*(x, y: float64): float64 {.importc: "pow", header: "<math.h>".} + ## computes x to power raised of y. + + proc erf*(x: float32): float32 {.importc: "erff", header: "<math.h>".} + proc erf*(x: float64): float64 {.importc: "erf", header: "<math.h>".} ## The error function - proc erfc*(x: float): float {.importc: "erfc", header: "<math.h>".} + proc erfc*(x: float32): float32 {.importc: "erfcf", header: "<math.h>".} + proc erfc*(x: float64): float64 {.importc: "erfc", header: "<math.h>".} ## The complementary error function - proc lgamma*(x: float): float {.importc: "lgamma", header: "<math.h>".} + proc lgamma*(x: float32): float32 {.importc: "lgammaf", header: "<math.h>".} + proc lgamma*(x: float64): float64 {.importc: "lgamma", header: "<math.h>".} ## Natural log of the gamma function - proc tgamma*(x: float): float {.importc: "tgamma", header: "<math.h>".} + proc tgamma*(x: float32): float32 {.importc: "tgammaf", header: "<math.h>".} + proc tgamma*(x: float64): float64 {.importc: "tgamma", header: "<math.h>".} ## The gamma function - # C procs: - when defined(vcc) and false: - # The "secure" random, available from Windows XP - # https://msdn.microsoft.com/en-us/library/sxtz2fa8.aspx - # Present in some variants of MinGW but not enough to justify - # `when defined(windows)` yet - proc rand_s(val: var cuint) {.importc: "rand_s", header: "<stdlib.h>".} - # To behave like the normal version - proc rand(): cuint = rand_s(result) - else: - proc srand(seed: cint) {.importc: "srand", header: "<stdlib.h>".} - proc rand(): cint {.importc: "rand", header: "<stdlib.h>".} - - when not defined(windows): - proc srand48(seed: clong) {.importc: "srand48", header: "<stdlib.h>".} - proc drand48(): float {.importc: "drand48", header: "<stdlib.h>".} - proc random(max: float): float = - result = drand48() * max - else: - when defined(vcc): # Windows with Visual C - proc random(max: float): float = - # we are hardcoding this because - # importc-ing macros is extremely problematic - # and because the value is publicly documented - # on MSDN and very unlikely to change - # See https://msdn.microsoft.com/en-us/library/296az74e.aspx - const rand_max = 4294967295 # UINT_MAX - result = (float(rand()) / float(rand_max)) * max - proc randomize() = discard - proc randomize(seed: int) = discard - else: # Windows with another compiler - proc random(max: float): float = - # we are hardcoding this because - # importc-ing macros is extremely problematic - # and because the value is publicly documented - # on MSDN and very unlikely to change - const rand_max = 32767 - result = (float(rand()) / float(rand_max)) * max - - when not defined(vcc): # the above code for vcc uses `discard` instead - # this is either not Windows or is Windows without vcc - when not defined(nimscript): - proc randomize() = - randomize(cast[int](epochTime())) - proc randomize(seed: int) = - srand(cint(seed)) # rand_s doesn't use srand - when declared(srand48): srand48(seed) - - proc random(max: int): int = - result = int(rand()) mod max - - proc trunc*(x: float): float {.importc: "trunc", header: "<math.h>".} - ## Truncates `x` to the decimal point - ## - ## .. code-block:: nim - ## echo trunc(PI) # 3.0 - proc floor*(x: float): float {.importc: "floor", header: "<math.h>".} + proc floor*(x: float32): float32 {.importc: "floorf", header: "<math.h>".} + proc floor*(x: float64): float64 {.importc: "floor", header: "<math.h>".} ## Computes the floor function (i.e., the largest integer not greater than `x`) ## ## .. code-block:: nim ## echo floor(-3.5) ## -4.0 - proc ceil*(x: float): float {.importc: "ceil", header: "<math.h>".} + + proc ceil*(x: float32): float32 {.importc: "ceilf", header: "<math.h>".} + proc ceil*(x: float64): float64 {.importc: "ceil", header: "<math.h>".} ## Computes the ceiling function (i.e., the smallest integer not less than `x`) ## ## .. code-block:: nim ## echo ceil(-2.1) ## -2.0 - proc fmod*(x, y: float): float {.importc: "fmod", header: "<math.h>".} + when defined(windows) and defined(vcc): + # MSVC 2010 don't have trunc/truncf + # this implementation was inspired by Go-lang Math.Trunc + proc truncImpl(f: float64): float64 = + const + mask : uint64 = 0x7FF + shift: uint64 = 64 - 12 + bias : uint64 = 0x3FF + + if f < 1: + if f < 0: return -truncImpl(-f) + elif f == 0: return f # Return -0 when f == -0 + else: return 0 + + var x = cast[uint64](f) + let e = (x shr shift) and mask - bias + + # Keep the top 12+e bits, the integer part; clear the rest. + if e < 64-12: + x = x and (not (1'u64 shl (64'u64-12'u64-e) - 1'u64)) + + result = cast[float64](x) + + proc truncImpl(f: float32): float32 = + const + mask : uint32 = 0xFF + shift: uint32 = 32 - 9 + bias : uint32 = 0x7F + + if f < 1: + if f < 0: return -truncImpl(-f) + elif f == 0: return f # Return -0 when f == -0 + else: return 0 + + var x = cast[uint32](f) + let e = (x shr shift) and mask - bias + + # Keep the top 9+e bits, the integer part; clear the rest. + if e < 32-9: + x = x and (not (1'u32 shl (32'u32-9'u32-e) - 1'u32)) + + result = cast[float32](x) + + proc trunc*(x: float64): float64 = + if classify(x) in {fcZero, fcNegZero, fcNan, fcInf, fcNegInf}: return x + result = truncImpl(x) + + proc trunc*(x: float32): float32 = + if classify(x) in {fcZero, fcNegZero, fcNan, fcInf, fcNegInf}: return x + result = truncImpl(x) + + proc round0[T: float32|float64](x: T): T = + ## Windows compilers prior to MSVC 2012 do not implement 'round', + ## 'roundl' or 'roundf'. + result = if x < 0.0: ceil(x - T(0.5)) else: floor(x + T(0.5)) + else: + proc round0(x: float32): float32 {.importc: "roundf", header: "<math.h>".} + proc round0(x: float64): float64 {.importc: "round", header: "<math.h>".} + ## Rounds a float to zero decimal places. Used internally by the round + ## function when the specified number of places is 0. + + proc trunc*(x: float32): float32 {.importc: "truncf", header: "<math.h>".} + proc trunc*(x: float64): float64 {.importc: "trunc", header: "<math.h>".} + ## Truncates `x` to the decimal point + ## + ## .. code-block:: nim + ## echo trunc(PI) # 3.0 + + proc fmod*(x, y: float32): float32 {.importc: "fmodf", header: "<math.h>".} + proc fmod*(x, y: float64): float64 {.importc: "fmod", header: "<math.h>".} ## Computes the remainder of `x` divided by `y` ## ## .. code-block:: nim ## echo fmod(-2.5, 0.3) ## -0.1 else: - proc mathrandom(): float {.importc: "Math.random", nodecl.} - proc floor*(x: float): float {.importc: "Math.floor", nodecl.} - proc ceil*(x: float): float {.importc: "Math.ceil", nodecl.} - proc random(max: int): int = - result = int(floor(mathrandom() * float(max))) - proc random(max: float): float = - result = float(mathrandom() * float(max)) - proc randomize() = discard - proc randomize(seed: int) = discard - - proc sqrt*(x: float): float {.importc: "Math.sqrt", nodecl.} - proc ln*(x: float): float {.importc: "Math.log", nodecl.} - proc log10*(x: float): float = return ln(x) / ln(10.0) - proc log2*(x: float): float = return ln(x) / ln(2.0) - - proc exp*(x: float): float {.importc: "Math.exp", nodecl.} - proc round*(x: float): int {.importc: "Math.round", nodecl.} - proc pow*(x, y: float): float {.importc: "Math.pow", nodecl.} - - proc frexp*(x: float, exponent: var int): float = + proc floor*(x: float32): float32 {.importc: "Math.floor", nodecl.} + proc floor*(x: float64): float64 {.importc: "Math.floor", nodecl.} + proc ceil*(x: float32): float32 {.importc: "Math.ceil", nodecl.} + proc ceil*(x: float64): float64 {.importc: "Math.ceil", nodecl.} + + proc sqrt*(x: float32): float32 {.importc: "Math.sqrt", nodecl.} + proc sqrt*(x: float64): float64 {.importc: "Math.sqrt", nodecl.} + proc ln*(x: float32): float32 {.importc: "Math.log", nodecl.} + proc ln*(x: float64): float64 {.importc: "Math.log", nodecl.} + proc log10*[T: float32|float64](x: T): T = return ln(x) / ln(10.0) + proc log2*[T: float32|float64](x: T): T = return ln(x) / ln(2.0) + + proc exp*(x: float32): float32 {.importc: "Math.exp", nodecl.} + proc exp*(x: float64): float64 {.importc: "Math.exp", nodecl.} + proc round0(x: float): float {.importc: "Math.round", nodecl.} + + proc pow*(x, y: float32): float32 {.importC: "Math.pow", nodecl.} + proc pow*(x, y: float64): float64 {.importc: "Math.pow", nodecl.} + + proc arccos*(x: float32): float32 {.importc: "Math.acos", nodecl.} + proc arccos*(x: float64): float64 {.importc: "Math.acos", nodecl.} + proc arcsin*(x: float32): float32 {.importc: "Math.asin", nodecl.} + proc arcsin*(x: float64): float64 {.importc: "Math.asin", nodecl.} + proc arctan*(x: float32): float32 {.importc: "Math.atan", nodecl.} + proc arctan*(x: float64): float64 {.importc: "Math.atan", nodecl.} + proc arctan2*(y, x: float32): float32 {.importC: "Math.atan2", nodecl.} + proc arctan2*(y, x: float64): float64 {.importc: "Math.atan2", nodecl.} + + proc cos*(x: float32): float32 {.importc: "Math.cos", nodecl.} + proc cos*(x: float64): float64 {.importc: "Math.cos", nodecl.} + proc cosh*(x: float32): float32 = return (exp(x)+exp(-x))*0.5 + proc cosh*(x: float64): float64 = return (exp(x)+exp(-x))*0.5 + proc hypot*[T: float32|float64](x, y: T): T = return sqrt(x*x + y*y) + proc sinh*[T: float32|float64](x: T): T = return (exp(x)-exp(-x))*0.5 + proc sin*(x: float32): float32 {.importc: "Math.sin", nodecl.} + proc sin*(x: float64): float64 {.importc: "Math.sin", nodecl.} + proc tan*(x: float32): float32 {.importc: "Math.tan", nodecl.} + proc tan*(x: float64): float64 {.importc: "Math.tan", nodecl.} + proc tanh*[T: float32|float64](x: T): T = + var y = exp(2.0*x) + return (y-1.0)/(y+1.0) + +proc round*[T: float32|float64](x: T, places: int = 0): T = + ## Round a floating point number. + ## + ## If `places` is 0 (or omitted), round to the nearest integral value + ## following normal mathematical rounding rules (e.g. `round(54.5) -> 55.0`). + ## If `places` is greater than 0, round to the given number of decimal + ## places, e.g. `round(54.346, 2) -> 54.35`. + ## If `places` is negative, round to the left of the decimal place, e.g. + ## `round(537.345, -1) -> 540.0` + if places == 0: + result = round0(x) + else: + var mult = pow(10.0, places.T) + result = round0(x*mult)/mult + +when not defined(JS): + proc frexp*(x: float32, exponent: var int): float32 {. + importc: "frexp", header: "<math.h>".} + proc frexp*(x: float64, exponent: var int): float64 {. + importc: "frexp", header: "<math.h>".} + ## Split a number into mantissa and exponent. + ## `frexp` calculates the mantissa m (a float greater than or equal to 0.5 + ## and less than 1) and the integer value n such that `x` (the original + ## float value) equals m * 2**n. frexp stores n in `exponent` and returns + ## m. +else: + proc frexp*[T: float32|float64](x: T, exponent: var int): T = if x == 0.0: exponent = 0 result = 0.0 @@ -315,20 +368,22 @@ else: exponent = round(ex) result = x / pow(2.0, ex) - proc arccos*(x: float): float {.importc: "Math.acos", nodecl.} - proc arcsin*(x: float): float {.importc: "Math.asin", nodecl.} - proc arctan*(x: float): float {.importc: "Math.atan", nodecl.} - proc arctan2*(y, x: float): float {.importc: "Math.atan2", nodecl.} - - proc cos*(x: float): float {.importc: "Math.cos", nodecl.} - proc cosh*(x: float): float = return (exp(x)+exp(-x))*0.5 - proc hypot*(x, y: float): float = return sqrt(x*x + y*y) - proc sinh*(x: float): float = return (exp(x)-exp(-x))*0.5 - proc sin*(x: float): float {.importc: "Math.sin", nodecl.} - proc tan*(x: float): float {.importc: "Math.tan", nodecl.} - proc tanh*(x: float): float = - var y = exp(2.0*x) - return (y-1.0)/(y+1.0) +proc splitDecimal*[T: float32|float64](x: T): tuple[intpart: T, floatpart: T] = + ## Breaks `x` into an integral and a fractional part. + ## + ## Returns a tuple containing intpart and floatpart representing + ## the integer part and the fractional part respectively. + ## + ## Both parts have the same sign as `x`. Analogous to the `modf` + ## function in C. + var + absolute: T + absolute = abs(x) + result.intpart = floor(absolute) + result.floatpart = absolute - result.intpart + if x < 0: + result.intpart = -result.intpart + result.floatpart = -result.floatpart {.pop.} @@ -340,7 +395,7 @@ proc radToDeg*[T: float32|float64](d: T): T {.inline.} = ## Convert from radians to degrees result = T(d) / RadPerDeg -proc `mod`*(x, y: float): float = +proc `mod`*[T: float32|float64](x, y: T): T = ## Computes the modulo operation for float operators. Equivalent ## to ``x - y * floor(x/y)``. Note that the remainder will always ## have the same sign as the divisor. @@ -349,21 +404,13 @@ proc `mod`*(x, y: float): float = ## echo (4.0 mod -3.1) # -2.2 result = if y == 0.0: x else: x - y * (x/y).floor -proc random*[T](x: Slice[T]): T = - ## For a slice `a .. b` returns a value in the range `a .. b-1`. - result = random(x.b - x.a) + x.a - -proc random*[T](a: openArray[T]): T = - ## returns a random element from the openarray `a`. - result = a[random(a.low..a.len)] - {.pop.} {.pop.} proc `^`*[T](x, y: T): T = ## Computes ``x`` to the power ``y`. ``x`` must be non-negative, use ## `pow <#pow,float,float>` for negative exponents. - assert y >= 0 + assert y >= T(0) var (x, y) = (x, y) result = 1 @@ -391,24 +438,6 @@ proc lcm*[T](x, y: T): T = x div gcd(x, y) * y when isMainModule and not defined(JS): - proc gettime(dummy: ptr cint): cint {.importc: "time", header: "<time.h>".} - - # Verifies random seed initialization. - let seed = gettime(nil) - randomize(seed) - const SIZE = 10 - var buf : array[0..SIZE, int] - # Fill the buffer with random values - for i in 0..SIZE-1: - buf[i] = random(high(int)) - # Check that the second random calls are the same for each position. - randomize(seed) - for i in 0..SIZE-1: - assert buf[i] == random(high(int)), "non deterministic random seeding" - - when not defined(testing): - echo "random values equal after reseeding" - # Check for no side effect annotation proc mySqrt(num: float): float {.noSideEffect.} = return sqrt(num) @@ -418,3 +447,65 @@ when isMainModule and not defined(JS): assert(lgamma(1.0) == 0.0) # ln(1.0) == 0.0 assert(erf(6.0) > erf(5.0)) assert(erfc(6.0) < erfc(5.0)) +when isMainModule: + # Function for approximate comparison of floats + proc `==~`(x, y: float): bool = (abs(x-y) < 1e-9) + + block: # round() tests + # Round to 0 decimal places + doAssert round(54.652) ==~ 55.0 + doAssert round(54.352) ==~ 54.0 + doAssert round(-54.652) ==~ -55.0 + doAssert round(-54.352) ==~ -54.0 + doAssert round(0.0) ==~ 0.0 + # Round to positive decimal places + doAssert round(-547.652, 1) ==~ -547.7 + doAssert round(547.652, 1) ==~ 547.7 + doAssert round(-547.652, 2) ==~ -547.65 + doAssert round(547.652, 2) ==~ 547.65 + # Round to negative decimal places + doAssert round(547.652, -1) ==~ 550.0 + doAssert round(547.652, -2) ==~ 500.0 + doAssert round(547.652, -3) ==~ 1000.0 + doAssert round(547.652, -4) ==~ 0.0 + doAssert round(-547.652, -1) ==~ -550.0 + doAssert round(-547.652, -2) ==~ -500.0 + doAssert round(-547.652, -3) ==~ -1000.0 + doAssert round(-547.652, -4) ==~ 0.0 + + block: # splitDecimal() tests + doAssert splitDecimal(54.674).intpart ==~ 54.0 + doAssert splitDecimal(54.674).floatpart ==~ 0.674 + doAssert splitDecimal(-693.4356).intpart ==~ -693.0 + doAssert splitDecimal(-693.4356).floatpart ==~ -0.4356 + doAssert splitDecimal(0.0).intpart ==~ 0.0 + doAssert splitDecimal(0.0).floatpart ==~ 0.0 + + block: # trunc tests for vcc + doAssert(trunc(-1.1) == -1) + doAssert(trunc(1.1) == 1) + doAssert(trunc(-0.1) == -0) + doAssert(trunc(0.1) == 0) + + #special case + doAssert(classify(trunc(1e1000000)) == fcInf) + doAssert(classify(trunc(-1e1000000)) == fcNegInf) + doAssert(classify(trunc(0.0/0.0)) == fcNan) + doAssert(classify(trunc(0.0)) == fcZero) + + #trick the compiler to produce signed zero + let + f_neg_one = -1.0 + f_zero = 0.0 + f_nan = f_zero / f_zero + + doAssert(classify(trunc(f_neg_one*f_zero)) == fcNegZero) + + doAssert(trunc(-1.1'f32) == -1) + doAssert(trunc(1.1'f32) == 1) + doAssert(trunc(-0.1'f32) == -0) + doAssert(trunc(0.1'f32) == 0) + doAssert(classify(trunc(1e1000000'f32)) == fcInf) + doAssert(classify(trunc(-1e1000000'f32)) == fcNegInf) + doAssert(classify(trunc(f_nan.float32)) == fcNan) + doAssert(classify(trunc(0.0'f32)) == fcZero) diff --git a/lib/pure/memfiles.nim b/lib/pure/memfiles.nim index b9c574944..aa32778c5 100644 --- a/lib/pure/memfiles.nim +++ b/lib/pure/memfiles.nim @@ -42,6 +42,10 @@ type proc mapMem*(m: var MemFile, mode: FileMode = fmRead, mappedSize = -1, offset = 0): pointer = + ## returns a pointer to a mapped portion of MemFile `m` + ## + ## ``mappedSize`` of ``-1`` maps to the whole file, and + ## ``offset`` must be multiples of the PAGE SIZE of your OS var readonly = mode == fmRead when defined(windows): result = mapViewOfFileEx( @@ -68,7 +72,9 @@ proc mapMem*(m: var MemFile, mode: FileMode = fmRead, proc unmapMem*(f: var MemFile, p: pointer, size: int) = ## unmaps the memory region ``(p, <p+size)`` of the mapped file `f`. ## All changes are written back to the file system, if `f` was opened - ## with write access. ``size`` must be of exactly the size that was requested + ## with write access. + ## + ## ``size`` must be of exactly the size that was requested ## via ``mapMem``. when defined(windows): if unmapViewOfFile(p) == 0: raiseOSError(osLastError()) @@ -79,9 +85,17 @@ proc unmapMem*(f: var MemFile, p: pointer, size: int) = proc open*(filename: string, mode: FileMode = fmRead, mappedSize = -1, offset = 0, newFileSize = -1): MemFile = ## opens a memory mapped file. If this fails, ``EOS`` is raised. - ## `newFileSize` can only be set if the file does not exist and is opened - ## with write access (e.g., with fmReadWrite). `mappedSize` and `offset` - ## can be used to map only a slice of the file. Example: + ## + ## ``newFileSize`` can only be set if the file does not exist and is opened + ## with write access (e.g., with fmReadWrite). + ## + ##``mappedSize`` and ``offset`` + ## can be used to map only a slice of the file. + ## + ## ``offset`` must be multiples of the PAGE SIZE of your OS + ## (usually 4K or 8K but is unique to your OS) + ## + ## Example: ## ## .. code-block:: nim ## var @@ -257,12 +271,10 @@ type MemSlice* = object ## represent slice of a MemFile for iteration over deli data*: pointer size*: int -proc c_memcpy(a, b: pointer, n: int) {.importc: "memcpy", header: "<string.h>".} - proc `$`*(ms: MemSlice): string {.inline.} = ## Return a Nim string built from a MemSlice. var buf = newString(ms.size) - c_memcpy(addr(buf[0]), ms.data, ms.size) + copyMem(addr(buf[0]), ms.data, ms.size) buf[ms.size] = '\0' result = buf @@ -287,7 +299,9 @@ iterator memSlices*(mfile: MemFile, delim='\l', eat='\r'): MemSlice {.inline.} = ## iterate over line-like records in a file. However, returned (data,size) ## objects are not Nim strings, bounds checked Nim arrays, or even terminated ## C strings. So, care is required to access the data (e.g., think C mem* - ## functions, not str* functions). Example: + ## functions, not str* functions). + ## + ## Example: ## ## .. code-block:: nim ## var count = 0 @@ -320,7 +334,9 @@ iterator lines*(mfile: MemFile, buf: var TaintedString, delim='\l', eat='\r'): T ## Replace contents of passed buffer with each new line, like ## `readLine(File) <system.html#readLine,File,TaintedString>`_. ## `delim`, `eat`, and delimiting logic is exactly as for - ## `memSlices <#memSlices>`_, but Nim strings are returned. Example: + ## `memSlices <#memSlices>`_, but Nim strings are returned. + ## + ## Example: ## ## .. code-block:: nim ## var buffer: TaintedString = "" @@ -329,7 +345,7 @@ iterator lines*(mfile: MemFile, buf: var TaintedString, delim='\l', eat='\r'): T for ms in memSlices(mfile, delim, eat): buf.setLen(ms.size) - c_memcpy(addr(buf[0]), ms.data, ms.size) + copyMem(addr(buf[0]), ms.data, ms.size) buf[ms.size] = '\0' yield buf @@ -337,7 +353,9 @@ iterator lines*(mfile: MemFile, delim='\l', eat='\r'): TaintedString {.inline.} ## Return each line in a file as a Nim string, like ## `lines(File) <system.html#lines.i,File>`_. ## `delim`, `eat`, and delimiting logic is exactly as for - ## `memSlices <#memSlices>`_, but Nim strings are returned. Example: + ## `memSlices <#memSlices>`_, but Nim strings are returned. + ## + ## Example: ## ## .. code-block:: nim ## for line in lines(memfiles.open("foo")): diff --git a/lib/pure/mersenne.nim b/lib/pure/mersenne.nim index ae0845714..36b597767 100644 --- a/lib/pure/mersenne.nim +++ b/lib/pure/mersenne.nim @@ -1,3 +1,12 @@ +# +# +# Nim's Runtime Library +# (c) Copyright 2015 Nim Contributors +# +# See the file "copying.txt", included in this +# distribution, for details about the copyright. +# + type MersenneTwister* = object mt: array[0..623, uint32] @@ -5,29 +14,31 @@ type {.deprecated: [TMersenneTwister: MersenneTwister].} -proc newMersenneTwister*(seed: int): MersenneTwister = +proc newMersenneTwister*(seed: uint32): MersenneTwister = result.index = 0 - result.mt[0]= uint32(seed) + result.mt[0] = seed for i in 1..623'u32: - result.mt[i]= (0x6c078965'u32 * (result.mt[i-1] xor (result.mt[i-1] shr 30'u32)) + i) + result.mt[i] = (0x6c078965'u32 * (result.mt[i-1] xor (result.mt[i-1] shr 30'u32)) + i) proc generateNumbers(m: var MersenneTwister) = for i in 0..623: - var y = (m.mt[i] and 0x80000000'u32) + (m.mt[(i+1) mod 624] and 0x7fffffff'u32) + var y = (m.mt[i] and 0x80000000'u32) + + (m.mt[(i+1) mod 624] and 0x7fffffff'u32) m.mt[i] = m.mt[(i+397) mod 624] xor uint32(y shr 1'u32) if (y mod 2'u32) != 0: - m.mt[i] = m.mt[i] xor 0x9908b0df'u32 + m.mt[i] = m.mt[i] xor 0x9908b0df'u32 -proc getNum*(m: var MersenneTwister): int = +proc getNum*(m: var MersenneTwister): uint32 = + ## Returns the next pseudo random number ranging from 0 to high(uint32) if m.index == 0: generateNumbers(m) - var y = m.mt[m.index] - y = y xor (y shr 11'u32) - y = y xor ((7'u32 shl y) and 0x9d2c5680'u32) - y = y xor ((15'u32 shl y) and 0xefc60000'u32) - y = y xor (y shr 18'u32) - m.index = (m.index+1) mod 624 - return int(y) + result = m.mt[m.index] + m.index = (m.index + 1) mod m.mt.len + + result = result xor (result shr 11'u32) + result = result xor ((7'u32 shl result) and 0x9d2c5680'u32) + result = result xor ((15'u32 shl result) and 0xefc60000'u32) + result = result xor (result shr 18'u32) # Test when not defined(testing) and isMainModule: diff --git a/lib/pure/nativesockets.nim b/lib/pure/nativesockets.nim index 043d6d80a..4526afa49 100644 --- a/lib/pure/nativesockets.nim +++ b/lib/pure/nativesockets.nim @@ -27,7 +27,7 @@ else: import posix export fcntl, F_GETFL, O_NONBLOCK, F_SETFL, EAGAIN, EWOULDBLOCK, MSG_NOSIGNAL, EINTR, EINPROGRESS, ECONNRESET, EPIPE, ENETRESET - export Sockaddr_storage + export Sockaddr_storage, Sockaddr_un, Sockaddr_un_path_length export SocketHandle, Sockaddr_in, Addrinfo, INADDR_ANY, SockAddr, SockLen, Sockaddr_in6, @@ -38,7 +38,7 @@ export SOL_SOCKET, SOMAXCONN, SO_ACCEPTCONN, SO_BROADCAST, SO_DEBUG, SO_DONTROUTE, - SO_KEEPALIVE, SO_OOBINLINE, SO_REUSEADDR, + SO_KEEPALIVE, SO_OOBINLINE, SO_REUSEADDR, SO_REUSEPORT, MSG_PEEK when defined(macosx) and not defined(nimdoc): @@ -326,8 +326,13 @@ proc getHostByAddr*(ip: string): Hostent {.tags: [ReadIOEffect].} = cint(AF_INET)) if s == nil: raiseOSError(osLastError()) else: - var s = posix.gethostbyaddr(addr(myaddr), sizeof(myaddr).Socklen, - cint(posix.AF_INET)) + var s = + when defined(android4): + posix.gethostbyaddr(cast[cstring](addr(myaddr)), sizeof(myaddr).cint, + cint(posix.AF_INET)) + else: + posix.gethostbyaddr(addr(myaddr), sizeof(myaddr).Socklen, + cint(posix.AF_INET)) if s == nil: raiseOSError(osLastError(), $hstrerror(h_errno)) diff --git a/lib/pure/net.nim b/lib/pure/net.nim index 330682ca9..bd208761b 100644 --- a/lib/pure/net.nim +++ b/lib/pure/net.nim @@ -8,19 +8,77 @@ # ## This module implements a high-level cross-platform sockets interface. +## The procedures implemented in this module are primarily for blocking sockets. +## For asynchronous non-blocking sockets use the ``asyncnet`` module together +## with the ``asyncdispatch`` module. +## +## The first thing you will always need to do in order to start using sockets, +## is to create a new instance of the ``Socket`` type using the ``newSocket`` +## procedure. +## +## SSL +## ==== +## +## In order to use the SSL procedures defined in this module, you will need to +## compile your application with the ``-d:ssl`` flag. +## +## Examples +## ======== +## +## Connecting to a server +## ---------------------- +## +## After you create a socket with the ``newSocket`` procedure, you can easily +## connect it to a server running at a known hostname (or IP address) and port. +## To do so over TCP, use the example below. +## +## .. code-block:: Nim +## var socket = newSocket() +## socket.connect("google.com", Port(80)) +## +## UDP is a connectionless protocol, so UDP sockets don't have to explicitly +## call the ``connect`` procedure. They can simply start sending data +## immediately. +## +## .. code-block:: Nim +## var socket = newSocket() +## socket.sendTo("192.168.0.1", Port(27960), "status\n") +## +## Creating a server +## ----------------- +## +## After you create a socket with the ``newSocket`` procedure, you can create a +## TCP server by calling the ``bindAddr`` and ``listen`` procedures. +## +## .. code-block:: Nim +## var socket = newSocket() +## socket.bindAddr(Port(1234)) +## socket.listen() +## +## You can then begin accepting connections using the ``accept`` procedure. +## +## .. code-block:: Nim +## var client = newSocket() +## var address = "" +## while true: +## socket.acceptAddr(client, address) +## echo("Client connected from: ", address) +## {.deadCodeElim: on.} -import nativesockets, os, strutils, parseutils, times +import nativesockets, os, strutils, parseutils, times, sets export Port, `$`, `==` +export Domain, SockType, Protocol const useWinVersion = defined(Windows) or defined(nimdoc) +const defineSsl = defined(ssl) or defined(nimdoc) -when defined(ssl): +when defineSsl: import openssl # Note: The enumerations are mapped to Window's constants. -when defined(ssl): +when defineSsl: type SslError* = object of Exception @@ -30,7 +88,10 @@ when defined(ssl): SslProtVersion* = enum protSSLv2, protSSLv3, protTLSv1, protSSLv23 - SslContext* = distinct SslCtx + SslContext* = ref object + context*: SslCtx + extraInternalIndex: int + referencedData: HashSet[int] SslAcceptResult* = enum AcceptNoClient = 0, AcceptNoHandshake, AcceptSuccess @@ -38,6 +99,10 @@ when defined(ssl): SslHandshakeType* = enum handshakeAsClient, handshakeAsServer + SslClientGetPskFunc* = proc(hint: string): tuple[identity: string, psk: string] + + SslServerGetPskFunc* = proc(identity: string): string + {.deprecated: [ESSL: SSLError, TSSLCVerifyMode: SSLCVerifyMode, TSSLProtVersion: SSLProtVersion, PSSLContext: SSLContext, TSSLAcceptResult: SSLAcceptResult].} @@ -54,7 +119,7 @@ type currPos: int # current index in buffer bufLen: int # current length of buffer of false: nil - when defined(ssl): + when defineSsl: case isSsl: bool of true: sslHandle: SSLPtr @@ -72,7 +137,7 @@ type SOBool* = enum ## Boolean socket options. OptAcceptConn, OptBroadcast, OptDebug, OptDontRoute, OptKeepAlive, - OptOOBInline, OptReuseAddr + OptOOBInline, OptReuseAddr, OptReusePort ReadLineResult* = enum ## result for readLineAsync ReadFullLine, ReadPartialLine, ReadDisconnected, ReadNone @@ -140,6 +205,10 @@ proc newSocket*(fd: SocketHandle, domain: Domain = AF_INET, if buffered: result.currPos = 0 + # Set SO_NOSIGPIPE on OS X. + when defined(macosx): + setSockOptInt(fd, SOL_SOCKET, SO_NOSIGPIPE, 1) + proc newSocket*(domain, sockType, protocol: cint, buffered = true): Socket = ## Creates a new socket. ## @@ -160,13 +229,18 @@ proc newSocket*(domain: Domain = AF_INET, sockType: SockType = SOCK_STREAM, raiseOSError(osLastError()) result = newSocket(fd, domain, sockType, protocol, buffered) -when defined(ssl): +when defineSsl: CRYPTO_malloc_init() SslLibraryInit() SslLoadErrorStrings() ErrLoadBioStrings() OpenSSL_add_all_algorithms() + type + SslContextExtraInternal = ref object of RootRef + serverGetPskFunc: SslServerGetPskFunc + clientGetPskFunc: SslClientGetPskFunc + proc raiseSSLError*(s = "") = ## Raises a new SSL error. if s != "": @@ -179,6 +253,34 @@ when defined(ssl): var errStr = ErrErrorString(err, nil) raise newException(SSLError, $errStr) + proc getExtraDataIndex*(ctx: SSLContext): int = + ## Retrieves unique index for storing extra data in SSLContext. + result = SSL_CTX_get_ex_new_index(0, nil, nil, nil, nil).int + if result < 0: + raiseSSLError() + + proc getExtraData*(ctx: SSLContext, index: int): RootRef = + ## Retrieves arbitrary data stored inside SSLContext. + if index notin ctx.referencedData: + raise newException(IndexError, "No data with that index.") + let res = ctx.context.SSL_CTX_get_ex_data(index.cint) + if cast[int](res) == 0: + raiseSSLError() + return cast[RootRef](res) + + proc setExtraData*(ctx: SSLContext, index: int, data: RootRef) = + ## Stores arbitrary data inside SSLContext. The unique `index` + ## should be retrieved using getSslContextExtraDataIndex. + if index in ctx.referencedData: + GC_unref(getExtraData(ctx, index)) + + if ctx.context.SSL_CTX_set_ex_data(index.cint, cast[pointer](data)) == -1: + raiseSSLError() + + if index notin ctx.referencedData: + ctx.referencedData.incl(index) + GC_ref(data) + # http://simplestcodings.blogspot.co.uk/2010/08/secure-server-client-using-openssl-in-c.html proc loadCertificates(ctx: SSL_CTX, certFile, keyFile: string) = if certFile != "" and not existsFile(certFile): @@ -201,7 +303,7 @@ when defined(ssl): raiseSSLError("Verification of private key file failed.") proc newContext*(protVersion = protSSLv23, verifyMode = CVerifyPeer, - certFile = "", keyFile = ""): SSLContext = + certFile = "", keyFile = "", cipherList = "ALL"): SSLContext = ## Creates an SSL context. ## ## Protocol version specifies the protocol to use. SSLv2, SSLv3, TLSv1 @@ -222,13 +324,13 @@ when defined(ssl): of protSSLv23: newCTX = SSL_CTX_new(SSLv23_method()) # SSlv2,3 and TLS1 support. of protSSLv2: - raiseSslError("SSLv2 is no longer secure and has been deprecated, use protSSLv3") + raiseSslError("SSLv2 is no longer secure and has been deprecated, use protSSLv23") of protSSLv3: - newCTX = SSL_CTX_new(SSLv3_method()) + raiseSslError("SSLv3 is no longer secure and has been deprecated, use protSSLv23") of protTLSv1: newCTX = SSL_CTX_new(TLSv1_method()) - if newCTX.SSLCTXSetCipherList("ALL") != 1: + if newCTX.SSLCTXSetCipherList(cipherList) != 1: raiseSSLError() case verifyMode of CVerifyPeer: @@ -240,7 +342,87 @@ when defined(ssl): discard newCTX.SSLCTXSetMode(SSL_MODE_AUTO_RETRY) newCTX.loadCertificates(certFile, keyFile) - return SSLContext(newCTX) + + result = SSLContext(context: newCTX, extraInternalIndex: 0, + referencedData: initSet[int]()) + result.extraInternalIndex = getExtraDataIndex(result) + + let extraInternal = new(SslContextExtraInternal) + result.setExtraData(result.extraInternalIndex, extraInternal) + + proc getExtraInternal(ctx: SSLContext): SslContextExtraInternal = + return SslContextExtraInternal(ctx.getExtraData(ctx.extraInternalIndex)) + + proc destroyContext*(ctx: SSLContext) = + ## Free memory referenced by SSLContext. + + # We assume here that OpenSSL's internal indexes increase by 1 each time. + # That means we can assume that the next internal index is the length of + # extra data indexes. + for i in ctx.referencedData: + GC_unref(getExtraData(ctx, i).RootRef) + ctx.context.SSL_CTX_free() + + proc `pskIdentityHint=`*(ctx: SSLContext, hint: string) = + ## Sets the identity hint passed to server. + ## + ## Only used in PSK ciphersuites. + if ctx.context.SSL_CTX_use_psk_identity_hint(hint) <= 0: + raiseSSLError() + + proc clientGetPskFunc*(ctx: SSLContext): SslClientGetPskFunc = + return ctx.getExtraInternal().clientGetPskFunc + + proc pskClientCallback(ssl: SslPtr; hint: cstring; identity: cstring; max_identity_len: cuint; psk: ptr cuchar; + max_psk_len: cuint): cuint {.cdecl.} = + let ctx = SSLContext(context: ssl.SSL_get_SSL_CTX, extraInternalIndex: 0) + let hintString = if hint == nil: nil else: $hint + let (identityString, pskString) = (ctx.clientGetPskFunc)(hintString) + if psk.len.cuint > max_psk_len: + return 0 + if identityString.len.cuint >= max_identity_len: + return 0 + + copyMem(identity, identityString.cstring, pskString.len + 1) # with the last zero byte + copyMem(psk, pskString.cstring, pskString.len) + + return pskString.len.cuint + + proc `clientGetPskFunc=`*(ctx: SSLContext, fun: SslClientGetPskFunc) = + ## Sets function that returns the client identity and the PSK based on identity + ## hint from the server. + ## + ## Only used in PSK ciphersuites. + ctx.getExtraInternal().clientGetPskFunc = fun + assert ctx.extraInternalIndex == 0, + "The pskClientCallback assumes the extraInternalIndex is 0" + ctx.context.SSL_CTX_set_psk_client_callback( + if fun == nil: nil else: pskClientCallback) + + proc serverGetPskFunc*(ctx: SSLContext): SslServerGetPskFunc = + return ctx.getExtraInternal().serverGetPskFunc + + proc pskServerCallback(ssl: SslCtx; identity: cstring; psk: ptr cuchar; max_psk_len: cint): cuint {.cdecl.} = + let ctx = SSLContext(context: ssl.SSL_get_SSL_CTX, extraInternalIndex: 0) + let pskString = (ctx.serverGetPskFunc)($identity) + if psk.len.cint > max_psk_len: + return 0 + copyMem(psk, pskString.cstring, pskString.len) + + return pskString.len.cuint + + proc `serverGetPskFunc=`*(ctx: SSLContext, fun: SslServerGetPskFunc) = + ## Sets function that returns PSK based on the client identity. + ## + ## Only used in PSK ciphersuites. + ctx.getExtraInternal().serverGetPskFunc = fun + ctx.context.SSL_CTX_set_psk_server_callback(if fun == nil: nil + else: pskServerCallback) + + proc getPskIdentity*(socket: Socket): string = + ## Gets the PSK identity provided by the client. + assert socket.isSSL + return $(socket.sslHandle.SSL_get_psk_identity) proc wrapSocket*(ctx: SSLContext, socket: Socket) = ## Wraps a socket in an SSL context. This function effectively turns @@ -255,7 +437,7 @@ when defined(ssl): assert (not socket.isSSL) socket.isSSL = true socket.sslContext = ctx - socket.sslHandle = SSLNew(SSLCTX(socket.sslContext)) + socket.sslHandle = SSLNew(socket.sslContext.context) socket.sslNoHandshake = false socket.sslHasPeekChar = false if socket.sslHandle == nil: @@ -301,7 +483,7 @@ proc socketError*(socket: Socket, err: int = -1, async = false, ## error was caused by no data being available to be read. ## ## If ``err`` is not lower than 0 no exception will be raised. - when defined(ssl): + when defineSsl: if socket.isSSL: if err <= 0: var ret = SSLGetError(socket.sslHandle, err.cint) @@ -334,7 +516,7 @@ proc socketError*(socket: Socket, err: int = -1, async = false, raiseSSLError() else: raiseSSLError("Unknown Error") - if err == -1 and not (when defined(ssl): socket.isSSL else: false): + if err == -1 and not (when defineSsl: socket.isSSL else: false): var lastE = if lastError.int == -1: getSocketError(socket) else: lastError if async: when useWinVersion: @@ -414,7 +596,7 @@ proc acceptAddr*(server: Socket, client: var Socket, address: var string, client.isBuffered = server.isBuffered # Handle SSL. - when defined(ssl): + when defineSsl: if server.isSSL: # We must wrap the client sock in a ssl context. @@ -425,7 +607,7 @@ proc acceptAddr*(server: Socket, client: var Socket, address: var string, # Client socket is set above. address = $inet_ntoa(sockAddress.sin_addr) -when false: #defined(ssl): +when false: #defineSsl: proc acceptAddrSSL*(server: Socket, client: var Socket, address: var string): SSLAcceptResult {. tags: [ReadIOEffect].} = @@ -444,7 +626,7 @@ when false: #defined(ssl): ## ``AcceptNoClient`` will be returned when no client is currently attempting ## to connect. template doHandshake(): stmt = - when defined(ssl): + when defineSsl: if server.isSSL: client.setBlocking(false) # We must wrap the client sock in a ssl context. @@ -495,7 +677,7 @@ proc accept*(server: Socket, client: var Socket, proc close*(socket: Socket) = ## Closes a socket. try: - when defined(ssl): + when defineSsl: if socket.isSSL: ErrClearError() # As we are closing the underlying socket immediately afterwards, @@ -522,6 +704,7 @@ proc toCInt*(opt: SOBool): cint = of OptKeepAlive: SO_KEEPALIVE of OptOOBInline: SO_OOBINLINE of OptReuseAddr: SO_REUSEADDR + of OptReusePort: SO_REUSEPORT proc getSockOpt*(socket: Socket, opt: SOBool, level = SOL_SOCKET): bool {. tags: [ReadIOEffect].} = @@ -547,8 +730,35 @@ proc setSockOpt*(socket: Socket, opt: SOBool, value: bool, level = SOL_SOCKET) { var valuei = cint(if value: 1 else: 0) setSockOptInt(socket.fd, cint(level), toCInt(opt), valuei) +when defined(posix) and not defined(nimdoc): + proc makeUnixAddr(path: string): Sockaddr_un = + result.sun_family = AF_UNIX.toInt + if path.len >= Sockaddr_un_path_length: + raise newException(ValueError, "socket path too long") + copyMem(addr result.sun_path, path.cstring, path.len + 1) + +when defined(posix): + proc connectUnix*(socket: Socket, path: string) = + ## Connects to Unix socket on `path`. + ## This only works on Unix-style systems: Mac OS X, BSD and Linux + when not defined(nimdoc): + var socketAddr = makeUnixAddr(path) + if socket.fd.connect(cast[ptr SockAddr](addr socketAddr), + sizeof(socketAddr).Socklen) != 0'i32: + raiseOSError(osLastError()) + + proc bindUnix*(socket: Socket, path: string) = + ## Binds Unix socket to `path`. + ## This only works on Unix-style systems: Mac OS X, BSD and Linux + when not defined(nimdoc): + var socketAddr = makeUnixAddr(path) + if socket.fd.bindAddr(cast[ptr SockAddr](addr socketAddr), + sizeof(socketAddr).Socklen) != 0'i32: + raiseOSError(osLastError()) + when defined(ssl): - proc handshake*(socket: Socket): bool {.tags: [ReadIOEffect, WriteIOEffect].} = + proc handshake*(socket: Socket): bool + {.tags: [ReadIOEffect, WriteIOEffect], deprecated.} = ## This proc needs to be called on a socket after it connects. This is ## only applicable when using ``connectAsync``. ## This proc performs the SSL handshake. @@ -557,6 +767,8 @@ when defined(ssl): ## ``True`` whenever handshake completed successfully. ## ## A ESSL error is raised on any other errors. + ## + ## **Note:** This procedure is deprecated since version 0.14.0. result = true if socket.isSSL: var ret = SSLConnect(socket.sslHandle) @@ -594,7 +806,7 @@ proc hasDataBuffered*(s: Socket): bool = if s.isBuffered: result = s.bufLen > 0 and s.currPos != s.bufLen - when defined(ssl): + when defineSsl: if s.isSSL and not result: result = s.sslHasPeekChar @@ -608,7 +820,7 @@ proc select(readfd: Socket, timeout = 500): int = proc readIntoBuf(socket: Socket, flags: int32): int = result = 0 - when defined(ssl): + when defineSsl: if socket.isSSL: result = SSLRead(socket.sslHandle, addr(socket.buffer), int(socket.buffer.high)) else: @@ -658,7 +870,7 @@ proc recv*(socket: Socket, data: pointer, size: int): int {.tags: [ReadIOEffect] result = read else: - when defined(ssl): + when defineSsl: if socket.isSSL: if socket.sslHasPeekChar: copyMem(data, addr(socket.sslPeekChar), 1) @@ -696,7 +908,7 @@ proc waitFor(socket: Socket, waited: var float, timeout, size: int, if timeout - int(waited * 1000.0) < 1: raise newException(TimeoutError, "Call to '" & funcName & "' timed out.") - when defined(ssl): + when defineSsl: if socket.isSSL: if socket.hasDataBuffered: # sslPeekChar is present. @@ -764,7 +976,7 @@ proc peekChar(socket: Socket, c: var char): int {.tags: [ReadIOEffect].} = c = socket.buffer[socket.currPos] else: - when defined(ssl): + when defineSsl: if socket.isSSL: if not socket.sslHasPeekChar: result = SSLRead(socket.sslHandle, addr(socket.sslPeekChar), 1) @@ -792,11 +1004,11 @@ proc readLine*(socket: Socket, line: var TaintedString, timeout = -1, ## ## **Warning**: Only the ``SafeDisconn`` flag is currently supported. - template addNLIfEmpty(): stmt = + template addNLIfEmpty() = if line.len == 0: line.string.add("\c\L") - template raiseSockError(): stmt {.dirty, immediate.} = + template raiseSockError() {.dirty.} = let lastError = getSocketError(socket) if flags.isDisconnectionError(lastError): setLen(line.string, 0); return socket.socketError(n, lastError = lastError) @@ -872,7 +1084,7 @@ proc send*(socket: Socket, data: pointer, size: int): int {. ## ## **Note**: This is a low-level version of ``send``. You likely should use ## the version below. - when defined(ssl): + when defineSsl: if socket.isSSL: return SSLWrite(socket.sslHandle, cast[cstring](data), size) @@ -943,7 +1155,7 @@ proc sendTo*(socket: Socket, address: string, port: Port, proc isSsl*(socket: Socket): bool = ## Determines whether ``socket`` is a SSL socket. - when defined(ssl): + when defineSsl: result = socket.isSSL else: result = false @@ -1253,7 +1465,7 @@ proc connect*(socket: Socket, address: string, dealloc(aiList) if not success: raiseOSError(lastError) - when defined(ssl): + when defineSsl: if socket.isSSL: # RFC3546 for SNI specifies that IP addresses are not allowed. if not isIpAddress(address): @@ -1314,8 +1526,10 @@ proc connect*(socket: Socket, address: string, port = Port(0), if selectWrite(s, timeout) != 1: raise newException(TimeoutError, "Call to 'connect' timed out.") else: - when defined(ssl): + when defineSsl and not defined(nimdoc): if socket.isSSL: socket.fd.setBlocking(true) + {.warning[Deprecated]: off.} doAssert socket.handshake() + {.warning[Deprecated]: on.} socket.fd.setBlocking(true) diff --git a/lib/pure/nimprof.nim b/lib/pure/nimprof.nim index 5a7deaab0..4289eb049 100644 --- a/lib/pure/nimprof.nim +++ b/lib/pure/nimprof.nim @@ -27,19 +27,22 @@ const tickCountCorrection = 50_000 when not declared(system.StackTrace): - type StackTrace = array [0..20, cstring] + type StackTrace = object + lines: array[0..20, cstring] + files: array[0..20, cstring] {.deprecated: [TStackTrace: StackTrace].} + proc `[]`*(st: StackTrace, i: int): cstring = st.lines[i] # We use a simple hash table of bounded size to keep track of the stack traces: type ProfileEntry = object total: int st: StackTrace - ProfileData = array [0..64*1024-1, ptr ProfileEntry] + ProfileData = array[0..64*1024-1, ptr ProfileEntry] {.deprecated: [TProfileEntry: ProfileEntry, TProfileData: ProfileData].} proc `==`(a, b: StackTrace): bool = - for i in 0 .. high(a): + for i in 0 .. high(a.lines): if a[i] != b[i]: return false result = true @@ -72,7 +75,7 @@ proc hookAux(st: StackTrace, costs: int) = # this is quite performance sensitive! when withThreads: acquire profilingLock inc totalCalls - var last = high(st) + var last = high(st.lines) while last > 0 and isNil(st[last]): dec last var h = hash(pointer(st[last])) and high(profileData) @@ -178,7 +181,7 @@ proc writeProfile() {.noconv.} = var perProc = initCountTable[string]() for i in 0..entries-1: var dups = initSet[string]() - for ii in 0..high(StackTrace): + for ii in 0..high(StackTrace.lines): let procname = profileData[i].st[ii] if isNil(procname): break let p = $procname @@ -193,10 +196,11 @@ proc writeProfile() {.noconv.} = writeLine(f, "Entry: ", i+1, "/", entries, " Calls: ", profileData[i].total // totalCalls, " [sum: ", sum, "; ", sum // totalCalls, "]") - for ii in 0..high(StackTrace): + for ii in 0..high(StackTrace.lines): let procname = profileData[i].st[ii] + let filename = profileData[i].st.files[ii] if isNil(procname): break - writeLine(f, " ", procname, " ", perProc[$procname] // totalCalls) + writeLine(f, " ", $filename & ": " & $procname, " ", perProc[$procname] // totalCalls) close(f) echo "... done" else: diff --git a/lib/pure/oids.nim b/lib/pure/oids.nim index fca10dab6..e4c97b260 100644 --- a/lib/pure/oids.nim +++ b/lib/pure/oids.nim @@ -74,8 +74,7 @@ proc genOid*(): Oid = var t = gettime(nil) - var i = int32(incr) - atomicInc(incr) + var i = int32(atomicInc(incr)) if fuzz == 0: # racy, but fine semantically: diff --git a/lib/pure/os.nim b/lib/pure/os.nim index 470559e17..1e474f4d4 100644 --- a/lib/pure/os.nim +++ b/lib/pure/os.nim @@ -26,7 +26,6 @@ elif defined(posix): else: {.error: "OS module not ported to your operating system!".} -include "system/ansi_c" include ospaths when defined(posix): @@ -37,6 +36,23 @@ when defined(posix): var pathMax {.importc: "PATH_MAX", header: "<stdlib.h>".}: cint +proc c_remove(filename: cstring): cint {. + importc: "remove", header: "<stdio.h>".} +proc c_rename(oldname, newname: cstring): cint {. + importc: "rename", header: "<stdio.h>".} +proc c_system(cmd: cstring): cint {. + importc: "system", header: "<stdlib.h>".} +proc c_strerror(errnum: cint): cstring {. + importc: "strerror", header: "<string.h>".} +proc c_strlen(a: cstring): cint {. + importc: "strlen", header: "<string.h>", noSideEffect.} +proc c_getenv(env: cstring): cstring {. + importc: "getenv", header: "<stdlib.h>".} +proc c_putenv(env: cstring): cint {. + importc: "putenv", header: "<stdlib.h>".} + +var errno {.importc, header: "<errno.h>".}: cint + proc osErrorMsg*(): string {.rtl, extern: "nos$1", deprecated.} = ## Retrieves the operating system's error flag, ``errno``. ## On Windows ``GetLastError`` is checked before ``errno``. @@ -50,18 +66,18 @@ proc osErrorMsg*(): string {.rtl, extern: "nos$1", deprecated.} = if err != 0'i32: when useWinUnicode: var msgbuf: WideCString - if formatMessageW(0x00000100 or 0x00001000 or 0x00000200, + if formatMessageW(0x00000100 or 0x00001000 or 0x00000200 or 0x000000FF, nil, err, 0, addr(msgbuf), 0, nil) != 0'i32: result = $msgbuf if msgbuf != nil: localFree(cast[pointer](msgbuf)) else: var msgbuf: cstring - if formatMessageA(0x00000100 or 0x00001000 or 0x00000200, + if formatMessageA(0x00000100 or 0x00001000 or 0x00000200 or 0x000000FF, nil, err, 0, addr(msgbuf), 0, nil) != 0'i32: result = $msgbuf if msgbuf != nil: localFree(msgbuf) if errno != 0'i32: - result = $os.strerror(errno) + result = $os.c_strerror(errno) {.push warning[deprecated]: off.} proc raiseOSError*(msg: string = "") {.noinline, rtl, extern: "nos$1", @@ -114,7 +130,7 @@ proc osErrorMsg*(errorCode: OSErrorCode): string = if msgbuf != nil: localFree(msgbuf) else: if errorCode != OSErrorCode(0'i32): - result = $os.strerror(errorCode.int32) + result = $os.c_strerror(errorCode.int32) proc raiseOSError*(errorCode: OSErrorCode; additionalInfo = "") {.noinline.} = ## Raises an ``OSError`` exception. The ``errorCode`` will determine the @@ -129,7 +145,7 @@ proc raiseOSError*(errorCode: OSErrorCode; additionalInfo = "") {.noinline.} = if additionalInfo.len == 0: e.msg = osErrorMsg(errorCode) else: - e.msg = additionalInfo & " " & osErrorMsg(errorCode) + e.msg = osErrorMsg(errorCode) & "\nAdditional info: " & additionalInfo if e.msg == "": e.msg = "unknown OS error" raise e @@ -157,24 +173,24 @@ proc osLastError*(): OSErrorCode = when defined(windows): when useWinUnicode: - template wrapUnary(varname, winApiProc, arg: expr) {.immediate.} = + template wrapUnary(varname, winApiProc, arg: untyped) = var varname = winApiProc(newWideCString(arg)) - template wrapBinary(varname, winApiProc, arg, arg2: expr) {.immediate.} = + template wrapBinary(varname, winApiProc, arg, arg2: untyped) = var varname = winApiProc(newWideCString(arg), arg2) proc findFirstFile(a: string, b: var WIN32_FIND_DATA): Handle = result = findFirstFileW(newWideCString(a), b) - template findNextFile(a, b: expr): expr = findNextFileW(a, b) - template getCommandLine(): expr = getCommandLineW() + template findNextFile(a, b: untyped): untyped = findNextFileW(a, b) + template getCommandLine(): untyped = getCommandLineW() - template getFilename(f: expr): expr = + template getFilename(f: untyped): untyped = $cast[WideCString](addr(f.cFilename[0])) else: - template findFirstFile(a, b: expr): expr = findFirstFileA(a, b) - template findNextFile(a, b: expr): expr = findNextFileA(a, b) - template getCommandLine(): expr = getCommandLineA() + template findFirstFile(a, b: untyped): untyped = findFirstFileA(a, b) + template findNextFile(a, b: untyped): untyped = findNextFileA(a, b) + template getCommandLine(): untyped = getCommandLineA() - template getFilename(f: expr): expr = $f.cFilename + template getFilename(f: untyped): untyped = $f.cFilename proc skipFindData(f: WIN32_FIND_DATA): bool {.inline.} = # Note - takes advantage of null delimiter in the cstring @@ -320,7 +336,7 @@ proc setCurrentDir*(newDir: string) {.inline, tags: [].} = proc expandFilename*(filename: string): string {.rtl, extern: "nos$1", tags: [ReadDirEffect].} = - ## Returns the full path of `filename`, raises OSError in case of an error. + ## Returns the full (`absolute`:idx:) path of the file `filename`, raises OSError in case of an error. when defined(windows): const bufsize = 3072'i32 when useWinUnicode: @@ -762,12 +778,26 @@ iterator envPairs*(): tuple[key, value: TaintedString] {.tags: [ReadEnvEffect].} yield (TaintedString(substr(environment[i], 0, p-1)), TaintedString(substr(environment[i], p+1))) -iterator walkFiles*(pattern: string): string {.tags: [ReadDirEffect].} = - ## Iterate over all the files that match the `pattern`. On POSIX this uses - ## the `glob`:idx: call. - ## - ## `pattern` is OS dependent, but at least the "\*.ext" - ## notation is supported. +# Templates for filtering directories and files +when defined(windows): + template isDir(f: WIN32_FIND_DATA): bool = + (f.dwFileAttributes and FILE_ATTRIBUTE_DIRECTORY) != 0'i32 + template isFile(f: WIN32_FIND_DATA): bool = + not isDir(f) +else: + template isDir(f: string): bool = + dirExists(f) + template isFile(f: string): bool = + fileExists(f) + +template defaultWalkFilter(item): bool = + ## Walk filter used to return true on both + ## files and directories + true + +template walkCommon(pattern: string, filter) = + ## Common code for getting the files and directories with the + ## specified `pattern` when defined(windows): var f: WIN32_FIND_DATA @@ -776,8 +806,7 @@ iterator walkFiles*(pattern: string): string {.tags: [ReadDirEffect].} = if res != -1: defer: findClose(res) while true: - if not skipFindData(f) and - (f.dwFileAttributes and FILE_ATTRIBUTE_DIRECTORY) == 0'i32: + if not skipFindData(f) and filter(f): # Windows bug/gotcha: 't*.nim' matches 'tfoo.nims' -.- so we check # that the file extensions have the same length ... let ff = getFilename(f) @@ -799,7 +828,33 @@ iterator walkFiles*(pattern: string): string {.tags: [ReadDirEffect].} = if res == 0: for i in 0.. f.gl_pathc - 1: assert(f.gl_pathv[i] != nil) - yield $f.gl_pathv[i] + let path = $f.gl_pathv[i] + if filter(path): + yield path + +iterator walkPattern*(pattern: string): string {.tags: [ReadDirEffect].} = + ## Iterate over all the files and directories that match the `pattern`. + ## On POSIX this uses the `glob`:idx: call. + ## + ## `pattern` is OS dependent, but at least the "\*.ext" + ## notation is supported. + walkCommon(pattern, defaultWalkFilter) + +iterator walkFiles*(pattern: string): string {.tags: [ReadDirEffect].} = + ## Iterate over all the files that match the `pattern`. On POSIX this uses + ## the `glob`:idx: call. + ## + ## `pattern` is OS dependent, but at least the "\*.ext" + ## notation is supported. + walkCommon(pattern, isFile) + +iterator walkDirs*(pattern: string): string {.tags: [ReadDirEffect].} = + ## Iterate over all the directories that match the `pattern`. + ## On POSIX this uses the `glob`:idx: call. + ## + ## `pattern` is OS dependent, but at least the "\*.ext" + ## notation is supported. + walkCommon(pattern, isDir) type PathComponent* = enum ## Enumeration specifying a path component. @@ -1067,7 +1122,7 @@ proc parseCmdLine*(c: string): seq[string] {. while true: setLen(a, 0) # eat all delimiting whitespace - while c[i] == ' ' or c[i] == '\t' or c [i] == '\l' or c [i] == '\r' : inc(i) + while c[i] == ' ' or c[i] == '\t' or c[i] == '\l' or c[i] == '\r' : inc(i) when defined(windows): # parse a single argument according to the above rules: if c[i] == '\0': break @@ -1447,13 +1502,13 @@ type lastWriteTime*: Time # Time file was last modified/written to. creationTime*: Time # Time file was created. Not supported on all systems! -template rawToFormalFileInfo(rawInfo, formalInfo): expr = +template rawToFormalFileInfo(rawInfo, formalInfo): untyped = ## Transforms the native file info structure into the one nim uses. ## 'rawInfo' is either a 'TBY_HANDLE_FILE_INFORMATION' structure on Windows, ## or a 'Stat' structure on posix when defined(Windows): - template toTime(e): expr = winTimeToUnixTime(rdFileTime(e)) - template merge(a, b): expr = a or (b shl 32) + template toTime(e: FILETIME): untyped {.gensym.} = winTimeToUnixTime(rdFileTime(e)) # local templates default to bind semantics + template merge(a, b): untyped = a or (b shl 32) formalInfo.id.device = rawInfo.dwVolumeSerialNumber formalInfo.id.file = merge(rawInfo.nFileIndexLow, rawInfo.nFileIndexHigh) formalInfo.size = merge(rawInfo.nFileSizeLow, rawInfo.nFileSizeHigh) @@ -1478,7 +1533,7 @@ template rawToFormalFileInfo(rawInfo, formalInfo): expr = else: - template checkAndIncludeMode(rawMode, formalMode: expr) = + template checkAndIncludeMode(rawMode, formalMode: untyped) = if (rawInfo.st_mode and rawMode) != 0'i32: formalInfo.permissions.incl(formalMode) formalInfo.id = (rawInfo.st_dev, rawInfo.st_ino) diff --git a/lib/pure/ospaths.nim b/lib/pure/ospaths.nim index 9fc816f2f..56671ee62 100644 --- a/lib/pure/ospaths.nim +++ b/lib/pure/ospaths.nim @@ -10,7 +10,7 @@ # Included by the ``os`` module but a module in its own right for NimScript # support. -when isMainModule: +when not declared(os): {.pragma: rtl.} import strutils @@ -556,12 +556,20 @@ when declared(getEnv) or defined(nimscript): yield substr(s, first, last-1) inc(last) - proc findExe*(exe: string): string {. + when not defined(windows) and declared(os): + proc checkSymlink(path: string): bool = + var rawInfo: Stat + if lstat(path, rawInfo) < 0'i32: result = false + else: result = S_ISLNK(rawInfo.st_mode) + + proc findExe*(exe: string, followSymlinks: bool = true): string {. tags: [ReadDirEffect, ReadEnvEffect, ReadIOEffect].} = ## Searches for `exe` in the current working directory and then ## in directories listed in the ``PATH`` environment variable. ## Returns "" if the `exe` cannot be found. On DOS-like platforms, `exe` ## is added the `ExeExt <#ExeExt>`_ file extension if it has none. + ## If the system supports symlinks it also resolves them until it + ## meets the actual file. This behavior can be disabled if desired. result = addFileExt(exe, ExeExt) if existsFile(result): return var path = string(getEnv("PATH")) @@ -572,7 +580,25 @@ when declared(getEnv) or defined(nimscript): result else: var x = expandTilde(candidate) / result - if existsFile(x): return x + if existsFile(x): + when not defined(windows) and declared(os): + while followSymlinks: # doubles as if here + if x.checkSymlink: + var r = newString(256) + var len = readlink(x, r, 256) + if len < 0: + raiseOSError(osLastError()) + if len > 256: + r = newString(len+1) + len = readlink(x, r, len) + setLen(r, len) + if isAbsolute(r): + x = r + else: + x = parentDir(x) / r + else: + break + return x result = "" when defined(nimscript) or (defined(nimdoc) and not declared(os)): diff --git a/lib/pure/osproc.nim b/lib/pure/osproc.nim index 38b0ed4a3..7378520e3 100644 --- a/lib/pure/osproc.nim +++ b/lib/pure/osproc.nim @@ -89,7 +89,7 @@ proc quoteShellWindows*(s: string): string {.noSideEffect, rtl, extern: "nosp$1" result.add("\"") proc quoteShellPosix*(s: string): string {.noSideEffect, rtl, extern: "nosp$1".} = - ## Quote s, so it can be safely passed to POSIX shell. + ## Quote ``s``, so it can be safely passed to POSIX shell. ## Based on Python's pipes.quote const safeUnixChars = {'%', '+', '-', '.', '/', '_', ':', '=', '@', '0'..'9', 'A'..'Z', 'a'..'z'} @@ -104,7 +104,7 @@ proc quoteShellPosix*(s: string): string {.noSideEffect, rtl, extern: "nosp$1".} return "'" & s.replace("'", "'\"'\"'") & "'" proc quoteShell*(s: string): string {.noSideEffect, rtl, extern: "nosp$1".} = - ## Quote s, so it can be safely passed to shell. + ## Quote ``s``, so it can be safely passed to shell. when defined(Windows): return quoteShellWindows(s) elif defined(posix): @@ -175,7 +175,11 @@ proc startCmd*(command: string, options: set[ProcessOption] = { result = startProcess(command=command, options=options + {poEvalCommand}) proc close*(p: Process) {.rtl, extern: "nosp$1", tags: [].} - ## When the process has finished executing, cleanup related handles + ## When the process has finished executing, cleanup related handles. + ## + ## **Warning:** If the process has not finished executing, this will forcibly + ## terminate the process. Doing so may result in zombie processes and + ## `pty leaks <http://stackoverflow.com/questions/27021641/how-to-fix-request-failed-on-channel-0>`_. proc suspend*(p: Process) {.rtl, extern: "nosp$1", tags: [].} ## Suspends the process `p`. @@ -400,15 +404,16 @@ when defined(Windows) and not defined(useNimRtl): result = cast[cstring](alloc0(res.len+1)) copyMem(result, cstring(res), res.len) - proc buildEnv(env: StringTableRef): cstring = + proc buildEnv(env: StringTableRef): tuple[str: cstring, len: int] = var L = 0 for key, val in pairs(env): inc(L, key.len + val.len + 2) - result = cast[cstring](alloc0(L+2)) + var str = cast[cstring](alloc0(L+2)) L = 0 for key, val in pairs(env): var x = key & "=" & val - copyMem(addr(result[L]), cstring(x), x.len+1) # copy \0 + copyMem(addr(str[L]), cstring(x), x.len+1) # copy \0 inc(L, x.len+1) + (str, L) #proc open_osfhandle(osh: Handle, mode: int): int {. # importc: "_open_osfhandle", header: "<fcntl.h>".} @@ -526,13 +531,15 @@ when defined(Windows) and not defined(useNimRtl): else: cmdl = buildCommandLine(command, args) var wd: cstring = nil - var e: cstring = nil + var e = (str: nil.cstring, len: -1) if len(workingDir) > 0: wd = workingDir if env != nil: e = buildEnv(env) if poEchoCmd in options: echo($cmdl) when useWinUnicode: var tmp = newWideCString(cmdl) - var ee = newWideCString(e) + var ee = + if e.str.isNil: nil + else: newWideCString(e.str, e.len) var wwd = newWideCString(wd) var flags = NORMAL_PRIORITY_CLASS or CREATE_UNICODE_ENVIRONMENT if poDemon in options: flags = flags or CREATE_NO_WINDOW @@ -549,7 +556,7 @@ when defined(Windows) and not defined(useNimRtl): if poStdErrToStdOut notin options: fileClose(si.hStdError) - if e != nil: dealloc(e) + if e.str != nil: dealloc(e.str) if success == 0: if poInteractive in result.options: close(result) const errInvalidParameter = 87.int @@ -721,7 +728,7 @@ elif not defined(useNimRtl): env: StringTableRef = nil, options: set[ProcessOption] = {poStdErrToStdOut}): Process = var - pStdin, pStdout, pStderr: array [0..1, cint] + pStdin, pStdout, pStderr: array[0..1, cint] new(result) result.options = options result.exitCode = -3 # for ``waitForExit`` @@ -875,8 +882,9 @@ elif not defined(useNimRtl): var error: cint let sizeRead = read(data.pErrorPipe[readIdx], addr error, sizeof(error)) if sizeRead == sizeof(error): - raiseOSError("Could not find command: '$1'. OS error: $2" % - [$data.sysCommand, $strerror(error)]) + raiseOSError(osLastError(), + "Could not find command: '$1'. OS error: $2" % + [$data.sysCommand, $strerror(error)]) return pid @@ -967,16 +975,168 @@ elif not defined(useNimRtl): if kill(p.id, SIGKILL) != 0'i32: raiseOsError(osLastError()) - proc waitForExit(p: Process, timeout: int = -1): int = - #if waitPid(p.id, p.exitCode, 0) == int(p.id): - # ``waitPid`` fails if the process is not running anymore. But then - # ``running`` probably set ``p.exitCode`` for us. Since ``p.exitCode`` is - # initialized with -3, wrong success exit codes are prevented. - if p.exitCode != -3: return p.exitCode - if waitpid(p.id, p.exitCode, 0) < 0: - p.exitCode = -3 - raiseOSError(osLastError()) - result = int(p.exitCode) shr 8 + when defined(macosx) or defined(freebsd) or defined(netbsd) or + defined(openbsd): + import kqueue, times + + proc waitForExit(p: Process, timeout: int = -1): int = + if p.exitCode != -3: return p.exitCode + if timeout == -1: + if waitpid(p.id, p.exitCode, 0) < 0: + p.exitCode = -3 + raiseOSError(osLastError()) + else: + var kqFD = kqueue() + if kqFD == -1: + raiseOSError(osLastError()) + + var kevIn = KEvent(ident: p.id.uint, filter: EVFILT_PROC, + flags: EV_ADD, fflags: NOTE_EXIT) + var kevOut: KEvent + var tmspec: Timespec + + if timeout >= 1000: + tmspec.tv_sec = (timeout div 1_000).Time + tmspec.tv_nsec = (timeout %% 1_000) * 1_000_000 + else: + tmspec.tv_sec = 0.Time + tmspec.tv_nsec = (timeout * 1_000_000) + + try: + while true: + var count = kevent(kqFD, addr(kevIn), 1, addr(kevOut), 1, + addr(tmspec)) + if count < 0: + let err = osLastError() + if err.cint != EINTR: + raiseOSError(osLastError()) + elif count == 0: + # timeout expired, so we trying to kill process + if posix.kill(p.id, SIGKILL) == -1: + raiseOSError(osLastError()) + if waitpid(p.id, p.exitCode, 0) < 0: + p.exitCode = -3 + raiseOSError(osLastError()) + break + else: + if kevOut.ident == p.id.uint and kevOut.filter == EVFILT_PROC: + if waitpid(p.id, p.exitCode, 0) < 0: + p.exitCode = -3 + raiseOSError(osLastError()) + break + else: + raiseOSError(osLastError()) + finally: + discard posix.close(kqFD) + + result = int(p.exitCode) shr 8 + else: + import times + + const + hasThreadSupport = compileOption("threads") and not defined(nimscript) + + proc waitForExit(p: Process, timeout: int = -1): int = + template adjustTimeout(t, s, e: Timespec) = + var diff: int + var b: Timespec + b.tv_sec = e.tv_sec + b.tv_nsec = e.tv_nsec + e.tv_sec = (e.tv_sec - s.tv_sec).Time + if e.tv_nsec >= s.tv_nsec: + e.tv_nsec -= s.tv_nsec + else: + if e.tv_sec == 0.Time: + raise newException(ValueError, "System time was modified") + else: + diff = s.tv_nsec - e.tv_nsec + e.tv_nsec = 1_000_000_000 - diff + t.tv_sec = (t.tv_sec - e.tv_sec).Time + if t.tv_nsec >= e.tv_nsec: + t.tv_nsec -= e.tv_nsec + else: + t.tv_sec = (int(t.tv_sec) - 1).Time + diff = e.tv_nsec - t.tv_nsec + t.tv_nsec = 1_000_000_000 - diff + s.tv_sec = b.tv_sec + s.tv_nsec = b.tv_nsec + + #if waitPid(p.id, p.exitCode, 0) == int(p.id): + # ``waitPid`` fails if the process is not running anymore. But then + # ``running`` probably set ``p.exitCode`` for us. Since ``p.exitCode`` is + # initialized with -3, wrong success exit codes are prevented. + if p.exitCode != -3: return p.exitCode + if timeout == -1: + if waitpid(p.id, p.exitCode, 0) < 0: + p.exitCode = -3 + raiseOSError(osLastError()) + else: + var nmask, omask: Sigset + var sinfo: SigInfo + var stspec, enspec, tmspec: Timespec + + discard sigemptyset(nmask) + discard sigemptyset(omask) + discard sigaddset(nmask, SIGCHLD) + + when hasThreadSupport: + if pthread_sigmask(SIG_BLOCK, nmask, omask) == -1: + raiseOSError(osLastError()) + else: + if sigprocmask(SIG_BLOCK, nmask, omask) == -1: + raiseOSError(osLastError()) + + if timeout >= 1000: + tmspec.tv_sec = (timeout div 1_000).Time + tmspec.tv_nsec = (timeout %% 1_000) * 1_000_000 + else: + tmspec.tv_sec = 0.Time + tmspec.tv_nsec = (timeout * 1_000_000) + + try: + if clock_gettime(CLOCK_REALTIME, stspec) == -1: + raiseOSError(osLastError()) + while true: + let res = sigtimedwait(nmask, sinfo, tmspec) + if res == SIGCHLD: + if sinfo.si_pid == p.id: + if waitpid(p.id, p.exitCode, 0) < 0: + p.exitCode = -3 + raiseOSError(osLastError()) + break + else: + # we have SIGCHLD, but not for process we are waiting, + # so we need to adjust timeout value and continue + if clock_gettime(CLOCK_REALTIME, enspec) == -1: + raiseOSError(osLastError()) + adjustTimeout(tmspec, stspec, enspec) + elif res < 0: + let err = osLastError() + if err.cint == EINTR: + # we have received another signal, so we need to + # adjust timeout and continue + if clock_gettime(CLOCK_REALTIME, enspec) == -1: + raiseOSError(osLastError()) + adjustTimeout(tmspec, stspec, enspec) + elif err.cint == EAGAIN: + # timeout expired, so we trying to kill process + if posix.kill(p.id, SIGKILL) == -1: + raiseOSError(osLastError()) + if waitpid(p.id, p.exitCode, 0) < 0: + p.exitCode = -3 + raiseOSError(osLastError()) + break + else: + raiseOSError(err) + finally: + when hasThreadSupport: + if pthread_sigmask(SIG_UNBLOCK, nmask, omask) == -1: + raiseOSError(osLastError()) + else: + if sigprocmask(SIG_UNBLOCK, nmask, omask) == -1: + raiseOSError(osLastError()) + + result = int(p.exitCode) shr 8 proc peekExitCode(p: Process): int = if p.exitCode != -3: return p.exitCode diff --git a/lib/pure/oswalkdir.nim b/lib/pure/oswalkdir.nim index 000fe25a3..23ca0566a 100644 --- a/lib/pure/oswalkdir.nim +++ b/lib/pure/oswalkdir.nim @@ -1,3 +1,11 @@ +# +# +# Nim's Runtime Library +# (c) Copyright 2015 Andreas Rumpf +# +# See the file "copying.txt", included in this +# distribution, for details about the copyright. +# ## Compile-time only version for walkDir if you need it at compile-time ## for JavaScript. diff --git a/lib/pure/parsecfg.nim b/lib/pure/parsecfg.nim index 9bcac0a50..c648b0703 100644 --- a/lib/pure/parsecfg.nim +++ b/lib/pure/parsecfg.nim @@ -15,17 +15,78 @@ ## This is an example of how a configuration file may look like: ## -## .. include:: doc/mytest.cfg +## .. include:: ../../doc/mytest.cfg ## :literal: ## The file ``examples/parsecfgex.nim`` demonstrates how to use the ## configuration file parser: ## ## .. code-block:: nim -## :file: examples/parsecfgex.nim - +## :file: ../../examples/parsecfgex.nim +## +## Examples +## -------- +## +## This is an example of a configuration file. +## +## :: +## +## charset = "utf-8" +## [Package] +## name = "hello" +## --threads:on +## [Author] +## name = "lihf8515" +## qq = "10214028" +## email = "lihaifeng@wxm.com" +## +## Creating a configuration file. +## ============================== +## .. code-block:: nim +## +## import parsecfg +## var dict=newConfig() +## dict.setSectionKey("","charset","utf-8") +## dict.setSectionKey("Package","name","hello") +## dict.setSectionKey("Package","--threads","on") +## dict.setSectionKey("Author","name","lihf8515") +## dict.setSectionKey("Author","qq","10214028") +## dict.setSectionKey("Author","email","lihaifeng@wxm.com") +## dict.writeConfig("config.ini") +## +## Reading a configuration file. +## ============================= +## .. code-block:: nim +## +## import parsecfg +## var dict = loadConfig("config.ini") +## var charset = dict.getSectionValue("","charset") +## var threads = dict.getSectionValue("Package","--threads") +## var pname = dict.getSectionValue("Package","name") +## var name = dict.getSectionValue("Author","name") +## var qq = dict.getSectionValue("Author","qq") +## var email = dict.getSectionValue("Author","email") +## echo pname & "\n" & name & "\n" & qq & "\n" & email +## +## Modifying a configuration file. +## =============================== +## .. code-block:: nim +## +## import parsecfg +## var dict = loadConfig("config.ini") +## dict.setSectionKey("Author","name","lhf") +## dict.writeConfig("config.ini") +## +## Deleting a section key in a configuration file. +## =============================================== +## .. code-block:: nim +## +## import parsecfg +## var dict = loadConfig("config.ini") +## dict.delSectionKey("Author","email") +## dict.writeConfig("config.ini") import - hashes, strutils, lexbase, streams + hashes, strutils, lexbase, streams, tables include "system/inclrtl" @@ -70,7 +131,7 @@ type # implementation const - SymChars = {'a'..'z', 'A'..'Z', '0'..'9', '_', '\x80'..'\xFF', '.', '/', '\\'} + SymChars = {'a'..'z', 'A'..'Z', '0'..'9', '_', '\x80'..'\xFF', '.', '/', '\\', '-'} proc rawGetTok(c: var CfgParser, tok: var Token) {.gcsafe.} @@ -359,3 +420,138 @@ proc next*(c: var CfgParser): CfgEvent {.rtl, extern: "npc$1".} = result.kind = cfgError result.msg = errorStr(c, "invalid token: " & c.tok.literal) rawGetTok(c, c.tok) + +# ---------------- Configuration file related operations ---------------- +type + Config* = OrderedTableRef[string, OrderedTableRef[string, string]] + +proc newConfig*(): Config = + ## Create a new configuration table. + ## Useful when wanting to create a configuration file. + result = newOrderedTable[string, OrderedTableRef[string, string]]() + +proc loadConfig*(filename: string): Config = + ## Load the specified configuration file into a new Config instance. + var dict = newOrderedTable[string, OrderedTableRef[string, string]]() + var curSection = "" ## Current section, + ## the default value of the current section is "", + ## which means that the current section is a common + var p: CfgParser + var fileStream = newFileStream(filename, fmRead) + if fileStream != nil: + open(p, fileStream, filename) + while true: + var e = next(p) + case e.kind + of cfgEof: + break + of cfgSectionStart: # Only look for the first time the Section + curSection = e.section + of cfgKeyValuePair: + var t = newOrderedTable[string, string]() + if dict.hasKey(curSection): + t = dict[curSection] + t[e.key] = e.value + dict[curSection] = t + of cfgOption: + var c = newOrderedTable[string, string]() + if dict.hasKey(curSection): + c = dict[curSection] + c["--" & e.key] = e.value + dict[curSection] = c + of cfgError: + break + close(p) + result = dict + +proc replace(s: string): string = + var d = "" + var i = 0 + while i < s.len(): + if s[i] == '\\': + d.add(r"\\") + elif s[i] == '\c' and s[i+1] == '\L': + d.add(r"\n") + inc(i) + elif s[i] == '\c': + d.add(r"\n") + elif s[i] == '\L': + d.add(r"\n") + else: + d.add(s[i]) + inc(i) + result = d + +proc writeConfig*(dict: Config, filename: string) = + ## Writes the contents of the table to the specified configuration file. + ## Note: Comment statement will be ignored. + var file: File + if file.open(filename, fmWrite): + try: + var section, key, value, kv, segmentChar:string + for pair in dict.pairs(): + section = pair[0] + if section != "": ## Not general section + if not allCharsInSet(section, SymChars): ## Non system character + file.writeLine("[\"" & section & "\"]") + else: + file.writeLine("[" & section & "]") + for pair2 in pair[1].pairs(): + key = pair2[0] + value = pair2[1] + if key[0] == '-' and key[1] == '-': ## If it is a command key + segmentChar = ":" + if not allCharsInSet(key[2..key.len()-1], SymChars): + kv.add("--\"") + kv.add(key[2..key.len()-1]) + kv.add("\"") + else: + kv = key + else: + segmentChar = "=" + kv = key + if value != "": ## If the key is not empty + if not allCharsInSet(value, SymChars): + kv.add(segmentChar) + kv.add("\"") + kv.add(replace(value)) + kv.add("\"") + else: + kv.add(segmentChar) + kv.add(value) + file.writeLine(kv) + except: + raise + finally: + file.close() + +proc getSectionValue*(dict: Config, section, key: string): string = + ## Gets the Key value of the specified Section. + if dict.haskey(section): + if dict[section].hasKey(key): + result = dict[section][key] + else: + result = "" + else: + result = "" + +proc setSectionKey*(dict: var Config, section, key, value: string) = + ## Sets the Key value of the specified Section. + var t = newOrderedTable[string, string]() + if dict.hasKey(section): + t = dict[section] + t[key] = value + dict[section] = t + +proc delSection*(dict: var Config, section: string) = + ## Deletes the specified section and all of its sub keys. + dict.del(section) + +proc delSectionKey*(dict: var Config, section, key: string) = + ## Delete the key of the specified section. + if dict.haskey(section): + if dict[section].hasKey(key): + if dict[section].len() == 1: + dict.del(section) + else: + dict[section].del(key) diff --git a/lib/pure/parsecsv.nim b/lib/pure/parsecsv.nim index af51e1201..77b145a73 100644 --- a/lib/pure/parsecsv.nim +++ b/lib/pure/parsecsv.nim @@ -25,6 +25,28 @@ ## echo "##", val, "##" ## close(x) ## +## For CSV files with a header row, the header can be read and then used as a +## reference for item access with `rowEntry <#rowEntry.CsvParser.string>`_: +## +## .. code-block:: nim +## import parsecsv +## import os +## # Prepare a file +## var csv_content = """One,Two,Three,Four +## 1,2,3,4 +## 10,20,30,40 +## 100,200,300,400 +## """ +## writeFile("temp.csv", content) +## +## var p: CsvParser +## p.open("temp.csv") +## p.readHeaderRow() +## while p.readRow(): +## echo "new row: " +## for col in items(p.headers): +## echo "##", col, ":", p.rowEntry(col), "##" +## p.close() import lexbase, streams @@ -37,6 +59,9 @@ type sep, quote, esc: char skipWhite: bool currRow: int + headers*: seq[string] ## The columns that are defined in the csv file + ## (read using `readHeaderRow <#readHeaderRow.CsvParser>`_). + ## Used with `rowEntry <#rowEntry.CsvParser.string>`_). CsvError* = object of IOError ## exception that is raised if ## a parsing error occurs @@ -77,6 +102,15 @@ proc open*(my: var CsvParser, input: Stream, filename: string, my.row = @[] my.currRow = 0 +proc open*(my: var CsvParser, filename: string, + separator = ',', quote = '"', escape = '\0', + skipInitialSpace = false) = + ## same as the other `open` but creates the file stream for you. + var s = newFileStream(filename, fmRead) + if s == nil: my.error(0, "cannot open: " & filename) + open(my, s, filename, separator, + quote, escape, skipInitialSpace) + proc parseField(my: var CsvParser, a: var string) = var pos = my.bufpos var buf = my.buf @@ -131,6 +165,8 @@ proc readRow*(my: var CsvParser, columns = 0): bool = ## reads the next row; if `columns` > 0, it expects the row to have ## exactly this many columns. Returns false if the end of the file ## has been encountered else true. + ## + ## Blank lines are skipped. var col = 0 # current column var oldpos = my.bufpos while my.buf[my.bufpos] != '\0': @@ -166,6 +202,22 @@ proc close*(my: var CsvParser) {.inline.} = ## closes the parser `my` and its associated input stream. lexbase.close(my) +proc readHeaderRow*(my: var CsvParser) = + ## Reads the first row and creates a look-up table for column numbers + ## See also `rowEntry <#rowEntry.CsvParser.string>`_. + var present = my.readRow() + if present: + my.headers = my.row + +proc rowEntry*(my: var CsvParser, entry: string): string = + ## Reads a specified `entry` from the current row. + ## + ## Assumes that `readHeaderRow <#readHeaderRow.CsvParser>`_ has already been + ## called. + var index = my.headers.find(entry) + if index >= 0: + result = my.row[index] + when not defined(testing) and isMainModule: import os var s = newFileStream(paramStr(1), fmRead) @@ -178,3 +230,35 @@ when not defined(testing) and isMainModule: echo "##", val, "##" close(x) +when isMainModule: + import os + import strutils + block: # Tests for reading the header row + var content = "One,Two,Three,Four\n1,2,3,4\n10,20,30,40,\n100,200,300,400\n" + writeFile("temp.csv", content) + + var p: CsvParser + p.open("temp.csv") + p.readHeaderRow() + while p.readRow(): + var zeros = repeat('0', p.currRow-2) + doAssert p.rowEntry("One") == "1" & zeros + doAssert p.rowEntry("Two") == "2" & zeros + doAssert p.rowEntry("Three") == "3" & zeros + doAssert p.rowEntry("Four") == "4" & zeros + p.close() + + when not defined(testing): + var parser: CsvParser + parser.open("temp.csv") + parser.readHeaderRow() + while parser.readRow(): + echo "new row: " + for col in items(parser.headers): + echo "##", col, ":", parser.rowEntry(col), "##" + parser.close() + removeFile("temp.csv") + + # Tidy up + removeFile("temp.csv") + diff --git a/lib/pure/parseutils.nim b/lib/pure/parseutils.nim index 698bde42a..fb7d72182 100644 --- a/lib/pure/parseutils.nim +++ b/lib/pure/parseutils.nim @@ -173,6 +173,22 @@ proc parseUntil*(s: string, token: var string, until: char, result = i-start token = substr(s, start, i-1) +proc parseUntil*(s: string, token: var string, until: string, + start = 0): int {.inline.} = + ## parses a token and stores it in ``token``. Returns + ## the number of the parsed characters or 0 in case of an error. A token + ## consists of any character that comes before the `until` token. + var i = start + while i < s.len: + if s[i] == until[0]: + var u = 1 + while i+u < s.len and u < until.len and s[i+u] == until[u]: + inc u + if u >= until.len: break + inc(i) + result = i-start + token = substr(s, start, i-1) + proc parseWhile*(s: string, token: var string, validChars: set[char], start = 0): int {.inline.} = ## parses a token and stores it in ``token``. Returns @@ -234,6 +250,51 @@ proc parseInt*(s: string, number: var int, start = 0): int {. elif result != 0: number = int(res) +# overflowChecks doesn't work with uint64 +proc rawParseUInt(s: string, b: var uint64, start = 0): int = + var + res = 0'u64 + prev = 0'u64 + i = start + if s[i] == '+': inc(i) # Allow + if s[i] in {'0'..'9'}: + b = 0 + while s[i] in {'0'..'9'}: + prev = res + res = res * 10 + (ord(s[i]) - ord('0')).uint64 + if prev > res: + return 0 # overflowChecks emulation + inc(i) + while s[i] == '_': inc(i) # underscores are allowed and ignored + b = res + result = i - start + +proc parseBiggestUInt*(s: string, number: var uint64, start = 0): int {. + rtl, extern: "npuParseBiggestUInt", noSideEffect.} = + ## parses an unsigned integer starting at `start` and stores the value + ## into `number`. + ## Result is the number of processed chars or 0 if there is no integer + ## or overflow detected. + var res: uint64 + # use 'res' for exception safety (don't write to 'number' in case of an + # overflow exception): + result = rawParseUInt(s, res, start) + number = res + +proc parseUInt*(s: string, number: var uint, start = 0): int {. + rtl, extern: "npuParseUInt", noSideEffect.} = + ## parses an unsigned integer starting at `start` and stores the value + ## into `number`. + ## Result is the number of processed chars or 0 if there is no integer or + ## overflow detected. + var res: uint64 + result = parseBiggestUInt(s, res, start) + if (sizeof(uint) <= 4) and + (res > 0xFFFF_FFFF'u64): + raise newException(OverflowError, "overflow") + elif result != 0: + number = uint(res) + proc parseBiggestFloat*(s: string, number: var BiggestFloat, start = 0): int {. magic: "ParseBiggestFloat", importc: "nimParseBiggestFloat", noSideEffect.} ## parses a float starting at `start` and stores the value into `number`. diff --git a/lib/pure/parsexml.nim b/lib/pure/parsexml.nim index f8b2c3d8d..d16a55302 100644 --- a/lib/pure/parsexml.nim +++ b/lib/pure/parsexml.nim @@ -34,7 +34,7 @@ ## document. ## ## .. code-block:: nim -## :file: examples/htmltitle.nim +## :file: ../../examples/htmltitle.nim ## ## ## Example 2: Retrieve all HTML links @@ -45,7 +45,7 @@ ## an HTML document contains. ## ## .. code-block:: nim -## :file: examples/htmlrefs.nim +## :file: ../../examples/htmlrefs.nim ## import @@ -142,6 +142,9 @@ proc kind*(my: XmlParser): XmlEventKind {.inline.} = template charData*(my: XmlParser): string = ## returns the character data for the events: ``xmlCharData``, ## ``xmlWhitespace``, ``xmlComment``, ``xmlCData``, ``xmlSpecial`` + ## Raises an assertion in debug mode if ``my.kind`` is not one + ## of those events. In release mode, this will not trigger an error + ## but the value returned will not be valid. assert(my.kind in {xmlCharData, xmlWhitespace, xmlComment, xmlCData, xmlSpecial}) my.a @@ -149,31 +152,49 @@ template charData*(my: XmlParser): string = template elementName*(my: XmlParser): string = ## returns the element name for the events: ``xmlElementStart``, ## ``xmlElementEnd``, ``xmlElementOpen`` + ## Raises an assertion in debug mode if ``my.kind`` is not one + ## of those events. In release mode, this will not trigger an error + ## but the value returned will not be valid. assert(my.kind in {xmlElementStart, xmlElementEnd, xmlElementOpen}) my.a template entityName*(my: XmlParser): string = ## returns the entity name for the event: ``xmlEntity`` + ## Raises an assertion in debug mode if ``my.kind`` is not + ## ``xmlEntity``. In release mode, this will not trigger an error + ## but the value returned will not be valid. assert(my.kind == xmlEntity) my.a template attrKey*(my: XmlParser): string = ## returns the attribute key for the event ``xmlAttribute`` + ## Raises an assertion in debug mode if ``my.kind`` is not + ## ``xmlAttribute``. In release mode, this will not trigger an error + ## but the value returned will not be valid. assert(my.kind == xmlAttribute) my.a template attrValue*(my: XmlParser): string = ## returns the attribute value for the event ``xmlAttribute`` + ## Raises an assertion in debug mode if ``my.kind`` is not + ## ``xmlAttribute``. In release mode, this will not trigger an error + ## but the value returned will not be valid. assert(my.kind == xmlAttribute) my.b template piName*(my: XmlParser): string = ## returns the processing instruction name for the event ``xmlPI`` + ## Raises an assertion in debug mode if ``my.kind`` is not + ## ``xmlPI``. In release mode, this will not trigger an error + ## but the value returned will not be valid. assert(my.kind == xmlPI) my.a template piRest*(my: XmlParser): string = ## returns the rest of the processing instruction for the event ``xmlPI`` + ## Raises an assertion in debug mode if ``my.kind`` is not + ## ``xmlPI``. In release mode, this will not trigger an error + ## but the value returned will not be valid. assert(my.kind == xmlPI) my.b diff --git a/lib/pure/pegs.nim b/lib/pure/pegs.nim index fead66de2..5c978a2f8 100644 --- a/lib/pure/pegs.nim +++ b/lib/pure/pegs.nim @@ -12,7 +12,7 @@ ## Matching performance is hopefully competitive with optimized regular ## expression engines. ## -## .. include:: ../doc/pegdocs.txt +## .. include:: ../../doc/pegdocs.txt ## include "system/inclrtl" @@ -659,7 +659,7 @@ proc rawMatch*(s: string, p: Peg, start: int, c: var Captures): int {. of pkSearch: var oldMl = c.ml result = 0 - while start+result < s.len: + while start+result <= s.len: var x = rawMatch(s, p.sons[0], start+result, c) if x >= 0: inc(result, x) @@ -671,7 +671,7 @@ proc rawMatch*(s: string, p: Peg, start: int, c: var Captures): int {. var idx = c.ml # reserve a slot for the subpattern inc(c.ml) result = 0 - while start+result < s.len: + while start+result <= s.len: var x = rawMatch(s, p.sons[0], start+result, c) if x >= 0: if idx < MaxSubpatterns: @@ -962,6 +962,50 @@ proc parallelReplace*(s: string, subs: varargs[ # copy the rest: add(result, substr(s, i)) +proc replace*(s: string, sub: Peg, cb: proc( + match: int, cnt: int, caps: openArray[string]): string): string {. + rtl, extern: "npegs$1cb".}= + ## Replaces `sub` in `s` by the resulting strings from the callback. + ## The callback proc receives the index of the current match (starting with 0), + ## the count of captures and an open array with the captures of each match. Examples: + ## + ## .. code-block:: nim + ## + ## proc handleMatches*(m: int, n: int, c: openArray[string]): string = + ## result = "" + ## if m > 0: + ## result.add ", " + ## result.add case n: + ## of 2: c[0].toLower & ": '" & c[1] & "'" + ## of 1: c[0].toLower & ": ''" + ## else: "" + ## + ## let s = "Var1=key1;var2=Key2; VAR3" + ## echo s.replace(peg"{\ident}('='{\ident})* ';'* \s*", handleMatches) + ## + ## Results in: + ## + ## .. code-block:: nim + ## + ## "var1: 'key1', var2: 'Key2', var3: ''" + result = "" + var i = 0 + var caps: array[0..MaxSubpatterns-1, string] + var c: Captures + var m = 0 + while i < s.len: + c.ml = 0 + var x = rawMatch(s, sub, i, c) + if x <= 0: + add(result, s[i]) + inc(i) + else: + fillMatches(s, caps, c) + add(result, cb(m, c.ml, caps)) + inc(i, x) + inc(m) + add(result, substr(s, i)) + proc transformFile*(infile, outfile: string, subs: varargs[tuple[pattern: Peg, repl: string]]) {. rtl, extern: "npegs$1".} = @@ -1789,3 +1833,22 @@ when isMainModule: assert(str.find(empty_test) == 0) assert(str.match(empty_test)) + + proc handleMatches*(m: int, n: int, c: openArray[string]): string = + result = "" + + if m > 0: + result.add ", " + + result.add case n: + of 2: toLowerAscii(c[0]) & ": '" & c[1] & "'" + of 1: toLowerAscii(c[0]) & ": ''" + else: "" + + assert("Var1=key1;var2=Key2; VAR3". + replace(peg"{\ident}('='{\ident})* ';'* \s*", + handleMatches)=="var1: 'key1', var2: 'Key2', var3: ''") + + + doAssert "test1".match(peg"""{@}$""") + doAssert "test2".match(peg"""{(!$ .)*} $""") diff --git a/lib/pure/punycode.nim b/lib/pure/punycode.nim new file mode 100644 index 000000000..ab6501ed1 --- /dev/null +++ b/lib/pure/punycode.nim @@ -0,0 +1,174 @@ +# +# +# Nim's Runtime Library +# (c) Copyright 2016 Andreas Rumpf +# +# See the file "copying.txt", included in this +# distribution, for details about the copyright. +# + +import strutils +import unicode + +# issue #3045 + +const + Base = 36 + TMin = 1 + TMax = 26 + Skew = 38 + Damp = 700 + InitialBias = 72 + InitialN = 128 + Delimiter = '-' + +type + PunyError* = object of Exception + +proc decodeDigit(x: char): int {.raises: [PunyError].} = + if '0' <= x and x <= '9': + result = ord(x) - (ord('0') - 26) + elif 'A' <= x and x <= 'Z': + result = ord(x) - ord('A') + elif 'a' <= x and x <= 'z': + result = ord(x) - ord('a') + else: + raise newException(PunyError, "Bad input") + +proc encodeDigit(digit: int): Rune {.raises: [PunyError].} = + if 0 <= digit and digit < 26: + result = Rune(digit + ord('a')) + elif 26 <= digit and digit < 36: + result = Rune(digit + (ord('0') - 26)) + else: + raise newException(PunyError, "internal error in punycode encoding") + +proc isBasic(c: char): bool = ord(c) < 0x80 +proc isBasic(r: Rune): bool = int(r) < 0x80 + +proc adapt(delta, numPoints: int, first: bool): int = + var d = if first: delta div Damp else: delta div 2 + d += d div numPoints + var k = 0 + while d > ((Base-TMin)*TMax) div 2: + d = d div (Base - TMin) + k += Base + result = k + (Base - TMin + 1) * d div (d + Skew) + +proc encode*(prefix, s: string): string {.raises: [PunyError].} = + ## Encode a string that may contain Unicode. + ## Prepend `prefix` to the result + result = prefix + var (d, n, bias) = (0, InitialN, InitialBias) + var (b, remaining) = (0, 0) + for r in s.runes: + if r.isBasic: + # basic Ascii character + inc b + result.add($r) + else: + # special character + inc remaining + + var h = b + if b > 0: + result.add(Delimiter) # we have some Ascii chars + while remaining != 0: + var m: int = high(int32) + for r in s.runes: + if m > int(r) and int(r) >= n: + m = int(r) + d += (m - n) * (h + 1) + if d < 0: + raise newException(PunyError, "invalid label " & s) + n = m + for r in s.runes: + if int(r) < n: + inc d + if d < 0: + raise newException(PunyError, "invalid label " & s) + continue + if int(r) > n: + continue + var q = d + var k = Base + while true: + var t = k - bias + if t < TMin: + t = TMin + elif t > TMax: + t = TMax + if q < t: + break + result.add($encodeDigit(t + (q - t) mod (Base - t))) + q = (q - t) div (Base - t) + k += Base + result.add($encodeDigit(q)) + bias = adapt(d, h + 1, h == b) + d = 0 + inc h + dec remaining + inc d + inc n + +proc encode*(s: string): string {.raises: [PunyError].} = + ## Encode a string that may contain Unicode. Prefix is empty. + result = encode("", s) + +proc decode*(encoded: string): string {.raises: [PunyError].} = + ## Decode a Punycode-encoded string + var + n = InitialN + i = 0 + bias = InitialBias + var d = rfind(encoded, Delimiter) + result = "" + + if d > 0: + # found Delimiter + for j in 0..<d: + var c = encoded[j] # char + if not c.isBasic: + raise newException(PunyError, "Encoded contains a non-basic char") + result.add(c) # add the character + inc d + else: + d = 0 # set to first index + + while (d < len(encoded)): + var oldi = i + var w = 1 + var k = Base + while true: + if d == len(encoded): + raise newException(PunyError, "Bad input: " & encoded) + var c = encoded[d]; inc d + var digit = int(decodeDigit(c)) + if digit > (high(int32) - i) div w: + raise newException(PunyError, "Too large a value: " & $digit) + i += digit * w + var t: int + if k <= bias: + t = TMin + elif k >= bias + TMax: + t = TMax + else: + t = k - bias + if digit < t: + break + w *= Base - t + k += Base + bias = adapt(i - oldi, runelen(result) + 1, oldi == 0) + + if i div (runelen(result) + 1) > high(int32) - n: + raise newException(PunyError, "Value too large") + + n += i div (runelen(result) + 1) + i = i mod (runelen(result) + 1) + insert(result, $Rune(n), i) + inc i + +when isMainModule: + assert(decode(encode("", "bücher")) == "bücher") + assert(decode(encode("münchen")) == "münchen") + assert encode("xn--", "münchen") == "xn--mnchen-3ya" diff --git a/lib/pure/random.nim b/lib/pure/random.nim new file mode 100644 index 000000000..08da771dc --- /dev/null +++ b/lib/pure/random.nim @@ -0,0 +1,128 @@ +# +# +# Nim's Runtime Library +# (c) Copyright 2016 Andreas Rumpf +# +# See the file "copying.txt", included in this +# distribution, for details about the copyright. +# + +## Nim's standard random number generator. Based on the ``xoroshiro128+`` (xor/rotate/shift/rotate) library. +## * More information: http://xoroshiro.di.unimi.it/ +## * C implementation: http://xoroshiro.di.unimi.it/xoroshiro128plus.c +## + +include "system/inclrtl" +{.push debugger:off.} + +# XXX Expose RandomGenState +when defined(JS): + type ui = uint32 +else: + type ui = uint64 + +type + RandomGenState = object + a0, a1: ui + +when defined(JS): + var state = RandomGenState( + a0: 0x69B4C98Cu32, + a1: 0xFED1DD30u32) # global for backwards compatibility +else: + # racy for multi-threading but good enough for now: + var state = RandomGenState( + a0: 0x69B4C98CB8530805u64, + a1: 0xFED1DD3004688D67CAu64) # global for backwards compatibility + +proc rotl(x, k: ui): ui = + result = (x shl k) or (x shr (ui(64) - k)) + +proc next(s: var RandomGenState): uint64 = + let s0 = s.a0 + var s1 = s.a1 + result = s0 + s1 + s1 = s1 xor s0 + s.a0 = rotl(s0, 55) xor s1 xor (s1 shl 14) # a, b + s.a1 = rotl(s1, 36) # c + +proc skipRandomNumbers(s: var RandomGenState) = + ## This is the jump function for the generator. It is equivalent + ## to 2^64 calls to next(); it can be used to generate 2^64 + ## non-overlapping subsequences for parallel computations. + when defined(JS): + const helper = [0xbeac0467u32, 0xd86b048bu32] + else: + const helper = [0xbeac0467eba5facbu64, 0xd86b048b86aa9922u64] + var + s0 = ui 0 + s1 = ui 0 + for i in 0..high(helper): + for b in 0..< 64: + if (helper[i] and (ui(1) shl ui(b))) != 0: + s0 = s0 xor s.a0 + s1 = s1 xor s.a1 + discard next(s) + s.a0 = s0 + s.a1 = s1 + +proc random*(max: int): int {.benign.} = + ## Returns a random number in the range 0..max-1. The sequence of + ## random number is always the same, unless `randomize` is called + ## which initializes the random number generator with a "random" + ## number, i.e. a tickcount. + result = int(next(state) mod uint64(max)) + +proc random*(max: float): float {.benign.} = + ## Returns a random number in the range 0..<max. The sequence of + ## random number is always the same, unless `randomize` is called + ## which initializes the random number generator with a "random" + ## number, i.e. a tickcount. + let x = next(state) + when defined(JS): + result = (float(x) / float(high(uint32))) * max + else: + let u = (0x3FFu64 shl 52u64) or (x shr 12u64) + result = (cast[float](u) - 1.0) * max + +proc random*[T](x: Slice[T]): T = + ## For a slice `a .. b` returns a value in the range `a .. b-1`. + result = random(x.b - x.a) + x.a + +proc random*[T](a: openArray[T]): T = + ## returns a random element from the openarray `a`. + result = a[random(a.low..a.len)] + +proc randomize*(seed: int) {.benign.} = + ## Initializes the random number generator with a specific seed. + state.a0 = ui(seed shr 16) + state.a1 = ui(seed and 0xffff) + +when not defined(nimscript): + import times + + proc randomize*() {.benign.} = + ## Initializes the random number generator with a "random" + ## number, i.e. a tickcount. Note: Does not work for NimScript. + when defined(JS): + proc getMil(t: Time): int {.importcpp: "getTime", nodecl.} + randomize(getMil times.getTime()) + else: + randomize(int times.getTime()) + +{.pop.} + +when isMainModule: + proc main = + var occur: array[1000, int] + + var x = 8234 + for i in 0..100_000: + x = random(len(occur)) # myrand(x) + inc occur[x] + for i, oc in occur: + if oc < 69: + doAssert false, "too few occurances of " & $i + elif oc > 130: + doAssert false, "too many occurances of " & $i + main() diff --git a/lib/pure/rationals.nim b/lib/pure/rationals.nim index 6fd05dc4b..bf134f2ae 100644 --- a/lib/pure/rationals.nim +++ b/lib/pure/rationals.nim @@ -41,26 +41,26 @@ proc toRational*[T:SomeInteger](x: T): Rational[T] = proc toRationalSub(x: float, n: int): Rational[int] = var - a = 0 - b, c, d = 1 + a = 0'i64 + b, c, d = 1'i64 result = 0 // 1 # rational 0 while b <= n and d <= n: let ac = (a+c) let bd = (b+d) # scale by 1000 so not overflow for high precision - let mediant = (ac/1000) / (bd/1000) + let mediant = (ac.float/1000) / (bd.float/1000) if x == mediant: if bd <= n: - result.num = ac - result.den = bd + result.num = ac.int + result.den = bd.int return result elif d > b: - result.num = c - result.den = d + result.num = c.int + result.den = d.int return result else: - result.num = a - result.den = b + result.num = a.int + result.den = b.int return result elif x > mediant: a = ac @@ -69,8 +69,8 @@ proc toRationalSub(x: float, n: int): Rational[int] = c = ac d = bd if (b > n): - return initRational(c, d) - return initRational(a, b) + return initRational(c.int, d.int) + return initRational(a.int, b.int) proc toRational*(x: float, n: int = high(int)): Rational[int] = ## Calculate the best rational numerator and denominator diff --git a/lib/pure/selectors.nim b/lib/pure/selectors.nim index 89e92c133..098b78c95 100644 --- a/lib/pure/selectors.nim +++ b/lib/pure/selectors.nim @@ -132,11 +132,12 @@ elif defined(linux): s.fds[fd].events = events proc unregister*(s: var Selector, fd: SocketHandle) = - if epoll_ctl(s.epollFD, EPOLL_CTL_DEL, fd, nil) != 0: - let err = osLastError() - if err.cint notin {ENOENT, EBADF}: - # TODO: Why do we sometimes get an EBADF? Is this normal? - raiseOSError(err) + if s.fds[fd].events != {}: + if epoll_ctl(s.epollFD, EPOLL_CTL_DEL, fd, nil) != 0: + let err = osLastError() + if err.cint notin {ENOENT, EBADF}: + # TODO: Why do we sometimes get an EBADF? Is this normal? + raiseOSError(err) s.fds.del(fd) proc close*(s: var Selector) = diff --git a/lib/pure/smtp.nim b/lib/pure/smtp.nim index b2adac2f3..050712902 100644 --- a/lib/pure/smtp.nim +++ b/lib/pure/smtp.nim @@ -20,9 +20,9 @@ ## var msg = createMessage("Hello from Nim's SMTP", ## "Hello!.\n Is this awesome or what?", ## @["foo@gmail.com"]) -## var smtp = connect("smtp.gmail.com", 465, true, true) -## smtp.auth("username", "password") -## smtp.sendmail("username@gmail.com", @["foo@gmail.com"], $msg) +## var smtpConn = connect("smtp.gmail.com", Port 465, true, true) +## smtpConn.auth("username", "password") +## smtpConn.sendmail("username@gmail.com", @["foo@gmail.com"], $msg) ## ## ## For SSL support this module relies on OpenSSL. If you want to @@ -31,6 +31,8 @@ import net, strutils, strtabs, base64, os import asyncnet, asyncdispatch +export Port + type Smtp* = object sock: Socket @@ -258,8 +260,8 @@ when not defined(testing) and isMainModule: # "Hello, my name is dom96.\n What\'s yours?", @["dominik@localhost"]) #echo(msg) - #var smtp = connect("localhost", 25, False, True) - #smtp.sendmail("root@localhost", @["dominik@localhost"], $msg) + #var smtpConn = connect("localhost", Port 25, false, true) + #smtpConn.sendmail("root@localhost", @["dominik@localhost"], $msg) #echo(decode("a17sm3701420wbe.12")) proc main() {.async.} = diff --git a/lib/pure/streams.nim b/lib/pure/streams.nim index c606b4680..eea06f4ce 100644 --- a/lib/pure/streams.nim +++ b/lib/pure/streams.nim @@ -306,68 +306,68 @@ proc peekLine*(s: Stream): TaintedString = defer: setPosition(s, pos) result = readLine(s) -type - StringStream* = ref StringStreamObj ## a stream that encapsulates a string - StringStreamObj* = object of StreamObj - data*: string - pos: int - -{.deprecated: [PStringStream: StringStream, TStringStream: StringStreamObj].} - -proc ssAtEnd(s: Stream): bool = - var s = StringStream(s) - return s.pos >= s.data.len - -proc ssSetPosition(s: Stream, pos: int) = - var s = StringStream(s) - s.pos = clamp(pos, 0, s.data.len) - -proc ssGetPosition(s: Stream): int = - var s = StringStream(s) - return s.pos - -proc ssReadData(s: Stream, buffer: pointer, bufLen: int): int = - var s = StringStream(s) - result = min(bufLen, s.data.len - s.pos) - if result > 0: - copyMem(buffer, addr(s.data[s.pos]), result) - inc(s.pos, result) - -proc ssPeekData(s: Stream, buffer: pointer, bufLen: int): int = - var s = StringStream(s) - result = min(bufLen, s.data.len - s.pos) - if result > 0: - copyMem(buffer, addr(s.data[s.pos]), result) - -proc ssWriteData(s: Stream, buffer: pointer, bufLen: int) = - var s = StringStream(s) - if bufLen <= 0: - return - if s.pos + bufLen > s.data.len: - setLen(s.data, s.pos + bufLen) - copyMem(addr(s.data[s.pos]), buffer, bufLen) - inc(s.pos, bufLen) - -proc ssClose(s: Stream) = - var s = StringStream(s) - s.data = nil - -proc newStringStream*(s: string = ""): StringStream = - ## creates a new stream from the string `s`. - new(result) - result.data = s - result.pos = 0 - result.closeImpl = ssClose - result.atEndImpl = ssAtEnd - result.setPositionImpl = ssSetPosition - result.getPositionImpl = ssGetPosition - result.readDataImpl = ssReadData - result.peekDataImpl = ssPeekData - result.writeDataImpl = ssWriteData - when not defined(js): type + StringStream* = ref StringStreamObj ## a stream that encapsulates a string + StringStreamObj* = object of StreamObj + data*: string + pos: int + + {.deprecated: [PStringStream: StringStream, TStringStream: StringStreamObj].} + + proc ssAtEnd(s: Stream): bool = + var s = StringStream(s) + return s.pos >= s.data.len + + proc ssSetPosition(s: Stream, pos: int) = + var s = StringStream(s) + s.pos = clamp(pos, 0, s.data.len) + + proc ssGetPosition(s: Stream): int = + var s = StringStream(s) + return s.pos + + proc ssReadData(s: Stream, buffer: pointer, bufLen: int): int = + var s = StringStream(s) + result = min(bufLen, s.data.len - s.pos) + if result > 0: + copyMem(buffer, addr(s.data[s.pos]), result) + inc(s.pos, result) + + proc ssPeekData(s: Stream, buffer: pointer, bufLen: int): int = + var s = StringStream(s) + result = min(bufLen, s.data.len - s.pos) + if result > 0: + copyMem(buffer, addr(s.data[s.pos]), result) + + proc ssWriteData(s: Stream, buffer: pointer, bufLen: int) = + var s = StringStream(s) + if bufLen <= 0: + return + if s.pos + bufLen > s.data.len: + setLen(s.data, s.pos + bufLen) + copyMem(addr(s.data[s.pos]), buffer, bufLen) + inc(s.pos, bufLen) + + proc ssClose(s: Stream) = + var s = StringStream(s) + s.data = nil + + proc newStringStream*(s: string = ""): StringStream = + ## creates a new stream from the string `s`. + new(result) + result.data = s + result.pos = 0 + result.closeImpl = ssClose + result.atEndImpl = ssAtEnd + result.setPositionImpl = ssSetPosition + result.getPositionImpl = ssGetPosition + result.readDataImpl = ssReadData + result.peekDataImpl = ssPeekData + result.writeDataImpl = ssWriteData + + type FileStream* = ref FileStreamObj ## a stream that encapsulates a `File` FileStreamObj* = object of Stream f: File diff --git a/lib/pure/strmisc.nim b/lib/pure/strmisc.nim new file mode 100644 index 000000000..89ef2fcd2 --- /dev/null +++ b/lib/pure/strmisc.nim @@ -0,0 +1,83 @@ +# +# +# Nim's Runtime Library +# (c) Copyright 2016 Joey Payne +# +# See the file "copying.txt", included in this +# distribution, for details about the copyright. +# + +## This module contains various string utility routines that are uncommonly +## used in comparison to `strutils <strutils.html>`_. + +import strutils + +{.deadCodeElim: on.} + +proc expandTabs*(s: string, tabSize: int = 8): string {.noSideEffect, + procvar.} = + ## Expand tab characters in `s` by `tabSize` spaces + + result = newStringOfCap(s.len + s.len shr 2) + var pos = 0 + + template addSpaces(n) = + for j in 0 ..< n: + result.add(' ') + pos += 1 + + for i in 0 ..< len(s): + let c = s[i] + if c == '\t': + let + denominator = if tabSize > 0: tabSize else: 1 + numSpaces = tabSize - pos mod denominator + + addSpaces(numSpaces) + else: + result.add(c) + pos += 1 + if c == '\l': + pos = 0 + +proc partition*(s: string, sep: string, + right: bool = false): (string, string, string) + {.noSideEffect, procvar.} = + ## Split the string at the first or last occurrence of `sep` into a 3-tuple + ## + ## Returns a 3 string tuple of (beforeSep, `sep`, afterSep) or + ## (`s`, "", "") if `sep` is not found and `right` is false or + ## ("", "", `s`) if `sep` is not found and `right` is true + let position = if right: s.rfind(sep) else: s.find(sep) + if position != -1: + return (s[0 ..< position], sep, s[position + sep.len ..< s.len]) + return if right: ("", "", s) else: (s, "", "") + +proc rpartition*(s: string, sep: string): (string, string, string) + {.noSideEffect, procvar.} = + ## Split the string at the last occurrence of `sep` into a 3-tuple + ## + ## Returns a 3 string tuple of (beforeSep, `sep`, afterSep) or + ## ("", "", `s`) if `sep` is not found + return partition(s, sep, right = true) + +when isMainModule: + doAssert expandTabs("\t", 4) == " " + doAssert expandTabs("\tfoo\t", 4) == " foo " + doAssert expandTabs("\tfoo\tbar", 4) == " foo bar" + doAssert expandTabs("\tfoo\tbar\t", 4) == " foo bar " + doAssert expandTabs("", 4) == "" + doAssert expandTabs("", 0) == "" + doAssert expandTabs("\t\t\t", 0) == "" + + doAssert partition("foo:bar", ":") == ("foo", ":", "bar") + doAssert partition("foobarbar", "bar") == ("foo", "bar", "bar") + doAssert partition("foobarbar", "bank") == ("foobarbar", "", "") + doAssert partition("foobarbar", "foo") == ("", "foo", "barbar") + doAssert partition("foofoobar", "bar") == ("foofoo", "bar", "") + + doAssert rpartition("foo:bar", ":") == ("foo", ":", "bar") + doAssert rpartition("foobarbar", "bar") == ("foobar", "bar", "") + doAssert rpartition("foobarbar", "bank") == ("", "", "foobarbar") + doAssert rpartition("foobarbar", "foo") == ("", "foo", "barbar") + doAssert rpartition("foofoobar", "bar") == ("foofoo", "bar", "") diff --git a/lib/pure/strscans.nim b/lib/pure/strscans.nim new file mode 100644 index 000000000..246f018c5 --- /dev/null +++ b/lib/pure/strscans.nim @@ -0,0 +1,522 @@ +# +# +# Nim's Runtime Library +# (c) Copyright 2016 Andreas Rumpf +# +# See the file "copying.txt", included in this +# distribution, for details about the copyright. +# + +##[ +This module contains a `scanf`:idx: macro that can be used for extracting +substrings from an input string. This is often easier than regular expressions. +Some examples as an apetizer: + +.. code-block:: nim + # check if input string matches a triple of integers: + const input = "(1,2,4)" + var x, y, z: int + if scanf(input, "($i,$i,$i)", x, y, z): + echo "matches and x is ", x, " y is ", y, " z is ", z + + # check if input string matches an ISO date followed by an identifier followed + # by whitespace and a floating point number: + var year, month, day: int + var identifier: string + var myfloat: float + if scanf(input, "$i-$i-$i $w$s$f", year, month, day, identifier, myfloat): + echo "yes, we have a match!" + +As can be seen from the examples, strings are matched verbatim except for +substrings starting with ``$``. These constructions are available: + +================= ======================================================== +``$i`` Matches an integer. This uses ``parseutils.parseInt``. +``$f`` Matches a floating pointer number. Uses ``parseFloat``. +``$w`` Matches an ASCII identifier: ``[A-Z-a-z_][A-Za-z_0-9]*``. +``$s`` Skips optional whitespace. +``$$`` Matches a single dollar sign. +``$.`` Matches if the end of the input string has been reached. +``$*`` Matches until the token following the ``$*`` was found. + The match is allowed to be of 0 length. +``$+`` Matches until the token following the ``$+`` was found. + The match must consist of at least one char. +``${foo}`` User defined matcher. Uses the proc ``foo`` to perform + the match. See below for more details. +``$[foo]`` Call user defined proc ``foo`` to **skip** some optional + parts in the input string. See below for more details. +================= ======================================================== + +Even though ``$*`` and ``$+`` look similar to the regular expressions ``.*`` +and ``.+`` they work quite differently, there is no non-deterministic +state machine involved and the matches are non-greedy. ``[$*]`` +matches ``[xyz]`` via ``parseutils.parseUntil``. + +Furthermore no backtracking is performed, if parsing fails after a value +has already been bound to a matched subexpression this value is not restored +to its original value. This rarely causes problems in practice and if it does +for you, it's easy enough to bind to a temporary variable first. + + +Startswith vs full match +======================== + +``scanf`` returns true if the input string **starts with** the specified +pattern. If instead it should only return true if theres is also nothing +left in the input, append ``$.`` to your pattern. + + +User definable matchers +======================= + +One very nice advantage over regular expressions is that ``scanf`` is +extensible with ordinary Nim procs. The proc is either enclosed in ``${}`` +or in ``$[]``. ``${}`` matches and binds the result +to a variable (that was passed to the ``scanf`` macro) while ``$[]`` merely +optional tokens. + + +In this example, we define a helper proc ``skipSep`` that skips some separators +which we then use in our scanf pattern to help us in the matching process: + +.. code-block:: nim + + proc someSep(input: string; start: int; seps: set[char] = {':','-','.'}): int = + # Note: The parameters and return value must match to what ``scanf`` requires + result = 0 + while input[start+result] in seps: inc result + + if scanf(input, "$w${someSep}$w", key, value): + ... + +It also possible to pass arguments to a user definable matcher: + +.. code-block:: nim + + proc ndigits(input: string; start: int; intVal: var int; n: int): int = + # matches exactly ``n`` digits. Matchers need to return 0 if nothing + # matched or otherwise the number of processed chars. + var x = 0 + var i = 0 + while i < n and i+start < input.len and input[i+start] in {'0'..'9'}: + x = x * 10 + input[i+start].ord - '0'.ord + inc i + # only overwrite if we had a match + if i == n: + result = n + intVal = x + + # match an ISO date extracting year, month, day at the same time. + # Also ensure the input ends after the ISO date: + var year, month, day: int + if scanf("2013-01-03", "${ndigits(4)}-${ndigits(2)}-${ndigits(2)}$.", year, month, day): + ... + +]## + + +import macros, parseutils + +proc conditionsToIfChain(n, idx, res: NimNode; start: int): NimNode = + assert n.kind == nnkStmtList + if start >= n.len: return newAssignment(res, newLit true) + var ifs: NimNode = nil + if n[start+1].kind == nnkEmpty: + ifs = conditionsToIfChain(n, idx, res, start+3) + else: + ifs = newIfStmt((n[start+1], + newTree(nnkStmtList, newCall(bindSym"inc", idx, n[start+2]), + conditionsToIfChain(n, idx, res, start+3)))) + result = newTree(nnkStmtList, n[start], ifs) + +proc notZero(x: NimNode): NimNode = newCall(bindSym"!=", x, newLit 0) + +proc buildUserCall(x: string; args: varargs[NimNode]): NimNode = + let y = parseExpr(x) + result = newTree(nnkCall) + if y.kind in nnkCallKinds: result.add y[0] + else: result.add y + for a in args: result.add a + if y.kind in nnkCallKinds: + for i in 1..<y.len: result.add y[i] + +macro scanf*(input: string; pattern: static[string]; results: varargs[typed]): bool = + ## See top level documentation of his module of how ``scanf`` works. + template matchBind(parser) {.dirty.} = + var resLen = genSym(nskLet, "resLen") + conds.add newLetStmt(resLen, newCall(bindSym(parser), input, results[i], idx)) + conds.add resLen.notZero + conds.add resLen + + var i = 0 + var p = 0 + var idx = genSym(nskVar, "idx") + var res = genSym(nskVar, "res") + result = newTree(nnkStmtListExpr, newVarStmt(idx, newLit 0), newVarStmt(res, newLit false)) + var conds = newTree(nnkStmtList) + var fullMatch = false + while p < pattern.len: + if pattern[p] == '$': + inc p + case pattern[p] + of '$': + var resLen = genSym(nskLet, "resLen") + conds.add newLetStmt(resLen, newCall(bindSym"skip", input, newLit($pattern[p]), idx)) + conds.add resLen.notZero + conds.add resLen + of 'w': + if i < results.len or getType(results[i]).typeKind != ntyString: + matchBind "parseIdent" + else: + error("no string var given for $w") + inc i + of 'i': + if i < results.len or getType(results[i]).typeKind != ntyInt: + matchBind "parseInt" + else: + error("no int var given for $d") + inc i + of 'f': + if i < results.len or getType(results[i]).typeKind != ntyFloat: + matchBind "parseFloat" + else: + error("no float var given for $f") + inc i + of 's': + conds.add newCall(bindSym"inc", idx, newCall(bindSym"skipWhitespace", input, idx)) + conds.add newEmptyNode() + conds.add newEmptyNode() + of '.': + if p == pattern.len-1: + fullMatch = true + else: + error("invalid format string") + of '*', '+': + if i < results.len or getType(results[i]).typeKind != ntyString: + var min = ord(pattern[p] == '+') + var q=p+1 + var token = "" + while q < pattern.len and pattern[q] != '$': + token.add pattern[q] + inc q + var resLen = genSym(nskLet, "resLen") + conds.add newLetStmt(resLen, newCall(bindSym"parseUntil", input, results[i], newLit(token), idx)) + conds.add newCall(bindSym"!=", resLen, newLit min) + conds.add resLen + else: + error("no string var given for $" & pattern[p]) + inc i + of '{': + inc p + var nesting = 0 + let start = p + while true: + case pattern[p] + of '{': inc nesting + of '}': + if nesting == 0: break + dec nesting + of '\0': error("expected closing '}'") + else: discard + inc p + let expr = pattern.substr(start, p-1) + if i < results.len: + var resLen = genSym(nskLet, "resLen") + conds.add newLetStmt(resLen, buildUserCall(expr, input, results[i], idx)) + conds.add newCall(bindSym"!=", resLen, newLit 0) + conds.add resLen + else: + error("no var given for $" & expr) + inc i + of '[': + inc p + var nesting = 0 + let start = p + while true: + case pattern[p] + of '[': inc nesting + of ']': + if nesting == 0: break + dec nesting + of '\0': error("expected closing ']'") + else: discard + inc p + let expr = pattern.substr(start, p-1) + conds.add newCall(bindSym"inc", idx, buildUserCall(expr, input, idx)) + conds.add newEmptyNode() + conds.add newEmptyNode() + else: error("invalid format string") + inc p + else: + var token = "" + while p < pattern.len and pattern[p] != '$': + token.add pattern[p] + inc p + var resLen = genSym(nskLet, "resLen") + conds.add newLetStmt(resLen, newCall(bindSym"skip", input, newLit(token), idx)) + conds.add resLen.notZero + conds.add resLen + result.add conditionsToIfChain(conds, idx, res, 0) + if fullMatch: + result.add newCall(bindSym"and", res, + newCall(bindSym">=", idx, newCall(bindSym"len", input))) + else: + result.add res + +template atom*(input: string; idx: int; c: char): bool = + ## Used in scanp for the matching of atoms (usually chars). + input[idx] == c + +template atom*(input: string; idx: int; s: set[char]): bool = + input[idx] in s + +#template prepare*(input: string): int = 0 +template success*(x: int): bool = x != 0 + +template nxt*(input: string; idx, step: int = 1) = inc(idx, step) + +macro scanp*(input, idx: typed; pattern: varargs[untyped]): bool = + ## See top level documentation of his module of how ``scanp`` works. + type StmtTriple = tuple[init, cond, action: NimNode] + + template interf(x): untyped = bindSym(x, brForceOpen) + + proc toIfChain(n: seq[StmtTriple]; idx, res: NimNode; start: int): NimNode = + if start >= n.len: return newAssignment(res, newLit true) + var ifs: NimNode = nil + if n[start].cond.kind == nnkEmpty: + ifs = toIfChain(n, idx, res, start+1) + else: + ifs = newIfStmt((n[start].cond, + newTree(nnkStmtList, n[start].action, + toIfChain(n, idx, res, start+1)))) + result = newTree(nnkStmtList, n[start].init, ifs) + + proc attach(x, attached: NimNode): NimNode = + if attached == nil: x + else: newStmtList(attached, x) + + proc placeholder(n, x, j: NimNode): NimNode = + if n.kind == nnkPrefix and n[0].eqIdent("$"): + let n1 = n[1] + if n1.eqIdent"_" or n1.eqIdent"current": + result = newTree(nnkBracketExpr, x, j) + elif n1.eqIdent"input": + result = x + elif n1.eqIdent"i" or n1.eqIdent"index": + result = j + else: + error("unknown pattern " & repr(n)) + else: + result = copyNimNode(n) + for i in 0 ..< n.len: + result.add placeholder(n[i], x, j) + + proc atm(it, input, idx, attached: NimNode): StmtTriple = + template `!!`(x): untyped = attach(x, attached) + case it.kind + of nnkIdent: + var resLen = genSym(nskLet, "resLen") + result = (newLetStmt(resLen, newCall(it, input, idx)), + newCall(interf"success", resLen), + !!newCall(interf"nxt", input, idx, resLen)) + of nnkCallKinds: + # *{'A'..'Z'} !! s.add(!_) + template buildWhile(init, cond, action): untyped = + while true: + init + if not cond: break + action + + # (x) a # bind action a to (x) + if it[0].kind == nnkPar and it.len == 2: + result = atm(it[0], input, idx, placeholder(it[1], input, idx)) + elif it.kind == nnkInfix and it[0].eqIdent"->": + # bind matching to some action: + result = atm(it[1], input, idx, placeholder(it[2], input, idx)) + elif it.kind == nnkInfix and it[0].eqIdent"as": + let cond = if it[1].kind in nnkCallKinds: placeholder(it[1], input, idx) + else: newCall(it[1], input, idx) + result = (newLetStmt(it[2], cond), + newCall(interf"success", it[2]), + !!newCall(interf"nxt", input, idx, it[2])) + elif it.kind == nnkPrefix and it[0].eqIdent"*": + let (init, cond, action) = atm(it[1], input, idx, attached) + result = (getAst(buildWhile(init, cond, action)), + newEmptyNode(), newEmptyNode()) + elif it.kind == nnkPrefix and it[0].eqIdent"+": + # x+ is the same as xx* + result = atm(newTree(nnkPar, it[1], newTree(nnkPrefix, ident"*", it[1])), + input, idx, attached) + elif it.kind == nnkPrefix and it[0].eqIdent"?": + # optional. + let (init, cond, action) = atm(it[1], input, idx, attached) + if cond.kind == nnkEmpty: + error("'?' operator applied to a non-condition") + else: + result = (newTree(nnkStmtList, init, newIfStmt((cond, action))), + newEmptyNode(), newEmptyNode()) + elif it.kind == nnkPrefix and it[0].eqIdent"~": + # not operator + let (init, cond, action) = atm(it[1], input, idx, attached) + if cond.kind == nnkEmpty: + error("'~' operator applied to a non-condition") + else: + result = (init, newCall(bindSym"not", cond), action) + elif it.kind == nnkInfix and it[0].eqIdent"|": + let a = atm(it[1], input, idx, attached) + let b = atm(it[2], input, idx, attached) + if a.cond.kind == nnkEmpty or b.cond.kind == nnkEmpty: + error("'|' operator applied to a non-condition") + else: + result = (newStmtList(a.init, + newIfStmt((a.cond, a.action), (newTree(nnkStmtListExpr, b.init, b.cond), b.action))), + newEmptyNode(), newEmptyNode()) + elif it.kind == nnkInfix and it[0].eqIdent"^*": + # a ^* b is rewritten to: (a *(b a))? + #exprList = expr ^+ comma + template tmp(a, b): untyped = ?(a, *(b, a)) + result = atm(getAst(tmp(it[1], it[2])), input, idx, attached) + + elif it.kind == nnkInfix and it[0].eqIdent"^+": + # a ^* b is rewritten to: (a +(b a))? + template tmp(a, b): untyped = (a, *(b, a)) + result = atm(getAst(tmp(it[1], it[2])), input, idx, attached) + elif it.kind == nnkCommand and it.len == 2 and it[0].eqIdent"pred": + # enforce that the wrapped call is interpreted as a predicate, not a non-terminal: + result = (newEmptyNode(), placeholder(it[1], input, idx), newEmptyNode()) + else: + var resLen = genSym(nskLet, "resLen") + result = (newLetStmt(resLen, placeholder(it, input, idx)), + newCall(interf"success", resLen), !!newCall(interf"nxt", input, idx, resLen)) + of nnkStrLit..nnkTripleStrLit: + var resLen = genSym(nskLet, "resLen") + result = (newLetStmt(resLen, newCall(interf"skip", input, it, idx)), + newCall(interf"success", resLen), !!newCall(interf"nxt", input, idx, resLen)) + of nnkCurly, nnkAccQuoted, nnkCharLit: + result = (newEmptyNode(), newCall(interf"atom", input, idx, it), !!newCall(interf"nxt", input, idx)) + of nnkCurlyExpr: + if it.len == 3 and it[1].kind == nnkIntLit and it[2].kind == nnkIntLit: + var h = newTree(nnkPar, it[0]) + for count in 2..it[1].intVal: h.add(it[0]) + for count in it[1].intVal .. it[2].intVal-1: h.add(newTree(nnkPrefix, ident"?", it[0])) + result = atm(h, input, idx, attached) + elif it.len == 2 and it[1].kind == nnkIntLit: + var h = newTree(nnkPar, it[0]) + for count in 2..it[1].intVal: h.add(it[0]) + result = atm(h, input, idx, attached) + else: + error("invalid pattern") + of nnkPar: + if it.len == 1: + result = atm(it[0], input, idx, attached) + else: + # concatenation: + var conds: seq[StmtTriple] = @[] + for x in it: conds.add atm(x, input, idx, attached) + var res = genSym(nskVar, "res") + result = (newStmtList(newVarStmt(res, newLit false), + toIfChain(conds, idx, res, 0)), res, newEmptyNode()) + else: + error("invalid pattern") + + #var idx = genSym(nskVar, "idx") + var res = genSym(nskVar, "res") + result = newTree(nnkStmtListExpr, #newVarStmt(idx, newCall(interf"prepare", input)), + newVarStmt(res, newLit false)) + var conds: seq[StmtTriple] = @[] + for it in pattern: + conds.add atm(it, input, idx, nil) + result.add toIfChain(conds, idx, res, 0) + result.add res + when defined(debugScanp): + echo repr result + + +when isMainModule: + proc twoDigits(input: string; x: var int; start: int): int = + if input[start] == '0' and input[start+1] == '0': + result = 2 + x = 13 + else: + result = 0 + + proc someSep(input: string; start: int; seps: set[char] = {';',',','-','.'}): int = + result = 0 + while input[start+result] in seps: inc result + + proc demangle(s: string; res: var string; start: int): int = + while s[result+start] in {'_', '@'}: inc result + res = "" + while result+start < s.len and s[result+start] > ' ' and s[result+start] != '_': + res.add s[result+start] + inc result + while result+start < s.len and s[result+start] > ' ': + inc result + + proc parseGDB(resp: string): seq[string] = + const + digits = {'0'..'9'} + hexdigits = digits + {'a'..'f', 'A'..'F'} + whites = {' ', '\t', '\C', '\L'} + result = @[] + var idx = 0 + while true: + var prc = "" + var info = "" + if scanp(resp, idx, *`whites`, '#', *`digits`, +`whites`, ?("0x", *`hexdigits`, " in "), + demangle($input, prc, $index), *`whites`, '(', * ~ ')', ')', + *`whites`, "at ", +(~{'\C', '\L', '\0'} -> info.add($_)) ): + result.add prc & " " & info + else: + break + + var key, val: string + var intval: int + var floatval: float + doAssert scanf("abc:: xyz 89 33.25", "$w$s::$s$w$s$i $f", key, val, intval, floatVal) + doAssert key == "abc" + doAssert val == "xyz" + doAssert intval == 89 + doAssert floatVal == 33.25 + + let xx = scanf("$abc", "$$$i", intval) + doAssert xx == false + + + let xx2 = scanf("$1234", "$$$i", intval) + doAssert xx2 + + let yy = scanf(";.--Breakpoint00 [output]", "$[someSep]Breakpoint${twoDigits}$[someSep({';','.','-'})] [$+]$.", intVal, key) + doAssert yy + doAssert key == "output" + doAssert intVal == 13 + + var ident = "" + var idx = 0 + let zz = scanp("foobar x x x xWZ", idx, +{'a'..'z'} -> add(ident, $_), *(*{' ', '\t'}, "x"), ~'U', "Z") + doAssert zz + doAssert ident == "foobar" + + const digits = {'0'..'9'} + var year = 0 + var idx2 = 0 + if scanp("201655-8-9", idx2, `digits`{4,6} -> (year = year * 10 + ord($_) - ord('0')), "-8", "-9"): + doAssert year == 201655 + + const gdbOut = """ + #0 @foo_96013_1208911747@8 (x0=...) + at c:/users/anwender/projects/nim/temp.nim:11 + #1 0x00417754 in tempInit000 () at c:/users/anwender/projects/nim/temp.nim:13 + #2 0x0041768d in NimMainInner () + at c:/users/anwender/projects/nim/lib/system.nim:2605 + #3 0x004176b1 in NimMain () + at c:/users/anwender/projects/nim/lib/system.nim:2613 + #4 0x004176db in main (argc=1, args=0x712cc8, env=0x711ca8) + at c:/users/anwender/projects/nim/lib/system.nim:2620""" + const result = @["foo c:/users/anwender/projects/nim/temp.nim:11", + "tempInit000 c:/users/anwender/projects/nim/temp.nim:13", + "NimMainInner c:/users/anwender/projects/nim/lib/system.nim:2605", + "NimMain c:/users/anwender/projects/nim/lib/system.nim:2613", + "main c:/users/anwender/projects/nim/lib/system.nim:2620"] + doAssert parseGDB(gdbOut) == result diff --git a/lib/pure/strutils.nim b/lib/pure/strutils.nim index f2c1e77e1..bfc32bc71 100644 --- a/lib/pure/strutils.nim +++ b/lib/pure/strutils.nim @@ -14,6 +14,8 @@ ## <backends.html#the-javascript-target>`_. import parseutils +from math import pow, round, floor, log10 +from algorithm import reverse {.deadCodeElim: on.} @@ -24,6 +26,12 @@ include "system/inclrtl" {.pop.} +# Support old split with set[char] +when defined(nimOldSplit): + {.pragma: deprecatedSplit, deprecated.} +else: + {.pragma: deprecatedSplit.} + type CharSet* {.deprecated.} = set[char] # for compatibility with Nim {.deprecated: [TCharSet: CharSet].} @@ -62,8 +70,8 @@ const ## doAssert "01234".find(invalid) == -1 ## doAssert "01A34".find(invalid) == 2 -proc isAlpha*(c: char): bool {.noSideEffect, procvar, - rtl, extern: "nsuIsAlphaChar".}= +proc isAlphaAscii*(c: char): bool {.noSideEffect, procvar, + rtl, extern: "nsuIsAlphaAsciiChar".}= ## Checks whether or not `c` is alphabetical. ## ## This checks a-z, A-Z ASCII characters only. @@ -83,27 +91,27 @@ proc isDigit*(c: char): bool {.noSideEffect, procvar, ## This checks 0-9 ASCII characters only. return c in Digits -proc isSpace*(c: char): bool {.noSideEffect, procvar, - rtl, extern: "nsuIsSpaceChar".}= +proc isSpaceAscii*(c: char): bool {.noSideEffect, procvar, + rtl, extern: "nsuIsSpaceAsciiChar".}= ## Checks whether or not `c` is a whitespace character. return c in Whitespace -proc isLower*(c: char): bool {.noSideEffect, procvar, - rtl, extern: "nsuIsLowerChar".}= +proc isLowerAscii*(c: char): bool {.noSideEffect, procvar, + rtl, extern: "nsuIsLowerAsciiChar".}= ## Checks whether or not `c` is a lower case character. ## ## This checks ASCII characters only. return c in {'a'..'z'} -proc isUpper*(c: char): bool {.noSideEffect, procvar, - rtl, extern: "nsuIsUpperChar".}= +proc isUpperAscii*(c: char): bool {.noSideEffect, procvar, + rtl, extern: "nsuIsUpperAsciiChar".}= ## Checks whether or not `c` is an upper case character. ## ## This checks ASCII characters only. return c in {'A'..'Z'} -proc isAlpha*(s: string): bool {.noSideEffect, procvar, - rtl, extern: "nsuIsAlphaStr".}= +proc isAlphaAscii*(s: string): bool {.noSideEffect, procvar, + rtl, extern: "nsuIsAlphaAsciiStr".}= ## Checks whether or not `s` is alphabetical. ## ## This checks a-z, A-Z ASCII characters only. @@ -115,7 +123,7 @@ proc isAlpha*(s: string): bool {.noSideEffect, procvar, result = true for c in s: - result = c.isAlpha() and result + result = c.isAlphaAscii() and result proc isAlphaNumeric*(s: string): bool {.noSideEffect, procvar, rtl, extern: "nsuIsAlphaNumericStr".}= @@ -147,8 +155,8 @@ proc isDigit*(s: string): bool {.noSideEffect, procvar, for c in s: result = c.isDigit() and result -proc isSpace*(s: string): bool {.noSideEffect, procvar, - rtl, extern: "nsuIsSpaceStr".}= +proc isSpaceAscii*(s: string): bool {.noSideEffect, procvar, + rtl, extern: "nsuIsSpaceAsciiStr".}= ## Checks whether or not `s` is completely whitespace. ## ## Returns true if all characters in `s` are whitespace @@ -158,10 +166,11 @@ proc isSpace*(s: string): bool {.noSideEffect, procvar, result = true for c in s: - result = c.isSpace() and result + if not c.isSpaceAscii(): + return false -proc isLower*(s: string): bool {.noSideEffect, procvar, - rtl, extern: "nsuIsLowerStr".}= +proc isLowerAscii*(s: string): bool {.noSideEffect, procvar, + rtl, extern: "nsuIsLowerAsciiStr".}= ## Checks whether or not `s` contains all lower case characters. ## ## This checks ASCII characters only. @@ -172,10 +181,10 @@ proc isLower*(s: string): bool {.noSideEffect, procvar, result = true for c in s: - result = c.isLower() and result + result = c.isLowerAscii() and result -proc isUpper*(s: string): bool {.noSideEffect, procvar, - rtl, extern: "nsuIsUpperStr".}= +proc isUpperAscii*(s: string): bool {.noSideEffect, procvar, + rtl, extern: "nsuIsUpperAsciiStr".}= ## Checks whether or not `s` contains all upper case characters. ## ## This checks ASCII characters only. @@ -186,10 +195,10 @@ proc isUpper*(s: string): bool {.noSideEffect, procvar, result = true for c in s: - result = c.isUpper() and result + result = c.isUpperAscii() and result -proc toLower*(c: char): char {.noSideEffect, procvar, - rtl, extern: "nsuToLowerChar".} = +proc toLowerAscii*(c: char): char {.noSideEffect, procvar, + rtl, extern: "nsuToLowerAsciiChar".} = ## Converts `c` into lower case. ## ## This works only for the letters ``A-Z``. See `unicode.toLower @@ -200,8 +209,8 @@ proc toLower*(c: char): char {.noSideEffect, procvar, else: result = c -proc toLower*(s: string): string {.noSideEffect, procvar, - rtl, extern: "nsuToLowerStr".} = +proc toLowerAscii*(s: string): string {.noSideEffect, procvar, + rtl, extern: "nsuToLowerAsciiStr".} = ## Converts `s` into lower case. ## ## This works only for the letters ``A-Z``. See `unicode.toLower @@ -209,10 +218,10 @@ proc toLower*(s: string): string {.noSideEffect, procvar, ## character. result = newString(len(s)) for i in 0..len(s) - 1: - result[i] = toLower(s[i]) + result[i] = toLowerAscii(s[i]) -proc toUpper*(c: char): char {.noSideEffect, procvar, - rtl, extern: "nsuToUpperChar".} = +proc toUpperAscii*(c: char): char {.noSideEffect, procvar, + rtl, extern: "nsuToUpperAsciiChar".} = ## Converts `c` into upper case. ## ## This works only for the letters ``A-Z``. See `unicode.toUpper @@ -223,8 +232,8 @@ proc toUpper*(c: char): char {.noSideEffect, procvar, else: result = c -proc toUpper*(s: string): string {.noSideEffect, procvar, - rtl, extern: "nsuToUpperStr".} = +proc toUpperAscii*(s: string): string {.noSideEffect, procvar, + rtl, extern: "nsuToUpperAsciiStr".} = ## Converts `s` into upper case. ## ## This works only for the letters ``A-Z``. See `unicode.toUpper @@ -232,14 +241,145 @@ proc toUpper*(s: string): string {.noSideEffect, procvar, ## character. result = newString(len(s)) for i in 0..len(s) - 1: - result[i] = toUpper(s[i]) + result[i] = toUpperAscii(s[i]) + +proc capitalizeAscii*(s: string): string {.noSideEffect, procvar, + rtl, extern: "nsuCapitalizeAscii".} = + ## Converts the first character of `s` into upper case. + ## + ## This works only for the letters ``A-Z``. + result = toUpperAscii(s[0]) & substr(s, 1) + +proc isSpace*(c: char): bool {.noSideEffect, procvar, + rtl, deprecated, extern: "nsuIsSpaceChar".}= + ## Checks whether or not `c` is a whitespace character. + ## + ## **Deprecated since version 0.15.0**: use ``isSpaceAscii`` instead. + isSpaceAscii(c) + +proc isLower*(c: char): bool {.noSideEffect, procvar, + rtl, deprecated, extern: "nsuIsLowerChar".}= + ## Checks whether or not `c` is a lower case character. + ## + ## This checks ASCII characters only. + ## + ## **Deprecated since version 0.15.0**: use ``isLowerAscii`` instead. + isLowerAscii(c) + +proc isUpper*(c: char): bool {.noSideEffect, procvar, + rtl, deprecated, extern: "nsuIsUpperChar".}= + ## Checks whether or not `c` is an upper case character. + ## + ## This checks ASCII characters only. + ## + ## **Deprecated since version 0.15.0**: use ``isUpperAscii`` instead. + isUpperAscii(c) + +proc isAlpha*(c: char): bool {.noSideEffect, procvar, + rtl, deprecated, extern: "nsuIsAlphaChar".}= + ## Checks whether or not `c` is alphabetical. + ## + ## This checks a-z, A-Z ASCII characters only. + ## + ## **Deprecated since version 0.15.0**: use ``isAlphaAscii`` instead. + isAlphaAscii(c) + +proc isAlpha*(s: string): bool {.noSideEffect, procvar, + rtl, deprecated, extern: "nsuIsAlphaStr".}= + ## Checks whether or not `s` is alphabetical. + ## + ## This checks a-z, A-Z ASCII characters only. + ## Returns true if all characters in `s` are + ## alphabetic and there is at least one character + ## in `s`. + ## + ## **Deprecated since version 0.15.0**: use ``isAlphaAscii`` instead. + isAlphaAscii(s) + +proc isSpace*(s: string): bool {.noSideEffect, procvar, + rtl, deprecated, extern: "nsuIsSpaceStr".}= + ## Checks whether or not `s` is completely whitespace. + ## + ## Returns true if all characters in `s` are whitespace + ## characters and there is at least one character in `s`. + ## + ## **Deprecated since version 0.15.0**: use ``isSpaceAscii`` instead. + isSpaceAscii(s) + +proc isLower*(s: string): bool {.noSideEffect, procvar, + rtl, deprecated, extern: "nsuIsLowerStr".}= + ## Checks whether or not `s` contains all lower case characters. + ## + ## This checks ASCII characters only. + ## Returns true if all characters in `s` are lower case + ## and there is at least one character in `s`. + ## + ## **Deprecated since version 0.15.0**: use ``isLowerAscii`` instead. + isLowerAscii(s) + +proc isUpper*(s: string): bool {.noSideEffect, procvar, + rtl, deprecated, extern: "nsuIsUpperStr".}= + ## Checks whether or not `s` contains all upper case characters. + ## + ## This checks ASCII characters only. + ## Returns true if all characters in `s` are upper case + ## and there is at least one character in `s`. + ## + ## **Deprecated since version 0.15.0**: use ``isUpperAscii`` instead. + isUpperAscii(s) + +proc toLower*(c: char): char {.noSideEffect, procvar, + rtl, deprecated, extern: "nsuToLowerChar".} = + ## Converts `c` into lower case. + ## + ## This works only for the letters ``A-Z``. See `unicode.toLower + ## <unicode.html#toLower>`_ for a version that works for any Unicode + ## character. + ## + ## **Deprecated since version 0.15.0**: use ``toLowerAscii`` instead. + toLowerAscii(c) + +proc toLower*(s: string): string {.noSideEffect, procvar, + rtl, deprecated, extern: "nsuToLowerStr".} = + ## Converts `s` into lower case. + ## + ## This works only for the letters ``A-Z``. See `unicode.toLower + ## <unicode.html#toLower>`_ for a version that works for any Unicode + ## character. + ## + ## **Deprecated since version 0.15.0**: use ``toLowerAscii`` instead. + toLowerAscii(s) + +proc toUpper*(c: char): char {.noSideEffect, procvar, + rtl, deprecated, extern: "nsuToUpperChar".} = + ## Converts `c` into upper case. + ## + ## This works only for the letters ``A-Z``. See `unicode.toUpper + ## <unicode.html#toUpper>`_ for a version that works for any Unicode + ## character. + ## + ## **Deprecated since version 0.15.0**: use ``toUpperAscii`` instead. + toUpperAscii(c) + +proc toUpper*(s: string): string {.noSideEffect, procvar, + rtl, deprecated, extern: "nsuToUpperStr".} = + ## Converts `s` into upper case. + ## + ## This works only for the letters ``A-Z``. See `unicode.toUpper + ## <unicode.html#toUpper>`_ for a version that works for any Unicode + ## character. + ## + ## **Deprecated since version 0.15.0**: use ``toUpperAscii`` instead. + toUpperAscii(s) proc capitalize*(s: string): string {.noSideEffect, procvar, - rtl, extern: "nsuCapitalize".} = + rtl, deprecated, extern: "nsuCapitalize".} = ## Converts the first character of `s` into upper case. ## ## This works only for the letters ``A-Z``. - result = toUpper(s[0]) & substr(s, 1) + ## + ## **Deprecated since version 0.15.0**: use ``capitalizeAscii`` instead. + capitalizeAscii(s) proc normalize*(s: string): string {.noSideEffect, procvar, rtl, extern: "nsuNormalize".} = @@ -268,7 +408,7 @@ proc cmpIgnoreCase*(a, b: string): int {.noSideEffect, var i = 0 var m = min(a.len, b.len) while i < m: - result = ord(toLower(a[i])) - ord(toLower(b[i])) + result = ord(toLowerAscii(a[i])) - ord(toLowerAscii(b[i])) if result != 0: return inc(i) result = a.len - b.len @@ -289,8 +429,8 @@ proc cmpIgnoreStyle*(a, b: string): int {.noSideEffect, while true: while a[i] == '_': inc(i) while b[j] == '_': inc(j) # BUGFIX: typo - var aa = toLower(a[i]) - var bb = toLower(b[j]) + var aa = toLowerAscii(a[i]) + var bb = toLowerAscii(b[j]) result = ord(aa) - ord(bb) if result != 0 or aa == '\0': break inc(i) @@ -324,16 +464,77 @@ proc toOctal*(c: char): string {.noSideEffect, rtl, extern: "nsuToOctal".} = result[i] = chr(val mod 8 + ord('0')) val = val div 8 -iterator split*(s: string, seps: set[char] = Whitespace): string = +proc isNilOrEmpty*(s: string): bool {.noSideEffect, procvar, rtl, extern: "nsuIsNilOrEmpty".} = + ## Checks if `s` is nil or empty. + result = len(s) == 0 + +proc isNilOrWhitespace*(s: string): bool {.noSideEffect, procvar, rtl, extern: "nsuIsNilOrWhitespace".} = + ## Checks if `s` is nil or consists entirely of whitespace characters. + if len(s) == 0: + return true + + result = true + for c in s: + if not c.isSpaceAscii(): + return false + +proc substrEq(s: string, pos: int, substr: string): bool = + var i = 0 + var length = substr.len + while i < length and s[pos+i] == substr[i]: + inc i + + return i == length + +# --------- Private templates for different split separators ----------- + +template stringHasSep(s: string, index: int, seps: set[char]): bool = + s[index] in seps + +template stringHasSep(s: string, index: int, sep: char): bool = + s[index] == sep + +template stringHasSep(s: string, index: int, sep: string): bool = + s.substrEq(index, sep) + +template splitCommon(s, sep, maxsplit, sepLen) = + ## Common code for split procedures + var last = 0 + var splits = maxsplit + + if len(s) > 0: + while last <= len(s): + var first = last + while last < len(s) and not stringHasSep(s, last, sep): + inc(last) + if splits == 0: last = len(s) + yield substr(s, first, last-1) + if splits == 0: break + dec(splits) + inc(last, sepLen) + +template oldSplit(s, seps, maxsplit) = + var last = 0 + var splits = maxsplit + assert(not ('\0' in seps)) + while last < len(s): + while s[last] in seps: inc(last) + var first = last + while last < len(s) and s[last] notin seps: inc(last) + if first <= last-1: + if splits == 0: last = len(s) + yield substr(s, first, last-1) + if splits == 0: break + dec(splits) + +iterator split*(s: string, seps: set[char] = Whitespace, + maxsplit: int = -1): string = ## Splits the string `s` into substrings using a group of separators. ## - ## Substrings are separated by a substring containing only `seps`. Note - ## that whole sequences of characters found in ``seps`` will be counted as - ## a single split point and leading/trailing separators will be ignored. - ## The following example: + ## Substrings are separated by a substring containing only `seps`. ## ## .. code-block:: nim - ## for word in split(" this is an example "): + ## for word in split("this\lis an\texample"): ## writeLine(stdout, word) ## ## ...generates this output: @@ -347,7 +548,7 @@ iterator split*(s: string, seps: set[char] = Whitespace): string = ## And the following code: ## ## .. code-block:: nim - ## for word in split(";;this;is;an;;example;;;", {';'}): + ## for word in split("this:is;an$example", {';', ':', '$'}): ## writeLine(stdout, word) ## ## ...produces the same output as the first example. The code: @@ -368,22 +569,26 @@ iterator split*(s: string, seps: set[char] = Whitespace): string = ## "08" ## "08.398990" ## - var last = 0 - assert(not ('\0' in seps)) - while last < len(s): - while s[last] in seps: inc(last) - var first = last - while last < len(s) and s[last] notin seps: inc(last) # BUGFIX! - if first <= last-1: - yield substr(s, first, last-1) + when defined(nimOldSplit): + oldSplit(s, seps, maxsplit) + else: + splitCommon(s, seps, maxsplit, 1) + +iterator splitWhitespace*(s: string): string = + ## Splits at whitespace. + oldSplit(s, Whitespace, -1) -iterator split*(s: string, sep: char): string = +proc splitWhitespace*(s: string): seq[string] {.noSideEffect, + rtl, extern: "nsuSplitWhitespace".} = + ## The same as the `splitWhitespace <#splitWhitespace.i,string>`_ + ## iterator, but is a proc that returns a sequence of substrings. + accumulateResult(splitWhitespace(s)) + +iterator split*(s: string, sep: char, maxsplit: int = -1): string = ## Splits the string `s` into substrings using a single separator. ## ## Substrings are separated by the character `sep`. - ## Unlike the version of the iterator which accepts a set of separator - ## characters, this proc will not coalesce groups of the - ## separator, returning a string for each found character. The code: + ## The code: ## ## .. code-block:: nim ## for word in split(";;this;is;an;;example;;;", ';'): @@ -403,28 +608,118 @@ iterator split*(s: string, sep: char): string = ## "" ## "" ## - var last = 0 - assert('\0' != sep) - if len(s) > 0: - # `<=` is correct here for the edge cases! - while last <= len(s): - var first = last - while last < len(s) and s[last] != sep: inc(last) - yield substr(s, first, last-1) - inc(last) + splitCommon(s, sep, maxsplit, 1) -iterator split*(s: string, sep: string): string = +iterator split*(s: string, sep: string, maxsplit: int = -1): string = ## Splits the string `s` into substrings using a string separator. ## ## Substrings are separated by the string `sep`. - var last = 0 + ## The code: + ## + ## .. code-block:: nim + ## for word in split("thisDATAisDATAcorrupted", "DATA"): + ## writeLine(stdout, word) + ## + ## Results in: + ## + ## .. code-block:: + ## "this" + ## "is" + ## "corrupted" + ## + + splitCommon(s, sep, maxsplit, sep.len) + +template rsplitCommon(s, sep, maxsplit, sepLen) = + ## Common code for rsplit functions + var + last = s.len - 1 + first = last + splits = maxsplit + startPos = 0 + if len(s) > 0: - while last <= len(s): - var first = last - while last < len(s) and s.substr(last, last + <sep.len) != sep: - inc(last) - yield substr(s, first, last-1) - inc(last, sep.len) + # go to -1 in order to get separators at the beginning + while first >= -1: + while first >= 0 and not stringHasSep(s, first, sep): + dec(first) + + if splits == 0: + # No more splits means set first to the beginning + first = -1 + + if first == -1: + startPos = 0 + else: + startPos = first + sepLen + + yield substr(s, startPos, last) + + if splits == 0: + break + + dec(splits) + dec(first) + + last = first + +iterator rsplit*(s: string, seps: set[char] = Whitespace, + maxsplit: int = -1): string = + ## Splits the string `s` into substrings from the right using a + ## string separator. Works exactly the same as `split iterator + ## <#split.i,string,char>`_ except in reverse order. + ## + ## .. code-block:: nim + ## for piece in "foo bar".rsplit(WhiteSpace): + ## echo piece + ## + ## Results in: + ## + ## .. code-block:: nim + ## "bar" + ## "foo" + ## + ## Substrings are separated from the right by the set of chars `seps` + + rsplitCommon(s, seps, maxsplit, 1) + +iterator rsplit*(s: string, sep: char, + maxsplit: int = -1): string = + ## Splits the string `s` into substrings from the right using a + ## string separator. Works exactly the same as `split iterator + ## <#split.i,string,char>`_ except in reverse order. + ## + ## .. code-block:: nim + ## for piece in "foo:bar".rsplit(':'): + ## echo piece + ## + ## Results in: + ## + ## .. code-block:: nim + ## "bar" + ## "foo" + ## + ## Substrings are separated from the right by the char `sep` + rsplitCommon(s, sep, maxsplit, 1) + +iterator rsplit*(s: string, sep: string, maxsplit: int = -1, + keepSeparators: bool = false): string = + ## Splits the string `s` into substrings from the right using a + ## string separator. Works exactly the same as `split iterator + ## <#split.i,string,string>`_ except in reverse order. + ## + ## .. code-block:: nim + ## for piece in "foothebar".rsplit("the"): + ## echo piece + ## + ## Results in: + ## + ## .. code-block:: nim + ## "bar" + ## "foo" + ## + ## Substrings are separated from the right by the string `sep` + rsplitCommon(s, sep, maxsplit, sep.len) iterator splitLines*(s: string): string = ## Splits the string `s` into its containing lines. @@ -493,25 +788,92 @@ proc countLines*(s: string): int {.noSideEffect, else: discard inc i -proc split*(s: string, seps: set[char] = Whitespace): seq[string] {. +proc split*(s: string, seps: set[char] = Whitespace, maxsplit: int = -1): seq[string] {. noSideEffect, rtl, extern: "nsuSplitCharSet".} = ## The same as the `split iterator <#split.i,string,set[char]>`_, but is a ## proc that returns a sequence of substrings. - accumulateResult(split(s, seps)) + accumulateResult(split(s, seps, maxsplit)) -proc split*(s: string, sep: char): seq[string] {.noSideEffect, +proc split*(s: string, sep: char, maxsplit: int = -1): seq[string] {.noSideEffect, rtl, extern: "nsuSplitChar".} = ## The same as the `split iterator <#split.i,string,char>`_, but is a proc ## that returns a sequence of substrings. - accumulateResult(split(s, sep)) + accumulateResult(split(s, sep, maxsplit)) -proc split*(s: string, sep: string): seq[string] {.noSideEffect, +proc split*(s: string, sep: string, maxsplit: int = -1): seq[string] {.noSideEffect, rtl, extern: "nsuSplitString".} = ## Splits the string `s` into substrings using a string separator. ## ## Substrings are separated by the string `sep`. This is a wrapper around the ## `split iterator <#split.i,string,string>`_. - accumulateResult(split(s, sep)) + accumulateResult(split(s, sep, maxsplit)) + +proc rsplit*(s: string, seps: set[char] = Whitespace, + maxsplit: int = -1): seq[string] + {.noSideEffect, rtl, extern: "nsuRSplitCharSet".} = + ## The same as the `rsplit iterator <#rsplit.i,string,set[char]>`_, but is a + ## proc that returns a sequence of substrings. + ## + ## A possible common use case for `rsplit` is path manipulation, + ## particularly on systems that don't use a common delimiter. + ## + ## For example, if a system had `#` as a delimiter, you could + ## do the following to get the tail of the path: + ## + ## .. code-block:: nim + ## var tailSplit = rsplit("Root#Object#Method#Index", {'#'}, maxsplit=1) + ## + ## Results in `tailSplit` containing: + ## + ## .. code-block:: nim + ## @["Root#Object#Method", "Index"] + ## + accumulateResult(rsplit(s, seps, maxsplit)) + result.reverse() + +proc rsplit*(s: string, sep: char, maxsplit: int = -1): seq[string] + {.noSideEffect, rtl, extern: "nsuRSplitChar".} = + ## The same as the `split iterator <#rsplit.i,string,char>`_, but is a proc + ## that returns a sequence of substrings. + ## + ## A possible common use case for `rsplit` is path manipulation, + ## particularly on systems that don't use a common delimiter. + ## + ## For example, if a system had `#` as a delimiter, you could + ## do the following to get the tail of the path: + ## + ## .. code-block:: nim + ## var tailSplit = rsplit("Root#Object#Method#Index", '#', maxsplit=1) + ## + ## Results in `tailSplit` containing: + ## + ## .. code-block:: nim + ## @["Root#Object#Method", "Index"] + ## + accumulateResult(rsplit(s, sep, maxsplit)) + result.reverse() + +proc rsplit*(s: string, sep: string, maxsplit: int = -1): seq[string] + {.noSideEffect, rtl, extern: "nsuRSplitString".} = + ## The same as the `split iterator <#rsplit.i,string,string>`_, but is a proc + ## that returns a sequence of substrings. + ## + ## A possible common use case for `rsplit` is path manipulation, + ## particularly on systems that don't use a common delimiter. + ## + ## For example, if a system had `#` as a delimiter, you could + ## do the following to get the tail of the path: + ## + ## .. code-block:: nim + ## var tailSplit = rsplit("Root#Object#Method#Index", "#", maxsplit=1) + ## + ## Results in `tailSplit` containing: + ## + ## .. code-block:: nim + ## @["Root#Object#Method", "Index"] + ## + accumulateResult(rsplit(s, sep, maxsplit)) + result.reverse() proc toHex*(x: BiggestInt, len: Positive): string {.noSideEffect, rtl, extern: "nsuToHex".} = @@ -530,6 +892,10 @@ proc toHex*(x: BiggestInt, len: Positive): string {.noSideEffect, # handle negative overflow if n == 0 and x < 0: n = -1 +proc toHex*[T](x: T): string = + ## Shortcut for ``toHex(x, T.sizeOf * 2)`` + toHex(x, T.sizeOf * 2) + proc intToStr*(x: int, minchars: Positive = 1): string {.noSideEffect, rtl, extern: "nsuIntToStr".} = ## Converts `x` to its decimal representation. @@ -560,6 +926,24 @@ proc parseBiggestInt*(s: string): BiggestInt {.noSideEffect, procvar, if L != s.len or L == 0: raise newException(ValueError, "invalid integer: " & s) +proc parseUInt*(s: string): uint {.noSideEffect, procvar, + rtl, extern: "nsuParseUInt".} = + ## Parses a decimal unsigned integer value contained in `s`. + ## + ## If `s` is not a valid integer, `ValueError` is raised. + var L = parseutils.parseUInt(s, result, 0) + if L != s.len or L == 0: + raise newException(ValueError, "invalid unsigned integer: " & s) + +proc parseBiggestUInt*(s: string): uint64 {.noSideEffect, procvar, + rtl, extern: "nsuParseBiggestUInt".} = + ## Parses a decimal unsigned integer value contained in `s`. + ## + ## If `s` is not a valid integer, `ValueError` is raised. + var L = parseutils.parseBiggestUInt(s, result, 0) + if L != s.len or L == 0: + raise newException(ValueError, "invalid unsigned integer: " & s) + proc parseFloat*(s: string): float {.noSideEffect, procvar, rtl, extern: "nsuParseFloat".} = ## Parses a decimal floating point value contained in `s`. If `s` is not @@ -651,7 +1035,7 @@ proc repeat*(s: string, n: Natural): string {.noSideEffect, result = newStringOfCap(n * s.len) for i in 1..n: result.add(s) -template spaces*(n: Natural): string = repeat(' ',n) +template spaces*(n: Natural): string = repeat(' ', n) ## Returns a String with `n` space characters. You can use this proc ## to left align strings. Example: ## @@ -766,8 +1150,7 @@ proc indent*(s: string, count: Natural, padding: string = " "): string {.noSideEffect, rtl, extern: "nsuIndent".} = ## Indents each line in ``s`` by ``count`` amount of ``padding``. ## - ## **Note:** This currently does not preserve the specific new line characters - ## used. + ## **Note:** This does not preserve the new line characters used in ``s``. result = "" var i = 0 for line in s.splitLines(): @@ -778,32 +1161,39 @@ proc indent*(s: string, count: Natural, padding: string = " "): string result.add(line) i.inc -proc unindent*(s: string, eatAllIndent = false): string {. - noSideEffect, rtl, extern: "nsuUnindent".} = - ## Unindents `s`. - result = newStringOfCap(s.len) +proc unindent*(s: string, count: Natural, padding: string = " "): string + {.noSideEffect, rtl, extern: "nsuUnindent".} = + ## Unindents each line in ``s`` by ``count`` amount of ``padding``. + ## + ## **Note:** This does not preserve the new line characters used in ``s``. + result = "" var i = 0 - var pattern = true - var indent = 0 - while s[i] == ' ': inc i - var level = if i == 0: -1 else: i - while i < s.len: - if s[i] == ' ': - if i > 0 and s[i-1] in {'\l', '\c'}: - pattern = true - indent = 0 - if pattern: - inc(indent) - if indent > level and not eatAllIndent: - result.add(s[i]) - if level < 0: level = indent - else: - # a space somewhere: do not delete - result.add(s[i]) - else: - pattern = false - result.add(s[i]) - inc i + for line in s.splitLines(): + if i != 0: + result.add("\n") + var indentCount = 0 + for j in 0..<count.int: + indentCount.inc + if line[j .. j + <padding.len] != padding: + indentCount = j + break + result.add(line[indentCount*padding.len .. ^1]) + i.inc + +proc unindent*(s: string): string + {.noSideEffect, rtl, extern: "nsuUnindentAll".} = + ## Removes all indentation composed of whitespace from each line in ``s``. + ## + ## For example: + ## + ## .. code-block:: nim + ## const x = """ + ## Hello + ## There + ## """.unindent() + ## + ## doAssert x == "Hello\nThere\n" + unindent(s, 1000) # TODO: Passing a 1000 is a bit hackish. proc startsWith*(s, prefix: string): bool {.noSideEffect, rtl, extern: "nsuStartsWith".} = @@ -816,6 +1206,10 @@ proc startsWith*(s, prefix: string): bool {.noSideEffect, if s[i] != prefix[i]: return false inc(i) +proc startsWith*(s: string, prefix: char): bool {.noSideEffect, inline.} = + ## Returns true iff ``s`` starts with ``prefix``. + result = s[0] == prefix + proc endsWith*(s, suffix: string): bool {.noSideEffect, rtl, extern: "nsuEndsWith".} = ## Returns true iff ``s`` ends with ``suffix``. @@ -828,6 +1222,10 @@ proc endsWith*(s, suffix: string): bool {.noSideEffect, inc(i) if suffix[i] == '\0': return true +proc endsWith*(s: string, suffix: char): bool {.noSideEffect, inline.} = + ## Returns true iff ``s`` ends with ``suffix``. + result = s[s.high] == suffix + proc continuesWith*(s, substr: string, start: Natural): bool {.noSideEffect, rtl, extern: "nsuContinuesWith".} = ## Returns true iff ``s`` continues with ``substr`` at position ``start``. @@ -981,6 +1379,34 @@ proc rfind*(s: string, sub: char, start: int = -1): int {.noSideEffect, if sub == s[i]: return i return -1 +proc center*(s: string, width: int, fillChar: char = ' '): string {. + noSideEffect, rtl, extern: "nsuCenterString".} = + ## Return the contents of `s` centered in a string `width` long using + ## `fillChar` as padding. + ## + ## The original string is returned if `width` is less than or equal + ## to `s.len`. + if width <= s.len: + return s + + result = newString(width) + + # Left padding will be one fillChar + # smaller if there are an odd number + # of characters + let + charsLeft = (width - s.len) + leftPadding = charsLeft div 2 + + for i in 0 ..< width: + if i >= leftPadding and i < leftPadding + s.len: + # we are where the string should be located + result[i] = s[i-leftPadding] + else: + # we are either before or after where + # the string s should go + result[i] = fillChar + proc count*(s: string, sub: string, overlapping: bool = false): int {. noSideEffect, rtl, extern: "nsuCountString".} = ## Count the occurrences of a substring `sub` in the string `s`. @@ -1420,28 +1846,216 @@ proc formatFloat*(f: float, format: FloatFormatMode = ffDefault, ## after the decimal point for Nim's ``float`` type. result = formatBiggestFloat(f, format, precision, decimalSep) -proc formatSize*(bytes: BiggestInt, decimalSep = '.'): string = - ## Rounds and formats `bytes`. Examples: +proc trimZeros*(x: var string) {.noSideEffect.} = + ## Trim trailing zeros from a formatted floating point + ## value (`x`). Modifies the passed value. + var spl: seq[string] + if x.contains('.') or x.contains(','): + if x.contains('e'): + spl= x.split('e') + x = spl[0] + while x[x.high] == '0': + x.setLen(x.len-1) + if x[x.high] in [',', '.']: + x.setLen(x.len-1) + if spl.len > 0: + x &= "e" & spl[1] + +type + BinaryPrefixMode* = enum ## the different names for binary prefixes + bpIEC, # use the IEC/ISO standard prefixes such as kibi + bpColloquial # use the colloquial kilo, mega etc + +proc formatSize*(bytes: int64, + decimalSep = '.', + prefix = bpIEC, + includeSpace = false): string {.noSideEffect.} = + ## Rounds and formats `bytes`. + ## + ## By default, uses the IEC/ISO standard binary prefixes, so 1024 will be + ## formatted as 1KiB. Set prefix to `bpColloquial` to use the colloquial + ## names from the SI standard (e.g. k for 1000 being reused as 1024). + ## + ## `includeSpace` can be set to true to include the (SI preferred) space + ## between the number and the unit (e.g. 1 KiB). + ## + ## Examples: ## ## .. code-block:: nim ## - ## formatSize(1'i64 shl 31 + 300'i64) == "2.204GB" - ## formatSize(4096) == "4KB" - ## - template frmt(a, b, c: expr): expr = - let bs = $b - insertSep($a) & decimalSep & bs.substr(0, 2) & c - let gigabytes = bytes shr 30 - let megabytes = bytes shr 20 - let kilobytes = bytes shr 10 - if gigabytes != 0: - result = frmt(gigabytes, megabytes, "GB") - elif megabytes != 0: - result = frmt(megabytes, kilobytes, "MB") - elif kilobytes != 0: - result = frmt(kilobytes, bytes, "KB") + ## formatSize((1'i64 shl 31) + (300'i64 shl 20)) == "2.293GiB" + ## formatSize((2.234*1024*1024).int) == "2.234MiB" + ## formatSize(4096, includeSpace=true) == "4 KiB" + ## formatSize(4096, prefix=bpColloquial, includeSpace=true) == "4 kB" + ## formatSize(4096) == "4KiB" + ## formatSize(5_378_934, prefix=bpColloquial, decimalSep=',') == "5,13MB" + ## + const iecPrefixes = ["", "Ki", "Mi", "Gi", "Ti", "Pi", "Ei", "Zi", "Yi"] + const collPrefixes = ["", "k", "M", "G", "T", "P", "E", "Z", "Y"] + var + xb: int64 = bytes + fbytes: float + last_xb: int64 = bytes + matchedIndex: int + prefixes: array[9, string] + if prefix == bpColloquial: + prefixes = collPrefixes else: - result = insertSep($bytes) & "B" + prefixes = iecPrefixes + + # Iterate through prefixes seeing if value will be greater than + # 0 in each case + for index in 1..<prefixes.len: + last_xb = xb + xb = bytes div (1'i64 shl (index*10)) + matchedIndex = index + if xb == 0: + xb = last_xb + matchedIndex = index - 1 + break + # xb has the integer number for the latest value; index should be correct + fbytes = bytes.float / (1'i64 shl (matchedIndex*10)).float + result = formatFloat(fbytes, format=ffDecimal, precision=3, decimalSep=decimalSep) + result.trimZeros() + if includeSpace: + result &= " " + result &= prefixes[matchedIndex] + result &= "B" + +proc formatEng*(f: BiggestFloat, + precision: range[0..32] = 10, + trim: bool = true, + siPrefix: bool = false, + unit: string = nil, + decimalSep = '.'): string {.noSideEffect.} = + ## Converts a floating point value `f` to a string using engineering notation. + ## + ## Numbers in of the range -1000.0<f<1000.0 will be formatted without an + ## exponent. Numbers outside of this range will be formatted as a + ## significand in the range -1000.0<f<1000.0 and an exponent that will always + ## be an integer multiple of 3, corresponding with the SI prefix scale k, M, + ## G, T etc for numbers with an absolute value greater than 1 and m, μ, n, p + ## etc for numbers with an absolute value less than 1. + ## + ## The default configuration (`trim=true` and `precision=10`) shows the + ## **shortest** form that precisely (up to a maximum of 10 decimal places) + ## displays the value. For example, 4.100000 will be displayed as 4.1 (which + ## is mathematically identical) whereas 4.1000003 will be displayed as + ## 4.1000003. + ## + ## If `trim` is set to true, trailing zeros will be removed; if false, the + ## number of digits specified by `precision` will always be shown. + ## + ## `precision` can be used to set the number of digits to be shown after the + ## decimal point or (if `trim` is true) the maximum number of digits to be + ## shown. + ## + ## .. code-block:: nim + ## + ## formatEng(0, 2, trim=false) == "0.00" + ## formatEng(0, 2) == "0" + ## formatEng(0.053, 0) == "53e-3" + ## formatEng(52731234, 2) == "52.73e6" + ## formatEng(-52731234, 2) == "-52.73e6" + ## + ## If `siPrefix` is set to true, the number will be displayed with the SI + ## prefix corresponding to the exponent. For example 4100 will be displayed + ## as "4.1 k" instead of "4.1e3". Note that `u` is used for micro- in place + ## of the greek letter mu (μ) as per ISO 2955. Numbers with an absolute + ## value outside of the range 1e-18<f<1000e18 (1a<f<1000E) will be displayed + ## with an exponent rather than an SI prefix, regardless of whether + ## `siPrefix` is true. + ## + ## If `unit` is not nil, the provided unit will be appended to the string + ## (with a space as required by the SI standard). This behaviour is slightly + ## different to appending the unit to the result as the location of the space + ## is altered depending on whether there is an exponent. + ## + ## .. code-block:: nim + ## + ## formatEng(4100, siPrefix=true, unit="V") == "4.1 kV" + ## formatEng(4.1, siPrefix=true, unit="V") == "4.1 V" + ## formatEng(4.1, siPrefix=true) == "4.1" # Note lack of space + ## formatEng(4100, siPrefix=true) == "4.1 k" + ## formatEng(4.1, siPrefix=true, unit="") == "4.1 " # Space with unit="" + ## formatEng(4100, siPrefix=true, unit="") == "4.1 k" + ## formatEng(4100) == "4.1e3" + ## formatEng(4100, unit="V") == "4.1e3 V" + ## formatEng(4100, unit="") == "4.1e3 " # Space with unit="" + ## + ## `decimalSep` is used as the decimal separator + var + absolute: BiggestFloat + significand: BiggestFloat + fexponent: BiggestFloat + exponent: int + splitResult: seq[string] + suffix: string = "" + proc getPrefix(exp: int): char = + ## Get the SI prefix for a given exponent + ## + ## Assumes exponent is a multiple of 3; returns ' ' if no prefix found + const siPrefixes = ['a','f','p','n','u','m',' ','k','M','G','T','P','E'] + var index: int = (exp div 3) + 6 + result = ' ' + if index in low(siPrefixes)..high(siPrefixes): + result = siPrefixes[index] + + # Most of the work is done with the sign ignored, so get the absolute value + absolute = abs(f) + significand = f + + if absolute == 0.0: + # Simple case: just format it and force the exponent to 0 + exponent = 0 + result = significand.formatBiggestFloat(ffDecimal, precision, decimalSep='.') + else: + # Find the best exponent that's a multiple of 3 + fexponent = round(floor(log10(absolute))) + fexponent = 3.0 * round(floor(fexponent / 3.0)) + # Adjust the significand for the new exponent + significand /= pow(10.0, fexponent) + + # Round the significand and check whether it has affected + # the exponent + significand = round(significand, precision) + absolute = abs(significand) + if absolute >= 1000.0: + significand *= 0.001 + fexponent += 3 + # Components of the result: + result = significand.formatBiggestFloat(ffDecimal, precision, decimalSep='.') + exponent = fexponent.int() + + splitResult = result.split('.') + result = splitResult[0] + # result should have at most one decimal character + if splitResult.len() > 1: + # If trim is set, we get rid of trailing zeros. Don't use trimZeros here as + # we can be a bit more efficient through knowledge that there will never be + # an exponent in this part. + if trim: + while splitResult[1].endsWith("0"): + # Trim last character + splitResult[1].setLen(splitResult[1].len-1) + if splitResult[1].len() > 0: + result &= decimalSep & splitResult[1] + else: + result &= decimalSep & splitResult[1] + + # Combine the results accordingly + if siPrefix and exponent != 0: + var p = getPrefix(exponent) + if p != ' ': + suffix = " " & p + exponent = 0 # Exponent replaced by SI prefix + if suffix == "" and unit != nil: + suffix = " " + if unit != nil: + suffix &= unit + if exponent != 0: + result &= "e" & $exponent + result &= suffix proc findNormalized(x: string, inArray: openArray[string]): int = var i = 0 @@ -1637,9 +2251,14 @@ when isMainModule: ["1,0e-11", "1,0e-011"] doAssert "$# $3 $# $#" % ["a", "b", "c"] == "a c b c" - when not defined(testing): - echo formatSize(1'i64 shl 31 + 300'i64) # == "4,GB" - echo formatSize(1'i64 shl 31) + + block: # formatSize tests + doAssert formatSize((1'i64 shl 31) + (300'i64 shl 20)) == "2.293GiB" + doAssert formatSize((2.234*1024*1024).int) == "2.234MiB" + doAssert formatSize(4096) == "4KiB" + doAssert formatSize(4096, prefix=bpColloquial, includeSpace=true) == "4 kB" + doAssert formatSize(4096, includeSpace=true) == "4 KiB" + doAssert formatSize(5_378_934, prefix=bpColloquial, decimalSep=',') == "5,13MB" doAssert "$animal eats $food." % ["animal", "The cat", "food", "fish"] == "The cat eats fish." @@ -1652,6 +2271,11 @@ when isMainModule: doAssert parseEnum("invalid enum value", enC) == enC + doAssert center("foo", 13) == " foo " + doAssert center("foo", 0) == "foo" + doAssert center("foo", 3, fillChar = 'a') == "foo" + doAssert center("foo", 10, fillChar = '\t') == "\t\t\tfoo\t\t\t\t" + doAssert count("foofoofoo", "foofoo") == 1 doAssert count("foofoofoo", "foofoo", overlapping = true) == 2 doAssert count("foofoofoo", 'f') == 3 @@ -1668,13 +2292,13 @@ when isMainModule: doAssert " foo\n bar".indent(4, "Q") == "QQQQ foo\nQQQQ bar" - doAssert isAlpha('r') - doAssert isAlpha('A') - doAssert(not isAlpha('$')) + doAssert isAlphaAscii('r') + doAssert isAlphaAscii('A') + doAssert(not isAlphaAscii('$')) - doAssert isAlpha("Rasp") - doAssert isAlpha("Args") - doAssert(not isAlpha("$Tomato")) + doAssert isAlphaAscii("Rasp") + doAssert isAlphaAscii("Args") + doAssert(not isAlphaAscii("$Tomato")) doAssert isAlphaNumeric('3') doAssert isAlphaNumeric('R') @@ -1693,35 +2317,131 @@ when isMainModule: doAssert(not isDigit("12.33")) doAssert(not isDigit("A45b")) - doAssert isSpace('\t') - doAssert isSpace('\l') - doAssert(not isSpace('A')) - - doAssert isSpace("\t\l \v\r\f") - doAssert isSpace(" ") - doAssert(not isSpace("ABc \td")) + doAssert isSpaceAscii('\t') + doAssert isSpaceAscii('\l') + doAssert(not isSpaceAscii('A')) + + doAssert isSpaceAscii("\t\l \v\r\f") + doAssert isSpaceAscii(" ") + doAssert(not isSpaceAscii("ABc \td")) + + doAssert(isNilOrEmpty("")) + doAssert(isNilOrEmpty(nil)) + doAssert(not isNilOrEmpty("test")) + doAssert(not isNilOrEmpty(" ")) + + doAssert(isNilOrWhitespace("")) + doAssert(isNilOrWhitespace(nil)) + doAssert(isNilOrWhitespace(" ")) + doAssert(isNilOrWhitespace("\t\l \v\r\f")) + doAssert(not isNilOrWhitespace("ABc \td")) + + doAssert isLowerAscii('a') + doAssert isLowerAscii('z') + doAssert(not isLowerAscii('A')) + doAssert(not isLowerAscii('5')) + doAssert(not isLowerAscii('&')) + + doAssert isLowerAscii("abcd") + doAssert(not isLowerAscii("abCD")) + doAssert(not isLowerAscii("33aa")) + + doAssert isUpperAscii('A') + doAssert(not isUpperAscii('b')) + doAssert(not isUpperAscii('5')) + doAssert(not isUpperAscii('%')) + + doAssert isUpperAscii("ABC") + doAssert(not isUpperAscii("AAcc")) + doAssert(not isUpperAscii("A#$")) + + doAssert rsplit("foo bar", seps=Whitespace) == @["foo", "bar"] + doAssert rsplit(" foo bar", seps=Whitespace, maxsplit=1) == @[" foo", "bar"] + doAssert rsplit(" foo bar ", seps=Whitespace, maxsplit=1) == @[" foo bar", ""] + doAssert rsplit(":foo:bar", sep=':') == @["", "foo", "bar"] + doAssert rsplit(":foo:bar", sep=':', maxsplit=2) == @["", "foo", "bar"] + doAssert rsplit(":foo:bar", sep=':', maxsplit=3) == @["", "foo", "bar"] + doAssert rsplit("foothebar", sep="the") == @["foo", "bar"] - doAssert isLower('a') - doAssert isLower('z') - doAssert(not isLower('A')) - doAssert(not isLower('5')) - doAssert(not isLower('&')) - - doAssert isLower("abcd") - doAssert(not isLower("abCD")) - doAssert(not isLower("33aa")) - - doAssert isUpper('A') - doAssert(not isUpper('b')) - doAssert(not isUpper('5')) - doAssert(not isUpper('%')) - - doAssert isUpper("ABC") - doAssert(not isUpper("AAcc")) - doAssert(not isUpper("A#$")) doAssert(unescape(r"\x013", "", "") == "\x013") doAssert join(["foo", "bar", "baz"]) == "foobarbaz" doAssert join(@["foo", "bar", "baz"], ", ") == "foo, bar, baz" doAssert join([1, 2, 3]) == "123" doAssert join(@[1, 2, 3], ", ") == "1, 2, 3" + + doAssert """~~!!foo +~~!!bar +~~!!baz""".unindent(2, "~~!!") == "foo\nbar\nbaz" + + doAssert """~~!!foo +~~!!bar +~~!!baz""".unindent(2, "~~!!aa") == "~~!!foo\n~~!!bar\n~~!!baz" + doAssert """~~foo +~~ bar +~~ baz""".unindent(4, "~") == "foo\n bar\n baz" + doAssert """foo +bar + baz + """.unindent(4) == "foo\nbar\nbaz\n" + doAssert """foo + bar + baz + """.unindent(2) == "foo\n bar\n baz\n" + doAssert """foo + bar + baz + """.unindent(100) == "foo\nbar\nbaz\n" + + doAssert """foo + foo + bar + """.unindent() == "foo\nfoo\nbar\n" + + let s = " this is an example " + let s2 = ":this;is;an:example;;" + + doAssert s.split() == @["", "this", "is", "an", "example", "", ""] + doAssert s2.split(seps={':', ';'}) == @["", "this", "is", "an", "example", "", ""] + doAssert s.split(maxsplit=4) == @["", "this", "is", "an", "example "] + doAssert s.split(' ', maxsplit=1) == @["", "this is an example "] + doAssert s.split(" ", maxsplit=4) == @["", "this", "is", "an", "example "] + + block: # formatEng tests + doAssert formatEng(0, 2, trim=false) == "0.00" + doAssert formatEng(0, 2) == "0" + doAssert formatEng(53, 2, trim=false) == "53.00" + doAssert formatEng(0.053, 2, trim=false) == "53.00e-3" + doAssert formatEng(0.053, 4, trim=false) == "53.0000e-3" + doAssert formatEng(0.053, 4, trim=true) == "53e-3" + doAssert formatEng(0.053, 0) == "53e-3" + doAssert formatEng(52731234) == "52.731234e6" + doAssert formatEng(-52731234) == "-52.731234e6" + doAssert formatEng(52731234, 1) == "52.7e6" + doAssert formatEng(-52731234, 1) == "-52.7e6" + doAssert formatEng(52731234, 1, decimalSep=',') == "52,7e6" + doAssert formatEng(-52731234, 1, decimalSep=',') == "-52,7e6" + + doAssert formatEng(4100, siPrefix=true, unit="V") == "4.1 kV" + doAssert formatEng(4.1, siPrefix=true, unit="V") == "4.1 V" + doAssert formatEng(4.1, siPrefix=true) == "4.1" # Note lack of space + doAssert formatEng(4100, siPrefix=true) == "4.1 k" + doAssert formatEng(4.1, siPrefix=true, unit="") == "4.1 " # Includes space + doAssert formatEng(4100, siPrefix=true, unit="") == "4.1 k" + doAssert formatEng(4100) == "4.1e3" + doAssert formatEng(4100, unit="V") == "4.1e3 V" + doAssert formatEng(4100, unit="") == "4.1e3 " # Space with unit="" + # Don't use SI prefix as number is too big + doAssert formatEng(3.1e22, siPrefix=true, unit="a") == "31e21 a" + # Don't use SI prefix as number is too small + doAssert formatEng(3.1e-25, siPrefix=true, unit="A") == "310e-27 A" + + block: # startsWith / endsWith char tests + var s = "abcdef" + doAssert s.startsWith('a') + doAssert s.startsWith('b') == false + doAssert s.endsWith('f') + doAssert s.endsWith('a') == false + doAssert s.endsWith('\0') == false + + #echo("strutils tests passed") diff --git a/lib/pure/subexes.nim b/lib/pure/subexes.nim index 5824ace81..351b3c086 100644 --- a/lib/pure/subexes.nim +++ b/lib/pure/subexes.nim @@ -9,7 +9,7 @@ ## Nim support for `substitution expressions`:idx: (`subex`:idx:). ## -## .. include:: ../doc/subexes.txt +## .. include:: ../../doc/subexes.txt ## {.push debugger:off .} # the user does not want to trace a part @@ -390,11 +390,11 @@ when isMainModule: doAssert(("type MyEnum* = enum\n $', '2i'\n '{..}" % ["fieldA", "fieldB", "FiledClkad", "fieldD", "fieldE", "longishFieldName"]).replace(" \n", "\n") == - strutils.unindent """ + strutils.unindent(""" type MyEnum* = enum fieldA, fieldB, FiledClkad, fieldD, - fieldE, longishFieldName""") + fieldE, longishFieldName""", 6)) doAssert subex"$1($', '{2..})" % ["f", "a", "b", "c"] == "f(a, b, c)" @@ -404,8 +404,8 @@ when isMainModule: doAssert((subex("type\n Enum = enum\n $', '40c'\n '{..}") % [ "fieldNameA", "fieldNameB", "fieldNameC", "fieldNameD"]).replace(" \n", "\n") == - strutils.unindent """ + strutils.unindent(""" type Enum = enum fieldNameA, fieldNameB, fieldNameC, - fieldNameD""") + fieldNameD""", 6)) diff --git a/lib/pure/terminal.nim b/lib/pure/terminal.nim index 60f064e7c..1f34ec07e 100644 --- a/lib/pure/terminal.nim +++ b/lib/pure/terminal.nim @@ -384,7 +384,7 @@ proc setForegroundColor*(f: File, fg: ForegroundColor, bright=false) = var old = getAttributes(h) and not 0x0007 if bright: old = old or FOREGROUND_INTENSITY - const lookup: array [ForegroundColor, int] = [ + const lookup: array[ForegroundColor, int] = [ 0, (FOREGROUND_RED), (FOREGROUND_GREEN), @@ -406,7 +406,7 @@ proc setBackgroundColor*(f: File, bg: BackgroundColor, bright=false) = var old = getAttributes(h) and not 0x0070 if bright: old = old or BACKGROUND_INTENSITY - const lookup: array [BackgroundColor, int] = [ + const lookup: array[BackgroundColor, int] = [ 0, (BACKGROUND_RED), (BACKGROUND_GREEN), @@ -493,17 +493,21 @@ template styledEcho*(args: varargs[expr]): expr = ## Echoes styles arguments to stdout using ``styledWriteLine``. callStyledEcho(args) -when defined(nimdoc): - proc getch*(): char = - ## Read a single character from the terminal, blocking until it is entered. - ## The character is not printed to the terminal. This is not available for - ## Windows. - discard -elif not defined(windows): - proc getch*(): char = - ## Read a single character from the terminal, blocking until it is entered. - ## The character is not printed to the terminal. This is not available for - ## Windows. +proc getch*(): char = + ## Read a single character from the terminal, blocking until it is entered. + ## The character is not printed to the terminal. + when defined(windows): + let fd = getStdHandle(STD_INPUT_HANDLE) + var keyEvent = KEY_EVENT_RECORD() + var numRead: cint + while true: + # Block until character is entered + doAssert(waitForSingleObject(fd, INFINITE) == WAIT_OBJECT_0) + doAssert(readConsoleInput(fd, addr(keyEvent), 1, addr(numRead)) != 0) + if numRead == 0 or keyEvent.eventType != 1 or keyEvent.bKeyDown == 0: + continue + return char(keyEvent.uChar) + else: let fd = getFileHandle(stdin) var oldMode: Termios discard fd.tcgetattr(addr oldMode) diff --git a/lib/pure/times.nim b/lib/pure/times.nim index 29ae52d47..d6eb29e1c 100644 --- a/lib/pure/times.nim +++ b/lib/pure/times.nim @@ -64,8 +64,9 @@ when defined(posix) and not defined(JS): proc posix_gettimeofday(tp: var Timeval, unused: pointer = nil) {. importc: "gettimeofday", header: "<sys/time.h>".} + when not defined(freebsd) and not defined(netbsd) and not defined(openbsd): + var timezone {.importc, header: "<time.h>".}: int var - timezone {.importc, header: "<time.h>".}: int tzname {.importc, header: "<time.h>" .}: array[0..1, cstring] # we also need tzset() to make sure that tzname is initialized proc tzset() {.importc, header: "<time.h>".} @@ -94,7 +95,8 @@ elif defined(windows): elif defined(JS): type - Time* {.importc.} = object + Time* = ref TimeObj + TimeObj {.importc.} = object getDay: proc (): int {.tags: [], raises: [], benign.} getFullYear: proc (): int {.tags: [], raises: [], benign.} getHours: proc (): int {.tags: [], raises: [], benign.} @@ -182,7 +184,16 @@ proc getGMTime*(t: Time): TimeInfo {.tags: [TimeEffect], raises: [], benign.} ## converts the calendar time `t` to broken-down time representation, ## expressed in Coordinated Universal Time (UTC). -proc timeInfoToTime*(timeInfo: TimeInfo): Time {.tags: [], benign.} +proc timeInfoToTime*(timeInfo: TimeInfo): Time {.tags: [], benign, deprecated.} + ## converts a broken-down time structure to + ## calendar time representation. The function ignores the specified + ## contents of the structure members `weekday` and `yearday` and recomputes + ## them from the other information in the broken-down time structure. + ## + ## **Warning:** This procedure is deprecated since version 0.14.0. + ## Use ``toTime`` instead. + +proc toTime*(timeInfo: TimeInfo): Time {.tags: [], benign.} ## converts a broken-down time structure to ## calendar time representation. The function ignores the specified ## contents of the structure members `weekday` and `yearday` and recomputes @@ -356,16 +367,19 @@ proc `+`*(a: TimeInfo, interval: TimeInterval): TimeInfo = ## ## **Note:** This has been only briefly tested and it may not be ## very accurate. - let t = toSeconds(timeInfoToTime(a)) + let t = toSeconds(toTime(a)) let secs = toSeconds(a, interval) - result = getLocalTime(fromSeconds(t + secs)) + if a.tzname == "UTC": + result = getGMTime(fromSeconds(t + secs)) + else: + result = getLocalTime(fromSeconds(t + secs)) proc `-`*(a: TimeInfo, interval: TimeInterval): TimeInfo = ## subtracts ``interval`` time from TimeInfo ``a``. ## ## **Note:** This has been only briefly tested, it is inaccurate especially ## when you subtract so much that you reach the Julian calendar. - let t = toSeconds(timeInfoToTime(a)) + let t = toSeconds(toTime(a)) var intval: TimeInterval intval.milliseconds = - interval.milliseconds intval.seconds = - interval.seconds @@ -375,7 +389,10 @@ proc `-`*(a: TimeInfo, interval: TimeInterval): TimeInfo = intval.months = - interval.months intval.years = - interval.years let secs = toSeconds(a, intval) - result = getLocalTime(fromSeconds(t + secs)) + if a.tzname == "UTC": + result = getGMTime(fromSeconds(t + secs)) + else: + result = getLocalTime(fromSeconds(t + secs)) proc miliseconds*(t: TimeInterval): int {.deprecated.} = t.milliseconds @@ -407,18 +424,32 @@ when not defined(JS): when not defined(JS): # C wrapper: + when defined(freebsd) or defined(netbsd) or defined(openbsd): + type + StructTM {.importc: "struct tm", final.} = object + second {.importc: "tm_sec".}, + minute {.importc: "tm_min".}, + hour {.importc: "tm_hour".}, + monthday {.importc: "tm_mday".}, + month {.importc: "tm_mon".}, + year {.importc: "tm_year".}, + weekday {.importc: "tm_wday".}, + yearday {.importc: "tm_yday".}, + isdst {.importc: "tm_isdst".}: cint + gmtoff {.importc: "tm_gmtoff".}: clong + else: + type + StructTM {.importc: "struct tm", final.} = object + second {.importc: "tm_sec".}, + minute {.importc: "tm_min".}, + hour {.importc: "tm_hour".}, + monthday {.importc: "tm_mday".}, + month {.importc: "tm_mon".}, + year {.importc: "tm_year".}, + weekday {.importc: "tm_wday".}, + yearday {.importc: "tm_yday".}, + isdst {.importc: "tm_isdst".}: cint type - StructTM {.importc: "struct tm", final.} = object - second {.importc: "tm_sec".}, - minute {.importc: "tm_min".}, - hour {.importc: "tm_hour".}, - monthday {.importc: "tm_mday".}, - month {.importc: "tm_mon".}, - year {.importc: "tm_year".}, - weekday {.importc: "tm_wday".}, - yearday {.importc: "tm_yday".}, - isdst {.importc: "tm_isdst".}: cint - TimeInfoPtr = ptr StructTM Clock {.importc: "clock_t".} = distinct int @@ -446,30 +477,53 @@ when not defined(JS): # our own procs on top of that: proc tmToTimeInfo(tm: StructTM, local: bool): TimeInfo = const - weekDays: array [0..6, WeekDay] = [ + weekDays: array[0..6, WeekDay] = [ dSun, dMon, dTue, dWed, dThu, dFri, dSat] - TimeInfo(second: int(tm.second), - minute: int(tm.minute), - hour: int(tm.hour), - monthday: int(tm.monthday), - month: Month(tm.month), - year: tm.year + 1900'i32, - weekday: weekDays[int(tm.weekday)], - yearday: int(tm.yearday), - isDST: tm.isdst > 0, - tzname: if local: - if tm.isdst > 0: - getTzname().DST + when defined(freebsd) or defined(netbsd) or defined(openbsd): + TimeInfo(second: int(tm.second), + minute: int(tm.minute), + hour: int(tm.hour), + monthday: int(tm.monthday), + month: Month(tm.month), + year: tm.year + 1900'i32, + weekday: weekDays[int(tm.weekday)], + yearday: int(tm.yearday), + isDST: tm.isdst > 0, + tzname: if local: + if tm.isdst > 0: + getTzname().DST + else: + getTzname().nonDST + else: + "UTC", + # BSD stores in `gmtoff` offset east of UTC in seconds, + # but posix systems using west of UTC in seconds + timezone: if local: -(tm.gmtoff) else: 0 + ) + else: + TimeInfo(second: int(tm.second), + minute: int(tm.minute), + hour: int(tm.hour), + monthday: int(tm.monthday), + month: Month(tm.month), + year: tm.year + 1900'i32, + weekday: weekDays[int(tm.weekday)], + yearday: int(tm.yearday), + isDST: tm.isdst > 0, + tzname: if local: + if tm.isdst > 0: + getTzname().DST + else: + getTzname().nonDST else: - getTzname().nonDST - else: - "UTC", - timezone: if local: getTimezone() else: 0 - ) + "UTC", + timezone: if local: getTimezone() else: 0 + ) + proc timeInfoToTM(t: TimeInfo): StructTM = const - weekDays: array [WeekDay, int8] = [1'i8,2'i8,3'i8,4'i8,5'i8,6'i8,0'i8] + weekDays: array[WeekDay, int8] = [1'i8,2'i8,3'i8,4'i8,5'i8,6'i8,0'i8] result.second = t.second result.minute = t.minute result.hour = t.hour @@ -517,6 +571,11 @@ when not defined(JS): # because the header of mktime is broken in my version of libc return mktime(timeInfoToTM(cTimeInfo)) + proc toTime(timeInfo: TimeInfo): Time = + var cTimeInfo = timeInfo # for C++ we have to make a copy, + # because the header of mktime is broken in my version of libc + return mktime(timeInfoToTM(cTimeInfo)) + proc toStringTillNL(p: cstring): string = result = "" var i = 0 @@ -550,7 +609,14 @@ when not defined(JS): return ($tzname[0], $tzname[1]) proc getTimezone(): int = - return timezone + when defined(freebsd) or defined(netbsd) or defined(openbsd): + var a = timec(nil) + let lt = localtime(addr(a)) + # BSD stores in `gmtoff` offset east of UTC in seconds, + # but posix systems using west of UTC in seconds + return -(lt.gmtoff) + else: + return timezone proc fromSeconds(since1970: float): Time = Time(since1970) @@ -586,7 +652,7 @@ elif defined(JS): return newDate() const - weekDays: array [0..6, WeekDay] = [ + weekDays: array[0..6, WeekDay] = [ dSun, dMon, dTue, dWed, dThu, dFri, dSat] proc getLocalTime(t: Time): TimeInfo = @@ -618,7 +684,16 @@ elif defined(JS): result.setFullYear(timeInfo.year) result.setDate(timeInfo.monthday) - proc `$`(timeInfo: TimeInfo): string = return $(timeInfoToTime(timeInfo)) + proc toTime*(timeInfo: TimeInfo): Time = + result = internGetTime() + result.setSeconds(timeInfo.second) + result.setMinutes(timeInfo.minute) + result.setHours(timeInfo.hour) + result.setMonth(ord(timeInfo.month)) + result.setFullYear(timeInfo.year) + result.setDate(timeInfo.monthday) + + proc `$`(timeInfo: TimeInfo): string = return $(toTime(timeInfo)) proc `$`(time: Time): string = return $time.toLocaleString() proc `-` (a, b: Time): int64 = @@ -708,24 +783,24 @@ proc years*(y: int): TimeInterval {.inline.} = proc `+=`*(t: var Time, ti: TimeInterval) = ## modifies `t` by adding the interval `ti` - t = timeInfoToTime(getLocalTime(t) + ti) + t = toTime(getLocalTime(t) + ti) proc `+`*(t: Time, ti: TimeInterval): Time = ## adds the interval `ti` to Time `t` ## by converting to localTime, adding the interval, and converting back ## ## ``echo getTime() + 1.day`` - result = timeInfoToTime(getLocalTime(t) + ti) + result = toTime(getLocalTime(t) + ti) proc `-=`*(t: var Time, ti: TimeInterval) = ## modifies `t` by subtracting the interval `ti` - t = timeInfoToTime(getLocalTime(t) - ti) + t = toTime(getLocalTime(t) - ti) proc `-`*(t: Time, ti: TimeInterval): Time = ## adds the interval `ti` to Time `t` ## ## ``echo getTime() - 1.day`` - result = timeInfoToTime(getLocalTime(t) - ti) + result = toTime(getLocalTime(t) - ti) proc formatToken(info: TimeInfo, token: string, buf: var string) = ## Helper of the format proc to parse individual tokens. @@ -923,21 +998,14 @@ proc parseToken(info: var TimeInfo; token, value: string; j: var int) = info.monthday = value[j..j+1].parseInt() j += 2 of "ddd": - case value[j..j+2].toLower() - of "sun": - info.weekday = dSun - of "mon": - info.weekday = dMon - of "tue": - info.weekday = dTue - of "wed": - info.weekday = dWed - of "thu": - info.weekday = dThu - of "fri": - info.weekday = dFri - of "sat": - info.weekday = dSat + case value[j..j+2].toLowerAscii() + of "sun": info.weekday = dSun + of "mon": info.weekday = dMon + of "tue": info.weekday = dTue + of "wed": info.weekday = dWed + of "thu": info.weekday = dThu + of "fri": info.weekday = dFri + of "sat": info.weekday = dSat else: raise newException(ValueError, "Couldn't parse day of week (ddd), got: " & value[j..j+2]) @@ -991,31 +1059,19 @@ proc parseToken(info: var TimeInfo; token, value: string; j: var int) = j += 2 info.month = Month(month-1) of "MMM": - case value[j..j+2].toLower(): - of "jan": - info.month = mJan - of "feb": - info.month = mFeb - of "mar": - info.month = mMar - of "apr": - info.month = mApr - of "may": - info.month = mMay - of "jun": - info.month = mJun - of "jul": - info.month = mJul - of "aug": - info.month = mAug - of "sep": - info.month = mSep - of "oct": - info.month = mOct - of "nov": - info.month = mNov - of "dec": - info.month = mDec + case value[j..j+2].toLowerAscii(): + of "jan": info.month = mJan + of "feb": info.month = mFeb + of "mar": info.month = mMar + of "apr": info.month = mApr + of "may": info.month = mMay + of "jun": info.month = mJun + of "jul": info.month = mJul + of "aug": info.month = mAug + of "sep": info.month = mSep + of "oct": info.month = mOct + of "nov": info.month = mNov + of "dec": info.month = mDec else: raise newException(ValueError, "Couldn't parse month (MMM), got: " & value) @@ -1112,7 +1168,7 @@ proc parseToken(info: var TimeInfo; token, value: string; j: var int) = "Couldn't parse timezone offset (zzz), got: " & value[j]) j += 6 of "ZZZ": - info.tzname = value[j..j+2].toUpper() + info.tzname = value[j..j+2].toUpperAscii() j += 3 else: # Ignore the token and move forward in the value string by the same length @@ -1193,7 +1249,7 @@ proc parse*(value, layout: string): TimeInfo = parseToken(info, token, value, j) token = "" # Reset weekday as it might not have been provided and the default may be wrong - info.weekday = getLocalTime(timeInfoToTime(info)).weekday + info.weekday = getLocalTime(toTime(info)).weekday return info # Leap year calculations are adapted from: @@ -1204,7 +1260,7 @@ proc parse*(value, layout: string): TimeInfo = proc countLeapYears*(yearSpan: int): int = ## Returns the number of leap years spanned by a given number of years. ## - ## Note: for leap years, start date is assumed to be 1 AD. + ## **Note:** For leap years, start date is assumed to be 1 AD. ## counts the number of leap years up to January 1st of a given year. ## Keep in mind that if specified year is a leap year, the leap day ## has not happened before January 1st of that year. @@ -1239,13 +1295,14 @@ proc getDayOfWeek*(day, month, year: int): WeekDay = y = year - a m = month + (12*a) - 2 d = (day + y + (y div 4) - (y div 100) + (y div 400) + (31*m) div 12) mod 7 - # The value of d is 0 for a Sunday, 1 for a Monday, 2 for a Tuesday, etc. so we must correct - # for the WeekDay type. + # The value of d is 0 for a Sunday, 1 for a Monday, 2 for a Tuesday, etc. + # so we must correct for the WeekDay type. if d == 0: return dSun result = (d-1).WeekDay proc getDayOfWeekJulian*(day, month, year: int): WeekDay = - ## Returns the day of the week enum from day, month and year, according to the Julian calendar. + ## Returns the day of the week enum from day, month and year, + ## according to the Julian calendar. # Day & month start from one. let a = (14 - month) div 12 @@ -1254,8 +1311,11 @@ proc getDayOfWeekJulian*(day, month, year: int): WeekDay = d = (5 + day + y + (y div 4) + (31*m) div 12) mod 7 result = d.WeekDay -proc timeToTimeInfo*(t: Time): TimeInfo = +proc timeToTimeInfo*(t: Time): TimeInfo {.deprecated.} = ## Converts a Time to TimeInfo. + ## + ## **Warning:** This procedure is deprecated since version 0.14.0. + ## Use ``getLocalTime`` or ``getGMTime`` instead. let secs = t.toSeconds().int daysSinceEpoch = secs div secondsInDay @@ -1286,34 +1346,21 @@ proc timeToTimeInfo*(t: Time): TimeInfo = s = daySeconds mod secondsInMin result = TimeInfo(year: y, yearday: yd, month: m, monthday: md, weekday: wd, hour: h, minute: mi, second: s) -proc timeToTimeInterval*(t: Time): TimeInterval = +proc timeToTimeInterval*(t: Time): TimeInterval {.deprecated.} = ## Converts a Time to a TimeInterval. + ## + ## **Warning:** This procedure is deprecated since version 0.14.0. + ## Use ``toTimeInterval`` instead. + # Milliseconds not available from Time + var tInfo = t.getLocalTime() + initInterval(0, tInfo.second, tInfo.minute, tInfo.hour, tInfo.weekday.ord, tInfo.month.ord, tInfo.year) - let - secs = t.toSeconds().int - daysSinceEpoch = secs div secondsInDay - (yearsSinceEpoch, daysRemaining) = countYearsAndDays(daysSinceEpoch) - daySeconds = secs mod secondsInDay - - result.years = yearsSinceEpoch + epochStartYear - - var - mon = mJan - days = daysRemaining - daysInMonth = getDaysInMonth(mon, result.years) - - # calculate month and day remainder - while days > daysInMonth and mon <= mDec: - days -= daysInMonth - mon.inc - daysInMonth = getDaysInMonth(mon, result.years) - - result.months = mon.int + 1 # month is 1 indexed int - result.days = days - result.hours = daySeconds div secondsInHour + 1 - result.minutes = (daySeconds mod secondsInHour) div secondsInMin - result.seconds = daySeconds mod secondsInMin +proc toTimeInterval*(t: Time): TimeInterval = + ## Converts a Time to a TimeInterval. # Milliseconds not available from Time + var tInfo = t.getLocalTime() + initInterval(0, tInfo.second, tInfo.minute, tInfo.hour, tInfo.weekday.ord, tInfo.month.ord, tInfo.year) + when isMainModule: # this is testing non-exported function diff --git a/lib/pure/unicode.nim b/lib/pure/unicode.nim index 45f52eb7f..0cbe8de7a 100644 --- a/lib/pure/unicode.nim +++ b/lib/pure/unicode.nim @@ -14,7 +14,7 @@ include "system/inclrtl" type - RuneImpl = int # underlying type of Rune + RuneImpl = int32 # underlying type of Rune Rune* = distinct RuneImpl ## type that can hold any Unicode character Rune16* = distinct int16 ## 16 bit Unicode character @@ -135,45 +135,62 @@ proc runeAt*(s: string, i: Natural): Rune = ## Returns the unicode character in ``s`` at byte index ``i`` fastRuneAt(s, i, result, false) -proc toUTF8*(c: Rune): string {.rtl, extern: "nuc$1".} = - ## Converts a rune into its UTF-8 representation +template fastToUTF8Copy*(c: Rune, s: var string, pos: int, doInc = true) = + ## Copies UTF-8 representation of `c` into the preallocated string `s` + ## starting at position `pos`. If `doInc == true`, `pos` is incremented + ## by the number of bytes that have been processed. + ## + ## To be the most efficient, make sure `s` is preallocated + ## with an additional amount equal to the byte length of + ## `c`. var i = RuneImpl(c) if i <=% 127: - result = newString(1) - result[0] = chr(i) + s.setLen(pos+1) + s[pos+0] = chr(i) + when doInc: inc(pos) elif i <=% 0x07FF: - result = newString(2) - result[0] = chr((i shr 6) or 0b110_00000) - result[1] = chr((i and ones(6)) or 0b10_0000_00) + s.setLen(pos+2) + s[pos+0] = chr((i shr 6) or 0b110_00000) + s[pos+1] = chr((i and ones(6)) or 0b10_0000_00) + when doInc: inc(pos, 2) elif i <=% 0xFFFF: - result = newString(3) - result[0] = chr(i shr 12 or 0b1110_0000) - result[1] = chr(i shr 6 and ones(6) or 0b10_0000_00) - result[2] = chr(i and ones(6) or 0b10_0000_00) + s.setLen(pos+3) + s[pos+0] = chr(i shr 12 or 0b1110_0000) + s[pos+1] = chr(i shr 6 and ones(6) or 0b10_0000_00) + s[pos+2] = chr(i and ones(6) or 0b10_0000_00) + when doInc: inc(pos, 3) elif i <=% 0x001FFFFF: - result = newString(4) - result[0] = chr(i shr 18 or 0b1111_0000) - result[1] = chr(i shr 12 and ones(6) or 0b10_0000_00) - result[2] = chr(i shr 6 and ones(6) or 0b10_0000_00) - result[3] = chr(i and ones(6) or 0b10_0000_00) + s.setLen(pos+4) + s[pos+0] = chr(i shr 18 or 0b1111_0000) + s[pos+1] = chr(i shr 12 and ones(6) or 0b10_0000_00) + s[pos+2] = chr(i shr 6 and ones(6) or 0b10_0000_00) + s[pos+3] = chr(i and ones(6) or 0b10_0000_00) + when doInc: inc(pos, 4) elif i <=% 0x03FFFFFF: - result = newString(5) - result[0] = chr(i shr 24 or 0b111110_00) - result[1] = chr(i shr 18 and ones(6) or 0b10_0000_00) - result[2] = chr(i shr 12 and ones(6) or 0b10_0000_00) - result[3] = chr(i shr 6 and ones(6) or 0b10_0000_00) - result[4] = chr(i and ones(6) or 0b10_0000_00) + s.setLen(pos+5) + s[pos+0] = chr(i shr 24 or 0b111110_00) + s[pos+1] = chr(i shr 18 and ones(6) or 0b10_0000_00) + s[pos+2] = chr(i shr 12 and ones(6) or 0b10_0000_00) + s[pos+3] = chr(i shr 6 and ones(6) or 0b10_0000_00) + s[pos+4] = chr(i and ones(6) or 0b10_0000_00) + when doInc: inc(pos, 5) elif i <=% 0x7FFFFFFF: - result = newString(6) - result[0] = chr(i shr 30 or 0b1111110_0) - result[1] = chr(i shr 24 and ones(6) or 0b10_0000_00) - result[2] = chr(i shr 18 and ones(6) or 0b10_0000_00) - result[3] = chr(i shr 12 and ones(6) or 0b10_0000_00) - result[4] = chr(i shr 6 and ones(6) or 0b10_0000_00) - result[5] = chr(i and ones(6) or 0b10_0000_00) + s.setLen(pos+6) + s[pos+0] = chr(i shr 30 or 0b1111110_0) + s[pos+1] = chr(i shr 24 and ones(6) or 0b10_0000_00) + s[pos+2] = chr(i shr 18 and ones(6) or 0b10_0000_00) + s[pos+3] = chr(i shr 12 and ones(6) or 0b10_0000_00) + s[pos+4] = chr(i shr 6 and ones(6) or 0b10_0000_00) + s[pos+5] = chr(i and ones(6) or 0b10_0000_00) + when doInc: inc(pos, 6) else: discard # error, exception? +proc toUTF8*(c: Rune): string {.rtl, extern: "nuc$1".} = + ## Converts a rune into its UTF-8 representation + result = "" + fastToUTF8Copy(c, result, 0, false) + proc `$`*(rune: Rune): string = ## Converts a Rune to a string rune.toUTF8 @@ -183,6 +200,105 @@ proc `$`*(runes: seq[Rune]): string = result = "" for rune in runes: result.add(rune.toUTF8) +proc runeOffset*(s: string, pos:Natural, start: Natural = 0): int = + ## Returns the byte position of unicode character + ## at position pos in s with an optional start byte position. + ## returns the special value -1 if it runs out of the string + ## + ## Beware: This can lead to unoptimized code and slow execution! + ## Most problems are solve more efficient by using an iterator + ## or conversion to a seq of Rune. + var + i = 0 + o = start + while i < pos: + o += runeLenAt(s, o) + if o >= s.len: + return -1 + inc i + return o + +proc runeAtPos*(s: string, pos: int): Rune = + ## Returns the unicode character at position pos + ## + ## Beware: This can lead to unoptimized code and slow execution! + ## Most problems are solve more efficient by using an iterator + ## or conversion to a seq of Rune. + fastRuneAt(s, runeOffset(s, pos), result, false) + +proc runeStrAtPos*(s: string, pos: Natural): string = + ## Returns the unicode character at position pos as UTF8 String + ## + ## Beware: This can lead to unoptimized code and slow execution! + ## Most problems are solve more efficient by using an iterator + ## or conversion to a seq of Rune. + let o = runeOffset(s, pos) + s[o.. (o+runeLenAt(s, o)-1)] + +proc runeReverseOffset*(s: string, rev:Positive): (int, int) = + ## Returns a tuple with the the byte offset of the + ## unicode character at position ``rev`` in s counting + ## from the end (starting with 1) and the total + ## number of runes in the string. Returns a negative value + ## for offset if there are to few runes in the string to + ## satisfy the request. + ## + ## Beware: This can lead to unoptimized code and slow execution! + ## Most problems are solve more efficient by using an iterator + ## or conversion to a seq of Rune. + var + a = rev.int + o = 0 + x = 0 + while o < s.len: + let r = runeLenAt(s, o) + o += r + if a < 0: + x += r + dec a + + if a > 0: + return (-a, rev.int-a) + return (x, -a+rev.int) + +proc runeSubStr*(s: string, pos:int, len:int = int.high): string = + ## Returns the UTF-8 substring starting at codepoint pos + ## with len codepoints. If pos or len is negativ they count from + ## the end of the string. If len is not given it means the longest + ## possible string. + ## + ## (Needs some examples) + if pos < 0: + let (o, rl) = runeReverseOffset(s, -pos) + if len >= rl: + result = s[o.. s.len-1] + elif len < 0: + let e = rl + len + if e < 0: + result = "" + else: + result = s[o.. runeOffset(s, e-(rl+pos) , o)-1] + else: + result = s[o.. runeOffset(s, len, o)-1] + else: + let o = runeOffset(s, pos) + if o < 0: + result = "" + elif len == int.high: + result = s[o.. s.len-1] + elif len < 0: + let (e, rl) = runeReverseOffset(s, -len) + discard rl + if e <= 0: + result = "" + else: + result = s[o.. e-1] + else: + var e = runeOffset(s, len, o) + if e < 0: + e = s.len + result = s[o.. e-1] + const alphaRanges = [ 0x00d8, 0x00f6, # - @@ -1148,7 +1264,7 @@ const 0x01f1, 501, # 0x01f3, 499] # -proc binarySearch(c: RuneImpl, tab: openArray[RuneImpl], len, stride: int): int = +proc binarySearch(c: RuneImpl, tab: openArray[int], len, stride: int): int = var n = len var t = 0 while n > 1: @@ -1253,8 +1369,210 @@ proc isCombining*(c: Rune): bool {.rtl, extern: "nuc$1", procvar.} = (c >= 0x20d0 and c <= 0x20ff) or (c >= 0xfe20 and c <= 0xfe2f)) +template runeCheck(s, runeProc) = + ## Common code for rune.isLower, rune.isUpper, etc + result = if len(s) == 0: false else: true + + var + i = 0 + rune: Rune + + while i < len(s) and result: + fastRuneAt(s, i, rune, doInc=true) + result = runeProc(rune) and result + +proc isUpper*(s: string): bool {.noSideEffect, procvar, + rtl, extern: "nuc$1Str".} = + ## Returns true iff `s` contains all upper case unicode characters. + runeCheck(s, isUpper) + +proc isLower*(s: string): bool {.noSideEffect, procvar, + rtl, extern: "nuc$1Str".} = + ## Returns true iff `s` contains all lower case unicode characters. + runeCheck(s, isLower) + +proc isAlpha*(s: string): bool {.noSideEffect, procvar, + rtl, extern: "nuc$1Str".} = + ## Returns true iff `s` contains all alphabetic unicode characters. + runeCheck(s, isAlpha) + +proc isSpace*(s: string): bool {.noSideEffect, procvar, + rtl, extern: "nuc$1Str".} = + ## Returns true iff `s` contains all whitespace unicode characters. + runeCheck(s, isWhiteSpace) + +template convertRune(s, runeProc) = + ## Convert runes in `s` using `runeProc` as the converter. + result = newString(len(s)) + + var + i = 0 + lastIndex = 0 + rune: Rune + + while i < len(s): + lastIndex = i + fastRuneAt(s, i, rune, doInc=true) + rune = runeProc(rune) + + rune.fastToUTF8Copy(result, lastIndex) + +proc toUpper*(s: string): string {.noSideEffect, procvar, + rtl, extern: "nuc$1Str".} = + ## Converts `s` into upper-case unicode characters. + convertRune(s, toUpper) + +proc toLower*(s: string): string {.noSideEffect, procvar, + rtl, extern: "nuc$1Str".} = + ## Converts `s` into lower-case unicode characters. + convertRune(s, toLower) + +proc swapCase*(s: string): string {.noSideEffect, procvar, + rtl, extern: "nuc$1".} = + ## Swaps the case of unicode characters in `s` + ## + ## Returns a new string such that the cases of all unicode characters + ## are swapped if possible + + var + i = 0 + lastIndex = 0 + rune: Rune + + result = newString(len(s)) + + while i < len(s): + lastIndex = i + + fastRuneAt(s, i, rune) + + if rune.isUpper(): + rune = rune.toLower() + elif rune.isLower(): + rune = rune.toUpper() + + rune.fastToUTF8Copy(result, lastIndex) + +proc capitalize*(s: string): string {.noSideEffect, procvar, + rtl, extern: "nuc$1".} = + ## Converts the first character of `s` into an upper-case unicode character. + if len(s) == 0: + return s + + var + rune: Rune + i = 0 + + fastRuneAt(s, i, rune, doInc=true) + + result = $toUpper(rune) & substr(s, i) + +proc translate*(s: string, replacements: proc(key: string): string): string {. + rtl, extern: "nuc$1".} = + ## Translates words in a string using the `replacements` proc to substitute + ## words inside `s` with their replacements + ## + ## `replacements` is any proc that takes a word and returns + ## a new word to fill it's place. + + # Allocate memory for the new string based on the old one. + # If the new string length is less than the old, no allocations + # will be needed. If the new string length is greater than the + # old, then maybe only one allocation is needed + result = newStringOfCap(s.len) + + var + index = 0 + lastIndex = 0 + wordStart = 0 + inWord = false + rune: Rune + + while index < len(s): + lastIndex = index + + fastRuneAt(s, index, rune) + + let whiteSpace = rune.isWhiteSpace() + + if whiteSpace and inWord: + # If we've reached the end of a word + let word = s[wordStart ..< lastIndex] + result.add(replacements(word)) + result.add($rune) + + inWord = false + elif not whiteSpace and not inWord: + # If we've hit a non space character and + # are not currently in a word, track + # the starting index of the word + inWord = true + wordStart = lastIndex + elif whiteSpace: + result.add($rune) + + if wordStart < len(s) and inWord: + # Get the trailing word at the end + let word = s[wordStart .. ^1] + result.add(replacements(word)) + +proc title*(s: string): string {.noSideEffect, procvar, + rtl, extern: "nuc$1".} = + ## Converts `s` to a unicode title. + ## + ## Returns a new string such that the first character + ## in each word inside `s` is capitalized + + var + i = 0 + lastIndex = 0 + rune: Rune + + result = newString(len(s)) + + var firstRune = true + + while i < len(s): + lastIndex = i + + fastRuneAt(s, i, rune) + + if not rune.isWhiteSpace() and firstRune: + rune = rune.toUpper() + firstRune = false + elif rune.isWhiteSpace(): + firstRune = true + + rune.fastToUTF8Copy(result, lastIndex) + +proc isTitle*(s: string): bool {.noSideEffect, procvar, + rtl, extern: "nuc$1Str".}= + ## Checks whether or not `s` is a unicode title. + ## + ## Returns true if the first character in each word inside `s` + ## are upper case and there is at least one character in `s`. + if s.len() == 0: + return false + + result = true + + var + i = 0 + rune: Rune + + var firstRune = true + + while i < len(s) and result: + fastRuneAt(s, i, rune, doInc=true) + + if not rune.isWhiteSpace() and firstRune: + result = rune.isUpper() and result + firstRune = false + elif rune.isWhiteSpace(): + firstRune = true + iterator runes*(s: string): Rune = - ## Iterates over any unicode character of the string ``s`` + ## Iterates over any unicode character of the string ``s`` returning runes var i = 0 result: Rune @@ -1262,6 +1580,14 @@ iterator runes*(s: string): Rune = fastRuneAt(s, i, result, true) yield result +iterator utf8*(s: string): string = + ## Iterates over any unicode character of the string ``s`` returning utf8 values + var o = 0 + while o < s.len: + let n = runeLenAt(s, o) + yield s[o.. (o+n-1)] + o += n + proc toRunes*(s: string): seq[Rune] = ## Obtains a sequence containing the Runes in ``s`` result = newSeq[Rune]() @@ -1352,6 +1678,101 @@ when isMainModule: compared = (someString == $someRunes) doAssert compared == true + proc test_replacements(word: string): string = + case word + of "two": + return "2" + of "foo": + return "BAR" + of "βeta": + return "beta" + of "alpha": + return "αlpha" + else: + return "12345" + + doAssert translate("two not alpha foo βeta", test_replacements) == "2 12345 αlpha BAR beta" + doAssert translate(" two not foo βeta ", test_replacements) == " 2 12345 BAR beta " + + doAssert title("foo bar") == "Foo Bar" + doAssert title("αlpha βeta γamma") == "Αlpha Βeta Γamma" + doAssert title("") == "" + + doAssert capitalize("βeta") == "Βeta" + doAssert capitalize("foo") == "Foo" + doAssert capitalize("") == "" + + doAssert isTitle("Foo") + doAssert(not isTitle("Foo bar")) + doAssert(not isTitle("αlpha Βeta")) + doAssert(isTitle("Αlpha Βeta Γamma")) + doAssert(not isTitle("fFoo")) + + doAssert swapCase("FooBar") == "fOObAR" + doAssert swapCase(" ") == " " + doAssert swapCase("Αlpha Βeta Γamma") == "αLPHA βETA γAMMA" + doAssert swapCase("a✓B") == "A✓b" + doAssert swapCase("") == "" + + doAssert isAlpha("r") + doAssert isAlpha("α") + doAssert(not isAlpha("$")) + doAssert(not isAlpha("")) + + doAssert isAlpha("Βeta") + doAssert isAlpha("Args") + doAssert(not isAlpha("$Foo✓")) + + doAssert isSpace("\t") + doAssert isSpace("\l") + doAssert(not isSpace("Β")) + doAssert(not isSpace("Βeta")) + + doAssert isSpace("\t\l \v\r\f") + doAssert isSpace(" ") + doAssert(not isSpace("")) + doAssert(not isSpace("ΑΓc \td")) + + doAssert isLower("a") + doAssert isLower("γ") + doAssert(not isLower("Γ")) + doAssert(not isLower("4")) + doAssert(not isLower("")) + + doAssert isLower("abcdγ") + doAssert(not isLower("abCDΓ")) + doAssert(not isLower("33aaΓ")) + + doAssert isUpper("Γ") + doAssert(not isUpper("b")) + doAssert(not isUpper("α")) + doAssert(not isUpper("✓")) + doAssert(not isUpper("")) + + doAssert isUpper("ΑΒΓ") + doAssert(not isUpper("AAccβ")) + doAssert(not isUpper("A#$β")) + + doAssert toUpper("Γ") == "Γ" + doAssert toUpper("b") == "B" + doAssert toUpper("α") == "Α" + doAssert toUpper("✓") == "✓" + doAssert toUpper("") == "" + + doAssert toUpper("ΑΒΓ") == "ΑΒΓ" + doAssert toUpper("AAccβ") == "AACCΒ" + doAssert toUpper("A✓$β") == "A✓$Β" + + doAssert toLower("a") == "a" + doAssert toLower("γ") == "γ" + doAssert toLower("Γ") == "γ" + doAssert toLower("4") == "4" + doAssert toLower("") == "" + + doAssert toLower("abcdγ") == "abcdγ" + doAssert toLower("abCDΓ") == "abcdγ" + doAssert toLower("33aaΓ") == "33aaγ" + doAssert reversed("Reverse this!") == "!siht esreveR" doAssert reversed("先秦兩漢") == "漢兩秦先" doAssert reversed("as⃝df̅") == "f̅ds⃝a" @@ -1360,3 +1781,44 @@ when isMainModule: const test = "as⃝" doAssert lastRune(test, test.len-1)[1] == 3 doAssert graphemeLen("è", 0) == 2 + + # test for rune positioning and runeSubStr() + let s = "Hänsel ««: 10,00€" + + var t = "" + for c in s.utf8: + t.add c + + doAssert(s == t) + + doAssert(runeReverseOffset(s, 1) == (20, 18)) + doAssert(runeReverseOffset(s, 19) == (-1, 18)) + + doAssert(runeStrAtPos(s, 0) == "H") + doAssert(runeSubStr(s, 0, 1) == "H") + doAssert(runeStrAtPos(s, 10) == ":") + doAssert(runeSubStr(s, 10, 1) == ":") + doAssert(runeStrAtPos(s, 9) == "«") + doAssert(runeSubStr(s, 9, 1) == "«") + doAssert(runeStrAtPos(s, 17) == "€") + doAssert(runeSubStr(s, 17, 1) == "€") + # echo runeStrAtPos(s, 18) # index error + + doAssert(runeSubStr(s, 0) == "Hänsel ««: 10,00€") + doAssert(runeSubStr(s, -18) == "Hänsel ««: 10,00€") + doAssert(runeSubStr(s, 10) == ": 10,00€") + doAssert(runeSubStr(s, 18) == "") + doAssert(runeSubStr(s, 0, 10) == "Hänsel ««") + + doAssert(runeSubStr(s, 12) == "10,00€") + doAssert(runeSubStr(s, -6) == "10,00€") + + doAssert(runeSubStr(s, 12, 5) == "10,00") + doAssert(runeSubStr(s, 12, -1) == "10,00") + doAssert(runeSubStr(s, -6, 5) == "10,00") + doAssert(runeSubStr(s, -6, -1) == "10,00") + + doAssert(runeSubStr(s, 0, 100) == "Hänsel ««: 10,00€") + doAssert(runeSubStr(s, -100, 100) == "Hänsel ««: 10,00€") + doAssert(runeSubStr(s, 0, -100) == "") + doAssert(runeSubStr(s, 100, -100) == "") diff --git a/lib/pure/unittest.nim b/lib/pure/unittest.nim index aca9d51e2..92ddc3e75 100644 --- a/lib/pure/unittest.nim +++ b/lib/pure/unittest.nim @@ -41,7 +41,11 @@ when not defined(ECMAScript): import terminal type - TestStatus* = enum OK, FAILED ## The status of a test when it is done. + TestStatus* = enum ## The status of a test when it is done. + OK, + FAILED, + SKIPPED + OutputLevel* = enum ## The output verbosity of the tests. PRINT_ALL, ## Print as much as possible. PRINT_FAILURES, ## Print only the failed tests. @@ -73,7 +77,7 @@ checkpoints = @[] proc shouldRun(testName: string): bool = result = true -template suite*(name: expr, body: stmt): stmt {.immediate, dirty.} = +template suite*(name, body) {.dirty.} = ## Declare a test suite identified by `name` with optional ``setup`` ## and/or ``teardown`` section. ## @@ -102,13 +106,13 @@ template suite*(name: expr, body: stmt): stmt {.immediate, dirty.} = ## [OK] 2 + 2 = 4 ## [OK] (2 + -2) != 4 block: - template setup(setupBody: stmt): stmt {.immediate, dirty.} = + template setup(setupBody: untyped) {.dirty.} = var testSetupIMPLFlag = true - template testSetupIMPL: stmt {.immediate, dirty.} = setupBody + template testSetupIMPL: untyped {.dirty.} = setupBody - template teardown(teardownBody: stmt): stmt {.immediate, dirty.} = + template teardown(teardownBody: untyped) {.dirty.} = var testTeardownIMPLFlag = true - template testTeardownIMPL: stmt {.immediate, dirty.} = teardownBody + template testTeardownIMPL: untyped {.dirty.} = teardownBody body @@ -120,14 +124,18 @@ proc testDone(name: string, s: TestStatus) = template rawPrint() = echo("[", $s, "] ", name) when not defined(ECMAScript): if colorOutput and not defined(ECMAScript): - var color = (if s == OK: fgGreen else: fgRed) + var color = case s + of OK: fgGreen + of FAILED: fgRed + of SKIPPED: fgYellow + else: fgWhite styledEcho styleBright, color, "[", $s, "] ", fgWhite, name else: rawPrint() else: rawPrint() -template test*(name: expr, body: stmt): stmt {.immediate, dirty.} = +template test*(name, body) {.dirty.} = ## Define a single test case identified by `name`. ## ## .. code-block:: nim @@ -203,7 +211,22 @@ template fail* = checkpoints = @[] -macro check*(conditions: stmt): stmt {.immediate.} = +template skip* = + ## Makes test to be skipped. Should be used directly + ## in case when it is not possible to perform test + ## for reasons depending on outer environment, + ## or certain application logic conditions or configurations. + ## + ## .. code-block:: nim + ## + ## if not isGLConextCreated(): + ## skip() + bind checkpoints + + testStatusIMPL = SKIPPED + checkpoints = @[] + +macro check*(conditions: untyped): untyped = ## Verify if a statement or a list of statements is true. ## A helpful error message and set checkpoints are printed out on ## failure (if ``outputLevel`` is not ``PRINT_NONE``). @@ -236,31 +259,34 @@ macro check*(conditions: stmt): stmt {.immediate.} = proc inspectArgs(exp: NimNode): NimNode = result = copyNimTree(exp) - for i in countup(1, exp.len - 1): - if exp[i].kind notin nnkLiterals: - inc counter - var arg = newIdentNode(":p" & $counter) - var argStr = exp[i].toStrLit - var paramAst = exp[i] - if exp[i].kind == nnkIdent: - argsPrintOuts.add getAst(print(argStr, paramAst)) - if exp[i].kind in nnkCallKinds: - var callVar = newIdentNode(":c" & $counter) - argsAsgns.add getAst(asgn(callVar, paramAst)) - result[i] = callVar - argsPrintOuts.add getAst(print(argStr, callVar)) - if exp[i].kind == nnkExprEqExpr: - # ExprEqExpr - # Ident !"v" - # IntLit 2 - result[i] = exp[i][1] - if exp[i].typekind notin {ntyTypeDesc}: - argsAsgns.add getAst(asgn(arg, paramAst)) - argsPrintOuts.add getAst(print(argStr, arg)) - if exp[i].kind != nnkExprEqExpr: - result[i] = arg - else: - result[i][1] = arg + if exp[0].kind == nnkIdent and + $exp[0] in ["and", "or", "not", "in", "notin", "==", "<=", + ">=", "<", ">", "!=", "is", "isnot"]: + for i in countup(1, exp.len - 1): + if exp[i].kind notin nnkLiterals: + inc counter + var arg = newIdentNode(":p" & $counter) + var argStr = exp[i].toStrLit + var paramAst = exp[i] + if exp[i].kind == nnkIdent: + argsPrintOuts.add getAst(print(argStr, paramAst)) + if exp[i].kind in nnkCallKinds: + var callVar = newIdentNode(":c" & $counter) + argsAsgns.add getAst(asgn(callVar, paramAst)) + result[i] = callVar + argsPrintOuts.add getAst(print(argStr, callVar)) + if exp[i].kind == nnkExprEqExpr: + # ExprEqExpr + # Ident !"v" + # IntLit 2 + result[i] = exp[i][1] + if exp[i].typekind notin {ntyTypeDesc}: + argsAsgns.add getAst(asgn(arg, paramAst)) + argsPrintOuts.add getAst(print(argStr, arg)) + if exp[i].kind != nnkExprEqExpr: + result[i] = arg + else: + result[i][1] = arg case checked.kind of nnkCallKinds: @@ -292,7 +318,7 @@ macro check*(conditions: stmt): stmt {.immediate.} = result = getAst(rewrite(checked, checked.lineinfo, checked.toStrLit)) -template require*(conditions: stmt): stmt {.immediate.} = +template require*(conditions: untyped) = ## Same as `check` except any failed test causes the program to quit ## immediately. Any teardown statements are not executed and the failed ## test output is not generated. @@ -302,7 +328,7 @@ template require*(conditions: stmt): stmt {.immediate.} = check conditions abortOnError = savedAbortOnError -macro expect*(exceptions: varargs[expr], body: stmt): stmt {.immediate.} = +macro expect*(exceptions: varargs[typed], body: untyped): untyped = ## Test if `body` raises an exception found in the passed `exceptions`. ## The test passes if the raised exception is part of the acceptable ## exceptions. Otherwise, it fails. @@ -310,7 +336,7 @@ macro expect*(exceptions: varargs[expr], body: stmt): stmt {.immediate.} = ## ## .. code-block:: nim ## - ## import math + ## import math, random ## proc defectiveRobot() = ## randomize() ## case random(1..4) diff --git a/lib/pure/xmldom.nim b/lib/pure/xmldom.nim index 6cf837f25..559f45348 100644 --- a/lib/pure/xmldom.nim +++ b/lib/pure/xmldom.nim @@ -51,6 +51,9 @@ const # Illegal characters illegalChars = {'>', '<', '&', '"'} + # standard xml: attribute names + # see https://www.w3.org/XML/1998/namespace + stdattrnames = ["lang", "space", "base", "id"] type Feature = tuple[name: string, version: string] @@ -229,12 +232,15 @@ proc createAttributeNS*(doc: PDocument, namespaceURI: string, qualifiedName: str raise newException(EInvalidCharacterErr, "Invalid character") # Exceptions if qualifiedName.contains(':'): + let qfnamespaces = qualifiedName.toLower().split(':') if isNil(namespaceURI): raise newException(ENamespaceErr, "When qualifiedName contains a prefix namespaceURI cannot be nil") - elif qualifiedName.split(':')[0].toLower() == "xml" and namespaceURI != "http://www.w3.org/XML/1998/namespace": + elif qfnamespaces[0] == "xml" and + namespaceURI != "http://www.w3.org/XML/1998/namespace" and + qfnamespaces[1] notin stdattrnames: raise newException(ENamespaceErr, "When the namespace prefix is \"xml\" namespaceURI has to be \"http://www.w3.org/XML/1998/namespace\"") - elif qualifiedName.split(':')[1].toLower() == "xmlns" and namespaceURI != "http://www.w3.org/2000/xmlns/": + elif qfnamespaces[1] == "xmlns" and namespaceURI != "http://www.w3.org/2000/xmlns/": raise newException(ENamespaceErr, "When the namespace prefix is \"xmlns\" namespaceURI has to be \"http://www.w3.org/2000/xmlns/\"") @@ -305,9 +311,12 @@ proc createElement*(doc: PDocument, tagName: string): PElement = proc createElementNS*(doc: PDocument, namespaceURI: string, qualifiedName: string): PElement = ## Creates an element of the given qualified name and namespace URI. if qualifiedName.contains(':'): + let qfnamespaces = qualifiedName.toLower().split(':') if isNil(namespaceURI): raise newException(ENamespaceErr, "When qualifiedName contains a prefix namespaceURI cannot be nil") - elif qualifiedName.split(':')[0].toLower() == "xml" and namespaceURI != "http://www.w3.org/XML/1998/namespace": + elif qfnamespaces[0] == "xml" and + namespaceURI != "http://www.w3.org/XML/1998/namespace" and + qfnamespaces[1] notin stdattrnames: raise newException(ENamespaceErr, "When the namespace prefix is \"xml\" namespaceURI has to be \"http://www.w3.org/XML/1998/namespace\"") |