diff options
-rw-r--r-- | changelog.md | 2 | ||||
-rw-r--r-- | lib/pure/base64.nim | 264 | ||||
-rw-r--r-- | lib/pure/httpclient.nim | 2 | ||||
-rw-r--r-- | tests/stdlib/tbase64.nim | 44 |
4 files changed, 191 insertions, 121 deletions
diff --git a/changelog.md b/changelog.md index 681d5166d..22b0b130d 100644 --- a/changelog.md +++ b/changelog.md @@ -7,6 +7,7 @@ ### Breaking changes in the standard library +- `base64.encode` no longer supports `lineLen` and `newLine` use `base64.encodeMIME` instead. ### Breaking changes in the compiler @@ -22,6 +23,7 @@ ## Library changes +- `base64.encode` and `base64.decode` was made faster by about 50%. ## Language additions diff --git a/lib/pure/base64.nim b/lib/pure/base64.nim index a5b69fadb..615af24d1 100644 --- a/lib/pure/base64.nim +++ b/lib/pure/base64.nim @@ -56,100 +56,119 @@ const cb64 = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/" + invalidChar = 255 -template encodeInternal(s: typed, lineLen: int, newLine: string): untyped = - ## encodes `s` into base64 representation. After `lineLen` characters, a - ## `newline` is added. - var total = ((len(s) + 2) div 3) * 4 - let numLines = (total + lineLen - 1) div lineLen - if numLines > 0: inc(total, (numLines - 1) * newLine.len) +template encodeInternal(s: typed): untyped = + ## encodes `s` into base64 representation. + proc encodeSize(size: int): int = + return (size * 4 div 3) + 6 + + result.setLen(encodeSize(s.len)) - result = newString(total) var - i = 0 - r = 0 - currLine = 0 - while i < s.len - 2: - let - a = ord(s[i]) - b = ord(s[i+1]) - c = ord(s[i+2]) - result[r] = cb64[a shr 2] - result[r+1] = cb64[((a and 3) shl 4) or ((b and 0xF0) shr 4)] - result[r+2] = cb64[((b and 0x0F) shl 2) or ((c and 0xC0) shr 6)] - result[r+3] = cb64[c and 0x3F] - inc(r, 4) - inc(i, 3) - inc(currLine, 4) - # avoid index out of bounds when lineLen == encoded length - if currLine >= lineLen and i != s.len-2 and r < total: - for x in items(newLine): - result[r] = x - inc(r) - currLine = 0 - - if i < s.len-1: - let - a = ord(s[i]) - b = ord(s[i+1]) - result[r] = cb64[a shr 2] - result[r+1] = cb64[((a and 3) shl 4) or ((b and 0xF0) shr 4)] - result[r+2] = cb64[((b and 0x0F) shl 2)] - result[r+3] = '=' - if r+4 != result.len: - setLen(result, r+4) - elif i < s.len: - let a = ord(s[i]) - result[r] = cb64[a shr 2] - result[r+1] = cb64[(a and 3) shl 4] - result[r+2] = '=' - result[r+3] = '=' - if r+4 != result.len: - setLen(result, r+4) - else: - if r != result.len: - setLen(result, r) - #assert(r == result.len) - discard - -proc encode*[T: SomeInteger|char](s: openArray[T], lineLen = 75, - newLine = ""): string = - ## Encodes ``s`` into base64 representation. After ``lineLen`` characters, a - ## ``newline`` is added. + inputIndex = 0 + outputIndex = 0 + inputEnds = s.len - s.len mod 3 + n: uint32 + b: uint32 + + template inputByte(exp: untyped) = + b = uint32(s[inputIndex]) + n = exp + inc inputIndex + + template outputChar(x: untyped) = + result[outputIndex] = cb64[x and 63] + inc outputIndex + + template outputChar(c: char) = + result[outputIndex] = c + inc outputIndex + + while inputIndex != inputEnds: + inputByte(b shl 16) + inputByte(n or b shl 8) + inputByte(n or b shl 0) + outputChar(n shr 18) + outputChar(n shr 12) + outputChar(n shr 6) + outputChar(n shr 0) + + var padding = s.len mod 3 + if padding == 1: + inputByte(b shl 16) + outputChar(n shr 18) + outputChar(n shr 12) + outputChar('=') + outputChar('=') + + elif padding == 2: + inputByte(b shl 16) + inputByte(n or b shl 8) + outputChar(n shr 18) + outputChar(n shr 12) + outputChar(n shr 6) + outputChar('=') + + result.setLen(outputIndex) + +proc encode*[T: SomeInteger|char](s: openarray[T]): string = + ## Encodes `s` into base64 representation. ## ## This procedure encodes an openarray (array or sequence) of either integers ## or characters. ## ## **See also:** - ## * `encode proc<#encode,string,int,string>`_ for encoding a string + ## * `encode proc<#encode,string>`_ for encoding a string ## * `decode proc<#decode,string>`_ for decoding a string runnableExamples: assert encode(['n', 'i', 'm']) == "bmlt" assert encode(@['n', 'i', 'm']) == "bmlt" assert encode([1, 2, 3, 4, 5]) == "AQIDBAU=" - encodeInternal(s, lineLen, newLine) + encodeInternal(s) -proc encode*(s: string, lineLen = 75, newLine = ""): string = - ## Encodes ``s`` into base64 representation. After ``lineLen`` characters, a - ## ``newline`` is added. +proc encode*(s: string): string = + ## Encodes ``s`` into base64 representation. ## ## This procedure encodes a string. ## ## **See also:** - ## * `encode proc<#encode,openArray[T],int,string>`_ for encoding an openarray + ## * `encode proc<#encode,openArray[T]>`_ for encoding an openarray ## * `decode proc<#decode,string>`_ for decoding a string runnableExamples: assert encode("Hello World") == "SGVsbG8gV29ybGQ=" - assert encode("Hello World", 3, "\n") == "SGVs\nbG8g\nV29ybGQ=" - encodeInternal(s, lineLen, newLine) + encodeInternal(s) -proc decodeByte(b: char): int {.inline.} = - case b - of '+': result = ord('>') - of '0'..'9': result = ord(b) + 4 - of 'A'..'Z': result = ord(b) - ord('A') - of 'a'..'z': result = ord(b) - 71 - else: result = 63 +proc encodeMIME*(s: string, lineLen = 75, newLine = "\r\n"): string = + ## Encodes ``s`` into base64 representation as lines. + ## Used in email MIME forma, use ``lineLen`` and ``newline``. + ## + ## This procedure encodes a string according to MIME spec. + ## + ## **See also:** + ## * `encode proc<#encode,string>`_ for encoding a string + ## * `decode proc<#decode,string>`_ for decoding a string + runnableExamples: + assert encodeMIME("Hello World", 4, "\n") == "SGVs\nbG8g\nV29y\nbGQ=" + for i, c in encode(s): + if i != 0 and (i mod lineLen == 0): + result.add(newLine) + result.add(c) + +proc initDecodeTable*(): array[256, char] = + # computes a decode table at compile time + for i in 0 ..< 256: + let ch = char(i) + var code = invalidChar + if ch >= 'A' and ch <= 'Z': code = i - 0x00000041 + if ch >= 'a' and ch <= 'z': code = i - 0x00000047 + if ch >= '0' and ch <= '9': code = i + 0x00000004 + if ch == '+' or ch == '-': code = 0x0000003E + if ch == '/' or ch == '_': code = 0x0000003F + result[i] = char(code) + +const + decodeTable = initDecodeTable() proc decode*(s: string): string = ## Decodes string ``s`` in base64 representation back into its original form. @@ -161,53 +180,58 @@ proc decode*(s: string): string = runnableExamples: assert decode("SGVsbG8gV29ybGQ=") == "Hello World" assert decode(" SGVsbG8gV29ybGQ=") == "Hello World" - const Whitespace = {' ', '\t', '\v', '\r', '\l', '\f'} - var total = ((len(s) + 3) div 4) * 3 - # total is an upper bound, as we will skip arbitrary whitespace: - result = newString(total) + if s.len == 0: return + proc decodeSize(size: int): int = + return (size * 3 div 4) + 6 + + template inputChar(x: untyped) = + let x = int decode_table[ord(s[inputIndex])] + inc inputIndex + if x == invalidChar: + raise newException(ValueError, + "Invalid base64 format character " & repr(s[inputIndex]) & + " at location " & $inputIndex & ".") + + template outputChar(x: untyped) = + result[outputIndex] = char(x and 255) + inc outputIndex + + # pre allocate output string once + result.setLen(decodeSize(s.len)) var - i = 0 - r = 0 - while true: - while i < s.len and s[i] in Whitespace: inc(i) - if i < s.len-3: - let - a = s[i].decodeByte - b = s[i+1].decodeByte - c = s[i+2].decodeByte - d = s[i+3].decodeByte - - result[r] = chr((a shl 2) and 0xff or ((b shr 4) and 0x03)) - result[r+1] = chr((b shl 4) and 0xff or ((c shr 2) and 0x0F)) - result[r+2] = chr((c shl 6) and 0xff or (d and 0x3F)) - inc(r, 3) - inc(i, 4) - else: break - assert i == s.len - # adjust the length: - if i > 0 and s[i-1] == '=': - dec(r) - if i > 1 and s[i-2] == '=': dec(r) - setLen(result, r) - -when isMainModule: - assert encode("leasure.") == "bGVhc3VyZS4=" - assert encode("easure.") == "ZWFzdXJlLg==" - assert encode("asure.") == "YXN1cmUu" - assert encode("sure.") == "c3VyZS4=" - - const testInputExpandsTo76 = "+++++++++++++++++++++++++++++++++++++++++++++++++++++++++" - const testInputExpands = "++++++++++++++++++++++++++++++" - const longText = """Man is distinguished, not only by his reason, but by this - singular passion from other animals, which is a lust of the mind, - that by a perseverance of delight in the continued and indefatigable - generation of knowledge, exceeds the short vehemence of any carnal - pleasure.""" - const tests = ["", "abc", "xyz", "man", "leasure.", "sure.", "easure.", - "asure.", longText, testInputExpandsTo76, testInputExpands] - - for t in items(tests): - assert decode(encode(t)) == t - assert decode(encode(t, lineLen = 40)) == t - assert decode(encode(t, lineLen = 76)) == t + inputIndex = 0 + outputIndex = 0 + inputLen = s.len + inputEnds = 0 + # strip trailing characters + while s[inputLen - 1] in {'\n', '\r', ' ', '='}: + dec inputLen + # hot loop: read 4 characters at at time + inputEnds = inputLen - 4 + while inputIndex <= inputEnds: + while s[inputIndex] in {'\n', '\r', ' '}: + inc inputIndex + inputChar(a) + inputChar(b) + inputChar(c) + inputChar(d) + outputChar(a shl 2 or b shr 4) + outputChar(b shl 4 or c shr 2) + outputChar(c shl 6 or d shr 0) + # do the last 2 or 3 characters + var leftLen = abs((inputIndex - inputLen) mod 4) + if leftLen == 2: + inputChar(a) + inputChar(b) + outputChar(a shl 2 or b shr 4) + elif leftLen == 3: + inputChar(a) + inputChar(b) + inputChar(c) + outputChar(a shl 2 or b shr 4) + outputChar(b shl 4 or c shr 2) + result.setLen(outputIndex) + + + diff --git a/lib/pure/httpclient.nim b/lib/pure/httpclient.nim index c26d8920c..430e15c5a 100644 --- a/lib/pure/httpclient.nim +++ b/lib/pure/httpclient.nim @@ -454,7 +454,7 @@ proc generateHeaders(requestUrl: Uri, httpMethod: string, # Proxy auth header. if not proxy.isNil and proxy.auth != "": - let auth = base64.encode(proxy.auth, newline = "") + let auth = base64.encode(proxy.auth) add(result, "Proxy-Authorization: basic " & auth & "\c\L") for key, val in headers: diff --git a/tests/stdlib/tbase64.nim b/tests/stdlib/tbase64.nim new file mode 100644 index 000000000..9db6e8802 --- /dev/null +++ b/tests/stdlib/tbase64.nim @@ -0,0 +1,44 @@ +discard """ + output: "OK" +""" +import base64 + +proc main() = + doAssert encode("Hello World") == "SGVsbG8gV29ybGQ=" + doAssert encode("leasure.") == "bGVhc3VyZS4=" + doAssert encode("easure.") == "ZWFzdXJlLg==" + doAssert encode("asure.") == "YXN1cmUu" + doAssert encode("sure.") == "c3VyZS4=" + doAssert encode([1,2,3]) == "AQID" + doAssert encode(['h','e','y']) == "aGV5" + + doAssert encode("") == "" + doAssert decode("") == "" + + const testInputExpandsTo76 = "+++++++++++++++++++++++++++++++++++++++++++++++++++++++++" + const testInputExpands = "++++++++++++++++++++++++++++++" + const longText = """Man is distinguished, not only by his reason, but by this + singular passion from other animals, which is a lust of the mind, + that by a perseverance of delight in the continued and indefatigable + generation of knowledge, exceeds the short vehemence of any carnal + pleasure.""" + const tests = ["", "abc", "xyz", "man", "leasure.", "sure.", "easure.", + "asure.", longText, testInputExpandsTo76, testInputExpands] + + doAssert encodeMIME("foobarbaz", lineLen=4) == "Zm9v\r\nYmFy\r\nYmF6" + doAssert decode("Zm9v\r\nYmFy\r\nYmF6") == "foobarbaz" + + for t in items(tests): + doAssert decode(encode(t)) == t + doAssert decode(encodeMIME(t, lineLen=40)) == t + doAssert decode(encodeMIME(t, lineLen=76)) == t + + const invalid = "SGVsbG\x008gV29ybGQ=" + try: + doAssert decode(invalid) == "will throw error" + except ValueError: + discard + + echo "OK" + +main() |