Skip to content

Commit

Permalink
MongoDB: comparing *Ydb.Type
Browse files Browse the repository at this point in the history
  • Loading branch information
ninaiad committed Feb 3, 2025
1 parent d975e33 commit 5be187b
Show file tree
Hide file tree
Showing 4 changed files with 100 additions and 7 deletions.
2 changes: 1 addition & 1 deletion app/server/data_source_collection.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ func (dsc *DataSourceCollection) DoReadSplit(
Query: retry.NewRetrierFromConfig(mongoDbCfg.ExponentialBackoff, retry.ErrorCheckerNoop),
}, mongoDbCfg)

return readSplit(logger, stream, request, split, ds, dsc.memoryAllocator, dsc.readLimiterFactory, dsc.cfg)
return readSplit[string](logger, stream, request, split, ds, dsc.memoryAllocator, dsc.readLimiterFactory, dsc.cfg)
default:
return fmt.Errorf("unsupported data source type '%v': %w", kind, common.ErrDataSourceNotSupported)
}
Expand Down
3 changes: 3 additions & 0 deletions app/server/datasource/mongodb/datasource.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ func (ds *dataSource) DescribeTable(
err := ds.retrierSet.MakeConnection.Run(ctx, logger,
func() error {
var makeConnErr error

uri := fmt.Sprintf(
"mongodb://%s:%s@%s:%d/%s?%v&authSource=admin",
dsi.Credentials.GetBasic().Username,
Expand All @@ -75,10 +76,12 @@ func (ds *dataSource) DescribeTable(
if err := conn.Disconnect(ctx); err != nil {
logger.Fatal(fmt.Sprintf("conn.Disconnect: %v", err))
}

return fmt.Errorf("conn.Ping: %w", makeConnErr)
}

logger.Debug("Connected to MongoDB!")

return nil
},
)
Expand Down
8 changes: 2 additions & 6 deletions app/server/datasource/mongodb/type_mapping.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,6 @@ var errNull = errors.New("can't determine field type for null")
const idColumn string = "_id"
const objectIdTag string = "ObjectId"

func typesEqual(lhs, rhs *Ydb.Type) bool {
return lhs.String() == rhs.String()
}

func typeMap(logger *zap.Logger, v bson.RawValue, omitUnsupported bool) (*Ydb.Type, error) {
switch v.Type {
case bson.TypeInt32:
Expand Down Expand Up @@ -76,7 +72,7 @@ func typeMapArray(logger *zap.Logger, elements []bson.RawElement, omitUnsupporte
continue
}

if !typesEqual(newInnerType, innerType) {
if !common.TypesEqual(newInnerType, innerType) {
return common.MakeListType(common.MakePrimitiveType(Ydb.Type_UTF8)), nil
}
}
Expand Down Expand Up @@ -137,7 +133,7 @@ func bsonToYqlColumn(
// Leaving fields that have inconsistent types serialized
// Extra check for arrays because we might have encountered an empty one:
// we know it is an array, but prevType is not determined yet
if (prevTypeExists && !typesEqual(prevType, t)) || (prevIsArray && t.GetListType() == nil) {
if (prevTypeExists && !common.TypesEqual(prevType, t)) || (prevIsArray && t.GetListType() == nil) {
deducedTypes[key] = common.MakePrimitiveType(Ydb.Type_UTF8)

logger.Debug(fmt.Sprintf("bsonToYqlColumn: keeping serialized %v. prev: %v curr: %v", key, prevType.String(), tString))
Expand Down
94 changes: 94 additions & 0 deletions common/ydb_type_helpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"fmt"

"github.com/ydb-platform/ydb-go-genproto/protos/Ydb"
"google.golang.org/protobuf/types/known/structpb"

api_service_protos "github.com/ydb-platform/fq-connector-go/api/service/protos"
)
Expand Down Expand Up @@ -87,3 +88,96 @@ func MakeYdbDateTimeType(ydbTypeID Ydb.Type_PrimitiveTypeId, format api_service_
return nil, fmt.Errorf("unexpected datetime format '%s': %w", format, ErrInvalidRequest)
}
}

//nolint:gocyclo
func TypesEqual(lhs, rhs *Ydb.Type) bool {
switch lhsType := lhs.Type.(type) {
case *Ydb.Type_TypeId:
return lhsType.TypeId == rhs.GetTypeId()
case *Ydb.Type_NullType:
return rhs.GetNullType() != structpb.NullValue(0)
case *Ydb.Type_OptionalType:
rhsType := rhs.GetOptionalType()
return rhsType != nil &&
TypesEqual(rhsType.Item, lhsType.OptionalType.Item)
case *Ydb.Type_DictType:
rhsType := rhs.GetDictType()

return rhsType != nil &&
TypesEqual(rhsType.Key, lhsType.DictType.Key) &&
TypesEqual(rhsType.Payload, lhsType.DictType.Payload)
case *Ydb.Type_ListType:
rhsType := rhs.GetListType()
return rhsType != nil &&
TypesEqual(rhsType.Item, lhsType.ListType.Item)
case *Ydb.Type_DecimalType:
rhsType := rhs.GetDecimalType()

return rhsType != nil &&
rhsType.Precision == lhsType.DecimalType.Precision &&
rhsType.Scale == lhsType.DecimalType.Scale
case *Ydb.Type_TupleType:
rhsType := rhs.GetTupleType()
return rhsType != nil && tuplesEqual(rhsType, lhsType.TupleType)
case *Ydb.Type_StructType:
rhsType := rhs.GetStructType()
return rhsType != nil && structsEqual(rhsType, lhsType.StructType)
case *Ydb.Type_VariantType:
rhsType := rhs.GetVariantType()
return rhsType != nil && variantsEqual(rhsType, lhsType.VariantType)
case *Ydb.Type_TaggedType:
rhsType := rhs.GetTaggedType()
return rhsType.Tag == lhsType.TaggedType.Tag &&
TypesEqual(rhsType.Type, lhsType.TaggedType.Type)
case *Ydb.Type_VoidType:
return rhs.GetVoidType() != structpb.NullValue(0)
case *Ydb.Type_EmptyListType:
return rhs.GetEmptyListType() != structpb.NullValue(0)
case *Ydb.Type_EmptyDictType:
return rhs.GetEmptyDictType() != structpb.NullValue(0)
case *Ydb.Type_PgType:
rhsType := rhs.GetPgType()
return rhsType != nil && rhs.GetPgType().TypeName == lhsType.PgType.TypeName
}

panic("unreachable")
}

func tuplesEqual(lhs, rhs *Ydb.TupleType) bool {
if len(lhs.Elements) != len(rhs.Elements) {
return false
}

for i := range len(rhs.Elements) {
if !TypesEqual(rhs.Elements[i], lhs.Elements[i]) {
return false
}
}

return true
}

func structsEqual(lhs, rhs *Ydb.StructType) bool {
if len(rhs.Members) != len(lhs.Members) {
return false
}

for i := range len(rhs.Members) {
if rhs.Members[i].Name != lhs.Members[i].Name || !TypesEqual(rhs.Members[i].Type, lhs.Members[i].Type) {
return false
}
}

return true
}

func variantsEqual(lhs, rhs *Ydb.VariantType) bool {
switch innerType := lhs.Type.(type) {
case *Ydb.VariantType_TupleItems:
return tuplesEqual(innerType.TupleItems, rhs.GetTupleItems())
case *Ydb.VariantType_StructItems:
return structsEqual(innerType.StructItems, rhs.GetStructItems())
}

panic("unreachable")
}

0 comments on commit 5be187b

Please sign in to comment.