summary refs log tree commit diff stats
path: root/lib
diff options
context:
space:
mode:
Diffstat (limited to 'lib')
-rw-r--r--lib/pure/parsesql.nim415
1 files changed, 256 insertions, 159 deletions
diff --git a/lib/pure/parsesql.nim b/lib/pure/parsesql.nim
index 6891e2ff7..b53f46f82 100644
--- a/lib/pure/parsesql.nim
+++ b/lib/pure/parsesql.nim
@@ -462,27 +462,27 @@ proc errorStr(L: SqlLexer, msg: string): string =
 
 # ----------------------------- parser ----------------------------------------
 
-# Operator/Element	Associativity	Description
-# .	                left    	table/column name separator
-# ::            	left	        PostgreSQL-style typecast
-# [ ]	                left    	array element selection
-# -	                right	        unary minus
-# ^                     left	        exponentiation
-# * / %	                left	        multiplication, division, modulo
-# + -	                left	        addition, subtraction
-# IS	 	IS TRUE, IS FALSE, IS UNKNOWN, IS NULL
-# ISNULL	 	test for null
-# NOTNULL	 	test for not null
-# (any other)	        left    	all other native and user-defined oprs
-# IN	          	set membership
-# BETWEEN	 	range containment
-# OVERLAPS	 	time interval overlap
-# LIKE ILIKE SIMILAR	 	string pattern matching
-# < >	 	less than, greater than
-# =	                right	        equality, assignment
-# NOT	                right	        logical negation
-# AND	                left	        logical conjunction
-# OR              	left	        logical disjunction
+# Operator/Element   Associativity  Description
+# .                  left           table/column name separator
+# ::                 left           PostgreSQL-style typecast
+# [ ]                left           array element selection
+# -                  right          unary minus
+# ^                  left           exponentiation
+# * / %              left           multiplication, division, modulo
+# + -                left           addition, subtraction
+# IS                                IS TRUE, IS FALSE, IS UNKNOWN, IS NULL
+# ISNULL                            test for null
+# NOTNULL                           test for not null
+# (any other)        left           all other native and user-defined oprs
+# IN                                set membership
+# BETWEEN                           range containment
+# OVERLAPS                          time interval overlap
+# LIKE ILIKE SIMILAR                string pattern matching
+# < >                               less than, greater than
+# =                  right          equality, assignment
+# NOT                right          logical negation
+# AND                left           logical conjunction
+# OR                 left           logical disjunction
 
 type
   SqlNodeKind* = enum ## kind of SQL abstract syntax tree
@@ -518,11 +518,15 @@ type
     nkSelect,
     nkSelectDistinct,
     nkSelectColumns,
+    nkSelectPair,
     nkAsgn,
     nkFrom,
+    nkFromItemPair,
     nkGroup,
+    nkLimit,
     nkHaving,
     nkOrder,
+    nkJoin,
     nkDesc,
     nkUnion,
     nkIntersect,
@@ -670,6 +674,7 @@ proc getPrecedence(p: SqlParser): int =
     result = - 1
 
 proc parseExpr(p: var SqlParser): SqlNode
+proc parseSelect(p: var SqlParser): SqlNode
 
 proc identOrLiteral(p: var SqlParser): SqlNode =
   case p.tok.kind
@@ -921,6 +926,19 @@ proc parseWhere(p: var SqlParser): SqlNode =
   result = newNode(nkWhere)
   result.add(parseExpr(p))
 
+proc parseFromItem(p: var SqlParser): SqlNode =
+  result = newNode(nkFromItemPair)
+  if p.tok.kind == tkParLe:
+    getTok(p)
+    var select = parseSelect(p)
+    result.add(select)
+    eat(p, tkParRi)
+  else:
+    result.add(parseExpr(p))
+  if isKeyw(p, "as"):
+    getTok(p)
+    result.add(parseExpr(p))
+
 proc parseIndexDef(p: var SqlParser): SqlNode =
   result = parseIfNotExists(p, nkCreateIndex)
   if isKeyw(p, "primary"):
@@ -1019,7 +1037,12 @@ proc parseSelect(p: var SqlParser): SqlNode =
       a.add(newNode(nkIdent, "*"))
       getTok(p)
     else:
