summary refs log tree commit diff stats
path: root/lib
diff options
context:
space:
mode:
authortreeform <starplant@gmail.com>2018-04-12 08:49:24 -0700
committerAndreas Rumpf <rumpf_a@web.de>2018-04-12 17:49:24 +0200
commitf3db632b1d730892e6770a6034acfb8aec077b34 (patch)
tree6b403b2825c475b34573a82b23f86198eaab4845 /lib
parent63160855aade7fd64755e4bdc78bf4694513df50 (diff)
downloadNim-f3db632b1d730892e6770a6034acfb8aec077b34.tar.gz
Added count(*) support to sql parser. Fixed warnings in sql parser. (#7490)
Diffstat (limited to 'lib')
-rw-r--r--lib/pure/parsesql.nim171
1 files changed, 99 insertions, 72 deletions
diff --git a/lib/pure/parsesql.nim b/lib/pure/parsesql.nim
index ae192ab9a..fcc757bea 100644
--- a/lib/pure/parsesql.nim
+++ b/lib/pure/parsesql.nim
@@ -11,7 +11,7 @@
 ## parser. It parses PostgreSQL syntax and the SQL ANSI standard.
 
 import
-  hashes, strutils, lexbase, streams
+  hashes, strutils, lexbase
 
 # ------------------- scanner -------------------------------------------------
 
@@ -62,10 +62,6 @@ const
     "count",
   ]
 
-proc open(L: var SqlLexer, input: Stream, filename: string) =
-  lexbase.open(L, input)
-  L.filename = filename
-
 proc close(L: var SqlLexer) =
   lexbase.close(L)
 
@@ -496,6 +492,7 @@ type
   SqlNodeKind* = enum ## kind of SQL abstract syntax tree
     nkNone,
     nkIdent,
+    nkQuotedIdent,
     nkStringLit,
     nkBitStringLit,
     nkHexStringLit,
@@ -551,13 +548,18 @@ type
     nkCreateIndexIfNotExists,
     nkEnumDef
 
+const
+  LiteralNodes = {
+    nkIdent, nkQuotedIdent, nkStringLit, nkBitStringLit, nkHexStringLit,
+    nkIntegerLit, nkNumericLit
+  }
+
 type
   SqlParseError* = object of ValueError ## Invalid SQL encountered
   SqlNode* = ref SqlNodeObj        ## an SQL abstract syntax tree node
   SqlNodeObj* = object              ## an SQL abstract syntax tree node
     case kind*: SqlNodeKind      ## kind of syntax tree
-    of nkIdent, nkStringLit, nkBitStringLit, nkHexStringLit,
-                nkIntegerLit, nkNumericLit:
+    of LiteralNodes:
       strVal*: string             ## AST leaf: the identifier, numeric literal
                                   ## string literal, etc.
     else:
@@ -566,21 +568,26 @@ type
   SqlParser* = object of SqlLexer ## SQL parser object
     tok: Token
 
+
 {.deprecated: [EInvalidSql: SqlParseError, PSqlNode: SqlNode,
     TSqlNode: SqlNodeObj, TSqlParser: SqlParser, TSqlNodeKind: SqlNodeKind].}
 
-proc newNode(k: SqlNodeKind): SqlNode =
+proc newNode*(k: SqlNodeKind): SqlNode =
   new(result)
   result.kind = k
 
-proc newNode(k: SqlNodeKind, s: string): SqlNode =
+proc newNode*(k: SqlNodeKind, s: string): SqlNode =
   new(result)
   result.kind = k
   result.strVal = s
 
+proc newNode*(k: SqlNodeKind, sons: seq[SqlNode]): SqlNode =
+  new(result)
+  result.kind = k
+  result.sons = sons
+
 proc len*(n: SqlNode): int =
-  if n.kind in {nkIdent, nkStringLit, nkBitStringLit, nkHexStringLit,
-                nkIntegerLit, nkNumericLit}:
+  if n.kind in LiteralNodes:
     result = 0
   else:
     result = n.sons.len
@@ -630,7 +637,7 @@ proc eat(p: var SqlParser, keyw: string) =
   if isKeyw(p, keyw):
     getTok(p)
   else:
-    sqlError(p, keyw.toUpper() & " expected")
+    sqlError(p, keyw.toUpperAscii() & " expected")
 
 proc opt(p: var SqlParser, kind: TokKind) =
   if p.tok.kind == kind: getTok(p)
@@ -689,7 +696,10 @@ proc parseSelect(p: var SqlParser): SqlNode
 
 proc identOrLiteral(p: var SqlParser): SqlNode =
   case p.tok.kind
-  of tkIdentifier, tkQuotedIdentifier:
+  of tkQuotedIdentifier:
+    result = newNode(nkQuotedIdent, p.tok.literal)
+    getTok(p)
+  of tkIdentifier:
     result = newNode(nkIdent, p.tok.literal)
     getTok(p)
   of tkStringConstant, tkEscapeConstant, tkDollarQuotedConstant:
