summary refs log tree commit diff stats
path: root/lib/pure/net.nim
blob: 9ee98cbe696fb1655fd47aba109d29393727c43b (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
#
#
#            Nimrod's Runtime Library
#        (c) Copyright 2014 Dominik Picheta
#
#    See the file "copying.txt", included in this
#    distribution, for details about the copyright.
#

## This module implements a high-level cross-platform sockets interface.

import sockets2, os, strutils, unsigned

type
  IpAddressFamily* {.pure.} = enum ## Describes the type of an IP address
    IPv6, ## IPv6 address
    IPv4  ## IPv4 address

  TIpAddress* = object ## stores an arbitrary IP address    
    case family*: IpAddressFamily      ## the type of the IP address (IPv4 or IPv6)
    of IpAddressFamily.IPv6:
      address_v6*: array[0..15, uint8] ## Contains the IP address in bytes in case of IPv6
    of IpAddressFamily.IPv4:
      address_v4*: array[0..3, uint8]  ## Contains the IP address in bytes in case of IPv4

proc IPv4_any*(): TIpAddress =
  ## Returns the IPv4 any address, which can be used to listen on all available
  ## network adapters
  result = TIpAddress(
    family: IpAddressFamily.IPv4,
    address_v4: [0'u8, 0'u8, 0'u8, 0'u8])

proc IPv4_loopback*(): TIpAddress =
  ## Returns the IPv4 loopback address (127.0.0.1)
  result = TIpAddress(
    family: IpAddressFamily.IPv4,
    address_v4: [127'u8, 0'u8, 0'u8, 1'u8])

proc IPv4_broadcast*(): TIpAddress =
  ## Returns the IPv4 broadcast address (255.255.255.255)
  result = TIpAddress(
    family: IpAddressFamily.IPv4,
    address_v4: [255'u8, 255'u8, 255'u8, 255'u8])

proc IPv6_any*(): TIpAddress =
  ## Returns the IPv6 any address (::0), which can be used
  ## to listen on all available network adapters 
  result = TIpAddress(
    family: IpAddressFamily.IPv6,
    address_v6: [0'u8,0'u8,0'u8,0'u8,0'u8,0'u8,0'u8,0'u8,0'u8,0'u8,0'u8,0'u8,0'u8,0'u8,0'u8,0'u8])

proc IPv6_loopback*(): TIpAddress =
  ## Returns the IPv6 loopback address (::1)
  result = TIpAddress(
    family: IpAddressFamily.IPv6,
    address_v6: [0'u8,0'u8,0'u8,0'u8,0'u8,0'u8,0'u8,0'u8,0'u8,0'u8,0'u8,0'u8,0'u8,0'u8,0'u8,1'u8])

proc `==`*(lhs, rhs: TIpAddress): bool =
  ## Compares two IpAddresses for Equality. Returns two if the addresses are equal
  if lhs.family != rhs.family: return false
  if lhs.family == IpAddressFamily.IPv4:
    for i in low(lhs.address_v4) .. high(lhs.address_v4):
      if lhs.address_v4[i] != rhs.address_v4[i]: return false
  else: # IPv6
    for i in low(lhs.address_v6) .. high(lhs.address_v6):
      if lhs.address_v6[i] != rhs.address_v6[i]: return false
  return true

proc `$`*(address: TIpAddress): string =
  ## Converts an TIpAddress into the textual representation
  result = ""
  case address.family
  of IpAddressFamily.IPv4:
    for i in 0 .. 3:
      if i != 0:
        result.add('.')
      result.add($address.address_v4[i])
  of IpAddressFamily.IPv6:
    var
      currentZeroStart = -1
      currentZeroCount = 0
      biggestZeroStart = -1
      biggestZeroCount = 0
    # Look for the largest block of zeros
    for i in 0..7:
      var isZero = address.address_v6[i*2] == 0 and address.address_v6[i*2+1] == 0
      if isZero:
        if currentZeroStart == -1:
          currentZeroStart = i
          currentZeroCount = 1
        else:
          currentZeroCount.inc()
        if currentZeroCount > biggestZeroCount:
          biggestZeroCount = currentZeroCount
          biggestZeroStart = currentZeroStart
      else:
        currentZeroStart = -1

    if biggestZeroCount == 8: # Special case ::0
      result.add("::")
    else: # Print address
      var printedLastGroup = false
      for i in 0..7:
        var word:uint16 = (cast[uint16](address.address_v6[i*2])) shl 8
        word = word or cast[uint16](address.address_v6[i*2+1])

        if biggestZeroCount != 0 and # Check if group is in skip group
          (i >= biggestZeroStart and i < (biggestZeroStart + biggestZeroCount)):
          if i == biggestZeroStart: # skip start
            result.add("::")
          printedLastGroup = false
        else:
          if printedLastGroup:
            result.add(':')
          var
            afterLeadingZeros = false
            mask = 0xF000'u16
          for j in 0'u16..3'u16:
            var val = (mask and word) shr (4'u16*(3'u16-j))
            if val != 0 or afterLeadingZeros:
              if val < 0xA:
                result.add(chr(uint16(ord('0'))+val))
              else: # val >= 0xA
                result.add(chr(uint16(ord('a'))+val-0xA))
              afterLeadingZeros = true
            mask = mask shr 4
          printedLastGroup = true