-      a.add(parseExpr(p))
+      var pair = newNode(nkSelectPair)
+      pair.add(parseExpr(p))
+      a.add(pair)
+      if isKeyw(p, "as"):
+        getTok(p)
+        pair.add(parseExpr(p))
     if p.tok.kind != tkComma: break
     getTok(p)
   result.add(a)
@@ -1027,7 +1050,7 @@ proc parseSelect(p: var SqlParser): SqlNode =
     var f = newNode(nkFrom)
     while true:
       getTok(p)
-      f.add(parseExpr(p))
+      f.add(parseFromItem(p))
       if p.tok.kind != tkComma: break
     result.add(f)
   if isKeyw(p, "where"):
@@ -1041,6 +1064,11 @@ proc parseSelect(p: var SqlParser): SqlNode =
       if p.tok.kind != tkComma: break
       getTok(p)
     result.add(g)
+  if isKeyw(p, "limit"):
+    getTok(p)
+    var l = newNode(nkLimit)
+    l.add(parseExpr(p))
+    result.add(l)
   if isKeyw(p, "having"):
     var h = newNode(nkHaving)
     while true:
@@ -1073,6 +1101,18 @@ proc parseSelect(p: var SqlParser): SqlNode =
       if p.tok.kind != tkComma: break
       getTok(p)
     result.add(n)
+  if isKeyw(p, "join") or isKeyw(p, "inner") or isKeyw(p, "outer") or isKeyw(p, "cross"):
+    var join = newNode(nkJoin)
+    result.add(join)
+    if isKeyw(p, "join"):
+      join.add(newNode(nkIdent, ""))
+      getTok(p)
+    else:
+      join.add(parseExpr(p))
+      eat(p, "join")
+    join.add(parseFromItem(p))
+    eat(p, "on")
+    join.add(parseExpr(p))
 
 proc parseStmt(p: var SqlParser; parent: SqlNode) =
   if isKeyw(p, "create"):
@@ -1104,7 +1144,7 @@ proc parseStmt(p: var SqlParser; parent: SqlNode) =
   elif isKeyw(p, "begin"):
     getTok(p)
   else:
-    sqlError(p, "CREATE expected")
+    sqlError(p, "SELECT, CREATE, UPDATE or DELETE expected")
 
 proc open(p: var SqlParser, input: Stream, filename: string) =
   ## opens the parser `p` and assigns the input stream `input` to it.
@@ -1120,6 +1160,8 @@ proc parse(p: var SqlParser): SqlNode =
   result = newNode(nkStmtList)
   while p.tok.kind != tkEof:
     parseStmt(p, result)
+    if p.tok.kind == tkEof:
+      break
     eat(p, tkSemicolon)
   if result.len == 1:
     result = result.sons[0]
@@ -1139,19 +1181,69 @@ proc parseSQL*(input: Stream, filename: string): SqlNode =
   finally:
     close(p)
 
-proc ra(n: SqlNode, s: var string, indent: int)
+proc parseSQL*(input: string, filename=""): SqlNode =
+  ## parses the SQL from `input` into an AST and returns the AST.
+  ## `filename` is only used for error messages.
+  ## Syntax errors raise an `EInvalidSql` exception.
+  parseSQL(newStringStream(input), "")
 
-proc rs(n: SqlNode, s: var string, indent: int,
-        prefix = "(", suffix = ")",
-        sep = ", ") =
+
+type
+  SqlWriter = object
+    indent: int
+    upperCase: bool
+    buffer: string
+
+proc add(s: var SqlWriter, thing: string) =
+  s.buffer.add(thing)
+
+proc add(s: var SqlWriter, thing: char) =
+  s.buffer.add(thing)
+
+proc addKeyw(s: var SqlWriter, thing: string) =
+  if s.buffer.len > 0 and s.buffer[^1] notin " ,\L(":
+    s.buffer.add(" ")
+  if s.upperCase:
+    s.buffer.add(thing.toUpper())
+  else:
+    s.buffer.add(thing)
+  s.buffer.add(" ")
+
+proc rm(s: var SqlWriter, chars = " \L,") =
+  while s.buffer[^1] in chars:
+    s.buffer = s.buffer[0..^2]
+
+proc newLine(s: var SqlWriter) =
+  s.rm(" \L")
+  s.buffer.add("\L")
+  for i in 0..<s.indent:
+    s.buffer.add("  ")
+
+template inner(s: var SqlWriter, body: untyped) =
+  inc s.indent
+  s.newLine()
+  body
+  dec s.indent
+
+template innerKeyw(s: var SqlWriter, keyw: string, body: untyped) =
+  s.newLine()
+  s.addKeyw(keyw)
+  inc s.indent
+  s.newLine()
+  body
+  dec s.indent
+
+proc ra(n: SqlNode, s: var SqlWriter)
+
+proc rs(n: SqlNode, s: var SqlWriter, prefix = "(", suffix = ")", sep = ", ") =
   if n.len > 0:
     s.add(prefix)
     for i in 0 .. n.len-1:
       if i > 0: s.add(sep)
