Skip to content

Commit

Permalink
MongoDB: comparing *Ydb.Type
Browse files Browse the repository at this point in the history
  • Loading branch information
ninaiad committed Jan 31, 2025
1 parent 88af95c commit 6af56bd
Show file tree
Hide file tree
Showing 3 changed files with 96 additions and 7 deletions.
2 changes: 1 addition & 1 deletion app/server/data_source_collection.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ func (dsc *DataSourceCollection) DoReadSplit(
Query: retry.NewRetrierFromConfig(mongoDbCfg.ExponentialBackoff, retry.ErrorCheckerNoop),
}, mongoDbCfg)

return readSplit(logger, stream, request, split, ds, dsc.memoryAllocator, dsc.readLimiterFactory, dsc.cfg)
return readSplit[string](logger, stream, request, split, ds, dsc.memoryAllocator, dsc.readLimiterFactory, dsc.cfg)
default:
return fmt.Errorf("unsupported data source type '%v': %w", kind, common.ErrDataSourceNotSupported)
}
Expand Down
8 changes: 2 additions & 6 deletions app/server/datasource/mongodb/type_mapping.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,6 @@ var errNull = errors.New("can't determine field type for null")
const idColumn string = "_id"
const objectIdTag string = "ObjectId"

func typesEqual(lhs, rhs *Ydb.Type) bool {
return lhs.String() == rhs.String()
}

func typeMap(logger *zap.Logger, v bson.RawValue, omitUnsupported bool) (*Ydb.Type, error) {
switch v.Type {
case bson.TypeInt32:
Expand Down Expand Up @@ -76,7 +72,7 @@ func typeMapArray(logger *zap.Logger, elements []bson.RawElement, omitUnsupporte
continue
}

if !typesEqual(newInnerType, innerType) {
if !common.TypesEqual(newInnerType, innerType) {
return common.MakeListType(common.MakePrimitiveType(Ydb.Type_UTF8)), nil
}
}
Expand Down Expand Up @@ -137,7 +133,7 @@ func bsonToYqlColumn(
// Leaving fields that have inconsistent types serialized
// Extra check for arrays because we might have encountered an empty one:
// we know it is an array, but prevType is not determined yet
if (prevTypeExists && !typesEqual(prevType, t)) || (prevIsArray && t.GetListType() == nil) {
if (prevTypeExists && !common.TypesEqual(prevType, t)) || (prevIsArray && t.GetListType() == nil) {
deducedTypes[key] = common.MakePrimitiveType(Ydb.Type_UTF8)

logger.Debug(fmt.Sprintf("bsonToYqlColumn: keeping serialized %v. prev: %v curr: %v", key, prevType.String(), tString))
Expand Down
93 changes: 93 additions & 0 deletions common/ydb_type_helpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"fmt"

"github.com/ydb-platform/ydb-go-genproto/protos/Ydb"
"google.golang.org/protobuf/types/known/structpb"

api_service_protos "github.com/ydb-platform/fq-connector-go/api/service/protos"
)
Expand Down Expand Up @@ -87,3 +88,95 @@ func MakeYdbDateTimeType(ydbTypeID Ydb.Type_PrimitiveTypeId, format api_service_
return nil, fmt.Errorf("unexpected datetime format '%s': %w", format, ErrInvalidRequest)
}
}

func TypesEqual(lhs, rhs *Ydb.Type) bool {
switch lhsType := lhs.Type.(type) {
case *Ydb.Type_TypeId:
return lhsType.TypeId == rhs.GetTypeId()
case *Ydb.Type_NullType:
return rhs.GetNullType() != structpb.NullValue(0)
case *Ydb.Type_OptionalType:
rhsType := rhs.GetOptionalType()
return rhsType != nil &&
TypesEqual(rhsType.Item, lhsType.OptionalType.Item)
case *Ydb.Type_DictType:
rhsType := rhs.GetDictType()
return rhsType != nil &&
TypesEqual(rhsType.Key, lhsType.DictType.Key) &&
TypesEqual(rhsType.Payload, lhsType.DictType.Payload)
case *Ydb.Type_ListType:
rhsType := rhs.GetListType()
return rhsType != nil &&
TypesEqual(rhsType.Item, lhsType.ListType.Item)
case *Ydb.Type_DecimalType:
rhsType := rhs.GetDecimalType()
return rhsType != nil &&
rhsType.Precision == lhsType.DecimalType.Precision &&
rhsType.Scale == lhsType.DecimalType.Scale
case *Ydb.Type_TupleType:
rhsType := rhs.GetTupleType()

return rhsType != nil && tuplesEqual(rhsType, lhsType.TupleType)
case *Ydb.Type_StructType:
rhsType := rhs.GetStructType()

return rhsType != nil && structsEqual(rhsType, lhsType.StructType)
case *Ydb.Type_VariantType:
rhsType := rhs.GetVariantType()
if rhsType == nil {
return false
}

switch innerType := lhsType.VariantType.Type.(type) {
case *Ydb.VariantType_TupleItems:
return tuplesEqual(innerType.TupleItems, lhsType.VariantType.GetTupleItems())
case *Ydb.VariantType_StructItems:
return structsEqual(innerType.StructItems, lhsType.VariantType.GetStructItems())
}

case *Ydb.Type_TaggedType:
rhsType := rhs.GetTaggedType()
return rhsType.Tag == lhsType.TaggedType.Tag && TypesEqual(rhsType.Type, lhsType.TaggedType.Type)

case *Ydb.Type_VoidType:
return rhs.GetVoidType() != structpb.NullValue(0)
case *Ydb.Type_EmptyListType:
return rhs.GetEmptyListType() != structpb.NullValue(0)
case *Ydb.Type_EmptyDictType:
return rhs.GetEmptyDictType() != structpb.NullValue(0)
case *Ydb.Type_PgType:
rhsType := rhs.GetPgType()
return rhsType != nil &&
rhs.GetPgType().TypeName == lhsType.PgType.TypeName
}

panic("unreachable")
}

func tuplesEqual(lhs, rhs *Ydb.TupleType) bool {
if len(lhs.Elements) != len(rhs.Elements) {
return false
}

for i := range len(rhs.Elements) {
if !TypesEqual(rhs.Elements[i], lhs.Elements[i]) {
return false
}
}

return true
}

func structsEqual(lhs, rhs *Ydb.StructType) bool {
if len(rhs.Members) != len(lhs.Members) {
return false
}

for i := range len(rhs.Members) {
if rhs.Members[i].Name != lhs.Members[i].Name || !TypesEqual(rhs.Members[i].Type, lhs.Members[i].Type) {
return false
}
}

return true
}

0 comments on commit 6af56bd

Please sign in to comment.