From 6377aa90546fbd0395c05772cd81638a1cc6a5a0 Mon Sep 17 00:00:00 2001 From: Fan Yang Date: Wed, 4 Sep 2024 21:12:25 +0800 Subject: [PATCH] fix: convert decimal to the type expected by go-mysql-server (#49) --- .../binlog_replication_alltypes_test.go | 42 +++++------ executor.go | 5 +- iter.go | 70 ++++++++++++++----- 3 files changed, 73 insertions(+), 44 deletions(-) diff --git a/binlogreplication/binlog_replication_alltypes_test.go b/binlogreplication/binlog_replication_alltypes_test.go index a0e6ca21..11f25e74 100644 --- a/binlogreplication/binlog_replication_alltypes_test.go +++ b/binlogreplication/binlog_replication_alltypes_test.go @@ -262,27 +262,27 @@ var allTypes = []typeDescription{ newTypeDescriptionAssertion("65535"), }, }, - // { - // TypeDefinition: "decimal", - // Assertions: [2]typeDescriptionAssertion{ - // newTypeDescriptionAssertion("0"), - // newTypeDescriptionAssertion("1234567890"), - // }, - // }, - // { - // TypeDefinition: "decimal(10,2)", - // Assertions: [2]typeDescriptionAssertion{ - // newTypeDescriptionAssertion("0.00"), - // newTypeDescriptionAssertion("12345678.00"), - // }, - // }, - // { - // TypeDefinition: "decimal(20,8)", - // Assertions: [2]typeDescriptionAssertion{ - // newTypeDescriptionAssertion("-1234567890.12345678"), - // newTypeDescriptionAssertion("999999999999.00000001"), - // }, - // }, + { + TypeDefinition: "decimal", + Assertions: [2]typeDescriptionAssertion{ + newTypeDescriptionAssertion("0"), + newTypeDescriptionAssertion("1234567890"), + }, + }, + { + TypeDefinition: "decimal(10,2)", + Assertions: [2]typeDescriptionAssertion{ + newTypeDescriptionAssertion("0.00"), + newTypeDescriptionAssertion("12345678.00"), + }, + }, + { + TypeDefinition: "decimal(20,8)", + Assertions: [2]typeDescriptionAssertion{ + newTypeDescriptionAssertion("-1234567890.12345678"), + newTypeDescriptionAssertion("999999999999.00000001"), + }, + }, // Floating point types { diff --git a/executor.go b/executor.go index 2eb7060e..d120faf3 100644 --- a/executor.go +++ b/executor.go @@ -197,10 +197,7 @@ func (b *DuckBuilder) executeQuery(ctx *sql.Context, n sql.Node, conn *stdsql.Co return nil, err } - // Create a new iterator - iter := &SQLRowIter{rows: rows, schema: n.Schema()} - - return iter, nil + return NewSQLRowIter(rows, n.Schema()) } func (b *DuckBuilder) executeDML(ctx *sql.Context, n sql.Node, conn *stdsql.Conn) (sql.RowIter, error) { diff --git a/iter.go b/iter.go index ff89f836..d63a945e 100644 --- a/iter.go +++ b/iter.go @@ -17,16 +17,45 @@ package main import ( stdsql "database/sql" "io" + "strings" "github.com/dolthub/go-mysql-server/sql" + "github.com/marcboeker/go-duckdb" + "github.com/shopspring/decimal" ) var _ sql.RowIter = (*SQLRowIter)(nil) // SQLRowIter wraps a standard sql.Rows as a RowIter. type SQLRowIter struct { - rows *stdsql.Rows - schema sql.Schema + rows *stdsql.Rows + columns []*stdsql.ColumnType + schema sql.Schema + buffer []any // pre-allocated buffer for scanning values + pointers []any // pointers to the buffer + decimals []int +} + +func NewSQLRowIter(rows *stdsql.Rows, schema sql.Schema) (*SQLRowIter, error) { + columns, err := rows.ColumnTypes() + if err != nil { + return nil, err + } + + var decimals []int + for i, t := range columns { + if strings.HasPrefix(t.DatabaseTypeName(), "DECIMAL") { + decimals = append(decimals, i) + } + } + + width := max(len(columns), len(schema)) + buf := make([]any, width) + ptrs := make([]any, width) + for i := range buf { + ptrs[i] = &buf[i] + } + return &SQLRowIter{rows, columns, schema, buf, ptrs, decimals}, nil } // Next retrieves the next row. It will return io.EOF if it's the last row. @@ -38,31 +67,34 @@ func (iter *SQLRowIter) Next(ctx *sql.Context) (sql.Row, error) { return nil, io.EOF } - columns, err := iter.rows.Columns() - if err != nil { + // Scan the values into the buffer + if err := iter.rows.Scan(iter.pointers[:len(iter.columns)]...); err != nil { return nil, err } - values := make([]interface{}, len(columns)) - valuePtrs := make([]interface{}, len(columns)) - for i := range values { - valuePtrs[i] = &values[i] - } - - if err := iter.rows.Scan(valuePtrs...); err != nil { - return nil, err + // Process decimal values + for _, idx := range iter.decimals { + switch v := iter.buffer[idx].(type) { + case nil: + // nothing to do + case duckdb.Decimal: + iter.buffer[idx] = decimal.NewFromBigInt(v.Value, -int32(v.Scale)) + case string: + iter.buffer[idx], _ = decimal.NewFromString(v) + default: + // nothing to do + } } - // Prune the values to match the schema - if len(values) > len(iter.schema) { - values = values[:len(iter.schema)] - } else if len(values) < len(iter.schema) { - for i := len(values); i < len(iter.schema); i++ { - values = append(values, nil) + // Prune or fill the values to match the schema + width := len(iter.schema) // the desired width + if len(iter.columns) < width { + for i := len(iter.columns); i < width; i++ { + iter.buffer[i] = nil } } - return sql.NewRow(values...), nil + return sql.NewRow(iter.buffer[:width]...), nil } // Close closes the underlying sql.Rows.