summary refs log tree commit diff stats
diff options
context:
space:
mode:
authortreeform <starplant@gmail.com>2019-10-17 00:50:00 -0700
committerAndreas Rumpf <rumpf_a@web.de>2019-10-17 09:50:00 +0200
commit5ba932e43c9c971555d8fdc9e25e2edcdcdd70b4 (patch)
tree6dd27e515d5af0408efa64f84d6a460f4d98850b
parent37dfb7ecc6b482578bdd8778379f506032b961c9 (diff)
downloadNim-5ba932e43c9c971555d8fdc9e25e2edcdcdd70b4.tar.gz
About 50% faster base64 implemention. (#12436)
-rw-r--r--changelog.md2
-rw-r--r--lib/pure/base64.nim264
-rw-r--r--lib/pure/httpclient.nim2
-rw-r--r--tests/stdlib/tbase64.nim44
4 files changed, 191 insertions, 121 deletions
diff --git a/changelog.md b/changelog.md
index 681d5166d..22b0b130d 100644
--- a/changelog.md
+++ b/changelog.md
@@ -7,6 +7,7 @@
 
 ### Breaking changes in the standard library
 
+- `base64.encode` no longer supports `lineLen` and `newLine` use `base64.encodeMIME` instead.
 
 
 ### Breaking changes in the compiler
@@ -22,6 +23,7 @@
 
 ## Library changes
 
+- `base64.encode` and `base64.decode` was made faster by about 50%.
 
 
 ## Language additions
diff --git a/lib/pure/base64.nim b/lib/pure/base64.nim
index a5b69fadb..615af24d1 100644
--- a/lib/pure/base64.nim
+++ b/lib/pure/base64.nim
@@ -56,100 +56,119 @@
 
 const
   cb64 = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"
+  invalidChar = 255
 
-template encodeInternal(s: typed, lineLen: int, newLine: string): untyped =
-  ## encodes `s` into base64 representation. After `lineLen` characters, a
-  ## `newline` is added.
-  var total = ((len(s) + 2) div 3) * 4
-  let numLines = (total + lineLen - 1) div lineLen
-  if numLines > 0: inc(total, (numLines - 1) * newLine.len)
+template encodeInternal(s: typed): untyped =
+  ## encodes `s` into base64 representation.
+  proc encodeSize(size: int): int =
+    return (size * 4 div 3) + 6
+
+  result.setLen(encodeSize(s.len))
 
-  result = newString(total)
   var
-    i = 0
-    r = 0
-    currLine = 0
-  while i < s.len - 2:
-    let
-      a = ord(s[i])
-      b = ord(s[i+1])
-      c = ord(s[i+2])
-    result[r] = cb64[a shr 2]
-    result[r+1] = cb64[((a and 3) shl 4) or ((b and 0xF0) shr 4)]
-    result[r+2] = cb64[((b and 0x0F) shl 2) or ((c and 0xC0) shr 6)]
-    result[r+3] = cb64[c and 0x3F]
-    inc(r, 4)
-    inc(i, 3)
-    inc(currLine, 4)
-    # avoid index out of bounds when lineLen == encoded length
-    if currLine >= lineLen and i != s.len-2 and r < total:
-      for x in items(newLine):
-        result[r] = x
-        inc(r)
-      currLine = 0
-
-  if i < s.len-1:
-    let
-      a = ord(s[i])
-      b = ord(s[i+1])
-    result[r] = cb64[a shr 2]
-    result[r+1] = cb64[((a and 3) shl 4) or ((b and 0xF0) shr 4)]
-    result[r+2] = cb64[((b and 0x0F) shl 2)]
-    result[r+3] = '='
-    if r+4 != result.len:
-      setLen(result, r+4)
-  elif i < s.len:
-    let a = ord(s[i])
-    result[r] = cb64[a shr 2]
-    result[r+1] = cb64[(a and 3) shl 4]
-    result[r+2] = '='
-    result[r+3] = '='
-    if r+4 != result.len:
-      setLen(result, r+4)
-  else:
-    if r != result.len:
-      setLen(result, r)
-    #assert(r == result.len)
-    discard
-
-proc encode*[T: SomeInteger|char](s: openArray[T], lineLen = 75,
-    newLine = ""): string =
-  ## Encodes ``s`` into base64 representation. After ``lineLen`` characters, a
-  ## ``newline`` is added.
+    inputIndex = 0
+    outputIndex = 0
+    inputEnds = s.len - s.len mod 3
+    n: uint32
+    b: uint32
+
+  template inputByte(exp: untyped) =
+    b = uint32(s[inputIndex])
+    n = exp
+    inc inputIndex
+
+  template outputChar(x: untyped) =
+    result[outputIndex] = cb64[x and 63]
+    inc outputIndex
+
+  template outputChar(c: char) =
+    result[outputIndex] = c
+    inc outputIndex
+
+  while inputIndex != inputEnds:
+    inputByte(b shl 16)
+    inputByte(n or b shl 8)
+    inputByte(n or b shl 0)
+    outputChar(n shr 18)
+    outputChar(n shr 12)
+    outputChar(n shr 6)
+    outputChar(n shr 0)
+
+  var padding = s.len mod 3
+  if padding == 1:
+    inputByte(b shl 16)
+    outputChar(n shr 18)
+    outputChar(n shr 12)
+    outputChar('=')
+    outputChar('=')
+
+  elif padding == 2:
+    inputByte(b shl 16)
+    inputByte(n or b shl 8)
+    outputChar(n shr 18)
+    outputChar(n shr 12)
+    outputChar(n shr 6)
+    outputChar('=')
+
+  result.setLen(outputIndex)
+
+proc encode*[T: SomeInteger|char](s: openarray[T]): string =
+  ## Encodes `s` into base64 representation.
   ##
   ## This procedure encodes an openarray (array or sequence) of either integers
   ## or characters.
   ##
   ## **See also:**
-  ## * `encode proc<#encode,string,int,string>`_ for encoding a string
+  ## * `encode proc<#encode,string>`_ for encoding a string
   ## * `decode proc<#decode,string>`_ for decoding a string
   runnableExamples:
     assert encode(['n', 'i', 'm']) == "bmlt"
     assert encode(@['n', 'i', 'm']) == "bmlt"
     assert encode([1, 2, 3, 4, 5]) == "AQIDBAU="
-  encodeInternal(s, lineLen, newLine)
+  encodeInternal(s)
 
-proc encode*(s: string, lineLen = 75, newLine = ""): string =
-  ## Encodes ``s`` into base64 representation. After ``lineLen`` characters, a
-  ## ``newline`` is added.
+proc encode*(s: string): string =
+  ## Encodes ``s`` into base64 representation.
   ##
   ## This procedure encodes a string.
   ##
   ## **See also:**
-  ## * `encode proc<#encode,openArray[T],int,string>`_ for encoding an openarray
+  ## * `encode proc<#encode,openArray[T]>`_ for encoding an openarray
   ## * `decode proc<#decode,string>`_ for decoding a string
   runnableExamples:
     assert encode("Hello World") == "SGVsbG8gV29ybGQ="
-    assert encode("Hello World", 3, "\n") == "SGVs\nbG8g\nV29ybGQ="
-  encodeInternal(s, lineLen, newLine)
+  encodeInternal(s)
 
-proc decodeByte(b: char): int {.inline.} =
-  case b
-  of '+': result = ord('>')
-  of '0'..'9': result = ord(b) + 4
-  of 'A'..'Z': result = ord(b) - ord('A')
-  of 'a'..'z': result = ord(b) - 71
-  else: result = 63
+proc encodeMIME*(s: string, lineLen = 75, newLine = "\r\n"): string =
+  ## Encodes ``s`` into base64 representation as lines.
+  ## Used in email MIME forma, use ``lineLen`` and ``newline``.
+  ##
+  ## This procedure encodes a string according to MIME spec.
+  ##
+  ## **See also:**
+  ## * `encode proc<#encode,string>`_ for encoding a string
+  ## * `decode proc<#decode,string>`_ for decoding a string
+  runnableExamples:
+    assert encodeMIME("Hello World", 4, "\n") == "SGVs\nbG8g\nV29y\nbGQ="
+  for i, c in encode(s):
+    if i != 0 and (i mod lineLen == 0):
+      result.add(newLine)
+    result.add(c)
+
+proc initDecodeTable*(): array[256, char] =
+  # computes a decode table at compile time
+  for i in 0 ..< 256:
+    let ch = char(i)
+    var code = invalidChar
+    if ch >= 'A' and ch <= 'Z': code = i - 0x00000041
+    if ch >= 'a' and ch <= 'z': code = i - 0x00000047
+    if ch >= '0' and ch <= '9': code = i + 0x00000004
+    if ch == '+' or ch == '-': code = 0x0000003E
+    if ch == '/' or ch == '_': code = 0x0000003F
+    result[i] = char(code)
+
+const
+  decodeTable = initDecodeTable()
 
 proc decode*(s: string): string =
   ## Decodes string ``s`` in base64 representation back into its original form.
@@ -161,53 +180,58 @@ proc decode*(s: string): string =
   runnableExamples:
     assert decode("SGVsbG8gV29ybGQ=") == "Hello World"
     assert decode("  SGVsbG8gV29ybGQ=") == "Hello World"
-  const Whitespace = {' ', '\t', '\v', '\r', '\l', '\f'}
-  var total = ((len(s) + 3) div 4) * 3
-  # total is an upper bound, as we will skip arbitrary whitespace:
-  result = newString(total)
+  if s.len == 0: return
 
+  proc decodeSize(size: int): int =
+    return (size * 3 div 4) + 6
+
+  template inputChar(x: untyped) =
+    let x = int decode_table[ord(s[inputIndex])]
+    inc inputIndex
+    if x == invalidChar:
+      raise newException(ValueError,
+        "Invalid base64 format character " & repr(s[inputIndex]) &
+        " at location " & $inputIndex & ".")
+
+  template outputChar(x: untyped) =
+    result[outputIndex] = char(x and 255)
+    inc outputIndex
+
+  # pre allocate output string once
+  result.setLen(decodeSize(s.len))
   var
-    i = 0
-    r = 0
-  while true:
-    while i < s.len and s[i] in Whitespace: inc(i)
-    if i < s.len-3:
-      let
-        a = s[i].decodeByte
-        b = s[i+1].decodeByte
-        c = s[i+2].decodeByte
-        d = s[i+3].decodeByte
-
-      result[r] = chr((a shl 2) and 0xff or ((b shr 4) and 0x03))
-      result[r+1] = chr((b shl 4) and 0xff or ((c shr 2) and 0x0F))
-      result[r+2] = chr((c shl 6) and 0xff or (d and 0x3F))
-      inc(r, 3)
-      inc(i, 4)
-    else: break
-  assert i == s.len
-  # adjust the length:
-  if i > 0 and s[i-1] == '=':
-    dec(r)
-    if i > 1 and s[i-2] == '=': dec(r)
-  setLen(result, r)
-
-when isMainModule:
-  assert encode("leasure.") == "bGVhc3VyZS4="
-  assert encode("easure.") == "ZWFzdXJlLg=="
-  assert encode("asure.") == "YXN1cmUu"
-  assert encode("sure.") == "c3VyZS4="
-
-  const testInputExpandsTo76 = "+++++++++++++++++++++++++++++++++++++++++++++++++++++++++"
-  const testInputExpands = "++++++++++++++++++++++++++++++"
-  const longText = """Man is distinguished, not only by his reason, but by this
-    singular passion from other animals, which is a lust of the mind,
-    that by a perseverance of delight in the continued and indefatigable
-    generation of knowledge, exceeds the short vehemence of any carnal
-    pleasure."""
-  const tests = ["", "abc", "xyz", "man", "leasure.", "sure.", "easure.",
-                 "asure.", longText, testInputExpandsTo76, testInputExpands]
-
-  for t in items(tests):
-    assert decode(encode(t)) == t
-    assert decode(encode(t, lineLen = 40)) == t
-    assert decode(encode(t, lineLen = 76)) == t
+    inputIndex = 0
+    outputIndex = 0
+    inputLen = s.len
+    inputEnds = 0
+  # strip trailing characters
+  while s[inputLen - 1] in {'\n', '\r', ' ', '='}:
+    dec inputLen
+  # hot loop: read 4 characters at at time
+  inputEnds = inputLen - 4
+  while inputIndex <= inputEnds:
+    while s[inputIndex] in {'\n', '\r', ' '}:
+      inc inputIndex
+    inputChar(a)
+    inputChar(b)
+    inputChar(c)
+    inputChar(d)
+    outputChar(a shl 2 or b shr 4)
+    outputChar(b shl 4 or c shr 2)
+    outputChar(c shl 6 or d shr 0)
+  # do the last 2 or 3 characters
+  var leftLen = abs((inputIndex - inputLen) mod 4)
+  if leftLen == 2:
+    inputChar(a)
+    inputChar(b)
+    outputChar(a shl 2 or b shr 4)
+  elif leftLen == 3:
+    inputChar(a)
+    inputChar(b)
+    inputChar(c)
+    outputChar(a shl 2 or b shr 4)
+    outputChar(b shl 4 or c shr 2)
+  result.setLen(outputIndex)
+
+
+
diff --git a/lib/pure/httpclient.nim b/lib/pure/httpclient.nim
index c26d8920c..430e15c5a 100644
--- a/lib/pure/httpclient.nim
+++ b/lib/pure/httpclient.nim
@@ -454,7 +454,7 @@ proc generateHeaders(requestUrl: Uri, httpMethod: string,
 
   # Proxy auth header.
   if not proxy.isNil and proxy.auth != "":
-    let auth = base64.encode(proxy.auth, newline = "")
+    let auth = base64.encode(proxy.auth)
     add(result, "Proxy-Authorization: basic " & auth & "\c\L")
 
   for key, val in headers:
diff --git a/tests/stdlib/tbase64.nim b/tests/stdlib/tbase64.nim
new file mode 100644
index 000000000..9db6e8802
--- /dev/null
+++ b/tests/stdlib/tbase64.nim
@@ -0,0 +1,44 @@
+discard """
+  output: "OK"
+"""
+import base64
+
+proc main() =
+  doAssert encode("Hello World") == "SGVsbG8gV29ybGQ="
+  doAssert encode("leasure.") == "bGVhc3VyZS4="
+  doAssert encode("easure.") == "ZWFzdXJlLg=="
+  doAssert encode("asure.") == "YXN1cmUu"
+  doAssert encode("sure.") == "c3VyZS4="
+  doAssert encode([1,2,3]) == "AQID"
+  doAssert encode(['h','e','y']) == "aGV5"
+
+  doAssert encode("") == ""
+  doAssert decode("") == ""
+
+  const testInputExpandsTo76 = "+++++++++++++++++++++++++++++++++++++++++++++++++++++++++"
+  const testInputExpands = "++++++++++++++++++++++++++++++"
+  const longText = """Man is distinguished, not only by his reason, but by this
+    singular passion from other animals, which is a lust of the mind,
+    that by a perseverance of delight in the continued and indefatigable
+    generation of knowledge, exceeds the short vehemence of any carnal
+    pleasure."""
+  const tests = ["", "abc", "xyz", "man", "leasure.", "sure.", "easure.",
+                 "asure.", longText, testInputExpandsTo76, testInputExpands]
+
+  doAssert encodeMIME("foobarbaz", lineLen=4) == "Zm9v\r\nYmFy\r\nYmF6"
+  doAssert decode("Zm9v\r\nYmFy\r\nYmF6") == "foobarbaz"
+
+  for t in items(tests):
+    doAssert decode(encode(t)) == t
+    doAssert decode(encodeMIME(t, lineLen=40)) == t
+    doAssert decode(encodeMIME(t, lineLen=76)) == t
+
+  const invalid = "SGVsbG\x008gV29ybGQ="
+  try:
+    doAssert decode(invalid) == "will throw error"
+  except ValueError:
+    discard
+
+  echo "OK"
+
+main()