-      ra(n.sons[i], s, indent)
+      ra(n.sons[i], s)
     s.add(suffix)
 
-proc ra(n: SqlNode, s: var string, indent: int) =
+proc ra(n: SqlNode, s: var SqlWriter) =
   if n == nil: return
   case n.kind
   of nkNone: discard
@@ -1169,217 +1261,222 @@ proc ra(n: SqlNode, s: var string, indent: int) =
   of nkIntegerLit, nkNumericLit:
     s.add(n.strVal)
   of nkPrimaryKey:
-    s.add(" primary key")
-    rs(n, s, indent)
+    s.addKeyw("primary key")
+    rs(n, s)
   of nkForeignKey:
-    s.add(" foreign key")
-    rs(n, s, indent)
+    s.addKeyw("foreign key")
+    rs(n, s)
   of nkNotNull:
-    s.add(" not null")
+    s.addKeyw("not null")
   of nkNull:
-    s.add(" null")
+    s.addKeyw("null")
   of nkDot:
-    ra(n.sons[0], s, indent)
+    ra(n.sons[0], s)
     s.add(".")
-    ra(n.sons[1], s, indent)
+    ra(n.sons[1], s)
   of nkDotDot:
-    ra(n.sons[0], s, indent)
+    ra(n.sons[0], s)
     s.add(". .")
-    ra(n.sons[1], s, indent)
+    ra(n.sons[1], s)
   of nkPrefix:
     s.add('(')
-    ra(n.sons[0], s, indent)
+    ra(n.sons[0], s)
     s.add(' ')
-    ra(n.sons[1], s, indent)
+    ra(n.sons[1], s)
     s.add(')')
   of nkInfix:
     s.add('(')
-    ra(n.sons[1], s, indent)
+    ra(n.sons[1], s)
     s.add(' ')
-    ra(n.sons[0], s, indent)
+    ra(n.sons[0], s)
     s.add(' ')
-    ra(n.sons[2], s, indent)
+    ra(n.sons[2], s)
     s.add(')')
   of nkCall, nkColumnReference:
-    ra(n.sons[0], s, indent)
+    ra(n.sons[0], s)
     s.add('(')
     for i in 1..n.len-1:
       if i > 1: s.add(", ")
-      ra(n.sons[i], s, indent)
+      ra(n.sons[i], s)
     s.add(')')
   of nkReferences:
-    s.add(" references ")
-    ra(n.sons[0], s, indent)
+    s.addKeyw("references")
+    ra(n.sons[0], s)
   of nkDefault:
-    s.add(" default ")
-    ra(n.sons[0], s, indent)
+    s.addKeyw("default")
+    ra(n.sons[0], s)
   of nkCheck:
-    s.add(" check ")
-    ra(n.sons[0], s, indent)
+    s.addKeyw("check")
+    ra(n.sons[0], s)
   of nkConstraint:
-    s.add(" constraint ")
-    ra(n.sons[0], s, indent)
-    s.add(" check ")
-    ra(n.sons[1], s, indent)
+    s.addKeyw("constraint")
+    ra(n.sons[0], s)
+    s.addKeyw("check")
+    ra(n.sons[1], s)
   of nkUnique:
-    s.add(" unique")
-    rs(n, s, indent)
+    s.addKeyw("unique")
+    rs(n, s)
   of nkIdentity:
-    s.add(" identity")
+    s.addKeyw("identity")
   of nkColumnDef:
     s.add("\n  ")
-    rs(n, s, indent, "", "", " ")
+    rs(n, s, "", "", " ")
   of nkStmtList:
     for i in 0..n.len-1:
-      ra(n.sons[i], s, indent)
+      ra(n.sons[i], s)
       s.add("\n")
   of nkInsert:
     assert n.len == 3