@@ -713,11 +723,15 @@ proc identOrLiteral(p: var SqlParser): SqlNode =
     result.add(parseExpr(p))
     eat(p, tkParRi)
   else:
-    sqlError(p, "expression expected")
-    getTok(p) # we must consume a token here to prevend endless loops!
+    if p.tok.literal == "*":
+      result = newNode(nkIdent, p.tok.literal)
+      getTok(p)
+    else:
+      sqlError(p, "expression expected")
+      getTok(p) # we must consume a token here to prevend endless loops!
 
 proc primary(p: var SqlParser): SqlNode =
-  if p.tok.kind == tkOperator or isKeyw(p, "not"):
+  if (p.tok.kind == tkOperator and (p.tok.literal == "+" or p.tok.literal == "-")) or isKeyw(p, "not"):
     result = newNode(nkPrefix)
     result.add(newNode(nkIdent, p.tok.literal))
     getTok(p)
@@ -762,7 +776,7 @@ proc lowestExprAux(p: var SqlParser, v: var SqlNode, limit: int): int =
   result = opPred
   while opPred > limit:
     node = newNode(nkInfix)
-    opNode = newNode(nkIdent, p.tok.literal.toLower())
+    opNode = newNode(nkIdent, p.tok.literal.toLowerAscii())
     getTok(p)
     result = lowestExprAux(p, v2, opPred)
     node.add(opNode)
@@ -1078,11 +1092,23 @@ proc parseSelect(p: var SqlParser): SqlNode =
       if p.tok.kind != tkComma: break
       getTok(p)
     result.add(g)
-  if isKeyw(p, "limit"):
+  if isKeyw(p, "order"):
     getTok(p)
-    var l = newNode(nkLimit)
-    l.add(parseExpr(p))
-    result.add(l)
+    eat(p, "by")
+    var n = newNode(nkOrder)
+    while true:
+      var e = parseExpr(p)
+      if isKeyw(p, "asc"):
+        getTok(p) # is default
+      elif isKeyw(p, "desc"):
+        getTok(p)
+        var x = newNode(nkDesc)
+        x.add(e)
+        e = x
+      n.add(e)
+      if p.tok.kind != tkComma: break
+      getTok(p)
+    result.add(n)
   if isKeyw(p, "having"):
     var h = newNode(nkHaving)
     while true:
@@ -1099,22 +1125,6 @@ proc parseSelect(p: var SqlParser): SqlNode =
   elif isKeyw(p, "except"):
     result.add(newNode(nkExcept))
     getTok(p)
-  if isKeyw(p, "order"):
-    getTok(p)
-    eat(p, "by")
-    var n = newNode(nkOrder)
-    while true:
-      var e = parseExpr(p)
-      if isKeyw(p, "asc"): getTok(p) # is default
-      elif isKeyw(p, "desc"):
-        getTok(p)
-        var x = newNode(nkDesc)
-        x.add(e)
-        e = x
-      n.add(e)
-      if p.tok.kind != tkComma: break
-      getTok(p)
-    result.add(n)
   if isKeyw(p, "join") or isKeyw(p, "inner") or isKeyw(p, "outer") or isKeyw(p, "cross"):
     var join = newNode(nkJoin)
     result.add(join)
@@ -1122,12 +1132,17 @@ proc parseSelect(p: var SqlParser): SqlNode =
       join.add(newNode(nkIdent, ""))
       getTok(p)
     else:
-      join.add(newNode(nkIdent, p.tok.literal.toLower()))
+      join.add(newNode(nkIdent, p.tok.literal.toLowerAscii()))
       getTok(p)
       eat(p, "join")
     join.add(parseFromItem(p))
     eat(p, "on")
     join.add(parseExpr(p))
+  if isKeyw(p, "limit"):
+    getTok(p)
+    var l = newNode(nkLimit)
+    l.add(parseExpr(p))
+    result.add(l)
 
 proc parseStmt(p: var SqlParser; parent: SqlNode) =
   if isKeyw(p, "create"):
@@ -1161,14 +1176,6 @@ proc parseStmt(p: var SqlParser; parent: SqlNode) =
   else:
     sqlError(p, "SELECT, CREATE, UPDATE or DELETE expected")
 
-proc open(p: var SqlParser, input: Stream, filename: string) =
-  ## opens the parser `p` and assigns the input stream `input` to it.
-  ## `filename` is only used for error messages.
-  open(SqlLexer(p), input, filename)
-  p.tok.kind = tkInvalid
-  p.tok.literal = ""
-  getTok(p)
-
 proc parse(p: var SqlParser): SqlNode =
   ## parses the content of `p`'s input stream and returns the SQL AST.
   ## Syntax errors raise an `SqlParseError` exception.
