Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: convert decimal to the type expected by go-mysql-server #49

Merged
merged 4 commits into from
Sep 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 21 additions & 21 deletions binlogreplication/binlog_replication_alltypes_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -262,27 +262,27 @@ var allTypes = []typeDescription{
newTypeDescriptionAssertion("65535"),
},
},
// {
// TypeDefinition: "decimal",
// Assertions: [2]typeDescriptionAssertion{
// newTypeDescriptionAssertion("0"),
// newTypeDescriptionAssertion("1234567890"),
// },
// },
// {
// TypeDefinition: "decimal(10,2)",
// Assertions: [2]typeDescriptionAssertion{
// newTypeDescriptionAssertion("0.00"),
// newTypeDescriptionAssertion("12345678.00"),
// },
// },
// {
// TypeDefinition: "decimal(20,8)",
// Assertions: [2]typeDescriptionAssertion{
// newTypeDescriptionAssertion("-1234567890.12345678"),
// newTypeDescriptionAssertion("999999999999.00000001"),
// },
// },
{
TypeDefinition: "decimal",
Assertions: [2]typeDescriptionAssertion{
newTypeDescriptionAssertion("0"),
newTypeDescriptionAssertion("1234567890"),
},
},
{
TypeDefinition: "decimal(10,2)",
Assertions: [2]typeDescriptionAssertion{
newTypeDescriptionAssertion("0.00"),
newTypeDescriptionAssertion("12345678.00"),
},
},
{
TypeDefinition: "decimal(20,8)",
Assertions: [2]typeDescriptionAssertion{
newTypeDescriptionAssertion("-1234567890.12345678"),
newTypeDescriptionAssertion("999999999999.00000001"),
},
},

// Floating point types
{
Expand Down
5 changes: 1 addition & 4 deletions executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -197,10 +197,7 @@ func (b *DuckBuilder) executeQuery(ctx *sql.Context, n sql.Node, conn *stdsql.Co
return nil, err
}

// Create a new iterator
iter := &SQLRowIter{rows: rows, schema: n.Schema()}

return iter, nil
return NewSQLRowIter(rows, n.Schema())
}

func (b *DuckBuilder) executeDML(ctx *sql.Context, n sql.Node, conn *stdsql.Conn) (sql.RowIter, error) {
Expand Down
70 changes: 51 additions & 19 deletions iter.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,45 @@ package main
import (
stdsql "database/sql"
"io"
"strings"

"github.com/dolthub/go-mysql-server/sql"
"github.com/marcboeker/go-duckdb"
"github.com/shopspring/decimal"
)

var _ sql.RowIter = (*SQLRowIter)(nil)

// SQLRowIter wraps a standard sql.Rows as a RowIter.
type SQLRowIter struct {
rows *stdsql.Rows
schema sql.Schema
rows *stdsql.Rows
columns []*stdsql.ColumnType
schema sql.Schema
buffer []any // pre-allocated buffer for scanning values
pointers []any // pointers to the buffer
decimals []int
}

func NewSQLRowIter(rows *stdsql.Rows, schema sql.Schema) (*SQLRowIter, error) {
columns, err := rows.ColumnTypes()
if err != nil {
return nil, err
}

var decimals []int
for i, t := range columns {
if strings.HasPrefix(t.DatabaseTypeName(), "DECIMAL") {
decimals = append(decimals, i)
}
}

width := max(len(columns), len(schema))
buf := make([]any, width)
ptrs := make([]any, width)
for i := range buf {
ptrs[i] = &buf[i]
}
return &SQLRowIter{rows, columns, schema, buf, ptrs, decimals}, nil
}

// Next retrieves the next row. It will return io.EOF if it's the last row.
Expand All @@ -38,31 +67,34 @@ func (iter *SQLRowIter) Next(ctx *sql.Context) (sql.Row, error) {
return nil, io.EOF
}

columns, err := iter.rows.Columns()
if err != nil {
// Scan the values into the buffer
if err := iter.rows.Scan(iter.pointers[:len(iter.columns)]...); err != nil {
return nil, err
}

values := make([]interface{}, len(columns))
valuePtrs := make([]interface{}, len(columns))
for i := range values {
valuePtrs[i] = &values[i]
}

if err := iter.rows.Scan(valuePtrs...); err != nil {
return nil, err
// Process decimal values
for _, idx := range iter.decimals {
switch v := iter.buffer[idx].(type) {
case nil:
// nothing to do
case duckdb.Decimal:
iter.buffer[idx] = decimal.NewFromBigInt(v.Value, -int32(v.Scale))
GaoYusong marked this conversation as resolved.
Show resolved Hide resolved
case string:
iter.buffer[idx], _ = decimal.NewFromString(v)
default:
// nothing to do
}
}

// Prune the values to match the schema
if len(values) > len(iter.schema) {
values = values[:len(iter.schema)]
} else if len(values) < len(iter.schema) {
for i := len(values); i < len(iter.schema); i++ {
values = append(values, nil)
// Prune or fill the values to match the schema
width := len(iter.schema) // the desired width
if len(iter.columns) < width {
for i := len(iter.columns); i < width; i++ {
iter.buffer[i] = nil
}
}

return sql.NewRow(values...), nil
return sql.NewRow(iter.buffer[:width]...), nil
}

// Close closes the underlying sql.Rows.
Expand Down