-    s.add("insert into ")
-    ra(n.sons[0], s, indent)
-    ra(n.sons[1], s, indent)
+    s.addKeyw("insert into")
+    ra(n.sons[0], s)
+    ra(n.sons[1], s)
     if n.sons[2].kind == nkDefault:
-      s.add("default values")
+      s.addKeyw("default values")
     else:
-      s.add("\n")
-      ra(n.sons[2], s, indent)
+      s.newLine()
+      ra(n.sons[2], s)
     s.add(';')
   of nkUpdate:
-    s.add("update ")
-    ra(n.sons[0], s, indent)
-    s.add(" set ")
+    s.addKeyw("update")
+    ra(n.sons[0], s)
+    s.addKeyw("set")
     var L = n.len
     for i in 1 .. L-2:
       if i > 1: s.add(", ")
       var it = n.sons[i]
       assert it.kind == nkAsgn
-      ra(it, s, indent)
-    ra(n.sons[L-1], s, indent)
+      ra(it, s)
+    ra(n.sons[L-1], s)
     s.add(';')
   of nkDelete:
-    s.add("delete from ")
-    ra(n.sons[0], s, indent)
-    ra(n.sons[1], s, indent)
+    s.addKeyw("delete from")
+    ra(n.sons[0], s)
+    ra(n.sons[1], s)
     s.add(';')
   of nkSelect, nkSelectDistinct:
-    s.add("select ")
+    s.addKeyw("select")
     if n.kind == nkSelectDistinct:
-      s.add("distinct ")
-    rs(n.sons[0], s, indent, "", "", ", ")
-    for i in 1 .. n.len-1: ra(n.sons[i], s, indent)
+      s.addKeyw("distinct")
+    s.inner:
+      for son in n.sons[0].sons:
+        ra(son, s)
+        s.add(',')
+        s.newLine()
+      s.rm()
+    for i in 1 .. n.len-1:
+      ra(n.sons[i], s)
     s.add(';')
   of nkSelectColumns:
     assert(false)
+  of nkSelectPair:
+    ra(n.sons[0], s)
+    if n.sons.len == 2:
+      s.addKeyw("as")
+      ra(n.sons[1], s)
+  of nkFromItemPair:
+    if n.sons[0].kind == nkIdent:
+      ra(n.sons[0], s)
+    else:
+      assert n.sons[0].kind == nkSelect
+      s.add("(")
+      s.inner:
+        ra(n.sons[0], s)
+      s.rm("; \L")
+      s.newLine()
+      s.add(")")
+    if n.sons.len == 2:
+      s.addKeyw("as")
+      ra(n.sons[1], s)
   of nkAsgn:
-    ra(n.sons[0], s, indent)
+    ra(n.sons[0], s)
     s.add(" = ")
-    ra(n.sons[1], s, indent)
+    ra(n.sons[1], s)
   of nkFrom:
-    s.add("\nfrom ")
-    rs(n, s, indent, "", "", ", ")
+    s.innerKeyw("from"):
+      rs(n, s, "", "", ", ")
   of nkGroup:
-    s.add("\ngroup by")
-    rs(n, s, indent, "", "", ", ")
+    s.innerKeyw("group by"):
+      rs(n, s, "", "", ", ")
+  of nkLimit:
+    s.innerKeyw("limit"):
+      rs(n, s, "", "", ", ")
   of nkHaving:
-    s.add("\nhaving")
-    rs(n, s, indent, "", "", ", ")
+    s.innerKeyw("having"):
+      rs(n, s, "", "", ", ")
   of nkOrder:
-    s.add("\norder by ")
-    rs(n, s, indent, "", "", ", ")
+    s.addKeyw("order by")
+    rs(n, s, "", "", ", ")
+  of nkJoin:
+    var joinType = n.sons[0].strVal
+    if joinType == "":
+      joinType = "join"
+    else:
+      joinType &= " " & "join"
+    s.innerKeyw(joinType):
+      ra(n.sons[1], s)
+    s.innerKeyw("on"):
+      ra(n.sons[2], s)
   of nkDesc:
-    ra(n.sons[0], s, indent)
-    s.add(" desc")
+    ra(n.sons[0], s)
+    s.addKeyw("desc")
   of nkUnion:
-    s.add(" union")
+    s.addKeyw("union")
   of nkIntersect:
