diff --git a/internal/value/cast.go b/internal/value/cast.go index ddde56ace..5615a3cc8 100644 --- a/internal/value/cast.go +++ b/internal/value/cast.go @@ -1,5 +1,12 @@ package value +import ( + "database/sql" + "database/sql/driver" + + "github.com/google/uuid" +) + func CastTo(v Value, dst interface{}) error { if dst == nil { return errNilDestination @@ -10,6 +17,21 @@ func CastTo(v Value, dst interface{}) error { return nil } + if _, ok := dst.(*uuid.UUID); ok { + return v.castTo(dst) + } + + if scanner, has := dst.(sql.Scanner); has { + dv := new(driver.Value) + + err := v.castTo(dv) + if err != nil { + return err + } + + return scanner.Scan(*dv) + } + if scanner, has := dst.(Scanner); has { return scanner.UnmarshalYDBValue(v) } diff --git a/internal/value/cast_test.go b/internal/value/cast_test.go index 8886bb4b5..8a8890d53 100644 --- a/internal/value/cast_test.go +++ b/internal/value/cast_test.go @@ -2,6 +2,7 @@ package value import ( "database/sql/driver" + "errors" "reflect" "testing" "time" @@ -32,12 +33,32 @@ func loadLocation(t *testing.T, name string) *time.Location { return loc } -type testStringValueScanner struct { - field string -} +type testStringValueScanner string func (s *testStringValueScanner) UnmarshalYDBValue(v Value) error { - return CastTo(v, &s.field) + var tmp string + + err := CastTo(v, &tmp) + if err != nil { + return err + } + + *s = testStringValueScanner(tmp) + + return nil +} + +type testStringSQLScanner string + +func (s *testStringSQLScanner) Scan(value any) error { + ts, ok := value.(string) + if !ok { + return errors.New("can't cast from " + reflect.TypeOf(value).String() + " to string") + } + + *s = testStringSQLScanner(ts) + + return nil } func TestCastTo(t *testing.T) { @@ -440,7 +461,14 @@ func TestCastTo(t *testing.T) { name: xtest.CurrentFileLine(), value: TextValue("text-string"), dst: ptr[testStringValueScanner](), - exp: testStringValueScanner{field: "text-string"}, + exp: testStringValueScanner("text-string"), + err: nil, + }, + { + name: xtest.CurrentFileLine(), + value: TextValue("text-string"), + dst: ptr[testStringSQLScanner](), + exp: testStringSQLScanner("text-string"), err: nil, }, } diff --git a/internal/value/value.go b/internal/value/value.go index 0b8331874..866fa37ad 100644 --- a/internal/value/value.go +++ b/internal/value/value.go @@ -1303,7 +1303,7 @@ func (v *listValue) castTo(dst any) error { inner.Set(newSlice) for i, item := range v.ListItems() { - if err := item.castTo(inner.Index(i).Addr().Interface()); err != nil { + if err := CastTo(item, inner.Index(i).Addr().Interface()); err != nil { return xerrors.WithStackTrace(fmt.Errorf( "%w '%s(%+v)' to '%T' destination", ErrCannotCast, v.Type().Yql(), v, dstValue, @@ -1437,7 +1437,7 @@ func (v *setValue) castTo(dst any) error { inner.Set(newSlice) for i, item := range v.items { - if err := item.castTo(inner.Index(i).Addr().Interface()); err != nil { + if err := CastTo(item, inner.Index(i).Addr().Interface()); err != nil { return xerrors.WithStackTrace(fmt.Errorf( "%w '%s(%+v)' to '%T' destination", ErrCannotCast, v.Type().Yql(), v, dstValue, @@ -1545,7 +1545,7 @@ func (v *optionalValue) castTo(dst any) error { return nil } - if err := v.value.castTo(ptr.Interface()); err != nil { + if err := CastTo(v.value, (ptr.Interface())); err != nil { return xerrors.WithStackTrace(err) } @@ -1560,7 +1560,7 @@ func (v *optionalValue) castTo(dst any) error { inner.Set(reflect.New(inner.Type().Elem())) - if err := v.value.castTo(inner.Interface()); err != nil { + if err := CastTo(v.value, inner.Interface()); err != nil { return xerrors.WithStackTrace(err) } @@ -1641,7 +1641,7 @@ func (v *structValue) castTo(dst any) error { } for i, field := range v.fields { - if err := field.V.castTo(inner.Field(i).Addr().Interface()); err != nil { + if err := CastTo(field.V, inner.Field(i).Addr().Interface()); err != nil { return xerrors.WithStackTrace(fmt.Errorf( "scan error on struct field name '%s': %w", field.Name, err, @@ -1768,7 +1768,7 @@ func (v *tupleValue) TupleItems() []Value { func (v *tupleValue) castTo(dst any) error { if len(v.items) == 1 { - return v.items[0].castTo(dst) + return CastTo(v.items[0], dst) } switch dstValue := dst.(type) {