From 72ba76f205b8bbeae998722cf97fdcbb9206a35b Mon Sep 17 00:00:00 2001 From: David Coe <> Date: Thu, 30 Apr 2026 00:46:56 -0400 Subject: [PATCH 1/6] fix but with proposed Reset --- go/connection.go | 21 +- go/database.go | 19 ++ go/driver.go | 9 + go/record_reader.go | 330 +++++++++++++++++++++--- go/record_reader_test.go | 535 +++++++++++++++++++++++++++++++++++++++ go/statement.go | 22 +- 6 files changed, 895 insertions(+), 41 deletions(-) create mode 100644 go/record_reader_test.go diff --git a/go/connection.go b/go/connection.go index 8f10771..132ac76 100644 --- a/go/connection.go +++ b/go/connection.go @@ -79,6 +79,7 @@ type connectionImpl struct { activeTransaction bool useHighPrecision bool + streamRetryEnabled bool maxTimestampPrecision MaxTimestampPrecision } @@ -365,9 +366,11 @@ func (c *connectionImpl) GetObjects(ctx context.Context, depth adbc.ObjectDepth, xdbcDataType := driverbase.ToXdbcDataType(field.Type) if field.Type != nil { - getObjectsCatalog.CatalogDbSchemas[i].DbSchemaTables[j].TableColumns[k].XdbcDataType = new(int16(field.Type.ID())) + v := int16(field.Type.ID()) + getObjectsCatalog.CatalogDbSchemas[i].DbSchemaTables[j].TableColumns[k].XdbcDataType = &v } - getObjectsCatalog.CatalogDbSchemas[i].DbSchemaTables[j].TableColumns[k].XdbcSqlDataType = new(int16(xdbcDataType)) + sqlDT := int16(xdbcDataType) + getObjectsCatalog.CatalogDbSchemas[i].DbSchemaTables[j].TableColumns[k].XdbcSqlDataType = &sqlDT } } } @@ -758,6 +761,7 @@ func (c *connectionImpl) NewStatement() (adbc.Statement, error) { queueSize: defaultStatementQueueSize, prefetchConcurrency: defaultPrefetchConcurrency, useHighPrecision: c.useHighPrecision, + streamRetryEnabled: c.streamRetryEnabled, maxTimestampPrecision: c.maxTimestampPrecision, ingestOptions: DefaultIngestOptions(), } @@ -809,6 +813,19 @@ func (c *connectionImpl) SetOption(key, value string) error { } } return nil + case OptionStreamRetryEnabled: + switch value { + case adbc.OptionValueEnabled: + c.streamRetryEnabled = true + case adbc.OptionValueDisabled: + c.streamRetryEnabled = false + default: + return adbc.Error{ + Msg: "[Snowflake] invalid value for option " + key + ": " + value, + Code: adbc.StatusInvalidArgument, + } + } + return nil default: return c.Base().SetOption(key, value) } diff --git a/go/database.go b/go/database.go index fb0e9d2..576eb77 100644 --- a/go/database.go +++ b/go/database.go @@ -82,6 +82,7 @@ type databaseImpl struct { cfg *gosnowflake.Config useHighPrecision bool + streamRetryEnabled bool maxTimestampPrecision MaxTimestampPrecision defaultAppName string } @@ -163,6 +164,11 @@ func (d *databaseImpl) GetOption(key string) (string, error) { return adbc.OptionValueEnabled, nil } return adbc.OptionValueDisabled, nil + case OptionStreamRetryEnabled: + if d.streamRetryEnabled { + return adbc.OptionValueEnabled, nil + } + return adbc.OptionValueDisabled, nil case OptionMaxTimestampPrecision: switch d.maxTimestampPrecision { case Microseconds: @@ -510,6 +516,18 @@ func (d *databaseImpl) SetOptionInternal(k string, v string, cnOptions *map[stri Code: adbc.StatusInvalidArgument, } } + case OptionStreamRetryEnabled: + switch v { + case adbc.OptionValueEnabled: + d.streamRetryEnabled = true + case adbc.OptionValueDisabled: + d.streamRetryEnabled = false + default: + return adbc.Error{ + Msg: fmt.Sprintf("Invalid value for database option '%s': '%s'", OptionStreamRetryEnabled, v), + Code: adbc.StatusInvalidArgument, + } + } case OptionMaxTimestampPrecision: switch v { case OptionValueNanoseconds, OptionValueNanosecondsNoOverflow, OptionValueMicroseconds: @@ -551,6 +569,7 @@ func (d *databaseImpl) Open(ctx context.Context) (adbcConnection adbc.Connection // SetOption(OptionUseHighPrecision, adbc.OptionValueDisabled) to // get Int64/Float64 instead useHighPrecision: d.useHighPrecision, + streamRetryEnabled: d.streamRetryEnabled, maxTimestampPrecision: d.maxTimestampPrecision, ConnectionImplBase: driverbase.NewConnectionImplBase(&d.DatabaseImplBase), } diff --git a/go/driver.go b/go/driver.go index 754b230..cc4b779 100644 --- a/go/driver.go +++ b/go/driver.go @@ -89,6 +89,15 @@ const ( // `microseconds`: Limits the max Timestamp precision to microseconds, which is safe for all values. OptionMaxTimestampPrecision = "adbc.snowflake.sql.client_option.max_timestamp_precision" + // OptionStreamRetryEnabled controls whether batch reads from Snowflake + // use a buffered approach that reads the entire HTTP response body into + // memory before IPC parsing, with retry on failure. When enabled, this + // reduces the TCP connection open time (mitigating connection resets from + // cloud storage) and provides clearer diagnostics on network errors. + // When disabled, batches are read in the original streaming mode directly + // from the network. Default is disabled. + OptionStreamRetryEnabled = "adbc.snowflake.sql.client_option.stream_retry_enabled" + OptionApplicationName = "adbc.snowflake.sql.client_option.app_name" OptionSSLSkipVerify = "adbc.snowflake.sql.client_option.tls_skip_verify" OptionOCSPFailOpenMode = "adbc.snowflake.sql.client_option.ocsp_fail_open_mode" diff --git a/go/record_reader.go b/go/record_reader.go index 7576766..34e722a 100644 --- a/go/record_reader.go +++ b/go/record_reader.go @@ -43,6 +43,8 @@ import ( "github.com/apache/arrow-go/v18/arrow/ipc" "github.com/apache/arrow-go/v18/arrow/memory" "github.com/snowflakedb/gosnowflake" + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/trace" "golang.org/x/sync/errgroup" ) @@ -559,7 +561,193 @@ type reader struct { done chan struct{} // signals all producer goroutines have finished } -func newRecordReader(ctx context.Context, alloc memory.Allocator, ld gosnowflake.ArrowStreamLoader, bufferSize, prefetchConcurrency int, useHighPrecision bool, maxTimestampPrecision MaxTimestampPrecision) (array.RecordReader, error) { +const defaultStreamMaxRetries = 3 + +// batchStreamer is the subset of gosnowflake.ArrowStreamBatch needed for reading. +type batchStreamer interface { + GetStream(ctx context.Context) (io.ReadCloser, error) +} + +// batchResetter is an optional interface that a batchStreamer may implement +// to allow clearing the cached stream for retry. gosnowflake's +// ArrowStreamBatch.GetStream caches its internal reader after the first +// successful HTTP response; without Reset, a mid-stream failure (e.g. TCP +// RST) leaves the batch permanently broken. When Reset is available, +// bufferBatchBody calls it before each retry to force a fresh download. +type batchResetter interface { + Reset() error +} + +// countingReadCloser wraps an io.ReadCloser and counts bytes read. +// Used for diagnosing truncated Arrow IPC streams. +type countingReadCloser struct { + inner io.ReadCloser + bytesRead int64 +} + +func (c *countingReadCloser) Read(p []byte) (int, error) { + n, err := c.inner.Read(p) + c.bytesRead += int64(n) + return n, err +} + +func (c *countingReadCloser) Close() error { + return c.inner.Close() +} + +// readBatchRecords reads all Arrow records from a Snowflake batch with retries. +// It buffers the entire stream body into memory before IPC parsing to isolate +// network I/O from Arrow deserialization. If the download fails, it retries +// up to maxRetries times. Records are only returned on full success. +// +// NOTE: Retry only works when GetStream itself fails (rr stays nil in +// gosnowflake). Mid-stream TCP resets cannot be retried because +// ArrowStreamBatch.GetStream caches its internal reader. The buffering +// still helps by reading the body at full network speed (reducing the +// window for connection resets) and providing clear diagnostics. +func readBatchRecords(ctx context.Context, batch batchStreamer, alloc memory.Allocator, transform recordTransformer, maxRetries int) ([]arrow.RecordBatch, error) { + var lastErr error + for attempt := 0; attempt <= maxRetries; attempt++ { + if ctx.Err() != nil { + return nil, ctx.Err() + } + + recs, err := tryReadBatch(ctx, batch, alloc, transform) + if err == nil { + trace.SpanFromContext(ctx).AddEvent("readBatchRecords.success", trace.WithAttributes( + attribute.Int("attempt", attempt), + attribute.Int("records", len(recs)), + )) + return recs, nil + } + trace.SpanFromContext(ctx).AddEvent("readBatchRecords.failed", trace.WithAttributes( + attribute.Int("attempt", attempt), + attribute.String("error", err.Error()), + )) + // Release any partial records from the failed attempt + for _, r := range recs { + r.Release() + } + lastErr = err + } + return nil, fmt.Errorf("failed to read Arrow batch after %d attempts: %w", maxRetries+1, lastErr) +} + +// tryReadBatch downloads the full stream body into memory, then parses +// Arrow IPC records from the buffer. Buffering the body first means: +// 1. The HTTP body is consumed at full network speed (no IPC parsing backpressure) +// 2. The TCP connection is held open for a shorter time +// 3. Network errors are caught before any IPC state is created +func tryReadBatch(ctx context.Context, batch batchStreamer, alloc memory.Allocator, transform recordTransformer) (recs []arrow.RecordBatch, err error) { + raw, err := batch.GetStream(ctx) + if err != nil { + return nil, err + } + + // Buffer the entire stream body into memory to isolate network I/O + data, err := io.ReadAll(raw) + closeErr := raw.Close() + if err != nil { + return nil, fmt.Errorf("failed to buffer stream body (read %d bytes): %w", len(data), err) + } + if closeErr != nil { + return nil, fmt.Errorf("failed to close stream after buffering %d bytes: %w", len(data), closeErr) + } + + // Parse IPC from the in-memory buffer — this cannot fail due to network issues + rr, err := ipc.NewReader(bytes.NewReader(data), ipc.WithAllocator(alloc)) + if err != nil { + return nil, fmt.Errorf("ipc.NewReader failed on %d buffered bytes: %w", len(data), err) + } + defer rr.Release() + + for rr.Next() && ctx.Err() == nil { + rec := rr.RecordBatch() + rec, err = transform(ctx, rec) + if err != nil { + return recs, err + } + recs = append(recs, rec) + } + if err = rr.Err(); err != nil { + return recs, err + } + if ctx.Err() != nil { + return recs, ctx.Err() + } + return recs, nil +} + +// bufferBatchBody downloads the full batch stream body into memory with +// retry on failure, returning the raw bytes. This isolates the network +// I/O so that IPC parsing can proceed from an in-memory buffer. +// The raw []byte is only held until the caller finishes parsing; callers +// should stream records to their destination as they parse rather than +// accumulating them, to minimize peak memory. +// +// If the batch implements batchResetter (i.e. has a Reset() method), +// it is called before each retry to clear the cached stream, enabling +// a fresh HTTP download. Without Reset, retries only help when +// GetStream itself fails before the HTTP response starts streaming. +func bufferBatchBody(ctx context.Context, batch batchStreamer, maxRetries int) ([]byte, error) { + var lastErr error + for attempt := 0; attempt <= maxRetries; attempt++ { + if ctx.Err() != nil { + return nil, ctx.Err() + } + + // On retries, reset the batch's cached stream so GetStream + // will re-download from cloud storage. + if attempt > 0 { + if resetter, ok := batch.(batchResetter); ok { + if err := resetter.Reset(); err != nil { + trace.SpanFromContext(ctx).AddEvent("bufferBatchBody.resetFailed", trace.WithAttributes( + attribute.Int("attempt", attempt), + attribute.String("error", err.Error()), + )) + // Reset failure is not fatal — GetStream may still + // return the stale stream, which will likely fail again. + } + } + } + + raw, err := batch.GetStream(ctx) + if err != nil { + trace.SpanFromContext(ctx).AddEvent("bufferBatchBody.getStreamFailed", trace.WithAttributes( + attribute.Int("attempt", attempt), + attribute.String("error", err.Error()), + )) + lastErr = err + continue + } + + data, err := io.ReadAll(raw) + closeErr := raw.Close() + if err != nil { + trace.SpanFromContext(ctx).AddEvent("bufferBatchBody.readFailed", trace.WithAttributes( + attribute.Int("attempt", attempt), + attribute.Int("bytesRead", len(data)), + attribute.String("error", err.Error()), + )) + lastErr = fmt.Errorf("failed to buffer stream body (read %d bytes): %w", len(data), err) + continue + } + if closeErr != nil { + lastErr = fmt.Errorf("failed to close stream after buffering %d bytes: %w", len(data), closeErr) + continue + } + + trace.SpanFromContext(ctx).AddEvent("bufferBatchBody.success", trace.WithAttributes( + attribute.Int("attempt", attempt), + attribute.Int("bytes", len(data)), + attribute.Int("capacityBytes", cap(data)), + )) + return data, nil + } + return nil, fmt.Errorf("failed to buffer batch body after %d attempts: %w", maxRetries+1, lastErr) +} + +func newRecordReader(ctx context.Context, alloc memory.Allocator, ld gosnowflake.ArrowStreamLoader, bufferSize, prefetchConcurrency int, useHighPrecision, streamRetryEnabled bool, maxTimestampPrecision MaxTimestampPrecision) (array.RecordReader, error) { batches, err := ld.GetBatches() if err != nil { return nil, errToAdbcErr(adbc.StatusInternal, err) @@ -683,17 +871,29 @@ func newRecordReader(ctx context.Context, alloc memory.Allocator, ld gosnowflake return rdr, nil } + trace.SpanFromContext(ctx).AddEvent("newRecordReader", trace.WithAttributes( + attribute.Int("batches", len(batches)), + attribute.Int64("totalRows", ld.TotalRows()), + attribute.Bool("streamRetryEnabled", streamRetryEnabled), + )) + // Do all error-prone initialization first, before starting goroutines - r, err := batches[0].GetStream(ctx) + raw0, err := batches[0].GetStream(ctx) if err != nil { return nil, errToAdbcErr(adbc.StatusIO, err) } + r := &countingReadCloser{inner: raw0} + rr, err := ipc.NewReader(r, ipc.WithAllocator(alloc)) if err != nil { - _ = r.Close() // Clean up the stream + trace.SpanFromContext(ctx).AddEvent("newRecordReader.ipcReaderFailed", trace.WithAttributes( + attribute.Int64("bytesRead", r.bytesRead), + attribute.String("error", err.Error()), + )) + _ = r.Close() return nil, adbc.Error{ - Msg: err.Error(), + Msg: fmt.Sprintf("batch[0]: ipc.NewReader failed after reading %d bytes: %s", r.bytesRead, err.Error()), Code: adbc.StatusInvalidState, } } @@ -749,51 +949,100 @@ func newRecordReader(ctx context.Context, alloc memory.Allocator, ld gosnowflake return rr.Err() }) + // Track cumulative buffer allocations across all batches for diagnostics + var totalBufferedBytes atomic.Int64 + var totalBufferCapacity atomic.Int64 + lastChannelIndex := len(chs) - 1 go func() { - for i, b := range batches[1:] { - batch, batchIdx := b, i+1 + for i := range batches[1:] { + batch, batchIdx := &batches[i+1], i+1 // Channels already initialized above, no need to create them here - group.Go(func() (err error) { - // close channels (except the last) so that Next can move on to the next channel properly - if batchIdx != lastChannelIndex { - defer close(chs[batchIdx]) - } + group.Go(func(batch batchStreamer, batchIdx int) func() error { + return func() (err error) { + // close channels (except the last) so that Next can move on to the next channel properly + if batchIdx != lastChannelIndex { + defer close(chs[batchIdx]) + } - rdr, err := batch.GetStream(ctx) - if err != nil { - return err - } - defer func() { - err = errors.Join(err, rdr.Close()) - }() + if streamRetryEnabled { + // Buffer the HTTP body into memory with retry, then parse IPC + // from the buffer while streaming records directly to the channel. + // This avoids accumulating all records in a local slice. + data, err := bufferBatchBody(ctx, batch, defaultStreamMaxRetries) + if err != nil { + trace.SpanFromContext(ctx).AddEvent("batch.bufferBody.failed", trace.WithAttributes( + attribute.Int("batchIndex", batchIdx), + attribute.String("error", err.Error()), + )) + return err + } + totalBufferedBytes.Add(int64(len(data))) + totalBufferCapacity.Add(int64(cap(data))) - rr, err := ipc.NewReader(rdr, ipc.WithAllocator(alloc)) - if err != nil { - return err - } - defer rr.Release() + rr, err := ipc.NewReader(bytes.NewReader(data), ipc.WithAllocator(alloc)) + if err != nil { + return fmt.Errorf("batch[%d]: ipc.NewReader failed on %d buffered bytes: %w", batchIdx, len(data), err) + } + defer rr.Release() + + for rr.Next() && ctx.Err() == nil { + rec := rr.RecordBatch() + rec, err = recTransform(ctx, rec) + if err != nil { + return err + } + select { + case chs[batchIdx] <- rec: + case <-ctx.Done(): + rec.Release() + return ctx.Err() + } + } + return rr.Err() + } - for rr.Next() && ctx.Err() == nil { - rec := rr.RecordBatch() - rec, err = recTransform(ctx, rec) + // Original streaming path: read directly from stream without buffering + rawStream, err := batch.GetStream(ctx) if err != nil { + trace.SpanFromContext(ctx).AddEvent("batch.GetStream.failed", trace.WithAttributes( + attribute.Int("batchIndex", batchIdx), + attribute.String("error", err.Error()), + )) return err } + countingStream := &countingReadCloser{inner: rawStream} + defer func() { + err = errors.Join(err, countingStream.Close()) + }() - // Use context-aware send to prevent deadlock - select { - case chs[batchIdx] <- rec: - // Successfully sent - case <-ctx.Done(): - // Context cancelled, clean up and exit - rec.Release() - return ctx.Err() + rr, err := ipc.NewReader(countingStream, ipc.WithAllocator(alloc)) + if err != nil { + trace.SpanFromContext(ctx).AddEvent("batch.ipcReader.failed", trace.WithAttributes( + attribute.Int("batchIndex", batchIdx), + attribute.Int64("bytesRead", countingStream.bytesRead), + attribute.String("error", err.Error()), + )) + return fmt.Errorf("batch[%d]: ipc.NewReader failed after reading %d bytes: %w", batchIdx, countingStream.bytesRead, err) } - } + defer rr.Release() - return rr.Err() - }) + for rr.Next() && ctx.Err() == nil { + rec := rr.RecordBatch() + rec, err = recTransform(ctx, rec) + if err != nil { + return err + } + select { + case chs[batchIdx] <- rec: + case <-ctx.Done(): + rec.Release() + return ctx.Err() + } + } + return rr.Err() + } + }(batch, batchIdx)) } // place this here so that we always clean up, but they can't be in a @@ -801,6 +1050,13 @@ func newRecordReader(ctx context.Context, alloc memory.Allocator, ld gosnowflake // the call to wait and the calls to group.Go to kick off the jobs // to perform the pre-fetching (GH-1283). rdr.err = group.Wait() + if streamRetryEnabled { + trace.SpanFromContext(ctx).AddEvent("streamRetry.summary", trace.WithAttributes( + attribute.Int64("totalBufferedBytes", totalBufferedBytes.Load()), + attribute.Int64("totalBufferCapacityBytes", totalBufferCapacity.Load()), + attribute.Int("batchCount", len(batches)-1), + )) + } // don't close the last channel until after the group is finished, // so that Next() can only return after reader.err may have been set close(chs[lastChannelIndex]) diff --git a/go/record_reader_test.go b/go/record_reader_test.go new file mode 100644 index 0000000..dc0abf6 --- /dev/null +++ b/go/record_reader_test.go @@ -0,0 +1,535 @@ +// Copyright (c) 2025 ADBC Drivers Contributors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package snowflake + +import ( + "bytes" + "context" + "fmt" + "io" + "testing" + + "github.com/apache/arrow-go/v18/arrow" + "github.com/apache/arrow-go/v18/arrow/array" + "github.com/apache/arrow-go/v18/arrow/ipc" + "github.com/apache/arrow-go/v18/arrow/memory" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// mockBatch implements batchStreamer for testing. +type mockBatch struct { + streams []func() (io.ReadCloser, error) + call int +} + +func (m *mockBatch) GetStream(ctx context.Context) (io.ReadCloser, error) { + if m.call >= len(m.streams) { + return nil, fmt.Errorf("no more streams configured") + } + fn := m.streams[m.call] + m.call++ + return fn() +} + +// mockResettableBatch implements both batchStreamer and batchResetter. +// Reset() clears the cached stream, simulating what gosnowflake's +// ArrowStreamBatch.Reset() does (clears the cached rr field). +type mockResettableBatch struct { + mockBatch + resetCalls int +} + +func (m *mockResettableBatch) Reset() error { + m.resetCalls++ + return nil +} + +// buildIPCBytes writes Arrow IPC record batches to a byte buffer. +func buildIPCBytes(alloc memory.Allocator, schema *arrow.Schema, records []arrow.RecordBatch) []byte { + var buf bytes.Buffer + w := ipc.NewWriter(&buf, ipc.WithSchema(schema), ipc.WithAllocator(alloc)) + for _, rec := range records { + _ = w.Write(rec) + } + _ = w.Close() + return buf.Bytes() +} + +func testSchema() *arrow.Schema { + return arrow.NewSchema([]arrow.Field{ + {Name: "id", Type: arrow.PrimitiveTypes.Int64}, + }, nil) +} + +func buildTestRecord(alloc memory.Allocator, schema *arrow.Schema, values []int64) arrow.RecordBatch { + bldr := array.NewRecordBuilder(alloc, schema) + defer bldr.Release() + for _, v := range values { + bldr.Field(0).(*array.Int64Builder).Append(v) + } + return bldr.NewRecordBatch() +} + +func identityTransform(_ context.Context, r arrow.RecordBatch) (arrow.RecordBatch, error) { + r.Retain() + return r, nil +} + +func failingTransform(msg string) recordTransformer { + return func(_ context.Context, r arrow.RecordBatch) (arrow.RecordBatch, error) { + return nil, fmt.Errorf("%s", msg) + } +} + +func streamFromBytes(data []byte) func() (io.ReadCloser, error) { + return func() (io.ReadCloser, error) { + return io.NopCloser(bytes.NewReader(data)), nil + } +} + +func streamError(err error) func() (io.ReadCloser, error) { + return func() (io.ReadCloser, error) { + return nil, err + } +} + +// truncatedReader returns partial data then an error, simulating a TCP RST +type truncatedReader struct { + data []byte + offset int + errAfter int + err error +} + +func (t *truncatedReader) Read(p []byte) (int, error) { + if t.offset >= t.errAfter { + return 0, t.err + } + end := min(t.offset+len(p), t.errAfter) + n := copy(p, t.data[t.offset:end]) + t.offset += n + if t.offset >= t.errAfter { + return n, t.err + } + return n, nil +} + +func (t *truncatedReader) Close() error { return nil } + +// --- tryReadBatch tests --- + +func TestTryReadBatch_Success(t *testing.T) { + alloc := memory.NewCheckedAllocator(memory.DefaultAllocator) + defer alloc.AssertSize(t, 0) + + schema := testSchema() + rec := buildTestRecord(alloc, schema, []int64{1, 2, 3}) + defer rec.Release() + + data := buildIPCBytes(alloc, schema, []arrow.RecordBatch{rec}) + batch := &mockBatch{streams: []func() (io.ReadCloser, error){streamFromBytes(data)}} + + recs, err := tryReadBatch(context.Background(), batch, alloc, identityTransform) + require.NoError(t, err) + require.Len(t, recs, 1) + defer recs[0].Release() + + assert.EqualValues(t, 3, recs[0].NumRows()) + col := recs[0].Column(0).(*array.Int64) + assert.EqualValues(t, 1, col.Value(0)) + assert.EqualValues(t, 2, col.Value(1)) + assert.EqualValues(t, 3, col.Value(2)) +} + +func TestTryReadBatch_MultipleRecords(t *testing.T) { + alloc := memory.NewCheckedAllocator(memory.DefaultAllocator) + defer alloc.AssertSize(t, 0) + + schema := testSchema() + rec1 := buildTestRecord(alloc, schema, []int64{10, 20}) + defer rec1.Release() + rec2 := buildTestRecord(alloc, schema, []int64{30, 40}) + defer rec2.Release() + + data := buildIPCBytes(alloc, schema, []arrow.RecordBatch{rec1, rec2}) + batch := &mockBatch{streams: []func() (io.ReadCloser, error){streamFromBytes(data)}} + + recs, err := tryReadBatch(context.Background(), batch, alloc, identityTransform) + require.NoError(t, err) + require.Len(t, recs, 2) + defer func() { + for _, r := range recs { + r.Release() + } + }() + + assert.EqualValues(t, 2, recs[0].NumRows()) + assert.EqualValues(t, 2, recs[1].NumRows()) +} + +func TestTryReadBatch_EmptyStream(t *testing.T) { + alloc := memory.NewCheckedAllocator(memory.DefaultAllocator) + defer alloc.AssertSize(t, 0) + + schema := testSchema() + data := buildIPCBytes(alloc, schema, nil) // no records, just schema + batch := &mockBatch{streams: []func() (io.ReadCloser, error){streamFromBytes(data)}} + + recs, err := tryReadBatch(context.Background(), batch, alloc, identityTransform) + require.NoError(t, err) + assert.Empty(t, recs) +} + +func TestTryReadBatch_GetStreamError(t *testing.T) { + alloc := memory.NewCheckedAllocator(memory.DefaultAllocator) + defer alloc.AssertSize(t, 0) + + batch := &mockBatch{streams: []func() (io.ReadCloser, error){ + streamError(fmt.Errorf("network down")), + }} + + recs, err := tryReadBatch(context.Background(), batch, alloc, identityTransform) + require.Error(t, err) + assert.Contains(t, err.Error(), "network down") + assert.Nil(t, recs) +} + +func TestTryReadBatch_TransformError(t *testing.T) { + alloc := memory.NewCheckedAllocator(memory.DefaultAllocator) + defer alloc.AssertSize(t, 0) + + schema := testSchema() + rec := buildTestRecord(alloc, schema, []int64{1}) + defer rec.Release() + + data := buildIPCBytes(alloc, schema, []arrow.RecordBatch{rec}) + batch := &mockBatch{streams: []func() (io.ReadCloser, error){streamFromBytes(data)}} + + recs, err := tryReadBatch(context.Background(), batch, alloc, failingTransform("bad transform")) + require.Error(t, err) + assert.Contains(t, err.Error(), "bad transform") + // partial recs may be returned; caller is responsible for releasing them + for _, r := range recs { + r.Release() + } +} + +func TestTryReadBatch_CancelledContext(t *testing.T) { + alloc := memory.NewCheckedAllocator(memory.DefaultAllocator) + defer alloc.AssertSize(t, 0) + + ctx, cancel := context.WithCancel(context.Background()) + cancel() // cancel immediately + + schema := testSchema() + rec := buildTestRecord(alloc, schema, []int64{1}) + defer rec.Release() + + data := buildIPCBytes(alloc, schema, []arrow.RecordBatch{rec}) + batch := &mockBatch{streams: []func() (io.ReadCloser, error){streamFromBytes(data)}} + + recs, err := tryReadBatch(ctx, batch, alloc, identityTransform) + // Either GetStream or context check will surface the error + if err != nil { + for _, r := range recs { + r.Release() + } + assert.ErrorIs(t, err, context.Canceled) + return + } + for _, r := range recs { + r.Release() + } +} + +func TestTryReadBatch_TruncatedStream(t *testing.T) { + alloc := memory.NewCheckedAllocator(memory.DefaultAllocator) + defer alloc.AssertSize(t, 0) + + schema := testSchema() + rec := buildTestRecord(alloc, schema, []int64{1, 2, 3}) + defer rec.Release() + + data := buildIPCBytes(alloc, schema, []arrow.RecordBatch{rec}) + + // Simulate a TCP RST after reading only half the data + tcpErr := fmt.Errorf("read tcp: wsarecv: connection forcibly closed") + batch := &mockBatch{streams: []func() (io.ReadCloser, error){ + func() (io.ReadCloser, error) { + return &truncatedReader{ + data: data, + errAfter: len(data) / 2, + err: tcpErr, + }, nil + }, + }} + + recs, err := tryReadBatch(context.Background(), batch, alloc, identityTransform) + require.Error(t, err) + assert.Contains(t, err.Error(), "failed to buffer stream body") + assert.Contains(t, err.Error(), "connection forcibly closed") + for _, r := range recs { + r.Release() + } +} + +// --- readBatchRecords tests --- + +func TestReadBatchRecords_SuccessFirstAttempt(t *testing.T) { + alloc := memory.NewCheckedAllocator(memory.DefaultAllocator) + defer alloc.AssertSize(t, 0) + + schema := testSchema() + rec := buildTestRecord(alloc, schema, []int64{5, 6}) + defer rec.Release() + + data := buildIPCBytes(alloc, schema, []arrow.RecordBatch{rec}) + batch := &mockBatch{streams: []func() (io.ReadCloser, error){streamFromBytes(data)}} + + recs, err := readBatchRecords(context.Background(), batch, alloc, identityTransform, 3) + require.NoError(t, err) + require.Len(t, recs, 1) + defer recs[0].Release() + + assert.EqualValues(t, 2, recs[0].NumRows()) +} + +func TestReadBatchRecords_SuccessAfterRetries(t *testing.T) { + alloc := memory.NewCheckedAllocator(memory.DefaultAllocator) + defer alloc.AssertSize(t, 0) + + schema := testSchema() + rec := buildTestRecord(alloc, schema, []int64{7, 8, 9}) + defer rec.Release() + + goodData := buildIPCBytes(alloc, schema, []arrow.RecordBatch{rec}) + + // First two calls fail, third succeeds + batch := &mockBatch{streams: []func() (io.ReadCloser, error){ + streamError(fmt.Errorf("fail 1")), + streamError(fmt.Errorf("fail 2")), + streamFromBytes(goodData), + }} + + recs, err := readBatchRecords(context.Background(), batch, alloc, identityTransform, 3) + require.NoError(t, err) + require.Len(t, recs, 1) + defer recs[0].Release() + + assert.EqualValues(t, 3, recs[0].NumRows()) +} + +func TestReadBatchRecords_ExhaustsRetries(t *testing.T) { + alloc := memory.NewCheckedAllocator(memory.DefaultAllocator) + defer alloc.AssertSize(t, 0) + + maxRetries := 2 + batch := &mockBatch{streams: []func() (io.ReadCloser, error){ + streamError(fmt.Errorf("fail 1")), + streamError(fmt.Errorf("fail 2")), + streamError(fmt.Errorf("fail 3")), + }} + + recs, err := readBatchRecords(context.Background(), batch, alloc, identityTransform, maxRetries) + require.Error(t, err) + assert.Nil(t, recs) + assert.Contains(t, err.Error(), "failed to read Arrow batch after 3 attempts") + assert.Contains(t, err.Error(), "fail 3") +} + +func TestReadBatchRecords_ZeroRetries(t *testing.T) { + alloc := memory.NewCheckedAllocator(memory.DefaultAllocator) + defer alloc.AssertSize(t, 0) + + batch := &mockBatch{streams: []func() (io.ReadCloser, error){ + streamError(fmt.Errorf("only chance")), + }} + + recs, err := readBatchRecords(context.Background(), batch, alloc, identityTransform, 0) + require.Error(t, err) + assert.Nil(t, recs) + assert.Contains(t, err.Error(), "failed to read Arrow batch after 1 attempts") +} + +func TestReadBatchRecords_CancelledBeforeRetry(t *testing.T) { + alloc := memory.NewCheckedAllocator(memory.DefaultAllocator) + defer alloc.AssertSize(t, 0) + + ctx, cancel := context.WithCancel(context.Background()) + + // First call fails, context cancelled before retry + batch := &mockBatch{streams: []func() (io.ReadCloser, error){ + func() (io.ReadCloser, error) { + cancel() // cancel after first failure + return nil, fmt.Errorf("fail 1") + }, + streamError(fmt.Errorf("should not reach")), + }} + + recs, err := readBatchRecords(ctx, batch, alloc, identityTransform, 3) + require.Error(t, err) + assert.Nil(t, recs) + assert.ErrorIs(t, err, context.Canceled) +} + +// --- bufferBatchBody + batchResetter tests --- + +func TestBufferBatchBody_CallsResetBeforeRetry(t *testing.T) { + expected := []byte("good data") + batch := &mockResettableBatch{ + mockBatch: mockBatch{streams: []func() (io.ReadCloser, error){ + streamError(fmt.Errorf("fail 1")), + streamError(fmt.Errorf("fail 2")), + func() (io.ReadCloser, error) { + return io.NopCloser(bytes.NewReader(expected)), nil + }, + }}, + } + + data, err := bufferBatchBody(context.Background(), batch, 3) + require.NoError(t, err) + assert.Equal(t, expected, data) + // Reset should have been called before attempt 1 and attempt 2 (not before attempt 0) + assert.Equal(t, 2, batch.resetCalls) +} + +func TestBufferBatchBody_NoResetOnFirstAttempt(t *testing.T) { + expected := []byte("first try works") + batch := &mockResettableBatch{ + mockBatch: mockBatch{streams: []func() (io.ReadCloser, error){ + func() (io.ReadCloser, error) { + return io.NopCloser(bytes.NewReader(expected)), nil + }, + }}, + } + + data, err := bufferBatchBody(context.Background(), batch, 3) + require.NoError(t, err) + assert.Equal(t, expected, data) + assert.Equal(t, 0, batch.resetCalls) // no retries needed, no Reset calls +} + +func TestBufferBatchBody_ResetNotCalledWithoutInterface(t *testing.T) { + // mockBatch does NOT implement batchResetter + expected := []byte("good data") + batch := &mockBatch{streams: []func() (io.ReadCloser, error){ + streamError(fmt.Errorf("fail")), + func() (io.ReadCloser, error) { + return io.NopCloser(bytes.NewReader(expected)), nil + }, + }} + + data, err := bufferBatchBody(context.Background(), batch, 1) + require.NoError(t, err) + assert.Equal(t, expected, data) + // No panic, no error — gracefully skips Reset when not available +} + +// --- countingReadCloser tests --- + +func TestCountingReadCloser(t *testing.T) { + data := []byte("hello world") + rc := &countingReadCloser{inner: io.NopCloser(bytes.NewReader(data))} + + buf := make([]byte, 5) + n, err := rc.Read(buf) + require.NoError(t, err) + assert.Equal(t, 5, n) + assert.EqualValues(t, 5, rc.bytesRead) + + n, err = rc.Read(buf) + require.NoError(t, err) + assert.Equal(t, 5, n) + assert.EqualValues(t, 10, rc.bytesRead) + + require.NoError(t, rc.Close()) +} + +// --- bufferBatchBody tests --- + +func TestBufferBatchBody_Success(t *testing.T) { + expected := []byte("hello arrow data") + batch := &mockBatch{streams: []func() (io.ReadCloser, error){ + func() (io.ReadCloser, error) { + return io.NopCloser(bytes.NewReader(expected)), nil + }, + }} + + data, err := bufferBatchBody(context.Background(), batch, 3) + require.NoError(t, err) + assert.Equal(t, expected, data) +} + +func TestBufferBatchBody_SuccessAfterRetries(t *testing.T) { + expected := []byte("good data") + batch := &mockBatch{streams: []func() (io.ReadCloser, error){ + streamError(fmt.Errorf("net fail 1")), + streamError(fmt.Errorf("net fail 2")), + func() (io.ReadCloser, error) { + return io.NopCloser(bytes.NewReader(expected)), nil + }, + }} + + data, err := bufferBatchBody(context.Background(), batch, 3) + require.NoError(t, err) + assert.Equal(t, expected, data) +} + +func TestBufferBatchBody_ExhaustsRetries(t *testing.T) { + batch := &mockBatch{streams: []func() (io.ReadCloser, error){ + streamError(fmt.Errorf("fail 1")), + streamError(fmt.Errorf("fail 2")), + }} + + data, err := bufferBatchBody(context.Background(), batch, 1) + require.Error(t, err) + assert.Nil(t, data) + assert.Contains(t, err.Error(), "failed to buffer batch body after 2 attempts") +} + +func TestBufferBatchBody_TruncatedStream(t *testing.T) { + fullData := []byte("this is a fairly long payload that will be truncated") + tcpErr := fmt.Errorf("read tcp: wsarecv: connection forcibly closed") + batch := &mockBatch{streams: []func() (io.ReadCloser, error){ + func() (io.ReadCloser, error) { + return &truncatedReader{ + data: fullData, + errAfter: 10, + err: tcpErr, + }, nil + }, + }} + + data, err := bufferBatchBody(context.Background(), batch, 0) + require.Error(t, err) + assert.Nil(t, data) + assert.Contains(t, err.Error(), "failed to buffer stream body") + assert.Contains(t, err.Error(), "connection forcibly closed") +} + +func TestBufferBatchBody_CancelledContext(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + batch := &mockBatch{streams: []func() (io.ReadCloser, error){ + streamError(fmt.Errorf("should not reach")), + }} + + data, err := bufferBatchBody(ctx, batch, 3) + require.Error(t, err) + assert.Nil(t, data) + assert.ErrorIs(t, err, context.Canceled) +} diff --git a/go/statement.go b/go/statement.go index df15bc1..f3d7d36 100644 --- a/go/statement.go +++ b/go/statement.go @@ -60,6 +60,7 @@ type statement struct { queueSize int prefetchConcurrency int useHighPrecision bool + streamRetryEnabled bool maxTimestampPrecision MaxTimestampPrecision query string @@ -130,6 +131,11 @@ func (st *statement) GetOption(key string) (string, error) { switch key { case OptionStatementQueryTag: return st.queryTag, nil + case OptionStreamRetryEnabled: + if st.streamRetryEnabled { + return adbc.OptionValueEnabled, nil + } + return adbc.OptionValueDisabled, nil default: return st.Base().GetOption(key) } @@ -253,6 +259,18 @@ func (st *statement) SetOption(key string, val string) error { Code: adbc.StatusInvalidArgument, } } + case OptionStreamRetryEnabled: + switch val { + case adbc.OptionValueEnabled: + st.streamRetryEnabled = true + case adbc.OptionValueDisabled: + st.streamRetryEnabled = false + default: + return adbc.Error{ + Msg: fmt.Sprintf("[Snowflake] invalid statement option %s=%s", key, val), + Code: adbc.StatusInvalidArgument, + } + } case OptionStatementVectorizedScanner: vectorized, err := strconv.ParseBool(val) if err != nil { @@ -548,7 +566,7 @@ func (st *statement) ExecuteQuery(ctx context.Context) (reader array.RecordReade return nil, err } - reader, err = newRecordReader(ctx, st.alloc, loader, st.queueSize, st.prefetchConcurrency, st.useHighPrecision, st.maxTimestampPrecision) + reader, err = newRecordReader(ctx, st.alloc, loader, st.queueSize, st.prefetchConcurrency, st.useHighPrecision, st.streamRetryEnabled, st.maxTimestampPrecision) return reader, err }, currentBatch: st.bound, @@ -573,7 +591,7 @@ func (st *statement) ExecuteQuery(ctx context.Context) (reader array.RecordReade return } - reader, err = newRecordReader(ctx, st.alloc, loader, st.queueSize, st.prefetchConcurrency, st.useHighPrecision, st.maxTimestampPrecision) + reader, err = newRecordReader(ctx, st.alloc, loader, st.queueSize, st.prefetchConcurrency, st.useHighPrecision, st.streamRetryEnabled, st.maxTimestampPrecision) nRows = loader.TotalRows() return } From f14cc4aa672874f7de1c03dae50c79bebf4f0509 Mon Sep 17 00:00:00 2001 From: David Coe <> Date: Mon, 8 Jun 2026 10:58:18 -0400 Subject: [PATCH 2/6] update mod --- csharp/arrow-adbc | 2 +- go/go.mod | 10 +++++----- go/go.sum | 20 ++++++++++---------- 3 files changed, 16 insertions(+), 16 deletions(-) diff --git a/csharp/arrow-adbc b/csharp/arrow-adbc index ffa5f0a..514e13b 160000 --- a/csharp/arrow-adbc +++ b/csharp/arrow-adbc @@ -1 +1 @@ -Subproject commit ffa5f0a2f32d83dcecf159ad9a25da0492c4f759 +Subproject commit 514e13b5753dd0f2005960656eb0238653590268 diff --git a/go/go.mod b/go/go.mod index 0ec93c5..e8737ff 100644 --- a/go/go.mod +++ b/go/go.mod @@ -17,13 +17,13 @@ module github.com/adbc-drivers/snowflake/go go 1.26.1 require ( - github.com/adbc-drivers/driverbase-go/driverbase v0.0.0-20260531215508-9a56b1c7bd6d - github.com/adbc-drivers/driverbase-go/testutil v0.0.0-20260427080211-5b908ab0cfd8 - github.com/adbc-drivers/driverbase-go/validation v0.0.0-20260427080211-5b908ab0cfd8 + github.com/adbc-drivers/driverbase-go/driverbase v0.0.0-20260608064711-7f3f9a9f3990 + github.com/adbc-drivers/driverbase-go/testutil v0.0.0-20260608002410-49f9e21a1d4a + github.com/adbc-drivers/driverbase-go/validation v0.0.0-20260608064711-7f3f9a9f3990 github.com/apache/arrow-adbc/go/adbc v1.11.0 github.com/apache/arrow-go/v18 v18.6.0 github.com/google/uuid v1.6.0 - github.com/snowflakedb/gosnowflake/v2 v2.0.2 + github.com/snowflakedb/gosnowflake/v2 v2.1.0 github.com/stretchr/testify v1.11.1 github.com/youmark/pkcs8 v0.0.0-20240726163527-a2c0da244d78 go.opentelemetry.io/otel v1.44.0 @@ -89,7 +89,7 @@ require ( go.opentelemetry.io/otel/sdk v1.44.0 // indirect go.opentelemetry.io/proto/otlp v1.10.0 // indirect golang.org/x/crypto v0.52.0 // indirect - golang.org/x/exp v0.0.0-20260529124908-c761662dc8c9 // indirect + golang.org/x/exp v0.0.0-20260603202125-055de637280b // indirect golang.org/x/mod v0.36.0 // indirect golang.org/x/net v0.55.0 // indirect golang.org/x/oauth2 v0.36.0 // indirect diff --git a/go/go.sum b/go/go.sum index e8ade07..36808a4 100644 --- a/go/go.sum +++ b/go/go.sum @@ -16,12 +16,12 @@ github.com/AzureAD/microsoft-authentication-library-for-go v1.7.2 h1:RHK7bS+HQMs github.com/AzureAD/microsoft-authentication-library-for-go v1.7.2/go.mod h1:HKpQxkWaGLJ+D/5H8QRpyQXA1eKjxkFlOMwck5+33Jk= github.com/BurntSushi/toml v1.6.0 h1:dRaEfpa2VI55EwlIW72hMRHdWouJeRF7TPYhI+AUQjk= github.com/BurntSushi/toml v1.6.0/go.mod h1:ukJfTF/6rtPPRCnwkur4qwRxa8vTRFBF0uk2lLoLwho= -github.com/adbc-drivers/driverbase-go/driverbase v0.0.0-20260531215508-9a56b1c7bd6d h1:xv3wwpcS1ByVK0KYqUmNsxVUvTuZtfoAm5foFQ9a8oQ= -github.com/adbc-drivers/driverbase-go/driverbase v0.0.0-20260531215508-9a56b1c7bd6d/go.mod h1:doFujhe7BcZTCWPPFlzT34PEAKyV7uqQgGIa2Teoxg8= -github.com/adbc-drivers/driverbase-go/testutil v0.0.0-20260427080211-5b908ab0cfd8 h1:cMopE5au+VwajOS3B3netWOJb4d/1hJ1O/1ez6ilsTM= -github.com/adbc-drivers/driverbase-go/testutil v0.0.0-20260427080211-5b908ab0cfd8/go.mod h1:uKkYVA+iUXij7mPFproEi6sOIWFVNOT87UcqucqqnBg= -github.com/adbc-drivers/driverbase-go/validation v0.0.0-20260427080211-5b908ab0cfd8 h1:AbVlFTLcSlVOrVuPC6guN5dhuEWFW1Nof/eTm8xdkog= -github.com/adbc-drivers/driverbase-go/validation v0.0.0-20260427080211-5b908ab0cfd8/go.mod h1:Tq+xMUZfJqVDBtdvzzn8BQoj6psg6INAngiAUZQhmTA= +github.com/adbc-drivers/driverbase-go/driverbase v0.0.0-20260608064711-7f3f9a9f3990 h1:b8vaQHcgHqmalpmIqqPF+bVWrMQISUUjxwFfRbT0uy0= +github.com/adbc-drivers/driverbase-go/driverbase v0.0.0-20260608064711-7f3f9a9f3990/go.mod h1:rCDugBvob6aEFZu7zJ9SqPSyyys6ipNWcsfvzF/Ns74= +github.com/adbc-drivers/driverbase-go/testutil v0.0.0-20260608002410-49f9e21a1d4a h1:j1a8iNxRTV6ZaHnEuH8/NVlBMrBejVTBPxu4wueB72c= +github.com/adbc-drivers/driverbase-go/testutil v0.0.0-20260608002410-49f9e21a1d4a/go.mod h1:dUNGWra6WmjWWeqEC6WcFZFuw6wrbJNjcZIiS0REizQ= +github.com/adbc-drivers/driverbase-go/validation v0.0.0-20260608064711-7f3f9a9f3990 h1:pHtKruOLTXwNrK1IJVjt+f00lGPOXFQaWFuUjvISZjQ= +github.com/adbc-drivers/driverbase-go/validation v0.0.0-20260608064711-7f3f9a9f3990/go.mod h1:UDTbsQIddSqov8hUA+rrchZEyzv9P5pqDOZGZR3kthk= github.com/andybalholm/brotli v1.2.1 h1:R+f5xP285VArJDRgowrfb9DqL18yVK0gKAW/F+eTWro= github.com/andybalholm/brotli v1.2.1/go.mod h1:rzTDkvFWvIrjDXZHkuS16NPggd91W3kUSvPlQ1pLaKY= github.com/apache/arrow-adbc/go/adbc v1.11.0 h1:zcSLtV8CQ27chkYWZmySvd4+pkkDtWhRtHz0LpglRAU= @@ -126,8 +126,8 @@ github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRI github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/rogpeppe/go-internal v1.14.1 h1:UQB4HGPB6osV0SQTLymcB4TgvyWu6ZyliaW0tI/otEQ= github.com/rogpeppe/go-internal v1.14.1/go.mod h1:MaRKkUm5W0goXpeCfT7UZI6fk/L7L7so1lCWt35ZSgc= -github.com/snowflakedb/gosnowflake/v2 v2.0.2 h1:8UZo+v1T2Y9sgoPk3JYT3RatAUd9o6q6yjL40TyHluA= -github.com/snowflakedb/gosnowflake/v2 v2.0.2/go.mod h1:c0hIqJ/dxgaMl7g1o8n4Ca3Mf5YCiiVx9igio/PNqC8= +github.com/snowflakedb/gosnowflake/v2 v2.1.0 h1:rfjs6NAMnbLKCBYlOarqQX/UKgQVrXi43TZNHCP5/jw= +github.com/snowflakedb/gosnowflake/v2 v2.1.0/go.mod h1:c0hIqJ/dxgaMl7g1o8n4Ca3Mf5YCiiVx9igio/PNqC8= github.com/stretchr/objx v0.5.2 h1:xuMeJ0Sdp5ZMRXx/aWO6RZxdr3beISkG5/G/aIRr3pY= github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA= github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= @@ -166,8 +166,8 @@ go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto= go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE= golang.org/x/crypto v0.52.0 h1:RMs7fP2rXdep0CftQlK8Uf+kibLm7qkCcradZWYz988= golang.org/x/crypto v0.52.0/go.mod h1:1QgfPxDqh0T2M/elOJtp9RvuR95kVjir0e6/BvEmGbc= -golang.org/x/exp v0.0.0-20260529124908-c761662dc8c9 h1:4d4PbuBNwaxMXkXI8yiIYjydtMU+04RHeuSxJdgKftM= -golang.org/x/exp v0.0.0-20260529124908-c761662dc8c9/go.mod h1:d2fgXJLVs4dYDHUk5lwMIfzRzSrWCfGZb0ZqeLa/Vcw= +golang.org/x/exp v0.0.0-20260603202125-055de637280b h1:v1uXiEBHo8QA0LiGCo7UgHMzHT4Kdfpl2zmtH5vaP1Q= +golang.org/x/exp v0.0.0-20260603202125-055de637280b/go.mod h1:d2fgXJLVs4dYDHUk5lwMIfzRzSrWCfGZb0ZqeLa/Vcw= golang.org/x/mod v0.36.0 h1:JJjpVx6myfUsUdAzZuOSTTmRE0PfZeNWzzvKrP7amb4= golang.org/x/mod v0.36.0/go.mod h1:moc6ELqsWcOw5Ef3xVprK5ul/MvtVvkIXLziUOICjUQ= golang.org/x/net v0.55.0 h1:bcvxaJn3e1U6InsFWt1JUq1aSjnRxLzT2rtD2KfkDF8= From c33713790c5ca29a2b228d38c08879738f1726d8 Mon Sep 17 00:00:00 2001 From: David Coe <> Date: Mon, 8 Jun 2026 17:09:15 -0400 Subject: [PATCH 3/6] clean up --- go/connection.go | 6 ++---- go/database.go | 4 +--- go/driver_test.go | 1 + 3 files changed, 4 insertions(+), 7 deletions(-) diff --git a/go/connection.go b/go/connection.go index 7af591f..561b33f 100644 --- a/go/connection.go +++ b/go/connection.go @@ -374,11 +374,9 @@ func (c *connectionImpl) GetObjects(ctx context.Context, depth adbc.ObjectDepth, xdbcDataType := driverbase.ToXdbcDataType(field.Type) if field.Type != nil { - v := int16(field.Type.ID()) - getObjectsCatalog.CatalogDbSchemas[i].DbSchemaTables[j].TableColumns[k].XdbcDataType = &v + getObjectsCatalog.CatalogDbSchemas[i].DbSchemaTables[j].TableColumns[k].XdbcDataType = driverbase.Nullable(int16(field.Type.ID())) } - sqlDT := int16(xdbcDataType) - getObjectsCatalog.CatalogDbSchemas[i].DbSchemaTables[j].TableColumns[k].XdbcSqlDataType = &sqlDT + getObjectsCatalog.CatalogDbSchemas[i].DbSchemaTables[j].TableColumns[k].XdbcSqlDataType = driverbase.Nullable(int16(xdbcDataType)) } } } diff --git a/go/database.go b/go/database.go index ee5faf4..e29d021 100644 --- a/go/database.go +++ b/go/database.go @@ -93,7 +93,6 @@ type databaseImpl struct { defaultAppName string } -//nolint:staticcheck // ignore snowflake deprecated warnings for now func (d *databaseImpl) GetOption(ctx context.Context, key string) (string, error) { switch key { case adbc.OptionKeyUsername: @@ -209,7 +208,6 @@ func ParseSnowflakeURI(uri string) (*gosnowflake.Config, error) { return gosnowflake.ParseDSN(uri) } -//nolint:staticcheck // ignore snowflake deprecated warnings for now func (d *databaseImpl) SetOptions(ctx context.Context, cnOptions map[string]string) error { uri, ok := cnOptions[adbc.OptionKeyURI] if ok { @@ -247,7 +245,7 @@ func (d *databaseImpl) SetOptions(ctx context.Context, cnOptions map[string]stri // // cnOptions is nil if the option is being set post-initialiation. // -//nolint:staticcheck // ignore snowflake deprecated warnings for now + func (d *databaseImpl) SetOptionInternal(k string, v string, cnOptions *map[string]string) error { var err error var ok bool diff --git a/go/driver_test.go b/go/driver_test.go index d434c03..b4682d4 100644 --- a/go/driver_test.go +++ b/go/driver_test.go @@ -234,6 +234,7 @@ func (s *SnowflakeQuirks) SupportsConcurrentStatements() bool { return func (s *SnowflakeQuirks) SupportsCurrentCatalogSchema() bool { return true } func (s *SnowflakeQuirks) SupportsExecuteSchema() bool { return true } func (s *SnowflakeQuirks) SupportsGetSetOptions() bool { return true } +func (s *SnowflakeQuirks) SupportsGetTableSchema() bool { return true } func (s *SnowflakeQuirks) SupportsPartitionedData() bool { return false } func (s *SnowflakeQuirks) SupportsStatistics() bool { return true } func (s *SnowflakeQuirks) SupportsTransactions() bool { return true } From 64c3142b49413448ca77a96fe9b67d500e4688a4 Mon Sep 17 00:00:00 2001 From: David Coe <> Date: Tue, 9 Jun 2026 08:55:23 -0400 Subject: [PATCH 4/6] linter fix 2 --- go/database.go | 25 +++++++++++-------------- 1 file changed, 11 insertions(+), 14 deletions(-) diff --git a/go/database.go b/go/database.go index e29d021..eced4c7 100644 --- a/go/database.go +++ b/go/database.go @@ -120,13 +120,13 @@ func (d *databaseImpl) GetOption(ctx context.Context, key string) (string, error case OptionAuthType: return d.cfg.Authenticator.String(), nil case OptionLoginTimeout: - return strconv.FormatFloat(d.cfg.LoginTimeout.Seconds(), 'f', -1, 64), nil + return strconv.FormatFloat(d.cfg.LoginTimeout.Seconds(), 'f', -1, 64), nil //nolint:staticcheck,nolintlint // ignore snowflake deprecated warnings for now case OptionRequestTimeout: - return strconv.FormatFloat(d.cfg.RequestTimeout.Seconds(), 'f', -1, 64), nil + return strconv.FormatFloat(d.cfg.RequestTimeout.Seconds(), 'f', -1, 64), nil //nolint:staticcheck,nolintlint // ignore snowflake deprecated warnings for now case OptionJwtExpireTimeout: - return strconv.FormatFloat(d.cfg.JWTExpireTimeout.Seconds(), 'f', -1, 64), nil + return strconv.FormatFloat(d.cfg.JWTExpireTimeout.Seconds(), 'f', -1, 64), nil //nolint:staticcheck,nolintlint // ignore snowflake deprecated warnings for now case OptionClientTimeout: - return strconv.FormatFloat(d.cfg.ClientTimeout.Seconds(), 'f', -1, 64), nil + return strconv.FormatFloat(d.cfg.ClientTimeout.Seconds(), 'f', -1, 64), nil //nolint:staticcheck,nolintlint // ignore snowflake deprecated warnings for now case OptionApplicationName: return d.cfg.Application, nil case OptionSSLSkipVerify: @@ -163,7 +163,7 @@ func (d *databaseImpl) GetOption(ctx context.Context, key string) (string, error } return adbc.OptionValueDisabled, nil case OptionLogTracing: - return d.cfg.Tracing, nil + return d.cfg.Tracing, nil //nolint:staticcheck,nolintlint // ignore snowflake deprecated warnings for now case OptionClientConfigFile: return d.cfg.ClientConfigFile, nil case OptionUseHighPrecision: @@ -225,8 +225,7 @@ func (d *databaseImpl) SetOptions(ctx context.Context, cnOptions map[string]stri } // XXX(https://github.com/apache/arrow-adbc/issues/2792): Snowflake // has a tendency to spam the log by default, so set the log level - - d.cfg.Tracing = "fatal" + d.cfg.Tracing = "fatal" //nolint:staticcheck,nolintlint // ignore snowflake deprecated warnings for now // set default application name to track // unless user overrides it @@ -244,8 +243,6 @@ func (d *databaseImpl) SetOptions(ctx context.Context, cnOptions map[string]stri // SetOptionInternal sets the option for the database. // // cnOptions is nil if the option is being set post-initialiation. -// - func (d *databaseImpl) SetOptionInternal(k string, v string, cnOptions *map[string]string) error { var err error var ok bool @@ -299,7 +296,7 @@ func (d *databaseImpl) SetOptionInternal(k string, v string, cnOptions *map[stri if dur < 0 { dur = -dur } - d.cfg.LoginTimeout = dur + d.cfg.LoginTimeout = dur //nolint:staticcheck,nolintlint // ignore snowflake deprecated warnings for now case OptionRequestTimeout: dur, err := time.ParseDuration(v) if err != nil { @@ -311,7 +308,7 @@ func (d *databaseImpl) SetOptionInternal(k string, v string, cnOptions *map[stri if dur < 0 { dur = -dur } - d.cfg.RequestTimeout = dur + d.cfg.RequestTimeout = dur //nolint:staticcheck,nolintlint // ignore snowflake deprecated warnings for now case OptionJwtExpireTimeout: dur, err := time.ParseDuration(v) if err != nil { @@ -323,7 +320,7 @@ func (d *databaseImpl) SetOptionInternal(k string, v string, cnOptions *map[stri if dur < 0 { dur = -dur } - d.cfg.JWTExpireTimeout = dur + d.cfg.JWTExpireTimeout = dur //nolint:staticcheck,nolintlint // ignore snowflake deprecated warnings for now case OptionClientTimeout: dur, err := time.ParseDuration(v) if err != nil { @@ -335,7 +332,7 @@ func (d *databaseImpl) SetOptionInternal(k string, v string, cnOptions *map[stri if dur < 0 { dur = -dur } - d.cfg.ClientTimeout = dur + d.cfg.ClientTimeout = dur //nolint:staticcheck,nolintlint // ignore snowflake deprecated warnings for now case OptionApplicationName: if !strings.HasPrefix(v, "[ADBC]") { v = d.defaultAppName + v @@ -511,7 +508,7 @@ func (d *databaseImpl) SetOptionInternal(k string, v string, cnOptions *map[stri } } case OptionLogTracing: - d.cfg.Tracing = v + d.cfg.Tracing = v //nolint:staticcheck,nolintlint // ignore snowflake deprecated warnings for now case OptionClientConfigFile: d.cfg.ClientConfigFile = v case OptionUseHighPrecision: From 118f0ddead6e61b9ee0d3a95fbd7505f9f82ed3a Mon Sep 17 00:00:00 2001 From: David Coe <> Date: Tue, 9 Jun 2026 14:55:43 -0400 Subject: [PATCH 5/6] add test --- go/record_reader_test.go | 146 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 146 insertions(+) diff --git a/go/record_reader_test.go b/go/record_reader_test.go index 76d3da0..a27c531 100644 --- a/go/record_reader_test.go +++ b/go/record_reader_test.go @@ -20,6 +20,7 @@ import ( "context" "fmt" "io" + "strings" "sync" "testing" "time" @@ -807,3 +808,148 @@ func TestReaderCancellationSetsErrBeforeNextAndReleaseReturns(t *testing.T) { require.FailNow(t, "reader.Release should not block after cancellation") } } + +// shortReadStream delivers data in chunks no larger than chunkSize per Read. +// If errAtOffset > 0, it returns injectedErr once that many bytes have been +// delivered, simulating mid-frame truncation (gosnowflake#1781). +type shortReadStream struct { + data []byte + pos int + chunkSize int + errAtOffset int + injectedErr error +} + +func (s *shortReadStream) Read(p []byte) (int, error) { + if s.errAtOffset > 0 && s.pos >= s.errAtOffset { + return 0, s.injectedErr + } + if s.pos >= len(s.data) { + return 0, io.EOF + } + n := min(len(p), s.chunkSize) + remaining := len(s.data) - s.pos + if n > remaining { + n = remaining + } + if s.errAtOffset > 0 && s.pos+n > s.errAtOffset { + n = s.errAtOffset - s.pos + } + copy(p, s.data[s.pos:s.pos+n]) + s.pos += n + return n, nil +} + +func (s *shortReadStream) Close() error { return nil } + +func streamShortReads(data []byte, chunkSize int) func(context.Context) (io.ReadCloser, error) { + return func(context.Context) (io.ReadCloser, error) { + return &shortReadStream{data: data, chunkSize: chunkSize}, nil + } +} + +func streamShortReadsThenError(data []byte, chunkSize, errAtOffset int, err error) func(context.Context) (io.ReadCloser, error) { + return func(context.Context) (io.ReadCloser, error) { + return &shortReadStream{ + data: data, + chunkSize: chunkSize, + errAtOffset: errAtOffset, + injectedErr: err, + }, nil + } +} + +// Sanity check that legal short reads alone (n < len(p), nil err) don't break +// IPC decoding — the failure in gosnowflake#1781 requires a mid-frame error. +func TestTryReadBatch_ShortReadsSucceed(t *testing.T) { + alloc := memory.NewCheckedAllocator(memory.DefaultAllocator) + defer alloc.AssertSize(t, 0) + + schema := testSchema() + rec := buildTestRecord(alloc, schema, []int64{1, 2, 3, 4, 5}) + defer rec.Release() + + data := buildIPCBytes(alloc, schema, []arrow.RecordBatch{rec}) + batch := &mockBatch{numRows: 5, streams: []func(context.Context) (io.ReadCloser, error){ + streamShortReads(data, 1), + }} + + recs, err := tryReadBatch(context.Background(), batch, alloc, identityTransform) + require.NoError(t, err) + require.Len(t, recs, 1) + defer recs[0].Release() + assert.EqualValues(t, 5, recs[0].NumRows()) +} + +// Reproduces gosnowflake#1781 and validates the streamRetryEnabled flag: +// with retries disabled (matching the production "no retry" branch in +// newRecordReader) the mid-frame IPC error surfaces; with retries enabled +// the same broken first stream is recovered by a second attempt. +func TestStreamRetryEnabled_RecoversFromShortReadMidFrameError(t *testing.T) { + schema := testSchema() + + cases := []struct { + name string + streamRetryEnabled bool + expectErr bool + }{ + {name: "disabled_surfacesIPCError", streamRetryEnabled: false, expectErr: true}, + {name: "enabled_recovers", streamRetryEnabled: true, expectErr: false}, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + alloc := memory.NewCheckedAllocator(memory.DefaultAllocator) + defer alloc.AssertSize(t, 0) + + rec := buildTestRecord(alloc, schema, []int64{11, 22, 33, 44}) + defer rec.Release() + + data := buildIPCBytes(alloc, schema, []arrow.RecordBatch{rec}) + require.Greater(t, len(data), 32) + truncationOffset := len(data) - 8 + + batch := &mockBatch{numRows: 4, streams: []func(context.Context) (io.ReadCloser, error){ + streamShortReadsThenError(data, 16, truncationOffset, io.ErrUnexpectedEOF), + streamShortReads(data, 16), + }} + + // Mirror the production dispatch in newRecordReader. + ctx := context.Background() + var recs []arrow.RecordBatch + var err error + if tc.streamRetryEnabled { + recs, err = readBatchRecords(ctx, batch, alloc, identityTransform, defaultStreamMaxRetries) + } else { + out := make(chan arrow.RecordBatch, 4) + target := newBatchStreamTarget(0, batch, out, nil) + err = streamBatchToChannel(ctx, 0, batch, alloc, identityTransform, target) + close(out) + for r := range out { + recs = append(recs, r) + } + } + + if tc.expectErr { + require.Error(t, err, "expected IPC error to surface without retry") + msg := err.Error() + assert.True(t, + strings.Contains(msg, "could not read message body") || + strings.Contains(msg, "unexpected EOF") || + strings.Contains(msg, "row count mismatch"), + "expected IPC body-read failure, got: %s", msg, + ) + for _, r := range recs { + r.Release() + } + return + } + + require.NoError(t, err, "retry should recover from mid-frame truncation") + require.Len(t, recs, 1) + defer recs[0].Release() + assert.EqualValues(t, 4, recs[0].NumRows()) + assert.Equal(t, 2, batch.call, "retry should have been invoked") + }) + } +} From 877bb3f4f76fff29b286819262b0a0f9d89beec1 Mon Sep 17 00:00:00 2001 From: David Coe <> Date: Tue, 9 Jun 2026 15:00:07 -0400 Subject: [PATCH 6/6] clean up --- go/record_reader.go | 5 ----- 1 file changed, 5 deletions(-) diff --git a/go/record_reader.go b/go/record_reader.go index 4425eaa..4a4a436 100644 --- a/go/record_reader.go +++ b/go/record_reader.go @@ -1258,11 +1258,6 @@ func newRecordReader(ctx context.Context, alloc memory.Allocator, ld gosnowflake attribute.Int("batches", len(batches)), attribute.Int64("totalRows", ld.TotalRows()), attribute.Bool("streamRetryEnabled", streamRetryEnabled), - )) - - trace.SpanFromContext(ctx).AddEvent("newRecordReader", trace.WithAttributes( - attribute.Int("batches", len(batches)), - attribute.Int64("totalRows", ld.TotalRows()), attribute.Int("jsonDataLen", len(ld.JSONData())), ))