diff --git a/app/server/datasource/rdbms/ydb/connection_manager.go b/app/server/datasource/rdbms/ydb/connection_manager.go index 88cc5665..d017bfe1 100644 --- a/app/server/datasource/rdbms/ydb/connection_manager.go +++ b/app/server/datasource/rdbms/ydb/connection_manager.go @@ -84,8 +84,10 @@ func (c *connectionManager) Make( case config.TYdbConfig_MODE_UNSPECIFIED: fallthrough case config.TYdbConfig_MODE_QUERY_SERVICE_NATIVE: + logger.Debug("YDB Connector will use Native SDK over Query Service") ydbConn = newConnectionNative(ctx, c.QueryLoggerFactory.Make(logger), dsi, ydbDriver) case config.TYdbConfig_MODE_TABLE_SERVICE_STDLIB_SCAN_QUERIES: + logger.Debug("YDB Connector will use database/sql SDK with scan queries over Table Service") ydbConn, err = newConnectionDatabaseSQL(ctx, logger, c.QueryLoggerFactory.Make(logger), c.cfg, dsi, ydbDriver) default: return nil, fmt.Errorf("unknown mode: %v", c.cfg.Mode) diff --git a/app/server/datasource/rdbms/ydb/connection_native.go b/app/server/datasource/rdbms/ydb/connection_native.go index 48e784fa..6a6f6684 100644 --- a/app/server/datasource/rdbms/ydb/connection_native.go +++ b/app/server/datasource/rdbms/ydb/connection_native.go @@ -228,27 +228,21 @@ func (c *connectionNative) Close() error { return nil } -func newConnectionNative( - ctx context.Context, - queryLogger common.QueryLogger, - dsi *api_common.TDataSourceInstance, - driver *ydb_sdk.Driver, -) ydbConnection { - return &connectionNative{ - ctx: ctx, - driver: driver, - queryLogger: queryLogger, - dsi: dsi, - } -} - func (c *connectionNative) rewriteQuery(params *rdbms_utils.QueryParams) (string, error) { var buf bytes.Buffer buf.WriteString(fmt.Sprintf("PRAGMA TablePathPrefix(\"%s\");", c.dsi.Database)) //nolint:revive for i, arg := range params.QueryArgs.GetAll() { - typeName, err := primitiveYqlTypeName(arg.YdbType.GetTypeId()) + var primitiveTypeID Ydb.Type_PrimitiveTypeId + + if arg.YdbType.GetOptionalType() != nil { + primitiveTypeID = arg.YdbType.GetOptionalType().Item.GetTypeId() + } else { + primitiveTypeID = arg.YdbType.GetTypeId() + } + + typeName, err := primitiveYqlTypeName(primitiveTypeID) if err != nil { return "", fmt.Errorf("get YQL type name from value %v: %w", arg, err) } @@ -260,3 +254,17 @@ func (c *connectionNative) rewriteQuery(params *rdbms_utils.QueryParams) (string return buf.String(), nil } + +func newConnectionNative( + ctx context.Context, + queryLogger common.QueryLogger, + dsi *api_common.TDataSourceInstance, + driver *ydb_sdk.Driver, +) ydbConnection { + return &connectionNative{ + ctx: ctx, + driver: driver, + queryLogger: queryLogger, + dsi: dsi, + } +} diff --git a/tests/infra/datasource/ydb/suite.go b/tests/infra/datasource/ydb/suite.go index 6763bfe3..0eb6ef55 100644 --- a/tests/infra/datasource/ydb/suite.go +++ b/tests/infra/datasource/ydb/suite.go @@ -266,6 +266,20 @@ func (s *Suite) TestPushdownStringsUtf8() { ) } +func (s *Suite) TestPushdownStringsUtf8Optional() { + s.ValidateTable( + s.dataSource, + tables["pushdown_strings_utf8"], + suite.WithPredicate(&api_service_protos.TPredicate{ + Payload: tests_utils.MakePredicateComparisonColumn( + "col_02_utf8", + api_service_protos.TPredicate_TComparison_EQ, + common.MakeTypedValue(common.MakeOptionalType(common.MakePrimitiveType(Ydb.Type_UTF8)), "a"), + ), + }), + ) +} + func (s *Suite) TestPushdownStringsString() { s.ValidateTable( s.dataSource, @@ -280,6 +294,20 @@ func (s *Suite) TestPushdownStringsString() { ) } +func (s *Suite) TestPushdownStringsStringOptional() { + s.ValidateTable( + s.dataSource, + tables["pushdown_strings_string"], + suite.WithPredicate(&api_service_protos.TPredicate{ + Payload: tests_utils.MakePredicateComparisonColumn( + "col_03_string", + api_service_protos.TPredicate_TComparison_EQ, + common.MakeTypedValue(common.MakeOptionalType(common.MakePrimitiveType(Ydb.Type_STRING)), []byte("b")), + ), + }), + ) +} + func (s *Suite) TestLargeTable() { // For tables larger than 1000 rows, scan queries must be used, // otherwise output will be truncated. @@ -333,6 +361,7 @@ func (s *Suite) TestInvalidLogin() { if s.connectorMode == config.TYdbConfig_MODE_QUERY_SERVICE_NATIVE { s.T().Skip("Skipping test in QUERY_SERVICE_NATIVE mode") } + for _, dsi := range s.dataSource.Instances { suite.TestInvalidLogin(s.Base, dsi, tables["simple"]) } @@ -342,6 +371,7 @@ func (s *Suite) TestInvalidPassword() { if s.connectorMode == config.TYdbConfig_MODE_QUERY_SERVICE_NATIVE { s.T().Skip("Skipping test in QUERY_SERVICE_NATIVE mode") } + for _, dsi := range s.dataSource.Instances { suite.TestInvalidPassword(s.Base, dsi, tables["simple"]) } diff --git a/tests/main_test.go b/tests/main_test.go index 3f3f8d6b..b0d5053b 100644 --- a/tests/main_test.go +++ b/tests/main_test.go @@ -44,8 +44,8 @@ func TestPostgreSQL(t *testing.T) { func TestYDB(t *testing.T) { modes := []config.TYdbConfig_Mode{ - config.TYdbConfig_MODE_QUERY_SERVICE_NATIVE, config.TYdbConfig_MODE_TABLE_SERVICE_STDLIB_SCAN_QUERIES, + config.TYdbConfig_MODE_QUERY_SERVICE_NATIVE, } for _, mode := range modes { diff --git a/tests/suite/suite.go b/tests/suite/suite.go index 965e3f67..308747d4 100644 --- a/tests/suite/suite.go +++ b/tests/suite/suite.go @@ -212,7 +212,7 @@ func (b *Base[ID, IDBUILDER]) doValidateTable( table.MatchRecords(b.T(), records, schema) } -type SuiteOption interface { +type BaseOption interface { apply(cfg *baseConfig) } @@ -224,7 +224,7 @@ func (o *embeddedOption) apply(cfg *baseConfig) { cfg.embeddedOptions = append(cfg.embeddedOptions, o.options...) } -func WithEmbeddedOptions(options ...server.EmbeddedOption) SuiteOption { +func WithEmbeddedOptions(options ...server.EmbeddedOption) BaseOption { return &embeddedOption{ options: options, } @@ -233,7 +233,7 @@ func WithEmbeddedOptions(options ...server.EmbeddedOption) SuiteOption { func NewBase[ ID test_utils.TableIDTypes, IDBUILDER test_utils.ArrowIDBuilder[ID], -](t *testing.T, state *State, name string, suiteOptions ...SuiteOption) *Base[ID, IDBUILDER] { +](t *testing.T, state *State, name string, suiteOptions ...BaseOption) *Base[ID, IDBUILDER] { cfg := &baseConfig{ name: name, }