forked from ydb-platform/ydb-go-sdk
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdsn.go
126 lines (113 loc) · 4.37 KB
/
dsn.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
package ydb
import (
"errors"
"fmt"
"regexp"
"strings"
"github.com/ydb-platform/ydb-go-sdk/v3/balancers"
"github.com/ydb-platform/ydb-go-sdk/v3/credentials"
"github.com/ydb-platform/ydb-go-sdk/v3/internal/bind"
"github.com/ydb-platform/ydb-go-sdk/v3/internal/dsn"
"github.com/ydb-platform/ydb-go-sdk/v3/internal/xerrors"
"github.com/ydb-platform/ydb-go-sdk/v3/internal/xsql"
)
const tablePathPrefixTransformer = "table_path_prefix"
var dsnParsers = []func(dsn string) (opts []Option, _ error){
func(dsn string) ([]Option, error) {
opts, err := parseConnectionString(dsn)
if err != nil {
return nil, xerrors.WithStackTrace(err)
}
return opts, nil
},
}
// RegisterDsnParser registers DSN parser for ydb.Open and sql.Open driver constructors
//
// Experimental: https://github.com/ydb-platform/ydb-go-sdk/blob/master/VERSIONING.md#experimental
func RegisterDsnParser(parser func(dsn string) (opts []Option, _ error)) (registrationID int) {
dsnParsers = append(dsnParsers, parser)
return len(dsnParsers) - 1
}
// UnregisterDsnParser unregisters DSN parser by key
//
// Experimental: https://github.com/ydb-platform/ydb-go-sdk/blob/master/VERSIONING.md#experimental
func UnregisterDsnParser(registrationID int) {
dsnParsers[registrationID] = nil
}
//nolint:funlen
func parseConnectionString(dataSourceName string) (opts []Option, _ error) {
info, err := dsn.Parse(dataSourceName)
if err != nil {
return nil, xerrors.WithStackTrace(err)
}
opts = append(opts, With(info.Options...))
if token := info.Params.Get("token"); token != "" {
opts = append(opts, WithCredentials(credentials.NewAccessTokenCredentials(token)))
}
if balancer := info.Params.Get("go_balancer"); balancer != "" {
opts = append(opts, WithBalancer(balancers.FromConfig(balancer)))
} else if balancer := info.Params.Get("balancer"); balancer != "" {
opts = append(opts, WithBalancer(balancers.FromConfig(balancer)))
}
if queryMode := info.Params.Get("go_query_mode"); queryMode != "" {
mode := xsql.QueryModeFromString(queryMode)
if mode == xsql.UnknownQueryMode {
return nil, xerrors.WithStackTrace(fmt.Errorf("unknown query mode: %s", queryMode))
}
opts = append(opts, withConnectorOptions(xsql.WithDefaultQueryMode(mode)))
} else if queryMode := info.Params.Get("query_mode"); queryMode != "" {
mode := xsql.QueryModeFromString(queryMode)
if mode == xsql.UnknownQueryMode {
return nil, xerrors.WithStackTrace(fmt.Errorf("unknown query mode: %s", queryMode))
}
opts = append(opts, withConnectorOptions(xsql.WithDefaultQueryMode(mode)))
}
if fakeTx := info.Params.Get("go_fake_tx"); fakeTx != "" {
for _, queryMode := range strings.Split(fakeTx, ",") {
mode := xsql.QueryModeFromString(queryMode)
if mode == xsql.UnknownQueryMode {
return nil, xerrors.WithStackTrace(fmt.Errorf("unknown query mode: %s", queryMode))
}
opts = append(opts, withConnectorOptions(xsql.WithFakeTx(mode)))
}
}
if info.Params.Has("go_query_bind") {
var binders []xsql.ConnectorOption
queryTransformers := strings.Split(info.Params.Get("go_query_bind"), ",")
for _, transformer := range queryTransformers {
switch transformer {
case "declare":
binders = append(binders, xsql.WithQueryBind(bind.AutoDeclare{}))
case "positional":
binders = append(binders, xsql.WithQueryBind(bind.PositionalArgs{}))
case "numeric":
binders = append(binders, xsql.WithQueryBind(bind.NumericArgs{}))
default:
if strings.HasPrefix(transformer, tablePathPrefixTransformer) {
prefix, err := extractTablePathPrefixFromBinderName(transformer)
if err != nil {
return nil, xerrors.WithStackTrace(err)
}
binders = append(binders, xsql.WithTablePathPrefix(prefix))
} else {
return nil, xerrors.WithStackTrace(
fmt.Errorf("unknown query rewriter: %s", transformer),
)
}
}
}
opts = append(opts, withConnectorOptions(binders...))
}
return opts, nil
}
var (
tablePathPrefixRe = regexp.MustCompile(tablePathPrefixTransformer + "\\((.*)\\)")
errWrongTablePathPrefix = errors.New("wrong '" + tablePathPrefixTransformer + "' query transformer")
)
func extractTablePathPrefixFromBinderName(binderName string) (string, error) {
ss := tablePathPrefixRe.FindAllStringSubmatch(binderName, -1)
if len(ss) != 1 || len(ss[0]) != 2 || ss[0][1] == "" {
return "", xerrors.WithStackTrace(fmt.Errorf("%w: %s", errWrongTablePathPrefix, binderName))
}
return ss[0][1], nil
}