Skip to content

Commit

Permalink
extracted ExplainQueryMode to internal/connector, other query modes m…
Browse files Browse the repository at this point in the history
…oved from xcontext to internal/table/conn
  • Loading branch information
asmyasnikov committed Dec 9, 2024
1 parent 3d61b65 commit 9f7962d
Show file tree
Hide file tree
Showing 12 changed files with 179 additions and 213 deletions.
7 changes: 3 additions & 4 deletions dsn.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ import (
"github.com/ydb-platform/ydb-go-sdk/v3/internal/connector"
"github.com/ydb-platform/ydb-go-sdk/v3/internal/dsn"
tableSql "github.com/ydb-platform/ydb-go-sdk/v3/internal/table/conn"
"github.com/ydb-platform/ydb-go-sdk/v3/internal/xcontext"
"github.com/ydb-platform/ydb-go-sdk/v3/internal/xerrors"
)

Expand Down Expand Up @@ -61,21 +60,21 @@ func parseConnectionString(dataSourceName string) (opts []Option, _ error) {
opts = append(opts, WithBalancer(balancers.FromConfig(balancer)))
}
if queryMode := info.Params.Get("go_query_mode"); queryMode != "" {
mode := xcontext.QueryModeFromString(queryMode)
mode := tableSql.QueryModeFromString(queryMode)
if mode == tableSql.UnknownQueryMode {
return nil, xerrors.WithStackTrace(fmt.Errorf("unknown query mode: %s", queryMode))
}
opts = append(opts, withConnectorOptions(connector.WithDefaultQueryMode(mode)))
} else if queryMode := info.Params.Get("query_mode"); queryMode != "" {
mode := xcontext.QueryModeFromString(queryMode)
mode := tableSql.QueryModeFromString(queryMode)
if mode == tableSql.UnknownQueryMode {
return nil, xerrors.WithStackTrace(fmt.Errorf("unknown query mode: %s", queryMode))
}
opts = append(opts, withConnectorOptions(connector.WithDefaultQueryMode(mode)))
}
if fakeTx := info.Params.Get("go_fake_tx"); fakeTx != "" {
for _, queryMode := range strings.Split(fakeTx, ",") {
mode := xcontext.QueryModeFromString(queryMode)
mode := tableSql.QueryModeFromString(queryMode)
if mode == tableSql.UnknownQueryMode {
return nil, xerrors.WithStackTrace(fmt.Errorf("unknown query mode: %s", queryMode))
}
Expand Down
54 changes: 54 additions & 0 deletions internal/connector/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@ package connector

import (
"context"
"database/sql"
"database/sql/driver"
"fmt"
"io"
"path"
"strings"
"time"
Expand Down Expand Up @@ -31,6 +33,8 @@ type (
driver.Validator
driver.NamedValueChecker

Explain(ctx context.Context, sql string) (ast string, plan string, err error)

LastUsage() time.Time
ID() string
}
Expand All @@ -39,8 +43,58 @@ type (

connector *Connector
}
singleRow struct {
values []sql.NamedArg
readAll bool
}
)

func (r *singleRow) Columns() (columns []string) {
for i := range r.values {
columns = append(columns, r.values[i].Name)
}

return columns
}

func (r *singleRow) Close() error {
return nil
}

func (r *singleRow) Next(dst []driver.Value) error {
if r.values == nil || r.readAll {
return io.EOF
}
for i := range r.values {
dst[i] = r.values[i].Value
}
r.readAll = true

return nil
}

func (c *connWrapper) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) {
if isExplain(ctx) {
ast, plan, err := c.conn.Explain(ctx, query)
if err != nil {
return nil, xerrors.WithStackTrace(err)
}

return &singleRow{
values: []sql.NamedArg{
sql.Named("AST", ast),
sql.Named("Plan", plan),
},
}, nil
}

return c.conn.QueryContext(ctx, query, args)
}

func (c *connWrapper) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) {
return c.conn.ExecContext(ctx, query, args)
}

func (c *connWrapper) GetDatabaseName() string {
return c.connector.Name()
}
Expand Down
15 changes: 15 additions & 0 deletions internal/connector/context.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
package connector

import "context"

type ctxExplainQueryModeKey struct{}

func WithExplain(ctx context.Context) context.Context {
return context.WithValue(ctx, ctxExplainQueryModeKey{}, true)
}

