diff options
author | Daniil Yarancev <21169548+Yardanico@users.noreply.github.com> | 2018-06-05 21:25:45 +0300 |
---|---|---|
committer | GitHub <noreply@github.com> | 2018-06-05 21:25:45 +0300 |
commit | 642641359821b6a63c6cf7edaaa45873b7ea59c7 (patch) | |
tree | 627af3020528cb916b3174bd94304307ca875c77 /lib/pure/net.nim | |
parent | fb44c522e6173528efa8035ecc459c84887d0167 (diff) | |
parent | 3cbc07ac7877b03c605498760fe198e3200cc197 (diff) | |
download | Nim-642641359821b6a63c6cf7edaaa45873b7ea59c7.tar.gz |
Merge pull request #2 from nim-lang/devel
Update
Diffstat (limited to 'lib/pure/net.nim')
-rw-r--r-- | lib/pure/net.nim | 133 |
1 files changed, 94 insertions, 39 deletions
diff --git a/lib/pure/net.nim b/lib/pure/net.nim index aad6ab3e8..bf5f3f57e 100644 --- a/lib/pure/net.nim +++ b/lib/pure/net.nim @@ -64,8 +64,11 @@ ## socket.acceptAddr(client, address) ## echo("Client connected from: ", address) ## +## **Note:** The ``client`` variable is initialised with ``new Socket`` **not** +## ``newSocket()``. The difference is that the latter creates a new file +## descriptor. -{.deadCodeElim: on.} +{.deadCodeElim: on.} # dce option deprecated import nativesockets, os, strutils, parseutils, times, sets, options export Port, `$`, `==` export Domain, SockType, Protocol @@ -107,9 +110,6 @@ when defineSsl: serverGetPskFunc: SslServerGetPskFunc clientGetPskFunc: SslClientGetPskFunc - {.deprecated: [ESSL: SSLError, TSSLCVerifyMode: SSLCVerifyMode, - TSSLProtVersion: SSLProtVersion, PSSLContext: SSLContext, - TSSLAcceptResult: SSLAcceptResult].} else: type SslContext* = void # TODO: Workaround #4797. @@ -156,10 +156,6 @@ type Peek, SafeDisconn ## Ensures disconnection exceptions (ECONNRESET, EPIPE etc) are not thrown. -{.deprecated: [TSocketFlags: SocketFlag, ETimeout: TimeoutError, - TReadLineResult: ReadLineResult, TSOBool: SOBool, PSocket: Socket, - TSocketImpl: SocketImpl].} - type IpAddressFamily* {.pure.} = enum ## Describes the type of an IP address IPv6, ## IPv6 address @@ -173,8 +169,6 @@ type of IpAddressFamily.IPv4: address_v4*: array[0..3, uint8] ## Contains the IP address in bytes in ## case of IPv4 -{.deprecated: [TIpAddress: IpAddress].} - proc socketError*(socket: Socket, err: int = -1, async = false, lastError = (-1).OSErrorCode): void {.gcsafe.} @@ -221,7 +215,7 @@ proc newSocket*(domain, sockType, protocol: cint, buffered = true): Socket = ## Creates a new socket. ## ## If an error occurs EOS will be raised. - let fd = newNativeSocket(domain, sockType, protocol) + let fd = createNativeSocket(domain, sockType, protocol) if fd == osInvalidSocket: raiseOSError(osLastError()) result = newSocket(fd, domain.Domain, sockType.SockType, protocol.Protocol, @@ -232,7 +226,7 @@ proc newSocket*(domain: Domain = AF_INET, sockType: SockType = SOCK_STREAM, ## Creates a new socket. ## ## If an error occurs EOS will be raised. - let fd = newNativeSocket(domain, sockType, protocol) + let fd = createNativeSocket(domain, sockType, protocol) if fd == osInvalidSocket: raiseOSError(osLastError()) result = newSocket(fd, domain, sockType, protocol, buffered) @@ -411,9 +405,46 @@ proc isIpAddress*(address_str: string): bool {.tags: [].} = return false return true +proc toSockAddr*(address: IpAddress, port: Port, sa: var Sockaddr_storage, sl: var Socklen) = + ## Converts `IpAddress` and `Port` to `SockAddr` and `Socklen` + let port = htons(uint16(port)) + case address.family + of IpAddressFamily.IPv4: + sl = sizeof(Sockaddr_in).Socklen + let s = cast[ptr Sockaddr_in](addr sa) + s.sin_family = type(s.sin_family)(AF_INET) + s.sin_port = port + copyMem(addr s.sin_addr, unsafeAddr address.address_v4[0], sizeof(s.sin_addr)) + of IpAddressFamily.IPv6: + sl = sizeof(Sockaddr_in6).Socklen + let s = cast[ptr Sockaddr_in6](addr sa) + s.sin6_family = type(s.sin6_family)(AF_INET6) + s.sin6_port = port + copyMem(addr s.sin6_addr, unsafeAddr address.address_v6[0], sizeof(s.sin6_addr)) + +proc fromSockAddrAux(sa: ptr Sockaddr_storage, sl: Socklen, address: var IpAddress, port: var Port) = + if sa.ss_family.int == AF_INET.int and sl == sizeof(Sockaddr_in).Socklen: + address = IpAddress(family: IpAddressFamily.IPv4) + let s = cast[ptr Sockaddr_in](sa) + copyMem(addr address.address_v4[0], addr s.sin_addr, sizeof(address.address_v4)) + port = ntohs(s.sin_port).Port + elif sa.ss_family.int == AF_INET6.int and sl == sizeof(Sockaddr_in6).Socklen: + address = IpAddress(family: IpAddressFamily.IPv6) + let s = cast[ptr Sockaddr_in6](sa) + copyMem(addr address.address_v6[0], addr s.sin6_addr, sizeof(address.address_v6)) + port = ntohs(s.sin6_port).Port + else: + raise newException(ValueError, "Neither IPv4 nor IPv6") + +proc fromSockAddr*(sa: Sockaddr_storage | SockAddr | Sockaddr_in | Sockaddr_in6, + sl: Socklen, address: var IpAddress, port: var Port) {.inline.} = + ## Converts `SockAddr` and `Socklen` to `IpAddress` and `Port`. Raises + ## `ObjectConversionError` in case of invalid `sa` and `sl` arguments. + fromSockAddrAux(unsafeAddr sa, sl, address, port) + when defineSsl: CRYPTO_malloc_init() - SslLibraryInit() + doAssert SslLibraryInit() == 1 SslLoadErrorStrings() ErrLoadBioStrings() OpenSSL_add_all_algorithms() @@ -427,8 +458,14 @@ when defineSsl: raise newException(SSLError, "No error reported.") if err == -1: raiseOSError(osLastError()) - var errStr = ErrErrorString(err, nil) - raise newException(SSLError, $errStr) + var errStr = $ErrErrorString(err, nil) + case err + of 336032814, 336032784: + errStr = "Please upgrade your OpenSSL library, it does not support the " & + "necessary protocols. OpenSSL error is: " & errStr + else: + discard + raise newException(SSLError, errStr) proc getExtraData*(ctx: SSLContext, index: int): RootRef = ## Retrieves arbitrary data stored inside SSLContext. @@ -753,10 +790,10 @@ proc acceptAddr*(server: Socket, client: var Socket, address: var string, ## flag is specified then this error will not be raised and instead ## accept will be called again. assert(client != nil) - var sockAddress: Sockaddr_in - var addrLen = sizeof(sockAddress).SockLen - var sock = accept(server.fd, cast[ptr SockAddr](addr(sockAddress)), - addr(addrLen)) + assert client.fd.int <= 0, "Client socket needs to be initialised with " & + "`new`, not `newSocket`." + let ret = accept(server.fd) + let sock = ret[0] if sock == osInvalidSocket: let err = osLastError() @@ -764,7 +801,9 @@ proc acceptAddr*(server: Socket, client: var Socket, address: var string, acceptAddr(server, client, address, flags) raiseOSError(err) else: + address = ret[1] client.fd = sock + client.domain = getSockDomain(sock) client.isBuffered = server.isBuffered # Handle SSL. @@ -776,9 +815,6 @@ proc acceptAddr*(server: Socket, client: var Socket, address: var string, let ret = SSLAccept(client.sslHandle) socketError(client, ret, false) - # Client socket is set above. - address = $inet_ntoa(sockAddress.sin_addr) - when false: #defineSsl: proc acceptAddrSSL*(server: Socket, client: var Socket, address: var string): SSLAcceptResult {. @@ -868,6 +904,7 @@ proc close*(socket: Socket) = socket.sslHandle = nil socket.fd.close() + socket.fd = osInvalidSocket when defined(posix): from posix import TCP_NODELAY @@ -924,7 +961,7 @@ when defined(posix) and not defined(nimdoc): raise newException(ValueError, "socket path too long") copyMem(addr result.sun_path, path.cstring, path.len + 1) -when defined(posix): +when defined(posix) or defined(nimdoc): proc connectUnix*(socket: Socket, path: string) = ## Connects to Unix socket on `path`. ## This only works on Unix-style systems: Mac OS X, BSD and Linux @@ -1005,15 +1042,25 @@ proc select(readfd: Socket, timeout = 500): int = var fds = @[readfd.fd] result = select(fds, timeout) -proc readIntoBuf(socket: Socket, flags: int32): int = +proc isClosed(socket: Socket): bool = + socket.fd == osInvalidSocket + +proc uniRecv(socket: Socket, buffer: pointer, size, flags: cint): int = + ## Handles SSL and non-ssl recv in a nice package. + ## + ## In particular handles the case where socket has been closed properly + ## for both SSL and non-ssl. result = 0 + assert(not socket.isClosed, "Cannot `recv` on a closed socket") when defineSsl: - if socket.isSSL: - result = SSLRead(socket.sslHandle, addr(socket.buffer), int(socket.buffer.high)) - else: - result = recv(socket.fd, addr(socket.buffer), cint(socket.buffer.high), flags) - else: - result = recv(socket.fd, addr(socket.buffer), cint(socket.buffer.high), flags) + if socket.isSsl: + return SSLRead(socket.sslHandle, buffer, size) + + return recv(socket.fd, buffer, size, flags) + +proc readIntoBuf(socket: Socket, flags: int32): int = + result = 0 + result = uniRecv(socket, addr(socket.buffer), socket.buffer.high, flags) if result < 0: # Save it in case it gets reset (the Nim codegen occasionally may call # Win API functions which reset it). @@ -1059,16 +1106,16 @@ proc recv*(socket: Socket, data: pointer, size: int): int {.tags: [ReadIOEffect] else: when defineSsl: if socket.isSSL: - if socket.sslHasPeekChar: + if socket.sslHasPeekChar: # TODO: Merge this peek char mess into uniRecv copyMem(data, addr(socket.sslPeekChar), 1) socket.sslHasPeekChar = false if size-1 > 0: var d = cast[cstring](data) - result = SSLRead(socket.sslHandle, addr(d[1]), size-1) + 1 + result = uniRecv(socket, addr(d[1]), cint(size-1), 0'i32) + 1 else: result = 1 else: - result = SSLRead(socket.sslHandle, data, size) + result = uniRecv(socket, data, size.cint, 0'i32) else: result = recv(socket.fd, data, size.cint, 0'i32) else: @@ -1145,7 +1192,11 @@ proc recv*(socket: Socket, data: var string, size: int, timeout = -1, ## ## **Warning**: Only the ``SafeDisconn`` flag is currently supported. data.setLen(size) - result = recv(socket, cstring(data), size, timeout) + result = + if timeout == -1: + recv(socket, cstring(data), size) + else: + recv(socket, cstring(data), size, timeout) if result < 0: data.setLen(0) let lastError = getSocketError(socket) @@ -1182,7 +1233,7 @@ proc peekChar(socket: Socket, c: var char): int {.tags: [ReadIOEffect].} = when defineSsl: if socket.isSSL: if not socket.sslHasPeekChar: - result = SSLRead(socket.sslHandle, addr(socket.sslPeekChar), 1) + result = uniRecv(socket, addr(socket.sslPeekChar), 1, 0'i32) socket.sslHasPeekChar = true c = socket.sslPeekChar @@ -1316,6 +1367,7 @@ proc send*(socket: Socket, data: pointer, size: int): int {. ## ## **Note**: This is a low-level version of ``send``. You likely should use ## the version below. + assert(not socket.isClosed, "Cannot `send` on a closed socket") when defineSsl: if socket.isSSL: return SSLWrite(socket.sslHandle, cast[cstring](data), size) @@ -1360,8 +1412,8 @@ proc sendTo*(socket: Socket, address: string, port: Port, data: pointer, ## which is defined below. ## ## **Note:** This proc is not available for SSL sockets. - var aiList = getAddrInfo(address, port, af) - + assert(not socket.isClosed, "Cannot `sendTo` on a closed socket") + var aiList = getAddrInfo(address, port, af, socket.sockType, socket.protocol) # try all possibilities: var success = false var it = aiList @@ -1382,7 +1434,7 @@ proc sendTo*(socket: Socket, address: string, port: Port, ## this function will try each IP of that hostname. ## ## This is the high-level version of the above ``sendTo`` function. - result = socket.sendTo(address, port, cstring(data), data.len) + result = socket.sendTo(address, port, cstring(data), data.len, socket.domain ) proc isSsl*(socket: Socket): bool = @@ -1531,7 +1583,7 @@ proc dial*(address: string, port: Port, domain = domainOpt.unsafeGet() lastFd = fdPerDomain[ord(domain)] if lastFd == osInvalidSocket: - lastFd = newNativeSocket(domain, sockType, protocol) + lastFd = createNativeSocket(domain, sockType, protocol) if lastFd == osInvalidSocket: # we always raise if socket creation failed, because it means a # network system problem (e.g. not enough FDs), and not an unreachable @@ -1640,6 +1692,9 @@ proc connect*(socket: Socket, address: string, port = Port(0), if selectWrite(s, timeout) != 1: raise newException(TimeoutError, "Call to 'connect' timed out.") else: + let res = getSockOptInt(socket.fd, SOL_SOCKET, SO_ERROR) + if res != 0: + raiseOSError(OSErrorCode(res)) when defineSsl and not defined(nimdoc): if socket.isSSL: socket.fd.setBlocking(true) |