proc parseIPv4Address(address_str: string): TIpAddress =
  ## Parses IPv4 adresses
  ## Raises EInvalidValue on errors
  var
    byteCount = 0
    currentByte:uint16 = 0
    seperatorValid = false

  result.family = IpAddressFamily.IPv4

  for i in 0 .. high(address_str):
    if address_str[i] in strutils.Digits: # Character is a number
      currentByte = currentByte * 10 + cast[uint16](ord(address_str[i]) - ord('0'))
      if currentByte > 255'u16:
        raise newException(EInvalidValue, "Invalid IP Address. Value is out of range")
      seperatorValid = true
    elif address_str[i] == '.': # IPv4 address separator
      if not seperatorValid or byteCount >= 3:
        raise newException(EInvalidValue, "Invalid IP Address. The address consists of too many groups")
      result.address_v4[byteCount] = cast[uint8](currentByte)
      currentByte = 0
      byteCount.inc
      seperatorValid = false
    else:
      raise newException(EInvalidValue, "Invalid IP Address. Address contains an invalid character")

  if byteCount != 3 or not seperatorValid:
    raise newException(EInvalidValue, "Invalid IP Address")
  result.address_v4[byteCount] = cast[uint8](currentByte)

proc parseIPv6Address(address_str: string): TIpAddress =
  ## Parses IPv6 adresses
  ## Raises EInvalidValue on errors
  result.family = IpAddressFamily.IPv6
  if address_str.len < 2: raise newException(EInvalidValue, "Invalid IP Address")

  var
    groupCount = 0
    currentGroupStart = 0
    currentShort:uint32 = 0
    seperatorValid = true
    dualColonGroup = -1
    lastWasColon = false
    v4StartPos = -1
    byteCount = 0

  for i,c in address_str:
    if c == ':':
      if not seperatorValid: raise newException(EInvalidValue, "Invalid IP Address. Address contains an invalid seperator")
      if lastWasColon:        
        if dualColonGroup != -1: raise newException(EInvalidValue, "Invalid IP Address. Address contains more than one \"::\" seperator")
        dualColonGroup = groupCount
        seperatorValid = false
      elif i != 0 and i != high(address_str):
        if groupCount >= 8: raise newException(EInvalidValue, "Invalid IP Address. The address consists of too many groups")
        result.address_v6[groupCount*2] = cast[uint8](currentShort shr 8)
        result.address_v6[groupCount*2+1] = cast[uint8](currentShort and 0xFF)
        currentShort = 0
        groupCount.inc()        
        if dualColonGroup != -1: seperatorValid = false
      elif i == 0: # only valid if address starts with ::
        if address_str[1] != ':':
          raise newException(EInvalidValue, "Invalid IP Address. Address may not start with \":\"")
      else: # i == high(address_str) - only valid if address ends with ::
        if address_str[high(address_str)-1] != ':': 
          raise newException(EInvalidValue, "Invalid IP Address. Address may not end with \":\"")
      lastWasColon = true
      currentGroupStart = i + 1
    elif c == '.': # Switch to parse IPv4 mode
      if i < 3 or not seperatorValid or groupCount >= 7: raise newException(EInvalidValue, "Invalid IP Address")
      v4StartPos = currentGroupStart
      currentShort = 0
      seperatorValid = false
      break
    elif c in strutils.HexDigits:
      if c in strutils.Digits: # Normal digit
        currentShort = (currentShort shl 4) + cast[uint32](ord(c) - ord('0'))
      elif c >= 'a' and c <= 'f': # Lower case hex
        currentShort = (currentShort shl 4) + cast[uint32](ord(c) - ord('a')) + 10
      else: # Upper case hex
        currentShort = (currentShort shl 4) + cast[uint32](ord(c) - ord('A')) + 10
      if currentShort > 65535'u32:
        raise newException(EInvalidValue, "Invalid IP Address. Value is out of range")
      lastWasColon = false
      seperatorValid = true
    else:
      raise newException(EInvalidValue, "Invalid IP Address. Address contains an invalid character")


  if v4StartPos == -1: # Don't parse v4. Copy the remaining v6 stuff
    if seperatorValid: # Copy remaining data
      if groupCount >= 8: raise newException(EInvalidValue, "Invalid IP Address. The address consists of too many groups")
      result.address_v6[groupCount*2] = cast[uint8](currentShort shr 8)
      result.address_v6[groupCount*2+1] = cast[uint8](currentShort and 0xFF)
      groupCount.inc()
  else: # Must parse IPv4 address
    for i,c in address_str[v4StartPos..high(address_str)]:
      if c in strutils.Digits: # Character is a number
        currentShort = currentShort * 10 + cast[uint32](ord(c) - ord('0'))
        if currentShort > 255'u32:
          raise newException(EInvalidValue, "Invalid IP Address. Value is out of range")
        seperatorValid = true
      elif c == '.': # IPv4 address separator
        if not seperatorValid or byteCount >= 3:
          raise newException(EInvalidValue, "Invalid IP Address")
        result.address_v6[groupCount*2 + byteCount] = cast[uint8](currentShort)
        currentShort = 0
        byteCount.inc()
        seperatorValid = false
      else: # Invalid character
        raise newException(EInvalidValue, "Invalid IP Address. Address contains an invalid character")

    if byteCount != 3 or not seperatorValid:
      raise newException(EInvalidValue, "Invalid IP Address")
    result.address_v6[groupCount*2 + byteCount] = cast[uint8](currentShort)
    groupCount += 2

  # Shift and fill zeros in case of ::
  if groupCount > 8:
    raise newException(EInvalidValue, "Invalid IP Address. The address consists of too many groups")
  elif groupCount < 8: # must fill
    if dualColonGroup == -1: raise newException(EInvalidValue, "Invalid IP Address. The address consists of too few groups")
    var toFill = 8 - groupCount # The number of groups to fill
    var toShift = groupCount - dualColonGroup # Nr of known groups after ::
    for i in 0..2*toShift-1: # shift
      result.address_v6[15-i] = result.address_v6[groupCount*2-i-1]
    for i in 0..2*toFill-1: # fill with 0s
      result.address_v6[dualColonGroup*2+i] = 0
  elif dualColonGroup != -1: raise newException(EInvalidValue, "Invalid IP Address. The address consists of too many groups")


