Skip to content

Commit

Permalink
snowpipe: support interpolated tables
Browse files Browse the repository at this point in the history
  • Loading branch information
rockwotj committed Dec 17, 2024
1 parent 5c43ce3 commit c189654
Show file tree
Hide file tree
Showing 2 changed files with 177 additions and 71 deletions.
236 changes: 165 additions & 71 deletions internal/impl/snowflake/output_snowflake_streaming.go
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ You can monitor the output batch size using the `+"`snowflake_compressed_output_
service.NewStringField(ssoFieldRole).Description("The role for the `user` field. The role must have the https://docs.snowflake.com/en/user-guide/data-load-snowpipe-streaming-overview#required-access-privileges[required privileges^] to call the Snowpipe Streaming APIs. See https://docs.snowflake.com/en/user-guide/admin-user-management#user-roles[Snowflake Documentation^] for more information about roles.").Example("ACCOUNTADMIN"),
service.NewStringField(ssoFieldDB).Description("The Snowflake database to ingest data into.").Example("MY_DATABASE"),
service.NewStringField(ssoFieldSchema).Description("The Snowflake schema to ingest data into.").Example("PUBLIC"),
service.NewStringField(ssoFieldTable).Description("The Snowflake table to ingest data into.").Example("MY_TABLE"),
service.NewInterpolatedStringField(ssoFieldTable).Description("The Snowflake table to ingest data into.").Example("MY_TABLE"),
service.NewStringField(ssoFieldKey).Description("The PEM encoded private RSA key to use for authenticating with Snowflake. Either this or `private_key_file` must be specified.").Optional().Secret(),
service.NewStringField(ssoFieldKeyFile).Description("The file to load the private RSA key from. This should be a `.p8` PEM encoded file. Either this or `private_key` must be specified.").Optional(),
service.NewStringField(ssoFieldKeyPass).Description("The RSA key passphrase if the RSA key is encrypted.").Optional().Secret(),
Expand Down Expand Up @@ -387,7 +387,7 @@ func newSnowflakeStreamer(
if err != nil {
return nil, err
}
table, err := conf.FieldString(ssoFieldTable)
dynamicTable, err := conf.FieldInterpolatedString(ssoFieldTable)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -431,11 +431,6 @@ func newSnowflakeStreamer(
if err != nil {
return nil, err
}
} else if !conf.Contains(ssoFieldChannelName) {
// There is a limit of 10k channels, so we can't dynamically create them.
// The only other good default is to create one and only allow a single
// stream to write to a single table.
channelPrefix = fmt.Sprintf("Redpanda_Connect_%s.%s.%s", db, schema, table)
}

var channelName *service.InterpolatedString
Expand Down Expand Up @@ -511,71 +506,104 @@ func newSnowflakeStreamer(
if err != nil {
return nil, err
}
var schemaEvolver *snowpipeSchemaEvolver
if schemaEvolutionMapping != nil {
schemaEvolver = &snowpipeSchemaEvolver{
schemaEvolutionMapping: schemaEvolutionMapping,
restClient: restClient,
logger: mgr.Logger(),
db: db,
schema: schema,
table: table,
role: role,
}
}
var impl service.BatchOutput
if channelName != nil {
indexed := &snowpipeIndexedOutput{
channelName: channelName,
client: client,
db: db,
schema: schema,
table: table,
role: role,
logger: mgr.Logger(),
metrics: newSnowpipeMetrics(mgr.Metrics()),
buildOpts: buildOpts,
offsetToken: offsetToken,
schemaMigrationEnabled: schemaEvolver != nil,
}
indexed.channelPool = pool.NewIndexed(func(ctx context.Context, name string) (*streaming.SnowflakeIngestionChannel, error) {
hash := sha256.Sum256([]byte(name))
id := binary.BigEndian.Uint16(hash[:])
return indexed.openChannel(ctx, name, int16(id))
})
impl = indexed
} else {
pooled := &snowpipePooledOutput{
channelPrefix: channelPrefix,
client: client,
db: db,
schema: schema,
table: table,
role: role,
logger: mgr.Logger(),
metrics: newSnowpipeMetrics(mgr.Metrics()),
buildOpts: buildOpts,
offsetToken: offsetToken,
schemaMigrationEnabled: schemaEvolver != nil,
}
pooled.channelPool = pool.NewCapped(maxInFlight, func(ctx context.Context, id int) (*streaming.SnowflakeIngestionChannel, error) {
name := fmt.Sprintf("%s_%d", pooled.channelPrefix, id)
return pooled.openChannel(ctx, name, int16(id))
})
impl = pooled

mgr.SetGeneric(SnowflakeClientResourceForTesting, restClient)
makeImpl := func(table string) (*snowpipeSchemaEvolver, service.BatchOutput) {
var schemaEvolver *snowpipeSchemaEvolver
if schemaEvolutionMapping != nil {
schemaEvolver = &snowpipeSchemaEvolver{
schemaEvolutionMapping: schemaEvolutionMapping,
restClient: restClient,
logger: mgr.Logger(),
db: db,
schema: schema,
table: table,
role: role,
}
}
var impl service.BatchOutput
if channelName != nil {
indexed := &snowpipeIndexedOutput{
channelName: channelName,
client: client,
db: db,
schema: schema,
table: table,
role: role,
logger: mgr.Logger(),
metrics: newSnowpipeMetrics(mgr.Metrics()),
buildOpts: buildOpts,
offsetToken: offsetToken,
schemaMigrationEnabled: schemaEvolver != nil,
}
indexed.channelPool = pool.NewIndexed(func(ctx context.Context, name string) (*streaming.SnowflakeIngestionChannel, error) {
hash := sha256.Sum256([]byte(name))
id := binary.BigEndian.Uint16(hash[:])
return indexed.openChannel(ctx, name, int16(id))
})
impl = indexed
} else {
if channelPrefix == "" {
// There is a limit of 10k channels, so we can't dynamically create them.
// The only other good default is to create one and only allow a single
// stream to write to a single table.
channelPrefix = fmt.Sprintf("Redpanda_Connect_%s.%s.%s", db, schema, table)
}
pooled := &snowpipePooledOutput{
channelPrefix: channelPrefix,
client: client,
db: db,
schema: schema,
table: table,
role: role,
logger: mgr.Logger(),
metrics: newSnowpipeMetrics(mgr.Metrics()),
buildOpts: buildOpts,
offsetToken: offsetToken,
schemaMigrationEnabled: schemaEvolver != nil,
}
pooled.channelPool = pool.NewCapped(maxInFlight, func(ctx context.Context, id int) (*streaming.SnowflakeIngestionChannel, error) {
name := fmt.Sprintf("%s_%d", pooled.channelPrefix, id)
return pooled.openChannel(ctx, name, int16(id))
})
impl = pooled
}
return schemaEvolver, impl
}
foo := &snowpipeStreamingOutput{
initStatementsFn: initStatementsFn,
client: client,
restClient: restClient,
mapping: mapping,
logger: mgr.Logger(),
schemaEvolver: schemaEvolver,

impl: impl,
if table, ok := dynamicTable.Static(); ok {
schemaEvolver, impl := makeImpl(table)
return &snowpipeStreamingOutput{
initStatementsFn: initStatementsFn,
client: client,
restClient: restClient,
mapping: mapping,
logger: mgr.Logger(),
schemaEvolver: schemaEvolver,

impl: impl,
}, nil
} else {
return &dynamicSnowpipeStreamingOutput{
table: dynamicTable,
byTable: pool.NewIndexed(func(ctx context.Context, table string) (service.BatchOutput, error) {
schemaEvolver, impl := makeImpl(table)
return &snowpipeStreamingOutput{
initStatementsFn: nil,
client: nil,
restClient: nil,
mapping: mapping,
logger: mgr.Logger(),
schemaEvolver: schemaEvolver,

impl: impl,
}, nil
}),
initStatementsFn: initStatementsFn,
client: client,
restClient: restClient,
}, nil
}
mgr.SetGeneric(SnowflakeClientResourceForTesting, restClient)
return foo, nil
}

type snowflakeClientForTesting string
Expand All @@ -584,6 +612,68 @@ type snowflakeClientForTesting string
// which can remove boilerplate from tests to setup a new REST client.
const SnowflakeClientResourceForTesting snowflakeClientForTesting = "SnowflakeClientResourceForTesting"

type dynamicSnowpipeStreamingOutput struct {
table *service.InterpolatedString
byTable pool.Indexed[service.BatchOutput]

initStatementsFn func(context.Context, *streaming.SnowflakeRestClient) error
client *streaming.SnowflakeServiceClient
restClient *streaming.SnowflakeRestClient
}

func (o *dynamicSnowpipeStreamingOutput) Connect(ctx context.Context) error {
if o.initStatementsFn != nil {
if err := o.initStatementsFn(ctx, o.restClient); err != nil {
return fmt.Errorf("unable to run initialization statement: %w", err)
}
// We've already executed our init statement, we don't need to do that anymore
o.initStatementsFn = nil
}
return nil
}

func (o *dynamicSnowpipeStreamingOutput) WriteBatch(ctx context.Context, batch service.MessageBatch) error {
executor := batch.InterpolationExecutor(o.table)
tableBatches := map[string]service.MessageBatch{}
for i, msg := range batch {
table, err := executor.TryString(i)
if err != nil {
return fmt.Errorf("unable to interpolate `%s`: %w", ssoFieldTable, err)
}
tableBatches[table] = append(tableBatches[table], msg)
}
for table, batch := range tableBatches {
output, err := o.byTable.Acquire(ctx, table)
if err != nil {
return err
}
// Immediately release, these are thread safe, so we can let other
// threads modify them while we have a reference.
o.byTable.Release(table, output)
if err := output.WriteBatch(ctx, batch); err != nil {
return err
}
}
return nil
}

func (o *dynamicSnowpipeStreamingOutput) Close(ctx context.Context) error {
for _, key := range o.byTable.Keys() {
out, err := o.byTable.Acquire(ctx, key)
if err != nil {
return err
}
o.byTable.Release(key, out)
if err := out.Close(ctx); err != nil {
return err
}
}
o.byTable.Reset()
o.client.Close()
o.restClient.Close()
return nil
}

type snowpipeStreamingOutput struct {
initStatementsFn func(context.Context, *streaming.SnowflakeRestClient) error
client *streaming.SnowflakeServiceClient
Expand Down Expand Up @@ -705,8 +795,12 @@ func (o *snowpipeStreamingOutput) Close(ctx context.Context) error {
if err := o.impl.Close(ctx); err != nil {
return err
}
o.client.Close()
o.restClient.Close()
if o.client != nil {
o.client.Close()
}
if o.restClient != nil {
o.restClient.Close()
}
return nil
}

Expand Down
12 changes: 12 additions & 0 deletions internal/impl/snowflake/pool/indexed.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ type (
Release(name string, item T)
// Reset all items in the pool
Reset()
// Get all the keys in the pool
Keys() []string
}
indexedImpl[T any] struct {
ctor func(context.Context, string) (T, error)
Expand Down Expand Up @@ -97,3 +99,13 @@ func (p *indexedImpl[T]) Reset() {
clear(p.items)
p.unlock()
}

func (p *indexedImpl[T]) Keys() []string {
keys := []string{}
_ = p.lock(context.Background())
for k := range p.items {
keys = append(keys, k)
}
p.unlock()
return keys
}

0 comments on commit c189654

Please sign in to comment.