about summary refs log tree commit diff stats
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/data/idna.nim212
-rw-r--r--src/types/url.nim34
-rw-r--r--src/utils/twtstr.nim96
3 files changed, 330 insertions, 12 deletions
diff --git a/src/data/idna.nim b/src/data/idna.nim
new file mode 100644
index 00000000..5bd38ee4
--- /dev/null
+++ b/src/data/idna.nim
@@ -0,0 +1,212 @@
+import algorithm
+import unicode
+import sets
+import tables
+import sugar
+import strutils
+
+type IDNATableStatus* = enum
+  IDNA_VALID, IDNA_IGNORED, IDNA_MAPPED, IDNA_DEVIATION, IDNA_DISALLOWED
+
+const IdnaMappingTable = staticRead"res/IdnaMappingTable.txt"
+
+func loadStuff(s: string): (seq[(uint16, cstring)], seq[(int, cstring)],
+                            seq[(uint16, uint16)], seq[(int, int)],
+                            set[uint16], HashSet[int],
+                            set[uint16], HashSet[int],
+                            seq[(uint16, cstring)]) =
+  template add_map(i: int, str: string) =
+    if cast[uint](i) <= high(uint16):
+      result[0].add((cast[uint16](i), cstring(str)))
+    else:
+      result[1].add((i, cstring(str)))
+  template add_disallow(i, j: int) =
+    if cast[uint](i) <= high(uint16):
+      result[2].add((cast[uint16](i), cast[uint16](j)))
+    else:
+      result[3].add((i, j))
+  template add_disallow(i: int) =
+    if cast[uint](i) <= high(uint16):
+      result[4].incl(cast[uint16](i))
+    else:
+      result[5].incl(i)
+  template add_ignore(i: int) =
+    if cast[uint](i) <= high(uint16):
+      result[6].incl(cast[uint16](i))
+    else:
+      result[7].incl(i)
+  template add_deviation(i: int, str: string) =
+    if cast[uint](i) <= high(uint16):
+      result[8].add((cast[uint16](i), cstring(str)))
+    else:
+      assert false
+
+  for line in s.split('\n'):
+    if line.len == 0 or line[0] == '#':
+      continue
+    var i = 0
+    var firstcol = ""
+    var status = ""
+    var thirdcol: seq[string]
+    var fourthcol = ""
+
+    while i < line.len and line[i] notin {'#', ';'}:
+      if line[i] != ' ':
+        firstcol &= line[i]
+      inc i
+    if line[i] != '#':
+      inc i
+
+    while i < line.len and line[i] notin {'#', ';'}:
+      if line[i] != ' ':
+        status &= line[i]
+      inc i
+    if line[i] != '#':
+      inc i
+
+    var nw = true
+    while i < line.len and line[i] notin {'#', ';'}:
+      if line[i] == ' ':
+        nw = true
+      else:
+        if nw:
+          thirdcol.add("")
+          nw = false
+        thirdcol[^1] &= line[i]
+      inc i
+    if line[i] != '#':
+      inc i
+
+    while i < line.len and line[i] notin {'#', ';'}:
+      if line[i] != ' ':
+        fourthcol &= line[i]
+      inc i
+
+    case status
+    of "mapped", "disallowed_STD3_mapped":
+      let codepoints = thirdcol
+      var str = ""
+      for code in codepoints:
+        str &= Rune(parseHexInt(code))
+
+      if firstcol.contains(".."):
+        let fcs = firstcol.split("..")
+        let rstart = parseHexInt(fcs[0])
+        let rend = parseHexInt(fcs[1])
+        for i in rstart..rend:
+          add_map(i, str)
+      else:
+        add_map(parseHexInt(firstcol), str)
+    of "deviation":
+      let codepoints = thirdcol
+      var str = ""
+      for code in codepoints:
+        str &= Rune(parseHexInt(code))
+      if firstcol.contains(".."):
+        let fcs = firstcol.split("..")
+        let rstart = parseHexInt(fcs[0])
+        let rend = parseHexInt(fcs[1])
+        for i in rstart..rend:
+          add_deviation(i, str)
+      else:
+        add_deviation(parseHexInt(firstcol), str)
+    of "valid":
+      if fourthcol == "NV8" or fourthcol == "XV8":
+        if firstcol.contains(".."):
+          let fcs = firstcol.split("..")
+          let rstart = parseHexInt(fcs[0])
+          let rend = parseHexInt(fcs[1])
+          add_disallow(rstart, rend)
+        else:
+          add_disallow(parseHexInt(firstcol))
+    of "disallowed":
+      if firstcol.contains(".."):
+        let fcs = firstcol.split("..")
+        let rstart = parseHexInt(fcs[0])
+        let rend = parseHexInt(fcs[1])
+        add_disallow(rstart, rend)
+      else:
+        add_disallow(parseHexInt(firstcol))
+    of "ignored":
+      if firstcol.contains(".."):
+        let fcs = firstcol.split("..")
+        let rstart = parseHexInt(fcs[0])
+        let rend = parseHexInt(fcs[1])
+        for i in rstart..rend:
+          add_ignore(i)
+      else:
+        add_ignore(parseHexInt(firstcol))
+
+when defined(release):
+  const (MappedMap1,
+         MappedMap2,
+         DisallowedRanges1,
+         DisallowedRanges2,
+         Disallowed1,
+         Disallowed2,
+         Ignored1,
+         Ignored2,
+         Deviation) = loadStuff(IdnaMappingTable)
+else:
+  let (MappedMap1,
+         MappedMap2,
+         DisallowedRanges1,
+         DisallowedRanges2,
+         Disallowed1,
+         Disallowed2,
+         Ignored1,
+         Ignored2,
+         Deviation) = loadStuff(IdnaMappingTable)
+
+func searchInMap[U, T](a: openarray[(U, T)], u: U): int =
+  binarySearch(a, u, (x, y) => cmp(x[0], y))
+
+func isInMap[U, T](a: openarray[(U, T)], u: U): bool =
+  a.searchInMap(u) != -1
+
+func isInRange[U](a: openarray[(U, U)], u: U): bool =
+  binarySearch(a, u, (x, y) => (if x[0] < y: -1 elif x[1] > y: 1 else: 0)) != -1
+
+func getIdnaTableStatus*(r: Rune): IDNATableStatus =
+  let i = int(r)
+  {.cast(noSideEffect).}:
+    if cast[uint](i) <= high(uint16):
+      let u = cast[uint16](i)
+      if u in Ignored1:
+        return IDNA_IGNORED
+      if u in Disallowed1:
+        return IDNA_DISALLOWED
+      for item in Deviation:
+        if item[0] == u:
+          return IDNA_DEVIATION
+      if DisallowedRanges1.isInRange(u):
+        return IDNA_DISALLOWED
+      if MappedMap1.isInMap(u):
+        return IDNA_MAPPED
+    else:
+      if i in Ignored2:
+        return IDNA_IGNORED
+      if i in Disallowed2:
+        return IDNA_DISALLOWED
+      if DisallowedRanges2.isInRange(i):
+        return IDNA_DISALLOWED
+      if MappedMap2.isInMap(i):
+        return IDNA_MAPPED
+    return IDNA_VALID
+
+func getIdnaMapped*(r: Rune): string =
+  {.cast(noSideEffect).}:
+    let i = int(r)
+    if cast[uint](i) <= high(uint16):
+      let u = cast[uint16](i)
+      let n = MappedMap1.searchInMap(u)
+      if n != -1:
+        return $MappedMap1[n][1]
+    let n = MappedMap2.searchInMap(i)
+    return $MappedMap2[n][1]
+
+func getDeviationMapped*(r: Rune): string =
+  {.cast(noSideEffect).}:
+    for item in Deviation:
+      if item[0] == cast[uint16](r):
+        return $item[1]
diff --git a/src/types/url.nim b/src/types/url.nim
index 6dea5064..81ae4ea4 100644
--- a/src/types/url.nim
+++ b/src/types/url.nim
@@ -248,6 +248,40 @@ func endsInNumber(input: string): bool =
     return true
   return false
 
