summary refs log tree commit diff stats
path: root/lib/pure/parsesql.nim
diff options
context:
space:
mode:
Diffstat (limited to 'lib/pure/parsesql.nim')
-rw-r--r--lib/pure/parsesql.nim141
1 files changed, 81 insertions, 60 deletions
diff --git a/lib/pure/parsesql.nim b/lib/pure/parsesql.nim
index b84c1a744..a7c938d01 100644
--- a/lib/pure/parsesql.nim
+++ b/lib/pure/parsesql.nim
@@ -7,13 +7,16 @@
 #    distribution, for details about the copyright.
 #
 
-## The ``parsesql`` module implements a high performance SQL file
+## The `parsesql` module implements a high performance SQL file
 ## parser. It parses PostgreSQL syntax and the SQL ANSI standard.
 ##
 ## Unstable API.
 
-import
-  strutils, lexbase
+import std/[strutils, lexbase]
+import std/private/decode_helpers
+
+when defined(nimPreviewSlimSystem):
+  import std/assertions
 
 # ------------------- scanner -------------------------------------------------
 
@@ -57,7 +60,7 @@ const
 
   reservedKeywords = @[
     # statements
-    "select", "from", "where", "group", "limit", "having",
+    "select", "from", "where", "group", "limit", "offset", "having",
     # functions
     "count",
   ]
@@ -72,20 +75,6 @@ proc getColumn(L: SqlLexer): int =
 proc getLine(L: SqlLexer): int =
   result = L.lineNumber
 
-proc handleHexChar(c: var SqlLexer, xi: var int) =
-  case c.buf[c.bufpos]
-  of '0'..'9':
-    xi = (xi shl 4) or (ord(c.buf[c.bufpos]) - ord('0'))
-    inc(c.bufpos)
-  of 'a'..'f':
-    xi = (xi shl 4) or (ord(c.buf[c.bufpos]) - ord('a') + 10)
-    inc(c.bufpos)
-  of 'A'..'F':
-    xi = (xi shl 4) or (ord(c.buf[c.bufpos]) - ord('A') + 10)
-    inc(c.bufpos)
-  else:
-    discard
-
 proc handleOctChar(c: var SqlLexer, xi: var int) =
   if c.buf[c.bufpos] in {'0'..'7'}:
     xi = (xi shl 3) or (ord(c.buf[c.bufpos]) - ord('0'))
@@ -130,8 +119,10 @@ proc getEscapedChar(c: var SqlLexer, tok: var Token) =
   of 'x', 'X':
     inc(c.bufpos)
     var xi = 0
-    handleHexChar(c, xi)
-    handleHexChar(c, xi)
+    if handleHexChar(c.buf[c.bufpos], xi):
+      inc(c.bufpos)
+      if handleHexChar(c.buf[c.bufpos], xi):
+        inc(c.bufpos)
     add(tok.literal, chr(xi))
   of '0'..'7':
     var xi = 0
@@ -518,6 +509,7 @@ type
     nkFromItemPair,
     nkGroup,
     nkLimit,
+    nkOffset,
     nkHaving,
     nkOrder,
     nkJoin,
@@ -557,7 +549,14 @@ type
     tok: Token
 
 proc newNode*(k: SqlNodeKind): SqlNode =
-  result = SqlNode(kind: k)
+  when defined(js): # bug #14117
+    case k
+    of LiteralNodes:
+      result = SqlNode(kind: k, strVal: "")
+    else:
+      result = SqlNode(kind: k, sons: @[])
+  else:
+    result = SqlNode(kind: k)
 
 proc newNode*(k: SqlNodeKind, s: string): SqlNode =
   result = SqlNode(kind: k)
@@ -672,8 +671,8 @@ proc getPrecedence(p: SqlParser): int =
   else:
     result = - 1
 
-proc parseExpr(p: var SqlParser): SqlNode
-proc parseSelect(p: var SqlParser): SqlNode
+proc parseExpr(p: var SqlParser): SqlNode {.gcsafe.}
+proc parseSelect(p: var SqlParser): SqlNode {.gcsafe.}
 
 proc identOrLiteral(p: var SqlParser): SqlNode =
   case p.tok.kind
@@ -987,7 +986,7 @@ proc parseInsert(p: var SqlParser): SqlNode =
     parseParIdentList(p, n)
     result.add n
   else:
-    result.add(nil)
+    result.add(newNode(nkNone))
   if isKeyw(p, "default"):
     getTok(p)
     eat(p, "values")
@@ -1022,7 +1021,7 @@ proc parseUpdate(p: var SqlParser): SqlNode =
   if isKeyw(p, "where"):
     result.add(parseWhere(p))
   else:
-    result.add(nil)
+    result.add(newNode(nkNone))
 
 proc parseDelete(p: var SqlParser): SqlNode =
   getTok(p)
@@ -1034,7 +1033,7 @@ proc parseDelete(p: var SqlParser): SqlNode =
   if isKeyw(p, "where"):
     result.add(parseWhere(p))
   else:
-    result.add(nil)
+    result.add(newNode(nkNone))
 
 proc parseSelect(p: var SqlParser): SqlNode =
   getTok(p)
@@ -1128,6 +1127,11 @@ proc parseSelect(p: var SqlParser): SqlNode =
     var l = newNode(nkLimit)
     l.add(parseExpr(p))
     result.add(l)
+  if isKeyw(p, "offset"):
+    getTok(p)
+    var o = newNode(nkOffset)
+    o.add(parseExpr(p))
+    result.add(o)
 
 proc parseStmt(p: var SqlParser; parent: SqlNode) =
   if isKeyw(p, "create"):
@@ -1184,9 +1188,12 @@ type
 proc add(s: var SqlWriter, thing: char) =
   s.buffer.add(thing)
 
