@@ -31,6 +31,7 @@ package sqlite3
31
31
#endif
32
32
#include <stdlib.h>
33
33
#include <string.h>
34
+ #include <ctype.h>
34
35
35
36
#ifdef __CYGWIN__
36
37
# include <errno.h>
@@ -79,6 +80,16 @@ _sqlite3_exec(sqlite3* db, const char* pcmd, long long* rowid, long long* change
79
80
return rv;
80
81
}
81
82
83
+ static const char *
84
+ _trim_leading_spaces(const char *str) {
85
+ if (str) {
86
+ while (isspace(*str)) {
87
+ str++;
88
+ }
89
+ }
90
+ return str;
91
+ }
92
+
82
93
#ifdef SQLITE_ENABLE_UNLOCK_NOTIFY
83
94
extern int _sqlite3_step_blocking(sqlite3_stmt *stmt);
84
95
extern int _sqlite3_step_row_blocking(sqlite3_stmt* stmt, long long* rowid, long long* changes);
@@ -99,7 +110,11 @@ _sqlite3_step_row_internal(sqlite3_stmt* stmt, long long* rowid, long long* chan
99
110
static int
100
111
_sqlite3_prepare_v2_internal(sqlite3 *db, const char *zSql, int nBytes, sqlite3_stmt **ppStmt, const char **pzTail)
101
112
{
102
- return _sqlite3_prepare_v2_blocking(db, zSql, nBytes, ppStmt, pzTail);
113
+ int rv = _sqlite3_prepare_v2_blocking(db, zSql, nBytes, ppStmt, pzTail);
114
+ if (pzTail) {
115
+ *pzTail = _trim_leading_spaces(*pzTail);
116
+ }
117
+ return rv;
103
118
}
104
119
105
120
#else
@@ -122,7 +137,11 @@ _sqlite3_step_row_internal(sqlite3_stmt* stmt, long long* rowid, long long* chan
122
137
static int
123
138
_sqlite3_prepare_v2_internal(sqlite3 *db, const char *zSql, int nBytes, sqlite3_stmt **ppStmt, const char **pzTail)
124
139
{
125
- return sqlite3_prepare_v2(db, zSql, nBytes, ppStmt, pzTail);
140
+ int rv = sqlite3_prepare_v2(db, zSql, nBytes, ppStmt, pzTail);
141
+ if (pzTail) {
142
+ *pzTail = _trim_leading_spaces(*pzTail);
143
+ }
144
+ return rv;
126
145
}
127
146
#endif
128
147
@@ -848,24 +867,32 @@ func (c *SQLiteConn) Exec(query string, args []driver.Value) (driver.Result, err
848
867
}
849
868
850
869
func (c * SQLiteConn ) exec (ctx context.Context , query string , args []driver.NamedValue ) (driver.Result , error ) {
870
+ pquery := C .CString (query )
871
+ op := pquery // original pointer
872
+ defer C .free (unsafe .Pointer (op ))
873
+
874
+ var stmtArgs []driver.NamedValue
875
+ var tail * C.char
876
+ s := new (SQLiteStmt ) // escapes to the heap so reuse it
851
877
start := 0
852
878
for {
853
- s , err := c .prepare (ctx , query )
854
- if err != nil {
855
- return nil , err
879
+ * s = SQLiteStmt {c : c } // reset
880
+ rv := C ._sqlite3_prepare_v2_internal (c .db , pquery , C .int (- 1 ), & s .s , & tail )
881
+ if rv != C .SQLITE_OK {
882
+ return nil , c .lastError ()
856
883
}
884
+
857
885
var res driver.Result
858
- if s .(* SQLiteStmt ).s != nil {
859
- stmtArgs := make ([]driver.NamedValue , 0 , len (args ))
886
+ if s .s != nil {
860
887
na := s .NumInput ()
861
888
if len (args )- start < na {
862
- s .Close ()
889
+ s .finalize ()
863
890
return nil , fmt .Errorf ("not enough args to execute query: want %d got %d" , na , len (args ))
864
891
}
865
892
// consume the number of arguments used in the current
866
893
// statement and append all named arguments not
867
894
// contained therein
868
- stmtArgs = append (stmtArgs , args [start :start + na ]... )
895
+ stmtArgs = append (stmtArgs [: 0 ] , args [start :start + na ]... )
869
896
for i := range args {
870
897
if (i < start || i >= na ) && args [i ].Name != "" {
871
898
stmtArgs = append (stmtArgs , args [i ])
@@ -874,23 +901,23 @@ func (c *SQLiteConn) exec(ctx context.Context, query string, args []driver.Named
874
901
for i := range stmtArgs {
875
902
stmtArgs [i ].Ordinal = i + 1
876
903
}
877
- res , err = s .(* SQLiteStmt ).exec (ctx , stmtArgs )
904
+ var err error
905
+ res , err = s .exec (ctx , stmtArgs )
878
906
if err != nil && err != driver .ErrSkip {
879
- s .Close ()
907
+ s .finalize ()
880
908
return nil , err
881
909
}
882
910
start += na
883
911
}
884
- tail := s .(* SQLiteStmt ).t
885
- s .Close ()
886
- if tail == "" {
912
+ s .finalize ()
913
+ if tail == nil || * tail == '\000' {
887
914
if res == nil {
888
915
// https://github.com/mattn/go-sqlite3/issues/963
889
916
res = & SQLiteResult {0 , 0 }
890
917
}
891
918
return res , nil
892
919
}
893
- query = tail
920
+ pquery = tail
894
921
}
895
922
}
896
923
@@ -907,22 +934,29 @@ func (c *SQLiteConn) Query(query string, args []driver.Value) (driver.Rows, erro
907
934
}
908
935
909
936
func (c * SQLiteConn ) query (ctx context.Context , query string , args []driver.NamedValue ) (driver.Rows , error ) {
937
+ pquery := C .CString (query )
938
+ op := pquery // original pointer
939
+ defer C .free (unsafe .Pointer (op ))
940
+
941
+ var stmtArgs []driver.NamedValue
942
+ var tail * C.char
943
+ s := new (SQLiteStmt ) // escapes to the heap so reuse it
910
944
start := 0
911
945
for {
912
- stmtArgs := make ([]driver. NamedValue , 0 , len ( args ))
913
- s , err := c . prepare ( ctx , query )
914
- if err != nil {
915
- return nil , err
946
+ * s = SQLiteStmt { c : c , cls : true } // reset
947
+ rv := C . _sqlite3_prepare_v2_internal ( c . db , pquery , C . int ( - 1 ), & s . s , & tail )
948
+ if rv != C . SQLITE_OK {
949
+ return nil , c . lastError ()
916
950
}
917
- s .( * SQLiteStmt ). cls = true
951
+
918
952
na := s .NumInput ()
919
953
if len (args )- start < na {
920
954
return nil , fmt .Errorf ("not enough args to execute query: want %d got %d" , na , len (args )- start )
921
955
}
922
956
// consume the number of arguments used in the current
923
957
// statement and append all named arguments not contained
924
958
// therein
925
- stmtArgs = append (stmtArgs , args [start :start + na ]... )
959
+ stmtArgs = append (stmtArgs [: 0 ] , args [start :start + na ]... )
926
960
for i := range args {
927
961
if (i < start || i >= na ) && args [i ].Name != "" {
928
962
stmtArgs = append (stmtArgs , args [i ])
@@ -931,19 +965,18 @@ func (c *SQLiteConn) query(ctx context.Context, query string, args []driver.Name
931
965
for i := range stmtArgs {
932
966
stmtArgs [i ].Ordinal = i + 1
933
967
}
934
- rows , err := s .( * SQLiteStmt ). query (ctx , stmtArgs )
968
+ rows , err := s .query (ctx , stmtArgs )
935
969
if err != nil && err != driver .ErrSkip {
936
- s .Close ()
970
+ s .finalize ()
937
971
return rows , err
938
972
}
939
973
start += na
940
- tail := s .(* SQLiteStmt ).t
941
- if tail == "" {
974
+ if tail == nil || * tail == '\000' {
942
975
return rows , nil
943
976
}
944
977
rows .Close ()
945
- s .Close ()
946
- query = tail
978
+ s .finalize ()
979
+ pquery = tail
947
980
}
948
981
}
949
982
@@ -1805,8 +1838,11 @@ func (c *SQLiteConn) prepare(ctx context.Context, query string) (driver.Stmt, er
1805
1838
return nil , c .lastError ()
1806
1839
}
1807
1840
var t string
1808
- if tail != nil && * tail != '\000' {
1809
- t = strings .TrimSpace (C .GoString (tail ))
1841
+ if tail != nil && * tail != 0 {
1842
+ n := int (uintptr (unsafe .Pointer (tail ))) - int (uintptr (unsafe .Pointer (pquery )))
1843
+ if 0 <= n && n < len (query ) {
1844
+ t = strings .TrimSpace (query [n :])
1845
+ }
1810
1846
}
1811
1847
ss := & SQLiteStmt {c : c , s : s , t : t }
1812
1848
runtime .SetFinalizer (ss , (* SQLiteStmt ).Close )
@@ -1899,6 +1935,13 @@ func (s *SQLiteStmt) Close() error {
1899
1935
return nil
1900
1936
}
1901
1937
1938
+ func (s * SQLiteStmt ) finalize () {
1939
+ if s .s != nil {
1940
+ C .sqlite3_finalize (s .s )
1941
+ s .s = nil
1942
+ }
1943
+ }
1944
+
1902
1945
// NumInput return a number of parameters.
1903
1946
func (s * SQLiteStmt ) NumInput () int {
1904
1947
return int (C .sqlite3_bind_parameter_count (s .s ))
0 commit comments