diff options
author | Bung <crc32@qq.com> | 2020-05-26 13:44:47 +0800 |
---|---|---|
committer | GitHub <noreply@github.com> | 2020-05-26 07:44:47 +0200 |
commit | 2a4a43b05fa3fa6ad3f85543ccfbd9ea450c9be0 (patch) | |
tree | 77d638a33367b6e053a85efb98e979e471d19616 /lib | |
parent | 55758920f446ebf1eecab980460bf6e664b8dba7 (diff) | |
download | Nim-2a4a43b05fa3fa6ad3f85543ccfbd9ea450c9be0.tar.gz |
add bindParams to db_sqlite (#14408)
* add bindParams to db_sqlite * no need typeinfo * remove extro spaces * reduce bindParams to two branches,raise DbError * Update lib/impure/db_sqlite.nim * change bindParams to macro,accept varargs[untyped] as params * change bind blob val to openArray[byte] * remove unused err type * explicitly using i32 param * using import std/private/since * SQLITE_OK to right hand * bindParam val int using bindParam overload * copy data by default * change exec to template * remove SqlPrepared procs unused varargs * fix setupquery for prepared,reset first for exec prepared,add bindNull for literal nil Co-authored-by: alaviss <leorize+oss@disroot.org>
Diffstat (limited to 'lib')
-rw-r--r-- | lib/impure/db_sqlite.nim | 200 |
1 files changed, 159 insertions, 41 deletions
diff --git a/lib/impure/db_sqlite.nim b/lib/impure/db_sqlite.nim index ff00dd86e..80eee59fb 100644 --- a/lib/impure/db_sqlite.nim +++ b/lib/impure/db_sqlite.nim @@ -113,7 +113,9 @@ ## * `db_mysql module <db_mysql.html>`_ for MySQL database wrapper ## * `db_postgres module <db_postgres.html>`_ for PostgreSQL database wrapper -import sqlite3 +{.experimental: "codeReordering".} + +import sqlite3, macros import db_common export db_common @@ -167,7 +169,8 @@ proc dbFormat(formatstr: SqlQuery, args: varargs[string]): string = else: add(result, c) -proc prepare*(db: DbConn; q: string): SqlPrepared = +proc prepare*(db: DbConn; q: string): SqlPrepared {.since: (1, 3).} = + ## Creates a new ``SqlPrepared`` statement. if prepare_v2(db, q, q.len.cint,result.PStmt, nil) != SQLITE_OK: discard finalize(result.PStmt) dbError(db) @@ -196,8 +199,7 @@ proc tryExec*(db: DbConn, query: SqlQuery, discard finalize(stmt) result = false -proc tryExec*(db: DbConn, stmtName: SqlPrepared, - args: varargs[string, `$`]): bool {. +proc tryExec*(db: DbConn, stmtName: SqlPrepared): bool {. tags: [ReadDbEffect, WriteDbEffect].} = let x = step(stmtName.PStmt) if x in {SQLITE_DONE, SQLITE_ROW}: @@ -224,10 +226,6 @@ proc exec*(db: DbConn, query: SqlQuery, args: varargs[string, `$`]) {. ## db.close() if not tryExec(db, query, args): dbError(db) -proc exec*(db: DbConn, stmtName: SqlPrepared, - args: varargs[string]) {.tags: [ReadDbEffect, WriteDbEffect].} = - if not tryExec(db, stmtName, args): dbError(db) - proc newRow(L: int): Row = newSeq(result, L) for i in 0..L-1: result[i] = "" @@ -238,10 +236,9 @@ proc setupQuery(db: DbConn, query: SqlQuery, var q = dbFormat(query, args) if prepare_v2(db, q, q.len.cint, result, nil) != SQLITE_OK: dbError(db) -proc setupQuery(db: DbConn, stmtName: SqlPrepared, - args: varargs[string]): PStmt = +proc setupQuery(db: DbConn, stmtName: SqlPrepared): SqlPrepared {.since: (1, 3).} = assert(not db.isNil, "Database not connected.") - if not tryExec(db, stmtName, args): dbError(db) + result = stmtName proc setRow(stmt: PStmt, r: var Row, cols: cint) = for col in 0'i32..cols-1: @@ -291,14 +288,14 @@ iterator fastRows*(db: DbConn, query: SqlQuery, finally: if finalize(stmt) != SQLITE_OK: dbError(db) -iterator fastRows*(db: DbConn, stmtName: SqlPrepared, - args: varargs[string, `$`]): Row {.tags: [ReadDbEffect,WriteDbEffect].} = - var stmt = setupQuery(db, stmtName, args) - var L = (column_count(stmt)) +iterator fastRows*(db: DbConn, stmtName: SqlPrepared): Row + {.tags: [ReadDbEffect,WriteDbEffect], since: (1, 3).} = + discard setupQuery(db, stmtName) + var L = (column_count(stmtName.PStmt)) var result = newRow(L) try: - while step(stmt) == SQLITE_ROW: - setRow(stmt, result, L) + while step(stmtName.PStmt) == SQLITE_ROW: + setRow(stmtName.PStmt, result, L) yield result except: dbError(db) @@ -343,10 +340,9 @@ iterator instantRows*(db: DbConn, query: SqlQuery, finally: if finalize(stmt) != SQLITE_OK: dbError(db) -iterator instantRows*(db: DbConn, stmtName: SqlPrepared, - args: varargs[string, `$`]): InstantRow - {.tags: [ReadDbEffect,WriteDbEffect].} = - var stmt = setupQuery(db, stmtName, args) +iterator instantRows*(db: DbConn, stmtName: SqlPrepared): InstantRow + {.tags: [ReadDbEffect,WriteDbEffect], since: (1, 3).} = + var stmt = setupQuery(db, stmtName).PStmt try: while step(stmt) == SQLITE_ROW: yield stmt @@ -495,10 +491,10 @@ proc getAllRows*(db: DbConn, query: SqlQuery, for r in fastRows(db, query, args): result.add(r) -proc getAllRows*(db: DbConn, stmtName: SqlPrepared, - args: varargs[string, `$`]): seq[Row] {.tags: [ReadDbEffect,WriteDbEffect].} = +proc getAllRows*(db: DbConn, stmtName: SqlPrepared): seq[Row] + {.tags: [ReadDbEffect,WriteDbEffect], since: (1, 3).} = result = @[] - for r in fastRows(db, stmtName, args): + for r in fastRows(db, stmtName): result.add(r) iterator rows*(db: DbConn, query: SqlQuery, @@ -528,9 +524,9 @@ iterator rows*(db: DbConn, query: SqlQuery, ## db.close() for r in fastRows(db, query, args): yield r -iterator rows*(db: DbConn, stmtName: SqlPrepared, - args: varargs[string, `$`]): Row {.tags: [ReadDbEffect,WriteDbEffect].} = - for r in fastRows(db, stmtName, args): yield r +iterator rows*(db: DbConn, stmtName: SqlPrepared): Row + {.tags: [ReadDbEffect,WriteDbEffect], since: (1, 3).} = + for r in fastRows(db, stmtName): yield r proc getValue*(db: DbConn, query: SqlQuery, args: varargs[string, `$`]): string {.tags: [ReadDbEffect].} = @@ -568,9 +564,9 @@ proc getValue*(db: DbConn, query: SqlQuery, result = "" if finalize(stmt) != SQLITE_OK: dbError(db) -proc getValue*(db: DbConn, stmtName: SqlPrepared, - args: varargs[string, `$`]): string {.tags: [ReadDbEffect,WriteDbEffect].} = - var stmt = setupQuery(db, stmtName, args) +proc getValue*(db: DbConn, stmtName: SqlPrepared): string + {.tags: [ReadDbEffect,WriteDbEffect], since: (1, 3).} = + var stmt = setupQuery(db, stmtName).PStmt if step(stmt) == SQLITE_ROW: let cb = column_bytes(stmt, 0) if cb == 0: @@ -675,10 +671,9 @@ proc execAffectedRows*(db: DbConn, query: SqlQuery, exec(db, query, args) result = changes(db) -proc execAffectedRows*(db: DbConn, stmtName: SqlPrepared, - args: varargs[string, `$`]): int64 {. - tags: [ReadDbEffect, WriteDbEffect].} = - exec(db, stmtName, args) +proc execAffectedRows*(db: DbConn, stmtName: SqlPrepared): int64 + {.tags: [ReadDbEffect, WriteDbEffect],since: (1, 3).} = + exec(db, stmtName) result = changes(db) proc close*(db: DbConn) {.tags: [DbEffect].} = @@ -728,22 +723,96 @@ proc setEncoding*(connection: DbConn, encoding: string): bool {. exec(connection, sql"PRAGMA encoding = ?", [encoding]) result = connection.getValue(sql"PRAGMA encoding") == encoding -proc finalize*(sqlPrepared:SqlPrepared){.discardable.} = +proc finalize*(sqlPrepared:SqlPrepared) {.discardable, since: (1, 3).} = discard finalize(sqlPrepared.PStmt) +template dbBindParamError*(paramIdx: int, val: varargs[untyped]) = + ## Raises a `DbError` exception. + var e: ref DbError + new(e) + e.msg = "error binding param in position " & $paramIdx + raise e + +proc bindParam*(ps: SqlPrepared, paramIdx: int, val: int32) {.since: (1, 3).} = + ## Binds a int32 to the specified paramIndex. + if bind_int(ps.PStmt, paramIdx.int32, val) != SQLITE_OK: + dbBindParamError(paramIdx, val) + +proc bindParam*(ps: SqlPrepared, paramIdx: int, val: int64) {.since: (1, 3).} = + ## Binds a int64 to the specified paramIndex. + if bind_int64(ps.PStmt, paramIdx.int32, val) != SQLITE_OK: + dbBindParamError(paramIdx, val) + +proc bindParam*(ps: SqlPrepared, paramIdx: int, val: int) {.since: (1, 3).} = + ## Binds a int to the specified paramIndex. + when sizeof(int) == 8: + bindParam(ps, paramIdx, val.int64) + else: + bindParam(ps, paramIdx, val.int32) + +proc bindParam*(ps: SqlPrepared, paramIdx: int, val: float64) {.since: (1, 3).} = + ## Binds a 64bit float to the specified paramIndex. + if bind_double(ps.PStmt, paramIdx.int32, val) != SQLITE_OK: + dbBindParamError(paramIdx, val) + +proc bindNull*(ps: SqlPrepared, paramIdx: int) {.since: (1, 3).} = + ## Sets the bindparam at the specified paramIndex to null + ## (default behaviour by sqlite). + if bind_null(ps.PStmt, paramIdx.int32) != SQLITE_OK: + dbBindParamError(paramIdx) + +proc bindParam*(ps: SqlPrepared, paramIdx: int, val: string, copy = true) {.since: (1, 3).} = + ## Binds a string to the specified paramIndex. + ## if copy is true then SQLite makes its own private copy of the data immediately + if bind_text(ps.PStmt, paramIdx.int32, val.cstring, val.len.int32, if copy: SQLITE_TRANSIENT else: SQLITE_STATIC) != SQLITE_OK: + dbBindParamError(paramIdx, val) + +proc bindParam*(ps: SqlPrepared, paramIdx: int,val: openArray[byte], copy = true) {.since: (1, 3).} = + ## binds a blob to the specified paramIndex. + ## if copy is true then SQLite makes its own private copy of the data immediately + let len = val.len + if bind_blob(ps.PStmt, paramIdx.int32, val[0].unsafeAddr, len.int32, if copy: SQLITE_TRANSIENT else: SQLITE_STATIC) != SQLITE_OK: + dbBindParamError(paramIdx, val) + +macro bindParams*(ps: SqlPrepared, params: varargs[untyped]): untyped {.since: (1, 3).} = + let bindParam = bindSym("bindParam", brOpen) + let bindNull = bindSym("bindNull") + let preparedStatement = genSym() + result = newStmtList() + # Store `ps` in a temporary variable. This prevents `ps` from being evaluated every call. + result.add newNimNode(nnkLetSection).add(newIdentDefs(preparedStatement, newEmptyNode(), ps)) + for idx, param in params: + if param.kind != nnkNilLit: + result.add newCall(bindParam, preparedStatement, newIntLitNode idx + 1, param) + else: + result.add newCall(bindNull, preparedStatement, newIntLitNode idx + 1) + +macro untypedLen(args: varargs[untyped]): int = + newLit(args.len) + +template exec*(db: DbConn, stmtName: SqlPrepared, + args: varargs[typed]): untyped = + when args.untypedLen > 0: + if reset(stmtName.PStmt) != SQLITE_OK: + dbError(db) + if clear_bindings(stmtName.PStmt) != SQLITE_OK: + dbError(db) + stmtName.bindParams(args) + if not tryExec(db, stmtName): dbError(db) + when not defined(testing) and isMainModule: - var db = open("db.sql", "", "", "") + var db = open(":memory:", "", "", "") exec(db, sql"create table tbl1(one varchar(10), two smallint)", []) exec(db, sql"insert into tbl1 values('hello!',10)", []) exec(db, sql"insert into tbl1 values('goodbye', 20)", []) var p1 = db.prepare "create table tbl2(one varchar(10), two smallint)" - exec(db, p1, []) + exec(db, p1) finalize(p1) var p2 = db.prepare "insert into tbl2 values('hello!',10)" - exec(db, p2, []) + exec(db, p2) finalize(p2) var p3 = db.prepare "insert into tbl2 values('goodbye', 20)" - exec(db, p3, []) + exec(db, p3) finalize(p3) #db.query("create table tbl1(one varchar(10), two smallint)") #db.query("insert into tbl1 values('hello!',10)") @@ -753,17 +822,66 @@ when not defined(testing) and isMainModule: for r in db.instantRows(sql"select * from tbl1", []): echo(r[0], r[1]) var p4 = db.prepare "select * from tbl2" - for r in db.rows(p4, []): + for r in db.rows(p4): echo(r[0], r[1]) finalize(p4) + var i5 = 0 var p5 = db.prepare "select * from tbl2" - for r in db.instantRows(p5, []): + for r in db.instantRows(p5): + inc i5 echo(r[0], r[1]) + assert i5 == 2 finalize(p5) for r in db.rows(sql"select * from tbl2", []): echo(r[0], r[1]) for r in db.instantRows(sql"select * from tbl2", []): echo(r[0], r[1]) + var p6 = db.prepare "select * from tbl2 where one = ? " + p6.bindParams("goodbye") + var rowsP3 = 0 + for r in db.rows(p6): + rowsP3 = 1 + echo(r[0], r[1]) + assert rowsP3 == 1 + finalize(p6) + + var p7 = db.prepare "select * from tbl2 where two=?" + p7.bindParams(20'i32) + when sizeof(int) == 4: + p7.bindParams(20) + var rowsP = 0 + for r in db.rows(p7): + rowsP = 1 + echo(r[0], r[1]) + assert rowsP == 1 + finalize(p7) + + exec(db, sql"CREATE TABLE photos(ID INTEGER PRIMARY KEY AUTOINCREMENT, photo BLOB)") + var p8 = db.prepare "INSERT INTO photos (ID,PHOTO) VALUES (?,?)" + var d = "abcdefghijklmnopqrstuvwxyz" + p8.bindParams(1'i32, "abcdefghijklmnopqrstuvwxyz") + exec(db, p8) + finalize(p8) + var p10 = db.prepare "INSERT INTO photos (ID,PHOTO) VALUES (?,?)" + p10.bindParams(2'i32,nil) + exec(db, p10) + exec( db, p10, 3, nil) + finalize(p10) + for r in db.rows(sql"select * from photos where ID = 1", []): + assert r[1].len == d.len + assert r[1] == d + var i6 = 0 + for r in db.rows(sql"select * from photos where ID = 3", []): + i6 = 1 + assert i6 == 1 + var p9 = db.prepare("select * from photos where PHOTO is ?") + p9.bindParams(nil) + var rowsP2 = 0 + for r in db.rows(p9): + rowsP2 = 1 + echo(r[0], repr r[1]) + assert rowsP2 == 1 + finalize(p9) db_sqlite.close(db) |