summary refs log tree commit diff stats
path: root/lib
diff options
context:
space:
mode:
authorBung <crc32@qq.com>2020-05-26 13:44:47 +0800
committerGitHub <noreply@github.com>2020-05-26 07:44:47 +0200
commit2a4a43b05fa3fa6ad3f85543ccfbd9ea450c9be0 (patch)
tree77d638a33367b6e053a85efb98e979e471d19616 /lib
parent55758920f446ebf1eecab980460bf6e664b8dba7 (diff)
downloadNim-2a4a43b05fa3fa6ad3f85543ccfbd9ea450c9be0.tar.gz
add bindParams to db_sqlite (#14408)
* add bindParams to db_sqlite
* no need typeinfo
* remove extro spaces
* reduce bindParams to two branches,raise DbError
* Update lib/impure/db_sqlite.nim
* change bindParams to macro,accept varargs[untyped] as params
* change bind blob val to openArray[byte]
* remove unused err type
* explicitly using i32 param
* using import std/private/since
* SQLITE_OK to right hand

* bindParam val int using bindParam overload
* copy data by default
* change exec to template
* remove SqlPrepared procs unused varargs
* fix setupquery for prepared,reset first for exec prepared,add bindNull for literal nil

Co-authored-by: alaviss <leorize+oss@disroot.org>
Diffstat (limited to 'lib')
-rw-r--r--lib/impure/db_sqlite.nim200
1 files changed, 159 insertions, 41 deletions
diff --git a/lib/impure/db_sqlite.nim b/lib/impure/db_sqlite.nim
index ff00dd86e..80eee59fb 100644
--- a/lib/impure/db_sqlite.nim
+++ b/lib/impure/db_sqlite.nim
@@ -113,7 +113,9 @@
 ## * `db_mysql module <db_mysql.html>`_ for MySQL database wrapper
 ## * `db_postgres module <db_postgres.html>`_ for PostgreSQL database wrapper
 
-import sqlite3
+{.experimental: "codeReordering".}
+
+import sqlite3, macros
 
 import db_common
 export db_common
@@ -167,7 +169,8 @@ proc dbFormat(formatstr: SqlQuery, args: varargs[string]): string =
     else:
       add(result, c)
 
-proc prepare*(db: DbConn; q: string): SqlPrepared =
+proc prepare*(db: DbConn; q: string): SqlPrepared {.since: (1, 3).} =
+  ## Creates a new ``SqlPrepared`` statement.
   if prepare_v2(db, q, q.len.cint,result.PStmt, nil) != SQLITE_OK:
     discard finalize(result.PStmt)
     dbError(db)
@@ -196,8 +199,7 @@ proc tryExec*(db: DbConn, query: SqlQuery,
       discard finalize(stmt)
       result = false
 
-proc tryExec*(db: DbConn, stmtName: SqlPrepared,
-              args: varargs[string, `$`]): bool {.
+proc tryExec*(db: DbConn, stmtName: SqlPrepared): bool {.
               tags: [ReadDbEffect, WriteDbEffect].} =
     let x = step(stmtName.PStmt)
     if x in {SQLITE_DONE, SQLITE_ROW}:
@@ -224,10 +226,6 @@ proc exec*(db: DbConn, query: SqlQuery, args: varargs[string, `$`])  {.
   ##      db.close()
   if not tryExec(db, query, args): dbError(db)
 
-proc exec*(db: DbConn, stmtName: SqlPrepared,
-          args: varargs[string]) {.tags: [ReadDbEffect, WriteDbEffect].} =
-    if not tryExec(db, stmtName, args): dbError(db)
-
 proc newRow(L: int): Row =
   newSeq(result, L)
   for i in 0..L-1: result[i] = ""
@@ -238,10 +236,9 @@ proc setupQuery(db: DbConn, query: SqlQuery,
   var q = dbFormat(query, args)
   if prepare_v2(db, q, q.len.cint, result, nil) != SQLITE_OK: dbError(db)
 
-proc setupQuery(db: DbConn, stmtName: SqlPrepared,
-                args: varargs[string]): PStmt =
+proc setupQuery(db: DbConn, stmtName: SqlPrepared): SqlPrepared {.since: (1, 3).} =
   assert(not db.isNil, "Database not connected.")
-  if not tryExec(db, stmtName, args): dbError(db)
+  result = stmtName
 
 proc setRow(stmt: PStmt, r: var Row, cols: cint) =
   for col in 0'i32..cols-1:
@@ -291,14 +288,14 @@ iterator fastRows*(db: DbConn, query: SqlQuery,
   finally:
     if finalize(stmt) != SQLITE_OK: dbError(db)
 
-iterator fastRows*(db: DbConn, stmtName: SqlPrepared,
-                   args: varargs[string, `$`]): Row {.tags: [ReadDbEffect,WriteDbEffect].} =
-  var stmt = setupQuery(db, stmtName, args)
-  var L = (column_count(stmt))
+iterator fastRows*(db: DbConn, stmtName: SqlPrepared): Row 
+                  {.tags: [ReadDbEffect,WriteDbEffect], since: (1, 3).} =
+  discard setupQuery(db, stmtName)
+  var L = (column_count(stmtName.PStmt))
   var result = newRow(L)
   try:
-    while step(stmt) == SQLITE_ROW:
-      setRow(stmt, result, L)
+    while step(stmtName.PStmt) == SQLITE_ROW:
+      setRow(stmtName.PStmt, result, L)
       yield result
   except:
     dbError(db)
@@ -343,10 +340,9 @@ iterator instantRows*(db: DbConn, query: SqlQuery,
   finally:
     if finalize(stmt) != SQLITE_OK: dbError(db)
 
-iterator instantRows*(db: DbConn, stmtName: SqlPrepared,
-                      args: varargs[string, `$`]): InstantRow
-                      {.tags: [ReadDbEffect,WriteDbEffect].} =
-  var stmt = setupQuery(db, stmtName, args)
+iterator instantRows*(db: DbConn, stmtName: SqlPrepared): InstantRow
+                      {.tags: [ReadDbEffect,WriteDbEffect], since: (1, 3).} =
+  var stmt = setupQuery(db, stmtName).PStmt
   try:
     while step(stmt) == SQLITE_ROW:
       yield stmt
@@ -495,10 +491,10 @@ proc getAllRows*(db: DbConn, query: SqlQuery,
   for r in fastRows(db, query, args):
     result.add(r)
 
-proc getAllRows*(db: DbConn, stmtName: SqlPrepared,
-                 args: varargs[string, `$`]): seq[Row] {.tags: [ReadDbEffect,WriteDbEffect].} =
+proc getAllRows*(db: DbConn, stmtName: SqlPrepared): seq[Row] 
+                {.tags: [ReadDbEffect,WriteDbEffect], since: (1, 3).} =
   result = @[]
-  for r in fastRows(db, stmtName, args):
+  for r in fastRows(db, stmtName):
     result.add(r)
 
 iterator rows*(db: DbConn, query: SqlQuery,
@@ -528,9 +524,9 @@ iterator rows*(db: DbConn, query: SqlQuery,
   ##    db.close()
   for r in fastRows(db, query, args): yield r
 
-iterator rows*(db: DbConn, stmtName: SqlPrepared,
-               args: varargs[string, `$`]): Row {.tags: [ReadDbEffect,WriteDbEffect].} =
-  for r in fastRows(db, stmtName, args): yield r
+iterator rows*(db: DbConn, stmtName: SqlPrepared): Row 
+              {.tags: [ReadDbEffect,WriteDbEffect], since: (1, 3).} =
+  for r in fastRows(db, stmtName): yield r
 
 proc getValue*(db: DbConn, query: SqlQuery,
                args: varargs[string, `$`]): string {.tags: [ReadDbEffect].} =
@@ -568,9 +564,9 @@ proc getValue*(db: DbConn, query: SqlQuery,
     result = ""
   if finalize(stmt) != SQLITE_OK: dbError(db)
 
-proc getValue*(db: DbConn,  stmtName: SqlPrepared,
-               args: varargs[string, `$`]): string {.tags: [ReadDbEffect,WriteDbEffect].} =
-  var stmt = setupQuery(db, stmtName, args)
+proc getValue*(db: DbConn,  stmtName: SqlPrepared): string 
+              {.tags: [ReadDbEffect,WriteDbEffect], since: (1, 3).} =
+  var stmt = setupQuery(db, stmtName).PStmt
   if step(stmt) == SQLITE_ROW:
     let cb = column_bytes(stmt, 0)
     if cb == 0:
@@ -675,10 +671,9 @@ proc execAffectedRows*(db: DbConn, query: SqlQuery,
   exec(db, query, args)
   result = changes(db)
 
-proc execAffectedRows*(db: DbConn, stmtName: SqlPrepared,
-                       args: varargs[string, `$`]): int64 {.
-                       tags: [ReadDbEffect, WriteDbEffect].} =
-  exec(db, stmtName, args)
+proc execAffectedRows*(db: DbConn, stmtName: SqlPrepared): int64 
+                      {.tags: [ReadDbEffect, WriteDbEffect],since: (1, 3).} =
+  exec(db, stmtName)
   result = changes(db)
 
 proc close*(db: DbConn) {.tags: [DbEffect].} =
@@ -728,22 +723,96 @@ proc setEncoding*(connection: DbConn, encoding: string): bool {.
   exec(connection, sql"PRAGMA encoding = ?", [encoding])
   result = connection.getValue(sql"PRAGMA encoding") == encoding
 
-proc finalize*(sqlPrepared:SqlPrepared){.discardable.} = 
+proc finalize*(sqlPrepared:SqlPrepared) {.discardable, since: (1, 3).} = 
   discard finalize(sqlPrepared.PStmt)
 
+template dbBindParamError*(paramIdx: int, val: varargs[untyped]) =
+  ## Raises a `DbError` exception.
+  var e: ref DbError
+  new(e)
+  e.msg = "error binding param in position " & $paramIdx
+  raise e
+
+proc bindParam*(ps: SqlPrepared, paramIdx: int, val: int32) {.since: (1, 3).} =
+  ## Binds a int32  to the specified paramIndex.
+  if bind_int(ps.PStmt, paramIdx.int32, val) != SQLITE_OK:
+    dbBindParamError(paramIdx, val)
+
+proc bindParam*(ps: SqlPrepared, paramIdx: int, val: int64) {.since: (1, 3).} =
+  ## Binds a int64  to the specified paramIndex.
+  if bind_int64(ps.PStmt, paramIdx.int32, val) != SQLITE_OK:
+    dbBindParamError(paramIdx, val)
+
+proc bindParam*(ps: SqlPrepared, paramIdx: int, val: int) {.since: (1, 3).} =
+  ## Binds a int  to the specified paramIndex.
+  when sizeof(int) == 8:
+    bindParam(ps, paramIdx, val.int64)
+  else:
+    bindParam(ps, paramIdx, val.int32)
+
+proc bindParam*(ps: SqlPrepared, paramIdx: int, val: float64) {.since: (1, 3).} =
+  ## Binds a 64bit float to the specified paramIndex.
+  if bind_double(ps.PStmt, paramIdx.int32, val) != SQLITE_OK:
+    dbBindParamError(paramIdx, val)
+
+proc bindNull*(ps: SqlPrepared, paramIdx: int) {.since: (1, 3).} =
+  ## Sets the bindparam at the specified paramIndex to null 
+  ## (default behaviour by sqlite).
+  if bind_null(ps.PStmt, paramIdx.int32) != SQLITE_OK:
+    dbBindParamError(paramIdx)
+
+proc bindParam*(ps: SqlPrepared, paramIdx: int, val: string, copy = true) {.since: (1, 3).} =
+  ## Binds a string to the specified paramIndex.
+  ## if copy is true then SQLite makes its own private copy of the data immediately
+  if bind_text(ps.PStmt, paramIdx.int32, val.cstring, val.len.int32, if copy: SQLITE_TRANSIENT else: SQLITE_STATIC) != SQLITE_OK:
+    dbBindParamError(paramIdx, val)
+
+proc bindParam*(ps: SqlPrepared, paramIdx: int,val: openArray[byte], copy = true) {.since: (1, 3).} =
+  ## binds a blob to the specified paramIndex.
+  ## if copy is true then SQLite makes its own private copy of the data immediately
+  let len = val.len
+  if bind_blob(ps.PStmt, paramIdx.int32, val[0].unsafeAddr, len.int32, if copy: SQLITE_TRANSIENT else: SQLITE_STATIC) != SQLITE_OK:
+    dbBindParamError(paramIdx, val)
+
+macro bindParams*(ps: SqlPrepared, params: varargs[untyped]): untyped {.since: (1, 3).} =
+  let bindParam = bindSym("bindParam", brOpen)
+  let bindNull = bindSym("bindNull")
+  let preparedStatement = genSym()
+  result = newStmtList()
+  # Store `ps` in a temporary variable. This prevents `ps` from being evaluated every call.
+  result.add newNimNode(nnkLetSection).add(newIdentDefs(preparedStatement, newEmptyNode(), ps))
+  for idx, param in params:
+    if param.kind != nnkNilLit:
+      result.add newCall(bindParam, preparedStatement, newIntLitNode idx + 1, param)
+    else:
+      result.add newCall(bindNull, preparedStatement, newIntLitNode idx + 1)
+
+macro untypedLen(args: varargs[untyped]): int =
+  newLit(args.len)
+
+template exec*(db: DbConn, stmtName: SqlPrepared,
+          args: varargs[typed]): untyped =
+  when args.untypedLen > 0:
+    if reset(stmtName.PStmt) != SQLITE_OK:
+      dbError(db)
+    if clear_bindings(stmtName.PStmt) != SQLITE_OK:
+      dbError(db)
+    stmtName.bindParams(args)
+  if not tryExec(db, stmtName): dbError(db)
+
 when not defined(testing) and isMainModule:
-  var db = open("db.sql", "", "", "")
+  var db = open(":memory:", "", "", "")
   exec(db, sql"create table tbl1(one varchar(10), two smallint)", [])
   exec(db, sql"insert into tbl1 values('hello!',10)", [])
   exec(db, sql"insert into tbl1 values('goodbye', 20)", [])
   var p1 = db.prepare "create table tbl2(one varchar(10), two smallint)"
-  exec(db, p1, [])
+  exec(db, p1)
   finalize(p1)
   var p2 = db.prepare "insert into tbl2 values('hello!',10)"
-  exec(db, p2, [])
+  exec(db, p2)
   finalize(p2)
   var p3 = db.prepare "insert into tbl2 values('goodbye', 20)"
-  exec(db, p3, [])
+  exec(db, p3)
   finalize(p3)
   #db.query("create table tbl1(one varchar(10), two smallint)")
   #db.query("insert into tbl1 values('hello!',10)")
@@ -753,17 +822,66 @@ when not defined(testing) and isMainModule:
   for r in db.instantRows(sql"select * from tbl1", []):
     echo(r[0], r[1])
   var p4 =  db.prepare "select * from tbl2"
-  for r in db.rows(p4, []):
+  for r in db.rows(p4):
     echo(r[0], r[1])
   finalize(p4)
+  var i5 = 0
   var p5 =  db.prepare "select * from tbl2"
-  for r in db.instantRows(p5, []):
+  for r in db.instantRows(p5):
+    inc i5
     echo(r[0], r[1])
+  assert i5 == 2
   finalize(p5)
 
   for r in db.rows(sql"select * from tbl2", []):
     echo(r[0], r[1])
   for r in db.instantRows(sql"select * from tbl2", []):
     echo(r[0], r[1])
+  var p6 = db.prepare "select * from tbl2 where one = ? "
+  p6.bindParams("goodbye")
+  var rowsP3 = 0
+  for r in db.rows(p6):
+    rowsP3 = 1
+    echo(r[0], r[1])
+  assert rowsP3 == 1
+  finalize(p6)
+
+  var p7 = db.prepare "select * from tbl2 where two=?"
+  p7.bindParams(20'i32)
+  when sizeof(int) == 4:
+    p7.bindParams(20)
+  var rowsP = 0
+  for r in db.rows(p7):
+    rowsP = 1
+    echo(r[0], r[1])
+  assert rowsP == 1
+  finalize(p7)
+
+  exec(db, sql"CREATE TABLE photos(ID INTEGER PRIMARY KEY AUTOINCREMENT, photo BLOB)")
+  var p8 = db.prepare "INSERT INTO photos (ID,PHOTO) VALUES (?,?)"
+  var d = "abcdefghijklmnopqrstuvwxyz"
+  p8.bindParams(1'i32, "abcdefghijklmnopqrstuvwxyz")
+  exec(db, p8)
+  finalize(p8)
+  var p10 = db.prepare "INSERT INTO photos (ID,PHOTO) VALUES (?,?)"
+  p10.bindParams(2'i32,nil)
+  exec(db, p10)
+  exec( db, p10, 3, nil)
+  finalize(p10)
+  for r in db.rows(sql"select * from photos where ID = 1", []):
+    assert r[1].len == d.len
+    assert r[1] == d
+  var i6 = 0
+  for r in db.rows(sql"select * from photos where ID = 3", []):
+    i6 = 1
+  assert i6 == 1
+  var p9 = db.prepare("select * from photos where PHOTO is ?")
+  p9.bindParams(nil)
+  var rowsP2 = 0
+  for r in db.rows(p9):
+    rowsP2 = 1
+    echo(r[0], repr r[1])
+  assert rowsP2 == 1
+  finalize(p9)
 
   db_sqlite.close(db)