Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix postgres_cdc input #3075

Merged
merged 8 commits into from
Dec 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ All notable changes to this project will be documented in this file.
### Fixed

- `gcp_bigquery` output with parquet format no longer returns errors incorrectly. (@rockwotj)
- `postgres_cdc` input now allows quoted identifiers for the table names. (@mihaitodor, @rockwotj)

## 4.43.1 - 2024-12-09

Expand Down
9 changes: 5 additions & 4 deletions docs/modules/components/pages/inputs/pg_stream.adoc
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,8 @@ The PostgreSQL schema from which to replicate data.
# Examples

schema: public

schema: '"MyCaseSensitiveSchemaNeedingQuotes"'
```

=== `tables`
Expand All @@ -206,10 +208,9 @@ A list of table names to include in the logical replication. Each table should b
```yml
# Examples

tables: |2-
- my_table
- my_table_2

tables:
- my_table_1
- '"MyCaseSensitiveTableNeedingQuotes"'
```

=== `checkpoint_limit`
Expand Down
9 changes: 5 additions & 4 deletions docs/modules/components/pages/inputs/postgres_cdc.adoc
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,8 @@ The PostgreSQL schema from which to replicate data.
# Examples

schema: public

schema: '"MyCaseSensitiveSchemaNeedingQuotes"'
```

=== `tables`
Expand All @@ -201,10 +203,9 @@ A list of table names to include in the logical replication. Each table should b
```yml
# Examples

tables: |2-
- my_table
- my_table_2

tables:
- my_table_1
- '"MyCaseSensitiveTableNeedingQuotes"'
```

=== `checkpoint_limit`
Expand Down
10 changes: 4 additions & 6 deletions internal/impl/postgresql/input_pg_stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,13 +85,11 @@ This input adds the following metadata fields to each message:
Default(0)).
Field(service.NewStringField(fieldSchema).
Description("The PostgreSQL schema from which to replicate data.").
Example("public")).
Examples("public", `"MyCaseSensitiveSchemaNeedingQuotes"`),
).
Field(service.NewStringListField(fieldTables).
Description("A list of table names to include in the logical replication. Each table should be specified as a separate item.").
Example(`
- my_table
- my_table_2
`)).
Example([]string{"my_table_1", `"MyCaseSensitiveTableNeedingQuotes"`})).
Field(service.NewIntField(fieldCheckpointLimit).
Description("The maximum number of messages that can be processed at a given time. Increasing this limit enables parallel processing and batching at the output level. Any given LSN will not be acknowledged unless all messages under that offset are delivered in order to preserve at least once delivery guarantees.").
Default(1024)).
Expand Down Expand Up @@ -321,7 +319,7 @@ func (p *pgStreamInput) processStream(pgStream *pglogicalstream.Stream, batcher
// Periodically collect stats
report := pgStream.GetProgress()
for name, progress := range report.TableProgress {
p.snapshotMetrics.SetFloat64(progress, name)
p.snapshotMetrics.SetFloat64(progress, name.String())
}
p.replicationLag.Set(report.WalLagInBytes)
})
Expand Down
27 changes: 16 additions & 11 deletions internal/impl/postgresql/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -127,10 +127,11 @@ func ResourceWithPostgreSQLVersion(t *testing.T, pool *dockertest.Pool, version
return err
}

