#
#
# Nim's Runtime Library
# (c) Copyright 2009 Andreas Rumpf
#
# See the file "copying.txt", included in this
# distribution, for details about the copyright.
#
## The ``parsesql`` module implements a high performance SQL file
## parser. It parses PostgreSQL syntax and the SQL ANSI standard.
import
hashes, strutils, lexbase, streams
# ------------------- scanner -------------------------------------------------
type
TokKind = enum ## enumeration of all SQL tokens
tkInvalid, ## invalid token
tkEof, ## end of file reached
tkIdentifier, ## abc
tkQuotedIdentifier, ## "abc"
tkStringConstant, ## 'abc'
tkEscapeConstant, ## e'abc'
tkDollarQuotedConstant, ## $tag$abc$tag$
tkBitStringConstant, ## B'00011'
tkHexStringConstant, ## x'00011'
tkInteger,
tkNumeric,
tkOperator, ## + - * / < > = ~ ! @ # % ^ & | ` ?
tkSemicolon, ## ';'
tkColon, ## ':'
tkComma, ## ','
tkParLe, ## '('
tkParRi, ## ')'
tkBracketLe, ## '['
tkBracketRi, ## ']'
tkDot ## '.'
Token = object # a token
kind: TokKind # the type of the token
literal: string # the parsed (string) literal
SqlLexer* = object of BaseLexer ## the parser object.
filename: string
{.deprecated: [TToken: Token, TSqlLexer: SqlLexer].}
const
tokKindToStr: array[TokKind, string] = [
"invalid", "[EOF]", "identifier", "quoted identifier", "string constant",
"escape string constant", "dollar quoted constant", "bit string constant",
"hex string constant", "integer constant", "numeric constant", "operator",
";", ":", ",", "(", ")", "[", "]", "."
]
proc open(L: var SqlLexer, input: Stream, filename: string) =
lexbase.open(L, input)
L.filename = filename
proc close(L: var SqlLexer) =
lexbase.close(L)
proc getColumn(L: SqlLexer): int =
## get the current column the parser has arrived at.
result = getColNumber(L, L.bufpos)
proc getLine(L: SqlLexer): int =
result = L.lineNumber
proc handleHexChar(c: var SqlLexer, xi: var int) =
case c.buf[c.bufpos]
of '0'..'9':
xi = (xi shl 4) or (ord(c.buf[c.bufpos]) - ord('0'))
inc(c.bufpos)
of 'a'..'f':
xi = (xi shl 4) or (ord(c.buf[c.bufpos]) - ord('a') + 10)
inc(c.bufpos)
of 'A'..'F':
xi = (xi shl 4) or (ord(c.buf[c.bufpos]) - ord('A') + 10)
inc(c.bufpos)
else:
discard
proc handleOctChar(c: var SqlLexer, xi: var int) =
if c.buf[c.bufpos] in {'0'..'7'}:
xi = (xi shl 3) or (ord(c.buf[c.bufpos]) - ord('0'))
inc(c.bufpos)
proc getEscapedChar(c: var SqlLexer, tok: var Token) =
inc(c.bufpos)
case c.buf[c.bufpos]
of 'n', 'N':
add(tok.literal, '\L')
inc(c.bufpos)
of 'r', 'R', 'c', 'C':
add(tok.literal, '\c')
inc(c.bufpos)
of 'l', 'L':
add(tok.literal, '\L')
inc(c.bufpos)
of 'f', 'F':
add(tok.literal, '\f')
inc(c.bufpos)
of 'e', 'E':
add(tok.literal, '\e')
inc(c.bufpos)
of 'a', 'A':
add(tok.literal, '\a')
inc(c.bufpos)
of 'b', 'B':
add(tok.literal, '\b')
inc(c.bufpos)
of 'v', 'V':
add(tok.literal, '\v')
inc(c.bufpos)
of 't', 'T':
add(tok.literal, '\t')
inc(c.bufpos)
of '\'', '\"':
add(tok.literal, c.buf[c.bufpos])
inc(c.bufpos)
of '\\':
add(tok.literal, '\\')
inc(c.bufpos)
of 'x', 'X':
inc(c.bufpos)
var xi = 0
handleHexChar(c, xi)
handleHexChar(c, xi)
add(tok.literal, chr(xi))
of '0'..'7':
var xi = 0
handleOctChar(c, xi)
handleOctChar(c, xi)
handleOctChar(c, xi)
if (xi <= 255): add(tok.literal, chr(xi))
else: tok.kind = tkInvalid
else: tok.kind = tkInvalid
proc handleCRLF(c: var SqlLexer, pos: int): int =
case c.buf[pos]
of '\c': result = lexbase.handleCR(c, pos)
of '\L': result = lexbase.handleLF(c, pos)
else: result = pos
proc skip(c: var SqlLexer) =
var pos = c.bufpos
var buf = c.buf
var nested = 0
while true:
case buf[pos]
of ' ', '\t':
inc(pos)
of '-':
if buf[pos+1] == '-':
while not (buf[pos] in {'\c', '\L', lexbase.EndOfFile}): inc(pos)
else:
break
of '/':
if buf[pos+1] == '*':
inc(pos,2)
while true:
case buf[pos]
of '\0': break
of '\c', '\L':
pos = handleCRLF(c, pos)
buf = c.buf
of '*':
if buf[pos+1] == '/':
inc(pos, 2)
if nested <= 0: break
dec(nested)
else:
inc(pos)
of '/':
if buf[pos+1] == '*':
inc(pos, 2)
inc(nested)
else:
inc(pos)
else: inc(pos)
else: break
of '\c', '\L':
pos = handleCRLF(c, pos)
buf = c.buf
else:
break # EndOfFile also leaves the loop
c.bufpos = pos
proc getString(c: var SqlLexer, tok: var Token, kind: TokKind) =
var pos = c.bufpos + 1
var buf = c.buf
tok.kind = kind
block parseLoop:
while true:
while true:
var ch = buf[pos]
if ch == '\'':
if buf[pos+1] == '\'':
inc(pos, 2)
add(tok.literal, '\'')
else:
inc(pos)
break
elif ch in {'\c', '\L', lexbase.EndOfFile}:
tok.kind = tkInvalid
break parseLoop
elif (ch == '\\') and kind == tkEscapeConstant:
c.bufpos = pos
getEscapedChar(c, tok)
pos = c.bufpos
else:
add(tok.literal, ch)
inc(pos)
c.bufpos = pos
var line = c.lineNumber
skip(c)
if c.lineNumber > line:
# a new line whitespace has been parsed, so we check if the string
# continues after the whitespace:
buf = c.buf # may have been reallocated
pos = c.bufpos
if buf[pos] == '\'': inc(pos)
else: break parseLoop
else: break parseLoop
c.bufpos = pos
proc getDollarString(c: var SqlLexer, tok: var Token) =
var pos = c.bufpos + 1
var buf = c.buf
tok.kind = tkDollarQuotedConstant
var tag = "$"
while buf[pos] in IdentChars:
add(tag, buf[pos])
inc(pos)
if buf[pos] == '$': inc(pos)
else:
tok.kind = tkInvalid
return
while true:
case buf[pos]
of '\c', '\L':
pos = handleCRLF(c, pos)
buf = c.buf
add(tok.literal, "\L")
of '\0':
tok.kind = tkInvalid
break
of '$':
inc(pos)
var tag2 = "$"
while buf[pos] in IdentChars:
add(tag2, buf[pos])
inc(pos)
if buf[pos] == '$': inc(pos)
if tag2 == tag: break
add(tok.literal, tag2)
add(tok.literal, '$')
else:
add(tok.literal, buf[pos])
inc(pos)
c.bufpos = pos
proc getSymbol(c: var SqlLexer, tok: var Token) =
var pos = c.bufpos
var buf = c.buf
while true:
add(tok.literal, buf[pos])
inc(pos)
if buf[pos] notin {'a'..'z','A'..'Z','0'..'9','_','$', '\128'..'\255'}:
break
c.bufpos = pos
tok.kind = tkIdentifier
proc getQuotedIdentifier(c: var SqlLexer, tok: var Token) =
var pos = c.bufpos + 1
var buf = c.buf
tok.kind = tkQuotedIdentifier
while true:
var ch = buf[pos]
if ch == '\"':
if buf[pos+1] == '\"':
inc(pos, 2)
add(tok.literal, '\"')
else:
inc(pos)
break
elif ch in {'\c', '\L', lexbase.EndOfFile}:
tok.kind = tkInvalid
break
else:
add(tok.literal, ch)
inc(pos)
c.bufpos = pos
proc getBitHexString(c: var SqlLexer, tok: var Token, validChars: set[char]) =
var pos = c.bufpos + 1
var buf = c.buf
block parseLoop:
while true:
while true:
var ch = buf[pos]
if ch in validChars:
add(tok.literal, ch)
inc(pos)
elif ch == '\'':
inc(pos)
break
else:
tok.kind = tkInvalid
break parseLoop
c.bufpos = pos
var line = c.lineNumber
skip(c)
if c.lineNumber > line:
# a new line whitespace has been parsed, so we check if the string
# continues after the whitespace:
buf = c.buf # may have been reallocated
pos = c.bufpos
if buf[pos] == '\'': inc(pos)
else: break parseLoop
else: break parseLoop
c.bufpos = pos
proc getNumeric(c: var SqlLexer, tok: var Token) =
tok.kind = tkInteger
var pos = c.bufpos
var buf = c.buf
while buf[pos] in Digits:
add(tok.literal, buf[pos])
inc(pos)
if buf[pos] == '.':
tok.kind = tkNumeric
add(tok.literal, buf[pos])
inc(pos)
while buf[pos] in Digits:
add(tok.literal, buf[pos])
inc(pos)
if buf[pos] in {'E', 'e'}:
tok.kind = tkNumeric
add(tok.literal, buf[pos])
inc(pos)
if buf[pos] == '+':
inc(pos)
elif buf[pos] == '-':
add(tok.literal, buf[pos])
inc(pos)
if buf[pos] in Digits:
while buf[pos] in Digits:
add(tok.literal, buf[pos])
inc(pos)
else:
tok.kind = tkInvalid
c.bufpos = pos
proc getOperator(c: var SqlLexer, tok: var Token) =
const operators = {'+', '-', '*', '/', '<', '>', '=', '~', '!', '@', '#', '%',
'^', '&', '|', '`', '?'}
tok.kind = tkOperator
var pos = c.bufpos
var buf = c.buf
var trailingPlusMinus = false
while true:
case buf[pos]
of '-':
if buf[pos] == '-': break
if not trailingPlusMinus and buf[pos+1] notin operators and
tok.literal.len > 0: break
of '/':
if buf[pos] == '*': break
of '~', '!', '@', '#', '%', '^', '&', '|', '`', '?':
trailingPlusMinus = true
of '+':
if not trailingPlusMinus and buf[pos+1] notin operators and
tok.literal.len > 0: break
of '*', '<', '>', '=': discard
else: break
add(tok.literal, buf[pos])
inc(pos)
c.bufpos = pos
proc getTok(c: var SqlLexer, tok: var Token) =
tok.kind = tkInvalid
setLen(tok.literal, 0)
skip(c)
case c.buf[c.bufpos]
of ';':
tok.kind = tkSemicolon
inc(c.bufpos)
add(tok.literal, ';')
of ',':
tok.kind = tkComma
inc(c.bufpos)
add(tok.literal, ',')
of ':':
tok.kind = tkColon
inc(c.bufpos)
add(tok.literal, ':')
of 'e', 'E':
if c.buf[c.bufpos + 1] == '\'':
inc(c.bufpos)
getString(c, tok, tkEscapeConstant)
else:
getSymbol(c, tok)
of 'b', 'B':
if c.buf[c.bufpos + 1] == '\'':
tok.kind = tkBitStringConstant
getBitHexString(c, tok, {'0'..'1'})
else:
getSymbol(c, tok)
of 'x', 'X':
if c.buf[c.bufpos + 1] == '\'':
tok.kind = tkHexStringConstant
getBitHexString(c, tok, {'a'..'f','A'..'F','0'..'9'})
else:
getSymbol(c, tok)
of '$': getDollarString(c, tok)
of '[':
tok.kind = tkBracketLe
inc(c.bufpos)
add(tok.literal, '[')
of ']':
tok.kind = tkBracketRi
inc(c.bufpos)
add(tok.literal, ']')
of '(':
tok.kind = tkParLe
inc(c.bufpos)
add(tok.literal, '(')
of ')':
tok.kind = tkParRi
inc(c.bufpos)
add(tok.literal, ')')
of '.':
if c.buf[c.bufpos + 1] in Digits:
getNumeric(c, tok)
else:
tok.kind = tkDot
inc(c.bufpos)
add(tok.literal, '.')
of '0'..'9': getNumeric(c, tok)
of '\'': getString(c, tok, tkStringConstant)
of '"': getQuotedIdentifier(c, tok)
of lexbase.EndOfFile:
tok.kind = tkEof
tok.literal = "[EOF]"
of 'a', 'c', 'd', 'f'..'w', 'y', 'z', 'A', 'C', 'D', 'F'..'W', 'Y', 'Z', '_',
'\128'..'\255':
getSymbol(c, tok)
of '+', '-', '*', '/', '<', '>', '=', '~', '!', '@', '#', '%',
'^', '&', '|', '`', '?':
getOperator(c, tok)
else:
add(tok.literal, c.buf[c.bufpos])
inc(c.bufpos)
proc errorStr(L: SqlLexer, msg: string): string =
result = "$1($2, $3) Error: $4" % [L.filename, $getLine(L), $getColumn(L), msg]
# ----------------------------- parser ----------------------------------------
# Operator/Element Associativity Description
# . left table/column name separator
# :: left PostgreSQL-style typecast
# [ ] left array element selection
# - right unary minus
# ^ left exponentiation
# * / % left multiplication, division, modulo
# + - left addition, subtraction
# IS IS TRUE, IS FALSE, IS UNKNOWN, IS NULL
# ISNULL test for null
# NOTNULL test for not null
# (any other) left all other native and user-defined oprs
# IN set membership
# BETWEEN range containment
# OVERLAPS time interval overlap
# LIKE ILIKE SIMILAR string pattern matching
# < > less than, greater than
# = right equality, assignment
# NOT right logical negation
# AND left logical conjunction
# OR left logical disjunction
type
SqlNodeKind* = enum ## kind of SQL abstract syntax tree
nkNone,
nkIdent,
nkStringLit,
nkBitStringLit,
nkHexStringLit,
nkIntegerLit,
nkNumericLit,
nkPrimaryKey,
nkForeignKey,
nkNotNull,
nkNull,
nkStmtList,
nkDot,
nkDotDot,
nkPrefix,
nkInfix,
nkCall,
nkColumnReference,
nkReferences,
nkDefault,
nkCheck,
nkConstraint,
nkUnique,
nkIdentity,
nkColumnDef, ## name, datatype, constraints
nkInsert,
nkUpdate,
nkDelete,
nkSelect,
nkSelectDistinct,
nkSelectColumns,
nkSelectPair,
nkAsgn,
nkFrom,
nkFromItemPair,
nkGroup,
nkLimit,
nkHaving,
nkOrder,
nkJoin,
nkDesc,
nkUnion,
nkIntersect,
nkExcept,
nkColumnList,
nkValueList,
nkWhere,
nkCreateTable,
nkCreateTableIfNotExists,
nkCreateType,
nkCreateTypeIfNotExists,
nkCreateIndex,
nkCreateIndexIfNotExists,
nkEnumDef
type
SqlParseError* = object of ValueError ## Invalid SQL encountered
SqlNode* = ref SqlNodeObj ## an SQL abstract syntax tree node
SqlNodeObj* = object ## an SQL abstract syntax tree node
case kind*: SqlNodeKind ## kind of syntax tree
of nkIdent, nkStringLit, nkBitStringLit, nkHexStringLit,
nkIntegerLit, nkNumericLit:
strVal*: string ## AST leaf: the identifier, numeric literal
## string literal, etc.
else:
sons*: seq[SqlNode] ## the node's children
SqlParser* = object of SqlLexer ## SQL parser object
tok: Token
{.deprecated: [EInvalidSql: SqlParseError, PSqlNode: SqlNode,
TSqlNode: SqlNodeObj, TSqlParser: SqlParser, TSqlNodeKind: SqlNodeKind].}
proc newNode(k: SqlNodeKind): SqlNode =
new(result)
result.kind = k
proc newNode(k: SqlNodeKind, s: string): SqlNode =
new(result)
result.kind = k
result.strVal = s
proc len*(n: SqlNode): int =
if n.kind in {nkIdent, nkStringLit, nkBitStringLit, nkHexStringLit,
nkIntegerLit, nkNumericLit}:
result = 0
else:
result = n.sons.len
proc `[]`*(n: SqlNode; i: int): SqlNode = n.sons[i]
proc add*(father, n: SqlNode) =
if isNil(father.sons): father.sons = @[]
add(father.sons, n)
proc getTok(p: var SqlParser) =
getTok(p, p.tok)
proc sqlError(p: SqlParser, msg: string) =
var e: ref SqlParseError
new(e)
e.msg = errorStr(p, msg)
raise e
proc isKeyw(p: SqlParser, keyw: string): bool =
result = p.tok.kind == tkIdentifier and
cmpIgnoreCase(p.tok.literal, keyw) == 0
proc isOpr(p: SqlParser, opr: string): bool =
result = p.tok.kind == tkOperator and
cmpIgnoreCase(p.tok.literal, opr) == 0
proc optKeyw(p: var SqlParser, keyw: string) =
if p.tok.kind == tkIdentifier and cmpIgnoreCase(p.tok.literal, keyw) == 0:
getTok(p)
proc expectIdent(p: SqlParser) =
if p.tok.kind != tkIdentifier and p.tok.kind != tkQuotedIdentifier:
sqlError(p, "identifier expected")
proc expect(p: SqlParser, kind: TokKind) =
if p.tok.kind != kind:
sqlError(p, tokKindToStr[kind] & " expected")
proc eat(p: var SqlParser, kind: TokKind) =
if p.tok.kind == kind:
getTok(p)
else:
sqlError(p, tokKindToStr[kind] & " expected")
proc eat(p: var SqlParser, keyw: string) =
if isKeyw(p, keyw):
getTok(p)
else:
sqlError(p, keyw.toUpper() & " expected")
proc opt(p: var SqlParser, kind: TokKind) =
if p.tok.kind == kind: getTok(p)
proc parseDataType(p: var SqlParser): SqlNode =
if isKeyw(p, "enum"):
result = newNode(nkEnumDef)
getTok(p)
if p.tok.kind == tkParLe:
getTok(p)
result.add(newNode(nkStringLit, p.tok.literal))
getTok(p)
while p.tok.kind == tkComma:
getTok(p)
result.add(newNode(nkStringLit, p.tok.literal))
getTok(p)
eat(p, tkParRi)
else:
expectIdent(p)
result = newNode(nkIdent, p.tok.literal)
getTok(p)
# ignore (12, 13) part:
if p.tok.kind == tkParLe:
getTok(p)
expect(p, tkInteger)
getTok(p)
while p.tok.kind == tkComma:
getTok(p)
expect(p, tkInteger)
getTok(p)
eat(p, tkParRi)
proc getPrecedence(p: SqlParser): int =
if isOpr(p, "*") or isOpr(p, "/") or isOpr(p, "%"):
result = 6
elif isOpr(p, "+") or isOpr(p, "-"):
result = 5
elif isOpr(p, "=") or isOpr(p, "<") or isOpr(p, ">") or isOpr(p, ">=") or
isOpr(p, "<=") or isOpr(p, "<>") or isOpr(p, "!=") or isKeyw(p, "is") or
isKeyw(p, "like"):
result = 3
elif isKeyw(p, "and"):
result = 2
elif isKeyw(p, "or"):
result = 1
elif p.tok.kind == tkOperator:
# user-defined operator:
result = 0
else:
result = - 1
proc parseExpr(p: var SqlParser): SqlNode
proc parseSelect(p: var SqlParser): SqlNode
proc identOrLiteral(p: var SqlParser): SqlNode =
case p.tok.kind
of tkIdentifier, tkQuotedIdentifier:
result = newNode(nkIdent, p.tok.literal)
getTok(p)
of tkStringConstant, tkEscapeConstant, tkDollarQuotedConstant:
result = newNode(nkStringLit, p.tok.literal)
getTok(p)
of tkBitStringConstant:
result = newNode(nkBitStringLit, p.tok.literal)
getTok(p)
of tkHexStringConstant:
result = newNode(nkHexStringLit, p.tok.literal)
getTok(p)
of tkInteger:
result = newNode(nkIntegerLit, p.tok.literal)
getTok(p)
of tkNumeric:
result = newNode(nkNumericLit, p.tok.literal)
getTok(p)
of tkParLe:
getTok(p)
result = parseExpr(p)
eat(p, tkParRi)
else:
sqlError(p, "expression expected")
getTok(p) # we must consume a token here to prevend endless loops!
proc primary(p: var SqlParser): SqlNode =
if p.tok.kind == tkOperator or isKeyw(p, "not"):
result = newNode(nkPrefix)
result.add(newNode(nkIdent, p.tok.literal))
getTok(p)
result.add(primary(p))
return
result = identOrLiteral(p)
while true:
case p.tok.kind
of tkParLe:
var a = result
result = newNode(nkCall)
result.add(a)
getTok(p)
while p.tok.kind != tkParRi:
result.add(parseExpr(p))
if p.tok.kind == tkComma: getTok(p)
else: break
eat(p, tkParRi)
of tkDot:
getTok(p)
var a = result
if p.tok.kind == tkDot:
getTok(p)
result = newNode(nkDotDot)
else:
result = newNode(nkDot)
result.add(a)
if isOpr(p, "*"):
result.add(newNode(nkIdent, "*"))
elif p.tok.kind in {tkIdentifier, tkQuotedIdentifier}:
result.add(newNode(nkIdent, p.tok.literal))
else:
sqlError(p, "identifier expected")
getTok(p)
else: break
proc lowestExprAux(p: var SqlParser, v: var SqlNode, limit: int): int =
var
v2, node, opNode: SqlNode
v = primary(p) # expand while operators have priorities higher than 'limit'
var opPred = getPrecedence(p)
result = opPred
while opPred > limit:
node = newNode(nkInfix)
opNode = newNode(nkIdent, p.tok.literal)
getTok(p)
result = lowestExprAux(p, v2, opPred)
node.add(opNode)
node.add(v)
node.add(v2)
v = node
opPred = getPrecedence(p)
proc parseExpr(p: var SqlParser): SqlNode =
discard lowestExprAux(p, result, - 1)
proc parseTableName(p: var SqlParser): SqlNode =
expectIdent(p)
result = primary(p)
proc parseColumnReference(p: var SqlParser): SqlNode =
result = parseTableName(p)
if p.tok.kind == tkParLe:
getTok(p)
var a = result
result = newNode(nkColumnReference)
result.add(a)
result.add(parseTableName(p))
while p.tok.kind == tkComma:
getTok(p)
result.add(parseTableName(p))
eat(p, tkParRi)
proc parseCheck(p: var SqlParser): SqlNode =
getTok(p)
result = newNode(nkCheck)
result.add(parseExpr(p))
proc parseConstraint(p: var SqlParser): SqlNode =
getTok(p)
result = newNode(nkConstraint)
expectIdent(p)
result.add(newNode(nkIdent, p.tok.literal))
getTok(p)
optKeyw(p, "check")
result.add(parseExpr(p))
proc parseParIdentList(p: var SqlParser, father: SqlNode) =
eat(p, tkParLe)
while true:
expectIdent(p)
father.add(newNode(nkIdent, p.tok.literal))
getTok(p)
if p.tok.kind != tkComma: break
getTok(p)
eat(p, tkParRi)
proc parseColumnConstraints(p: var SqlParser, result: SqlNode) =
while true:
if isKeyw(p, "default"):
getTok(p)
var n = newNode(nkDefault)
n.add(parseExpr(p))
result.add(n)
elif isKeyw(p, "references"):
getTok(p)
var n = newNode(nkReferences)
n.add(parseColumnReference(p))
result.add(n)
elif isKeyw(p, "not"):
getTok(p)
eat(p, "null")
result.add(newNode(nkNotNull))
elif isKeyw(p, "null"):
getTok(p)
result.add(newNode(nkNull))
elif isKeyw(p, "identity"):
getTok(p)
result.add(newNode(nkIdentity))
elif isKeyw(p, "primary"):
getTok(p)
eat(p, "key")
result.add(newNode(nkPrimaryKey))
elif isKeyw(p, "check"):
result.add(parseCheck(p))
elif isKeyw(p, "constraint"):
result.add(parseConstraint(p))
elif isKeyw(p, "unique"):
getTok(p)
result.add(newNode(nkUnique))
else:
break
proc parseColumnDef(p: var SqlParser): SqlNode =
expectIdent(p)
result = newNode(nkColumnDef)
result.add(newNode(nkIdent, p.tok.literal))
getTok(p)
result.add(parseDataType(p))
parseColumnConstraints(p, result)
proc parseIfNotExists(p: var SqlParser, k: SqlNodeKind): SqlNode =
getTok(p)
if isKeyw(p, "if"):
getTok(p)
eat(p, "not")
eat(p, "exists")
result = newNode(succ(k))
else:
result = newNode(k)
proc parseTableConstraint(p: var SqlParser): SqlNode =
if isKeyw(p, "primary"):
getTok(p)
eat(p, "key")
result = newNode(nkPrimaryKey)
parseParIdentList(p, result)
elif isKeyw(p, "foreign"):
getTok(p)
eat(p, "key")
result = newNode(nkForeignKey)
parseParIdentList(p, result)
eat(p, "references")
var m = newNode(nkReferences)
m.add(parseColumnReference(p))
result.add(m)
elif isKeyw(p, "unique"):
getTok(p)
eat(p, "key")
result = newNode(nkUnique)
parseParIdentList(p, result)
elif isKeyw(p, "check"):
result = parseCheck(p)
elif isKeyw(p, "constraint"):
result = parseConstraint(p)
else:
sqlError(p, "column definition expected")
proc parseUnique(p: var SqlParser): SqlNode =
result = parseExpr(p)
if result.kind == nkCall: result.kind = nkUnique
proc parseTableDef(p: var SqlParser): SqlNode =
result = parseIfNotExists(p, nkCreateTable)
expectIdent(p)
result.add(newNode(nkIdent, p.tok.literal))
getTok(p)
if p.tok.kind == tkParLe:
getTok(p)
while p.tok.kind != tkParRi:
if isKeyw(p, "constraint"):
result.add parseConstraint(p)
elif isKeyw(p, "primary") or isKeyw(p, "foreign"):
result.add parseTableConstraint(p)
elif isKeyw(p, "unique"):
result.add parseUnique(p)
elif p.tok.kind == tkIdentifier or p.tok.kind == tkQuotedIdentifier:
result.add(parseColumnDef(p))
else:
result.add(parseTableConstraint(p))
if p.tok.kind != tkComma: break
getTok(p)
eat(p, tkParRi)
# skip additional crap after 'create table (...) crap;'
while p.tok.kind notin {tkSemicolon, tkEof}:
getTok(p)
proc parseTypeDef(p: var SqlParser): SqlNode =
result = parseIfNotExists(p, nkCreateType)
expectIdent(p)
result.add(newNode(nkIdent, p.tok.literal))
getTok(p)
eat(p, "as")
result.add(parseDataType(p))
proc parseWhere(p: var SqlParser): SqlNode =
getTok(p)
result = newNode(nkWhere)
result.add(parseExpr(p))
proc parseFromItem(p: var SqlParser): SqlNode =
result = newNode(nkFromItemPair)
if p.tok.kind == tkParLe:
getTok(p)
var select = parseSelect(p)
result.add(select)
eat(p, tkParRi)
else:
result.add(parseExpr(p))
if isKeyw(p, "as"):
getTok(p)
result.add(parseExpr(p))
proc parseIndexDef(p: var SqlParser): SqlNode =
result = parseIfNotExists(p, nkCreateIndex)
if isKeyw(p, "primary"):
getTok(p)
eat(p, "key")
result.add(newNode(nkPrimaryKey))
else:
expectIdent(p)
result.add(newNode(nkIdent, p.tok.literal))
getTok(p)
eat(p, "on")
expectIdent(p)
result.add(newNode(nkIdent, p.tok.literal))
getTok(p)
eat(p, tkParLe)
expectIdent(p)
result.add(newNode(nkIdent, p.tok.literal))
getTok(p)
while p.tok.kind == tkComma:
getTok(p)
expectIdent(p)
result.add(newNode(nkIdent, p.tok.literal))
getTok(p)
eat(p, tkParRi)
proc parseInsert(p: var SqlParser): SqlNode =
getTok(p)
eat(p, "into")
expectIdent(p)
result = newNode(nkInsert)
result.add(newNode(nkIdent, p.tok.literal))
getTok(p)
if p.tok.kind == tkParLe:
var n = newNode(nkColumnList)
parseParIdentList(p, n)
result.add n
else:
result.add(nil)
if isKeyw(p, "default"):
getTok(p)
eat(p, "values")
result.add(newNode(nkDefault))
else:
eat(p, "values")
eat(p, tkParLe)
var n = newNode(nkValueList)
while true:
n.add(parseExpr(p))
if p.tok.kind != tkComma: break
getTok(p)
result.add(n)
eat(p, tkParRi)
proc parseUpdate(p: var SqlParser): SqlNode =
getTok(p)
result = newNode(nkUpdate)
result.add(primary(p))
eat(p, "set")
while true:
var a = newNode(nkAsgn)
expectIdent(p)
a.add(newNode(nkIdent, p.tok.literal))
getTok(p)
if isOpr(p, "="): getTok(p)
else: sqlError(p, "= expected")
a.add(parseExpr(p))
result.add(a)
if p.tok.kind != tkComma: break
getTok(p)
if isKeyw(p, "where"):
result.add(parseWhere(p))
else:
result.add(nil)
proc parseDelete(p: var SqlParser): SqlNode =
getTok(p)
result = newNode(nkDelete)
eat(p, "from")
result.add(primary(p))
if isKeyw(p, "where"):
result.add(parseWhere(p))
else:
result.add(nil)
proc parseSelect(p: var SqlParser): SqlNode =
getTok(p)
if isKeyw(p, "distinct"):
getTok(p)
result = newNode(nkSelectDistinct)
elif isKeyw(p, "all"):
getTok(p)
result = newNode(nkSelect)
var a = newNode(nkSelectColumns)
while true:
if isOpr(p, "*"):
a.add(newNode(nkIdent, "*"))
getTok(p)
else:
var pair = newNode(nkSelectPair)
pair.add(parseExpr(p))
a.add(pair)
if isKeyw(p, "as"):
getTok(p)
pair.add(parseExpr(p))
if p.tok.kind != tkComma: break
getTok(p)
result.add(a)
if isKeyw(p, "from"):
var f = newNode(nkFrom)
while true:
getTok(p)
f.add(parseFromItem(p))
if p.tok.kind != tkComma: break
result.add(f)
if isKeyw(p, "where"):
result.add(parseWhere(p))
if isKeyw(p, "group"):
getTok(p)
eat(p, "by")
var g = newNode(nkGroup)
while true:
g.add(parseExpr(p))
if p.tok.kind != tkComma: break
getTok(p)
result.add(g)
if isKeyw(p, "limit"):
getTok(p)
var l = newNode(nkLimit)
l.add(parseExpr(p))
result.add(l)
if isKeyw(p, "having"):
var h = newNode(nkHaving)
while true:
getTok(p)
h.add(parseExpr(p))
if p.tok.kind != tkComma: break
result.add(h)
if isKeyw(p, "union"):
result.add(newNode(nkUnion))
getTok(p)
elif isKeyw(p, "intersect"):
result.add(newNode(nkIntersect))
getTok(p)
elif isKeyw(p, "except"):
result.add(newNode(nkExcept))
getTok(p)
if isKeyw(p, "order"):
getTok(p)
eat(p, "by")
var n = newNode(nkOrder)
while true:
var e = parseExpr(p)
if isKeyw(p, "asc"): getTok(p) # is default
elif isKeyw(p, "desc"):
getTok(p)
var x = newNode(nkDesc)
x.add(e)
e = x
n.add(e)
if p.tok.kind != tkComma: break
getTok(p)
result.add(n)
if isKeyw(p, "join") or isKeyw(p, "inner") or isKeyw(p, "outer") or isKeyw(p, "cross"):
var join = newNode(nkJoin)
result.add(join)
if isKeyw(p, "join"):
join.add(newNode(nkIdent, ""))
getTok(p)
else:
join.add(parseExpr(p))
eat(p, "join")
join.add(parseFromItem(p))
eat(p, "on")
join.add(parseExpr(p))
proc parseStmt(p: var SqlParser; parent: SqlNode) =
if isKeyw(p, "create"):
getTok(p)
optKeyw(p, "cached")
optKeyw(p, "memory")
optKeyw(p, "temp")
optKeyw(p, "global")
optKeyw(p, "local")
optKeyw(p, "temporary")
optKeyw(p, "unique")
optKeyw(p, "hash")
if isKeyw(p, "table"):
parent.add parseTableDef(p)
elif isKeyw(p, "type"):
parent.add parseTypeDef(p)
elif isKeyw(p, "index"):
parent.add parseIndexDef(p)
else:
sqlError(p, "TABLE expected")
elif isKeyw(p, "insert"):
parent.add parseInsert(p)
elif isKeyw(p, "update"):
parent.add parseUpdate(p)
elif isKeyw(p, "delete"):
parent.add parseDelete(p)
elif isKeyw(p, "select"):
parent.add parseSelect(p)
elif isKeyw(p, "begin"):
getTok(p)
else:
sqlError(p, "SELECT, CREATE, UPDATE or DELETE expected")
proc open(p: var SqlParser, input: Stream, filename: string) =
## opens the parser `p` and assigns the input stream `input` to it.
## `filename` is only used for error messages.
open(SqlLexer(p), input, filename)
p.tok.kind = tkInvalid
p.tok.literal = ""
getTok(p)
proc parse(p: var SqlParser): SqlNode =
## parses the content of `p`'s input stream and returns the SQL AST.
## Syntax errors raise an `EInvalidSql` exception.
result = newNode(nkStmtList)
while p.tok.kind != tkEof:
parseStmt(p, result)
if p.tok.kind == tkEof:
break
eat(p, tkSemicolon)
if result.len == 1:
result = result.sons[0]
proc close(p: var SqlParser) =
## closes the parser `p`. The associated input stream is closed too.
close(SqlLexer(p))
proc parseSQL*(input: Stream, filename: string): SqlNode =
## parses the SQL from `input` into an AST and returns the AST.
## `filename` is only used for error messages.
## Syntax errors raise an `EInvalidSql` exception.
var p: SqlParser
open(p, input, filename)
try:
result = parse(p)
finally:
close(p)
proc parseSQL*(input: string, filename=""): SqlNode =
## parses the SQL from `input` into an AST and returns the AST.
## `filename` is only used for error messages.
## Syntax errors raise an `EInvalidSql` exception.
parseSQL(newStringStream(input), "")
type
SqlWriter = object
indent: int
upperCase: bool
buffer: string
proc add(s: var SqlWriter, thing: string) =
s.buffer.add(thing)
proc add(s: var SqlWriter, thing: char) =
s.buffer.add(thing)
proc addKeyw(s: var SqlWriter, thing: string) =
if s.buffer.len > 0 and s.buffer[^1] notin " ,\L(":
s.buffer.add(" ")
if s.upperCase:
s.buffer.add(thing.toUpper())
else:
s.buffer.add(thing)
s.buffer.add(" ")
proc rm(s: var SqlWriter, chars = " \L,") =
while s.buffer[^1] in chars:
s.buffer = s.buffer[0..^2]
proc newLine(s: var SqlWriter) =
s.rm(" \L")
s.buffer.add("\L")
for i in 0..<s.indent:
s.buffer.add(" ")
template inner(s: var SqlWriter, body: untyped) =
inc s.indent
s.newLine()
body
dec s.indent
template innerKeyw(s: var SqlWriter, keyw: string, body: untyped) =
s.newLine()
s.addKeyw(keyw)
inc s.indent
s.newLine()
body
dec s.indent
proc ra(n: SqlNode, s: var SqlWriter)
proc rs(n: SqlNode, s: var SqlWriter, prefix = "(", suffix = ")", sep = ", ") =
if n.len > 0:
s.add(prefix)
for i in 0 .. n.len-1:
if i > 0: s.add(sep)
ra(n.sons[i], s)
s.add(suffix)
proc ra(n: SqlNode, s: var SqlWriter) =
if n == nil: return
case n.kind
of nkNone: discard
of nkIdent:
if allCharsInSet(n.strVal, {'\33'..'\127'}):
s.add(n.strVal)
else:
s.add("\"" & replace(n.strVal, "\"", "\"\"") & "\"")
of nkStringLit:
s.add(escape(n.strVal, "'", "'"))
of nkBitStringLit:
s.add("b'" & n.strVal & "'")
of nkHexStringLit:
s.add("x'" & n.strVal & "'")
of nkIntegerLit, nkNumericLit:
s.add(n.strVal)
of nkPrimaryKey:
s.addKeyw("primary key")
rs(n, s)
of nkForeignKey:
s.addKeyw("foreign key")
rs(n, s)
of nkNotNull:
s.addKeyw("not null")
of nkNull:
s.addKeyw("null")
of nkDot:
ra(n.sons[0], s)
s.add(".")
ra(n.sons[1], s)
of nkDotDot:
ra(n.sons[0], s)
s.add(". .")
ra(n.sons[1], s)
of nkPrefix:
s.add('(')
ra(n.sons[0], s)
s.add(' ')
ra(n.sons[1], s)
s.add(')')
of nkInfix:
s.add('(')
ra(n.sons[1], s)
s.add(' ')
ra(n.sons[0], s)
s.add(' ')
ra(n.sons[2], s)
s.add(')')
of nkCall, nkColumnReference:
ra(n.sons[0], s)
s.add('(')
for i in 1..n.len-1:
if i > 1: s.add(", ")
ra(n.sons[i], s)
s.add(')')
of nkReferences:
s.addKeyw("references")
ra(n.sons[0], s)
of nkDefault:
s.addKeyw("default")
ra(n.sons[0], s)
of nkCheck:
s.addKeyw("check")
ra(n.sons[0], s)
of nkConstraint:
s.addKeyw("constraint")
ra(n.sons[0], s)
s.addKeyw("check")
ra(n.sons[1], s)
of nkUnique:
s.addKeyw("unique")
rs(n, s)
of nkIdentity:
s.addKeyw("identity")
of nkColumnDef:
s.add("\n ")
rs(n, s, "", "", " ")
of nkStmtList:
for i in 0..n.len-1:
ra(n.sons[i], s)
s.add("\n")
of nkInsert:
assert n.len == 3
s.addKeyw("insert into")
ra(n.sons[0], s)
ra(n.sons[1], s)
if n.sons[2].kind == nkDefault:
s.addKeyw("default values")
else:
s.newLine()
ra(n.sons[2], s)
s.add(';')
of nkUpdate:
s.addKeyw("update")
ra(n.sons[0], s)
s.addKeyw("set")
var L = n.len
for i in 1 .. L-2:
if i > 1: s.add(", ")
var it = n.sons[i]
assert it.kind == nkAsgn
ra(it, s)
ra(n.sons[L-1], s)
s.add(';')
of nkDelete:
s.addKeyw("delete from")
ra(n.sons[0], s)
ra(n.sons[1], s)
s.add(';')
of nkSelect, nkSelectDistinct:
s.addKeyw("select")
if n.kind == nkSelectDistinct:
s.addKeyw("distinct")
s.inner:
for son in n.sons[0].sons:
ra(son, s)
s.add(',')
s.newLine()
s.rm()
for i in 1 .. n.len-1:
ra(n.sons[i], s)
s.add(';')
of nkSelectColumns:
assert(false)
of nkSelectPair:
ra(n.sons[0], s)
if n.sons.len == 2:
s.addKeyw("as")
ra(n.sons[1], s)
of nkFromItemPair:
if n.sons[0].kind == nkIdent:
ra(n.sons[0], s)
else:
assert n.sons[0].kind == nkSelect
s.add("(")
s.inner:
ra(n.sons[0], s)
s.rm("; \L")
s.newLine()
s.add(")")
if n.sons.len == 2:
s.addKeyw("as")
ra(n.sons[1], s)
of nkAsgn:
ra(n.sons[0], s)
s.add(" = ")
ra(n.sons[1], s)
of nkFrom:
s.innerKeyw("from"):
rs(n, s, "", "", ", ")
of nkGroup:
s.innerKeyw("group by"):
rs(n, s, "", "", ", ")
of nkLimit:
s.innerKeyw("limit"):
rs(n, s, "", "", ", ")
of nkHaving:
s.innerKeyw("having"):
rs(n, s, "", "", ", ")
of nkOrder:
s.addKeyw("order by")
rs(n, s, "", "", ", ")
of nkJoin:
var joinType = n.sons[0].strVal
if joinType == "":
joinType = "join"
else:
joinType &= " " & "join"
s.innerKeyw(joinType):
ra(n.sons[1], s)
s.innerKeyw("on"):
ra(n.sons[2], s)
of nkDesc:
ra(n.sons[0], s)
s.addKeyw("desc")
of nkUnion:
s.addKeyw("union")
of nkIntersect:
s.addKeyw("intersect")
of nkExcept:
s.addKeyw("except")
of nkColumnList:
rs(n, s)
of nkValueList:
s.addKeyw("values")
rs(n, s)
of nkWhere:
s.newLine()
s.addKeyw("where")
s.inner:
ra(n.sons[0], s)
of nkCreateTable, nkCreateTableIfNotExists:
s.addKeyw("create table")
if n.kind == nkCreateTableIfNotExists:
s.addKeyw("if not exists")
ra(n.sons[0], s)
s.add('(')
for i in 1..n.len-1:
if i > 1: s.add(",")
ra(n.sons[i], s)
s.add(");")
of nkCreateType, nkCreateTypeIfNotExists:
s.addKeyw("create type")
if n.kind == nkCreateTypeIfNotExists:
s.addKeyw("if not exists")
ra(n.sons[0], s)
s.addKeyw("as")
ra(n.sons[1], s)
s.add(';')
of nkCreateIndex, nkCreateIndexIfNotExists:
s.addKeyw("create index")
if n.kind == nkCreateIndexIfNotExists:
s.addKeyw("if not exists")
ra(n.sons[0], s)
s.addKeyw("on")
ra(n.sons[1], s)
s.add('(')
for i in 2..n.len-1:
if i > 2: s.add(", ")
ra(n.sons[i], s)
s.add(");")
of nkEnumDef:
s.addKeyw("enum")
rs(n, s)
proc renderSQL*(n: SqlNode, upperCase=false): string =
## Converts an SQL abstract syntax tree to its string representation.
var s: SqlWriter
s.buffer = ""
s.upperCase = upperCase
ra(n, s)
return s.buffer
proc `$`*(n: SqlNode): string =
## an alias for `renderSQL`.
renderSQL(n)