diff --git a/app/server/data_source_collection.go b/app/server/data_source_collection.go index 838b4f50..c4d5fcf9 100644 --- a/app/server/data_source_collection.go +++ b/app/server/data_source_collection.go @@ -92,7 +92,7 @@ func (dsc *DataSourceCollection) DoReadSplit( Query: retry.NewRetrierFromConfig(mongoDbCfg.ExponentialBackoff, retry.ErrorCheckerNoop), }, mongoDbCfg) - return readSplit(logger, stream, request, split, ds, dsc.memoryAllocator, dsc.readLimiterFactory, dsc.cfg) + return readSplit[string](logger, stream, request, split, ds, dsc.memoryAllocator, dsc.readLimiterFactory, dsc.cfg) default: return fmt.Errorf("unsupported data source type '%v': %w", kind, common.ErrDataSourceNotSupported) } diff --git a/app/server/datasource/mongodb/type_mapping.go b/app/server/datasource/mongodb/type_mapping.go index 31c8f231..0121fa62 100644 --- a/app/server/datasource/mongodb/type_mapping.go +++ b/app/server/datasource/mongodb/type_mapping.go @@ -17,10 +17,6 @@ var errNull = errors.New("can't determine field type for null") const idColumn string = "_id" const objectIdTag string = "ObjectId" -func typesEqual(lhs, rhs *Ydb.Type) bool { - return lhs.String() == rhs.String() -} - func typeMap(logger *zap.Logger, v bson.RawValue, omitUnsupported bool) (*Ydb.Type, error) { switch v.Type { case bson.TypeInt32: @@ -76,7 +72,7 @@ func typeMapArray(logger *zap.Logger, elements []bson.RawElement, omitUnsupporte continue } - if !typesEqual(newInnerType, innerType) { + if !common.TypesEqual(newInnerType, innerType) { return common.MakeListType(common.MakePrimitiveType(Ydb.Type_UTF8)), nil } } @@ -137,7 +133,7 @@ func bsonToYqlColumn( // Leaving fields that have inconsistent types serialized // Extra check for arrays because we might have encountered an empty one: // we know it is an array, but prevType is not determined yet - if (prevTypeExists && !typesEqual(prevType, t)) || (prevIsArray && t.GetListType() == nil) { + if (prevTypeExists && !common.TypesEqual(prevType, t)) || (prevIsArray && t.GetListType() == nil) { deducedTypes[key] = common.MakePrimitiveType(Ydb.Type_UTF8) logger.Debug(fmt.Sprintf("bsonToYqlColumn: keeping serialized %v. prev: %v curr: %v", key, prevType.String(), tString)) diff --git a/common/ydb_type_helpers.go b/common/ydb_type_helpers.go index 52cf2f59..2c332939 100644 --- a/common/ydb_type_helpers.go +++ b/common/ydb_type_helpers.go @@ -4,6 +4,7 @@ import ( "fmt" "github.com/ydb-platform/ydb-go-genproto/protos/Ydb" + "google.golang.org/protobuf/types/known/structpb" api_service_protos "github.com/ydb-platform/fq-connector-go/api/service/protos" ) @@ -87,3 +88,95 @@ func MakeYdbDateTimeType(ydbTypeID Ydb.Type_PrimitiveTypeId, format api_service_ return nil, fmt.Errorf("unexpected datetime format '%s': %w", format, ErrInvalidRequest) } } + +func TypesEqual(lhs, rhs *Ydb.Type) bool { + switch lhsType := lhs.Type.(type) { + case *Ydb.Type_TypeId: + return lhsType.TypeId == rhs.GetTypeId() + case *Ydb.Type_NullType: + return rhs.GetNullType() != structpb.NullValue(0) + case *Ydb.Type_OptionalType: + rhsType := rhs.GetOptionalType() + return rhsType != nil && + TypesEqual(rhsType.Item, lhsType.OptionalType.Item) + case *Ydb.Type_DictType: + rhsType := rhs.GetDictType() + return rhsType != nil && + TypesEqual(rhsType.Key, lhsType.DictType.Key) && + TypesEqual(rhsType.Payload, lhsType.DictType.Payload) + case *Ydb.Type_ListType: + rhsType := rhs.GetListType() + return rhsType != nil && + TypesEqual(rhsType.Item, lhsType.ListType.Item) + case *Ydb.Type_DecimalType: + rhsType := rhs.GetDecimalType() + return rhsType != nil && + rhsType.Precision == lhsType.DecimalType.Precision && + rhsType.Scale == lhsType.DecimalType.Scale + case *Ydb.Type_TupleType: + rhsType := rhs.GetTupleType() + + return rhsType != nil && tuplesEqual(rhsType, lhsType.TupleType) + case *Ydb.Type_StructType: + rhsType := rhs.GetStructType() + + return rhsType != nil && structsEqual(rhsType, lhsType.StructType) + case *Ydb.Type_VariantType: + rhsType := rhs.GetVariantType() + if rhsType == nil { + return false + } + + switch innerType := lhsType.VariantType.Type.(type) { + case *Ydb.VariantType_TupleItems: + return tuplesEqual(innerType.TupleItems, lhsType.VariantType.GetTupleItems()) + case *Ydb.VariantType_StructItems: + return structsEqual(innerType.StructItems, lhsType.VariantType.GetStructItems()) + } + + case *Ydb.Type_TaggedType: + rhsType := rhs.GetTaggedType() + return rhsType.Tag == lhsType.TaggedType.Tag && TypesEqual(rhsType.Type, lhsType.TaggedType.Type) + + case *Ydb.Type_VoidType: + return rhs.GetVoidType() != structpb.NullValue(0) + case *Ydb.Type_EmptyListType: + return rhs.GetEmptyListType() != structpb.NullValue(0) + case *Ydb.Type_EmptyDictType: + return rhs.GetEmptyDictType() != structpb.NullValue(0) + case *Ydb.Type_PgType: + rhsType := rhs.GetPgType() + return rhsType != nil && + rhs.GetPgType().TypeName == lhsType.PgType.TypeName + } + + panic("unreachable") +} + +func tuplesEqual(lhs, rhs *Ydb.TupleType) bool { + if len(lhs.Elements) != len(rhs.Elements) { + return false + } + + for i := range len(rhs.Elements) { + if !TypesEqual(rhs.Elements[i], lhs.Elements[i]) { + return false + } + } + + return true +} + +func structsEqual(lhs, rhs *Ydb.StructType) bool { + if len(rhs.Members) != len(lhs.Members) { + return false + } + + for i := range len(rhs.Members) { + if rhs.Members[i].Name != lhs.Members[i].Name || !TypesEqual(rhs.Members[i].Type, lhs.Members[i].Type) { + return false + } + } + + return true +}