diff options
Diffstat (limited to 'lib/pure/net.nim')
-rw-r--r-- | lib/pure/net.nim | 586 |
1 files changed, 299 insertions, 287 deletions
diff --git a/lib/pure/net.nim b/lib/pure/net.nim index 5f83b1dea..f35e0cf63 100644 --- a/lib/pure/net.nim +++ b/lib/pure/net.nim @@ -15,288 +15,6 @@ export TPort, `$` const useWinVersion = defined(Windows) or defined(nimdoc) -type - IpAddressFamily* {.pure.} = enum ## Describes the type of an IP address - IPv6, ## IPv6 address - IPv4 ## IPv4 address - - TIpAddress* = object ## stores an arbitrary IP address - case family*: IpAddressFamily ## the type of the IP address (IPv4 or IPv6) - of IpAddressFamily.IPv6: - address_v6*: array[0..15, uint8] ## Contains the IP address in bytes in - ## case of IPv6 - of IpAddressFamily.IPv4: - address_v4*: array[0..3, uint8] ## Contains the IP address in bytes in - ## case of IPv4 - -proc IPv4_any*(): TIpAddress = - ## Returns the IPv4 any address, which can be used to listen on all available - ## network adapters - result = TIpAddress( - family: IpAddressFamily.IPv4, - address_v4: [0'u8, 0, 0, 0]) - -proc IPv4_loopback*(): TIpAddress = - ## Returns the IPv4 loopback address (127.0.0.1) - result = TIpAddress( - family: IpAddressFamily.IPv4, - address_v4: [127'u8, 0, 0, 1]) - -proc IPv4_broadcast*(): TIpAddress = - ## Returns the IPv4 broadcast address (255.255.255.255) - result = TIpAddress( - family: IpAddressFamily.IPv4, - address_v4: [255'u8, 255, 255, 255]) - -proc IPv6_any*(): TIpAddress = - ## Returns the IPv6 any address (::0), which can be used - ## to listen on all available network adapters - result = TIpAddress( - family: IpAddressFamily.IPv6, - address_v6: [0'u8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]) - -proc IPv6_loopback*(): TIpAddress = - ## Returns the IPv6 loopback address (::1) - result = TIpAddress( - family: IpAddressFamily.IPv6, - address_v6: [0'u8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1]) - -proc `==`*(lhs, rhs: TIpAddress): bool = - ## Compares two IpAddresses for Equality. Returns two if the addresses are equal - if lhs.family != rhs.family: return false - if lhs.family == IpAddressFamily.IPv4: - for i in low(lhs.address_v4) .. high(lhs.address_v4): - if lhs.address_v4[i] != rhs.address_v4[i]: return false - else: # IPv6 - for i in low(lhs.address_v6) .. high(lhs.address_v6): - if lhs.address_v6[i] != rhs.address_v6[i]: return false - return true - -proc `$`*(address: TIpAddress): string = - ## Converts an TIpAddress into the textual representation - result = "" - case address.family - of IpAddressFamily.IPv4: - for i in 0 .. 3: - if i != 0: - result.add('.') - result.add($address.address_v4[i]) - of IpAddressFamily.IPv6: - var - currentZeroStart = -1 - currentZeroCount = 0 - biggestZeroStart = -1 - biggestZeroCount = 0 - # Look for the largest block of zeros - for i in 0..7: - var isZero = address.address_v6[i*2] == 0 and address.address_v6[i*2+1] == 0 - if isZero: - if currentZeroStart == -1: - currentZeroStart = i - currentZeroCount = 1 - else: - currentZeroCount.inc() - if currentZeroCount > biggestZeroCount: - biggestZeroCount = currentZeroCount - biggestZeroStart = currentZeroStart - else: - currentZeroStart = -1 - - if biggestZeroCount == 8: # Special case ::0 - result.add("::") - else: # Print address - var printedLastGroup = false - for i in 0..7: - var word:uint16 = (cast[uint16](address.address_v6[i*2])) shl 8 - word = word or cast[uint16](address.address_v6[i*2+1]) - - if biggestZeroCount != 0 and # Check if group is in skip group - (i >= biggestZeroStart and i < (biggestZeroStart + biggestZeroCount)): - if i == biggestZeroStart: # skip start - result.add("::") - printedLastGroup = false - else: - if printedLastGroup: - result.add(':') - var - afterLeadingZeros = false - mask = 0xF000'u16 - for j in 0'u16..3'u16: - var val = (mask and word) shr (4'u16*(3'u16-j)) - if val != 0 or afterLeadingZeros: - if val < 0xA: - result.add(chr(uint16(ord('0'))+val)) - else: # val >= 0xA - result.add(chr(uint16(ord('a'))+val-0xA)) - afterLeadingZeros = true - mask = mask shr 4 - printedLastGroup = true - -proc parseIPv4Address(address_str: string): TIpAddress = - ## Parses IPv4 adresses - ## Raises EInvalidValue on errors - var - byteCount = 0 - currentByte:uint16 = 0 - seperatorValid = false - - result.family = IpAddressFamily.IPv4 - - for i in 0 .. high(address_str): - if address_str[i] in strutils.Digits: # Character is a number - currentByte = currentByte * 10 + - cast[uint16](ord(address_str[i]) - ord('0')) - if currentByte > 255'u16: - raise newException(EInvalidValue, - "Invalid IP Address. Value is out of range") - seperatorValid = true - elif address_str[i] == '.': # IPv4 address separator - if not seperatorValid or byteCount >= 3: - raise newException(EInvalidValue, - "Invalid IP Address. The address consists of too many groups") - result.address_v4[byteCount] = cast[uint8](currentByte) - currentByte = 0 - byteCount.inc - seperatorValid = false - else: - raise newException(EInvalidValue, - "Invalid IP Address. Address contains an invalid character") - - if byteCount != 3 or not seperatorValid: - raise newException(EInvalidValue, "Invalid IP Address") - result.address_v4[byteCount] = cast[uint8](currentByte) - -proc parseIPv6Address(address_str: string): TIpAddress = - ## Parses IPv6 adresses - ## Raises EInvalidValue on errors - result.family = IpAddressFamily.IPv6 - if address_str.len < 2: - raise newException(EInvalidValue, "Invalid IP Address") - - var - groupCount = 0 - currentGroupStart = 0 - currentShort:uint32 = 0 - seperatorValid = true - dualColonGroup = -1 - lastWasColon = false - v4StartPos = -1 - byteCount = 0 - - for i,c in address_str: - if c == ':': - if not seperatorValid: - raise newException(EInvalidValue, - "Invalid IP Address. Address contains an invalid seperator") - if lastWasColon: - if dualColonGroup != -1: - raise newException(EInvalidValue, - "Invalid IP Address. Address contains more than one \"::\" seperator") - dualColonGroup = groupCount - seperatorValid = false - elif i != 0 and i != high(address_str): - if groupCount >= 8: - raise newException(EInvalidValue, - "Invalid IP Address. The address consists of too many groups") - result.address_v6[groupCount*2] = cast[uint8](currentShort shr 8) - result.address_v6[groupCount*2+1] = cast[uint8](currentShort and 0xFF) - currentShort = 0 - groupCount.inc() - if dualColonGroup != -1: seperatorValid = false - elif i == 0: # only valid if address starts with :: - if address_str[1] != ':': - raise newException(EInvalidValue, - "Invalid IP Address. Address may not start with \":\"") - else: # i == high(address_str) - only valid if address ends with :: - if address_str[high(address_str)-1] != ':': - raise newException(EInvalidValue, - "Invalid IP Address. Address may not end with \":\"") - lastWasColon = true - currentGroupStart = i + 1 - elif c == '.': # Switch to parse IPv4 mode - if i < 3 or not seperatorValid or groupCount >= 7: - raise newException(EInvalidValue, "Invalid IP Address") - v4StartPos = currentGroupStart - currentShort = 0 - seperatorValid = false - break - elif c in strutils.HexDigits: - if c in strutils.Digits: # Normal digit - currentShort = (currentShort shl 4) + cast[uint32](ord(c) - ord('0')) - elif c >= 'a' and c <= 'f': # Lower case hex - currentShort = (currentShort shl 4) + cast[uint32](ord(c) - ord('a')) + 10 - else: # Upper case hex - currentShort = (currentShort shl 4) + cast[uint32](ord(c) - ord('A')) + 10 - if currentShort > 65535'u32: - raise newException(EInvalidValue, - "Invalid IP Address. Value is out of range") - lastWasColon = false - seperatorValid = true - else: - raise newException(EInvalidValue, - "Invalid IP Address. Address contains an invalid character") - - - if v4StartPos == -1: # Don't parse v4. Copy the remaining v6 stuff - if seperatorValid: # Copy remaining data - if groupCount >= 8: - raise newException(EInvalidValue, - "Invalid IP Address. The address consists of too many groups") - result.address_v6[groupCount*2] = cast[uint8](currentShort shr 8) - result.address_v6[groupCount*2+1] = cast[uint8](currentShort and 0xFF) - groupCount.inc() - else: # Must parse IPv4 address - for i,c in address_str[v4StartPos..high(address_str)]: - if c in strutils.Digits: # Character is a number - currentShort = currentShort * 10 + cast[uint32](ord(c) - ord('0')) - if currentShort > 255'u32: - raise newException(EInvalidValue, - "Invalid IP Address. Value is out of range") - seperatorValid = true - elif c == '.': # IPv4 address separator - if not seperatorValid or byteCount >= 3: - raise newException(EInvalidValue, "Invalid IP Address") - result.address_v6[groupCount*2 + byteCount] = cast[uint8](currentShort) - currentShort = 0 - byteCount.inc() - seperatorValid = false - else: # Invalid character - raise newException(EInvalidValue, - "Invalid IP Address. Address contains an invalid character") - - if byteCount != 3 or not seperatorValid: - raise newException(EInvalidValue, "Invalid IP Address") - result.address_v6[groupCount*2 + byteCount] = cast[uint8](currentShort) - groupCount += 2 - - # Shift and fill zeros in case of :: - if groupCount > 8: - raise newException(EInvalidValue, - "Invalid IP Address. The address consists of too many groups") - elif groupCount < 8: # must fill - if dualColonGroup == -1: - raise newException(EInvalidValue, - "Invalid IP Address. The address consists of too few groups") - var toFill = 8 - groupCount # The number of groups to fill - var toShift = groupCount - dualColonGroup # Nr of known groups after :: - for i in 0..2*toShift-1: # shift - result.address_v6[15-i] = result.address_v6[groupCount*2-i-1] - for i in 0..2*toFill-1: # fill with 0s - result.address_v6[dualColonGroup*2+i] = 0 - elif dualColonGroup != -1: - raise newException(EInvalidValue, - "Invalid IP Address. The address consists of too many groups") - -proc parseIpAddress*(address_str: string): TIpAddress = - ## Parses an IP address - ## Raises EInvalidValue on error - if address_str == nil: - raise newException(EInvalidValue, "IP Address string is nil") - if address_str.contains(':'): - return parseIPv6Address(address_str) - else: - return parseIPv4Address(address_str) - when defined(ssl): import openssl @@ -569,8 +287,8 @@ proc bindAddr*(socket: PSocket, port = TPort(0), address = "") {. osError(osLastError()) dealloc(aiList) -proc acceptAddr*(server: PSocket, client: var PSocket, address: var string) {. - tags: [FReadIO].} = +proc acceptAddr*(server: PSocket, client: var PSocket, address: var string, + flags = {TSocketFlags.SafeDisconn}) {.tags: [FReadIO].} = ## Blocks until a connection is being made from a client. When a connection ## is made sets ``client`` to the client socket and ``address`` to the address ## of the connecting client. @@ -581,6 +299,11 @@ proc acceptAddr*(server: PSocket, client: var PSocket, address: var string) {. ## ## **Note**: ``client`` must be initialised (with ``new``), this function ## makes no effort to initialise the ``client`` variable. + ## + ## The ``accept`` call may result in an error if the connecting socket + ## disconnects during the duration of the ``accept``. If the ``SafeDisconn`` + ## flag is specified then this error will not be raised and instead + ## accept will be called again. assert(client != nil) var sockAddress: Tsockaddr_in var addrLen = sizeof(sockAddress).TSocklen @@ -589,6 +312,8 @@ proc acceptAddr*(server: PSocket, client: var PSocket, address: var string) {. if sock == osInvalidSocket: let err = osLastError() + if flags.isDisconnectionError(err): + acceptAddr(server, client, address, flags) osError(err) else: client.fd = sock @@ -658,15 +383,20 @@ when false: #defined(ssl): acceptAddrPlain(AcceptNoClient, AcceptSuccess): doHandshake() -proc accept*(server: PSocket, client: var PSocket) {.tags: [FReadIO].} = +proc accept*(server: PSocket, client: var PSocket, + flags = {TSocketFlags.SafeDisconn}) {.tags: [FReadIO].} = ## Equivalent to ``acceptAddr`` but doesn't return the address, only the ## socket. ## ## **Note**: ``client`` must be initialised (with ``new``), this function ## makes no effort to initialise the ``client`` variable. - + ## + ## The ``accept`` call may result in an error if the connecting socket + ## disconnects during the duration of the ``accept``. If the ``SafeDisconn`` + ## flag is specified then this error will not be raised and instead + ## accept will be called again. var addrDummy = "" - acceptAddr(server, client, addrDummy) + acceptAddr(server, client, addrDummy, flags) proc close*(socket: PSocket) = ## Closes a socket. @@ -1173,3 +903,285 @@ proc isSSL*(socket: PSocket): bool = return socket.isSSL proc getFD*(socket: PSocket): TSocketHandle = return socket.fd ## Returns the socket's file descriptor + +type + IpAddressFamily* {.pure.} = enum ## Describes the type of an IP address + IPv6, ## IPv6 address + IPv4 ## IPv4 address + + TIpAddress* = object ## stores an arbitrary IP address + case family*: IpAddressFamily ## the type of the IP address (IPv4 or IPv6) + of IpAddressFamily.IPv6: + address_v6*: array[0..15, uint8] ## Contains the IP address in bytes in + ## case of IPv6 + of IpAddressFamily.IPv4: + address_v4*: array[0..3, uint8] ## Contains the IP address in bytes in + ## case of IPv4 + +proc IPv4_any*(): TIpAddress = + ## Returns the IPv4 any address, which can be used to listen on all available + ## network adapters + result = TIpAddress( + family: IpAddressFamily.IPv4, + address_v4: [0'u8, 0, 0, 0]) + +proc IPv4_loopback*(): TIpAddress = + ## Returns the IPv4 loopback address (127.0.0.1) + result = TIpAddress( + family: IpAddressFamily.IPv4, + address_v4: [127'u8, 0, 0, 1]) + +proc IPv4_broadcast*(): TIpAddress = + ## Returns the IPv4 broadcast address (255.255.255.255) + result = TIpAddress( + family: IpAddressFamily.IPv4, + address_v4: [255'u8, 255, 255, 255]) + +proc IPv6_any*(): TIpAddress = + ## Returns the IPv6 any address (::0), which can be used + ## to listen on all available network adapters + result = TIpAddress( + family: IpAddressFamily.IPv6, + address_v6: [0'u8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]) + +proc IPv6_loopback*(): TIpAddress = + ## Returns the IPv6 loopback address (::1) + result = TIpAddress( + family: IpAddressFamily.IPv6, + address_v6: [0'u8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1]) + +proc `==`*(lhs, rhs: TIpAddress): bool = + ## Compares two IpAddresses for Equality. Returns two if the addresses are equal + if lhs.family != rhs.family: return false + if lhs.family == IpAddressFamily.IPv4: + for i in low(lhs.address_v4) .. high(lhs.address_v4): + if lhs.address_v4[i] != rhs.address_v4[i]: return false + else: # IPv6 + for i in low(lhs.address_v6) .. high(lhs.address_v6): + if lhs.address_v6[i] != rhs.address_v6[i]: return false + return true + +proc `$`*(address: TIpAddress): string = + ## Converts an TIpAddress into the textual representation + result = "" + case address.family + of IpAddressFamily.IPv4: + for i in 0 .. 3: + if i != 0: + result.add('.') + result.add($address.address_v4[i]) + of IpAddressFamily.IPv6: + var + currentZeroStart = -1 + currentZeroCount = 0 + biggestZeroStart = -1 + biggestZeroCount = 0 + # Look for the largest block of zeros + for i in 0..7: + var isZero = address.address_v6[i*2] == 0 and address.address_v6[i*2+1] == 0 + if isZero: + if currentZeroStart == -1: + currentZeroStart = i + currentZeroCount = 1 + else: + currentZeroCount.inc() + if currentZeroCount > biggestZeroCount: + biggestZeroCount = currentZeroCount + biggestZeroStart = currentZeroStart + else: + currentZeroStart = -1 + + if biggestZeroCount == 8: # Special case ::0 + result.add("::") + else: # Print address + var printedLastGroup = false + for i in 0..7: + var word:uint16 = (cast[uint16](address.address_v6[i*2])) shl 8 + word = word or cast[uint16](address.address_v6[i*2+1]) + + if biggestZeroCount != 0 and # Check if group is in skip group + (i >= biggestZeroStart and i < (biggestZeroStart + biggestZeroCount)): + if i == biggestZeroStart: # skip start + result.add("::") + printedLastGroup = false + else: + if printedLastGroup: + result.add(':') + var + afterLeadingZeros = false + mask = 0xF000'u16 + for j in 0'u16..3'u16: + var val = (mask and word) shr (4'u16*(3'u16-j)) + if val != 0 or afterLeadingZeros: + if val < 0xA: + result.add(chr(uint16(ord('0'))+val)) + else: # val >= 0xA + result.add(chr(uint16(ord('a'))+val-0xA)) + afterLeadingZeros = true + mask = mask shr 4 + printedLastGroup = true + +proc parseIPv4Address(address_str: string): TIpAddress = + ## Parses IPv4 adresses + ## Raises EInvalidValue on errors + var + byteCount = 0 + currentByte:uint16 = 0 + seperatorValid = false + + result.family = IpAddressFamily.IPv4 + + for i in 0 .. high(address_str): + if address_str[i] in strutils.Digits: # Character is a number + currentByte = currentByte * 10 + + cast[uint16](ord(address_str[i]) - ord('0')) + if currentByte > 255'u16: + raise newException(EInvalidValue, + "Invalid IP Address. Value is out of range") + seperatorValid = true + elif address_str[i] == '.': # IPv4 address separator + if not seperatorValid or byteCount >= 3: + raise newException(EInvalidValue, + "Invalid IP Address. The address consists of too many groups") + result.address_v4[byteCount] = cast[uint8](currentByte) + currentByte = 0 + byteCount.inc + seperatorValid = false + else: + raise newException(EInvalidValue, + "Invalid IP Address. Address contains an invalid character") + + if byteCount != 3 or not seperatorValid: + raise newException(EInvalidValue, "Invalid IP Address") + result.address_v4[byteCount] = cast[uint8](currentByte) + +proc parseIPv6Address(address_str: string): TIpAddress = + ## Parses IPv6 adresses + ## Raises EInvalidValue on errors + result.family = IpAddressFamily.IPv6 + if address_str.len < 2: + raise newException(EInvalidValue, "Invalid IP Address") + + var + groupCount = 0 + currentGroupStart = 0 + currentShort:uint32 = 0 + seperatorValid = true + dualColonGroup = -1 + lastWasColon = false + v4StartPos = -1 + byteCount = 0 + + for i,c in address_str: + if c == ':': + if not seperatorValid: + raise newException(EInvalidValue, + "Invalid IP Address. Address contains an invalid seperator") + if lastWasColon: + if dualColonGroup != -1: + raise newException(EInvalidValue, + "Invalid IP Address. Address contains more than one \"::\" seperator") + dualColonGroup = groupCount + seperatorValid = false + elif i != 0 and i != high(address_str): + if groupCount >= 8: + raise newException(EInvalidValue, + "Invalid IP Address. The address consists of too many groups") + result.address_v6[groupCount*2] = cast[uint8](currentShort shr 8) + result.address_v6[groupCount*2+1] = cast[uint8](currentShort and 0xFF) + currentShort = 0 + groupCount.inc() + if dualColonGroup != -1: seperatorValid = false + elif i == 0: # only valid if address starts with :: + if address_str[1] != ':': + raise newException(EInvalidValue, + "Invalid IP Address. Address may not start with \":\"") + else: # i == high(address_str) - only valid if address ends with :: + if address_str[high(address_str)-1] != ':': + raise newException(EInvalidValue, + "Invalid IP Address. Address may not end with \":\"") + lastWasColon = true + currentGroupStart = i + 1 + elif c == '.': # Switch to parse IPv4 mode + if i < 3 or not seperatorValid or groupCount >= 7: + raise newException(EInvalidValue, "Invalid IP Address") + v4StartPos = currentGroupStart + currentShort = 0 + seperatorValid = false + break + elif c in strutils.HexDigits: + if c in strutils.Digits: # Normal digit + currentShort = (currentShort shl 4) + cast[uint32](ord(c) - ord('0')) + elif c >= 'a' and c <= 'f': # Lower case hex + currentShort = (currentShort shl 4) + cast[uint32](ord(c) - ord('a')) + 10 + else: # Upper case hex + currentShort = (currentShort shl 4) + cast[uint32](ord(c) - ord('A')) + 10 + if currentShort > 65535'u32: + raise newException(EInvalidValue, + "Invalid IP Address. Value is out of range") + lastWasColon = false + seperatorValid = true + else: + raise newException(EInvalidValue, + "Invalid IP Address. Address contains an invalid character") + + + if v4StartPos == -1: # Don't parse v4. Copy the remaining v6 stuff + if seperatorValid: # Copy remaining data + if groupCount >= 8: + raise newException(EInvalidValue, + "Invalid IP Address. The address consists of too many groups") + result.address_v6[groupCount*2] = cast[uint8](currentShort shr 8) + result.address_v6[groupCount*2+1] = cast[uint8](currentShort and 0xFF) + groupCount.inc() + else: # Must parse IPv4 address + for i,c in address_str[v4StartPos..high(address_str)]: + if c in strutils.Digits: # Character is a number + currentShort = currentShort * 10 + cast[uint32](ord(c) - ord('0')) + if currentShort > 255'u32: + raise newException(EInvalidValue, + "Invalid IP Address. Value is out of range") + seperatorValid = true + elif c == '.': # IPv4 address separator + if not seperatorValid or byteCount >= 3: + raise newException(EInvalidValue, "Invalid IP Address") + result.address_v6[groupCount*2 + byteCount] = cast[uint8](currentShort) + currentShort = 0 + byteCount.inc() + seperatorValid = false + else: # Invalid character + raise newException(EInvalidValue, + "Invalid IP Address. Address contains an invalid character") + + if byteCount != 3 or not seperatorValid: + raise newException(EInvalidValue, "Invalid IP Address") + result.address_v6[groupCount*2 + byteCount] = cast[uint8](currentShort) + groupCount += 2 + + # Shift and fill zeros in case of :: + if groupCount > 8: + raise newException(EInvalidValue, + "Invalid IP Address. The address consists of too many groups") + elif groupCount < 8: # must fill + if dualColonGroup == -1: + raise newException(EInvalidValue, + "Invalid IP Address. The address consists of too few groups") + var toFill = 8 - groupCount # The number of groups to fill + var toShift = groupCount - dualColonGroup # Nr of known groups after :: + for i in 0..2*toShift-1: # shift + result.address_v6[15-i] = result.address_v6[groupCount*2-i-1] + for i in 0..2*toFill-1: # fill with 0s + result.address_v6[dualColonGroup*2+i] = 0 + elif dualColonGroup != -1: + raise newException(EInvalidValue, + "Invalid IP Address. The address consists of too many groups") + +proc parseIpAddress*(address_str: string): TIpAddress = + ## Parses an IP address + ## Raises EInvalidValue on error + if address_str == nil: + raise newException(EInvalidValue, "IP Address string is nil") + if address_str.contains(':'): + return parseIPv6Address(address_str) + else: + return parseIPv4Address(address_str) |