Skip to content

Commit

Permalink
Avoid MISUSE in Stmt.Bind
Browse files Browse the repository at this point in the history
If a Conn is interrupted then Conn.Prep will return a Stmt with
Stmt.stmt = nil. This causes SQLITE MISUSE errors to occur when any of
the Bind functions are called on the Stmt. This is confusing to users as
the error returned by Step is misleading and anyone using sqlite.Logger
will see misleading MISUSE messages. Now we check if Stmt.stmt is not
nil prior to using it in the Bind statements, which solves these issues
and makes the interrupt obvious when Stmt.Step is later called.

fix crawshaw#86
  • Loading branch information
AdamSLevy committed Feb 10, 2020
1 parent 469650a commit 6ab0960
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 24 deletions.
10 changes: 5 additions & 5 deletions export_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,10 @@ func ConnCount(conn *Conn) int { return conn.count }

func InterruptedStmt(conn *Conn, query string) *Stmt {
return &Stmt{
conn: conn,
query: query,
bindNames: make(map[string]int),
colNames: make(map[string]int),
prepInterupt: true,
conn: conn,
query: query,
bindNames: make(map[string]int),
colNames: make(map[string]int),
prepInterrupt: true,
}
}
63 changes: 44 additions & 19 deletions sqlite.go
Original file line number Diff line number Diff line change
Expand Up @@ -358,11 +358,11 @@ func (conn *Conn) Prep(query string) *Stmt {
if err != nil {
if ErrCode(err) == SQLITE_INTERRUPT {
return &Stmt{
conn: conn,
query: query,
bindNames: make(map[string]int),
colNames: make(map[string]int),
prepInterupt: true,
conn: conn,
query: query,
bindNames: make(map[string]int),
colNames: make(map[string]int),
prepInterrupt: true,
}
}
panic(err)
Expand Down Expand Up @@ -532,19 +532,20 @@ func reserr(loc, query, msg string, res C.int) error {
// When a Stmt is no longer needed it should be cleaned up
// by calling the Finalize method.
type Stmt struct {
conn *Conn
stmt *C.sqlite3_stmt
query string
bindNames map[string]int
colNames map[string]int
bindErr error
prepInterupt bool // set if Prep was interrupted
lastHasRow bool // last bool returned by Step
tracerTask TracerTask
conn *Conn
stmt *C.sqlite3_stmt
query string
bindNames map[string]int
colNames map[string]int
bindErr error
prepInterrupt bool // set if Prep was interrupted
lastHasRow bool // last bool returned by Step
tracerTask TracerTask
}

func (stmt *Stmt) interrupted(loc string) error {
if stmt.prepInterupt {
loc = "Stmt." + loc
if stmt.prepInterrupt {
return reserr(loc, stmt.query, "", C.SQLITE_INTERRUPT)
}
return stmt.conn.interrupted(loc, stmt.query)
Expand Down Expand Up @@ -599,7 +600,7 @@ func (stmt *Stmt) Reset() error {
// https://www.sqlite.org/c3ref/clear_bindings.html
func (stmt *Stmt) ClearBindings() error {
stmt.conn.count++
if err := stmt.interrupted("Stmt.ClearBindings"); err != nil {
if err := stmt.interrupted("ClearBindings"); err != nil {
return err
}
res := C.sqlite3_clear_bindings(stmt.stmt)
Expand Down Expand Up @@ -668,7 +669,7 @@ func (stmt *Stmt) Step() (rowReturned bool, err error) {
func (stmt *Stmt) step() (bool, error) {
for {
stmt.conn.count++
if err := stmt.interrupted("Stmt.Step"); err != nil {
if err := stmt.interrupted("Step"); err != nil {
return false, err
}
switch res := C.sqlite3_step(stmt.stmt); uint8(res) { // reduce to non-extended error code
Expand Down Expand Up @@ -699,14 +700,14 @@ func (stmt *Stmt) step() (bool, error) {

func (stmt *Stmt) handleBindErr(loc string, res C.int) {
if stmt.bindErr == nil {
stmt.bindErr = stmt.conn.reserr(loc, stmt.query, res)
stmt.bindErr = stmt.conn.reserr("Stmt."+loc, stmt.query, res)
}
}

func (stmt *Stmt) findBindName(loc string, param string) int {
pos := stmt.bindNames[param]
if pos == 0 && stmt.bindErr == nil {
stmt.bindErr = reserr(loc, stmt.query, "unknown parameter: "+param, C.SQLITE_ERROR)
stmt.bindErr = reserr("Stmt."+loc, stmt.query, "unknown parameter: "+param, C.SQLITE_ERROR)
}
return pos
}
Expand Down Expand Up @@ -739,6 +740,9 @@ func (stmt *Stmt) ColumnName(col int) string {
//
// https://www.sqlite.org/c3ref/bind_parameter_count.html
func (stmt *Stmt) BindParamCount() int {
if stmt.stmt == nil {
return 0
}
return int(C.sqlite3_bind_parameter_count(stmt.stmt))
}

Expand All @@ -748,6 +752,9 @@ func (stmt *Stmt) BindParamCount() int {
//
// https://www.sqlite.org/c3ref/bind_blob.html
func (stmt *Stmt) BindInt64(param int, value int64) {
if stmt.stmt == nil {
return
}
res := C.sqlite3_bind_int64(stmt.stmt, C.int(param), C.sqlite3_int64(value))
stmt.handleBindErr("BindInt64", res)
}
Expand All @@ -758,6 +765,9 @@ func (stmt *Stmt) BindInt64(param int, value int64) {
//
// https://www.sqlite.org/c3ref/bind_blob.html
func (stmt *Stmt) BindBool(param int, value bool) {
if stmt.stmt == nil {
return
}
v := 0
if value {
v = 1
Expand All @@ -775,6 +785,9 @@ func (stmt *Stmt) BindBool(param int, value bool) {
//
// https://www.sqlite.org/c3ref/bind_blob.html
func (stmt *Stmt) BindBytes(param int, value []byte) {
if stmt.stmt == nil {
return
}
var v *C.char
if len(value) != 0 {
v = (*C.char)(unsafe.Pointer(&value[0]))
Expand All @@ -792,6 +805,9 @@ var emptyCstr = C.CString("")
//
// https://www.sqlite.org/c3ref/bind_blob.html
func (stmt *Stmt) BindText(param int, value string) {
if stmt.stmt == nil {
return
}
var v *C.char
var free *[0]byte
if len(value) == 0 {
Expand All @@ -810,6 +826,9 @@ func (stmt *Stmt) BindText(param int, value string) {
//
// https://www.sqlite.org/c3ref/bind_blob.html
func (stmt *Stmt) BindFloat(param int, value float64) {
if stmt.stmt == nil {
return
}
res := C.sqlite3_bind_double(stmt.stmt, C.int(param), C.double(value))
stmt.handleBindErr("BindFloat", res)
}
Expand All @@ -820,6 +839,9 @@ func (stmt *Stmt) BindFloat(param int, value float64) {
//
// https://www.sqlite.org/c3ref/bind_blob.html
func (stmt *Stmt) BindNull(param int) {
if stmt.stmt == nil {
return
}
res := C.sqlite3_bind_null(stmt.stmt, C.int(param))
stmt.handleBindErr("BindNull", res)
}
Expand All @@ -830,6 +852,9 @@ func (stmt *Stmt) BindNull(param int) {
//
// https://www.sqlite.org/c3ref/bind_blob.html
func (stmt *Stmt) BindZeroBlob(param int, len int64) {
if stmt.stmt == nil {
return
}
res := C.sqlite3_bind_zeroblob64(stmt.stmt, C.int(param), C.sqlite3_uint64(len))
stmt.handleBindErr("BindZeroBlob", res)
}
Expand Down

0 comments on commit 6ab0960

Please sign in to comment.