Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
145 changes: 94 additions & 51 deletions sqlite3.go
Original file line number Diff line number Diff line change
Expand Up @@ -602,12 +602,13 @@ type SQLiteTx struct {

// SQLiteStmt implements driver.Stmt.
type SQLiteStmt struct {
mu sync.Mutex
c *SQLiteConn
s *C.sqlite3_stmt
closed bool
cls bool // True if the statement was created by SQLiteConn.Query
reset bool // True if the statement needs to reset before re-use
mu sync.Mutex
c *SQLiteConn
s *C.sqlite3_stmt
namedParams map[string]int // Map of named parameter to index
closed bool
cls bool // True if the statement was created by SQLiteConn.Query
reset bool // True if the statement needs to reset before re-use
}

// SQLiteResult implements sql.Result.
Expand Down Expand Up @@ -2328,8 +2329,35 @@ func (s *SQLiteStmt) bind(args []driver.NamedValue) error {
return nil
}

func (s *SQLiteStmt) bindIndices(args []driver.NamedValue) error {
// Find the longest named parameter name.
func (s *SQLiteStmt) missingNamedParams(args []driver.NamedValue) bool {
for _, v := range args {
if v.Name != "" {
if _, ok := s.namedParams[v.Name]; !ok {
return true
}
}
}
return false
}

func (s *SQLiteStmt) createBindIndices(args []driver.NamedValue) {
if len(args) == 0 {
return
}
s.mu.Lock()
defer s.mu.Unlock()

// Check if we need to update or create the named params map.
// This check is necessary because a bad set of params could
// be passed to a prepared statement on its first invocation.
if !s.missingNamedParams(args) {
return
}
if s.namedParams == nil {
s.namedParams = make(map[string]int, len(args))
}

// Find the longest parameter name
n := 0
for _, v := range args {
if m := len(v.Name); m > n {
Expand All @@ -2338,60 +2366,75 @@ func (s *SQLiteStmt) bindIndices(args []driver.NamedValue) error {
}
buf := make([]byte, 0, n+2) // +2 for placeholder and null terminator

bindIndices := make([][3]int, len(args))
for _, v := range args {
if v.Name == "" {
continue
}
for _, c := range []byte{':', '@', '$'} {
buf = append(buf[:0], c)
buf = append(buf, v.Name...)
buf = append(buf, 0)
idx := int(C.sqlite3_bind_parameter_index(s.s, (*C.char)(unsafe.Pointer(&buf[0]))))
if idx != 0 {
s.namedParams[v.Name] = idx
break
}
}
}
}

func (s *SQLiteStmt) bindIndices(args []driver.NamedValue) error {
s.createBindIndices(args)

bindIndices := make([]int, len(args))
for i, v := range args {
bindIndices[i][0] = args[i].Ordinal
bindIndices[i] = args[i].Ordinal
if v.Name != "" {
for j, c := range []byte{':', '@', '$'} {
buf = append(buf[:0], c)
buf = append(buf, v.Name...)
buf = append(buf, 0)
bindIndices[i][j] = int(C.sqlite3_bind_parameter_index(s.s, (*C.char)(unsafe.Pointer(&buf[0]))))
}
args[i].Ordinal = bindIndices[i][0]
// NB: Parameters with unrecognized names should be ignored:
// https://github.com/mattn/go-sqlite3/issues/697
bindIndices[i] = s.namedParams[v.Name]
args[i].Ordinal = bindIndices[i]
}
}

var rv C.int
for i, arg := range args {
for j := range bindIndices[i] {
if bindIndices[i][j] == 0 {
continue
if bindIndices[i] == 0 {
continue
}
n := C.int(bindIndices[i])
switch v := arg.Value.(type) {
case nil:
rv = C.sqlite3_bind_null(s.s, n)
case string:
p := stringData(v)
rv = C._sqlite3_bind_text(s.s, n, (*C.char)(unsafe.Pointer(p)), C.int(len(v)))
case int64:
rv = C.sqlite3_bind_int64(s.s, n, C.sqlite3_int64(v))
case bool:
val := 0
if v {
val = 1
}
n := C.int(bindIndices[i][j])
switch v := arg.Value.(type) {
case nil:
rv = C.sqlite3_bind_int(s.s, n, C.int(val))
case float64:
rv = C.sqlite3_bind_double(s.s, n, C.double(v))
case []byte:
if v == nil {
rv = C.sqlite3_bind_null(s.s, n)
case string:
p := stringData(v)
rv = C._sqlite3_bind_text(s.s, n, (*C.char)(unsafe.Pointer(p)), C.int(len(v)))
case int64:
rv = C.sqlite3_bind_int64(s.s, n, C.sqlite3_int64(v))
case bool:
val := 0
if v {
val = 1
}
rv = C.sqlite3_bind_int(s.s, n, C.int(val))
case float64:
rv = C.sqlite3_bind_double(s.s, n, C.double(v))
case []byte:
if v == nil {
rv = C.sqlite3_bind_null(s.s, n)
} else {
ln := len(v)
if ln == 0 {
v = placeHolder
}
rv = C._sqlite3_bind_blob(s.s, n, unsafe.Pointer(&v[0]), C.int(ln))
} else {
ln := len(v)
if ln == 0 {
v = placeHolder
}
case time.Time:
b := timefmt.Format(v)
rv = C._sqlite3_bind_text(s.s, n, (*C.char)(unsafe.Pointer(&b[0])), C.int(len(b)))
}
if rv != C.SQLITE_OK {
return s.c.lastError(int(rv))
rv = C._sqlite3_bind_blob(s.s, n, unsafe.Pointer(&v[0]), C.int(ln))
}
case time.Time:
b := timefmt.Format(v)
rv = C._sqlite3_bind_text(s.s, n, (*C.char)(unsafe.Pointer(&b[0])), C.int(len(b)))
}
if rv != C.SQLITE_OK {
return s.c.lastError(int(rv))
}
}
return nil
Expand Down
160 changes: 157 additions & 3 deletions sqlite3_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1093,8 +1093,6 @@ func TestExecer(t *testing.T) {
}

func newTestDB(t testing.TB) *sql.DB {
// fmt.Sprintf("file:%s?mode=rwc", filename)
// ?mode=rwc
db, err := sql.Open("sqlite3", fmt.Sprintf("file:%s?mode=rwc", t.TempDir()+"/test.sqlite3"))
if err != nil {
t.Fatal("Failed to open database:", err)
Expand Down Expand Up @@ -2346,7 +2344,9 @@ func TestNamedParam(t *testing.T) {
}
defer db.Close()

_, err = db.Exec("drop table foo")
if _, err = db.Exec("drop table if exists foo"); err != nil {
t.Fatal(err)
}
_, err = db.Exec("create table foo (id integer, name text, amount integer)")
if err != nil {
t.Fatal("Failed to create table:", err)
Expand Down Expand Up @@ -2376,6 +2376,67 @@ func TestNamedParam(t *testing.T) {
}
}

func TestNamedParamReorder(t *testing.T) {
db := newTestDB(t)
const createTableStmt = `
CREATE TABLE IF NOT EXISTS test_named_params (
r0 INTEGER NOT NULL,
r1 INTEGER NOT NULL
);
DELETE FROM test_named_params;
INSERT INTO test_named_params VALUES (10, 11);
INSERT INTO test_named_params VALUES (20, 21);`

if _, err := db.Exec(createTableStmt); err != nil {
t.Fatal(err)
}

const query = `
SELECT
r0, r1
FROM
test_named_params
WHERE r0 = :v1 AND r1 = :v2;
`
stmt, err := db.Prepare(query)
if err != nil {
t.Fatal(err)
}
defer stmt.Close()

test := func(t testing.TB, arg1, arg2 sql.NamedArg, v1, v2 int64) {
t.Helper()
var i1, i2 int64
err := stmt.QueryRow(arg1, arg2).Scan(&i1, &i2)
if err != nil {
t.Error(err)
return
}
if i1 != v1 && i2 != v2 {
t.Errorf("got: v1=%d v2=%d want: v1=%d v2=%d", i1, i2, v1, v2)
}
}

// Deliberately add invalid named params to make sure that they
// don't poison the named param cache.
test(ignoreError{t}, sql.Named("v1", 10), sql.Named("foo", 11), 10, 11)
test(ignoreError{t}, sql.Named("bar", 10), sql.Named("foo", 11), 10, 11)

test(t, sql.Named("v1", 10), sql.Named("v2", 11), 10, 11)
test(t, sql.Named("v2", 11), sql.Named("v1", 10), 10, 11) // Reverse arg order

// Change argument values
test(t, sql.Named("v1", 20), sql.Named("v2", 21), 20, 21)
test(t, sql.Named("v2", 21), sql.Named("v1", 20), 20, 21) // Reverse arg order

// Extra argument should error
var v1, v2 int64
err = stmt.QueryRow(sql.Named("v1", 10), sql.Named("v2", 11), sql.Named("v3", 12)).Scan(&v1, &v2)
if err == nil {
t.Fatal(err)
}
}

func TestRawBytes(t *testing.T) {
db, err := sql.Open("sqlite3", ":memory:")
if err != nil {
Expand Down Expand Up @@ -2555,6 +2616,7 @@ var benchmarks = []testing.InternalBenchmark{
{Name: "BenchmarkScanRawBytes", F: benchmarkScanRawBytes},
{Name: "BenchmarkQueryParallel", F: benchmarkQueryParallel},
{Name: "BenchmarkOpen", F: benchmarkOpen},
{Name: "BenchmarkNamedParams", F: benchmarkNamedParams},
{Name: "BenchmarkParseTime", F: benchmarkParseTime},
}

Expand Down Expand Up @@ -3337,6 +3399,69 @@ func benchmarkOpen(b *testing.B) {
}
}

func benchmarkNamedParams(b *testing.B) {
db, err := sql.Open("sqlite3", ":memory:")
if err != nil {
b.Fatal(err)
}
defer db.Close()

const createTableStmt = `
DROP TABLE IF EXISTS bench_named_params;
VACUUM;
CREATE TABLE bench_named_params (
r0 INTEGER NOT NULL,
r1 INTEGER NOT NULL,
r2 INTEGER NOT NULL,
r3 INTEGER NOT NULL
);`
if _, err := db.Exec(createTableStmt); err != nil {
b.Fatal(err)
}
for i := int64(0); i < 1; i++ {
_, err := db.Exec("INSERT INTO bench_named_params VALUES (?, ?, ?, ?);", i, i, i, i)
if err != nil {
b.Fatal(err)
}
}
// _, err = db.Exec("insert into foo(id, name, amount) values(:id, @name, $amount)",
const query = `
SELECT
r0
FROM
bench_named_params
WHERE
r0 >= :v0 AND r1 >= :v1 AND r2 >= :v2 AND r3 >= :v3;`

stmt, err := db.Prepare(query)
if err != nil {
b.Fatal(err)
}
defer stmt.Close()

args := []any{
sql.Named("v0", 0),
sql.Named("v1", 0),
sql.Named("v2", 0),
sql.Named("v3", 0),
}
for i := 0; i < b.N; i++ {
rows, err := stmt.Query(args...)
if err != nil {
b.Fatal(err)
}
var v int64
for rows.Next() {
if err := rows.Scan(&v); err != nil {
b.Fatal(err)
}
}
if err := rows.Err(); err != nil {
b.Fatal(err)
}
}
}

func benchmarkParseTime(b *testing.B) {
db, err := sql.Open("sqlite3", ":memory:")
if err != nil {
Expand Down Expand Up @@ -3379,3 +3504,32 @@ func benchmarkParseTime(b *testing.B) {
}
}
}

var _ testing.TB = ignoreError{}

// ignoreError prevents a testing.T from error'ing
type ignoreError struct {
*testing.T
}

func (t ignoreError) FailNow() {}

func (t ignoreError) Error(args ...any) {
t.Helper()
t.T.Log(args...)
}

func (t ignoreError) Errorf(format string, args ...any) {
t.Helper()
t.T.Logf(format, args...)
}

func (t ignoreError) Fatal(args ...any) {
t.Helper()
t.T.Log(args...)
}

func (t ignoreError) Fatalf(format string, args ...any) {
t.Helper()
t.T.Logf(format, args...)
}