func isExplain(ctx context.Context) bool {
v, has := ctx.Value(ctxExplainQueryModeKey{}).(bool)

return has && v
}
54 changes: 18 additions & 36 deletions internal/query/conn/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package conn

import (
"context"
"database/sql"
"database/sql/driver"
"sync/atomic"

Expand Down Expand Up @@ -57,6 +56,22 @@ type Conn struct {
lastUsage atomic.Int64
}

func (c *Conn) Explain(ctx context.Context, sql string) (ast string, plan string, _ error) {
_, err := c.session.Query(
ctx, sql,
options.WithExecMode(options.ExecModeExplain),
options.WithStatsMode(options.StatsModeNone, func(stats stats.QueryStats) {
ast = stats.QueryAST()
plan = stats.QueryPlan()
}),
)
if err != nil {
return "", "", xerrors.WithStackTrace(err)
}

return ast, plan, nil
}

func New(ctx context.Context, parent Parent, s *query.Session, opts ...Option) *Conn {
cc := &Conn{
ctx: ctx,
Expand Down Expand Up @@ -134,7 +149,7 @@ func (c *Conn) execContext(

onDone := trace.DatabaseSQLOnConnExec(c.parent.Trace(), &ctx,
stack.FunctionID("github.com/ydb-platform/ydb-go-sdk/v3/internal/query/conn.(*Conn).execContext"),
query, xcontext.UnknownQueryMode.String(), xcontext.IsIdempotent(ctx), c.parent.Clock().Since(c.LastUsage()),
query, "query", xcontext.IsIdempotent(ctx), c.parent.Clock().Since(c.LastUsage()),
)
defer func() {
onDone(finalErr)
Expand Down Expand Up @@ -170,7 +185,7 @@ func (c *Conn) queryContext(ctx context.Context, queryString string, args []driv

onDone := trace.DatabaseSQLOnConnQuery(c.parent.Trace(), &ctx,
stack.FunctionID("github.com/ydb-platform/ydb-go-sdk/v3/internal/query/conn.(*Conn).queryContext"),
queryString, xcontext.UnknownQueryMode.String(), xcontext.IsIdempotent(ctx), c.parent.Clock().Since(c.LastUsage()),
queryString, "query", xcontext.IsIdempotent(ctx), c.parent.Clock().Since(c.LastUsage()),
)

defer func() {
Expand All @@ -182,12 +197,6 @@ func (c *Conn) queryContext(ctx context.Context, queryString string, args []driv
return nil, xerrors.WithStackTrace(err)
}

queryMode := xcontext.QueryModeFromContext(ctx, xcontext.UnknownQueryMode)

if queryMode == xcontext.ExplainQueryMode {
return c.queryContextExplain(ctx, normalizedQuery, parameters)
}

return c.queryContextOther(ctx, normalizedQuery, parameters)
}

Expand All @@ -209,30 +218,3 @@ func (c *Conn) queryContextOther(
result: res,
}, nil
}

func (c *Conn) queryContextExplain(
ctx context.Context,
queryString string,
parameters params.Parameters,
) (driver.Rows, error) {
var ast, plan string
_, err := c.session.Query(
ctx, queryString,
options.WithParameters(parameters),
options.WithExecMode(options.ExecModeExplain),
options.WithStatsMode(options.StatsModeNone, func(stats stats.QueryStats) {
ast = stats.QueryAST()
plan = stats.QueryPlan()
}),
)
if err != nil {
return nil, xerrors.WithStackTrace(err)
}

return &single{
values: []sql.NamedArg{
sql.Named("AST", ast),
sql.Named("Plan", plan),
},
}, nil
}
31 changes: 0 additions & 31 deletions internal/query/conn/rows.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package conn

import (
"context"
"database/sql"
"database/sql/driver"
"errors"
"io"
Expand All @@ -19,7 +18,6 @@ var (
_ driver.RowsNextResultSet = &rows{}
_ driver.RowsColumnTypeDatabaseTypeName = &rows{}
_ driver.RowsColumnTypeNullable = &rows{}
_ driver.Rows = &single{}

ignoreColumnPrefixName = "__discard_column_"
)
Expand Down Expand Up @@ -166,32 +164,3 @@ func (r *rows) Close() error {

return r.result.Close(ctx)
}

type single struct {
values []sql.NamedArg
readAll bool
}

func (r *single) Columns() (columns []string) {
for i := range r.values {
columns = append(columns, r.values[i].Name)
}

return columns
}

func (r *single) Close() error {
return nil
}

func (r *single) Next(dst []driver.Value) error {
if r.values == nil || r.readAll {
return io.EOF
}
for i := range r.values {
dst[i] = r.values[i].Value
}
r.readAll = true

return nil
}
32 changes: 12 additions & 20 deletions internal/table/conn/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package conn

import (
"context"
"database/sql"
"database/sql/driver"
"fmt"
"io"
Expand Down Expand Up @@ -60,6 +59,15 @@ type (
}
)

func (c *Conn) Explain(ctx context.Context, sql string) (ast string, plan string, err error) {
exp, err := c.session.Explain(ctx, sql)
if err != nil {
return "", "", badconn.Map(xerrors.WithStackTrace(err))
}

return exp.AST, exp.Plan, nil
}

func (c *Conn) LastUsage() time.Time {
return time.Unix(c.lastUsage.Load(), 0)
}
Expand Down Expand Up @@ -162,7 +170,7 @@ func (c *Conn) execContext(
return c.currentTx.ExecContext(ctx, query, args)
}

m := xcontext.QueryModeFromContext(ctx, c.defaultQueryMode)
m := queryModeFromContext(ctx, c.defaultQueryMode)
onDone := trace.DatabaseSQLOnConnExec(c.parent.Trace(), &ctx,
stack.FunctionID("github.com/ydb-platform/ydb-go-sdk/v3/internal/table/conn.(*Conn).execContext"),
query, m.String(), xcontext.IsIdempotent(ctx), c.parent.Clock().Since(c.LastUsage()),
Expand Down Expand Up @@ -285,7 +293,7 @@ func (c *Conn) queryContext(ctx context.Context, query string, args []driver.Nam
}

var (
queryMode = xcontext.QueryModeFromContext(ctx, c.defaultQueryMode)
queryMode = queryModeFromContext(ctx, c.defaultQueryMode)
onDone = trace.DatabaseSQLOnConnQuery(c.parent.Trace(), &ctx,
stack.FunctionID("github.com/ydb-platform/ydb-go-sdk/v3/internal/table/conn.(*Conn).queryContext"),
query, queryMode.String(), xcontext.IsIdempotent(ctx), c.parent.Clock().Since(c.LastUsage()),
Expand All @@ -305,8 +313,6 @@ func (c *Conn) queryContext(ctx context.Context, query string, args []driver.Nam
return c.execDataQuery(ctx, normalizedQuery, parameters)
case ScanQueryMode:
return c.execScanQuery(ctx, normalizedQuery, parameters)
case ExplainQueryMode:
return c.explainQuery(ctx, normalizedQuery)
case ScriptingQueryMode:
return c.execScriptingQuery(ctx, normalizedQuery, parameters)
default:
Expand Down Expand Up @@ -349,20 +355,6 @@ func (c *Conn) execScanQuery(ctx context.Context, query string, params params.Pa
}, nil
}

func (c *Conn) explainQuery(ctx context.Context, query string) (driver.Rows, error) {
exp, err := c.session.Explain(ctx, query)
if err != nil {
return nil, badconn.Map(xerrors.WithStackTrace(err))
}

return &single{
values: []sql.NamedArg{
sql.Named("AST", exp.AST),
sql.Named("Plan", exp.Plan),
},
}, nil
}

func (c *Conn) execScriptingQuery(ctx context.Context, query string, params params.Params) (driver.Rows, error) {
res, err := c.parent.Scripting().StreamExecute(ctx, query, &params)
if err != nil {
Expand Down Expand Up @@ -462,7 +454,7 @@ func (c *Conn) beginTx(ctx context.Context, txOptions driver.TxOptions) (tx curr
)
}

m := xcontext.QueryModeFromContext(ctx, c.defaultQueryMode)
m := queryModeFromContext(ctx, c.defaultQueryMode)

if slices.Contains(c.fakeTxModes, m) {
return beginTxFake(ctx, c), nil
Expand Down
Loading

0 comments on commit 9f7962d

Please sign in to comment.