-proc add(s: var SqlWriter, thing: string) =
+proc prepareAdd(s: var SqlWriter) {.inline.} =
   if s.buffer.len > 0 and s.buffer[^1] notin {' ', '\L', '(', '.'}:
     s.buffer.add(" ")
+
+proc add(s: var SqlWriter, thing: string) =
+  s.prepareAdd
   s.buffer.add(thing)
 
 proc addKeyw(s: var SqlWriter, thing: string) =
@@ -1201,7 +1208,7 @@ proc addIden(s: var SqlWriter, thing: string) =
     iden = '"' & iden & '"'
   s.add(iden)
 
-proc ra(n: SqlNode, s: var SqlWriter)
+proc ra(n: SqlNode, s: var SqlWriter) {.gcsafe.}
 
 proc rs(n: SqlNode, s: var SqlWriter, prefix = "(", suffix = ")", sep = ", ") =
   if n.len > 0:
@@ -1228,6 +1235,17 @@ proc addMulti(s: var SqlWriter, n: SqlNode, sep = ',', prefix, suffix: char) =
 proc quoted(s: string): string =
   "\"" & replace(s, "\"", "\"\"") & "\""
 
+func escape(result: var string; s: string) =
+  result.add('\'')
+  for c in items(s):
+    case c
+    of '\0'..'\31':
+      result.add("\\x")
+      result.add(toHex(ord(c), 2))
+    of '\'': result.add("''")
+    else: result.add(c)
+  result.add('\'')
+
 proc ra(n: SqlNode, s: var SqlWriter) =
   if n == nil: return
   case n.kind
@@ -1240,7 +1258,8 @@ proc ra(n: SqlNode, s: var SqlWriter) =
   of nkQuotedIdent:
     s.add(quoted(n.strVal))
   of nkStringLit:
-    s.add(escape(n.strVal, "'", "'"))
+    s.prepareAdd
+    s.buffer.escape(n.strVal)
   of nkBitStringLit:
     s.add("b'" & n.strVal & "'")
   of nkHexStringLit:
@@ -1375,6 +1394,9 @@ proc ra(n: SqlNode, s: var SqlWriter) =
   of nkLimit:
     s.addKeyw("limit")
     s.addMulti(n)
+  of nkOffset:
+    s.addKeyw("offset")
+    s.addMulti(n)
   of nkHaving:
     s.addKeyw("having")
     s.addMulti(n)
@@ -1441,7 +1463,7 @@ proc ra(n: SqlNode, s: var SqlWriter) =
     s.addKeyw("enum")
     rs(n, s)
 
-proc renderSQL*(n: SqlNode, upperCase = false): string =
+proc renderSql*(n: SqlNode, upperCase = false): string =
   ## Converts an SQL abstract syntax tree to its string representation.
   var s: SqlWriter
   s.buffer = ""
@@ -1450,8 +1472,8 @@ proc renderSQL*(n: SqlNode, upperCase = false): string =
   return s.buffer
 
 proc `$`*(n: SqlNode): string =
-  ## an alias for `renderSQL`.
-  renderSQL(n)
+  ## an alias for `renderSql`.
+  renderSql(n)
 
 proc treeReprAux(s: SqlNode, level: int, result: var string) =
   result.add('\n')
@@ -1469,34 +1491,33 @@ proc treeRepr*(s: SqlNode): string =
   result = newStringOfCap(128)
   treeReprAux(s, 0, result)
 
-when not defined(js):
-  import streams
+import std/streams
 
-  proc open(L: var SqlLexer, input: Stream, filename: string) =
-    lexbase.open(L, input)
-    L.filename = filename
+proc open(L: var SqlLexer, input: Stream, filename: string) =
+  lexbase.open(L, input)
+  L.filename = filename
 
-  proc open(p: var SqlParser, input: Stream, filename: string) =
-    ## opens the parser `p` and assigns the input stream `input` to it.
-    ## `filename` is only used for error messages.
-    open(SqlLexer(p), input, filename)
-    p.tok.kind = tkInvalid
-    p.tok.literal = ""
-    getTok(p)
+proc open(p: var SqlParser, input: Stream, filename: string) =
+  ## opens the parser `p` and assigns the input stream `input` to it.
+  ## `filename` is only used for error messages.
+  open(SqlLexer(p), input, filename)
+  p.tok.kind = tkInvalid
+  p.tok.literal = ""
+  getTok(p)
 
-  proc parseSQL*(input: Stream, filename: string): SqlNode =
-    ## parses the SQL from `input` into an AST and returns the AST.
-    ## `filename` is only used for error messages.
-    ## Syntax errors raise an `SqlParseError` exception.
-    var p: SqlParser
-    open(p, input, filename)
-    try:
-      result = parse(p)
-    finally:
-      close(p)
-
-  proc parseSQL*(input: string, filename = ""): SqlNode =
-    ## parses the SQL from `input` into an AST and returns the AST.
-    ## `filename` is only used for error messages.
-    ## Syntax errors raise an `SqlParseError` exception.
-    parseSQL(newStringStream(input), "")
+proc parseSql*(input: Stream, filename: string): SqlNode =
+  ## parses the SQL from `input` into an AST and returns the AST.
+  ## `filename` is only used for error messages.
+  ## Syntax errors raise an `SqlParseError` exception.
+  var p: SqlParser
+  open(p, input, filename)
+  try:
+    result = parse(p)
+  finally:
+    close(p)
+
+proc parseSql*(input: string, filename = ""): SqlNode =
+  ## parses the SQL from `input` into an AST and returns the AST.
+  ## `filename` is only used for error messages.
+  ## Syntax errors raise an `SqlParseError` exception.
+  parseSql(newStringStream(input), "")