proc parseIpAddress*(address_str: string): TIpAddress =
  ## Parses an IP address
  ## Raises EInvalidValue on error
  if address_str == nil:
    raise newException(EInvalidValue, "IP Address string is nil")
  if address_str.contains(':'):
    return parseIPv6Address(address_str)
  else:
    return parseIPv4Address(address_str)


type
  TSocket* = TSocketHandle

proc bindAddr*(socket: TSocket, port = TPort(0), address = "") {.
  tags: [FReadIO].} =

  ## binds an address/port number to a socket.
  ## Use address string in dotted decimal form like "a.b.c.d"
  ## or leave "" for any address.

  if address == "":
    var name: TSockaddr_in
    when defined(windows):
      name.sin_family = toInt(AF_INET).int16
    else:
      name.sin_family = toInt(AF_INET)
    name.sin_port = htons(int16(port))
    name.sin_addr.s_addr = htonl(INADDR_ANY)
    if bindAddr(socket, cast[ptr TSockAddr](addr(name)),
                  sizeof(name).TSocklen) < 0'i32:
      osError(osLastError())
  else:
    var aiList = getAddrInfo(address, port, AF_INET)
    if bindAddr(socket, aiList.ai_addr, aiList.ai_addrlen.TSocklen) < 0'i32:
      dealloc(aiList)
      osError(osLastError())
    dealloc(aiList)

proc setBlocking*(s: TSocket, blocking: bool) {.tags: [].} =
  ## Sets blocking mode on socket
  when defined(Windows):
    var mode = clong(ord(not blocking)) # 1 for non-blocking, 0 for blocking
    if ioctlsocket(s, FIONBIO, addr(mode)) == -1:
      osError(osLastError())
  else: # BSD sockets
    var x: int = fcntl(s, F_GETFL, 0)
    if x == -1:
      osError(osLastError())
    else:
      var mode = if blocking: x and not O_NONBLOCK else: x or O_NONBLOCK
      if fcntl(s, F_SETFL, mode) == -1:
        osError(osLastError())