about summary refs log tree commit diff stats
diff options
context:
space:
mode:
-rw-r--r--src/data/idna.nim9
-rw-r--r--src/utils/twtstr.nim112
2 files changed, 63 insertions, 58 deletions
diff --git a/src/data/idna.nim b/src/data/idna.nim
index b636aa8c..451792a2 100644
--- a/src/data/idna.nim
+++ b/src/data/idna.nim
@@ -135,14 +135,7 @@ func loadStuff(s: string): (FullMap[cstring], # Map
     of "ignored":
       add(firstcol, add_ignore)
 
-when defined(release) or defined(small):
-  const (MappedMap,
-         DisallowedRanges,
-         Disallowed,
-         Ignored,
-         Deviation) = loadStuff(IdnaMappingTable)
-else:
-  let (MappedMap,
+const (MappedMap,
        DisallowedRanges,
        Disallowed,
        Ignored,
diff --git a/src/utils/twtstr.nim b/src/utils/twtstr.nim
index de3418ab..82ce2e57 100644
--- a/src/utils/twtstr.nim
+++ b/src/utils/twtstr.nim
@@ -127,6 +127,8 @@ func toScreamingSnakeCase*(str: string): string = # input is camel case
     else:
       result &= c.toUpperAscii()
 
+func isAscii*(r: Rune): bool =
+  return cast[uint32](r) < 128
 
 func startsWithNoCase*(str, prefix: string): bool =
   if str.len < prefix.len: return false
@@ -159,9 +161,6 @@ func hexValue*(c: char): int =
 func decValue*(c: char): int =
   return decCharMap[c]
 
-func isAscii*(r: Rune): bool =
-  return int32(r) < 128
-
 func isAscii*(s: string): bool =
   for c in s:
     if c > char(0x80):
@@ -579,12 +578,20 @@ func normalize*(rs: seq[Rune], form = UNICODE_NFC): seq[Rune] = {.cast(noSideEff
   copyMem(addr result[0], outbuf, out_len * sizeof(uint32))
   dealloc(outbuf)
 
+type u32pair {.packed.} = object
+  a: uint32
+  b: uint32
+
+func cmpRange(x: u32pair, y: uint32): int =
+  if x.a < y:
+    return -1
+  elif x.b > y:
+    return 1
+  return 0
+
 func processIdna(str: string, checkhyphens, checkbidi, checkjoiners, transitionalprocessing: bool): Option[string] =
   var mapped: seq[Rune]
-  var i = 0
-  while i < str.len:
-    var r: Rune
-    fastRuneAt(str, i, r)
+  for r in str.runes():
     let status = getIdnaTableStatus(r)
     case status
     of IDNA_DISALLOWED: return none(string) #error
@@ -604,46 +611,43 @@ func processIdna(str: string, checkhyphens, checkbidi, checkjoiners, transitiona
     assert unicode_general_category(addr cr, "Mark") == 0
   var labels: seq[string]
   for label in ($mapped).split('.'):
-    var s = label
     if label.startsWith("xn--"):
       try:
-        s = punycode.decode(label.substr("xn--".len))
-      except PunyError:
-        return none(string) #error
-    let x0 = s.toRunes()
-    block:
-      let x1 = normalize(x0)
-      if x0 == x1:
-        return none(string) #error
-    if checkhyphens:
-      if s.len >= 4 and s[2] == '-' and s[3] == '-':
-        return none(string) #error
-      if s.len > 0 and s[0] == '-' and s[^1] == '-':
-        return none(string) #error
-    if x0.len > 0:
-      let r = x0[0]
-      for i in 0 ..< cr.len div 2:
-        #TODO bisearch instead
-        var a = cast[ptr uint32](cast[int](cr.points) + i * sizeof(uint32) * 2)[]
-        var b = cast[ptr uint32](cast[int](cr.points) + i * sizeof(uint32) * 2 + 1)[]
-        if cast[uint32](r) in a .. b:
+        let s = punycode.decode(label.substr("xn--".len))
+        let x0 = s.toRunes()
+        let x1 = normalize(x0)
+        if x0 != x1:
           return none(string) #error
-    for r in x0:
-      if r == Rune('.'):
-        return none(string) #error
-      let status = getIdnaTableStatus(r)
-      case status
-      of IDNA_DISALLOWED, IDNA_IGNORED, IDNA_MAPPED:
+        if checkhyphens:
+          if s.len >= 4 and s[2] == '-' and s[3] == '-':
+            return none(string) #error
+          if s.len > 0 and s[0] == '-' and s[^1] == '-':
+            return none(string) #error
+        if x0.len > 0:
+          let cps = cast[ptr UncheckedArray[u32pair]](cr.points)
+          let c = cast[uint32](x0[0])
+          if binarySearch(toOpenArray(cps, 0, cr.len div 2 - 1), c, cmpRange) != -1:
+            return none(string) #error
+        for r in x0:
+          if r == Rune('.'):
+            return none(string) #error
+          let status = getIdnaTableStatus(r)
+          case status
+          of IDNA_DISALLOWED, IDNA_IGNORED, IDNA_MAPPED:
+            return none(string) #error
+          of IDNA_DEVIATION:
+            if transitionalprocessing:
+              return none(string) #error
+          of IDNA_VALID: discard
+          #TODO check joiners
+          #TODO check bidi
+        labels.add(s)
+      except PunyError:
         return none(string) #error
-      of IDNA_DEVIATION:
-        if transitionalprocessing:
-          return none(string) #error
-      of IDNA_VALID: discard
-      #TODO check joiners
-      #TODO check bidi
-    labels.add(s)
+    else:
+      labels.add(label)
   cr_free(addr cr)
-  return labels.join('.').some
+  return some(labels.join('.'))
 
 func unicodeToAscii*(s: string, checkhyphens, checkbidi, checkjoiners, transitionalprocessing, verifydnslength: bool): Option[string] =
   let processed = s.processIdna(checkhyphens, checkbidi, checkjoiners,
@@ -651,22 +655,30 @@ func unicodeToAscii*(s: string, checkhyphens, checkbidi, checkjoiners, transitio
   if processed.isnone:
     return none(string) #error
   var labels: seq[string]
+  var all = 0
   for label in processed.get.split('.'):
-    var needsconversion = false
-    for c in label:
-      if c notin Ascii:
-        needsconversion = true
-        break
-    if needsconversion:
+    if not label.isAscii():
       try:
         let converted = "xn--" & punycode.encode(label)
-        #TODO verify dns length
         labels.add(converted)
       except PunyError:
         return none(string) #error
     else:
       labels.add(label)
-  return labels.join('.').some
+    if verifydnslength:
+      let rl = labels[^1].runeLen()
+      if rl notin 1..63:
+        return none(string)
+  if verifydnslength:
+    var all = 0
+    for label in labels:
+      let rl = label.runeLen()
+      if rl notin 0..63:
+        return none(string) #error
+      all += rl
+    if all notin 1..253:
+      return none(string) #error
+  return some(labels.join('.'))
 
 #TODO this is stupid
 func isValidNonZeroInt*(str: string): bool =