+func domainToAscii*(domain: string, bestrict = false): Option[string] =
+  var needsprocessing = false
+  for s in domain.split('.'):
+    var i = 0
+    var xn = 0
+    while i < s.len:
+      if s[i] notin Ascii:
+        needsprocessing = true
+        break
+      case i
+      of 0:
+        if s[i] == 'x': inc xn
+      of 1:
+        if s[i] == 'n': inc xn
+      of 2:
+        if s[i] == '-': inc xn
+      of 3:
+        if s[i] == '-' and xn == 3:
+          needsprocessing = true
+          break
+      else: discard
+      inc i
+    if needsprocessing:
+      break
+  if bestrict or needsprocessing:
+    #Note: we don't implement STD3 separately, it's always true
+    result = domain.unicodeToAscii(false, true, true, false, bestrict)
+    if result.isnone or result.get == "":
+      #TODO validation error
+      return none(string)
+    return result
+  else:
+    return domain.toAsciiLower().some
+
 func parseHost(input: string, isnotspecial = false): Option[Host] =
   if input.len == 0: return
   if input[0] == '[':
diff --git a/src/utils/twtstr.nim b/src/utils/twtstr.nim
index 5fb45111..6d42134b 100644
--- a/src/utils/twtstr.nim
+++ b/src/utils/twtstr.nim
@@ -8,6 +8,9 @@ import math
 import sugar
 import sequtils
 import options