-    s.add(" intersect")
+    s.addKeyw("intersect")
   of nkExcept:
-    s.add(" except")
+    s.addKeyw("except")
   of nkColumnList:
-    rs(n, s, indent)
+    rs(n, s)
   of nkValueList:
-    s.add("values ")
-    rs(n, s, indent)
+    s.addKeyw("values")
+    rs(n, s)
   of nkWhere:
-    s.add("\nwhere ")
-    ra(n.sons[0], s, indent)
+    s.newLine()
+    s.addKeyw("where")
+    s.inner:
+      ra(n.sons[0], s)
   of nkCreateTable, nkCreateTableIfNotExists:
-    s.add("create table ")
+    s.addKeyw("create table")
     if n.kind == nkCreateTableIfNotExists:
-      s.add("if not exists ")
-    ra(n.sons[0], s, indent)
+      s.addKeyw("if not exists")
+    ra(n.sons[0], s)
     s.add('(')
     for i in 1..n.len-1:
-      if i > 1: s.add(", ")
-      ra(n.sons[i], s, indent)
+      if i > 1: s.add(",")
+      ra(n.sons[i], s)
     s.add(");")
   of nkCreateType, nkCreateTypeIfNotExists:
-    s.add("create type ")
+    s.addKeyw("create type")
     if n.kind == nkCreateTypeIfNotExists:
-      s.add("if not exists ")
-    ra(n.sons[0], s, indent)
-    s.add(" as ")
-    ra(n.sons[1], s, indent)
+      s.addKeyw("if not exists")
+    ra(n.sons[0], s)
+    s.addKeyw("as")
+    ra(n.sons[1], s)
     s.add(';')
   of nkCreateIndex, nkCreateIndexIfNotExists:
-    s.add("create index ")
+    s.addKeyw("create index")
     if n.kind == nkCreateIndexIfNotExists:
-      s.add("if not exists ")
-    ra(n.sons[0], s, indent)
-    s.add(" on ")
-    ra(n.sons[1], s, indent)
+      s.addKeyw("if not exists")
+    ra(n.sons[0], s)
+    s.addKeyw("on")
+    ra(n.sons[1], s)
     s.add('(')
     for i in 2..n.len-1:
       if i > 2: s.add(", ")
-      ra(n.sons[i], s, indent)
+      ra(n.sons[i], s)
     s.add(");")
   of nkEnumDef:
-    s.add("enum ")
-    rs(n, s, indent)
+    s.addKeyw("enum")
+    rs(n, s)
 
-# What I want:
-#
-#select(columns = [T1.all, T2.name],
-#       fromm = [T1, T2],
-#       where = T1.name ==. T2.name,
-#       orderby = [name]):
-#
-#for row in dbQuery(db, """select x, y, z
-#                          from a, b
-#                          where a.name = b.name"""):
-#
-
-#select x, y, z:
-#  fromm: Table1, Table2
-#  where: x.name == y.name
-#db.select(fromm = [t1, t2], where = t1.name == t2.name):
-#for x, y, z in db.select(fromm = a, b where = a.name == b.name):
-#  writeLine x, y, z
-
-proc renderSQL*(n: SqlNode): string =
+proc renderSQL*(n: SqlNode, upperCase=false): string =
   ## Converts an SQL abstract syntax tree to its string representation.
-  result = ""
-  ra(n, result, 0)
+  var s: SqlWriter
+  s.buffer = ""
+  s.upperCase = upperCase
+  ra(n, s)
+  return s.buffer
 
 proc `$`*(n: SqlNode): string =
   ## an alias for `renderSQL`.
   renderSQL(n)
-
-when not defined(testing) and isMainModule:
-  echo(renderSQL(parseSQL(newStringStream("""
-      CREATE TYPE happiness AS ENUM ('happy', 'very happy', 'ecstatic');
-      CREATE TABLE holidays (
-         num_weeks int,
-         happiness happiness
-      );
-      CREATE INDEX table1_attr1 ON table1(attr1);
-
-      SELECT * FROM myTab WHERE col1 = 'happy';
-  """), "stdin")))
-
-# CREATE TYPE happiness AS ENUM ('happy', 'very happy', 'ecstatic');
-# CREATE TABLE holidays (
-#    num_weeks int,
-#    happiness happiness
-# );
-# CREATE INDEX table1_attr1 ON table1(attr1)