diff --git a/go/connection.go b/go/connection.go index 665b2a2..1010822 100644 --- a/go/connection.go +++ b/go/connection.go @@ -19,15 +19,20 @@ import ( "database/sql" "errors" "fmt" + "io" "strings" + "sync/atomic" "github.com/adbc-drivers/driverbase-go/driverbase" sqlwrapper "github.com/adbc-drivers/driverbase-go/sqlwrapper" "github.com/apache/arrow-adbc/go/adbc" "github.com/apache/arrow-go/v18/arrow" "github.com/apache/arrow-go/v18/arrow/array" + gomysql "github.com/go-sql-driver/mysql" ) +var loadReaderCounter atomic.Uint64 + const ( // Default num of rows per batch for batched INSERT MySQLDefaultIngestBatchSize = 1000 @@ -214,13 +219,17 @@ func (c *mysqlConnectionImpl) GetPlaceholder(field *arrow.Field, index int) stri // Ensure mysqlConnectionImpl implements BulkIngester var _ sqlwrapper.BulkIngester = (*mysqlConnectionImpl)(nil) -// ExecuteBulkIngest performs MySQL bulk ingest using batched INSERT statements. +// ExecuteBulkIngest performs MySQL bulk ingest using LOAD DATA LOCAL INFILE with a fallback to batched INSERTs. func (c *mysqlConnectionImpl) ExecuteBulkIngest(ctx context.Context, conn *sqlwrapper.LoggingConn, options *driverbase.BulkIngestOptions, stream array.RecordReader) (rowCount int64, err error) { schema := stream.Schema() if err := c.createTableIfNeeded(ctx, conn, options.TableName, schema, options); err != nil { return -1, c.ErrorHelper.WrapIO(err, "failed to create table") } + if c.isLoadDataEnabled(ctx, conn) { + return c.executeLoadDataIngest(ctx, conn, options, stream) + } + // Validate MySQL-specific options if options.MaxQuerySizeBytes > 0 { return -1, c.ErrorHelper.InvalidArgument( @@ -246,6 +255,73 @@ func (c *mysqlConnectionImpl) ExecuteBulkIngest(ctx context.Context, conn *sqlwr ) } +// isLoadDataEnabled checks if LOAD DATA LOCAL INFILE is enabled on the server. +func (c *mysqlConnectionImpl) isLoadDataEnabled(ctx context.Context, conn *sqlwrapper.LoggingConn) bool { + var localInfile int + err := conn.QueryRowContext(ctx, "SELECT @@local_infile").Scan(&localInfile) + return err == nil && localInfile == 1 +} + +// executeLoadDataIngest performs bulk ingestion using the LOAD DATA LOCAL INFILE command. +func (c *mysqlConnectionImpl) executeLoadDataIngest(ctx context.Context, conn *sqlwrapper.LoggingConn, options *driverbase.BulkIngestOptions, stream array.RecordReader) (int64, error) { + r, w := io.Pipe() + readerId := loadReaderCounter.Add(1) + readerName := fmt.Sprintf("adbc_ingest_%s_%d", options.TableName, readerId) + + gomysql.RegisterReaderHandler(readerName, func() io.Reader { + return r + }) + defer gomysql.DeregisterReaderHandler(readerName) + batchSize := options.IngestBatchSize + if batchSize <= 0 { + batchSize = 10000 // Default batch size for streaming chunks + } + + it, err := sqlwrapper.NewRowBufferIterator(stream, batchSize, c.TypeConverter) + if err != nil { + return -1, c.ErrorHelper.WrapIO(err, "failed to create row buffer iterator") + } + + numCols := len(stream.Schema().Fields()) + go func() { + config := CSVConfig{ + FieldDelimiter: '\t', + LineTerminator: '\n', + NullValue: "\\N", + EscapeBackslash: true, + } + err := arrowToCSV(ctx, w, it, numCols, config) + if err != nil { + _ = w.CloseWithError(err) + } else { + _ = w.Close() + } + }() + + var colNames []string + for _, field := range stream.Schema().Fields() { + colNames = append(colNames, quoteIdentifier(field.Name)) + } + colsList := strings.Join(colNames, ", ") + + query := fmt.Sprintf( + "LOAD DATA LOCAL INFILE 'Reader::%s' INTO TABLE %s CHARACTER SET utf8mb4 FIELDS TERMINATED BY '\\t' ESCAPED BY '\\\\' LINES TERMINATED BY '\\n' (%s)", + readerName, c.QuoteIdentifier(options.TableName), colsList, + ) + + res, err := conn.ExecContext(ctx, query) + if err != nil { + return -1, c.ErrorHelper.WrapIO(err, "failed to execute LOAD DATA statement") + } + + rowCount, err := res.RowsAffected() + if err != nil { + return -1, c.ErrorHelper.WrapIO(err, "failed to get rows affected") + } + + return rowCount, nil +} + // createTableIfNeeded creates the table based on the ingest mode func (c *mysqlConnectionImpl) createTableIfNeeded(ctx context.Context, conn *sqlwrapper.LoggingConn, tableName string, schema *arrow.Schema, options *driverbase.BulkIngestOptions) error { switch options.Mode { diff --git a/go/csv_helper.go b/go/csv_helper.go new file mode 100644 index 0000000..8167d76 --- /dev/null +++ b/go/csv_helper.go @@ -0,0 +1,107 @@ +// Copyright (c) 2025 ADBC Drivers Contributors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package mysql + +import ( + "context" + "fmt" + "io" + "strings" + + "github.com/adbc-drivers/driverbase-go/sqlwrapper" +) + +// CSVConfig defines the configuration for Arrow-to-CSV/TSV conversion. +type CSVConfig struct { + FieldDelimiter byte + LineTerminator byte + NullValue string + EscapeBackslash bool +} + +// arrowToCSV reads from a RowBufferIterator and streams data in CSV/TSV format into the provided io.Writer. +func arrowToCSV(ctx context.Context, w io.Writer, it *sqlwrapper.RowBufferIterator, numCols int, config CSVConfig) error { + var buf strings.Builder + + for it.Next() { + select { + case <-ctx.Done(): + return ctx.Err() + default: + } + + rows, rowCount := it.CurrentBatch() + + buf.Reset() + for rowIdx := 0; rowIdx < rowCount; rowIdx++ { + for colIdx := 0; colIdx < numCols; colIdx++ { + if colIdx > 0 { + buf.WriteByte(config.FieldDelimiter) + } + + val := rows[rowIdx*numCols+colIdx] + buf.WriteString(formatValueForCSV(val, config)) + } + buf.WriteByte(config.LineTerminator) + } + if _, err := io.WriteString(w, buf.String()); err != nil { + return fmt.Errorf("failed to write batch to pipe: %w", err) + } + } + + return it.Err() +} + +// escapeCSV escapes special characters based on the provided CSVConfig. +func escapeCSV(s string, config CSVConfig) string { + if config.EscapeBackslash { + s = strings.ReplaceAll(s, "\\", "\\\\") + s = strings.ReplaceAll(s, "\b", "\\b") + s = strings.ReplaceAll(s, "\x1a", "\\Z") + s = strings.ReplaceAll(s, "\x00", "\\0") + // Always escape \r if we are escaping backslashes, as it's a common special char + s = strings.ReplaceAll(s, "\r", "\\r") + } + + if config.FieldDelimiter == '\t' { + s = strings.ReplaceAll(s, "\t", "\\t") + } + if config.LineTerminator == '\n' { + s = strings.ReplaceAll(s, "\n", "\\n") + } + + return s +} + +// formatValueForCSV converts a Go interface{} to a string suitable for CSV/TSV, handling escaping. +func formatValueForCSV(val any, config CSVConfig) string { + if val == nil { + return config.NullValue + } + + switch v := val.(type) { + case string: + return escapeCSV(v, config) + case []byte: + return escapeCSV(string(v), config) + case bool: + if v { + return "1" + } + return "0" + default: + return fmt.Sprintf("%v", v) + } +} diff --git a/go/csv_helper_test.go b/go/csv_helper_test.go new file mode 100644 index 0000000..6d353df --- /dev/null +++ b/go/csv_helper_test.go @@ -0,0 +1,257 @@ +// Copyright (c) 2025 ADBC Drivers Contributors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package mysql + +import ( + "bytes" + "context" + "testing" + + "github.com/adbc-drivers/driverbase-go/sqlwrapper" + "github.com/apache/arrow-go/v18/arrow" + "github.com/apache/arrow-go/v18/arrow/array" + "github.com/apache/arrow-go/v18/arrow/memory" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestArrowToCSV_Basic(t *testing.T) { + mem := memory.NewCheckedAllocator(memory.DefaultAllocator) + defer mem.AssertSize(t, 0) + + schema := arrow.NewSchema([]arrow.Field{ + {Name: "int", Type: arrow.PrimitiveTypes.Int32}, + {Name: "str", Type: arrow.BinaryTypes.String}, + }, nil) + + b := array.NewRecordBuilder(mem, schema) + defer b.Release() + + b.Field(0).(*array.Int32Builder).AppendValues([]int32{1, 2}, nil) + b.Field(1).(*array.StringBuilder).AppendValues([]string{"a", "b"}, nil) + + rec := b.NewRecordBatch() + defer rec.Release() + + rr, err := array.NewRecordReader(schema, []arrow.RecordBatch{rec}) + require.NoError(t, err) + defer rr.Release() + + tc := &mySQLTypeConverter{ + DefaultTypeConverter: sqlwrapper.DefaultTypeConverter{VendorName: "MySQL"}, + } + it, err := sqlwrapper.NewRowBufferIterator(rr, 5, tc) + require.NoError(t, err) + + config := CSVConfig{ + FieldDelimiter: '\t', + LineTerminator: '\n', + NullValue: "\\N", + EscapeBackslash: true, + } + + var buf bytes.Buffer + err = arrowToCSV(context.Background(), &buf, it, 2, config) + require.NoError(t, err) + + expected := "1\ta\n2\tb\n" + assert.Equal(t, expected, buf.String()) +} + +func TestArrowToCSV_Escaping(t *testing.T) { + mem := memory.NewCheckedAllocator(memory.DefaultAllocator) + defer mem.AssertSize(t, 0) + + schema := arrow.NewSchema([]arrow.Field{ + {Name: "str", Type: arrow.BinaryTypes.String}, + }, nil) + + b := array.NewRecordBuilder(mem, schema) + defer b.Release() + + // Testing backslash, newline, tab, and carriage return + b.Field(0).(*array.StringBuilder).AppendValues([]string{ + "back\\slash", + "new\nline", + "ta\bt", + "car\rriage", + "null\x00byte", + }, nil) + + rec := b.NewRecordBatch() + defer rec.Release() + + rr, err := array.NewRecordReader(schema, []arrow.RecordBatch{rec}) + require.NoError(t, err) + defer rr.Release() + + tc := &mySQLTypeConverter{ + DefaultTypeConverter: sqlwrapper.DefaultTypeConverter{VendorName: "MySQL"}, + } + it, err := sqlwrapper.NewRowBufferIterator(rr, 5, tc) + require.NoError(t, err) + + config := CSVConfig{ + FieldDelimiter: '\t', + LineTerminator: '\n', + NullValue: "\\N", + EscapeBackslash: true, + } + + var buf bytes.Buffer + err = arrowToCSV(context.Background(), &buf, it, 1, config) + require.NoError(t, err) + + // In TSV, backslash is escaped as \\, newline as \n, tab as \t + // We handle \, \n, \t, \r, \b, \Z (Ctrl+Z), and \0 (Null byte) + expected := "back\\\\slash\nnew\\nline\nta\\bt\ncar\\rriage\nnull\\0byte\n" + assert.Equal(t, expected, buf.String()) +} + +func TestArrowToCSV_Nulls(t *testing.T) { + mem := memory.NewCheckedAllocator(memory.DefaultAllocator) + defer mem.AssertSize(t, 0) + + schema := arrow.NewSchema([]arrow.Field{ + {Name: "int", Type: arrow.PrimitiveTypes.Int32}, + {Name: "str", Type: arrow.BinaryTypes.String}, + }, nil) + + b := array.NewRecordBuilder(mem, schema) + defer b.Release() + + b.Field(0).(*array.Int32Builder).AppendValues([]int32{0, 1}, []bool{false, true}) + b.Field(1).(*array.StringBuilder).AppendValues([]string{"", "val"}, []bool{false, true}) + + rec := b.NewRecord() + defer rec.Release() + + rr, err := array.NewRecordReader(schema, []arrow.RecordBatch{rec}) + require.NoError(t, err) + defer rr.Release() + + tc := &mySQLTypeConverter{ + DefaultTypeConverter: sqlwrapper.DefaultTypeConverter{VendorName: "MySQL"}, + } + it, err := sqlwrapper.NewRowBufferIterator(rr, 5, tc) + require.NoError(t, err) + + config := CSVConfig{ + FieldDelimiter: '\t', + LineTerminator: '\n', + NullValue: "\\N", + EscapeBackslash: true, + } + + var buf bytes.Buffer + err = arrowToCSV(context.Background(), &buf, it, 2, config) + require.NoError(t, err) + + // Row 1: Null int, Null string -> \N\t\N\n + // Row 2: 1 int, val string -> 1\tval\n + expected := "\\N\t\\N\n1\tval\n" + assert.Equal(t, expected, buf.String()) +} + +func TestArrowToCSV_Batching(t *testing.T) { + mem := memory.NewCheckedAllocator(memory.DefaultAllocator) + defer mem.AssertSize(t, 0) + + schema := arrow.NewSchema([]arrow.Field{ + {Name: "int", Type: arrow.PrimitiveTypes.Int32}, + }, nil) + + b := array.NewRecordBuilder(mem, schema) + defer b.Release() + + for i := 1; i <= 10; i++ { + b.Field(0).(*array.Int32Builder).Append(int32(i)) + } + + rec := b.NewRecordBatch() + defer rec.Release() + + rr, err := array.NewRecordReader(schema, []arrow.RecordBatch{rec}) + require.NoError(t, err) + defer rr.Release() + + tc := &mySQLTypeConverter{ + DefaultTypeConverter: sqlwrapper.DefaultTypeConverter{VendorName: "MySQL"}, + } + // Set batch size to 3 to test multiple writes + it, err := sqlwrapper.NewRowBufferIterator(rr, 3, tc) + require.NoError(t, err) + + config := CSVConfig{ + FieldDelimiter: ',', + LineTerminator: '\n', + NullValue: "NULL", + EscapeBackslash: false, + } + + var buf bytes.Buffer + err = arrowToCSV(context.Background(), &buf, it, 1, config) + require.NoError(t, err) + + expected := "1\n2\n3\n4\n5\n6\n7\n8\n9\n10\n" + assert.Equal(t, expected, buf.String()) +} + +func TestArrowToCSV_Types(t *testing.T) { + mem := memory.NewCheckedAllocator(memory.DefaultAllocator) + defer mem.AssertSize(t, 0) + + schema := arrow.NewSchema([]arrow.Field{ + {Name: "bool", Type: arrow.FixedWidthTypes.Boolean}, + {Name: "float", Type: arrow.PrimitiveTypes.Float64}, + {Name: "bin", Type: arrow.BinaryTypes.Binary}, + }, nil) + + b := array.NewRecordBuilder(mem, schema) + defer b.Release() + + b.Field(0).(*array.BooleanBuilder).AppendValues([]bool{true, false}, nil) + b.Field(1).(*array.Float64Builder).AppendValues([]float64{1.23, 4.56}, nil) + b.Field(2).(*array.BinaryBuilder).AppendValues([][]byte{[]byte("bin1"), []byte("bin2")}, nil) + + rec := b.NewRecordBatch() + defer rec.Release() + + rr, err := array.NewRecordReader(schema, []arrow.RecordBatch{rec}) + require.NoError(t, err) + defer rr.Release() + + tc := &mySQLTypeConverter{ + DefaultTypeConverter: sqlwrapper.DefaultTypeConverter{VendorName: "MySQL"}, + } + it, err := sqlwrapper.NewRowBufferIterator(rr, 5, tc) + require.NoError(t, err) + + config := CSVConfig{ + FieldDelimiter: ',', + LineTerminator: '\n', + NullValue: "\\N", + EscapeBackslash: true, + } + + var buf bytes.Buffer + err = arrowToCSV(context.Background(), &buf, it, 3, config) + require.NoError(t, err) + + // Boolean true/false might be formatted as 1/0 or true/false depending on TypeConverter. + // MySQL driver usually prefers 1/0 for boolean. + expected := "1,1.23,bin1\n0,4.56,bin2\n" + assert.Equal(t, expected, buf.String()) +} diff --git a/go/mysql_ingest_test.go b/go/mysql_ingest_test.go new file mode 100644 index 0000000..b53ed5e --- /dev/null +++ b/go/mysql_ingest_test.go @@ -0,0 +1,358 @@ +// Copyright (c) 2025 ADBC Drivers Contributors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package mysql_test + +import ( + "context" + "fmt" + "os" + "sync" + "testing" + + "github.com/apache/arrow-adbc/go/adbc" + "github.com/apache/arrow-go/v18/arrow" + "github.com/apache/arrow-go/v18/arrow/array" + "github.com/apache/arrow-go/v18/arrow/memory" + "github.com/stretchr/testify/suite" +) + +type MySQLIngestTestSuite struct { + suite.Suite + + Quirks *MySQLQuirks + mem *memory.CheckedAllocator + ctx context.Context + driver adbc.Driver + db adbc.Database + cnxn adbc.Connection +} + +func (s *MySQLIngestTestSuite) SetupSuite() { + dsn := os.Getenv("MYSQL_DSN") + if dsn == "" { + s.T().Skip("MYSQL_DSN not set") + } + s.Quirks = &MySQLQuirks{dsn: dsn} + s.mem = memory.NewCheckedAllocator(memory.DefaultAllocator) + s.ctx = context.Background() + s.driver = s.Quirks.SetupDriver(s.T()) + + var err error + s.db, err = s.driver.NewDatabase(s.Quirks.DatabaseOptions()) + s.NoError(err) + s.cnxn, err = s.db.Open(s.ctx) + s.NoError(err) +} + +func (s *MySQLIngestTestSuite) TearDownSuite() { + if s.cnxn != nil { + s.NoError(s.cnxn.Close()) + } + if s.db != nil { + s.NoError(s.db.Close()) + } + s.Quirks.TearDownDriver(s.T(), s.driver) +} + +func (s *MySQLIngestTestSuite) getTableNames(testName string) []string { + switch testName { + case "TestConcurrentIngest": + var names []string + for i := 0; i < 5; i++ { + names = append(names, fmt.Sprintf("concurrent_ingest_%d", i)) + } + return names + case "TestLargeIngest": + return []string{"large_ingest_test"} + case "TestSchemaMismatch": + return []string{"schema_mismatch_test"} + case "TestIngestFallback": + return []string{"fallback_ingest_test"} + case "TestComplexTypes": + return []string{"complex_types_ingest"} + default: + return nil + } +} + +func (s *MySQLIngestTestSuite) cleanupTables(testName string) { + tables := s.getTableNames(testName) + if len(tables) == 0 { + return + } + + stmt, err := s.cnxn.NewStatement() + s.NoError(err) + defer stmt.Close() + + for _, table := range tables { + s.NoError(stmt.SetSqlQuery("DROP TABLE IF EXISTS " + table)) + _, _ = stmt.ExecuteUpdate(s.ctx) + } +} + +func (s *MySQLIngestTestSuite) BeforeTest(suiteName, testName string) { + s.cleanupTables(testName) +} + +func (s *MySQLIngestTestSuite) AfterTest(suiteName, testName string) { + s.cleanupTables(testName) +} + +// TestLargeIngest verifies that chunked ingestion works for datasets larger than BatchSize +func (s *MySQLIngestTestSuite) TestLargeIngest() { + const numRows = 15000 + tableName := "large_ingest_test" + + schema := arrow.NewSchema([]arrow.Field{ + {Name: "id", Type: arrow.PrimitiveTypes.Int32}, + {Name: "val", Type: arrow.BinaryTypes.String}, + }, nil) + + bldr := array.NewRecordBuilder(s.mem, schema) + defer bldr.Release() + + for i := 0; i < numRows; i++ { + bldr.Field(0).(*array.Int32Builder).Append(int32(i)) + bldr.Field(1).(*array.StringBuilder).Append(fmt.Sprintf("value-%d", i)) + } + + rec := bldr.NewRecordBatch() + defer rec.Release() + + stmt, err := s.cnxn.NewStatement() + s.NoError(err) + defer stmt.Close() + + s.NoError(stmt.SetOption(adbc.OptionKeyIngestTargetTable, tableName)) + s.NoError(stmt.SetOption(adbc.OptionKeyIngestMode, adbc.OptionValueIngestModeCreate)) + s.NoError(stmt.Bind(s.ctx, rec)) + + _, err = stmt.ExecuteUpdate(s.ctx) + s.NoError(err) + + s.NoError(stmt.SetSqlQuery("SELECT COUNT(*) FROM " + tableName)) + rdr, _, err := stmt.ExecuteQuery(s.ctx) + s.NoError(err) + defer rdr.Release() + + s.True(rdr.Next()) + countRec := rdr.RecordBatch() + s.Equal(int64(numRows), countRec.Column(0).(*array.Int64).Value(0)) +} + +// TestConcurrentIngest ensures that unique reader names prevent collisions during parallel ingestion +func (s *MySQLIngestTestSuite) TestConcurrentIngest() { + const numThreads = 5 + const rowsPerThread = 100 + + var wg sync.WaitGroup + wg.Add(numThreads) + + for i := 0; i < numThreads; i++ { + go func(id int) { + defer wg.Done() + + tableName := fmt.Sprintf("concurrent_ingest_%d", id) + + // Use a fresh connection per thread to avoid state conflicts if any + cnxn, err := s.db.Open(s.ctx) + s.NoError(err) + defer cnxn.Close() + + schema := arrow.NewSchema([]arrow.Field{{Name: "id", Type: arrow.PrimitiveTypes.Int32}}, nil) + bldr := array.NewRecordBuilder(s.mem, schema) + defer bldr.Release() + for r := 0; r < rowsPerThread; r++ { + bldr.Field(0).(*array.Int32Builder).Append(int32(r)) + } + rec := bldr.NewRecordBatch() + defer rec.Release() + + stmt, err := cnxn.NewStatement() + s.NoError(err) + defer stmt.Close() + + s.NoError(stmt.SetOption(adbc.OptionKeyIngestTargetTable, tableName)) + s.NoError(stmt.SetOption(adbc.OptionKeyIngestMode, adbc.OptionValueIngestModeCreate)) + s.NoError(stmt.Bind(s.ctx, rec)) + + _, err = stmt.ExecuteUpdate(s.ctx) + s.NoError(err) + + s.NoError(stmt.SetSqlQuery("SELECT COUNT(*) FROM " + tableName)) + rdr, _, err := stmt.ExecuteQuery(s.ctx) + s.NoError(err) + defer rdr.Release() + s.True(rdr.Next()) + s.Equal(int64(rowsPerThread), rdr.Record().Column(0).(*array.Int64).Value(0)) + }(i) + } + + wg.Wait() +} + +// TestSchemaMismatch verifies that appending data with extra columns returns an error +func (s *MySQLIngestTestSuite) TestSchemaMismatch() { + tableName := "schema_mismatch_test" + + // Create table with 1 column + schema1 := arrow.NewSchema([]arrow.Field{{Name: "col1", Type: arrow.PrimitiveTypes.Int32}}, nil) + bldr1 := array.NewRecordBuilder(s.mem, schema1) + defer bldr1.Release() + bldr1.Field(0).(*array.Int32Builder).Append(1) + rec1 := bldr1.NewRecordBatch() + defer rec1.Release() + + stmt, err := s.cnxn.NewStatement() + s.NoError(err) + defer stmt.Close() + + s.NoError(stmt.SetOption(adbc.OptionKeyIngestTargetTable, tableName)) + s.NoError(stmt.SetOption(adbc.OptionKeyIngestMode, adbc.OptionValueIngestModeCreate)) + s.NoError(stmt.Bind(s.ctx, rec1)) + _, err = stmt.ExecuteUpdate(s.ctx) + s.NoError(err) + + // Try to append with 2 columns + schema2 := arrow.NewSchema([]arrow.Field{ + {Name: "col1", Type: arrow.PrimitiveTypes.Int32}, + {Name: "col2", Type: arrow.PrimitiveTypes.Int32}, + }, nil) + bldr2 := array.NewRecordBuilder(s.mem, schema2) + defer bldr2.Release() + bldr2.Field(0).(*array.Int32Builder).Append(2) + bldr2.Field(1).(*array.Int32Builder).Append(3) + rec2 := bldr2.NewRecordBatch() + defer rec2.Release() + + s.NoError(stmt.SetOption(adbc.OptionKeyIngestMode, adbc.OptionValueIngestModeAppend)) + s.NoError(stmt.Bind(s.ctx, rec2)) + _, err = stmt.ExecuteUpdate(s.ctx) + + s.Error(err) + s.Contains(err.Error(), "Unknown column 'col2' in 'field list'") +} + +// TestIngestFallback verifies that the driver falls back to INSERT statements when LOAD DATA is disabled +func (s *MySQLIngestTestSuite) TestIngestFallback() { + db, err := s.driver.NewDatabase(s.Quirks.DatabaseOptions()) + s.NoError(err) + defer db.Close() + + cnxn, err := db.Open(s.ctx) + s.NoError(err) + defer cnxn.Close() + + adminStmt, err := cnxn.NewStatement() + s.NoError(err) + defer adminStmt.Close() + + // Try to disable. Note: This might require SUPER privilege on the server. + // If it fails, we might just skip the test or use a different approach. + err = adminStmt.SetSqlQuery("SET GLOBAL local_infile = 0") + s.NoError(err) + _, err = adminStmt.ExecuteUpdate(s.ctx) + if err != nil { + s.T().Skip("Skipping fallback test: failed to disable global local_infile (requires SUPER privilege)") + return + } + defer func() { + backStmt, err := cnxn.NewStatement() + if err == nil { + _ = backStmt.SetSqlQuery("SET GLOBAL local_infile = 1") + _, _ = backStmt.ExecuteUpdate(s.ctx) + _ = backStmt.Close() + } + }() + + tableName := "fallback_ingest_test" + schema := arrow.NewSchema([]arrow.Field{{Name: "col1", Type: arrow.PrimitiveTypes.Int32}}, nil) + bldr := array.NewRecordBuilder(s.mem, schema) + defer bldr.Release() + bldr.Field(0).(*array.Int32Builder).Append(123) + rec := bldr.NewRecordBatch() + defer rec.Release() + + stmt, err := cnxn.NewStatement() + s.NoError(err) + defer stmt.Close() + + s.NoError(stmt.SetOption(adbc.OptionKeyIngestTargetTable, tableName)) + s.NoError(stmt.SetOption(adbc.OptionKeyIngestMode, adbc.OptionValueIngestModeCreate)) + s.NoError(stmt.Bind(s.ctx, rec)) + + // This should now use INSERT statements because LOAD DATA is disabled + affected, err := stmt.ExecuteUpdate(s.ctx) + s.NoError(err) + s.EqualValues(1, affected) + + s.NoError(stmt.SetSqlQuery("SELECT * FROM " + tableName)) + rdr, _, err := stmt.ExecuteQuery(s.ctx) + s.NoError(err) + defer rdr.Release() + s.True(rdr.Next()) + s.Equal(int32(123), rdr.RecordBatch().Column(0).(*array.Int32).Value(0)) +} + +func (s *MySQLIngestTestSuite) TestComplexTypes() { + tableName := "complex_types_ingest" + + schema := arrow.NewSchema([]arrow.Field{ + {Name: "id", Type: arrow.PrimitiveTypes.Int32}, + {Name: "ts", Type: arrow.FixedWidthTypes.Timestamp_us}, + {Name: "b", Type: arrow.FixedWidthTypes.Boolean}, + }, nil) + + bldr := array.NewRecordBuilder(s.mem, schema) + defer bldr.Release() + + // 2026-03-13 12:00:00 UTC + tsValue := int64(1773403200000000) + + bldr.Field(0).(*array.Int32Builder).Append(1) + bldr.Field(1).(*array.TimestampBuilder).Append(arrow.Timestamp(tsValue)) + bldr.Field(2).(*array.BooleanBuilder).Append(true) + + rec := bldr.NewRecordBatch() + defer rec.Release() + + stmt, err := s.cnxn.NewStatement() + s.NoError(err) + defer stmt.Close() + + s.NoError(stmt.SetOption(adbc.OptionKeyIngestTargetTable, tableName)) + s.NoError(stmt.SetOption(adbc.OptionKeyIngestMode, adbc.OptionValueIngestModeCreate)) + s.NoError(stmt.Bind(s.ctx, rec)) + + _, err = stmt.ExecuteUpdate(s.ctx) + s.NoError(err) + + s.NoError(stmt.SetSqlQuery("SELECT * FROM " + tableName)) + rdr, _, err := stmt.ExecuteQuery(s.ctx) + s.NoError(err) + defer rdr.Release() + s.True(rdr.Next()) + + recOut := rdr.RecordBatch() + s.Equal(int32(1), recOut.Column(0).(*array.Int32).Value(0)) + s.NotNil(recOut.Column(1).(*array.Timestamp)) + // MySQL TINYINT(1) used for boolean is returned as Int8 by the ADBC driver + s.Equal(int8(1), recOut.Column(2).(*array.Int8).Value(0)) +} + +func TestMySQLIngestSuite(t *testing.T) { + suite.Run(t, new(MySQLIngestTestSuite)) +}