Skip to content

Commit

Permalink
Merge pull request #3075 from redpanda-data/mihaitodor-fix-postgres-cdc
Browse files Browse the repository at this point in the history
Fix `postgres_cdc` input
  • Loading branch information
rockwotj authored Dec 13, 2024
2 parents a495eb0 + 5606ca9 commit 881e6b2
Show file tree
Hide file tree
Showing 13 changed files with 296 additions and 156 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ All notable changes to this project will be documented in this file.
### Fixed

- `gcp_bigquery` output with parquet format no longer returns errors incorrectly. (@rockwotj)
- `postgres_cdc` input now allows quoted identifiers for the table names. (@mihaitodor, @rockwotj)

## 4.43.1 - 2024-12-09

Expand Down
9 changes: 5 additions & 4 deletions docs/modules/components/pages/inputs/pg_stream.adoc
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,8 @@ The PostgreSQL schema from which to replicate data.
# Examples
schema: public
schema: '"MyCaseSensitiveSchemaNeedingQuotes"'
```
=== `tables`
Expand All @@ -206,10 +208,9 @@ A list of table names to include in the logical replication. Each table should b
```yml
# Examples
tables: |2-
- my_table
- my_table_2
tables:
- my_table_1
- '"MyCaseSensitiveTableNeedingQuotes"'
```
=== `checkpoint_limit`
Expand Down
9 changes: 5 additions & 4 deletions docs/modules/components/pages/inputs/postgres_cdc.adoc
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,8 @@ The PostgreSQL schema from which to replicate data.
# Examples
schema: public
schema: '"MyCaseSensitiveSchemaNeedingQuotes"'
```
=== `tables`
Expand All @@ -201,10 +203,9 @@ A list of table names to include in the logical replication. Each table should b
```yml
# Examples
tables: |2-
- my_table
- my_table_2
tables:
- my_table_1
- '"MyCaseSensitiveTableNeedingQuotes"'
```
=== `checkpoint_limit`
Expand Down
10 changes: 4 additions & 6 deletions internal/impl/postgresql/input_pg_stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,13 +85,11 @@ This input adds the following metadata fields to each message:
Default(0)).
Field(service.NewStringField(fieldSchema).
Description("The PostgreSQL schema from which to replicate data.").
Example("public")).
Examples("public", `"MyCaseSensitiveSchemaNeedingQuotes"`),
).
Field(service.NewStringListField(fieldTables).
Description("A list of table names to include in the logical replication. Each table should be specified as a separate item.").
Example(`
- my_table
- my_table_2
`)).
Example([]string{"my_table_1", `"MyCaseSensitiveTableNeedingQuotes"`})).
Field(service.NewIntField(fieldCheckpointLimit).
Description("The maximum number of messages that can be processed at a given time. Increasing this limit enables parallel processing and batching at the output level. Any given LSN will not be acknowledged unless all messages under that offset are delivered in order to preserve at least once delivery guarantees.").
Default(1024)).
Expand Down Expand Up @@ -321,7 +319,7 @@ func (p *pgStreamInput) processStream(pgStream *pglogicalstream.Stream, batcher
// Periodically collect stats
report := pgStream.GetProgress()
for name, progress := range report.TableProgress {
p.snapshotMetrics.SetFloat64(progress, name)
p.snapshotMetrics.SetFloat64(progress, name.String())
}
p.replicationLag.Set(report.WalLagInBytes)
})
Expand Down
27 changes: 16 additions & 11 deletions internal/impl/postgresql/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -127,10 +127,11 @@ func ResourceWithPostgreSQLVersion(t *testing.T, pool *dockertest.Pool, version
return err
}

