diff --git a/internal/impl/snowflake/output_snowflake_streaming.go b/internal/impl/snowflake/output_snowflake_streaming.go index 6749393b7a..92175d8aa8 100644 --- a/internal/impl/snowflake/output_snowflake_streaming.go +++ b/internal/impl/snowflake/output_snowflake_streaming.go @@ -100,7 +100,7 @@ You can monitor the output batch size using the `+"`snowflake_compressed_output_ service.NewStringField(ssoFieldRole).Description("The role for the `user` field. The role must have the https://docs.snowflake.com/en/user-guide/data-load-snowpipe-streaming-overview#required-access-privileges[required privileges^] to call the Snowpipe Streaming APIs. See https://docs.snowflake.com/en/user-guide/admin-user-management#user-roles[Snowflake Documentation^] for more information about roles.").Example("ACCOUNTADMIN"), service.NewStringField(ssoFieldDB).Description("The Snowflake database to ingest data into.").Example("MY_DATABASE"), service.NewStringField(ssoFieldSchema).Description("The Snowflake schema to ingest data into.").Example("PUBLIC"), - service.NewStringField(ssoFieldTable).Description("The Snowflake table to ingest data into.").Example("MY_TABLE"), + service.NewInterpolatedStringField(ssoFieldTable).Description("The Snowflake table to ingest data into.").Example("MY_TABLE"), service.NewStringField(ssoFieldKey).Description("The PEM encoded private RSA key to use for authenticating with Snowflake. Either this or `private_key_file` must be specified.").Optional().Secret(), service.NewStringField(ssoFieldKeyFile).Description("The file to load the private RSA key from. This should be a `.p8` PEM encoded file. Either this or `private_key` must be specified.").Optional(), service.NewStringField(ssoFieldKeyPass).Description("The RSA key passphrase if the RSA key is encrypted.").Optional().Secret(), @@ -387,7 +387,7 @@ func newSnowflakeStreamer( if err != nil { return nil, err } - table, err := conf.FieldString(ssoFieldTable) + dynamicTable, err := conf.FieldInterpolatedString(ssoFieldTable) if err != nil { return nil, err } @@ -431,11 +431,6 @@ func newSnowflakeStreamer( if err != nil { return nil, err } - } else if !conf.Contains(ssoFieldChannelName) { - // There is a limit of 10k channels, so we can't dynamically create them. - // The only other good default is to create one and only allow a single - // stream to write to a single table. - channelPrefix = fmt.Sprintf("Redpanda_Connect_%s.%s.%s", db, schema, table) } var channelName *service.InterpolatedString @@ -511,71 +506,104 @@ func newSnowflakeStreamer( if err != nil { return nil, err } - var schemaEvolver *snowpipeSchemaEvolver - if schemaEvolutionMapping != nil { - schemaEvolver = &snowpipeSchemaEvolver{ - schemaEvolutionMapping: schemaEvolutionMapping, - restClient: restClient, - logger: mgr.Logger(), - db: db, - schema: schema, - table: table, - role: role, - } - } - var impl service.BatchOutput - if channelName != nil { - indexed := &snowpipeIndexedOutput{ - channelName: channelName, - client: client, - db: db, - schema: schema, - table: table, - role: role, - logger: mgr.Logger(), - metrics: newSnowpipeMetrics(mgr.Metrics()), - buildOpts: buildOpts, - offsetToken: offsetToken, - schemaMigrationEnabled: schemaEvolver != nil, - } - indexed.channelPool = pool.NewIndexed(func(ctx context.Context, name string) (*streaming.SnowflakeIngestionChannel, error) { - hash := sha256.Sum256([]byte(name)) - id := binary.BigEndian.Uint16(hash[:]) - return indexed.openChannel(ctx, name, int16(id)) - }) - impl = indexed - } else { - pooled := &snowpipePooledOutput{ - channelPrefix: channelPrefix, - client: client, - db: db, - schema: schema, - table: table, - role: role, - logger: mgr.Logger(), - metrics: newSnowpipeMetrics(mgr.Metrics()), - buildOpts: buildOpts, - offsetToken: offsetToken, - schemaMigrationEnabled: schemaEvolver != nil, - } - pooled.channelPool = pool.NewCapped(maxInFlight, func(ctx context.Context, id int) (*streaming.SnowflakeIngestionChannel, error) { - name := fmt.Sprintf("%s_%d", pooled.channelPrefix, id) - return pooled.openChannel(ctx, name, int16(id)) - }) - impl = pooled + + mgr.SetGeneric(SnowflakeClientResourceForTesting, restClient) + makeImpl := func(table string) (*snowpipeSchemaEvolver, service.BatchOutput) { + var schemaEvolver *snowpipeSchemaEvolver + if schemaEvolutionMapping != nil { + schemaEvolver = &snowpipeSchemaEvolver{ + schemaEvolutionMapping: schemaEvolutionMapping, + restClient: restClient, + logger: mgr.Logger(), + db: db, + schema: schema, + table: table, + role: role, + } + } + var impl service.BatchOutput + if channelName != nil { + indexed := &snowpipeIndexedOutput{ + channelName: channelName, + client: client, + db: db, + schema: schema, + table: table, + role: role, + logger: mgr.Logger(), + metrics: newSnowpipeMetrics(mgr.Metrics()), + buildOpts: buildOpts, + offsetToken: offsetToken, + schemaMigrationEnabled: schemaEvolver != nil, + } + indexed.channelPool = pool.NewIndexed(func(ctx context.Context, name string) (*streaming.SnowflakeIngestionChannel, error) { + hash := sha256.Sum256([]byte(name)) + id := binary.BigEndian.Uint16(hash[:]) + return indexed.openChannel(ctx, name, int16(id)) + }) + impl = indexed + } else { + if channelPrefix == "" { + // There is a limit of 10k channels, so we can't dynamically create them. + // The only other good default is to create one and only allow a single + // stream to write to a single table. + channelPrefix = fmt.Sprintf("Redpanda_Connect_%s.%s.%s", db, schema, table) + } + pooled := &snowpipePooledOutput{ + channelPrefix: channelPrefix, + client: client, + db: db, + schema: schema, + table: table, + role: role, + logger: mgr.Logger(), + metrics: newSnowpipeMetrics(mgr.Metrics()), + buildOpts: buildOpts, + offsetToken: offsetToken, + schemaMigrationEnabled: schemaEvolver != nil, + } + pooled.channelPool = pool.NewCapped(maxInFlight, func(ctx context.Context, id int) (*streaming.SnowflakeIngestionChannel, error) { + name := fmt.Sprintf("%s_%d", pooled.channelPrefix, id) + return pooled.openChannel(ctx, name, int16(id)) + }) + impl = pooled + } + return schemaEvolver, impl } - foo := &snowpipeStreamingOutput{ - initStatementsFn: initStatementsFn, - client: client, - restClient: restClient, - mapping: mapping, - logger: mgr.Logger(), - schemaEvolver: schemaEvolver, - impl: impl, + if table, ok := dynamicTable.Static(); ok { + schemaEvolver, impl := makeImpl(table) + return &snowpipeStreamingOutput{ + initStatementsFn: initStatementsFn, + client: client, + restClient: restClient, + mapping: mapping, + logger: mgr.Logger(), + schemaEvolver: schemaEvolver, + + impl: impl, + }, nil + } else { + return &dynamicSnowpipeStreamingOutput{ + table: dynamicTable, + byTable: pool.NewIndexed(func(ctx context.Context, table string) (service.BatchOutput, error) { + schemaEvolver, impl := makeImpl(table) + return &snowpipeStreamingOutput{ + initStatementsFn: nil, + client: nil, + restClient: nil, + mapping: mapping, + logger: mgr.Logger(), + schemaEvolver: schemaEvolver, + + impl: impl, + }, nil + }), + initStatementsFn: initStatementsFn, + client: client, + restClient: restClient, + }, nil } - mgr.SetGeneric(SnowflakeClientResourceForTesting, restClient) - return foo, nil } type snowflakeClientForTesting string @@ -584,6 +612,68 @@ type snowflakeClientForTesting string // which can remove boilerplate from tests to setup a new REST client. const SnowflakeClientResourceForTesting snowflakeClientForTesting = "SnowflakeClientResourceForTesting" +type dynamicSnowpipeStreamingOutput struct { + table *service.InterpolatedString + byTable pool.Indexed[service.BatchOutput] + + initStatementsFn func(context.Context, *streaming.SnowflakeRestClient) error + client *streaming.SnowflakeServiceClient + restClient *streaming.SnowflakeRestClient +} + +func (o *dynamicSnowpipeStreamingOutput) Connect(ctx context.Context) error { + if o.initStatementsFn != nil { + if err := o.initStatementsFn(ctx, o.restClient); err != nil { + return fmt.Errorf("unable to run initialization statement: %w", err) + } + // We've already executed our init statement, we don't need to do that anymore + o.initStatementsFn = nil + } + return nil +} + +func (o *dynamicSnowpipeStreamingOutput) WriteBatch(ctx context.Context, batch service.MessageBatch) error { + executor := batch.InterpolationExecutor(o.table) + tableBatches := map[string]service.MessageBatch{} + for i, msg := range batch { + table, err := executor.TryString(i) + if err != nil { + return fmt.Errorf("unable to interpolate `%s`: %w", ssoFieldTable, err) + } + tableBatches[table] = append(tableBatches[table], msg) + } + for table, batch := range tableBatches { + output, err := o.byTable.Acquire(ctx, table) + if err != nil { + return err + } + // Immediately release, these are thread safe, so we can let other + // threads modify them while we have a reference. + o.byTable.Release(table, output) + if err := output.WriteBatch(ctx, batch); err != nil { + return err + } + } + return nil +} + +func (o *dynamicSnowpipeStreamingOutput) Close(ctx context.Context) error { + for _, key := range o.byTable.Keys() { + out, err := o.byTable.Acquire(ctx, key) + if err != nil { + return err + } + o.byTable.Release(key, out) + if err := out.Close(ctx); err != nil { + return err + } + } + o.byTable.Reset() + o.client.Close() + o.restClient.Close() + return nil +} + type snowpipeStreamingOutput struct { initStatementsFn func(context.Context, *streaming.SnowflakeRestClient) error client *streaming.SnowflakeServiceClient @@ -705,8 +795,12 @@ func (o *snowpipeStreamingOutput) Close(ctx context.Context) error { if err := o.impl.Close(ctx); err != nil { return err } - o.client.Close() - o.restClient.Close() + if o.client != nil { + o.client.Close() + } + if o.restClient != nil { + o.restClient.Close() + } return nil } diff --git a/internal/impl/snowflake/pool/indexed.go b/internal/impl/snowflake/pool/indexed.go index 7c5aece865..0751fbb50f 100644 --- a/internal/impl/snowflake/pool/indexed.go +++ b/internal/impl/snowflake/pool/indexed.go @@ -30,6 +30,8 @@ type ( Release(name string, item T) // Reset all items in the pool Reset() + // Get all the keys in the pool + Keys() []string } indexedImpl[T any] struct { ctor func(context.Context, string) (T, error) @@ -97,3 +99,13 @@ func (p *indexedImpl[T]) Reset() { clear(p.items) p.unlock() } + +func (p *indexedImpl[T]) Keys() []string { + keys := []string{} + _ = p.lock(context.Background()) + for k := range p.items { + keys = append(keys, k) + } + p.unlock() + return keys +}