From 20e90bf7e02401e603184a67141b4ad4456df8cd Mon Sep 17 00:00:00 2001 From: Tyler Rockwood Date: Fri, 13 Dec 2024 09:49:44 +0000 Subject: [PATCH] pgcdc: introduce TableFQN type There is currently a mess of "what does this string mean?", which means it's time to introduce some typesafety to this problem. TableFQN is a Schema+Table pair that is prevalidated to not have SQL injection opportunities and we can pass these around to make things a bit more clear as well as ensure we're handling quoted stuff correctly. --- internal/impl/postgresql/input_pg_stream.go | 2 +- internal/impl/postgresql/integration_test.go | 12 ++-- .../pglogicalstream/logical_stream.go | 47 +++++++------- .../postgresql/pglogicalstream/monitor.go | 33 ++++------ .../postgresql/pglogicalstream/pglogrepl.go | 59 +++++++----------- .../pglogicalstream/pglogrepl_test.go | 62 ++++++++++--------- .../pglogicalstream/sanitize/sanitize.go | 18 ++++++ .../pglogicalstream/sanitize/sanitize_test.go | 3 +- .../postgresql/pglogicalstream/snapshotter.go | 32 +++++----- .../impl/postgresql/pglogicalstream/types.go | 15 +++++ 10 files changed, 148 insertions(+), 135 deletions(-) diff --git a/internal/impl/postgresql/input_pg_stream.go b/internal/impl/postgresql/input_pg_stream.go index 5c52274196..32a7ca0627 100644 --- a/internal/impl/postgresql/input_pg_stream.go +++ b/internal/impl/postgresql/input_pg_stream.go @@ -321,7 +321,7 @@ func (p *pgStreamInput) processStream(pgStream *pglogicalstream.Stream, batcher // Periodically collect stats report := pgStream.GetProgress() for name, progress := range report.TableProgress { - p.snapshotMetrics.SetFloat64(progress, name) + p.snapshotMetrics.SetFloat64(progress, name.String()) } p.replicationLag.Set(report.WalLagInBytes) }) diff --git a/internal/impl/postgresql/integration_test.go b/internal/impl/postgresql/integration_test.go index c13b04b5fc..c470c30857 100644 --- a/internal/impl/postgresql/integration_test.go +++ b/internal/impl/postgresql/integration_test.go @@ -128,7 +128,7 @@ func ResourceWithPostgreSQLVersion(t *testing.T, pool *dockertest.Pool, version } _, err = db.Exec(` - CREATE TABLE IF NOT EXISTS flights_composite_pks ( + CREATE TABLE IF NOT EXISTS "FlightsCompositePK" ( id serial, seq integer, name VARCHAR(50), created_at TIMESTAMP, PRIMARY KEY (id, seq) );`) @@ -171,7 +171,7 @@ func TestIntegrationPostgresNoTxnMarkers(t *testing.T) { for i := 0; i < 10; i++ { f := GetFakeFlightRecord() - _, err = db.Exec("INSERT INTO flights_composite_pks (seq, name, created_at) VALUES ($1, $2, $3);", i, f.RealAddress.City, time.Unix(f.CreatedAt, 0).Format(time.RFC3339)) + _, err = db.Exec(`INSERT INTO "FlightsCompositePK" (seq, name, created_at) VALUES ($1, $2, $3);`, i, f.RealAddress.City, time.Unix(f.CreatedAt, 0).Format(time.RFC3339)) require.NoError(t, err) } @@ -184,7 +184,7 @@ pg_stream: snapshot_batch_size: 5 schema: public tables: - - flights_composite_pks + - '"FlightsCompositePK"' `, databaseURL) cacheConf := fmt.Sprintf(` @@ -226,9 +226,9 @@ file: for i := 10; i < 20; i++ { f := GetFakeFlightRecord() - _, err = db.Exec("INSERT INTO flights_composite_pks (seq, name, created_at) VALUES ($1, $2, $3);", i, f.RealAddress.City, time.Unix(f.CreatedAt, 0).Format(time.RFC3339)) + _, err = db.Exec(`INSERT INTO "FlightsCompositePK" (seq, name, created_at) VALUES ($1, $2, $3);`, i, f.RealAddress.City, time.Unix(f.CreatedAt, 0).Format(time.RFC3339)) require.NoError(t, err) - _, err = db.Exec("INSERT INTO flights_non_streamed (name, created_at) VALUES ($1, $2);", f.RealAddress.City, time.Unix(f.CreatedAt, 0).Format(time.RFC3339)) + _, err = db.Exec(`INSERT INTO flights_non_streamed (name, created_at) VALUES ($1, $2);`, f.RealAddress.City, time.Unix(f.CreatedAt, 0).Format(time.RFC3339)) require.NoError(t, err) } @@ -270,7 +270,7 @@ file: time.Sleep(time.Second * 5) for i := 20; i < 30; i++ { f := GetFakeFlightRecord() - _, err = db.Exec("INSERT INTO flights_composite_pks (seq, name, created_at) VALUES ($1, $2, $3);", i, f.RealAddress.City, time.Unix(f.CreatedAt, 0).Format(time.RFC3339)) + _, err = db.Exec(`INSERT INTO "FlightsCompositePK" (seq, name, created_at) VALUES ($1, $2, $3);`, i, f.RealAddress.City, time.Unix(f.CreatedAt, 0).Format(time.RFC3339)) require.NoError(t, err) } diff --git a/internal/impl/postgresql/pglogicalstream/logical_stream.go b/internal/impl/postgresql/pglogicalstream/logical_stream.go index 6f87e463b3..3a98b5cb12 100644 --- a/internal/impl/postgresql/pglogicalstream/logical_stream.go +++ b/internal/impl/postgresql/pglogicalstream/logical_stream.go @@ -13,8 +13,6 @@ import ( "database/sql" "errors" "fmt" - "slices" - "strings" "sync" "time" @@ -46,12 +44,10 @@ type Stream struct { messages chan StreamMessage errors chan error - includeTxnMarkers bool - snapshotName string - slotName string - schema string - // includes schema - tableQualifiedName []string + includeTxnMarkers bool + snapshotName string + slotName string + tables []TableFQN snapshotBatchSize int decodingPluginArguments []string snapshotMemorySafetyFactor float64 @@ -90,13 +86,16 @@ func NewPgStream(ctx context.Context, config *Config) (*Stream, error) { return nil, err } - tableNames := slices.Clone(config.DBTables) - for i, table := range tableNames { + if err := sanitize.ValidatePostgresIdentifier(config.DBSchema); err != nil { + return nil, fmt.Errorf("invalid schema name %q: %w", config.DBSchema, err) + } + + tables := []TableFQN{} + for _, table := range config.DBTables { if err := sanitize.ValidatePostgresIdentifier(table); err != nil { return nil, fmt.Errorf("invalid table name %q: %w", table, err) } - - tableNames[i] = fmt.Sprintf("%s.%s", config.DBSchema, table) + tables = append(tables, TableFQN{Schema: config.DBSchema, Table: table}) } stream := &Stream{ pgConn: dbConn, @@ -105,8 +104,7 @@ func NewPgStream(ctx context.Context, config *Config) (*Stream, error) { slotName: config.ReplicationSlotName, snapshotMemorySafetyFactor: config.SnapshotMemorySafetyFactor, snapshotBatchSize: config.BatchSize, - schema: config.DBSchema, - tableQualifiedName: tableNames, + tables: tables, maxParallelSnapshotTables: config.MaxParallelSnapshotTables, logger: config.Logger, shutSig: shutdown.NewSignaller(), @@ -143,8 +141,8 @@ func NewPgStream(ctx context.Context, config *Config) (*Stream, error) { stream.decodingPluginArguments = pluginArguments pubName := "pglog_stream_" + config.ReplicationSlotName - stream.logger.Infof("Creating publication %s for tables: %s", pubName, tableNames) - if err = CreatePublication(ctx, stream.pgConn, pubName, config.DBSchema, tableNames); err != nil { + stream.logger.Infof("Creating publication %s for tables: %s", pubName, tables) + if err = CreatePublication(ctx, stream.pgConn, pubName, tables); err != nil { return nil, err } cleanups = append(cleanups, func() { @@ -219,7 +217,7 @@ func NewPgStream(ctx context.Context, config *Config) (*Stream, error) { stream.standbyMessageTimeout = config.PgStandbyTimeout stream.nextStandbyMessageDeadline = time.Now().Add(stream.standbyMessageTimeout) - monitor, err := NewMonitor(ctx, config.DBRawDSN, stream.logger, tableNames, stream.slotName, config.WalMonitorInterval) + monitor, err := NewMonitor(ctx, config.DBRawDSN, stream.logger, tables, stream.slotName, config.WalMonitorInterval) if err != nil { return nil, err } @@ -481,7 +479,7 @@ func (s *Stream) processSnapshot() error { var wg errgroup.Group wg.SetLimit(s.maxParallelSnapshotTables) - for _, table := range s.tableQualifiedName { + for _, table := range s.tables { tableName := table wg.Go(func() (err error) { s.logger.Debugf("Processing snapshot for table: %v", table) @@ -551,7 +549,6 @@ func (s *Stream) processSnapshot() error { totalScanDuration := time.Duration(0) totalWaitingFromBenthos := time.Duration(0) - tableWithoutSchema := strings.Split(table, ".")[1] for snapshotRows.Next() { rowsCount += 1 @@ -581,13 +578,13 @@ func (s *Stream) processSnapshot() error { snapshotChangePacket := StreamMessage{ LSN: nil, Operation: ReadOpType, - Table: tableWithoutSchema, - Schema: s.schema, + Table: table.Table, + Schema: table.Schema, Data: data, } if rowsCount%100 == 0 { - s.monitor.UpdateSnapshotProgressForTable(tableWithoutSchema, rowsCount+offset) + s.monitor.UpdateSnapshotProgressForTable(table, rowsCount+offset) } waitingFromBenthos := time.Now() @@ -627,7 +624,7 @@ func (s *Stream) Errors() chan error { return s.errors } -func (s *Stream) getPrimaryKeyColumn(ctx context.Context, tableName string) (map[string]any, []string, error) { +func (s *Stream) getPrimaryKeyColumn(ctx context.Context, table TableFQN) (map[string]any, []string, error) { /// Query to get all primary key columns in their correct order q, err := sanitize.SQLQuery(` SELECT a.attname @@ -637,7 +634,7 @@ func (s *Stream) getPrimaryKeyColumn(ctx context.Context, tableName string) (map WHERE i.indrelid = $1::regclass AND i.indisprimary ORDER BY array_position(i.indkey, a.attnum); - `, tableName) + `, table.String()) if err != nil { return nil, nil, fmt.Errorf("failed to sanitize query: %w", err) @@ -650,7 +647,7 @@ func (s *Stream) getPrimaryKeyColumn(ctx context.Context, tableName string) (map } if len(data) == 0 || len(data[0].Rows) == 0 { - return nil, nil, fmt.Errorf("no primary key found for table %s", tableName) + return nil, nil, fmt.Errorf("no primary key found for table %s", table) } // Extract all primary key column names diff --git a/internal/impl/postgresql/pglogicalstream/monitor.go b/internal/impl/postgresql/pglogicalstream/monitor.go index 598a81ae62..add7429862 100644 --- a/internal/impl/postgresql/pglogicalstream/monitor.go +++ b/internal/impl/postgresql/pglogicalstream/monitor.go @@ -21,23 +21,22 @@ import ( "github.com/redpanda-data/benthos/v4/public/service" "github.com/redpanda-data/connect/v4/internal/asyncroutine" - "github.com/redpanda-data/connect/v4/internal/impl/postgresql/pglogicalstream/sanitize" ) // Report is a structure that contains the current state of the Monitor type Report struct { WalLagInBytes int64 - TableProgress map[string]float64 + TableProgress map[TableFQN]float64 } // Monitor is a structure that allows monitoring the progress of snapshot ingestion and replication lag type Monitor struct { // tableStat contains numbers of rows for each table determined at the moment of the snapshot creation // this is used to calculate snapshot ingestion progress - tableStat map[string]int64 + tableStat map[TableFQN]int64 lock sync.Mutex // snapshotProgress is a map of table names to the percentage of rows ingested from the snapshot - snapshotProgress map[string]float64 + snapshotProgress map[TableFQN]float64 // replicationLagInBytes is the replication lag in bytes measured by // finding the difference between the latest LSN and the last confirmed LSN for the replication slot replicationLagInBytes int64 @@ -53,7 +52,7 @@ func NewMonitor( ctx context.Context, dbDSN string, logger *service.Logger, - tables []string, + tables []TableFQN, slotName string, interval time.Duration, ) (*Monitor, error) { @@ -66,7 +65,7 @@ func NewMonitor( } m := &Monitor{ - snapshotProgress: map[string]float64{}, + snapshotProgress: map[TableFQN]float64{}, replicationLagInBytes: 0, dbConn: dbConn, slotName: slotName, @@ -81,40 +80,32 @@ func NewMonitor( } // UpdateSnapshotProgressForTable updates the snapshot ingestion progress for a given table -func (m *Monitor) UpdateSnapshotProgressForTable(table string, position int) { +func (m *Monitor) UpdateSnapshotProgressForTable(table TableFQN, position int) { m.lock.Lock() defer m.lock.Unlock() m.snapshotProgress[table] = math.Round(float64(position) / float64(m.tableStat[table]) * 100) } // we need to read the tables stat to calculate the snapshot ingestion progress -func (m *Monitor) readTablesStat(ctx context.Context, tables []string) error { - results := make(map[string]int64) +func (m *Monitor) readTablesStat(ctx context.Context, tables []TableFQN) error { + results := make(map[TableFQN]int64) for _, table := range tables { - tableWithoutSchema := strings.Split(table, ".")[1] - err := sanitize.ValidatePostgresIdentifier(tableWithoutSchema) - - if err != nil { - return fmt.Errorf("error sanitizing query: %w", err) - } - var count int64 - // tableWithoutSchema has been validated so its safe to use in the query - err = m.dbConn.QueryRowContext(ctx, "SELECT COUNT(*) FROM "+tableWithoutSchema).Scan(&count) + err := m.dbConn.QueryRowContext(ctx, "SELECT COUNT(*) FROM "+table.String()).Scan(&count) if err != nil { // If the error is because the table doesn't exist, we'll set the count to 0 // and continue. You might want to log this situation. if strings.Contains(err.Error(), "does not exist") { - results[tableWithoutSchema] = 0 + results[table] = 0 continue } // For any other error, we'll return it - return fmt.Errorf("error counting rows in table %s: %w", tableWithoutSchema, err) + return fmt.Errorf("error counting rows in table %s: %w", table, err) } - results[tableWithoutSchema] = count + results[table] = count } m.tableStat = results diff --git a/internal/impl/postgresql/pglogicalstream/pglogrepl.go b/internal/impl/postgresql/pglogicalstream/pglogrepl.go index 795675cf60..bf3c9cad3a 100644 --- a/internal/impl/postgresql/pglogicalstream/pglogrepl.go +++ b/internal/impl/postgresql/pglogicalstream/pglogrepl.go @@ -21,7 +21,6 @@ import ( "context" "database/sql/driver" "encoding/binary" - "errors" "fmt" "slices" "strconv" @@ -343,7 +342,7 @@ func DropReplicationSlot(ctx context.Context, conn *pgconn.PgConn, slotName stri } // CreatePublication creates a new PostgreSQL publication with the given name for a list of tables and drop if exists flag -func CreatePublication(ctx context.Context, conn *pgconn.PgConn, publicationName string, schema string, tables []string) error { +func CreatePublication(ctx context.Context, conn *pgconn.PgConn, publicationName string, tables []TableFQN) error { // Check if publication exists pubQuery, err := sanitize.SQLQuery(` SELECT pubname, puballtables @@ -354,17 +353,6 @@ func CreatePublication(ctx context.Context, conn *pgconn.PgConn, publicationName return fmt.Errorf("failed to sanitize publication query: %w", err) } - // Since we need to pass table names without quoting, we need to validate it - for _, table := range tables { - if err := sanitize.ValidatePostgresIdentifier(table); err != nil { - return errors.New("invalid table name") - } - } - // the same for publication name - if err := sanitize.ValidatePostgresIdentifier(publicationName); err != nil { - return errors.New("invalid publication name") - } - result := conn.Exec(ctx, pubQuery) rows, err := result.ReadAll() @@ -374,16 +362,13 @@ func CreatePublication(ctx context.Context, conn *pgconn.PgConn, publicationName tablesClause := "FOR ALL TABLES" if len(tables) > 0 { - // quotedTables := make([]string, len(tables)) - // for i, table := range tables { - // // Use sanitize.SQLIdentifier to properly quote and escape table names - // quoted, err := sanitize.SQLIdentifier(table) - // if err != nil { - // return fmt.Errorf("invalid table name %q: %w", table, err) - // } - // quotedTables[i] = quoted - // } - tablesClause = "FOR TABLE " + strings.Join(tables, ", ") + tablesClause = "FOR TABLE " + for i, table := range tables { + if i > 0 { + tablesClause += ", " + } + tablesClause += table.String() + } } if len(rows) == 0 || len(rows[0].Rows) == 0 { @@ -403,7 +388,7 @@ func CreatePublication(ctx context.Context, conn *pgconn.PgConn, publicationName // assuming publication already exists // get a list of tables in the publication - pubTables, forAllTables, err := GetPublicationTables(ctx, conn, publicationName, schema) + pubTables, forAllTables, err := GetPublicationTables(ctx, conn, publicationName) if err != nil { return fmt.Errorf("failed to get publication tables: %w", err) } @@ -414,8 +399,8 @@ func CreatePublication(ctx context.Context, conn *pgconn.PgConn, publicationName return nil } - var tablesToRemoveFromPublication = []string{} - var tablesToAddToPublication = []string{} + var tablesToRemoveFromPublication = []TableFQN{} + var tablesToAddToPublication = []TableFQN{} for _, table := range tables { if !slices.Contains(pubTables, table) { tablesToAddToPublication = append(tablesToAddToPublication, table) @@ -430,7 +415,7 @@ func CreatePublication(ctx context.Context, conn *pgconn.PgConn, publicationName // remove tables from publication for _, dropTable := range tablesToRemoveFromPublication { - sq, err := sanitize.SQLQuery(fmt.Sprintf(`ALTER PUBLICATION %s DROP TABLE %s."%s";`, publicationName, schema, dropTable)) + sq, err := sanitize.SQLQuery(fmt.Sprintf(`ALTER PUBLICATION %s DROP TABLE %s;`, publicationName, dropTable.String())) if err != nil { return fmt.Errorf("failed to sanitize drop table query: %w", err) } @@ -442,7 +427,7 @@ func CreatePublication(ctx context.Context, conn *pgconn.PgConn, publicationName // add tables to publication for _, addTable := range tablesToAddToPublication { - sq, err := sanitize.SQLQuery(fmt.Sprintf("ALTER PUBLICATION %s ADD TABLE %s;", publicationName, addTable)) + sq, err := sanitize.SQLQuery(fmt.Sprintf("ALTER PUBLICATION %s ADD TABLE %s;", publicationName, addTable.String())) if err != nil { return fmt.Errorf("failed to sanitize add table query: %w", err) } @@ -457,15 +442,15 @@ func CreatePublication(ctx context.Context, conn *pgconn.PgConn, publicationName // GetPublicationTables returns a list of tables currently in the publication // Arguments, in order: list of the tables, exist for all tables, errror -func GetPublicationTables(ctx context.Context, conn *pgconn.PgConn, publicationName, schema string) ([]string, bool, error) { +func GetPublicationTables(ctx context.Context, conn *pgconn.PgConn, publicationName string) ([]TableFQN, bool, error) { query, err := sanitize.SQLQuery(` SELECT DISTINCT - tablename as table_name + tablename as table_name, + schemaname as schema_name FROM pg_publication_tables WHERE pubname = $1 - AND schemaname = $2 - ORDER BY table_name; - `, publicationName, strings.Trim(schema, "\"")) + ORDER BY schema_name, table_name; + `, publicationName) if err != nil { return nil, false, fmt.Errorf("failed to get publication tables: %w", err) } @@ -482,9 +467,13 @@ func GetPublicationTables(ctx context.Context, conn *pgconn.PgConn, publicationN return nil, true, nil // Publication exists and is for all tables } - tables := make([]string, 0, len(rows)) + tables := make([]TableFQN, 0, len(rows)) for _, row := range rows[0].Rows { - tables = append(tables, string(row[0])) + // These come from postgres so they are valid, but we have to quote them + // to prevent normalization + table := sanitize.QuotePostgresIdentifier(string(row[0])) + schema := sanitize.QuotePostgresIdentifier(string(row[1])) + tables = append(tables, TableFQN{Table: table, Schema: schema}) } return tables, false, nil diff --git a/internal/impl/postgresql/pglogicalstream/pglogrepl_test.go b/internal/impl/postgresql/pglogicalstream/pglogrepl_test.go index c18eb7a850..e87bcc1b7d 100644 --- a/internal/impl/postgresql/pglogicalstream/pglogrepl_test.go +++ b/internal/impl/postgresql/pglogicalstream/pglogrepl_test.go @@ -233,11 +233,11 @@ func TestIntegrationCreatePublication(t *testing.T) { defer closeConn(t, conn) publicationName := "test_publication" - schema := "public" - err = CreatePublication(context.Background(), conn, publicationName, schema, []string{}) + schema := `"public"` + err = CreatePublication(context.Background(), conn, publicationName, []TableFQN{}) require.NoError(t, err) - tables, forAllTables, err := GetPublicationTables(context.Background(), conn, publicationName, schema) + tables, forAllTables, err := GetPublicationTables(context.Background(), conn, publicationName) require.NoError(t, err) assert.Empty(t, tables) assert.True(t, forAllTables) @@ -247,13 +247,14 @@ func TestIntegrationCreatePublication(t *testing.T) { require.NoError(t, err) publicationWithTables := "test_pub_with_tables" - err = CreatePublication(context.Background(), conn, publicationWithTables, schema, []string{"test_table"}) + err = CreatePublication(context.Background(), conn, publicationWithTables, []TableFQN{{schema, `"test_table"`}}) require.NoError(t, err) - tables, forAllTables, err = GetPublicationTables(context.Background(), conn, publicationName, schema) + tables, forAllTables, err = GetPublicationTables(context.Background(), conn, publicationName) require.NoError(t, err) assert.NotEmpty(t, tables) - assert.Contains(t, tables, "test_table") + assert.Len(t, tables, 1) + assert.Contains(t, tables, TableFQN{schema, `"test_table"`}) assert.False(t, forAllTables) // Add more tables to publication @@ -262,41 +263,43 @@ func TestIntegrationCreatePublication(t *testing.T) { require.NoError(t, err) // Pass more tables to the publication - err = CreatePublication(context.Background(), conn, publicationWithTables, schema, []string{ - "test_table2", - "test_table", + err = CreatePublication(context.Background(), conn, publicationWithTables, []TableFQN{ + {schema, "test_table2"}, + {schema, "test_table"}, }) require.NoError(t, err) - tables, forAllTables, err = GetPublicationTables(context.Background(), conn, publicationWithTables, schema) + tables, forAllTables, err = GetPublicationTables(context.Background(), conn, publicationWithTables) require.NoError(t, err) assert.NotEmpty(t, tables) - assert.Contains(t, tables, "test_table") - assert.Contains(t, tables, "test_table2") + assert.Len(t, tables, 2) + assert.Contains(t, tables, TableFQN{schema, `"test_table"`}) + assert.Contains(t, tables, TableFQN{schema, `"test_table2"`}) assert.False(t, forAllTables) // Remove one table from the publication - err = CreatePublication(context.Background(), conn, publicationWithTables, schema, []string{ - "test_table", + err = CreatePublication(context.Background(), conn, publicationWithTables, []TableFQN{ + {schema, "test_table"}, }) require.NoError(t, err) - tables, forAllTables, err = GetPublicationTables(context.Background(), conn, publicationWithTables, schema) + tables, forAllTables, err = GetPublicationTables(context.Background(), conn, publicationWithTables) require.NoError(t, err) assert.NotEmpty(t, tables) - assert.Contains(t, tables, "test_table") + assert.Len(t, tables, 1) + assert.Contains(t, tables, TableFQN{schema, `"test_table"`}) assert.False(t, forAllTables) // Add one table and remove one at the same time - err = CreatePublication(context.Background(), conn, publicationWithTables, schema, []string{ - "test_table2", + err = CreatePublication(context.Background(), conn, publicationWithTables, []TableFQN{ + {schema, "test_table2"}, }) require.NoError(t, err) - tables, forAllTables, err = GetPublicationTables(context.Background(), conn, publicationWithTables, schema) + tables, forAllTables, err = GetPublicationTables(context.Background(), conn, publicationWithTables) require.NoError(t, err) assert.NotEmpty(t, tables) - assert.Contains(t, tables, "test_table2") + assert.Contains(t, tables, TableFQN{schema, `"test_table2"`}) assert.False(t, forAllTables) // Create a schema with a quoted identifier @@ -317,22 +320,22 @@ func TestIntegrationCreatePublication(t *testing.T) { // Pass tables to the schema with quoted identifiers publicationQuotedIdentifiers := "quoted_identifiers" - err = CreatePublication(context.Background(), conn, publicationQuotedIdentifiers, caseSensitiveSchema, []string{ - caseSensitiveSchema + "." + caseSensitiveTable, - caseSensitiveSchema + "." + caseSensitiveTable2, + err = CreatePublication(context.Background(), conn, publicationQuotedIdentifiers, []TableFQN{ + {caseSensitiveSchema, caseSensitiveTable}, + {caseSensitiveSchema, caseSensitiveTable2}, }) require.NoError(t, err) // Remove one table with a quoted identifier from the publication - err = CreatePublication(context.Background(), conn, publicationQuotedIdentifiers, caseSensitiveSchema, []string{ - caseSensitiveSchema + "." + caseSensitiveTable, + err = CreatePublication(context.Background(), conn, publicationQuotedIdentifiers, []TableFQN{ + {caseSensitiveSchema, caseSensitiveTable}, }) require.NoError(t, err) - tables, forAllTables, err = GetPublicationTables(context.Background(), conn, publicationQuotedIdentifiers, caseSensitiveSchema) + tables, forAllTables, err = GetPublicationTables(context.Background(), conn, publicationQuotedIdentifiers) require.NoError(t, err) - assert.NotEmpty(t, tables) - assert.Contains(t, tables, "Foo") + assert.Len(t, tables, 1) + assert.Contains(t, tables, TableFQN{`"FooBar"`, `"Foo"`}) assert.False(t, forAllTables) } @@ -357,8 +360,7 @@ func TestIntegrationStartReplication(t *testing.T) { // create publication publicationName := "test_publication" - schema := "public" - err = CreatePublication(context.Background(), conn, publicationName, schema, []string{}) + err = CreatePublication(context.Background(), conn, publicationName, []TableFQN{}) require.NoError(t, err) _, err = CreateReplicationSlot(ctx, conn, slotName, outputPlugin, CreateReplicationSlotOptions{Temporary: false, SnapshotAction: "export"}, 16, nil) diff --git a/internal/impl/postgresql/pglogicalstream/sanitize/sanitize.go b/internal/impl/postgresql/pglogicalstream/sanitize/sanitize.go index 145c2378d7..6ae239be3e 100644 --- a/internal/impl/postgresql/pglogicalstream/sanitize/sanitize.go +++ b/internal/impl/postgresql/pglogicalstream/sanitize/sanitize.go @@ -363,6 +363,24 @@ func SQLQuery(sql string, args ...any) (string, error) { return query.Sanitize(args...) } +// QuotePostgresIdentifier returns the valid escaped identifier. +func QuotePostgresIdentifier(name string) string { + var quoted strings.Builder + // Default to assume we're just going to add quotes and there won't + // be any double quotes inside the string that needs escaped. + quoted.Grow(len(name) + 2) + quoted.WriteByte('"') + for _, r := range name { + if r == '"' { + quoted.WriteString(`""`) + } else { + quoted.WriteRune(r) + } + } + quoted.WriteByte('"') + return quoted.String() +} + // ValidatePostgresIdentifier checks if a string is a valid PostgreSQL identifier // This follows PostgreSQL's standard naming rules func ValidatePostgresIdentifier(name string) error { diff --git a/internal/impl/postgresql/pglogicalstream/sanitize/sanitize_test.go b/internal/impl/postgresql/pglogicalstream/sanitize/sanitize_test.go index 805bb68da7..ebb70bb759 100644 --- a/internal/impl/postgresql/pglogicalstream/sanitize/sanitize_test.go +++ b/internal/impl/postgresql/pglogicalstream/sanitize/sanitize_test.go @@ -28,8 +28,9 @@ import ( "testing" "time" - "github.com/redpanda-data/connect/v4/internal/impl/postgresql/pglogicalstream/sanitize" "github.com/stretchr/testify/require" + + "github.com/redpanda-data/connect/v4/internal/impl/postgresql/pglogicalstream/sanitize" ) func TestNewQuery(t *testing.T) { diff --git a/internal/impl/postgresql/pglogicalstream/snapshotter.go b/internal/impl/postgresql/pglogicalstream/snapshotter.go index 72a76eb815..1d0cbad2f2 100644 --- a/internal/impl/postgresql/pglogicalstream/snapshotter.go +++ b/internal/impl/postgresql/pglogicalstream/snapshotter.go @@ -117,7 +117,7 @@ func (s *Snapshotter) prepare() error { return nil } -func (s *Snapshotter) findAvgRowSize(ctx context.Context, table string) (sql.NullInt64, error) { +func (s *Snapshotter) findAvgRowSize(ctx context.Context, table TableFQN) (sql.NullInt64, error) { var ( avgRowSize sql.NullInt64 rows *sql.Rows @@ -144,24 +144,24 @@ func (s *Snapshotter) findAvgRowSize(ctx context.Context, table string) (sql.Nul return avgRowSize, nil } -func (s *Snapshotter) prepareScannersAndGetters(columnTypes []*sql.ColumnType) ([]interface{}, []func(interface{}) (interface{}, error)) { - scanArgs := make([]interface{}, len(columnTypes)) - valueGetters := make([]func(interface{}) (interface{}, error), len(columnTypes)) +func (s *Snapshotter) prepareScannersAndGetters(columnTypes []*sql.ColumnType) ([]any, []func(any) (any, error)) { + scanArgs := make([]any, len(columnTypes)) + valueGetters := make([]func(any) (any, error), len(columnTypes)) for i, v := range columnTypes { switch v.DatabaseTypeName() { case "VARCHAR", "TEXT", "UUID", "TIMESTAMP": scanArgs[i] = new(sql.NullString) - valueGetters[i] = func(v interface{}) (interface{}, error) { return v.(*sql.NullString).String, nil } + valueGetters[i] = func(v any) (any, error) { return v.(*sql.NullString).String, nil } case "BOOL": scanArgs[i] = new(sql.NullBool) - valueGetters[i] = func(v interface{}) (interface{}, error) { return v.(*sql.NullBool).Bool, nil } + valueGetters[i] = func(v any) (any, error) { return v.(*sql.NullBool).Bool, nil } case "INT4": scanArgs[i] = new(sql.NullInt64) - valueGetters[i] = func(v interface{}) (interface{}, error) { return v.(*sql.NullInt64).Int64, nil } + valueGetters[i] = func(v any) (any, error) { return v.(*sql.NullInt64).Int64, nil } case "JSONB": scanArgs[i] = new(sql.NullString) - valueGetters[i] = func(v interface{}) (interface{}, error) { + valueGetters[i] = func(v any) (any, error) { payload := v.(*sql.NullString).String if payload == "" { return payload, nil @@ -175,7 +175,7 @@ func (s *Snapshotter) prepareScannersAndGetters(columnTypes []*sql.ColumnType) ( } case "INET": scanArgs[i] = new(sql.NullString) - valueGetters[i] = func(v interface{}) (interface{}, error) { + valueGetters[i] = func(v any) (any, error) { inet := pgtype.Inet{} val := v.(*sql.NullString).String if err := inet.Scan(val); err != nil { @@ -186,7 +186,7 @@ func (s *Snapshotter) prepareScannersAndGetters(columnTypes []*sql.ColumnType) ( } case "TSRANGE": scanArgs[i] = new(sql.NullString) - valueGetters[i] = func(v interface{}) (interface{}, error) { + valueGetters[i] = func(v any) (any, error) { newArray := pgtype.Tsrange{} val := v.(*sql.NullString).String if err := newArray.Scan(val); err != nil { @@ -198,7 +198,7 @@ func (s *Snapshotter) prepareScannersAndGetters(columnTypes []*sql.ColumnType) ( } case "_INT4": scanArgs[i] = new(sql.NullString) - valueGetters[i] = func(v interface{}) (interface{}, error) { + valueGetters[i] = func(v any) (any, error) { newArray := pgtype.Int4Array{} val := v.(*sql.NullString).String if err := newArray.Scan(val); err != nil { @@ -209,7 +209,7 @@ func (s *Snapshotter) prepareScannersAndGetters(columnTypes []*sql.ColumnType) ( } case "_TEXT": scanArgs[i] = new(sql.NullString) - valueGetters[i] = func(v interface{}) (interface{}, error) { + valueGetters[i] = func(v any) (any, error) { newArray := pgtype.TextArray{} val := v.(*sql.NullString).String if err := newArray.Scan(val); err != nil { @@ -220,7 +220,7 @@ func (s *Snapshotter) prepareScannersAndGetters(columnTypes []*sql.ColumnType) ( } default: scanArgs[i] = new(sql.NullString) - valueGetters[i] = func(v interface{}) (interface{}, error) { return v.(*sql.NullString).String, nil } + valueGetters[i] = func(v any) (any, error) { return v.(*sql.NullString).String, nil } } } @@ -239,13 +239,13 @@ func (s *Snapshotter) calculateBatchSize(availableMemory uint64, estimatedRowSiz return batchSize } -func (s *Snapshotter) querySnapshotData(ctx context.Context, table string, lastSeenPk map[string]any, pkColumns []string, limit int) (rows *sql.Rows, err error) { +func (s *Snapshotter) querySnapshotData(ctx context.Context, table TableFQN, lastSeenPk map[string]any, pkColumns []string, limit int) (rows *sql.Rows, err error) { s.logger.Debugf("Query snapshot table: %v, limit: %v, lastSeenPkVal: %v, pk: %v", table, limit, lastSeenPk, pkColumns) if lastSeenPk == nil { // NOTE: All strings passed into here have been validated or derived from the code/database, therefore not prone to SQL injection. - sq, err := sanitize.SQLQuery(fmt.Sprintf("SELECT * FROM %s ORDER BY %s LIMIT %d;", table, strings.Join(pkColumns, ", "), limit)) + sq, err := sanitize.SQLQuery(fmt.Sprintf("SELECT * FROM %s ORDER BY %s LIMIT %d;", table.String(), strings.Join(pkColumns, ", "), limit)) if err != nil { return nil, err } @@ -269,7 +269,7 @@ func (s *Snapshotter) querySnapshotData(ctx context.Context, table string, lastS pkAsTuple := "(" + strings.Join(pkColumns, ", ") + ")" // NOTE: All strings passed into here have been validated or derived from the code/database, therefore not prone to SQL injection. - sq, err := sanitize.SQLQuery(fmt.Sprintf("SELECT * FROM %s WHERE %s > %s ORDER BY %s LIMIT %d;", table, pkAsTuple, lastSeenPlaceHolders, strings.Join(pkColumns, ", "), limit), lastSeenPksValues...) + sq, err := sanitize.SQLQuery(fmt.Sprintf("SELECT * FROM %s WHERE %s > %s ORDER BY %s LIMIT %d;", table.String(), pkAsTuple, lastSeenPlaceHolders, strings.Join(pkColumns, ", "), limit), lastSeenPksValues...) if err != nil { return nil, err } diff --git a/internal/impl/postgresql/pglogicalstream/types.go b/internal/impl/postgresql/pglogicalstream/types.go index 2d1d0ff3ad..6596f8010a 100644 --- a/internal/impl/postgresql/pglogicalstream/types.go +++ b/internal/impl/postgresql/pglogicalstream/types.go @@ -7,3 +7,18 @@ // https://github.com/redpanda-data/connect/v4/blob/main/licenses/rcl.md package pglogicalstream + +import "fmt" + +// TableFQN is both a table name AND a schema name +// +// TableFQN should always be SAFE and validated before creating +type TableFQN struct { + Schema string + Table string +} + +// String satifies the Stringer interface +func (t TableFQN) String() string { + return fmt.Sprintf("%s.%s", t.Schema, t.Table) +}