// This table explicitly uses identifiers that need quoting to ensure we work with those correctly.
_, err = db.Exec(`
CREATE TABLE IF NOT EXISTS flights_composite_pks (
id serial, seq integer, name VARCHAR(50), created_at TIMESTAMP,
PRIMARY KEY (id, seq)
CREATE TABLE IF NOT EXISTS "FlightsCompositePK" (
"ID" serial, "Seq" integer, "Name" VARCHAR(50), "CreatedAt" TIMESTAMP,
PRIMARY KEY ("ID", "Seq")
);`)
if err != nil {
return err
Expand Down Expand Up @@ -171,7 +172,7 @@ func TestIntegrationPostgresNoTxnMarkers(t *testing.T) {

for i := 0; i < 10; i++ {
f := GetFakeFlightRecord()
_, err = db.Exec("INSERT INTO flights_composite_pks (seq, name, created_at) VALUES ($1, $2, $3);", i, f.RealAddress.City, time.Unix(f.CreatedAt, 0).Format(time.RFC3339))
_, err = db.Exec(`INSERT INTO "FlightsCompositePK" ("Seq", "Name", "CreatedAt") VALUES ($1, $2, $3);`, i, f.RealAddress.City, time.Unix(f.CreatedAt, 0).Format(time.RFC3339))
require.NoError(t, err)
}

Expand All @@ -184,7 +185,7 @@ pg_stream:
snapshot_batch_size: 5
schema: public
tables:
- flights_composite_pks
- '"FlightsCompositePK"'
`, databaseURL)

cacheConf := fmt.Sprintf(`
Expand All @@ -194,7 +195,7 @@ file:
`, tmpDir)

streamOutBuilder := service.NewStreamBuilder()
require.NoError(t, streamOutBuilder.SetLoggerYAML(`level: INFO`))
require.NoError(t, streamOutBuilder.SetLoggerYAML(`level: TRACE`))
require.NoError(t, streamOutBuilder.AddCacheYAML(cacheConf))
require.NoError(t, streamOutBuilder.AddInputYAML(template))

Expand Down Expand Up @@ -226,9 +227,9 @@ file:

for i := 10; i < 20; i++ {
f := GetFakeFlightRecord()
_, err = db.Exec("INSERT INTO flights_composite_pks (seq, name, created_at) VALUES ($1, $2, $3);", i, f.RealAddress.City, time.Unix(f.CreatedAt, 0).Format(time.RFC3339))
_, err = db.Exec(`INSERT INTO "FlightsCompositePK" ("Seq", "Name", "CreatedAt") VALUES ($1, $2, $3);`, i, f.RealAddress.City, time.Unix(f.CreatedAt, 0).Format(time.RFC3339))
require.NoError(t, err)
_, err = db.Exec("INSERT INTO flights_non_streamed (name, created_at) VALUES ($1, $2);", f.RealAddress.City, time.Unix(f.CreatedAt, 0).Format(time.RFC3339))
_, err = db.Exec(`INSERT INTO flights_non_streamed (name, created_at) VALUES ($1, $2);`, f.RealAddress.City, time.Unix(f.CreatedAt, 0).Format(time.RFC3339))
require.NoError(t, err)
}

Expand Down Expand Up @@ -270,7 +271,7 @@ file:
time.Sleep(time.Second * 5)
for i := 20; i < 30; i++ {
f := GetFakeFlightRecord()
_, err = db.Exec("INSERT INTO flights_composite_pks (seq, name, created_at) VALUES ($1, $2, $3);", i, f.RealAddress.City, time.Unix(f.CreatedAt, 0).Format(time.RFC3339))
_, err = db.Exec(`INSERT INTO "FlightsCompositePK" ("Seq", "Name", "CreatedAt") VALUES ($1, $2, $3);`, i, f.RealAddress.City, time.Unix(f.CreatedAt, 0).Format(time.RFC3339))
require.NoError(t, err)
}

Expand Down Expand Up @@ -648,9 +649,13 @@ pg_stream:
slot_name: test_slot_native_decoder
stream_snapshot: true
include_transaction_markers: false
schema: public
# This is intentionally with uppercase - we want to validate
# we treat identifiers the same as Postgres Queries.
schema: PuBliC
tables:
- flights
# This is intentionally with uppercase - we want to validate
# we treat identifiers the same as Postgres Queries.
- FLIGHTS
`, databaseURL)

cacheConf := fmt.Sprintf(`
Expand Down
67 changes: 34 additions & 33 deletions internal/impl/postgresql/pglogicalstream/logical_stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,6 @@ import (
"database/sql"
"errors"
"fmt"
"slices"
"strings"
"sync"
"time"

Expand Down Expand Up @@ -46,12 +44,10 @@ type Stream struct {
messages chan StreamMessage
errors chan error

includeTxnMarkers bool
snapshotName string
slotName string
schema string
// includes schema
tableQualifiedName []string
includeTxnMarkers bool
snapshotName string
slotName string
tables []TableFQN
snapshotBatchSize int
decodingPluginArguments []string
snapshotMemorySafetyFactor float64
Expand Down Expand Up @@ -90,13 +86,18 @@ func NewPgStream(ctx context.Context, config *Config) (*Stream, error) {
return nil, err
}

tableNames := slices.Clone(config.DBTables)
for i, table := range tableNames {
if err := sanitize.ValidatePostgresIdentifier(table); err != nil {
schema, err := sanitize.NormalizePostgresIdentifier(config.DBSchema)
if err != nil {
return nil, fmt.Errorf("invalid schema name %q: %w", config.DBSchema, err)
}

tables := []TableFQN{}
for _, table := range config.DBTables {
normalized, err := sanitize.NormalizePostgresIdentifier(table)
if err != nil {
return nil, fmt.Errorf("invalid table name %q: %w", table, err)
}

tableNames[i] = fmt.Sprintf("%s.%s", config.DBSchema, table)
tables = append(tables, TableFQN{Schema: schema, Table: normalized})
}
stream := &Stream{
pgConn: dbConn,
Expand All @@ -105,8 +106,7 @@ func NewPgStream(ctx context.Context, config *Config) (*Stream, error) {
slotName: config.ReplicationSlotName,
snapshotMemorySafetyFactor: config.SnapshotMemorySafetyFactor,
snapshotBatchSize: config.BatchSize,
schema: config.DBSchema,
tableQualifiedName: tableNames,
tables: tables,
maxParallelSnapshotTables: config.MaxParallelSnapshotTables,
logger: config.Logger,
shutSig: shutdown.NewSignaller(),
Expand Down Expand Up @@ -143,8 +143,8 @@ func NewPgStream(ctx context.Context, config *Config) (*Stream, error) {
stream.decodingPluginArguments = pluginArguments

pubName := "pglog_stream_" + config.ReplicationSlotName
stream.logger.Infof("Creating publication %s for tables: %s", pubName, tableNames)
if err = CreatePublication(ctx, stream.pgConn, pubName, tableNames); err != nil {
stream.logger.Infof("Creating publication %s for tables: %s", pubName, tables)
if err = CreatePublication(ctx, stream.pgConn, pubName, tables); err != nil {
return nil, err
}
cleanups = append(cleanups, func() {
Expand Down Expand Up @@ -219,7 +219,7 @@ func NewPgStream(ctx context.Context, config *Config) (*Stream, error) {
stream.standbyMessageTimeout = config.PgStandbyTimeout
stream.nextStandbyMessageDeadline = time.Now().Add(stream.standbyMessageTimeout)

monitor, err := NewMonitor(ctx, config.DBRawDSN, stream.logger, tableNames, stream.slotName, config.WalMonitorInterval)
monitor, err := NewMonitor(ctx, config.DBRawDSN, stream.logger, tables, stream.slotName, config.WalMonitorInterval)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -481,7 +481,7 @@ func (s *Stream) processSnapshot() error {
var wg errgroup.Group
wg.SetLimit(s.maxParallelSnapshotTables)

for _, table := range s.tableQualifiedName {
for _, table := range s.tables {
tableName := table
wg.Go(func() (err error) {
s.logger.Debugf("Processing snapshot for table: %v", table)
Expand Down Expand Up @@ -551,7 +551,6 @@ func (s *Stream) processSnapshot() error {
totalScanDuration := time.Duration(0)
totalWaitingFromBenthos := time.Duration(0)

tableWithoutSchema := strings.Split(table, ".")[1]
for snapshotRows.Next() {
rowsCount += 1

Expand All @@ -567,27 +566,28 @@ func (s *Stream) processSnapshot() error {

var data = make(map[string]any)
for i, getter := range valueGetters {
if data[columnNames[i]], err = getter(scanArgs[i]); err != nil {
col := columnNames[i]
var val any
if val, err = getter(scanArgs[i]); err != nil {
return err
}

if _, ok := lastPrimaryKey[columnNames[i]]; ok {
if lastPkVals[columnNames[i]], err = getter(scanArgs[i]); err != nil {
return err
}
data[col] = val
normalized := sanitize.QuotePostgresIdentifier(col)
if _, ok := lastPrimaryKey[normalized]; ok {
lastPkVals[normalized] = val
}
}

snapshotChangePacket := StreamMessage{
LSN: nil,
Operation: ReadOpType,
Table: tableWithoutSchema,
Schema: s.schema,
Table: table.Table,
Schema: table.Schema,
Data: data,
}

if rowsCount%100 == 0 {
s.monitor.UpdateSnapshotProgressForTable(tableWithoutSchema, rowsCount+offset)
s.monitor.UpdateSnapshotProgressForTable(table, rowsCount+offset)
}

waitingFromBenthos := time.Now()
Expand Down Expand Up @@ -627,7 +627,7 @@ func (s *Stream) Errors() chan error {
return s.errors
}

func (s *Stream) getPrimaryKeyColumn(ctx context.Context, tableName string) (map[string]any, []string, error) {
func (s *Stream) getPrimaryKeyColumn(ctx context.Context, table TableFQN) (map[string]any, []string, error) {
/// Query to get all primary key columns in their correct order
q, err := sanitize.SQLQuery(`
SELECT a.attname
Expand All @@ -637,7 +637,7 @@ func (s *Stream) getPrimaryKeyColumn(ctx context.Context, tableName string) (map
WHERE i.indrelid = $1::regclass
AND i.indisprimary
ORDER BY array_position(i.indkey, a.attnum);
`, tableName)
`, table.String())

if err != nil {
return nil, nil, fmt.Errorf("failed to sanitize query: %w", err)
Expand All @@ -650,13 +650,14 @@ func (s *Stream) getPrimaryKeyColumn(ctx context.Context, tableName string) (map
}

if len(data) == 0 || len(data[0].Rows) == 0 {
return nil, nil, fmt.Errorf("no primary key found for table %s", tableName)
return nil, nil, fmt.Errorf("no primary key found for table %s", table)
}

// Extract all primary key column names
pkColumns := make([]string, len(data[0].Rows))
for i, row := range data[0].Rows {
pkColumns[i] = string(row[0])
// Postgres gives us back normalized identifiers here - we need to quote them.
pkColumns[i] = sanitize.QuotePostgresIdentifier(string(row[0]))
}

var pksMap = make(map[string]any)
Expand Down
Loading

0 comments on commit 881e6b2

Please sign in to comment.