summary refs log tree commit diff stats
path: root/lib/pure/parsesql.nim
diff options
context:
space:
mode:
Diffstat (limited to 'lib/pure/parsesql.nim')
-rw-r--r--lib/pure/parsesql.nim57
1 files changed, 42 insertions, 15 deletions
diff --git a/lib/pure/parsesql.nim b/lib/pure/parsesql.nim
index 23d43dfe0..a7c938d01 100644
--- a/lib/pure/parsesql.nim
+++ b/lib/pure/parsesql.nim
@@ -7,14 +7,17 @@
 #    distribution, for details about the copyright.
 #
 
-## The ``parsesql`` module implements a high performance SQL file
+## The `parsesql` module implements a high performance SQL file
 ## parser. It parses PostgreSQL syntax and the SQL ANSI standard.
 ##
 ## Unstable API.
 
-import strutils, lexbase
+import std/[strutils, lexbase]
 import std/private/decode_helpers
 
+when defined(nimPreviewSlimSystem):
+  import std/assertions
+
 # ------------------- scanner -------------------------------------------------
 
 type
@@ -57,7 +60,7 @@ const
 
   reservedKeywords = @[
     # statements
-    "select", "from", "where", "group", "limit", "having",
+    "select", "from", "where", "group", "limit", "offset", "having",
     # functions
     "count",
   ]
@@ -506,6 +509,7 @@ type
     nkFromItemPair,
     nkGroup,
     nkLimit,
+    nkOffset,
     nkHaving,
     nkOrder,
     nkJoin,
@@ -982,7 +986,7 @@ proc parseInsert(p: var SqlParser): SqlNode =
     parseParIdentList(p, n)
     result.add n
   else:
-    result.add(nil)
+    result.add(newNode(nkNone))
   if isKeyw(p, "default"):
     getTok(p)
     eat(p, "values")
@@ -1017,7 +1021,7 @@ proc parseUpdate(p: var SqlParser): SqlNode =
   if isKeyw(p, "where"):
     result.add(parseWhere(p))
   else:
-    result.add(nil)
+    result.add(newNode(nkNone))
 
 proc parseDelete(p: var SqlParser): SqlNode =
   getTok(p)
@@ -1029,7 +1033,7 @@ proc parseDelete(p: var SqlParser): SqlNode =
   if isKeyw(p, "where"):
     result.add(parseWhere(p))
   else:
-    result.add(nil)
+    result.add(newNode(nkNone))
 
 proc parseSelect(p: var SqlParser): SqlNode =
   getTok(p)
@@ -1123,6 +1127,11 @@ proc parseSelect(p: var SqlParser): SqlNode =
     var l = newNode(nkLimit)
     l.add(parseExpr(p))
     result.add(l)
+  if isKeyw(p, "offset"):
+    getTok(p)
+    var o = newNode(nkOffset)
+    o.add(parseExpr(p))
+    result.add(o)
 
 proc parseStmt(p: var SqlParser; parent: SqlNode) =
   if isKeyw(p, "create"):
@@ -1179,9 +1188,12 @@ type
 proc add(s: var SqlWriter, thing: char) =
   s.buffer.add(thing)
 
-proc add(s: var SqlWriter, thing: string) =
+proc prepareAdd(s: var SqlWriter) {.inline.} =
   if s.buffer.len > 0 and s.buffer[^1] notin {' ', '\L', '(', '.'}:
     s.buffer.add(" ")
+
+proc add(s: var SqlWriter, thing: string) =
+  s.prepareAdd
   s.buffer.add(thing)
 
 proc addKeyw(s: var SqlWriter, thing: string) =
@@ -1223,6 +1235,17 @@ proc addMulti(s: var SqlWriter, n: SqlNode, sep = ',', prefix, suffix: char) =
 proc quoted(s: string): string =
   "\"" & replace(s, "\"", "\"\"") & "\""
 
+func escape(result: var string; s: string) =
+  result.add('\'')
+  for c in items(s):
+    case c
+    of '\0'..'\31':
+      result.add("\\x")
+      result.add(toHex(ord(c), 2))
+    of '\'': result.add("''")
+    else: result.add(c)
+  result.add('\'')
+
 proc ra(n: SqlNode, s: var SqlWriter) =
   if n == nil: return
   case n.kind
@@ -1235,7 +1258,8 @@ proc ra(n: SqlNode, s: var SqlWriter) =
   of nkQuotedIdent:
     s.add(quoted(n.strVal))
   of nkStringLit:
-    s.add(escape(n.strVal, "'", "'"))
+    s.prepareAdd
+    s.buffer.escape(n.strVal)
   of nkBitStringLit:
     s.add("b'" & n.strVal & "'")
   of nkHexStringLit:
@@ -1370,6 +1394,9 @@ proc ra(n: SqlNode, s: var SqlWriter) =
   of nkLimit:
     s.addKeyw("limit")
     s.addMulti(n)
+  of nkOffset:
+    s.addKeyw("offset")
+    s.addMulti(n)
   of nkHaving:
     s.addKeyw("having")
     s.addMulti(n)
@@ -1436,7 +1463,7 @@ proc ra(n: SqlNode, s: var SqlWriter) =
     s.addKeyw("enum")
     rs(n, s)
 
-proc renderSQL*(n: SqlNode, upperCase = false): string =
+proc renderSql*(n: SqlNode, upperCase = false): string =
   ## Converts an SQL abstract syntax tree to its string representation.
   var s: SqlWriter
   s.buffer = ""
@@ -1445,8 +1472,8 @@ proc renderSQL*(n: SqlNode, upperCase = false): string =
   return s.buffer
 
 proc `$`*(n: SqlNode): string =
-  ## an alias for `renderSQL`.
-  renderSQL(n)
+  ## an alias for `renderSql`.
+  renderSql(n)
 
 proc treeReprAux(s: SqlNode, level: int, result: var string) =
   result.add('\n')
@@ -1464,7 +1491,7 @@ proc treeRepr*(s: SqlNode): string =
   result = newStringOfCap(128)
   treeReprAux(s, 0, result)
 
-import streams
+import std/streams
 
 proc open(L: var SqlLexer, input: Stream, filename: string) =
   lexbase.open(L, input)
@@ -1478,7 +1505,7 @@ proc open(p: var SqlParser, input: Stream, filename: string) =
   p.tok.literal = ""
   getTok(p)
 
-proc parseSQL*(input: Stream, filename: string): SqlNode =
+proc parseSql*(input: Stream, filename: string): SqlNode =
   ## parses the SQL from `input` into an AST and returns the AST.
   ## `filename` is only used for error messages.
   ## Syntax errors raise an `SqlParseError` exception.
@@ -1489,8 +1516,8 @@ proc parseSQL*(input: Stream, filename: string): SqlNode =
   finally:
     close(p)
 
-proc parseSQL*(input: string, filename = ""): SqlNode =
+proc parseSql*(input: string, filename = ""): SqlNode =
   ## parses the SQL from `input` into an AST and returns the AST.
   ## `filename` is only used for error messages.
   ## Syntax errors raise an `SqlParseError` exception.
-  parseSQL(newStringStream(input), "")
+  parseSql(newStringStream(input), "")