Skip to content

Commit 1aa45e0

Browse files
authored
Merge pull request #5 from mashiike/feature/ordinal-parameters
support parameter $ and ?
2 parents 7aacdfc + 36f9452 commit 1aa45e0

File tree

2 files changed

+82
-1
lines changed

2 files changed

+82
-1
lines changed

conn.go

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ func (conn *redshiftDataConn) QueryContext(ctx context.Context, query string, ar
6666

6767
func (conn *redshiftDataConn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) {
6868
params := &redshiftdata.ExecuteStatementInput{
69-
Sql: nullif(query),
69+
Sql: nullif(rewriteQuery(query, len(args))),
7070
Parameters: convertArgsToParameters(args),
7171
}
7272
_, output, err := conn.executeStatement(ctx, params)
@@ -76,6 +76,40 @@ func (conn *redshiftDataConn) ExecContext(ctx context.Context, query string, arg
7676
return newResult(output), nil
7777
}
7878

79+
func rewriteQuery(query string, paramsCount int) string {
80+
if paramsCount == 0 {
81+
return query
82+
}
83+
runes := make([]rune, 0, len(query))
84+
stack := make([]rune, 0)
85+
var exclamationCount int
86+
for _, r := range query {
87+
if len(stack) > 0 {
88+
if r == stack[len(stack)-1] {
89+
stack = stack[:len(stack)-1]
90+
runes = append(runes, r)
91+
continue
92+
}
93+
} else {
94+
switch r {
95+
case '?':
96+
exclamationCount++
97+
runes = append(runes, []rune(fmt.Sprintf(":%d", exclamationCount))...)
98+
continue
99+
case '$':
100+
runes = append(runes, ':')
101+
continue
102+
}
103+
}
104+
switch r {
105+
case '"', '\'':
106+
stack = append(stack, r)
107+
}
108+
runes = append(runes, r)
109+
}
110+
return string(runes)
111+
}
112+
79113
func convertArgsToParameters(args []driver.NamedValue) []types.SqlParameter {
80114
if len(args) == 0 {
81115
return nil

conn_test.go

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
package redshiftdatasqldriver
2+
3+
import (
4+
"testing"
5+
6+
"github.com/stretchr/testify/require"
7+
)
8+
9+
func TestRewriteQuery(t *testing.T) {
10+
cases := []struct {
11+
casename string
12+
query string
13+
paramsCount int
14+
expected string
15+
}{
16+
{
17+
casename: "no params",
18+
query: `SELECT * FROM pg_user`,
19+
paramsCount: 0,
20+
expected: `SELECT * FROM pg_user`,
21+
},
22+
{
23+
casename: "no change",
24+
query: `SELECT * FROM pg_user WHERE usename = :name`,
25+
paramsCount: 1,
26+
expected: `SELECT * FROM pg_user WHERE usename = :name`,
27+
},
28+
{
29+
casename: "? rewrite",
30+
query: `SELECT 'hoge?' FROM pg_user WHERE usename = ? AND usesysid > ?`,
31+
paramsCount: 1,
32+
expected: `SELECT 'hoge?' FROM pg_user WHERE usename = :1 AND usesysid > :2`,
33+
},
34+
{
35+
casename: "$ rewrite",
36+
query: `SELECT '3$1$' FROM table WHERE "$column" = $1 AND column1 > $2 AND column2 < $1`,
37+
paramsCount: 1,
38+
expected: `SELECT '3$1$' FROM table WHERE "$column" = :1 AND column1 > :2 AND column2 < :1`,
39+
},
40+
}
41+
for _, c := range cases {
42+
t.Run(c.casename, func(t *testing.T) {
43+
actual := rewriteQuery(c.query, c.paramsCount)
44+
require.Equal(t, c.expected, actual)
45+
})
46+
}
47+
}

0 commit comments

Comments
 (0)