summary refs log tree commit diff stats
diff options
context:
space:
mode:
authorjfilby <jason.filby@gmail.com>2022-11-27 20:28:38 +0200
committerGitHub <noreply@github.com>2022-11-27 19:28:38 +0100
commit5a848a070759c4b5fa868741b430f96586a9ac18 (patch)
treee41f357fc3c48ca0a8ecbd2963ce99f16756c217
parentf644f046540a1a3a59aecf5468bc4d9ae0ea101f (diff)
downloadNim-5a848a070759c4b5fa868741b430f96586a9ac18.tar.gz
Fix several memory leaks in the Postgres wrapper. (#20940)
-rw-r--r--lib/impure/db_postgres.nim29
-rw-r--r--tests/untestable/tpostgres.nim14
2 files changed, 27 insertions, 16 deletions
diff --git a/lib/impure/db_postgres.nim b/lib/impure/db_postgres.nim
index 47f071475..43ab94dbc 100644
--- a/lib/impure/db_postgres.nim
+++ b/lib/impure/db_postgres.nim
@@ -201,7 +201,8 @@ proc prepare*(db: DbConn; stmtName: string, query: SqlQuery;
     dbError("parameter substitution expects \"$1\"")
   var res = pqprepare(db, stmtName, query.cstring, int32(nParams), nil)
   if pqResultStatus(res) != PGRES_COMMAND_OK: dbError(db)
-  return SqlPrepared(stmtName)
+  result = SqlPrepared(stmtName)
+  pqclear(res)
 
 proc setRow(res: PPGresult, r: var Row, line, cols: int32) =
   for col in 0'i32..cols-1:
@@ -468,12 +469,12 @@ proc getRow*(db: DbConn, query: SqlQuery,
   ## retrieves a single row. If the query doesn't return any rows, this proc
   ## will return a Row with empty strings for each column.
   let res = setupQuery(db, query, args)
-  getRow(res)
+  result = getRow(res)
 
 proc getRow*(db: DbConn, stmtName: SqlPrepared,
              args: varargs[string, `$`]): Row {.tags: [ReadDbEffect].} =
   let res = setupQuery(db, stmtName, args)
-  getRow(res)
+  result = getRow(res)
 
 proc getAllRows(res: PPGresult): seq[Row] =
   let N = pqntuples(res)
@@ -490,14 +491,14 @@ proc getAllRows*(db: DbConn, query: SqlQuery,
                  tags: [ReadDbEffect].} =
   ## executes the query and returns the whole result dataset.
   let res = setupQuery(db, query, args)
-  getAllRows(res)
+  result = getAllRows(res)
 
 proc getAllRows*(db: DbConn, stmtName: SqlPrepared,
                  args: varargs[string, `$`]): seq[Row] {.tags:
                  [ReadDbEffect].} =
   ## executes the prepared query and returns the whole result dataset.
   let res = setupQuery(db, stmtName, args)
-  getAllRows(res)
+  result = getAllRows(res)
 
 iterator rows*(db: DbConn, query: SqlQuery,
                args: varargs[string, `$`]): Row {.tags: [ReadDbEffect].} =
@@ -523,7 +524,8 @@ proc getValue*(db: DbConn, query: SqlQuery,
   ## result dataset. Returns "" if the dataset contains no rows or the database
   ## value is NULL.
   let res = setupQuery(db, query, args)
-  getValue(res)
+  result = getValue(res)
+  pqclear(res)
 
 proc getValue*(db: DbConn, stmtName: SqlPrepared,
                args: varargs[string, `$`]): string {.
@@ -532,7 +534,8 @@ proc getValue*(db: DbConn, stmtName: SqlPrepared,
   ## result dataset. Returns "" if the dataset contains no rows or the database
   ## value is NULL.
   let res = setupQuery(db, stmtName, args)
-  getValue(res)
+  result = getValue(res)
+  pqclear(res)
 
 proc tryInsertID*(db: DbConn, query: SqlQuery,
                   args: varargs[string, `$`]): int64 {.
@@ -541,12 +544,14 @@ proc tryInsertID*(db: DbConn, query: SqlQuery,
   ## generated ID for the row or -1 in case of an error. For Postgre this adds
   ## `RETURNING id` to the query, so it only works if your primary key is
   ## named `id`.
-  var x = pqgetvalue(setupQuery(db, SqlQuery(string(query) & " RETURNING id"),
-    args), 0, 0)
+  let res = setupQuery(db, SqlQuery(string(query) & " RETURNING id"),
+                       args)
+  var x = pqgetvalue(res, 0, 0)
   if not isNil(x):
     result = parseBiggestInt($x)
   else:
     result = -1
+  pqclear(res)
 
 proc insertID*(db: DbConn, query: SqlQuery,
                args: varargs[string, `$`]): int64 {.
@@ -563,12 +568,14 @@ proc tryInsert*(db: DbConn, query: SqlQuery,pkName: string,
                {.tags: [WriteDbEffect], since: (1, 3).}=
   ## executes the query (typically "INSERT") and returns the
   ## generated ID for the row or -1 in case of an error.
-  var x = pqgetvalue(setupQuery(db, SqlQuery(string(query) & " RETURNING " & pkName),
-    args), 0, 0)
+  let res = setupQuery(db, SqlQuery(string(query) & " RETURNING " & pkName),
+                       args)
+  var x = pqgetvalue(res, 0, 0)
   if not isNil(x):
     result = parseBiggestInt($x)
   else:
     result = -1
+  pqclear(res)
 
 proc insert*(db: DbConn, query: SqlQuery, pkName: string,
              args: varargs[string, `$`]): int64
diff --git a/tests/untestable/tpostgres.nim b/tests/untestable/tpostgres.nim
index d3397e53a..e9f2403e2 100644
--- a/tests/untestable/tpostgres.nim
+++ b/tests/untestable/tpostgres.nim
@@ -2,6 +2,7 @@ import db_postgres, strutils
 
 
 let db = open("localhost", "dom", "", "test")
+
 db.exec(sql"DROP TABLE IF EXISTS myTable")
 db.exec(sql("""CREATE TABLE myTable (
                   id integer PRIMARY KEY,
@@ -49,9 +50,12 @@ try:
   doAssert false, "Exception expected"
 except DbError:
   let msg = getCurrentExceptionMsg().normalize
-  doAssert "expects" in msg
-  doAssert "?" in msg
-  doAssert "parameter substitution" in msg
+
+  info "DbError",
+    msg = $msg
+
+  doAssert "no parameter" in msg
+  doAssert "$1" in msg
 
 doAssert db.getValue(sql("select filename from files where id = ?"), 1) == "hello.tmp"
 
@@ -315,11 +319,11 @@ db.exec(sql("""CREATE TABLE DICTIONARY(
 var entry = "あっそ"
 var definition = "(int) (See ああそうそう) oh, really (uninterested)/oh yeah?/hmmmmm"
 discard db.getRow(
-  SqlQuery("INSERT INTO DICTIONARY(entry, definition) VALUES(\'$1\', \'$2\') RETURNING id" % [entry, definition]))
+  sql("INSERT INTO DICTIONARY(entry, definition) VALUES(?, ?) RETURNING id"), entry, definition)
 doAssert db.getValue(sql"SELECT definition FROM DICTIONARY WHERE entry = ?", entry) == definition
 entry = "Format string entry"
 definition = "Format string definition"
-db.exec(sql"INSERT INTO DICTIONARY(entry, definition) VALUES (?, ?)", entry, definition)
+db.exec(SqlQuery("INSERT INTO DICTIONARY(entry, definition) VALUES (?, ?)"), entry, definition)
 doAssert db.getValue(sql"SELECT definition FROM DICTIONARY WHERE entry = ?", entry) == definition
 
 echo("All tests succeeded!")