summary refs log tree commit diff stats
path: root/lib/impure
diff options
context:
space:
mode:
authorRegis Caillaud <35006197+Clonkk@users.noreply.github.com>2020-11-02 13:02:55 +0100
committerGitHub <noreply@github.com>2020-11-02 13:02:55 +0100
commit6fa82a5b3afbe644eef3fb41647a341d8e9b21c4 (patch)
treeeef7ff259f34ba4c072ec25e4743f68cf3c9817e /lib/impure
parent4fe56b56ce2b856ff8fae9051cceb766ab7729e0 (diff)
downloadNim-6fa82a5b3afbe644eef3fb41647a341d8e9b21c4.tar.gz
Handle BLOB column type in SQLite as binary data (#15681)
* Fixed not handling blob correctly in sqlite
* Fixed setLen commented by mistake
* Added binary example as db_sqlite doc
* Added tests for sqlite binary data
Diffstat (limited to 'lib/impure')
-rw-r--r--lib/impure/db_sqlite.nim93
1 files changed, 78 insertions, 15 deletions
diff --git a/lib/impure/db_sqlite.nim b/lib/impure/db_sqlite.nim
index 80eee59fb..ce57996be 100644
--- a/lib/impure/db_sqlite.nim
+++ b/lib/impure/db_sqlite.nim
@@ -94,6 +94,57 @@
 ##
 ##    db.close()
 ##
+## Storing binary data example
+##----------------------------
+##
+## .. code-block:: nim
+##
+##   import random
+##
+##   ## Generate random float datas
+##   var orig = newSeq[float64](150)
+##   randomize()
+##   for x in orig.mitems:
+##     x = rand(1.0)/10.0
+##
+##   let db = open("mysqlite.db", "", "", "")
+##   block: ## Create database
+##     ## Binary datas needs to be of type BLOB in SQLite
+##     let createTableStr = sql"""CREATE TABLE test(
+##       id INTEGER NOT NULL PRIMARY KEY,
+##       data BLOB
+##     )
+##     """
+##     db.exec(createTableStr)
+##
+##   block: ## Insert data
+##     var id = 1
+##     ## Data needs to be converted to seq[byte] to be interpreted as binary by bindParams
+##     var dbuf = newSeq[byte](orig.len*sizeof(float64))
+##     copyMem(unsafeAddr(dbuf[0]), unsafeAddr(orig[0]), dbuf.len)
+##
+##     ## Use prepared statement to insert binary data into database
+##     var insertStmt = db.prepare("INSERT INTO test (id, data) VALUES (?, ?)")
+##     insertStmt.bindParams(id, dbuf)
+##     let bres = db.tryExec(insertStmt)
+##     ## Check insert
+##     doAssert(bres)
+##     # Destroy statement
+##     finalize(insertStmt)
+##
+##   block: ## Use getValue to select data
+##     var dataTest = db.getValue(sql"SELECT data FROM test WHERE id = ?", 1)
+##     ## Calculate sequence size from buffer size
+##     let seqSize = int(dataTest.len*sizeof(byte)/sizeof(float64))
+##     ## Copy binary string data in dataTest into a seq
+##     var res: seq[float64] = newSeq[float64](seqSize)
+##     copyMem(unsafeAddr(res[0]), addr(dataTest[0]), dataTest.len)
+##
+##     ## Check datas obtained is identical
+##     doAssert res == orig
+##
+##   db.close()
+##
 ##
 ## Note
 ## ====
@@ -242,10 +293,14 @@ proc setupQuery(db: DbConn, stmtName: SqlPrepared): SqlPrepared {.since: (1, 3).
 
 proc setRow(stmt: PStmt, r: var Row, cols: cint) =
   for col in 0'i32..cols-1:
-    setLen(r[col], column_bytes(stmt, col)) # set capacity
-    setLen(r[col], 0)
-    let x = column_text(stmt, col)
-    if not isNil(x): add(r[col], x)
+    let cb = column_bytes(stmt, col)
+    setLen(r[col], cb) # set capacity
+    if column_type(stmt, col) == SQLITE_BLOB:
+      copyMem(addr(r[col][0]), column_blob(stmt, col), cb)
+    else:
+      setLen(r[col], 0)
+      let x = column_text(stmt, col)
+      if not isNil(x): add(r[col], x)
 
 iterator fastRows*(db: DbConn, query: SqlQuery,
                    args: varargs[string, `$`]): Row {.tags: [ReadDbEffect].} =
@@ -288,7 +343,7 @@ iterator fastRows*(db: DbConn, query: SqlQuery,
   finally:
     if finalize(stmt) != SQLITE_OK: dbError(db)
 
-iterator fastRows*(db: DbConn, stmtName: SqlPrepared): Row 
+iterator fastRows*(db: DbConn, stmtName: SqlPrepared): Row
                   {.tags: [ReadDbEffect,WriteDbEffect], since: (1, 3).} =
   discard setupQuery(db, stmtName)
   var L = (column_count(stmtName.PStmt))
@@ -491,7 +546,7 @@ proc getAllRows*(db: DbConn, query: SqlQuery,
   for r in fastRows(db, query, args):
     result.add(r)
 
-proc getAllRows*(db: DbConn, stmtName: SqlPrepared): seq[Row] 
+proc getAllRows*(db: DbConn, stmtName: SqlPrepared): seq[Row]
                 {.tags: [ReadDbEffect,WriteDbEffect], since: (1, 3).} =
   result = @[]
   for r in fastRows(db, stmtName):
@@ -524,7 +579,7 @@ iterator rows*(db: DbConn, query: SqlQuery,
   ##    db.close()
   for r in fastRows(db, query, args): yield r
 
-iterator rows*(db: DbConn, stmtName: SqlPrepared): Row 
+iterator rows*(db: DbConn, stmtName: SqlPrepared): Row
               {.tags: [ReadDbEffect,WriteDbEffect], since: (1, 3).} =
   for r in fastRows(db, stmtName): yield r
 
@@ -558,13 +613,17 @@ proc getValue*(db: DbConn, query: SqlQuery,
     if cb == 0:
       result = ""
     else:
-      result = newStringOfCap(cb)
-      add(result, column_text(stmt, 0))
+      if column_type(stmt, 0) == SQLITE_BLOB:
+        result.setLen(cb)
+        copyMem(addr(result[0]), column_blob(stmt, 0), cb)
+      else:
+        result = newStringOfCap(cb)
+        add(result, column_text(stmt, 0))
   else:
     result = ""
   if finalize(stmt) != SQLITE_OK: dbError(db)
 
-proc getValue*(db: DbConn,  stmtName: SqlPrepared): string 
+proc getValue*(db: DbConn,  stmtName: SqlPrepared): string
               {.tags: [ReadDbEffect,WriteDbEffect], since: (1, 3).} =
   var stmt = setupQuery(db, stmtName).PStmt
   if step(stmt) == SQLITE_ROW:
@@ -572,8 +631,12 @@ proc getValue*(db: DbConn,  stmtName: SqlPrepared): string
     if cb == 0:
       result = ""
     else:
-      result = newStringOfCap(cb)
-      add(result, column_text(stmt, 0))
+      if column_type(stmt, 0) == SQLITE_BLOB:
+        result.setLen(cb)
+        copyMem(addr(result[0]), column_blob(stmt, 0), cb)
+      else:
+        result = newStringOfCap(cb)
+        add(result, column_text(stmt, 0))
   else:
     result = ""
 
@@ -671,7 +734,7 @@ proc execAffectedRows*(db: DbConn, query: SqlQuery,
   exec(db, query, args)
   result = changes(db)
 
-proc execAffectedRows*(db: DbConn, stmtName: SqlPrepared): int64 
+proc execAffectedRows*(db: DbConn, stmtName: SqlPrepared): int64
                       {.tags: [ReadDbEffect, WriteDbEffect],since: (1, 3).} =
   exec(db, stmtName)
   result = changes(db)
@@ -723,7 +786,7 @@ proc setEncoding*(connection: DbConn, encoding: string): bool {.
   exec(connection, sql"PRAGMA encoding = ?", [encoding])
   result = connection.getValue(sql"PRAGMA encoding") == encoding
 
-proc finalize*(sqlPrepared:SqlPrepared) {.discardable, since: (1, 3).} = 
+proc finalize*(sqlPrepared:SqlPrepared) {.discardable, since: (1, 3).} =
   discard finalize(sqlPrepared.PStmt)
 
 template dbBindParamError*(paramIdx: int, val: varargs[untyped]) =
@@ -756,7 +819,7 @@ proc bindParam*(ps: SqlPrepared, paramIdx: int, val: float64) {.since: (1, 3).}
     dbBindParamError(paramIdx, val)
 
 proc bindNull*(ps: SqlPrepared, paramIdx: int) {.since: (1, 3).} =
-  ## Sets the bindparam at the specified paramIndex to null 
+  ## Sets the bindparam at the specified paramIndex to null
   ## (default behaviour by sqlite).
   if bind_null(ps.PStmt, paramIdx.int32) != SQLITE_OK:
     dbBindParamError(paramIdx)