@@ -1183,24 +1190,6 @@ proc close(p: var SqlParser) =
   ## closes the parser `p`. The associated input stream is closed too.
   close(SqlLexer(p))
 
-proc parseSQL*(input: Stream, filename: string): SqlNode =
-  ## parses the SQL from `input` into an AST and returns the AST.
-  ## `filename` is only used for error messages.
-  ## Syntax errors raise an `SqlParseError` exception.
-  var p: SqlParser
-  open(p, input, filename)
-  try:
-    result = parse(p)
-  finally:
-    close(p)
-
-proc parseSQL*(input: string, filename=""): SqlNode =
-  ## parses the SQL from `input` into an AST and returns the AST.
-  ## `filename` is only used for error messages.
-  ## Syntax errors raise an `SqlParseError` exception.
-  parseSQL(newStringStream(input), "")
-
-
 type
   SqlWriter = object
     indent: int
@@ -1218,12 +1207,12 @@ proc add(s: var SqlWriter, thing: string) =
 proc addKeyw(s: var SqlWriter, thing: string) =
   var keyw = thing
   if s.upperCase:
-    keyw = keyw.toUpper()
+    keyw = keyw.toUpperAscii()
   s.add(keyw)
 
 proc addIden(s: var SqlWriter, thing: string) =
   var iden = thing
-  if iden.toLower() in reservedKeywords:
+  if iden.toLowerAscii() in reservedKeywords:
     iden = '"' & iden & '"'
   s.add(iden)
 
@@ -1251,15 +1240,20 @@ proc addMulti(s: var SqlWriter, n: SqlNode, sep = ',', prefix, suffix: char) =
       ra(n.sons[i], s)
     s.add(suffix)
 
+proc quoted(s: string): string =
+  "\"" & replace(s, "\"", "\"\"") & "\""
+
 proc ra(n: SqlNode, s: var SqlWriter) =
   if n == nil: return
   case n.kind
   of nkNone: discard
   of nkIdent:
-    if allCharsInSet(n.strVal, {'\33'..'\127'}) and n.strVal.toLower() notin reservedKeywords:
+    if allCharsInSet(n.strVal, {'\33'..'\127'}):
       s.add(n.strVal)
     else:
-      s.add("\"" & replace(n.strVal, "\"", "\"\"") & "\"")
+      s.add(quoted(n.strVal))
+  of nkQuotedIdent:
+    s.add(quoted(n.strVal))
   of nkStringLit:
     s.add(escape(n.strVal, "'", "'"))
   of nkBitStringLit:
@@ -1361,18 +1355,19 @@ proc ra(n: SqlNode, s: var SqlWriter) =
     s.addKeyw("select")
     if n.kind == nkSelectDistinct:
       s.addKeyw("distinct")
-    s.addMulti(n.sons[0])
-    for i in 1 .. n.len-1:
+    for i in 0 ..< n.len:
       ra(n.sons[i], s)
   of nkSelectColumns:
-    assert(false)
+    for i, column in n.sons:
+      if i > 0: s.add(',')
+      ra(column, s)
   of nkSelectPair:
     ra(n.sons[0], s)
     if n.sons.len == 2:
       s.addKeyw("as")
       ra(n.sons[1], s)
   of nkFromItemPair:
-    if n.sons[0].kind == nkIdent:
+    if n.sons[0].kind in {nkIdent, nkQuotedIdent}:
       ra(n.sons[0], s)
     else:
       assert n.sons[0].kind == nkSelect
@@ -1472,3 +1467,35 @@ proc renderSQL*(n: SqlNode, upperCase=false): string =
 proc `$`*(n: SqlNode): string =
   ## an alias for `renderSQL`.
   renderSQL(n)
+
+when not defined(js):
+  import streams
+
+  proc open(L: var SqlLexer, input: Stream, filename: string) =
+    lexbase.open(L, input)
+    L.filename = filename
+
+  proc open(p: var SqlParser, input: Stream, filename: string) =
+    ## opens the parser `p` and assigns the input stream `input` to it.
+    ## `filename` is only used for error messages.
+    open(SqlLexer(p), input, filename)
+    p.tok.kind = tkInvalid
+    p.tok.literal = ""
+    getTok(p)
+
+  proc parseSQL*(input: Stream, filename: string): SqlNode =
+    ## parses the SQL from `input` into an AST and returns the AST.
+    ## `filename` is only used for error messages.
+    ## Syntax errors raise an `SqlParseError` exception.
+    var p: SqlParser
+    open(p, input, filename)
+    try:
+      result = parse(p)
+    finally:
+      close(p)
+
+  proc parseSQL*(input: string, filename=""): SqlNode =
+    ## parses the SQL from `input` into an AST and returns the AST.
+    ## `filename` is only used for error messages.
+    ## Syntax errors raise an `SqlParseError` exception.
+    parseSQL(newStringStream(input), "")