diff options
Diffstat (limited to 'lib/pure/parsesql.nim')
-rw-r--r--[-rwxr-xr-x] | lib/pure/parsesql.nim | 1256 |
1 files changed, 717 insertions, 539 deletions
diff --git a/lib/pure/parsesql.nim b/lib/pure/parsesql.nim index 2109c273a..a7c938d01 100755..100644 --- a/lib/pure/parsesql.nim +++ b/lib/pure/parsesql.nim @@ -1,27 +1,32 @@ # # -# Nimrod's Runtime Library +# Nim's Runtime Library # (c) Copyright 2009 Andreas Rumpf # # See the file "copying.txt", included in this # distribution, for details about the copyright. # -## The ``parsesql`` module implements a high performance SQL file +## The `parsesql` module implements a high performance SQL file ## parser. It parses PostgreSQL syntax and the SQL ANSI standard. +## +## Unstable API. -import - hashes, strutils, lexbase, streams +import std/[strutils, lexbase] +import std/private/decode_helpers + +when defined(nimPreviewSlimSystem): + import std/assertions # ------------------- scanner ------------------------------------------------- type - TTokKind = enum ## enumeration of all SQL tokens - tkInvalid, ## invalid token - tkEof, ## end of file reached - tkIdentifier, ## abc - tkQuotedIdentifier, ## "abc" - tkStringConstant, ## 'abc' + TokKind = enum ## enumeration of all SQL tokens + tkInvalid, ## invalid token + tkEof, ## end of file reached + tkIdentifier, ## abc + tkQuotedIdentifier, ## "abc" + tkStringConstant, ## 'abc' tkEscapeConstant, ## e'abc' tkDollarQuotedConstant, ## $tag$abc$tag$ tkBitStringConstant, ## B'00011' @@ -37,211 +42,195 @@ type tkBracketLe, ## '[' tkBracketRi, ## ']' tkDot ## '.' - - TToken {.final.} = object # a token - kind: TTokKind # the type of the token - literal: string # the parsed (string) literal - - TSqlLexer* = object of TBaseLexer ## the parser object. + + Token = object # a token + kind: TokKind # the type of the token + literal: string # the parsed (string) literal + + SqlLexer* = object of BaseLexer ## the parser object. filename: string const - tokKindToStr: array[TTokKind, string] = [ + tokKindToStr: array[TokKind, string] = [ "invalid", "[EOF]", "identifier", "quoted identifier", "string constant", "escape string constant", "dollar quoted constant", "bit string constant", "hex string constant", "integer constant", "numeric constant", "operator", ";", ":", ",", "(", ")", "[", "]", "." ] -proc open(L: var TSqlLexer, input: PStream, filename: string) = - lexbase.open(L, input) - L.filename = filename - -proc close(L: var TSqlLexer) = + reservedKeywords = @[ + # statements + "select", "from", "where", "group", "limit", "offset", "having", + # functions + "count", + ] + +proc close(L: var SqlLexer) = lexbase.close(L) -proc getColumn(L: TSqlLexer): int = +proc getColumn(L: SqlLexer): int = ## get the current column the parser has arrived at. - result = getColNumber(L, L.bufPos) + result = getColNumber(L, L.bufpos) -proc getLine(L: TSqlLexer): int = - result = L.linenumber +proc getLine(L: SqlLexer): int = + result = L.lineNumber -proc handleHexChar(c: var TSqlLexer, xi: var int) = - case c.buf[c.bufpos] - of '0'..'9': - xi = (xi shl 4) or (ord(c.buf[c.bufpos]) - ord('0')) - inc(c.bufpos) - of 'a'..'f': - xi = (xi shl 4) or (ord(c.buf[c.bufpos]) - ord('a') + 10) - inc(c.bufpos) - of 'A'..'F': - xi = (xi shl 4) or (ord(c.buf[c.bufpos]) - ord('A') + 10) - inc(c.bufpos) - else: - nil - -proc handleOctChar(c: var TSqlLexer, xi: var int) = +proc handleOctChar(c: var SqlLexer, xi: var int) = if c.buf[c.bufpos] in {'0'..'7'}: xi = (xi shl 3) or (ord(c.buf[c.bufpos]) - ord('0')) inc(c.bufpos) -proc getEscapedChar(c: var TSqlLexer, tok: var TToken) = +proc getEscapedChar(c: var SqlLexer, tok: var Token) = inc(c.bufpos) case c.buf[c.bufpos] - of 'n', 'N': + of 'n', 'N': add(tok.literal, '\L') - Inc(c.bufpos) - of 'r', 'R', 'c', 'C': + inc(c.bufpos) + of 'r', 'R', 'c', 'C': add(tok.literal, '\c') - Inc(c.bufpos) - of 'l', 'L': + inc(c.bufpos) + of 'l', 'L': add(tok.literal, '\L') - Inc(c.bufpos) - of 'f', 'F': + inc(c.bufpos) + of 'f', 'F': add(tok.literal, '\f') inc(c.bufpos) - of 'e', 'E': + of 'e', 'E': add(tok.literal, '\e') - Inc(c.bufpos) - of 'a', 'A': + inc(c.bufpos) + of 'a', 'A': add(tok.literal, '\a') - Inc(c.bufpos) - of 'b', 'B': + inc(c.bufpos) + of 'b', 'B': add(tok.literal, '\b') - Inc(c.bufpos) - of 'v', 'V': + inc(c.bufpos) + of 'v', 'V': add(tok.literal, '\v') - Inc(c.bufpos) - of 't', 'T': + inc(c.bufpos) + of 't', 'T': add(tok.literal, '\t') - Inc(c.bufpos) - of '\'', '\"': + inc(c.bufpos) + of '\'', '\"': add(tok.literal, c.buf[c.bufpos]) - Inc(c.bufpos) - of '\\': + inc(c.bufpos) + of '\\': add(tok.literal, '\\') - Inc(c.bufpos) - of 'x', 'X': + inc(c.bufpos) + of 'x', 'X': inc(c.bufpos) var xi = 0 - handleHexChar(c, xi) - handleHexChar(c, xi) - add(tok.literal, Chr(xi)) - of '0'..'7': + if handleHexChar(c.buf[c.bufpos], xi): + inc(c.bufpos) + if handleHexChar(c.buf[c.bufpos], xi): + inc(c.bufpos) + add(tok.literal, chr(xi)) + of '0'..'7': var xi = 0 handleOctChar(c, xi) handleOctChar(c, xi) handleOctChar(c, xi) - if (xi <= 255): add(tok.literal, Chr(xi)) + if (xi <= 255): add(tok.literal, chr(xi)) else: tok.kind = tkInvalid else: tok.kind = tkInvalid - -proc HandleCRLF(c: var TSqlLexer, pos: int): int = + +proc handleCRLF(c: var SqlLexer, pos: int): int = case c.buf[pos] - of '\c': result = lexbase.HandleCR(c, pos) - of '\L': result = lexbase.HandleLF(c, pos) + of '\c': result = lexbase.handleCR(c, pos) + of '\L': result = lexbase.handleLF(c, pos) else: result = pos -proc skip(c: var TSqlLexer) = +proc skip(c: var SqlLexer) = var pos = c.bufpos - var buf = c.buf var nested = 0 - while true: - case buf[pos] - of ' ', '\t': - Inc(pos) + while true: + case c.buf[pos] + of ' ', '\t': + inc(pos) of '-': - if buf[pos+1] == '-': - while not (buf[pos] in {'\c', '\L', lexbase.EndOfFile}): inc(pos) + if c.buf[pos+1] == '-': + while not (c.buf[pos] in {'\c', '\L', lexbase.EndOfFile}): inc(pos) else: break of '/': - if buf[pos+1] == '*': - inc(pos,2) + if c.buf[pos+1] == '*': + inc(pos, 2) while true: - case buf[pos] + case c.buf[pos] of '\0': break - of '\c', '\L': - pos = HandleCRLF(c, pos) - buf = c.buf + of '\c', '\L': + pos = handleCRLF(c, pos) of '*': - if buf[pos+1] == '/': + if c.buf[pos+1] == '/': inc(pos, 2) if nested <= 0: break dec(nested) else: inc(pos) of '/': - if buf[pos+1] == '*': + if c.buf[pos+1] == '*': inc(pos, 2) inc(nested) else: inc(pos) else: inc(pos) else: break - of '\c', '\L': - pos = HandleCRLF(c, pos) - buf = c.buf - else: - break # EndOfFile also leaves the loop + of '\c', '\L': + pos = handleCRLF(c, pos) + else: + break # EndOfFile also leaves the loop c.bufpos = pos - -proc getString(c: var TSqlLexer, tok: var TToken, kind: TTokKind) = - var pos = c.bufPos + 1 - var buf = c.buf + +proc getString(c: var SqlLexer, tok: var Token, kind: TokKind) = + var pos = c.bufpos + 1 tok.kind = kind block parseLoop: while true: - while true: - var ch = buf[pos] + while true: + var ch = c.buf[pos] if ch == '\'': - if buf[pos+1] == '\'': + if c.buf[pos+1] == '\'': inc(pos, 2) add(tok.literal, '\'') else: inc(pos) - break - elif ch in {'\c', '\L', lexbase.EndOfFile}: + break + elif ch in {'\c', '\L', lexbase.EndOfFile}: tok.kind = tkInvalid break parseLoop - elif (ch == '\\') and kind == tkEscapeConstant: - c.bufPos = pos + elif (ch == '\\') and kind == tkEscapeConstant: + c.bufpos = pos getEscapedChar(c, tok) - pos = c.bufPos - else: + pos = c.bufpos + else: add(tok.literal, ch) - Inc(pos) + inc(pos) c.bufpos = pos - var line = c.linenumber + var line = c.lineNumber skip(c) - if c.linenumber > line: + if c.lineNumber > line: # a new line whitespace has been parsed, so we check if the string # continues after the whitespace: - buf = c.buf # may have been reallocated pos = c.bufpos - if buf[pos] == '\'': inc(pos) + if c.buf[pos] == '\'': inc(pos) else: break parseLoop else: break parseLoop c.bufpos = pos -proc getDollarString(c: var TSqlLexer, tok: var TToken) = - var pos = c.bufPos + 1 - var buf = c.buf +proc getDollarString(c: var SqlLexer, tok: var Token) = + var pos = c.bufpos + 1 tok.kind = tkDollarQuotedConstant var tag = "$" - while buf[pos] in IdentChars: - add(tag, buf[pos]) + while c.buf[pos] in IdentChars: + add(tag, c.buf[pos]) inc(pos) - if buf[pos] == '$': inc(pos) + if c.buf[pos] == '$': inc(pos) else: tok.kind = tkInvalid return while true: - case buf[pos] - of '\c', '\L': - pos = HandleCRLF(c, pos) - buf = c.buf + case c.buf[pos] + of '\c', '\L': + pos = handleCRLF(c, pos) add(tok.literal, "\L") of '\0': tok.kind = tkInvalid @@ -249,190 +238,185 @@ proc getDollarString(c: var TSqlLexer, tok: var TToken) = of '$': inc(pos) var tag2 = "$" - while buf[pos] in IdentChars: - add(tag2, buf[pos]) + while c.buf[pos] in IdentChars: + add(tag2, c.buf[pos]) inc(pos) - if buf[pos] == '$': inc(pos) + if c.buf[pos] == '$': inc(pos) if tag2 == tag: break add(tok.literal, tag2) add(tok.literal, '$') else: - add(tok.literal, buf[pos]) + add(tok.literal, c.buf[pos]) inc(pos) c.bufpos = pos -proc getSymbol(c: var TSqlLexer, tok: var TToken) = +proc getSymbol(c: var SqlLexer, tok: var Token) = var pos = c.bufpos - var buf = c.buf - while true: - add(tok.literal, buf[pos]) - Inc(pos) - if not (buf[pos] in {'a'..'z','A'..'Z','0'..'9','_','$', '\128'..'\255'}): + while true: + add(tok.literal, c.buf[pos]) + inc(pos) + if c.buf[pos] notin {'a'..'z', 'A'..'Z', '0'..'9', '_', '$', + '\128'..'\255'}: break c.bufpos = pos tok.kind = tkIdentifier -proc getQuotedIdentifier(c: var TSqlLexer, tok: var TToken) = - var pos = c.bufPos + 1 - var buf = c.buf +proc getQuotedIdentifier(c: var SqlLexer, tok: var Token, quote = '\"') = + var pos = c.bufpos + 1 tok.kind = tkQuotedIdentifier while true: - var ch = buf[pos] - if ch == '\"': - if buf[pos+1] == '\"': + var ch = c.buf[pos] + if ch == quote: + if c.buf[pos+1] == quote: inc(pos, 2) - add(tok.literal, '\"') + add(tok.literal, quote) else: inc(pos) break - elif ch in {'\c', '\L', lexbase.EndOfFile}: + elif ch in {'\c', '\L', lexbase.EndOfFile}: tok.kind = tkInvalid break else: add(tok.literal, ch) - Inc(pos) + inc(pos) c.bufpos = pos -proc getBitHexString(c: var TSqlLexer, tok: var TToken, validChars: TCharSet) = - var pos = c.bufPos + 1 - var buf = c.buf +proc getBitHexString(c: var SqlLexer, tok: var Token, validChars: set[char]) = + var pos = c.bufpos + 1 block parseLoop: while true: - while true: - var ch = buf[pos] + while true: + var ch = c.buf[pos] if ch in validChars: add(tok.literal, ch) - Inc(pos) + inc(pos) elif ch == '\'': inc(pos) break - else: + else: tok.kind = tkInvalid break parseLoop c.bufpos = pos - var line = c.linenumber + var line = c.lineNumber skip(c) - if c.linenumber > line: + if c.lineNumber > line: # a new line whitespace has been parsed, so we check if the string # continues after the whitespace: - buf = c.buf # may have been reallocated pos = c.bufpos - if buf[pos] == '\'': inc(pos) + if c.buf[pos] == '\'': inc(pos) else: break parseLoop else: break parseLoop c.bufpos = pos -proc getNumeric(c: var TSqlLexer, tok: var TToken) = +proc getNumeric(c: var SqlLexer, tok: var Token) = tok.kind = tkInteger - var pos = c.bufPos - var buf = c.buf - while buf[pos] in Digits: - add(tok.literal, buf[pos]) + var pos = c.bufpos + while c.buf[pos] in Digits: + add(tok.literal, c.buf[pos]) inc(pos) - if buf[pos] == '.': + if c.buf[pos] == '.': tok.kind = tkNumeric - add(tok.literal, buf[pos]) + add(tok.literal, c.buf[pos]) inc(pos) - while buf[pos] in Digits: - add(tok.literal, buf[pos]) + while c.buf[pos] in Digits: + add(tok.literal, c.buf[pos]) inc(pos) - if buf[pos] in {'E', 'e'}: + if c.buf[pos] in {'E', 'e'}: tok.kind = tkNumeric - add(tok.literal, buf[pos]) + add(tok.literal, c.buf[pos]) inc(pos) - if buf[pos] == '+': + if c.buf[pos] == '+': inc(pos) - elif buf[pos] == '-': - add(tok.literal, buf[pos]) + elif c.buf[pos] == '-': + add(tok.literal, c.buf[pos]) inc(pos) - if buf[pos] in Digits: - while buf[pos] in Digits: - add(tok.literal, buf[pos]) + if c.buf[pos] in Digits: + while c.buf[pos] in Digits: + add(tok.literal, c.buf[pos]) inc(pos) else: tok.kind = tkInvalid - c.bufpos = pos + c.bufpos = pos -proc getOperator(c: var TSqlLexer, tok: var TToken) = +proc getOperator(c: var SqlLexer, tok: var Token) = const operators = {'+', '-', '*', '/', '<', '>', '=', '~', '!', '@', '#', '%', '^', '&', '|', '`', '?'} tok.kind = tkOperator - var pos = c.bufPos - var buf = c.buf + var pos = c.bufpos var trailingPlusMinus = false while true: - case buf[pos] + case c.buf[pos] of '-': - if buf[pos] == '-': break - if not trailingPlusMinus and buf[pos+1] notin operators and + if c.buf[pos] == '-': break + if not trailingPlusMinus and c.buf[pos+1] notin operators and tok.literal.len > 0: break of '/': - if buf[pos] == '*': break + if c.buf[pos] == '*': break of '~', '!', '@', '#', '%', '^', '&', '|', '`', '?': trailingPlusMinus = true of '+': - if not trailingPlusMinus and buf[pos+1] notin operators and + if not trailingPlusMinus and c.buf[pos+1] notin operators and tok.literal.len > 0: break - of '*', '<', '>', '=': nil + of '*', '<', '>', '=': discard else: break - add(tok.literal, buf[pos]) + add(tok.literal, c.buf[pos]) inc(pos) c.bufpos = pos -proc getTok(c: var TSqlLexer, tok: var TToken) = +proc getTok(c: var SqlLexer, tok: var Token) = tok.kind = tkInvalid - setlen(tok.literal, 0) + setLen(tok.literal, 0) skip(c) case c.buf[c.bufpos] - of ';': - tok.kind = tkSemiColon - inc(c.bufPos) + of ';': + tok.kind = tkSemicolon + inc(c.bufpos) add(tok.literal, ';') of ',': tok.kind = tkComma inc(c.bufpos) add(tok.literal, ',') - of ':': + of ':': tok.kind = tkColon inc(c.bufpos) add(tok.literal, ':') - of 'e', 'E': - if c.buf[c.bufPos + 1] == '\'': - Inc(c.bufPos) + of 'e', 'E': + if c.buf[c.bufpos + 1] == '\'': + inc(c.bufpos) getString(c, tok, tkEscapeConstant) - else: + else: getSymbol(c, tok) of 'b', 'B': - if c.buf[c.bufPos + 1] == '\'': + if c.buf[c.bufpos + 1] == '\'': tok.kind = tkBitStringConstant getBitHexString(c, tok, {'0'..'1'}) else: getSymbol(c, tok) of 'x', 'X': - if c.buf[c.bufPos + 1] == '\'': + if c.buf[c.bufpos + 1] == '\'': tok.kind = tkHexStringConstant - getBitHexString(c, tok, {'a'..'f','A'..'F','0'..'9'}) + getBitHexString(c, tok, {'a'..'f', 'A'..'F', '0'..'9'}) else: getSymbol(c, tok) of '$': getDollarString(c, tok) - of '[': + of '[': tok.kind = tkBracketLe inc(c.bufpos) add(tok.literal, '[') - of ']': + of ']': tok.kind = tkBracketRi - Inc(c.bufpos) + inc(c.bufpos) add(tok.literal, ']') of '(': tok.kind = tkParLe - Inc(c.bufpos) + inc(c.bufpos) add(tok.literal, '(') of ')': tok.kind = tkParRi - Inc(c.bufpos) + inc(c.bufpos) add(tok.literal, ')') - of '.': - if c.buf[c.bufPos + 1] in Digits: + of '.': + if c.buf[c.bufpos + 1] in Digits: getNumeric(c, tok) else: tok.kind = tkDot @@ -440,52 +424,54 @@ proc getTok(c: var TSqlLexer, tok: var TToken) = add(tok.literal, '.') of '0'..'9': getNumeric(c, tok) of '\'': getString(c, tok, tkStringConstant) - of '"': getQuotedIdentifier(c, tok) - of lexbase.EndOfFile: + of '"': getQuotedIdentifier(c, tok, '"') + of '`': getQuotedIdentifier(c, tok, '`') + of lexbase.EndOfFile: tok.kind = tkEof tok.literal = "[EOF]" of 'a', 'c', 'd', 'f'..'w', 'y', 'z', 'A', 'C', 'D', 'F'..'W', 'Y', 'Z', '_', '\128'..'\255': getSymbol(c, tok) of '+', '-', '*', '/', '<', '>', '=', '~', '!', '@', '#', '%', - '^', '&', '|', '`', '?': + '^', '&', '|', '?': getOperator(c, tok) else: add(tok.literal, c.buf[c.bufpos]) inc(c.bufpos) - -proc errorStr(L: TSqlLexer, msg: string): string = + +proc errorStr(L: SqlLexer, msg: string): string = result = "$1($2, $3) Error: $4" % [L.filename, $getLine(L), $getColumn(L), msg] # ----------------------------- parser ---------------------------------------- -# Operator/Element Associativity Description -# . left table/column name separator -# :: left PostgreSQL-style typecast -# [ ] left array element selection -# - right unary minus -# ^ left exponentiation -# * / % left multiplication, division, modulo -# + - left addition, subtraction -# IS IS TRUE, IS FALSE, IS UNKNOWN, IS NULL -# ISNULL test for null -# NOTNULL test for not null -# (any other) left all other native and user-defined oprs -# IN set membership -# BETWEEN range containment -# OVERLAPS time interval overlap -# LIKE ILIKE SIMILAR string pattern matching -# < > less than, greater than -# = right equality, assignment -# NOT right logical negation -# AND left logical conjunction -# OR left logical disjunction +# Operator/Element Associativity Description +# . left table/column name separator +# :: left PostgreSQL-style typecast +# [ ] left array element selection +# - right unary minus +# ^ left exponentiation +# * / % left multiplication, division, modulo +# + - left addition, subtraction +# IS IS TRUE, IS FALSE, IS UNKNOWN, IS NULL +# ISNULL test for null +# NOTNULL test for not null +# (any other) left all other native and user-defined oprs +# IN set membership +# BETWEEN range containment +# OVERLAPS time interval overlap +# LIKE ILIKE SIMILAR string pattern matching +# < > less than, greater than +# = right equality, assignment +# NOT right logical negation +# AND left logical conjunction +# OR left logical disjunction type - TSqlNodeKind* = enum ## kind of SQL abstract syntax tree + SqlNodeKind* = enum ## kind of SQL abstract syntax tree nkNone, nkIdent, + nkQuotedIdent, nkStringLit, nkBitStringLit, nkHexStringLit, @@ -494,13 +480,15 @@ type nkPrimaryKey, nkForeignKey, nkNotNull, - + nkNull, + nkStmtList, nkDot, nkDotDot, nkPrefix, nkInfix, nkCall, + nkPrGroup, nkColumnReference, nkReferences, nkDefault, @@ -508,18 +496,23 @@ type nkConstraint, nkUnique, nkIdentity, - nkColumnDef, ## name, datatype, constraints + nkColumnDef, ## name, datatype, constraints nkInsert, nkUpdate, nkDelete, nkSelect, nkSelectDistinct, nkSelectColumns, + nkSelectPair, nkAsgn, nkFrom, + nkFromItemPair, nkGroup, + nkLimit, + nkOffset, nkHaving, nkOrder, + nkJoin, nkDesc, nkUnion, nkIntersect, @@ -528,87 +521,108 @@ type nkValueList, nkWhere, nkCreateTable, - nkCreateTableIfNotExists, + nkCreateTableIfNotExists, nkCreateType, nkCreateTypeIfNotExists, nkCreateIndex, nkCreateIndexIfNotExists, nkEnumDef - + +const + LiteralNodes = { + nkIdent, nkQuotedIdent, nkStringLit, nkBitStringLit, nkHexStringLit, + nkIntegerLit, nkNumericLit + } + type - EInvalidSql* = object of EBase ## Invalid SQL encountered - PSqlNode* = ref TSqlNode ## an SQL abstract syntax tree node - TSqlNode* = object ## an SQL abstract syntax tree node - case kind*: TSqlNodeKind ## kind of syntax tree - of nkIdent, nkStringLit, nkBitStringLit, nkHexStringLit, - nkIntegerLit, nkNumericLit: - strVal*: string ## AST leaf: the identifier, numeric literal - ## string literal, etc. + SqlParseError* = object of ValueError ## Invalid SQL encountered + SqlNode* = ref SqlNodeObj ## an SQL abstract syntax tree node + SqlNodeObj* = object ## an SQL abstract syntax tree node + case kind*: SqlNodeKind ## kind of syntax tree + of LiteralNodes: + strVal*: string ## AST leaf: the identifier, numeric literal + ## string literal, etc. else: - sons*: seq[PSqlNode] ## the node's children + sons*: seq[SqlNode] ## the node's children - TSqlParser* = object of TSqlLexer ## SQL parser object - tok: TToken + SqlParser* = object of SqlLexer ## SQL parser object + tok: Token -proc newNode(k: TSqlNodeKind): PSqlNode = - new(result) - result.kind = k +proc newNode*(k: SqlNodeKind): SqlNode = + when defined(js): # bug #14117 + case k + of LiteralNodes: + result = SqlNode(kind: k, strVal: "") + else: + result = SqlNode(kind: k, sons: @[]) + else: + result = SqlNode(kind: k) -proc newNode(k: TSqlNodeKind, s: string): PSqlNode = - new(result) - result.kind = k +proc newNode*(k: SqlNodeKind, s: string): SqlNode = + result = SqlNode(kind: k) result.strVal = s - -proc len*(n: PSqlNode): int = - if isNil(n.sons): result = 0 - else: result = n.sons.len - -proc add*(father, n: PSqlNode) = - if isNil(father.sons): father.sons = @[] + +proc newNode*(k: SqlNodeKind, sons: seq[SqlNode]): SqlNode = + result = SqlNode(kind: k) + result.sons = sons + +proc len*(n: SqlNode): int = + if n.kind in LiteralNodes: + result = 0 + else: + result = n.sons.len + +proc `[]`*(n: SqlNode; i: int): SqlNode = n.sons[i] +proc `[]`*(n: SqlNode; i: BackwardsIndex): SqlNode = n.sons[n.len - int(i)] + +proc add*(father, n: SqlNode) = add(father.sons, n) -proc getTok(p: var TSqlParser) = +proc getTok(p: var SqlParser) = getTok(p, p.tok) -proc sqlError(p: TSqlParser, msg: string) = - var e: ref EInvalidSql +proc sqlError(p: SqlParser, msg: string) = + var e: ref SqlParseError new(e) e.msg = errorStr(p, msg) raise e -proc isKeyw(p: TSqlParser, keyw: string): bool = +proc isKeyw(p: SqlParser, keyw: string): bool = result = p.tok.kind == tkIdentifier and cmpIgnoreCase(p.tok.literal, keyw) == 0 -proc isOpr(p: TSqlParser, opr: string): bool = +proc isOpr(p: SqlParser, opr: string): bool = result = p.tok.kind == tkOperator and cmpIgnoreCase(p.tok.literal, opr) == 0 -proc optKeyw(p: var TSqlParser, keyw: string) = +proc optKeyw(p: var SqlParser, keyw: string) = if p.tok.kind == tkIdentifier and cmpIgnoreCase(p.tok.literal, keyw) == 0: getTok(p) -proc expectIdent(p: TSqlParser) = +proc expectIdent(p: SqlParser) = if p.tok.kind != tkIdentifier and p.tok.kind != tkQuotedIdentifier: sqlError(p, "identifier expected") -proc expect(p: TSqlParser, kind: TTokKind) = +proc expect(p: SqlParser, kind: TokKind) = if p.tok.kind != kind: sqlError(p, tokKindToStr[kind] & " expected") -proc eat(p: var TSqlParser, kind: TTokKind) = +proc eat(p: var SqlParser, kind: TokKind) = if p.tok.kind == kind: getTok(p) else: sqlError(p, tokKindToStr[kind] & " expected") -proc eat(p: var TSqlParser, keyw: string) = +proc eat(p: var SqlParser, keyw: string) = if isKeyw(p, keyw): getTok(p) else: - sqlError(p, keyw.toUpper() & " expected") + sqlError(p, keyw.toUpperAscii() & " expected") -proc parseDataType(p: var TSqlParser): PSqlNode = +proc opt(p: var SqlParser, kind: TokKind) = + if p.tok.kind == kind: getTok(p) + +proc parseDataType(p: var SqlParser): SqlNode = if isKeyw(p, "enum"): result = newNode(nkEnumDef) getTok(p) @@ -636,18 +650,20 @@ proc parseDataType(p: var TSqlParser): PSqlNode = getTok(p) eat(p, tkParRi) -proc getPrecedence(p: TSqlParser): int = +proc getPrecedence(p: SqlParser): int = if isOpr(p, "*") or isOpr(p, "/") or isOpr(p, "%"): result = 6 elif isOpr(p, "+") or isOpr(p, "-"): - result = 5 + result = 5 elif isOpr(p, "=") or isOpr(p, "<") or isOpr(p, ">") or isOpr(p, ">=") or isOpr(p, "<=") or isOpr(p, "<>") or isOpr(p, "!=") or isKeyw(p, "is") or - isKeyw(p, "like"): - result = 3 + isKeyw(p, "like") or isKeyw(p, "in"): + result = 4 elif isKeyw(p, "and"): - result = 2 + result = 3 elif isKeyw(p, "or"): + result = 2 + elif isKeyw(p, "between"): result = 1 elif p.tok.kind == tkOperator: # user-defined operator: @@ -655,11 +671,15 @@ proc getPrecedence(p: TSqlParser): int = else: result = - 1 -proc parseExpr(p: var TSqlParser): PSqlNode +proc parseExpr(p: var SqlParser): SqlNode {.gcsafe.} +proc parseSelect(p: var SqlParser): SqlNode {.gcsafe.} -proc identOrLiteral(p: var TSqlParser): PSqlNode = +proc identOrLiteral(p: var SqlParser): SqlNode = case p.tok.kind - of tkIdentifier, tkQuotedIdentifier: + of tkQuotedIdentifier: + result = newNode(nkQuotedIdent, p.tok.literal) + getTok(p) + of tkIdentifier: result = newNode(nkIdent, p.tok.literal) getTok(p) of tkStringConstant, tkEscapeConstant, tkDollarQuotedConstant: @@ -679,33 +699,42 @@ proc identOrLiteral(p: var TSqlParser): PSqlNode = getTok(p) of tkParLe: getTok(p) - result = parseExpr(p) + result = newNode(nkPrGroup) + while true: + result.add(parseExpr(p)) + if p.tok.kind != tkComma: break + getTok(p) eat(p, tkParRi) - else: - sqlError(p, "expression expected") - getTok(p) # we must consume a token here to prevend endless loops! + else: + if p.tok.literal == "*": + result = newNode(nkIdent, p.tok.literal) + getTok(p) + else: + sqlError(p, "expression expected") + getTok(p) # we must consume a token here to prevent endless loops! -proc primary(p: var TSqlParser): PSqlNode = - if p.tok.kind == tkOperator or isKeyw(p, "not"): +proc primary(p: var SqlParser): SqlNode = + if (p.tok.kind == tkOperator and (p.tok.literal == "+" or p.tok.literal == + "-")) or isKeyw(p, "not"): result = newNode(nkPrefix) result.add(newNode(nkIdent, p.tok.literal)) getTok(p) result.add(primary(p)) return result = identOrLiteral(p) - while true: + while true: case p.tok.kind - of tkParLe: + of tkParLe: var a = result result = newNode(nkCall) result.add(a) getTok(p) - while true: + while p.tok.kind != tkParRi: result.add(parseExpr(p)) if p.tok.kind == tkComma: getTok(p) else: break eat(p, tkParRi) - of tkDot: + of tkDot: getTok(p) var a = result if p.tok.kind == tkDot: @@ -722,16 +751,16 @@ proc primary(p: var TSqlParser): PSqlNode = sqlError(p, "identifier expected") getTok(p) else: break - -proc lowestExprAux(p: var TSqlParser, v: var PSqlNode, limit: int): int = + +proc lowestExprAux(p: var SqlParser, v: var SqlNode, limit: int): int = var - v2, node, opNode: PSqlNode + v2, node, opNode: SqlNode v = primary(p) # expand while operators have priorities higher than 'limit' var opPred = getPrecedence(p) result = opPred - while opPred > limit: + while opPred > limit: node = newNode(nkInfix) - opNode = newNode(nkIdent, p.tok.literal) + opNode = newNode(nkIdent, p.tok.literal.toLowerAscii()) getTok(p) result = lowestExprAux(p, v2, opPred) node.add(opNode) @@ -739,15 +768,15 @@ proc lowestExprAux(p: var TSqlParser, v: var PSqlNode, limit: int): int = node.add(v2) v = node opPred = getPrecedence(p) - -proc parseExpr(p: var TSqlParser): PSqlNode = + +proc parseExpr(p: var SqlParser): SqlNode = discard lowestExprAux(p, result, - 1) -proc parseTableName(p: var TSqlParser): PSqlNode = +proc parseTableName(p: var SqlParser): SqlNode = expectIdent(p) result = primary(p) -proc parseColumnReference(p: var TSqlParser): PSqlNode = +proc parseColumnReference(p: var SqlParser): SqlNode = result = parseTableName(p) if p.tok.kind == tkParLe: getTok(p) @@ -760,21 +789,31 @@ proc parseColumnReference(p: var TSqlParser): PSqlNode = result.add(parseTableName(p)) eat(p, tkParRi) -proc parseCheck(p: var TSqlParser): PSqlNode = +proc parseCheck(p: var SqlParser): SqlNode = getTok(p) result = newNode(nkCheck) result.add(parseExpr(p)) -proc parseConstraint(p: var TSqlParser): PSqlNode = +proc parseConstraint(p: var SqlParser): SqlNode = getTok(p) result = newNode(nkConstraint) expectIdent(p) result.add(newNode(nkIdent, p.tok.literal)) getTok(p) - eat(p, "check") + optKeyw(p, "check") result.add(parseExpr(p)) -proc parseColumnConstraints(p: var TSqlParser, result: PSqlNode) = +proc parseParIdentList(p: var SqlParser, father: SqlNode) = + eat(p, tkParLe) + while true: + expectIdent(p) + father.add(newNode(nkIdent, p.tok.literal)) + getTok(p) + if p.tok.kind != tkComma: break + getTok(p) + eat(p, tkParRi) + +proc parseColumnConstraints(p: var SqlParser, result: SqlNode) = while true: if isKeyw(p, "default"): getTok(p) @@ -790,6 +829,9 @@ proc parseColumnConstraints(p: var TSqlParser, result: PSqlNode) = getTok(p) eat(p, "null") result.add(newNode(nkNotNull)) + elif isKeyw(p, "null"): + getTok(p) + result.add(newNode(nkNull)) elif isKeyw(p, "identity"): getTok(p) result.add(newNode(nkIdentity)) @@ -802,19 +844,20 @@ proc parseColumnConstraints(p: var TSqlParser, result: PSqlNode) = elif isKeyw(p, "constraint"): result.add(parseConstraint(p)) elif isKeyw(p, "unique"): + getTok(p) result.add(newNode(nkUnique)) else: break -proc parseColumnDef(p: var TSqlParser): PSqlNode = +proc parseColumnDef(p: var SqlParser): SqlNode = expectIdent(p) result = newNode(nkColumnDef) result.add(newNode(nkIdent, p.tok.literal)) getTok(p) result.add(parseDataType(p)) - parseColumnConstraints(p, result) + parseColumnConstraints(p, result) -proc parseIfNotExists(p: var TSqlParser, k: TSqlNodeKind): PSqlNode = +proc parseIfNotExists(p: var SqlParser, k: SqlNodeKind): SqlNode = getTok(p) if isKeyw(p, "if"): getTok(p) @@ -824,17 +867,7 @@ proc parseIfNotExists(p: var TSqlParser, k: TSqlNodeKind): PSqlNode = else: result = newNode(k) -proc parseParIdentList(p: var TSqlParser, father: PSqlNode) = - eat(p, tkParLe) - while true: - expectIdent(p) - father.add(newNode(nkIdent, p.tok.literal)) - getTok(p) - if p.tok.kind != tkComma: break - getTok(p) - eat(p, tkParRi) - -proc parseTableConstraint(p: var TSqlParser): PSqlNode = +proc parseTableConstraint(p: var SqlParser): SqlNode = if isKeyw(p, "primary"): getTok(p) eat(p, "key") @@ -861,22 +894,36 @@ proc parseTableConstraint(p: var TSqlParser): PSqlNode = else: sqlError(p, "column definition expected") -proc parseTableDef(p: var TSqlParser): PSqlNode = +proc parseUnique(p: var SqlParser): SqlNode = + result = parseExpr(p) + if result.kind == nkCall: result.kind = nkUnique + +proc parseTableDef(p: var SqlParser): SqlNode = result = parseIfNotExists(p, nkCreateTable) expectIdent(p) result.add(newNode(nkIdent, p.tok.literal)) getTok(p) if p.tok.kind == tkParLe: - while true: - getTok(p) - if p.tok.kind == tkIdentifier or p.tok.kind == tkQuotedIdentifier: + getTok(p) + while p.tok.kind != tkParRi: + if isKeyw(p, "constraint"): + result.add parseConstraint(p) + elif isKeyw(p, "primary") or isKeyw(p, "foreign"): + result.add parseTableConstraint(p) + elif isKeyw(p, "unique"): + result.add parseUnique(p) + elif p.tok.kind == tkIdentifier or p.tok.kind == tkQuotedIdentifier: result.add(parseColumnDef(p)) else: result.add(parseTableConstraint(p)) if p.tok.kind != tkComma: break + getTok(p) eat(p, tkParRi) - -proc parseTypeDef(p: var TSqlParser): PSqlNode = + # skip additional crap after 'create table (...) crap;' + while p.tok.kind notin {tkSemicolon, tkEof}: + getTok(p) + +proc parseTypeDef(p: var SqlParser): SqlNode = result = parseIfNotExists(p, nkCreateType) expectIdent(p) result.add(newNode(nkIdent, p.tok.literal)) @@ -884,12 +931,25 @@ proc parseTypeDef(p: var TSqlParser): PSqlNode = eat(p, "as") result.add(parseDataType(p)) -proc parseWhere(p: var TSqlParser): PSqlNode = +proc parseWhere(p: var SqlParser): SqlNode = getTok(p) result = newNode(nkWhere) result.add(parseExpr(p)) -proc parseIndexDef(p: var TSqlParser): PSqlNode = +proc parseFromItem(p: var SqlParser): SqlNode = + result = newNode(nkFromItemPair) + if p.tok.kind == tkParLe: + getTok(p) + var select = parseSelect(p) + result.add(select) + eat(p, tkParRi) + else: + result.add(parseExpr(p)) + if isKeyw(p, "as"): + getTok(p) + result.add(parseExpr(p)) + +proc parseIndexDef(p: var SqlParser): SqlNode = result = parseIfNotExists(p, nkCreateIndex) if isKeyw(p, "primary"): getTok(p) @@ -914,7 +974,7 @@ proc parseIndexDef(p: var TSqlParser): PSqlNode = getTok(p) eat(p, tkParRi) -proc parseInsert(p: var TSqlParser): PSqlNode = +proc parseInsert(p: var SqlParser): SqlNode = getTok(p) eat(p, "into") expectIdent(p) @@ -924,8 +984,9 @@ proc parseInsert(p: var TSqlParser): PSqlNode = if p.tok.kind == tkParLe: var n = newNode(nkColumnList) parseParIdentList(p, n) + result.add n else: - result.add(nil) + result.add(newNode(nkNone)) if isKeyw(p, "default"): getTok(p) eat(p, "values") @@ -941,7 +1002,7 @@ proc parseInsert(p: var TSqlParser): PSqlNode = result.add(n) eat(p, tkParRi) -proc parseUpdate(p: var TSqlParser): PSqlNode = +proc parseUpdate(p: var SqlParser): SqlNode = getTok(p) result = newNode(nkUpdate) result.add(primary(p)) @@ -960,19 +1021,21 @@ proc parseUpdate(p: var TSqlParser): PSqlNode = if isKeyw(p, "where"): result.add(parseWhere(p)) else: - result.add(nil) - -proc parseDelete(p: var TSqlParser): PSqlNode = + result.add(newNode(nkNone)) + +proc parseDelete(p: var SqlParser): SqlNode = getTok(p) + if isOpr(p, "*"): + getTok(p) result = newNode(nkDelete) eat(p, "from") result.add(primary(p)) if isKeyw(p, "where"): result.add(parseWhere(p)) else: - result.add(nil) + result.add(newNode(nkNone)) -proc parseSelect(p: var TSqlParser): PSqlNode = +proc parseSelect(p: var SqlParser): SqlNode = getTok(p) if isKeyw(p, "distinct"): getTok(p) @@ -986,7 +1049,12 @@ proc parseSelect(p: var TSqlParser): PSqlNode = a.add(newNode(nkIdent, "*")) getTok(p) else: - a.add(parseExpr(p)) + var pair = newNode(nkSelectPair) + pair.add(parseExpr(p)) + a.add(pair) + if isKeyw(p, "as"): + getTok(p) + pair.add(parseExpr(p)) if p.tok.kind != tkComma: break getTok(p) result.add(a) @@ -994,7 +1062,7 @@ proc parseSelect(p: var TSqlParser): PSqlNode = var f = newNode(nkFrom) while true: getTok(p) - f.add(parseExpr(p)) + f.add(parseFromItem(p)) if p.tok.kind != tkComma: break result.add(f) if isKeyw(p, "where"): @@ -1008,29 +1076,14 @@ proc parseSelect(p: var TSqlParser): PSqlNode = if p.tok.kind != tkComma: break getTok(p) result.add(g) - if isKeyw(p, "having"): - var h = newNode(nkHaving) - while true: - getTok(p) - h.add(parseExpr(p)) - if p.tok.kind != tkComma: break - result.add(h) - if isKeyw(p, "union"): - result.add(newNode(nkUnion)) - getTok(p) - elif isKeyw(p, "intersect"): - result.add(newNode(nkIntersect)) - getTok(p) - elif isKeyw(p, "except"): - result.add(newNode(nkExcept)) - getTok(p) if isKeyw(p, "order"): getTok(p) eat(p, "by") var n = newNode(nkOrder) while true: var e = parseExpr(p) - if isKeyw(p, "asc"): getTok(p) # is default + if isKeyw(p, "asc"): + getTok(p) # is default elif isKeyw(p, "desc"): getTok(p) var x = newNode(nkDesc) @@ -1040,8 +1093,47 @@ proc parseSelect(p: var TSqlParser): PSqlNode = if p.tok.kind != tkComma: break getTok(p) result.add(n) + if isKeyw(p, "having"): + var h = newNode(nkHaving) + while true: + getTok(p) + h.add(parseExpr(p)) + if p.tok.kind != tkComma: break + result.add(h) + if isKeyw(p, "union"): + result.add(newNode(nkUnion)) + getTok(p) + elif isKeyw(p, "intersect"): + result.add(newNode(nkIntersect)) + getTok(p) + elif isKeyw(p, "except"): + result.add(newNode(nkExcept)) + getTok(p) + if isKeyw(p, "join") or isKeyw(p, "inner") or isKeyw(p, "outer") or isKeyw(p, "cross"): + var join = newNode(nkJoin) + result.add(join) + if isKeyw(p, "join"): + join.add(newNode(nkIdent, "")) + getTok(p) + else: + join.add(newNode(nkIdent, p.tok.literal.toLowerAscii())) + getTok(p) + eat(p, "join") + join.add(parseFromItem(p)) + eat(p, "on") + join.add(parseExpr(p)) + if isKeyw(p, "limit"): + getTok(p) + var l = newNode(nkLimit) + l.add(parseExpr(p)) + result.add(l) + if isKeyw(p, "offset"): + getTok(p) + var o = newNode(nkOffset) + o.add(parseExpr(p)) + result.add(o) -proc parseStmt(p: var TSqlParser): PSqlNode = +proc parseStmt(p: var SqlParser; parent: SqlNode) = if isKeyw(p, "create"): getTok(p) optKeyw(p, "cached") @@ -1053,81 +1145,121 @@ proc parseStmt(p: var TSqlParser): PSqlNode = optKeyw(p, "unique") optKeyw(p, "hash") if isKeyw(p, "table"): - result = parseTableDef(p) + parent.add parseTableDef(p) elif isKeyw(p, "type"): - result = parseTypeDef(p) + parent.add parseTypeDef(p) elif isKeyw(p, "index"): - result = parseIndexDef(p) + parent.add parseIndexDef(p) else: sqlError(p, "TABLE expected") elif isKeyw(p, "insert"): - result = parseInsert(p) + parent.add parseInsert(p) elif isKeyw(p, "update"): - result = parseUpdate(p) + parent.add parseUpdate(p) elif isKeyw(p, "delete"): - result = parseDelete(p) + parent.add parseDelete(p) elif isKeyw(p, "select"): - result = parseSelect(p) + parent.add parseSelect(p) + elif isKeyw(p, "begin"): + getTok(p) else: - sqlError(p, "CREATE expected") + sqlError(p, "SELECT, CREATE, UPDATE or DELETE expected") -proc open(p: var TSqlParser, input: PStream, filename: string) = - ## opens the parser `p` and assigns the input stream `input` to it. - ## `filename` is only used for error messages. - open(TSqlLexer(p), input, filename) - p.tok.kind = tkInvalid - p.tok.literal = "" - getTok(p) - -proc parse(p: var TSqlParser): PSqlNode = +proc parse(p: var SqlParser): SqlNode = ## parses the content of `p`'s input stream and returns the SQL AST. - ## Syntax errors raise an `EInvalidSql` exception. + ## Syntax errors raise an `SqlParseError` exception. result = newNode(nkStmtList) while p.tok.kind != tkEof: - var s = parseStmt(p) - eat(p, tkSemiColon) - result.add(s) - if result.len == 1: - result = result.sons[0] - -proc close(p: var TSqlParser) = + parseStmt(p, result) + if p.tok.kind == tkEof: + break + eat(p, tkSemicolon) + +proc close(p: var SqlParser) = ## closes the parser `p`. The associated input stream is closed too. - close(TSqlLexer(p)) + close(SqlLexer(p)) -proc parseSQL*(input: PStream, filename: string): PSqlNode = - ## parses the SQL from `input` into an AST and returns the AST. - ## `filename` is only used for error messages. - ## Syntax errors raise an `EInvalidSql` exception. - var p: TSqlParser - open(p, input, filename) - try: - result = parse(p) - finally: - close(p) +type + SqlWriter = object + indent: int + upperCase: bool + buffer: string + +proc add(s: var SqlWriter, thing: char) = + s.buffer.add(thing) + +proc prepareAdd(s: var SqlWriter) {.inline.} = + if s.buffer.len > 0 and s.buffer[^1] notin {' ', '\L', '(', '.'}: + s.buffer.add(" ") + +proc add(s: var SqlWriter, thing: string) = + s.prepareAdd + s.buffer.add(thing) + +proc addKeyw(s: var SqlWriter, thing: string) = + var keyw = thing + if s.upperCase: + keyw = keyw.toUpperAscii() + s.add(keyw) + +proc addIden(s: var SqlWriter, thing: string) = + var iden = thing + if iden.toLowerAscii() in reservedKeywords: + iden = '"' & iden & '"' + s.add(iden) + +proc ra(n: SqlNode, s: var SqlWriter) {.gcsafe.} + +proc rs(n: SqlNode, s: var SqlWriter, prefix = "(", suffix = ")", sep = ", ") = + if n.len > 0: + s.add(prefix) + for i in 0 .. n.len-1: + if i > 0: s.add(sep) + ra(n.sons[i], s) + s.add(suffix) -proc ra(n: PSqlNode, s: var string, indent: int) +proc addMulti(s: var SqlWriter, n: SqlNode, sep = ',') = + if n.len > 0: + for i in 0 .. n.len-1: + if i > 0: s.add(sep) + ra(n.sons[i], s) -proc rs(n: PSqlNode, s: var string, indent: int, - prefix = "(", suffix = ")", - sep = ", ") = +proc addMulti(s: var SqlWriter, n: SqlNode, sep = ',', prefix, suffix: char) = if n.len > 0: s.add(prefix) for i in 0 .. n.len-1: if i > 0: s.add(sep) - ra(n.sons[i], s, indent) + ra(n.sons[i], s) s.add(suffix) -proc ra(n: PSqlNode, s: var string, indent: int) = +proc quoted(s: string): string = + "\"" & replace(s, "\"", "\"\"") & "\"" + +func escape(result: var string; s: string) = + result.add('\'') + for c in items(s): + case c + of '\0'..'\31': + result.add("\\x") + result.add(toHex(ord(c), 2)) + of '\'': result.add("''") + else: result.add(c) + result.add('\'') + +proc ra(n: SqlNode, s: var SqlWriter) = if n == nil: return case n.kind - of nkNone: nil + of nkNone: discard of nkIdent: if allCharsInSet(n.strVal, {'\33'..'\127'}): s.add(n.strVal) else: - s.add("\"" & replace(n.strVal, "\"", "\"\"") & "\"") + s.add(quoted(n.strVal)) + of nkQuotedIdent: + s.add(quoted(n.strVal)) of nkStringLit: - s.add(escape(n.strVal, "e'", "'")) + s.prepareAdd + s.buffer.escape(n.strVal) of nkBitStringLit: s.add("b'" & n.strVal & "'") of nkHexStringLit: @@ -1135,211 +1267,257 @@ proc ra(n: PSqlNode, s: var string, indent: int) = of nkIntegerLit, nkNumericLit: s.add(n.strVal) of nkPrimaryKey: - s.add(" primary key") - rs(n, s, indent) + s.addKeyw("primary key") + rs(n, s) of nkForeignKey: - s.add(" foreign key") - rs(n, s, indent) + s.addKeyw("foreign key") + rs(n, s) of nkNotNull: - s.add(" not null") + s.addKeyw("not null") + of nkNull: + s.addKeyw("null") of nkDot: - ra(n.sons[0], s, indent) - s.add(".") - ra(n.sons[1], s, indent) + ra(n.sons[0], s) + s.add('.') + ra(n.sons[1], s) of nkDotDot: - ra(n.sons[0], s, indent) + ra(n.sons[0], s) s.add(". .") - ra(n.sons[1], s, indent) + ra(n.sons[1], s) of nkPrefix: - s.add('(') - ra(n.sons[0], s, indent) + ra(n.sons[0], s) s.add(' ') - ra(n.sons[1], s, indent) - s.add(')') + ra(n.sons[1], s) of nkInfix: - s.add('(') - ra(n.sons[1], s, indent) + ra(n.sons[1], s) s.add(' ') - ra(n.sons[0], s, indent) + ra(n.sons[0], s) s.add(' ') - ra(n.sons[2], s, indent) - s.add(')') + ra(n.sons[2], s) of nkCall, nkColumnReference: - ra(n.sons[0], s, indent) + ra(n.sons[0], s) s.add('(') for i in 1..n.len-1: - if i > 1: s.add(", ") - ra(n.sons[i], s, indent) + if i > 1: s.add(',') + ra(n.sons[i], s) + s.add(')') + of nkPrGroup: + s.add('(') + s.addMulti(n) s.add(')') of nkReferences: - s.add(" references ") - ra(n.sons[0], s, indent) + s.addKeyw("references") + ra(n.sons[0], s) of nkDefault: - s.add(" default ") - ra(n.sons[0], s, indent) + s.addKeyw("default") + ra(n.sons[0], s) of nkCheck: - s.add(" check ") - ra(n.sons[0], s, indent) + s.addKeyw("check") + ra(n.sons[0], s) of nkConstraint: - s.add(" constraint ") - ra(n.sons[0], s, indent) - s.add(" check ") - ra(n.sons[1], s, indent) + s.addKeyw("constraint") + ra(n.sons[0], s) + s.addKeyw("check") + ra(n.sons[1], s) of nkUnique: - s.add(" unique") - rs(n, s, indent) + s.addKeyw("unique") + rs(n, s) of nkIdentity: - s.add(" identity") + s.addKeyw("identity") of nkColumnDef: - s.add("\n ") - rs(n, s, indent, "", "", " ") + rs(n, s, "", "", " ") of nkStmtList: for i in 0..n.len-1: - ra(n.sons[i], s, indent) - s.add("\n") + ra(n.sons[i], s) + s.add(';') of nkInsert: assert n.len == 3 - s.add("insert into ") - ra(n.sons[0], s, indent) - ra(n.sons[1], s, indent) - if n.sons[2].kind == nkDefault: - s.add("default values") + s.addKeyw("insert into") + ra(n.sons[0], s) + s.add(' ') + ra(n.sons[1], s) + if n.sons[2].kind == nkDefault: + s.addKeyw("default values") else: - s.add("\nvalues ") - ra(n.sons[2], s, indent) - s.add(';') - of nkUpdate: - s.add("update ") - ra(n.sons[0], s, indent) - s.add(" set ") + ra(n.sons[2], s) + of nkUpdate: + s.addKeyw("update") + ra(n.sons[0], s) + s.addKeyw("set") var L = n.len for i in 1 .. L-2: if i > 1: s.add(", ") var it = n.sons[i] assert it.kind == nkAsgn - ra(it, s, indent) - ra(n.sons[L-1], s, indent) - s.add(';') - of nkDelete: - s.add("delete from ") - ra(n.sons[0], s, indent) - ra(n.sons[1], s, indent) - s.add(';') + ra(it, s) + ra(n.sons[L-1], s) + of nkDelete: + s.addKeyw("delete from") + ra(n.sons[0], s) + ra(n.sons[1], s) of nkSelect, nkSelectDistinct: - s.add("select ") + s.addKeyw("select") if n.kind == nkSelectDistinct: - s.add("distinct ") - rs(n.sons[0], s, indent, "", "", ", ") - for i in 1 .. n.len-1: ra(n.sons[i], s, indent) - s.add(';') - of nkSelectColumns: - assert(false) + s.addKeyw("distinct") + for i in 0 ..< n.len: + ra(n.sons[i], s) + of nkSelectColumns: + for i, column in n.sons: + if i > 0: s.add(',') + ra(column, s) + of nkSelectPair: + ra(n.sons[0], s) + if n.sons.len == 2: + s.addKeyw("as") + ra(n.sons[1], s) + of nkFromItemPair: + if n.sons[0].kind in {nkIdent, nkQuotedIdent}: + ra(n.sons[0], s) + else: + assert n.sons[0].kind == nkSelect + s.add('(') + ra(n.sons[0], s) + s.add(')') + if n.sons.len == 2: + s.addKeyw("as") + ra(n.sons[1], s) of nkAsgn: - ra(n.sons[0], s, indent) + ra(n.sons[0], s) s.add(" = ") - ra(n.sons[1], s, indent) + ra(n.sons[1], s) of nkFrom: - s.add("\nfrom ") - rs(n, s, indent, "", "", ", ") + s.addKeyw("from") + s.addMulti(n) of nkGroup: - s.add("\ngroup by") - rs(n, s, indent, "", "", ", ") + s.addKeyw("group by") + s.addMulti(n) + of nkLimit: + s.addKeyw("limit") + s.addMulti(n) + of nkOffset: + s.addKeyw("offset") + s.addMulti(n) of nkHaving: - s.add("\nhaving") - rs(n, s, indent, "", "", ", ") + s.addKeyw("having") + s.addMulti(n) of nkOrder: - s.add("\norder by ") - rs(n, s, indent, "", "", ", ") + s.addKeyw("order by") + s.addMulti(n) + of nkJoin: + var joinType = n.sons[0].strVal + if joinType == "": + joinType = "join" + else: + joinType &= " " & "join" + s.addKeyw(joinType) + ra(n.sons[1], s) + s.addKeyw("on") + ra(n.sons[2], s) of nkDesc: - ra(n.sons[0], s, indent) - s.add(" desc") + ra(n.sons[0], s) + s.addKeyw("desc") of nkUnion: - s.add(" union") + s.addKeyw("union") of nkIntersect: - s.add(" intersect") + s.addKeyw("intersect") of nkExcept: - s.add(" except") + s.addKeyw("except") of nkColumnList: - rs(n, s, indent) + rs(n, s) of nkValueList: - s.add("values ") - rs(n, s, indent) + s.addKeyw("values") + rs(n, s) of nkWhere: - s.add("\nwhere ") - ra(n.sons[0], s, indent) + s.addKeyw("where") + ra(n.sons[0], s) of nkCreateTable, nkCreateTableIfNotExists: - s.add("create table ") + s.addKeyw("create table") if n.kind == nkCreateTableIfNotExists: - s.add("if not exists ") - ra(n.sons[0], s, indent) + s.addKeyw("if not exists") + ra(n.sons[0], s) s.add('(') for i in 1..n.len-1: - if i > 1: s.add(", ") - ra(n.sons[i], s, indent) + if i > 1: s.add(',') + ra(n.sons[i], s) s.add(");") of nkCreateType, nkCreateTypeIfNotExists: - s.add("create type ") + s.addKeyw("create type") if n.kind == nkCreateTypeIfNotExists: - s.add("if not exists ") - ra(n.sons[0], s, indent) - s.add(" as ") - ra(n.sons[1], s, indent) - s.add(';') + s.addKeyw("if not exists") + ra(n.sons[0], s) + s.addKeyw("as") + ra(n.sons[1], s) of nkCreateIndex, nkCreateIndexIfNotExists: - s.add("create index ") + s.addKeyw("create index") if n.kind == nkCreateIndexIfNotExists: - s.add("if not exists ") - ra(n.sons[0], s, indent) - s.add(" on ") - ra(n.sons[1], s, indent) + s.addKeyw("if not exists") + ra(n.sons[0], s) + s.addKeyw("on") + ra(n.sons[1], s) s.add('(') for i in 2..n.len-1: if i > 2: s.add(", ") - ra(n.sons[i], s, indent) + ra(n.sons[i], s) s.add(");") of nkEnumDef: - s.add("enum ") - rs(n, s, indent) + s.addKeyw("enum") + rs(n, s) -# What I want: -# -#select(columns = [T1.all, T2.name], -# fromm = [T1, T2], -# where = T1.name ==. T2.name, -# orderby = [name]): -# -#for row in dbQuery(db, """select x, y, z -# from a, b -# where a.name = b.name"""): -# - -#select x, y, z: -# fromm: Table1, Table2 -# where: x.name == y.name -#db.select(fromm = [t1, t2], where = t1.name == t2.name): -#for x, y, z in db.select(fromm = a, b where = a.name == b.name): -# writeln x, y, z - -proc renderSQL*(n: PSqlNode): string = +proc renderSql*(n: SqlNode, upperCase = false): string = ## Converts an SQL abstract syntax tree to its string representation. - result = "" - ra(n, result, 0) - -when isMainModule: - echo(renderSQL(parseSQL(newStringStream(""" - CREATE TYPE happiness AS ENUM ('happy', 'very happy', 'ecstatic'); - CREATE TABLE holidays ( - num_weeks int, - happiness happiness - ); - CREATE INDEX table1_attr1 ON table1(attr1); - - SELECT * FROM myTab WHERE col1 = 'happy'; - """), "stdin"))) - -# CREATE TYPE happiness AS ENUM ('happy', 'very happy', 'ecstatic'); -# CREATE TABLE holidays ( -# num_weeks int, -# happiness happiness -# ); -# CREATE INDEX table1_attr1 ON table1(attr1) + var s: SqlWriter + s.buffer = "" + s.upperCase = upperCase + ra(n, s) + return s.buffer + +proc `$`*(n: SqlNode): string = + ## an alias for `renderSql`. + renderSql(n) + +proc treeReprAux(s: SqlNode, level: int, result: var string) = + result.add('\n') + for i in 0 ..< level: result.add(" ") + + result.add($s.kind) + if s.kind in LiteralNodes: + result.add(' ') + result.add(s.strVal) + else: + for son in s.sons: + treeReprAux(son, level + 1, result) + +proc treeRepr*(s: SqlNode): string = + result = newStringOfCap(128) + treeReprAux(s, 0, result) + +import std/streams + +proc open(L: var SqlLexer, input: Stream, filename: string) = + lexbase.open(L, input) + L.filename = filename + +proc open(p: var SqlParser, input: Stream, filename: string) = + ## opens the parser `p` and assigns the input stream `input` to it. + ## `filename` is only used for error messages. + open(SqlLexer(p), input, filename) + p.tok.kind = tkInvalid + p.tok.literal = "" + getTok(p) + +proc parseSql*(input: Stream, filename: string): SqlNode = + ## parses the SQL from `input` into an AST and returns the AST. + ## `filename` is only used for error messages. + ## Syntax errors raise an `SqlParseError` exception. + var p: SqlParser + open(p, input, filename) + try: + result = parse(p) + finally: + close(p) + +proc parseSql*(input: string, filename = ""): SqlNode = + ## parses the SQL from `input` into an AST and returns the AST. + ## `filename` is only used for error messages. + ## Syntax errors raise an `SqlParseError` exception. + parseSql(newStringStream(input), "") |