+import punycode
+
+import data/idna
 
 when defined(posix):
   import posix
@@ -104,7 +107,7 @@ func findChar*(str: string, c: Rune, start: int = 0): int =
 func getLowerChars*(): string =
   result = ""
   for i in 0..255:
-    if chr(i) >= 'A' and chr(i) <= 'Z':
+    if chr(i) in 'A'..'Z':
       result &= chr(i + 32)
     else:
       result &= chr(i)
@@ -114,6 +117,11 @@ const lowerChars = getLowerChars()
 func tolower*(c: char): char =
   return lowerChars[int(c)]
 
+func toAsciiLower*(str: string): string =
+  result = newString(str.len)
+  for i in 0..str.high:
+    result[i] = str[i].tolower()
+
 func getrune(s: string): Rune =
   return s.toRunes()[0]
 
@@ -165,8 +173,8 @@ func decValue*(r: Rune): int =
 const HexChars = "0123456789ABCDEF"
 func toHex*(c: char): string =
   result = newString(2)
-  result[0] = HexChars[(int8(c) and 0xF)]
-  result[1] = HexChars[(int8(c) shr 4)]
+  result[0] = HexChars[(uint8(c) and 0xF)]
+  result[1] = HexChars[(uint8(c) shr 4)]
 
 func equalsIgnoreCase*(s1: seq[Rune], s2: string): bool =
   var i = 0
@@ -449,13 +457,79 @@ func clearControls*(s: string): string =
     if c notin Controls:
       result &= c
 
-#TODO ugh this'll take a while to implement properly
-func domainToAscii*(domain: string): Option[string] =
-  result = some("")
-  for c in domain:
-    if c notin Ascii:
-      return none(string)
-    result.get &= c
+func processIdna(str: string, checkhyphens, checkbidi, checkjoiners, transitionalprocessing: bool): Option[string] =
+  var mapped = ""
+  var i = 0
+  while i < str.len:
+    var r: Rune
+    fastRuneAt(str, i, r)
+    let status = getIdnaTableStatus(r)
+    case status
+    of IDNA_DISALLOWED: return none(string) #error
+    of IDNA_IGNORED: discard
+    of IDNA_MAPPED: mapped &= getIdnaMapped(r)
+    of IDNA_DEVIATION:
+      if transitionalprocessing: mapped &= getDeviationMapped(r)
+      else: mapped &= r
+    of IDNA_VALID: mapped &= r
+  
+  #TODO normalize
+  var labels: seq[string]
+  for label in str.split('.'):
+    var s = label
+    if label.startsWith("xn--"):
+      try:
+        s = punycode.decode(label)
+      except PunyError:
+        return none(string) #error
+    #TODO check normalization
+    if checkhyphens:
+      if s.len >= 4 and s[2] == '-' and s[3] == '-':
+        return none(string) #error
+      if s.len > 0 and s[0] == '-' and s[^1] == '-':
+        return none(string) #error
+    var i = 0
+    while i < s.len:
+      if s[i] == '.':
+        return none(string) #error
+      var r: Rune
+      fastRuneAt(str, i, r)
+      #TODO check general category mark
+      let status = getIdnaTableStatus(r)
+      case status
+      of IDNA_DISALLOWED, IDNA_IGNORED, IDNA_MAPPED:
+        return none(string) #error
+      of IDNA_DEVIATION:
+        if transitionalprocessing:
+          return none(string) #error
+      of IDNA_VALID: discard
+      #TODO check joiners
+      #TODO check bidi
+    labels.add(s)
+  return labels.join('.').some
+
+func unicodeToAscii*(s: string, checkhyphens, checkbidi, checkjoiners, transitionalprocessing, verifydnslength: bool): Option[string] =
+  let processed = s.processIdna(checkhyphens, checkbidi, checkjoiners,
+                                transitionalprocessing)
+  if processed.isnone:
+    return none(string) #error
+  var labels: seq[string]
+  for label in processed.get.split('.'):
+    var needsconversion = false
+    for c in label:
+      if c notin Ascii:
+        needsconversion = true
+        break
+    if needsconversion:
+      try:
+        let converted = "xn--" & punycode.encode(label)
+        #TODO verify dns length
+        labels.add(converted)
+      except PunyError:
+        return none(string) #error
+    else:
+      labels.add(label)
+  return labels.join('.').some
 
 proc expandPath*(path: string): string =
   if path.len == 0:
@@ -672,8 +746,6 @@ const ambiguous = [
 # variant might be useful for users of CJK legacy encodings who want to migrate
 # to UCS without changing the traditional terminal character-width behaviour.
 # It is not otherwise recommended for general use.
-#
-# TODO: currently these are unused, the user should be able to toggle them
 
 # auxiliary function for binary search in interval table
 func bisearch(ucs: Rune, table: openarray[(int, int)]): bool =