Skip to content

Commit

Permalink
fix: convert decimal to the type expected by go-mysql-server (#49)
Browse files Browse the repository at this point in the history
  • Loading branch information
fanyang01 authored Sep 4, 2024
1 parent f04088e commit 6377aa9
Show file tree
Hide file tree
Showing 3 changed files with 73 additions and 44 deletions.
42 changes: 21 additions & 21 deletions binlogreplication/binlog_replication_alltypes_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -262,27 +262,27 @@ var allTypes = []typeDescription{
newTypeDescriptionAssertion("65535"),
},
},
// {
// TypeDefinition: "decimal",
// Assertions: [2]typeDescriptionAssertion{
// newTypeDescriptionAssertion("0"),
// newTypeDescriptionAssertion("1234567890"),
// },
// },
// {
// TypeDefinition: "decimal(10,2)",
// Assertions: [2]typeDescriptionAssertion{
// newTypeDescriptionAssertion("0.00"),
// newTypeDescriptionAssertion("12345678.00"),
// },
// },
// {
// TypeDefinition: "decimal(20,8)",
// Assertions: [2]typeDescriptionAssertion{
// newTypeDescriptionAssertion("-1234567890.12345678"),
// newTypeDescriptionAssertion("999999999999.00000001"),
// },
// },
{
TypeDefinition: "decimal",
Assertions: [2]typeDescriptionAssertion{
newTypeDescriptionAssertion("0"),
newTypeDescriptionAssertion("1234567890"),
},
},
{
TypeDefinition: "decimal(10,2)",
Assertions: [2]typeDescriptionAssertion{
newTypeDescriptionAssertion("0.00"),
newTypeDescriptionAssertion("12345678.00"),
},
},
{
TypeDefinition: "decimal(20,8)",
Assertions: [2]typeDescriptionAssertion{
newTypeDescriptionAssertion("-1234567890.12345678"),
newTypeDescriptionAssertion("999999999999.00000001"),
},
},

// Floating point types
{
Expand Down
5 changes: 1 addition & 4 deletions executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -197,10 +197,7 @@ func (b *DuckBuilder) executeQuery(ctx *sql.Context, n sql.Node, conn *stdsql.Co
return nil, err
}

// Create a new iterator
iter := &SQLRowIter{rows: rows, schema: n.Schema()}

return iter, nil
return NewSQLRowIter(rows, n.Schema())
}

func (b *DuckBuilder) executeDML(ctx *sql.Context, n sql.Node, conn *stdsql.Conn) (sql.RowIter, error) {
Expand Down
70 changes: 51 additions & 19 deletions iter.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,45 @@ package main
import (
stdsql "database/sql"
"io"
"strings"

"github.com/dolthub/go-mysql-server/sql"
"github.com/marcboeker/go-duckdb"
"github.com/shopspring/decimal"
)

var _ sql.RowIter = (*SQLRowIter)(nil)

// SQLRowIter wraps a standard sql.Rows as a RowIter.
type SQLRowIter struct {
rows *stdsql.Rows
schema sql.Schema
rows *stdsql.Rows
columns []*stdsql.ColumnType
schema sql.Schema
buffer []any // pre-allocated buffer for scanning values
pointers []any // pointers to the buffer
decimals []int
}

func NewSQLRowIter(rows *stdsql.Rows, schema sql.Schema) (*SQLRowIter, error) {
columns, err := rows.ColumnTypes()
if err != nil {
return nil, err
}

var decimals []int
for i, t := range columns {
if strings.HasPrefix(t.DatabaseTypeName(), "DECIMAL") {
decimals = append(decimals, i)
}
}

width := max(len(columns), len(schema))
buf := make([]any, width)
ptrs := make([]any, width)
for i := range buf {
ptrs[i] = &buf[i]
}
return &SQLRowIter{rows, columns, schema, buf, ptrs, decimals}, nil
}

// Next retrieves the next row. It will return io.EOF if it's the last row.
Expand All @@ -38,31 +67,34 @@ func (iter *SQLRowIter) Next(ctx *sql.Context) (sql.Row, error) {
return nil, io.EOF
}

columns, err := iter.rows.Columns()
if err != nil {
// Scan the values into the buffer
if err := iter.rows.Scan(iter.pointers[:len(iter.columns)]...); err != nil {
return nil, err
}

values := make([]interface{}, len(columns))
valuePtrs := make([]interface{}, len(columns))
for i := range values {
valuePtrs[i] = &values[i]
}

if err := iter.rows.Scan(valuePtrs...); err != nil {
return nil, err
// Process decimal values
for _, idx := range iter.decimals {
switch v := iter.buffer[idx].(type) {
case nil:
// nothing to do
case duckdb.Decimal:
iter.buffer[idx] = decimal.NewFromBigInt(v.Value, -int32(v.Scale))
case string:
iter.buffer[idx], _ = decimal.NewFromString(v)
default:
// nothing to do
}
}

// Prune the values to match the schema
if len(values) > len(iter.schema) {
values = values[:len(iter.schema)]
} else if len(values) < len(iter.schema) {
for i := len(values); i < len(iter.schema); i++ {
values = append(values, nil)
// Prune or fill the values to match the schema
width := len(iter.schema) // the desired width
if len(iter.columns) < width {
for i := len(iter.columns); i < width; i++ {
iter.buffer[i] = nil
}
}

return sql.NewRow(values...), nil
return sql.NewRow(iter.buffer[:width]...), nil
}

// Close closes the underlying sql.Rows.
Expand Down

0 comments on commit 6377aa9

Please sign in to comment.