// This table explicitly uses identifiers that need quoting to ensure we work with those correctly.
_, err = db.Exec(`
CREATE TABLE IF NOT EXISTS flights_composite_pks (
id serial, seq integer, name VARCHAR(50), created_at TIMESTAMP,
PRIMARY KEY (id, seq)
CREATE TABLE IF NOT EXISTS "FlightsCompositePK" (
"ID" serial, "Seq" integer, "Name" VARCHAR(50), "CreatedAt" TIMESTAMP,
PRIMARY KEY ("ID", "Seq")
);`)
if err != nil {
return err
Expand Down Expand Up @@ -171,7 +172,7 @@ func TestIntegrationPostgresNoTxnMarkers(t *testing.T) {

for i := 0; i < 10; i++ {
f := GetFakeFlightRecord()
_, err = db.Exec("INSERT INTO flights_composite_pks (seq, name, created_at) VALUES ($1, $2, $3);", i, f.RealAddress.City, time.Unix(f.CreatedAt, 0).Format(time.RFC3339))
_, err = db.Exec(`INSERT INTO "FlightsCompositePK" ("Seq", "Name", "CreatedAt") VALUES ($1, $2, $3);`, i, f.RealAddress.City, time.Unix(f.CreatedAt, 0).Format(time.RFC3339))
require.NoError(t, err)
}

Expand All @@ -184,7 +185,7 @@ pg_stream:
snapshot_batch_size: 5
schema: public
tables:
- flights_composite_pks
- '"FlightsCompositePK"'
`, databaseURL)

cacheConf := fmt.Sprintf(`
Expand All @@ -194,7 +195,7 @@ file:
`, tmpDir)

streamOutBuilder := service.NewStreamBuilder()
require.NoError(t, streamOutBuilder.SetLoggerYAML(`level: INFO`))
require.NoError(t, streamOutBuilder.SetLoggerYAML(`level: TRACE`))
require.NoError(t, streamOutBuilder.AddCacheYAML(cacheConf))
require.NoError(t, streamOutBuilder.AddInputYAML(template))

Expand Down Expand Up @@ -226,9 +227,9 @@ file:

for i := 10; i < 20; i++ {
f := GetFakeFlightRecord()
_, err = db.Exec("INSERT INTO flights_composite_pks (seq, name, created_at) VALUES ($1, $2, $3);", i, f.RealAddress.City, time.Unix(f.CreatedAt, 0).Format(time.RFC3339))
_, err = db.Exec(`INSERT INTO "FlightsCompositePK" ("Seq", "Name", "CreatedAt") VALUES ($1, $2, $3);`, i, f.RealAddress.City, time.Unix(f.CreatedAt, 0).Format(time.RFC3339))
require.NoError(t, err)
_, err = db.Exec("INSERT INTO flights_non_streamed (name, created_at) VALUES ($1, $2);", f.RealAddress.City, time.Unix(f.CreatedAt, 0).Format(time.RFC3339))
_, err = db.Exec(`INSERT INTO flights_non_streamed (name, created_at) VALUES ($1, $2);`, f.RealAddress.City, time.Unix(f.CreatedAt, 0).Format(time.RFC3339))
require.NoError(t, err)
}

Expand Down Expand Up @@ -270,7 +271,7 @@ file:
time.Sleep(time.Second * 5)
for i := 20; i < 30; i++ {
f := GetFakeFlightRecord()
_, err = db.Exec("INSERT INTO flights_composite_pks (seq, name, created_at) VALUES ($1, $2, $3);", i, f.RealAddress.City, time.Unix(f.CreatedAt, 0).Format(time.RFC3339))
_, err = db.Exec(`INSERT INTO "FlightsCompositePK" ("Seq", "Name", "CreatedAt") VALUES ($1, $2, $3);`, i, f.RealAddress.City, time.Unix(f.CreatedAt, 0).Format(time.RFC3339))
require.NoError(t, err)
}

Expand Down Expand Up @@ -648,9 +649,13 @@ pg_stream:
slot_name: test_slot_native_decoder
stream_snapshot: true
include_transaction_markers: false
schema: public
# This is intentionally with uppercase - we want to validate
# we treat identifiers the same as Postgres Queries.
schema: PuBliC
tables:
- flights
# This is intentionally with uppercase - we want to validate
# we treat identifiers the same as Postgres Queries.
- FLIGHTS
`, databaseURL)

cacheConf := fmt.Sprintf(`
Expand Down
67 changes: 34 additions & 33 deletions internal/impl/postgresql/pglogicalstream/logical_stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,6 @@ import (
"database/sql"
"errors"
"fmt"
"slices"
"strings"
"sync"
"time"

Expand Down Expand Up @@ -46,12 +44,10 @@ type Stream struct {
messages chan StreamMessage
errors chan error

includeTxnMarkers bool
snapshotName string
slotName string
schema string
// includes schema
tableQualifiedName []string
includeTxnMarkers bool
snapshotName string
slotName string
tables []TableFQN
snapshotBatchSize int
decodingPluginArguments []string
snapshotMemorySafetyFactor float64
Expand Down Expand Up @@ -90,13 +86,18 @@ func NewPgStream(ctx context.Context, config *Config) (*Stream, error) {
return nil, err
}

tableNames := slices.Clone(config.DBTables)
for i, table := range tableNames {
if err := sanitize.ValidatePostgresIdentifier(table); err != nil {
schema, err := sanitize.NormalizePostgresIdentifier(config.DBSchema)
if err != nil {
return nil, fmt.Errorf("invalid schema name %q: %w", config.DBSchema, err)
}

tables := []TableFQN{}
for _, table := range config.DBTables {
normalized, err := sanitize.NormalizePostgresIdentifier(table)
if err != nil {
return nil, fmt.Errorf("invalid table name %q: %w", table, err)
}

tableNames[i] = fmt.Sprintf("%s.%s", config.DBSchema, table)
tables = append(tables, TableFQN{Schema: schema, Table: normalized})
}
stream := &Stream{
pgConn: dbConn,
Expand All @@ -105,8 +106,7 @@ func NewPgStream(ctx context.Context, config *Config) (*Stream, error) {
slotName: config.ReplicationSlotName,
snapshotMemorySafetyFactor: config.SnapshotMemorySafetyFactor,
snapshotBatchSize: config.BatchSize,
schema: config.DBSchema,
tableQualifiedName: tableNames,
tables: tables,
maxParallelSnapshotTables: config.MaxParallelSnapshotTables,
logger: config.Logger,
shutSig: shutdown.NewSignaller(),
Expand Down Expand Up @@ -143,8 +143,8 @@ func NewPgStream(ctx context.Context, config *Config) (*Stream, error) {
stream.decodingPluginArguments = pluginArguments

pubName := "pglog_stream_" + config.ReplicationSlotName
stream.logger.Infof("Creating publication %s for tables: %s", pubName, tableNames)
if err = CreatePublication(ctx, stream.pgConn, pubName, tableNames); err != nil {
stream.logger.Infof("Creating publication %s for tables: %s", pubName, tables)
if err = CreatePublication(ctx, stream.pgConn, pubName, tables); err != nil {
return nil, err
}
cleanups = append(cleanups, func() {
Expand Down Expand Up @@ -219,7 +219,7 @@ func NewPgStream(ctx context.Context, config *Config) (*Stream, error) {
stream.standbyMessageTimeout = config.PgStandbyTimeout
stream.nextStandbyMessageDeadline = time.Now().Add(stream.standbyMessageTimeout)

monitor, err := NewMonitor(ctx, config.DBRawDSN, stream.logger, tableNames, stream.slotName, config.WalMonitorInterval)
monitor, err := NewMonitor(ctx, config.DBRawDSN, stream.logger, tables, stream.slotName, config.WalMonitorInterval)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -481,7 +481,7 @@ func (s *Stream) processSnapshot() error {
var wg errgroup.Group
wg.SetLimit(s.maxParallelSnapshotTables)

for _, table := range s.tableQualifiedName {
for _, table := range s.tables {
tableName := table
wg.Go(func() (err error) {
s.logger.Debugf("Processing snapshot for table: %v", table)
Expand Down Expand Up @@ -551,7 +551,6 @@ func (s *Stream) processSnapshot() error {
totalScanDuration := time.Duration(0)
totalWaitingFromBenthos := time.Duration(0)

tableWithoutSchema := strings.Split(table, ".")[1]
for snapshotRows.Next() {
rowsCount += 1

Expand All @@ -567,27 +566,28 @@ func (s *Stream) processSnapshot() error {

var data = make(map[string]any)
for i, getter := range valueGetters {
if data[columnNames[i]], err = getter(scanArgs[i]); err != nil {
col := columnNames[i]
var val any
if val, err = getter(scanArgs[i]); err != nil {
return err
}

if _, ok := lastPrimaryKey[columnNames[i]]; ok {
if lastPkVals[columnNames[i]], err = getter(scanArgs[i]); err != nil {
return err
}
data[col] = val
normalized := sanitize.QuotePostgresIdentifier(col)
if _, ok := lastPrimaryKey[normalized]; ok {
lastPkVals[normalized] = val
}
}

snapshotChangePacket := StreamMessage{
LSN: nil,
Operation: ReadOpType,
Table: tableWithoutSchema,
Schema: s.schema,
Table: table.Table,
Schema: table.Schema,
Data: data,
}

if rowsCount%100 == 0 {
s.monitor.UpdateSnapshotProgressForTable(tableWithoutSchema, rowsCount+offset)
s.monitor.UpdateSnapshotProgressForTable(table, rowsCount+offset)
}

waitingFromBenthos := time.Now()
Expand Down Expand Up @@ -627,7 +627,7 @@ func (s *Stream) Errors() chan error {
return s.errors
}

func (s *Stream) getPrimaryKeyColumn(ctx context.Context, tableName string) (map[string]any, []string, error) {
func (s *Stream) getPrimaryKeyColumn(ctx context.Context, table TableFQN) (map[string]any, []string, error) {
/// Query to get all primary key columns in their correct order
q, err := sanitize.SQLQuery(`
SELECT a.attname
Expand All @@ -637,7 +637,7 @@ func (s *Stream) getPrimaryKeyColumn(ctx context.Context, tableName string) (map
WHERE i.indrelid = $1::regclass
AND i.indisprimary
ORDER BY array_position(i.indkey, a.attnum);
`, tableName)
`, table.String())

if err != nil {
return nil, nil, fmt.Errorf("failed to sanitize query: %w", err)
Expand All @@ -650,13 +650,14 @@ func (s *Stream) getPrimaryKeyColumn(ctx context.Context, tableName string) (map
}

if len(data) == 0 || len(data[0].Rows) == 0 {
return nil, nil, fmt.Errorf("no primary key found for table %s", tableName)
return nil, nil, fmt.Errorf("no primary key found for table %s", table)
}

// Extract all primary key column names
pkColumns := make([]string, len(data[0].Rows))
for i, row := range data[0].Rows {
pkColumns[i] = string(row[0])
// Postgres gives us back normalized identifiers here - we need to quote them.
pkColumns[i] = sanitize.QuotePostgresIdentifier(string(row[0]))
}

var pksMap = make(map[string]any)
Expand Down
Loading
Loading