From 5f28795096ef2f09b6260962f20f6ee375d2da27 Mon Sep 17 00:00:00 2001 From: Vladyslav Len Date: Mon, 18 Nov 2024 10:04:12 +0100 Subject: [PATCH] PostgreSQL CDC Plugin (#2917) * Move repos into connect * Add placeholders for logging and TODOs on panics * feat(pgstream): added support for pgoutput native plugin * feat(pgstream): added support for pgoutput native plugin * chore(pg_stream): updated table filtering * chore(): updated tests for pglogical stream * chore(): fmt applied * chore(): code re-org * chore(): added temp. replication slot and removed outdated code * chore(): fixed eslint errors && tests * chore(): removed panics * fix(): table name in snapshotter * chore(): working on stream uncomited changes * fix(postgres): correct order for message LSN ack * chore(): removed log line * chore(): working on metrics * chore(): removed test case && working on monitor testing * chore(): monitor testing * chore(): added backward compatibility for postgresql * chore(): updated tests for different pg versions && working on metrics * chore(): added WAL lag streaming * chore(): added snapshot metrics * chore(): added snapshot metrics streaming * chore(): added explicit value for snapshot batch size * chore(): updated docs * chore(): updated docs * chore(): applieds golangci-lint notes * chore(): working on faster snapshot processing * chore(): experimenting with object pool * Revert "chore(): experimenting with object pool" This reverts commit 041a55ca2a3cffd802c8cef2b91202623a0d9efe. * chore(): use common pool to process snapshot * chore(): added snapshot message rate counter * chore(): working on batches * fixed(): test * fix(): metrics * chore(): removed unused struct * chore(): stabilised batches * chore(): removed debug lines; fixed linter * chore(): updated tls config field && small refactoring * ref(): use context when create publication * pgcdc: cleanup configuration * By default we were just using a replication slot name of `rs_`. * Cleanup description * Fix typos * pgcdc: simplify stream setup Just have the user give us a DSN that is standard and our SQL* plugins already expect this format. That fixes bugs we have with special characters that need escaping, and generally simplfies setup. Also fixes: - Don't os.Exit, but bubble an error up - Use provided context instead of context.Background - Prevent SQL injection attacks in slot names * more review feedback. This got to be a lot so just checkpointing so Vlad can see where I am going. * Chan cleanup WIP Signed-off-by: Mihai Todor * chore(): addressed pull requests changes * chore(): updated tests * chore(): removed unused vars * chore(): run make deps to fix ci pipeline * fix(postgres_cdc): monitor tests * chore(postgres_cdc): added integration test skip check * fix(postgres_cdc): lint warnings * chore(): specify monitoring && standby intervals via config * chore(): removed redundant tests + deps * chore(): updated docs * pgstream: create batcher in foreground * pgstream: only check for done once * pgcdc: remove bool for operation * pgcdc: update docs for mode * pgcdc: validate slot names can't cause SQL injection * pgcdc: use error type for error handling, not bool * pgcdc: import sanitization code from pgx We are forced to use the simple query protocol for pg in replication mode, which means we need to sanitize stuff. Import some internal code from pgx for that. * pgcdc: add note about pk in snapshot reading * pgcdc: properly sanitize query * pgcdc add note about how waiting for commit is buggy * pgcdc: drop unused param * pgcdc: actually remove unused param * pgcdc: update docs * ref(): small code refactoring * feat(): added max_parallel_snapshot_tables config field * chore(): added pk ordering to consume snapshot * fix(): enabled integration tests * chore(): small fixes && pr notes * chore(): updated docs && fixed lint * chore(): revert integration tests * chore(): added publication updates instead of re-creation * pgcdc: prefix stat names * pgcdc: remove lsnrestart field * pgcdc: add a high watermark utility * pgcdc: use watermark for log position * pgcdc: remove layer of nesting from switch * pgcdc: use typed duration fields * pgcdc: fix waiting for txn ack * pgcdc: dedup config fields * pgcdc: fix config field defaults * pgcdc: properly implement watermark We need to be able to be cancelled if we never reach the watermark * pgcdc: properly ack only on commit messages, once everything is processed * pgcdc: there are actually 3 handlers * pgcdc: simplify plugin handling code * pgcdc: fix randomized ID uuid is invalid because we can't use dashes * pgcdc: remove unused import * pgcdc: always include mode * pgcdc: fix period batching and cleanup logic * pgcdc: fix lint error * pgcdc: regen docs * chore(): added +1 to standby update to follow postgresql requirements * chore: goimports * pgcdc: simplify shutdown in the input Still need to simplify this in the internal logical_stream package, but this is a first step * pgcdc: localize the pg stream To make lifetime semantics and handling ErrNotConnected better * pgcdc: simplify internal flow control Simplify the internal flow control of the logical stream by just returning and handling errors at the top level. * pgcdc: don't produce 0 messages * pgcdc: rename stream uncommitted to batch transactions * pgcdc: fix config name * pgcdc: add some TODOs * pgcdc: update docs * pgcdc: review feedback * pgcdc: cleanup monitor with periodic utility * pgcdc: fmt * pgcdc: check for non-zero duration * chore(): sanitized queries && fixed tests * chore(): removed wal2json support * chore(): updated pgstream docs * feat(): added support for composite primary keys * pgcdc: mark as enterprise licensed * chore(): applied make fmt * pgcdc/snapshot: use context for cancellation * pgcdc: fix primary key order by clause * pgcdc: fix zero batch check * update changelog --------- Signed-off-by: Mihai Todor Co-authored-by: Ashley Jeffs Co-authored-by: Tyler Rockwood Co-authored-by: Mihai Todor --- CHANGELOG.md | 10 + .../components/pages/inputs/pg_stream.adoc | 393 ++++++++ go.mod | 3 +- internal/impl/postgresql/input_pg_stream.go | 468 ++++++++++ internal/impl/postgresql/integration_test.go | 747 ++++++++++++++++ .../pglogicalstream/availablememory.go | 20 + .../impl/postgresql/pglogicalstream/config.go | 48 + .../postgresql/pglogicalstream/connection.go | 48 + .../pglogicalstream/logical_stream.go | 633 +++++++++++++ .../postgresql/pglogicalstream/monitor.go | 172 ++++ .../postgresql/pglogicalstream/pglogrepl.go | 698 +++++++++++++++ .../pglogicalstream/pglogrepl_test.go | 462 ++++++++++ .../pglogicalstream/pluginhandlers.go | 166 ++++ .../pglogicalstream/replication_message.go | 728 +++++++++++++++ .../replication_message_decoders.go | 153 ++++ .../replication_message_test.go | 838 ++++++++++++++++++ .../pglogicalstream/sanitize/sanitize.go | 390 ++++++++ .../pglogicalstream/sanitize/sanitize_test.go | 252 ++++++ .../postgresql/pglogicalstream/snapshotter.go | 239 +++++ .../pglogicalstream/stream_message.go | 44 + .../impl/postgresql/pglogicalstream/types.go | 9 + .../pglogicalstream/watermark/watermark.go | 70 ++ .../watermark/watermark_test.go | 53 ++ internal/impl/postgresql/utils.go | 54 ++ .../impl/redis/rate_limit_integration_test.go | 3 +- internal/plugins/info.csv | 1 + public/components/all/package.go | 1 + public/components/postgresql/package.go | 16 + 28 files changed, 6716 insertions(+), 3 deletions(-) create mode 100644 docs/modules/components/pages/inputs/pg_stream.adoc create mode 100644 internal/impl/postgresql/input_pg_stream.go create mode 100644 internal/impl/postgresql/integration_test.go create mode 100644 internal/impl/postgresql/pglogicalstream/availablememory.go create mode 100644 internal/impl/postgresql/pglogicalstream/config.go create mode 100644 internal/impl/postgresql/pglogicalstream/connection.go create mode 100644 internal/impl/postgresql/pglogicalstream/logical_stream.go create mode 100644 internal/impl/postgresql/pglogicalstream/monitor.go create mode 100644 internal/impl/postgresql/pglogicalstream/pglogrepl.go create mode 100644 internal/impl/postgresql/pglogicalstream/pglogrepl_test.go create mode 100644 internal/impl/postgresql/pglogicalstream/pluginhandlers.go create mode 100644 internal/impl/postgresql/pglogicalstream/replication_message.go create mode 100644 internal/impl/postgresql/pglogicalstream/replication_message_decoders.go create mode 100644 internal/impl/postgresql/pglogicalstream/replication_message_test.go create mode 100644 internal/impl/postgresql/pglogicalstream/sanitize/sanitize.go create mode 100644 internal/impl/postgresql/pglogicalstream/sanitize/sanitize_test.go create mode 100644 internal/impl/postgresql/pglogicalstream/snapshotter.go create mode 100644 internal/impl/postgresql/pglogicalstream/stream_message.go create mode 100644 internal/impl/postgresql/pglogicalstream/types.go create mode 100644 internal/impl/postgresql/pglogicalstream/watermark/watermark.go create mode 100644 internal/impl/postgresql/pglogicalstream/watermark/watermark_test.go create mode 100644 internal/impl/postgresql/utils.go create mode 100644 public/components/postgresql/package.go diff --git a/CHANGELOG.md b/CHANGELOG.md index b96f9e0b8e..70fdb569e2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -3,6 +3,16 @@ Changelog All notable changes to this project will be documented in this file. +## 4.40.0 - TBD + +### Added + +- New `pg_stream` input supporting change data capture (CDC) from PostgreSQL (@le-vlad) + +### Changed + +- `snowflake_streaming` with `schema_evolution.enabled` set to true can now autocreate tables. + ## 4.39.0 - 2024-11-07 ### Added diff --git a/docs/modules/components/pages/inputs/pg_stream.adoc b/docs/modules/components/pages/inputs/pg_stream.adoc new file mode 100644 index 0000000000..5410431626 --- /dev/null +++ b/docs/modules/components/pages/inputs/pg_stream.adoc @@ -0,0 +1,393 @@ += pg_stream +:type: input +:status: beta +:categories: ["Services"] + + + +//// + THIS FILE IS AUTOGENERATED! + + To make changes, edit the corresponding source file under: + + https://github.com/redpanda-data/connect/tree/main/internal/impl/. + + And: + + https://github.com/redpanda-data/connect/tree/main/cmd/tools/docs_gen/templates/plugin.adoc.tmpl +//// + +// © 2024 Redpanda Data Inc. + + +component_type_dropdown::[] + + +Streams changes from a PostgreSQL database using logical replication. + +Introduced in version 4.39.0. + + +[tabs] +====== +Common:: ++ +-- + +```yml +# Common config fields, showing default values +input: + label: "" + pg_stream: + dsn: postgres://foouser:foopass@localhost:5432/foodb?sslmode=disable # No default (required) + batch_transactions: true + stream_snapshot: false + snapshot_memory_safety_factor: 1 + snapshot_batch_size: 0 + schema: public # No default (required) + tables: [] # No default (required) + checkpoint_limit: 1024 + temporary_slot: false + slot_name: "" + pg_standby_timeout: 10s + pg_wal_monitor_interval: 3s + max_parallel_snapshot_tables: 1 + auto_replay_nacks: true + batching: + count: 0 + byte_size: 0 + period: "" + check: "" +``` + +-- +Advanced:: ++ +-- + +```yml +# All config fields, showing default values +input: + label: "" + pg_stream: + dsn: postgres://foouser:foopass@localhost:5432/foodb?sslmode=disable # No default (required) + batch_transactions: true + stream_snapshot: false + snapshot_memory_safety_factor: 1 + snapshot_batch_size: 0 + schema: public # No default (required) + tables: [] # No default (required) + checkpoint_limit: 1024 + temporary_slot: false + slot_name: "" + pg_standby_timeout: 10s + pg_wal_monitor_interval: 3s + max_parallel_snapshot_tables: 1 + auto_replay_nacks: true + batching: + count: 0 + byte_size: 0 + period: "" + check: "" + processors: [] # No default (optional) +``` + +-- +====== + +Streams changes from a PostgreSQL database for Change Data Capture (CDC). +Additionally, if `stream_snapshot` is set to true, then the existing data in the database is also streamed too. + +== Metadata + +This input adds the following metadata fields to each message: +- mode (Either "streaming" or "snapshot" indicating whether the message is part of a streaming operation or snapshot processing) +- table (Name of the table that the message originated from) +- operation (Type of operation that generated the message, such as INSERT, UPDATE, or DELETE) + + +== Fields + +=== `dsn` + +The Data Source Name for the PostgreSQL database in the form of `postgres://[user[:password]@][netloc][:port][/dbname][?param1=value1&...]`. Please note that Postgres enforces SSL by default, you can override this with the parameter `sslmode=disable` if required. + + +*Type*: `string` + + +```yml +# Examples + +dsn: postgres://foouser:foopass@localhost:5432/foodb?sslmode=disable +``` + +=== `batch_transactions` + +When set to true, transactions are batched into a single message. + + +*Type*: `bool` + +*Default*: `true` + +=== `stream_snapshot` + +When set to true, the plugin will first stream a snapshot of all existing data in the database before streaming changes. In order to use this the tables that are being snapshot MUST have a primary key set so that reading from the table can be parallelized. + + +*Type*: `bool` + +*Default*: `false` + +```yml +# Examples + +stream_snapshot: true +``` + +=== `snapshot_memory_safety_factor` + +Determines the fraction of available memory that can be used for streaming the snapshot. Values between 0 and 1 represent the percentage of memory to use. Lower values make initial streaming slower but help prevent out-of-memory errors. + + +*Type*: `float` + +*Default*: `1` + +```yml +# Examples + +snapshot_memory_safety_factor: 0.2 +``` + +=== `snapshot_batch_size` + +The number of rows to fetch in each batch when querying the snapshot. A value of 0 lets the plugin determine the batch size based on `snapshot_memory_safety_factor` property. + + +*Type*: `int` + +*Default*: `0` + +```yml +# Examples + +snapshot_batch_size: 10000 +``` + +=== `schema` + +The PostgreSQL schema from which to replicate data. + + +*Type*: `string` + + +```yml +# Examples + +schema: public +``` + +=== `tables` + +A list of table names to include in the logical replication. Each table should be specified as a separate item. + + +*Type*: `array` + + +```yml +# Examples + +tables: |2- + - my_table + - my_table_2 + +``` + +=== `checkpoint_limit` + +The maximum number of messages that can be processed at a given time. Increasing this limit enables parallel processing and batching at the output level. Any given LSN will not be acknowledged unless all messages under that offset are delivered in order to preserve at least once delivery guarantees. + + +*Type*: `int` + +*Default*: `1024` + +=== `temporary_slot` + +If set to true, creates a temporary replication slot that is automatically dropped when the connection is closed. + + +*Type*: `bool` + +*Default*: `false` + +=== `slot_name` + +The name of the PostgreSQL logical replication slot to use. If not provided, a random name will be generated. You can create this slot manually before starting replication if desired. + + +*Type*: `string` + +*Default*: `""` + +```yml +# Examples + +slot_name: my_test_slot +``` + +=== `pg_standby_timeout` + +Specify the standby timeout before refreshing an idle connection. + + +*Type*: `string` + +*Default*: `"10s"` + +```yml +# Examples + +pg_standby_timeout: 30s +``` + +=== `pg_wal_monitor_interval` + +How often to report changes to the replication lag. + + +*Type*: `string` + +*Default*: `"3s"` + +```yml +# Examples + +pg_wal_monitor_interval: 6s +``` + +=== `max_parallel_snapshot_tables` + +Int specifies a number of tables that will be processed in parallel during the snapshot processing stage + + +*Type*: `int` + +*Default*: `1` + +=== `auto_replay_nacks` + +Whether messages that are rejected (nacked) at the output level should be automatically replayed indefinitely, eventually resulting in back pressure if the cause of the rejections is persistent. If set to `false` these messages will instead be deleted. Disabling auto replays can greatly improve memory efficiency of high throughput streams as the original shape of the data can be discarded immediately upon consumption and mutation. + + +*Type*: `bool` + +*Default*: `true` + +=== `batching` + +Allows you to configure a xref:configuration:batching.adoc[batching policy]. + + +*Type*: `object` + + +```yml +# Examples + +batching: + byte_size: 5000 + count: 0 + period: 1s + +batching: + count: 10 + period: 1s + +batching: + check: this.contains("END BATCH") + count: 0 + period: 1m +``` + +=== `batching.count` + +A number of messages at which the batch should be flushed. If `0` disables count based batching. + + +*Type*: `int` + +*Default*: `0` + +=== `batching.byte_size` + +An amount of bytes at which the batch should be flushed. If `0` disables size based batching. + + +*Type*: `int` + +*Default*: `0` + +=== `batching.period` + +A period in which an incomplete batch should be flushed regardless of its size. + + +*Type*: `string` + +*Default*: `""` + +```yml +# Examples + +period: 1s + +period: 1m + +period: 500ms +``` + +=== `batching.check` + +A xref:guides:bloblang/about.adoc[Bloblang query] that should return a boolean value indicating whether a message should end a batch. + + +*Type*: `string` + +*Default*: `""` + +```yml +# Examples + +check: this.type == "end_of_transaction" +``` + +=== `batching.processors` + +A list of xref:components:processors/about.adoc[processors] to apply to a batch as it is flushed. This allows you to aggregate and archive the batch however you see fit. Please note that all resulting messages are flushed as a single batch, therefore splitting the batch into smaller batches using these processors is a no-op. + + +*Type*: `array` + + +```yml +# Examples + +processors: + - archive: + format: concatenate + +processors: + - archive: + format: lines + +processors: + - archive: + format: json_array +``` + + diff --git a/go.mod b/go.mod index 0154c08207..64868d3804 100644 --- a/go.mod +++ b/go.mod @@ -72,6 +72,7 @@ require ( github.com/gosimple/slug v1.14.0 github.com/influxdata/influxdb1-client v0.0.0-20220302092344-a9ab5670611c github.com/jackc/pgx/v4 v4.18.3 + github.com/jackc/pgx/v5 v5.6.0 github.com/jhump/protoreflect v1.16.0 github.com/lib/pq v1.10.9 github.com/linkedin/goavro/v2 v2.13.0 @@ -302,7 +303,7 @@ require ( github.com/itchyny/timefmt-go v0.1.6 // indirect github.com/jackc/chunkreader/v2 v2.0.1 // indirect github.com/jackc/pgconn v1.14.3 - github.com/jackc/pgio v1.0.0 // indirect + github.com/jackc/pgio v1.0.0 github.com/jackc/pgpassfile v1.0.0 // indirect github.com/jackc/pgproto3/v2 v2.3.3 // indirect github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect diff --git a/internal/impl/postgresql/input_pg_stream.go b/internal/impl/postgresql/input_pg_stream.go new file mode 100644 index 0000000000..fcdbfb217f --- /dev/null +++ b/internal/impl/postgresql/input_pg_stream.go @@ -0,0 +1,468 @@ +// Copyright 2024 Redpanda Data, Inc. +// +// Licensed as a Redpanda Enterprise file under the Redpanda Community +// License (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// https://github.com/redpanda-data/connect/v4/blob/main/licenses/rcl.md + +package pgstream + +import ( + "context" + "encoding/json" + "fmt" + "time" + + "github.com/Jeffail/checkpoint" + "github.com/Jeffail/shutdown" + "github.com/jackc/pgx/v5/pgconn" + gonanoid "github.com/matoous/go-nanoid/v2" + "github.com/redpanda-data/benthos/v4/public/service" + + "github.com/redpanda-data/connect/v4/internal/impl/postgresql/pglogicalstream" +) + +const ( + fieldDSN = "dsn" + fieldBatchTransactions = "batch_transactions" + fieldStreamSnapshot = "stream_snapshot" + fieldSnapshotMemSafetyFactor = "snapshot_memory_safety_factor" + fieldSnapshotBatchSize = "snapshot_batch_size" + fieldSchema = "schema" + fieldTables = "tables" + fieldCheckpointLimit = "checkpoint_limit" + fieldTemporarySlot = "temporary_slot" + fieldPgStandbyTimeout = "pg_standby_timeout" + fieldWalMonitorInterval = "pg_wal_monitor_interval" + fieldSlotName = "slot_name" + fieldBatching = "batching" + fieldMaxParallelSnapshotTables = "max_parallel_snapshot_tables" + + shutdownTimeout = 5 * time.Second +) + +type asyncMessage struct { + msg service.MessageBatch + ackFn service.AckFunc +} + +var pgStreamConfigSpec = service.NewConfigSpec(). + Beta(). + Categories("Services"). + Version("4.39.0"). + Summary(`Streams changes from a PostgreSQL database using logical replication.`). + Description(`Streams changes from a PostgreSQL database for Change Data Capture (CDC). +Additionally, if ` + "`" + fieldStreamSnapshot + "`" + ` is set to true, then the existing data in the database is also streamed too. + +== Metadata + +This input adds the following metadata fields to each message: +- mode (Either "streaming" or "snapshot" indicating whether the message is part of a streaming operation or snapshot processing) +- table (Name of the table that the message originated from) +- operation (Type of operation that generated the message, such as INSERT, UPDATE, or DELETE) + `). + Field(service.NewStringField(fieldDSN). + Description("The Data Source Name for the PostgreSQL database in the form of `postgres://[user[:password]@][netloc][:port][/dbname][?param1=value1&...]`. Please note that Postgres enforces SSL by default, you can override this with the parameter `sslmode=disable` if required."). + Example("postgres://foouser:foopass@localhost:5432/foodb?sslmode=disable")). + Field(service.NewBoolField(fieldBatchTransactions). + Description("When set to true, transactions are batched into a single message."). + Default(true)). + Field(service.NewBoolField(fieldStreamSnapshot). + Description("When set to true, the plugin will first stream a snapshot of all existing data in the database before streaming changes. In order to use this the tables that are being snapshot MUST have a primary key set so that reading from the table can be parallelized."). + Example(true). + Default(false)). + Field(service.NewFloatField(fieldSnapshotMemSafetyFactor). + Description("Determines the fraction of available memory that can be used for streaming the snapshot. Values between 0 and 1 represent the percentage of memory to use. Lower values make initial streaming slower but help prevent out-of-memory errors."). + Example(0.2). + Default(1)). + Field(service.NewIntField(fieldSnapshotBatchSize). + Description("The number of rows to fetch in each batch when querying the snapshot. A value of 0 lets the plugin determine the batch size based on `snapshot_memory_safety_factor` property."). + Example(10000). + Default(0)). + Field(service.NewStringField(fieldSchema). + Description("The PostgreSQL schema from which to replicate data."). + Example("public")). + Field(service.NewStringListField(fieldTables). + Description("A list of table names to include in the logical replication. Each table should be specified as a separate item."). + Example(` + - my_table + - my_table_2 + `)). + Field(service.NewIntField(fieldCheckpointLimit). + Description("The maximum number of messages that can be processed at a given time. Increasing this limit enables parallel processing and batching at the output level. Any given LSN will not be acknowledged unless all messages under that offset are delivered in order to preserve at least once delivery guarantees."). + Default(1024)). + Field(service.NewBoolField(fieldTemporarySlot). + Description("If set to true, creates a temporary replication slot that is automatically dropped when the connection is closed."). + Default(false)). + Field(service.NewStringField(fieldSlotName). + Description("The name of the PostgreSQL logical replication slot to use. If not provided, a random name will be generated. You can create this slot manually before starting replication if desired."). + Example("my_test_slot"). + Default("")). + Field(service.NewDurationField(fieldPgStandbyTimeout). + Description("Specify the standby timeout before refreshing an idle connection."). + Example("30s"). + Default("10s")). + Field(service.NewDurationField(fieldWalMonitorInterval). + Description("How often to report changes to the replication lag."). + Example("6s"). + Default("3s")). + Field(service.NewIntField(fieldMaxParallelSnapshotTables). + Description("Int specifies a number of tables that will be processed in parallel during the snapshot processing stage"). + Default(1)). + Field(service.NewAutoRetryNacksToggleField()). + Field(service.NewBatchPolicyField(fieldBatching)) + +func newPgStreamInput(conf *service.ParsedConfig, mgr *service.Resources) (s service.BatchInput, err error) { + var ( + dsn string + dbSlotName string + temporarySlot bool + schema string + tables []string + streamSnapshot bool + snapshotMemSafetyFactor float64 + batchTransactions bool + snapshotBatchSize int + checkpointLimit int + walMonitorInterval time.Duration + maxParallelSnapshotTables int + pgStandbyTimeout time.Duration + batching service.BatchPolicy + ) + + if dsn, err = conf.FieldString(fieldDSN); err != nil { + return nil, err + } + + if dbSlotName, err = conf.FieldString(fieldSlotName); err != nil { + return nil, err + } + // Set the default to be a random string + if dbSlotName == "" { + dbSlotName, err = gonanoid.Generate("0123456789ABCDEFGHJKMNPQRSTVWXYZ", 32) + if err != nil { + return nil, err + } + } + + if err := validateSimpleString(dbSlotName); err != nil { + return nil, fmt.Errorf("invalid slot_name: %w", err) + } + + if temporarySlot, err = conf.FieldBool(fieldTemporarySlot); err != nil { + return nil, err + } + + if schema, err = conf.FieldString(fieldSchema); err != nil { + return nil, err + } + + if tables, err = conf.FieldStringList(fieldTables); err != nil { + return nil, err + } + + if checkpointLimit, err = conf.FieldInt(fieldCheckpointLimit); err != nil { + return nil, err + } + + if streamSnapshot, err = conf.FieldBool(fieldStreamSnapshot); err != nil { + return nil, err + } + + if batchTransactions, err = conf.FieldBool(fieldBatchTransactions); err != nil { + return nil, err + } + + if snapshotMemSafetyFactor, err = conf.FieldFloat(fieldSnapshotMemSafetyFactor); err != nil { + return nil, err + } + + if snapshotBatchSize, err = conf.FieldInt(fieldSnapshotBatchSize); err != nil { + return nil, err + } + + if batching, err = conf.FieldBatchPolicy(fieldBatching); err != nil { + return nil, err + } else if batching.IsNoop() { + batching.Count = 1 + } + + if pgStandbyTimeout, err = conf.FieldDuration(fieldPgStandbyTimeout); err != nil { + return nil, err + } + + if walMonitorInterval, err = conf.FieldDuration(fieldWalMonitorInterval); err != nil { + return nil, err + } + + if maxParallelSnapshotTables, err = conf.FieldInt(fieldMaxParallelSnapshotTables); err != nil { + return nil, err + } + + pgConnConfig, err := pgconn.ParseConfigWithOptions(dsn, pgconn.ParseConfigOptions{ + // Don't support dynamic reading of password + GetSSLPassword: func(context.Context) string { return "" }, + }) + if err != nil { + return nil, err + } + // This is required for postgres to understand we're interested in replication. + // https://github.com/jackc/pglogrepl/issues/6 + pgConnConfig.RuntimeParams["replication"] = "database" + + snapshotMetrics := mgr.Metrics().NewGauge("postgres_snapshot_progress", "table") + replicationLag := mgr.Metrics().NewGauge("postgres_replication_lag_bytes") + + i := &pgStreamInput{ + streamConfig: &pglogicalstream.Config{ + DBConfig: pgConnConfig, + DBRawDSN: dsn, + DBSchema: schema, + DBTables: tables, + + ReplicationSlotName: "rs_" + dbSlotName, + BatchSize: snapshotBatchSize, + StreamOldData: streamSnapshot, + TemporaryReplicationSlot: temporarySlot, + BatchTransactions: batchTransactions, + SnapshotMemorySafetyFactor: snapshotMemSafetyFactor, + PgStandbyTimeout: pgStandbyTimeout, + WalMonitorInterval: walMonitorInterval, + MaxParallelSnapshotTables: maxParallelSnapshotTables, + Logger: mgr.Logger(), + }, + batching: batching, + checkpointLimit: checkpointLimit, + msgChan: make(chan asyncMessage), + + mgr: mgr, + logger: mgr.Logger(), + snapshotMetrics: snapshotMetrics, + replicationLag: replicationLag, + stopSig: shutdown.NewSignaller(), + } + + // Has stopped is how we notify that we're not connected. This will get reset at connection time. + i.stopSig.TriggerHasStopped() + + r, err := service.AutoRetryNacksBatchedToggled(conf, i) + if err != nil { + return nil, err + } + + return conf.WrapBatchInputExtractTracingSpanMapping("pg_stream", r) +} + +// validateSimpleString ensures we aren't vuln to SQL injection +func validateSimpleString(s string) error { + for _, b := range []byte(s) { + isDigit := b >= '0' && b <= '9' + isLower := b >= 'a' && b <= 'z' + isUpper := b >= 'A' && b <= 'Z' + isDelimiter := b == '_' + if !isDigit && !isLower && !isUpper && !isDelimiter { + return fmt.Errorf("invalid postgres identifier %q", s) + } + } + return nil +} + +func init() { + err := service.RegisterBatchInput("pg_stream", pgStreamConfigSpec, newPgStreamInput) + if err != nil { + panic(err) + } +} + +type pgStreamInput struct { + streamConfig *pglogicalstream.Config + logger *service.Logger + mgr *service.Resources + msgChan chan asyncMessage + batching service.BatchPolicy + checkpointLimit int + + snapshotMetrics *service.MetricGauge + replicationLag *service.MetricGauge + stopSig *shutdown.Signaller +} + +func (p *pgStreamInput) Connect(ctx context.Context) error { + pgStream, err := pglogicalstream.NewPgStream(ctx, p.streamConfig) + if err != nil { + return fmt.Errorf("unable to create replication stream: %w", err) + } + batcher, err := p.batching.NewBatcher(p.mgr) + if err != nil { + return err + } + // Reset our stop signal + p.stopSig = shutdown.NewSignaller() + go p.processStream(pgStream, batcher) + return err +} + +func (p *pgStreamInput) processStream(pgStream *pglogicalstream.Stream, batcher *service.Batcher) { + ctx, _ := p.stopSig.SoftStopCtx(context.Background()) + defer func() { + ctx, _ := p.stopSig.HardStopCtx(context.Background()) + if err := batcher.Close(ctx); err != nil { + p.logger.Errorf("unable to close batcher: %s", err) + } + // TODO(rockwood): We should wait for outstanding acks to be completed (best effort) + if err := pgStream.Stop(ctx); err != nil { + p.logger.Errorf("unable to stop replication stream: %s", err) + } + p.stopSig.TriggerHasStopped() + }() + + var nextTimedBatchChan <-chan time.Time + + // offsets are nilable since we don't provide offset tracking during the snapshot phase + cp := checkpoint.NewCapped[*int64](int64(p.checkpointLimit)) + for !p.stopSig.IsSoftStopSignalled() { + select { + case <-nextTimedBatchChan: + nextTimedBatchChan = nil + flushedBatch, err := batcher.Flush(ctx) + if err != nil { + p.logger.Debugf("timed flush batch error: %s", err) + break + } + if err := p.flushBatch(ctx, pgStream, cp, flushedBatch); err != nil { + p.logger.Debugf("failed to flush batch: %s", err) + break + } + case message := <-pgStream.Messages(): + var ( + mb []byte + err error + ) + + if len(message.Changes) == 0 { + p.logger.Errorf("received empty message (LSN=%v)", message.Lsn) + break + } + + // TODO(rockwood): this should only be the message + if mb, err = json.Marshal(message.Changes); err != nil { + break + } + + batchMsg := service.NewMessage(mb) + + batchMsg.MetaSet("mode", string(message.Mode)) + batchMsg.MetaSet("table", message.Changes[0].Table) + batchMsg.MetaSet("operation", message.Changes[0].Operation) + if message.Lsn != nil { + batchMsg.MetaSet("lsn", *message.Lsn) + } + if message.Changes[0].TableSnapshotProgress != nil { + p.snapshotMetrics.SetFloat64(*message.Changes[0].TableSnapshotProgress, message.Changes[0].Table) + } + if message.WALLagBytes != nil { + p.replicationLag.Set(*message.WALLagBytes) + } + + if batcher.Add(batchMsg) { + nextTimedBatchChan = nil + flushedBatch, err := batcher.Flush(ctx) + if err != nil { + p.logger.Debugf("error flushing batch: %s", err) + break + } + if err := p.flushBatch(ctx, pgStream, cp, flushedBatch); err != nil { + p.logger.Debugf("failed to flush batch: %s", err) + break + } + } else { + d, ok := batcher.UntilNext() + if ok { + nextTimedBatchChan = time.After(d) + } + } + case err := <-pgStream.Errors(): + p.logger.Warnf("logical replication stream error: %s", err) + // If the stream has internally errored then we should stop and restart processing + p.stopSig.TriggerSoftStop() + case <-p.stopSig.SoftStopChan(): + p.logger.Debug("soft stop triggered, stopping logical replication stream") + } + } +} + +func (p *pgStreamInput) flushBatch( + ctx context.Context, + pgStream *pglogicalstream.Stream, + checkpointer *checkpoint.Capped[*int64], + batch service.MessageBatch, +) error { + if len(batch) == 0 { + return nil + } + + var lsn *int64 + lastMsg := batch[len(batch)-1] + lsnStr, ok := lastMsg.MetaGet("lsn") + if ok { + parsed, err := LSNToInt64(lsnStr) + if err != nil { + return fmt.Errorf("unable to extract LSN from last message in batch: %w", err) + } + lsn = &parsed + } + resolveFn, err := checkpointer.Track(ctx, lsn, int64(len(batch))) + if err != nil { + return fmt.Errorf("unable to checkpoint: %w", err) + } + + ackFn := func(ctx context.Context, res error) error { + maxOffset := resolveFn() + if maxOffset == nil { + return nil + } + lsn := *maxOffset + if lsn == nil { + return nil + } + if err = pgStream.AckLSN(ctx, Int64ToLSN(*lsn)); err != nil { + return fmt.Errorf("unable to ack LSN to postgres: %w", err) + } + return nil + } + select { + case p.msgChan <- asyncMessage{msg: batch, ackFn: ackFn}: + case <-ctx.Done(): + return ctx.Err() + } + return nil +} + +func (p *pgStreamInput) ReadBatch(ctx context.Context) (service.MessageBatch, service.AckFunc, error) { + select { + case m := <-p.msgChan: + return m.msg, m.ackFn, nil + case <-p.stopSig.HasStoppedChan(): + return nil, nil, service.ErrNotConnected + case <-ctx.Done(): + return nil, nil, ctx.Err() + } +} + +func (p *pgStreamInput) Close(ctx context.Context) error { + p.stopSig.TriggerSoftStop() + select { + case <-ctx.Done(): + case <-time.After(shutdownTimeout): + case <-p.stopSig.HasStoppedChan(): + } + p.stopSig.TriggerHardStop() + select { + case <-ctx.Done(): + return ctx.Err() + case <-time.After(shutdownTimeout): + case <-p.stopSig.HasStoppedChan(): + } + return nil +} diff --git a/internal/impl/postgresql/integration_test.go b/internal/impl/postgresql/integration_test.go new file mode 100644 index 0000000000..f22878d2cb --- /dev/null +++ b/internal/impl/postgresql/integration_test.go @@ -0,0 +1,747 @@ +// Copyright 2024 Redpanda Data, Inc. +// +// Licensed as a Redpanda Enterprise file under the Redpanda Community +// License (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// https://github.com/redpanda-data/connect/v4/blob/main/licenses/rcl.md + +package pgstream + +import ( + "context" + "database/sql" + "fmt" + "strings" + "sync" + "testing" + "time" + + "github.com/go-faker/faker/v4" + _ "github.com/lib/pq" + _ "github.com/redpanda-data/benthos/v4/public/components/io" + _ "github.com/redpanda-data/benthos/v4/public/components/pure" + "github.com/redpanda-data/benthos/v4/public/service" + "github.com/redpanda-data/benthos/v4/public/service/integration" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/ory/dockertest/v3" + "github.com/ory/dockertest/v3/docker" +) + +type FakeFlightRecord struct { + RealAddress faker.RealAddress `faker:"real_address"` + CreatedAt int64 `fake:"unix_time"` +} + +func GetFakeFlightRecord() FakeFlightRecord { + flightRecord := FakeFlightRecord{} + err := faker.FakeData(&flightRecord) + if err != nil { + panic(err) + } + + return flightRecord +} + +func ResourceWithPostgreSQLVersion(t *testing.T, pool *dockertest.Pool, version string) (*dockertest.Resource, *sql.DB, error) { + resource, err := pool.RunWithOptions(&dockertest.RunOptions{ + Repository: "postgres", + Tag: version, + Env: []string{ + "POSTGRES_PASSWORD=l]YLSc|4[i56%{gY", + "POSTGRES_USER=user_name", + "POSTGRES_DB=dbname", + }, + Cmd: []string{ + "postgres", + "-c", "wal_level=logical", + }, + }, func(config *docker.HostConfig) { + config.AutoRemove = true + config.RestartPolicy = docker.RestartPolicy{Name: "no"} + }) + + require.NoError(t, err) + t.Cleanup(func() { + assert.NoError(t, pool.Purge(resource)) + }) + + require.NoError(t, resource.Expire(120)) + + hostAndPort := resource.GetHostPort("5432/tcp") + hostAndPortSplited := strings.Split(hostAndPort, ":") + password := "l]YLSc|4[i56%{gY" + databaseURL := fmt.Sprintf("user=user_name password=%s dbname=dbname sslmode=disable host=%s port=%s", password, hostAndPortSplited[0], hostAndPortSplited[1]) + + var db *sql.DB + pool.MaxWait = 120 * time.Second + if err = pool.Retry(func() error { + if db, err = sql.Open("postgres", databaseURL); err != nil { + return err + } + + if err = db.Ping(); err != nil { + return err + } + + var walLevel string + if err = db.QueryRow("SHOW wal_level").Scan(&walLevel); err != nil { + return err + } + + var pgConfig string + if err = db.QueryRow("SHOW config_file").Scan(&pgConfig); err != nil { + return err + } + + if walLevel != "logical" { + return fmt.Errorf("wal_level is not logical") + } + + _, err = db.Exec("CREATE TABLE IF NOT EXISTS flights (id serial PRIMARY KEY, name VARCHAR(50), created_at TIMESTAMP);") + if err != nil { + return err + } + + _, err = db.Exec(` + CREATE TABLE IF NOT EXISTS flights_composite_pks ( + id serial, seq integer, name VARCHAR(50), created_at TIMESTAMP, + PRIMARY KEY (id, seq) + );`) + if err != nil { + return err + } + + // flights_non_streamed is a control table with data that should not be streamed or queried by snapshot streaming + _, err = db.Exec("CREATE TABLE IF NOT EXISTS flights_non_streamed (id serial PRIMARY KEY, name VARCHAR(50), created_at TIMESTAMP);") + + return err + }); err != nil { + panic(fmt.Errorf("could not connect to docker: %w", err)) + } + + return resource, db, nil +} + +func TestIntegrationPgCDCForPgOutputPlugin(t *testing.T) { + integration.CheckSkip(t) + tmpDir := t.TempDir() + pool, err := dockertest.NewPool("") + require.NoError(t, err) + + var ( + resource *dockertest.Resource + db *sql.DB + ) + + resource, db, err = ResourceWithPostgreSQLVersion(t, pool, "16") + require.NoError(t, err) + require.NoError(t, resource.Expire(120)) + + hostAndPort := resource.GetHostPort("5432/tcp") + hostAndPortSplited := strings.Split(hostAndPort, ":") + password := "l]YLSc|4[i56%{gY" + + require.NoError(t, err) + + for i := 0; i < 10; i++ { + f := GetFakeFlightRecord() + _, err = db.Exec("INSERT INTO flights_composite_pks (seq, name, created_at) VALUES ($1, $2, $3);", i, f.RealAddress.City, time.Unix(f.CreatedAt, 0).Format(time.RFC3339)) + require.NoError(t, err) + } + + databaseURL := fmt.Sprintf("user=user_name password=%s dbname=dbname sslmode=disable host=%s port=%s", password, hostAndPortSplited[0], hostAndPortSplited[1]) + template := fmt.Sprintf(` +pg_stream: + dsn: %s + slot_name: test_slot_native_decoder + stream_snapshot: true + snapshot_batch_size: 5 + schema: public + tables: + - flights_composite_pks +`, databaseURL) + + cacheConf := fmt.Sprintf(` +label: pg_stream_cache +file: + directory: %v +`, tmpDir) + + streamOutBuilder := service.NewStreamBuilder() + require.NoError(t, streamOutBuilder.SetLoggerYAML(`level: INFO`)) + require.NoError(t, streamOutBuilder.AddCacheYAML(cacheConf)) + require.NoError(t, streamOutBuilder.AddInputYAML(template)) + + var outBatches []string + var outBatchMut sync.Mutex + require.NoError(t, streamOutBuilder.AddBatchConsumerFunc(func(c context.Context, mb service.MessageBatch) error { + msgBytes, err := mb[0].AsBytes() + require.NoError(t, err) + outBatchMut.Lock() + outBatches = append(outBatches, string(msgBytes)) + outBatchMut.Unlock() + return nil + })) + + streamOut, err := streamOutBuilder.Build() + require.NoError(t, err) + + go func() { + _ = streamOut.Run(context.Background()) + }() + + assert.Eventually(t, func() bool { + outBatchMut.Lock() + defer outBatchMut.Unlock() + return len(outBatches) == 10 + }, time.Second*25, time.Millisecond*100) + + for i := 10; i < 20; i++ { + f := GetFakeFlightRecord() + _, err = db.Exec("INSERT INTO flights_composite_pks (seq, name, created_at) VALUES ($1, $2, $3);", i, f.RealAddress.City, time.Unix(f.CreatedAt, 0).Format(time.RFC3339)) + require.NoError(t, err) + _, err = db.Exec("INSERT INTO flights_non_streamed (name, created_at) VALUES ($1, $2);", f.RealAddress.City, time.Unix(f.CreatedAt, 0).Format(time.RFC3339)) + require.NoError(t, err) + } + + assert.Eventually(t, func() bool { + outBatchMut.Lock() + defer outBatchMut.Unlock() + return len(outBatches) == 20 + }, time.Second*25, time.Millisecond*100) + + require.NoError(t, streamOut.StopWithin(time.Second*10)) + + // Starting stream for the same replication slot should continue from the last LSN + // Meaning we must not receive any old messages again + + streamOutBuilder = service.NewStreamBuilder() + require.NoError(t, streamOutBuilder.SetLoggerYAML(`level: OFF`)) + require.NoError(t, streamOutBuilder.AddCacheYAML(cacheConf)) + require.NoError(t, streamOutBuilder.AddInputYAML(template)) + + outBatches = []string{} + require.NoError(t, streamOutBuilder.AddConsumerFunc(func(c context.Context, m *service.Message) error { + msgBytes, err := m.AsBytes() + require.NoError(t, err) + outBatchMut.Lock() + outBatches = append(outBatches, string(msgBytes)) + outBatchMut.Unlock() + return nil + })) + + streamOut, err = streamOutBuilder.Build() + require.NoError(t, err) + + go func() { + assert.NoError(t, streamOut.Run(context.Background())) + }() + + time.Sleep(time.Second * 5) + for i := 20; i < 30; i++ { + f := GetFakeFlightRecord() + _, err = db.Exec("INSERT INTO flights_composite_pks (seq, name, created_at) VALUES ($1, $2, $3);", i, f.RealAddress.City, time.Unix(f.CreatedAt, 0).Format(time.RFC3339)) + require.NoError(t, err) + } + + assert.Eventually(t, func() bool { + outBatchMut.Lock() + defer outBatchMut.Unlock() + return len(outBatches) == 10 + }, time.Second*20, time.Millisecond*100) + + require.NoError(t, streamOut.StopWithin(time.Second*10)) + t.Log("All the conditions are met 🎉") + + t.Cleanup(func() { + db.Close() + }) +} + +func TestIntegrationPgStreamingFromRemoteDB(t *testing.T) { + t.Skip("This test requires a remote database to run. Aimed to test remote databases") + tmpDir := t.TempDir() + + // tables: users, products, orders, order_items + + template := ` +pg_stream: + dsn: postgres://postgres:postgres@localhost:5432/postgres?sslmode=disable + slot_name: test_slot_native_decoder + snapshot_batch_size: 100000 + stream_snapshot: true + batch_transactions: false + temporary_slot: true + schema: public + tables: + - users + - products + - orders + - order_items +` + + cacheConf := fmt.Sprintf(` +label: pg_stream_cache +file: + directory: %v +`, tmpDir) + + streamOutBuilder := service.NewStreamBuilder() + require.NoError(t, streamOutBuilder.SetLoggerYAML(`level: INFO`)) + require.NoError(t, streamOutBuilder.AddCacheYAML(cacheConf)) + require.NoError(t, streamOutBuilder.AddInputYAML(template)) + + var outMessages int64 + var outMessagesMut sync.Mutex + + require.NoError(t, streamOutBuilder.AddBatchConsumerFunc(func(c context.Context, mb service.MessageBatch) error { + _, err := mb[0].AsBytes() + require.NoError(t, err) + outMessagesMut.Lock() + outMessages += 1 + outMessagesMut.Unlock() + return nil + })) + + streamOut, err := streamOutBuilder.Build() + require.NoError(t, err) + + go func() { + _ = streamOut.Run(context.Background()) + }() + + assert.Eventually(t, func() bool { + outMessagesMut.Lock() + defer outMessagesMut.Unlock() + return outMessages == 200000 + }, time.Minute*15, time.Millisecond*100) + + t.Log("Backfill conditioins are met 🎉") + + // you need to start inserting the data somewhere in another place + time.Sleep(time.Minute * 30) + outMessages = 0 + assert.Eventually(t, func() bool { + outMessagesMut.Lock() + defer outMessagesMut.Unlock() + return outMessages == 1000000 + }, time.Minute*15, time.Millisecond*100) + + require.NoError(t, streamOut.StopWithin(time.Second*10)) + + require.NoError(t, streamOut.StopWithin(time.Second*10)) +} + +func TestIntegrationPgCDCForPgOutputStreamUncommittedPlugin(t *testing.T) { + integration.CheckSkip(t) + tmpDir := t.TempDir() + pool, err := dockertest.NewPool("") + require.NoError(t, err) + + var ( + resource *dockertest.Resource + db *sql.DB + ) + + resource, db, err = ResourceWithPostgreSQLVersion(t, pool, "16") + require.NoError(t, err) + require.NoError(t, resource.Expire(120)) + + hostAndPort := resource.GetHostPort("5432/tcp") + hostAndPortSplited := strings.Split(hostAndPort, ":") + password := "l]YLSc|4[i56%{gY" + + for i := 0; i < 10000; i++ { + f := GetFakeFlightRecord() + _, err = db.Exec("INSERT INTO flights (name, created_at) VALUES ($1, $2);", f.RealAddress.City, time.Unix(f.CreatedAt, 0).Format(time.RFC3339)) + require.NoError(t, err) + } + + databaseURL := fmt.Sprintf("user=user_name password=%s dbname=dbname sslmode=disable host=%s port=%s", password, hostAndPortSplited[0], hostAndPortSplited[1]) + template := fmt.Sprintf(` +pg_stream: + dsn: %s + slot_name: test_slot_native_decoder + snapshot_batch_size: 100 + stream_snapshot: true + batch_transactions: true + schema: public + tables: + - flights +`, databaseURL) + + cacheConf := fmt.Sprintf(` +label: pg_stream_cache +file: + directory: %v +`, tmpDir) + + streamOutBuilder := service.NewStreamBuilder() + require.NoError(t, streamOutBuilder.SetLoggerYAML(`level: INFO`)) + require.NoError(t, streamOutBuilder.AddCacheYAML(cacheConf)) + require.NoError(t, streamOutBuilder.AddInputYAML(template)) + + var outBatches []string + var outBatchMut sync.Mutex + require.NoError(t, streamOutBuilder.AddBatchConsumerFunc(func(c context.Context, mb service.MessageBatch) error { + msgBytes, err := mb[0].AsBytes() + require.NoError(t, err) + outBatchMut.Lock() + outBatches = append(outBatches, string(msgBytes)) + outBatchMut.Unlock() + return nil + })) + + streamOut, err := streamOutBuilder.Build() + require.NoError(t, err) + + go func() { + err = streamOut.Run(context.Background()) + require.NoError(t, err) + }() + + assert.Eventually(t, func() bool { + outBatchMut.Lock() + defer outBatchMut.Unlock() + return len(outBatches) == 10000 + }, time.Second*25, time.Millisecond*100) + + for i := 0; i < 10; i++ { + f := GetFakeFlightRecord() + _, err = db.Exec("INSERT INTO flights (name, created_at) VALUES ($1, $2);", f.RealAddress.City, time.Unix(f.CreatedAt, 0).Format(time.RFC3339)) + require.NoError(t, err) + _, err = db.Exec("INSERT INTO flights_non_streamed (name, created_at) VALUES ($1, $2);", f.RealAddress.City, time.Unix(f.CreatedAt, 0).Format(time.RFC3339)) + require.NoError(t, err) + } + + assert.Eventually(t, func() bool { + outBatchMut.Lock() + defer outBatchMut.Unlock() + return len(outBatches) == 10010 + }, time.Second*25, time.Millisecond*100) + + require.NoError(t, streamOut.StopWithin(time.Second*10)) + + // Starting stream for the same replication slot should continue from the last LSN + // Meaning we must not receive any old messages again + + streamOutBuilder = service.NewStreamBuilder() + require.NoError(t, streamOutBuilder.SetLoggerYAML(`level: OFF`)) + require.NoError(t, streamOutBuilder.AddCacheYAML(cacheConf)) + require.NoError(t, streamOutBuilder.AddInputYAML(template)) + + outBatches = []string{} + require.NoError(t, streamOutBuilder.AddConsumerFunc(func(c context.Context, m *service.Message) error { + msgBytes, err := m.AsBytes() + require.NoError(t, err) + outBatchMut.Lock() + outBatches = append(outBatches, string(msgBytes)) + outBatchMut.Unlock() + return nil + })) + + streamOut, err = streamOutBuilder.Build() + require.NoError(t, err) + + go func() { + assert.NoError(t, streamOut.Run(context.Background())) + }() + + time.Sleep(time.Second * 5) + for i := 0; i < 10; i++ { + f := GetFakeFlightRecord() + _, err = db.Exec("INSERT INTO flights (name, created_at) VALUES ($1, $2);", f.RealAddress.City, time.Unix(f.CreatedAt, 0).Format(time.RFC3339)) + require.NoError(t, err) + } + + assert.Eventually(t, func() bool { + outBatchMut.Lock() + defer outBatchMut.Unlock() + return len(outBatches) == 10 + }, time.Second*20, time.Millisecond*100) + + require.NoError(t, streamOut.StopWithin(time.Second*10)) + t.Log("All the conditions are met 🎉") + + t.Cleanup(func() { + db.Close() + }) +} + +func TestIntegrationPgMultiVersionsCDCForPgOutputStreamUncomitedPlugin(t *testing.T) { + integration.CheckSkip(t) + // running tests in the look to test different PostgreSQL versions + t.Parallel() + for _, v := range []string{"17", "16", "15", "14", "13", "12", "11", "10"} { + tmpDir := t.TempDir() + pool, err := dockertest.NewPool("") + require.NoError(t, err) + + var ( + resource *dockertest.Resource + db *sql.DB + ) + + resource, db, err = ResourceWithPostgreSQLVersion(t, pool, v) + require.NoError(t, err) + require.NoError(t, resource.Expire(120)) + + hostAndPort := resource.GetHostPort("5432/tcp") + hostAndPortSplited := strings.Split(hostAndPort, ":") + password := "l]YLSc|4[i56%{gY" + + for i := 0; i < 1000; i++ { + f := GetFakeFlightRecord() + _, err = db.Exec("INSERT INTO flights (name, created_at) VALUES ($1, $2);", f.RealAddress.City, time.Unix(f.CreatedAt, 0).Format(time.RFC3339)) + require.NoError(t, err) + } + + databaseURL := fmt.Sprintf("user=user_name password=%s dbname=dbname sslmode=disable host=%s port=%s", password, hostAndPortSplited[0], hostAndPortSplited[1]) + template := fmt.Sprintf(` +pg_stream: + dsn: %s + slot_name: test_slot_native_decoder + stream_snapshot: true + batch_transactions: true + schema: public + tables: + - flights +`, databaseURL) + + cacheConf := fmt.Sprintf(` +label: pg_stream_cache +file: + directory: %v +`, tmpDir) + + streamOutBuilder := service.NewStreamBuilder() + require.NoError(t, streamOutBuilder.SetLoggerYAML(`level: INFO`)) + require.NoError(t, streamOutBuilder.AddCacheYAML(cacheConf)) + require.NoError(t, streamOutBuilder.AddInputYAML(template)) + + var outBatches []string + var outBatchMut sync.Mutex + require.NoError(t, streamOutBuilder.AddBatchConsumerFunc(func(c context.Context, mb service.MessageBatch) error { + msgBytes, err := mb[0].AsBytes() + require.NoError(t, err) + outBatchMut.Lock() + outBatches = append(outBatches, string(msgBytes)) + outBatchMut.Unlock() + return nil + })) + + streamOut, err := streamOutBuilder.Build() + require.NoError(t, err) + + go func() { + _ = streamOut.Run(context.Background()) + }() + + assert.Eventually(t, func() bool { + outBatchMut.Lock() + defer outBatchMut.Unlock() + return len(outBatches) == 1000 + }, time.Second*25, time.Millisecond*100) + + for i := 0; i < 1000; i++ { + f := GetFakeFlightRecord() + _, err = db.Exec("INSERT INTO flights (name, created_at) VALUES ($1, $2);", f.RealAddress.City, time.Unix(f.CreatedAt, 0).Format(time.RFC3339)) + require.NoError(t, err) + _, err = db.Exec("INSERT INTO flights_non_streamed (name, created_at) VALUES ($1, $2);", f.RealAddress.City, time.Unix(f.CreatedAt, 0).Format(time.RFC3339)) + require.NoError(t, err) + } + + assert.Eventually(t, func() bool { + outBatchMut.Lock() + defer outBatchMut.Unlock() + return len(outBatches) == 2000 + }, time.Second*25, time.Millisecond*100) + + require.NoError(t, streamOut.StopWithin(time.Second*10)) + + // Starting stream for the same replication slot should continue from the last LSN + // Meaning we must not receive any old messages again + + streamOutBuilder = service.NewStreamBuilder() + require.NoError(t, streamOutBuilder.SetLoggerYAML(`level: INFO`)) + require.NoError(t, streamOutBuilder.AddCacheYAML(cacheConf)) + require.NoError(t, streamOutBuilder.AddInputYAML(template)) + + outBatches = []string{} + require.NoError(t, streamOutBuilder.AddConsumerFunc(func(c context.Context, m *service.Message) error { + msgBytes, err := m.AsBytes() + require.NoError(t, err) + outBatchMut.Lock() + outBatches = append(outBatches, string(msgBytes)) + outBatchMut.Unlock() + return nil + })) + + streamOut, err = streamOutBuilder.Build() + require.NoError(t, err) + + go func() { + assert.NoError(t, streamOut.Run(context.Background())) + }() + + time.Sleep(time.Second * 5) + for i := 0; i < 1000; i++ { + f := GetFakeFlightRecord() + _, err = db.Exec("INSERT INTO flights (name, created_at) VALUES ($1, $2);", f.RealAddress.City, time.Unix(f.CreatedAt, 0).Format(time.RFC3339)) + require.NoError(t, err) + } + + assert.Eventually(t, func() bool { + outBatchMut.Lock() + defer outBatchMut.Unlock() + return len(outBatches) == 1000 + }, time.Second*20, time.Millisecond*100) + + require.NoError(t, streamOut.StopWithin(time.Second*10)) + t.Log("All the conditions are met 🎉") + + t.Cleanup(func() { + db.Close() + }) + } +} + +func TestIntegrationPgMultiVersionsCDCForPgOutputStreamComittedPlugin(t *testing.T) { + integration.CheckSkip(t) + for _, v := range []string{"17", "16", "15", "14", "13", "12", "11", "10"} { + tmpDir := t.TempDir() + pool, err := dockertest.NewPool("") + require.NoError(t, err) + + var ( + resource *dockertest.Resource + db *sql.DB + ) + + resource, db, err = ResourceWithPostgreSQLVersion(t, pool, v) + require.NoError(t, err) + require.NoError(t, resource.Expire(120)) + + hostAndPort := resource.GetHostPort("5432/tcp") + hostAndPortSplited := strings.Split(hostAndPort, ":") + password := "l]YLSc|4[i56%{gY" + + for i := 0; i < 1000; i++ { + f := GetFakeFlightRecord() + _, err = db.Exec("INSERT INTO flights (name, created_at) VALUES ($1, $2);", f.RealAddress.City, time.Unix(f.CreatedAt, 0).Format(time.RFC3339)) + require.NoError(t, err) + } + + databaseURL := fmt.Sprintf("user=user_name password=%s dbname=dbname sslmode=disable host=%s port=%s", password, hostAndPortSplited[0], hostAndPortSplited[1]) + template := fmt.Sprintf(` +pg_stream: + dsn: %s + slot_name: test_slot_native_decoder + stream_snapshot: true + batch_transactions: false + schema: public + tables: + - flights +`, databaseURL) + + cacheConf := fmt.Sprintf(` +label: pg_stream_cache +file: + directory: %v +`, tmpDir) + + streamOutBuilder := service.NewStreamBuilder() + require.NoError(t, streamOutBuilder.SetLoggerYAML(`level: INFO`)) + require.NoError(t, streamOutBuilder.AddCacheYAML(cacheConf)) + require.NoError(t, streamOutBuilder.AddInputYAML(template)) + + var outMessages []string + var outMessagesMut sync.Mutex + + require.NoError(t, streamOutBuilder.AddConsumerFunc(func(c context.Context, m *service.Message) error { + msgBytes, err := m.AsBytes() + require.NoError(t, err) + outMessagesMut.Lock() + outMessages = append(outMessages, string(msgBytes)) + outMessagesMut.Unlock() + return nil + })) + + streamOut, err := streamOutBuilder.Build() + require.NoError(t, err) + + go func() { + _ = streamOut.Run(context.Background()) + }() + + assert.Eventually(t, func() bool { + outMessagesMut.Lock() + defer outMessagesMut.Unlock() + return len(outMessages) == 1000 + }, time.Second*25, time.Millisecond*100) + + for i := 0; i < 1000; i++ { + f := GetFakeFlightRecord() + _, err = db.Exec("INSERT INTO flights (name, created_at) VALUES ($1, $2);", f.RealAddress.City, time.Unix(f.CreatedAt, 0).Format(time.RFC3339)) + require.NoError(t, err) + _, err = db.Exec("INSERT INTO flights_non_streamed (name, created_at) VALUES ($1, $2);", f.RealAddress.City, time.Unix(f.CreatedAt, 0).Format(time.RFC3339)) + require.NoError(t, err) + } + + assert.Eventually(t, func() bool { + outMessagesMut.Lock() + defer outMessagesMut.Unlock() + return len(outMessages) == 2000 + }, time.Second*25, time.Millisecond*100) + + require.NoError(t, streamOut.StopWithin(time.Second*10)) + + // Starting stream for the same replication slot should continue from the last LSN + // Meaning we must not receive any old messages again + + streamOutBuilder = service.NewStreamBuilder() + require.NoError(t, streamOutBuilder.SetLoggerYAML(`level: INFO`)) + require.NoError(t, streamOutBuilder.AddCacheYAML(cacheConf)) + require.NoError(t, streamOutBuilder.AddInputYAML(template)) + + outMessages = []string{} + require.NoError(t, streamOutBuilder.AddConsumerFunc(func(c context.Context, m *service.Message) error { + msgBytes, err := m.AsBytes() + require.NoError(t, err) + outMessagesMut.Lock() + outMessages = append(outMessages, string(msgBytes)) + outMessagesMut.Unlock() + return nil + })) + + streamOut, err = streamOutBuilder.Build() + require.NoError(t, err) + + go func() { + assert.NoError(t, streamOut.Run(context.Background())) + }() + + time.Sleep(time.Second * 5) + for i := 0; i < 1000; i++ { + f := GetFakeFlightRecord() + _, err = db.Exec("INSERT INTO flights (name, created_at) VALUES ($1, $2);", f.RealAddress.City, time.Unix(f.CreatedAt, 0).Format(time.RFC3339)) + require.NoError(t, err) + } + + assert.Eventually(t, func() bool { + outMessagesMut.Lock() + defer outMessagesMut.Unlock() + return len(outMessages) == 1000 + }, time.Second*20, time.Millisecond*100) + + require.NoError(t, streamOut.StopWithin(time.Second*10)) + t.Log("All the conditions are met 🎉") + + t.Cleanup(func() { + db.Close() + }) + } +} diff --git a/internal/impl/postgresql/pglogicalstream/availablememory.go b/internal/impl/postgresql/pglogicalstream/availablememory.go new file mode 100644 index 0000000000..ae0ae7e42b --- /dev/null +++ b/internal/impl/postgresql/pglogicalstream/availablememory.go @@ -0,0 +1,20 @@ +// Copyright 2024 Redpanda Data, Inc. +// +// Licensed as a Redpanda Enterprise file under the Redpanda Community +// License (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// https://github.com/redpanda-data/connect/v4/blob/main/licenses/rcl.md + +package pglogicalstream + +import "runtime" + +func getAvailableMemory() uint64 { + var memStats runtime.MemStats + runtime.ReadMemStats(&memStats) + // You can use memStats.Sys or another appropriate memory metric. + // Consider leaving some memory unused for other processes. + availableMemory := memStats.Sys - memStats.HeapInuse + return availableMemory +} diff --git a/internal/impl/postgresql/pglogicalstream/config.go b/internal/impl/postgresql/pglogicalstream/config.go new file mode 100644 index 0000000000..d937813a36 --- /dev/null +++ b/internal/impl/postgresql/pglogicalstream/config.go @@ -0,0 +1,48 @@ +// Copyright 2024 Redpanda Data, Inc. +// +// Licensed as a Redpanda Enterprise file under the Redpanda Community +// License (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// https://github.com/redpanda-data/connect/v4/blob/main/licenses/rcl.md + +package pglogicalstream + +import ( + "time" + + "github.com/jackc/pgx/v5/pgconn" + "github.com/redpanda-data/benthos/v4/public/service" +) + +// Config is the configuration for the pglogicalstream plugin +type Config struct { + // DBConfig is the configuration to connect to the database with + DBConfig *pgconn.Config + DBRawDSN string + // The DB schema to lookup tables in + DBSchema string + // DbTables is the tables to stream changes from + DBTables []string + // ReplicationSlotName is the name of the replication slot to use + // + // MUST BE SQL INJECTION FREE + ReplicationSlotName string + // TemporaryReplicationSlot is whether to use a temporary replication slot + TemporaryReplicationSlot bool + // StreamOldData is whether to stream all existing data + StreamOldData bool + // SnapshotMemorySafetyFactor is the memory safety factor for streaming snapshot + SnapshotMemorySafetyFactor float64 + // BatchSize is the batch size for streaming + BatchSize int + // BatchTransactions is whether to buffer transactions as an entire single message or to send + // each row in a transaction as a message. + BatchTransactions bool + + Logger *service.Logger + + PgStandbyTimeout time.Duration + WalMonitorInterval time.Duration + MaxParallelSnapshotTables int +} diff --git a/internal/impl/postgresql/pglogicalstream/connection.go b/internal/impl/postgresql/pglogicalstream/connection.go new file mode 100644 index 0000000000..9f81c3be2b --- /dev/null +++ b/internal/impl/postgresql/pglogicalstream/connection.go @@ -0,0 +1,48 @@ +// Copyright 2024 Redpanda Data, Inc. +// +// Licensed as a Redpanda Enterprise file under the Redpanda Community +// License (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// https://github.com/redpanda-data/connect/v4/blob/main/licenses/rcl.md + +package pglogicalstream + +import ( + "database/sql" + "fmt" + "regexp" + "strconv" +) + +var re = regexp.MustCompile(`^(\d+)`) + +func openPgConnectionFromConfig(dbDSN string) (*sql.DB, error) { + return sql.Open("postgres", dbDSN) +} + +func getPostgresVersion(dbDSN string) (int, error) { + conn, err := openPgConnectionFromConfig(dbDSN) + if err != nil { + return 0, fmt.Errorf("failed to connect to the database: %w", err) + } + defer conn.Close() + + var versionString string + err = conn.QueryRow("SHOW server_version").Scan(&versionString) + if err != nil { + return 0, fmt.Errorf("failed to execute query: %w", err) + } + + match := re.FindStringSubmatch(versionString) + if len(match) < 2 { + return 0, fmt.Errorf("failed to parse version string: %s", versionString) + } + + majorVersion, err := strconv.Atoi(match[1]) + if err != nil { + return 0, fmt.Errorf("failed to convert version to integer: %w", err) + } + + return majorVersion, nil +} diff --git a/internal/impl/postgresql/pglogicalstream/logical_stream.go b/internal/impl/postgresql/pglogicalstream/logical_stream.go new file mode 100644 index 0000000000..dbcb2c3f7f --- /dev/null +++ b/internal/impl/postgresql/pglogicalstream/logical_stream.go @@ -0,0 +1,633 @@ +// Copyright 2024 Redpanda Data, Inc. +// +// Licensed as a Redpanda Enterprise file under the Redpanda Community +// License (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// https://github.com/redpanda-data/connect/v4/blob/main/licenses/rcl.md + +package pglogicalstream + +import ( + "context" + "database/sql" + "errors" + "fmt" + "slices" + "strings" + "time" + + "github.com/Jeffail/shutdown" + "github.com/jackc/pgx/v5/pgconn" + "github.com/jackc/pgx/v5/pgproto3" + "github.com/redpanda-data/benthos/v4/public/service" + "golang.org/x/sync/errgroup" + + "github.com/redpanda-data/connect/v4/internal/impl/postgresql/pglogicalstream/sanitize" + "github.com/redpanda-data/connect/v4/internal/impl/postgresql/pglogicalstream/watermark" +) + +const decodingPlugin = "pgoutput" + +// Stream is a structure that represents a logical replication stream +// It includes the connection to the database, the context for the stream, and snapshotting functionality +type Stream struct { + pgConn *pgconn.PgConn + + shutSig *shutdown.Signaller + + clientXLogPos *watermark.Value[LSN] + + standbyMessageTimeout time.Duration + nextStandbyMessageDeadline time.Time + messages chan StreamMessage + errors chan error + + snapshotName string + slotName string + schema string + // includes schema + tableQualifiedName []string + snapshotBatchSize int + decodingPluginArguments []string + snapshotMemorySafetyFactor float64 + logger *service.Logger + monitor *Monitor + batchTransactions bool + snapshotter *Snapshotter + maxParallelSnapshotTables int +} + +// NewPgStream creates a new instance of the Stream struct +func NewPgStream(ctx context.Context, config *Config) (*Stream, error) { + if config.ReplicationSlotName == "" { + return nil, errors.New("missing replication slot name") + } + + // Cleanup state - this will be accumulated as the function progresses and cleared + // if we successfully create a stream. + var cleanups []func() + defer func() { + for i := len(cleanups) - 1; i >= 0; i-- { + cleanups[i]() + } + }() + + dbConn, err := pgconn.ConnectConfig(ctx, config.DBConfig.Copy()) + if err != nil { + return nil, err + } + cleanups = append(cleanups, func() { + if err := dbConn.Close(ctx); err != nil { + config.Logger.Warnf("unable to properly cleanup db connection on stream creation failure: %s", err) + } + }) + + if err = dbConn.Ping(ctx); err != nil { + return nil, err + } + + tableNames := slices.Clone(config.DBTables) + for i, table := range tableNames { + if err := sanitize.ValidatePostgresIdentifier(table); err != nil { + return nil, fmt.Errorf("invalid table name %q: %w", table, err) + } + + tableNames[i] = fmt.Sprintf("%s.%s", config.DBSchema, table) + } + stream := &Stream{ + pgConn: dbConn, + messages: make(chan StreamMessage), + errors: make(chan error, 1), + slotName: config.ReplicationSlotName, + snapshotMemorySafetyFactor: config.SnapshotMemorySafetyFactor, + batchTransactions: config.BatchTransactions, + snapshotBatchSize: config.BatchSize, + schema: config.DBSchema, + tableQualifiedName: tableNames, + maxParallelSnapshotTables: config.MaxParallelSnapshotTables, + logger: config.Logger, + shutSig: shutdown.NewSignaller(), + } + + var version int + version, err = getPostgresVersion(config.DBRawDSN) + if err != nil { + return nil, err + } + + snapshotter, err := NewSnapshotter(config.DBRawDSN, stream.logger, version) + if err != nil { + return nil, err + } + stream.snapshotter = snapshotter + cleanups = append(cleanups, func() { + if err := snapshotter.closeConn(); err != nil { + config.Logger.Warnf("unable to properly cleanup snapshotter connection on stream creation failure: %s", err) + } + }) + + pluginArguments := []string{ + "proto_version '1'", + // Sprintf is safe because we validate ReplicationSlotName is alphanumeric in the config + fmt.Sprintf("publication_names 'pglog_stream_%s'", config.ReplicationSlotName), + } + + if version > 14 { + pluginArguments = append(pluginArguments, "messages 'true'") + } + + stream.decodingPluginArguments = pluginArguments + + pubName := "pglog_stream_" + config.ReplicationSlotName + stream.logger.Infof("Creating publication %s for tables: %s", pubName, tableNames) + if err = CreatePublication(ctx, stream.pgConn, pubName, tableNames); err != nil { + return nil, err + } + cleanups = append(cleanups, func() { + // TODO: Drop publication if it was created (meaning it's not existing state we might want to keep). + }) + + sysident, err := IdentifySystem(ctx, stream.pgConn) + if err != nil { + return nil, err + } + + var freshlyCreatedSlot = false + var confirmedLSNFromDB string + var outputPlugin string + // check is replication slot exist to get last restart SLN + + s, err := sanitize.SQLQuery("SELECT confirmed_flush_lsn, plugin FROM pg_replication_slots WHERE slot_name = $1", config.ReplicationSlotName) + if err != nil { + return nil, err + } + connExecResult, err := stream.pgConn.Exec(ctx, s).ReadAll() + if err != nil { + return nil, err + } + if len(connExecResult) == 0 || len(connExecResult[0].Rows) == 0 { + // here we create a new replication slot because there is no slot found + var createSlotResult CreateReplicationSlotResult + createSlotResult, err = CreateReplicationSlot( + ctx, + stream.pgConn, + stream.slotName, + decodingPlugin, + CreateReplicationSlotOptions{ + Temporary: config.TemporaryReplicationSlot, + SnapshotAction: "export", + }, + version, + stream.snapshotter, + ) + if err != nil { + return nil, err + } + stream.snapshotName = createSlotResult.SnapshotName + freshlyCreatedSlot = true + cleanups = append(cleanups, func() { + err := DropReplicationSlot(ctx, stream.pgConn, stream.slotName, DropReplicationSlotOptions{Wait: true}) + if err != nil { + config.Logger.Warnf("unable to properly cleanup replication slot on stream creation failure: %s", err) + } + }) + } else { + slotCheckRow := connExecResult[0].Rows[0] + confirmedLSNFromDB = string(slotCheckRow[0]) + outputPlugin = string(slotCheckRow[1]) + } + + // handling a case when replication slot already exists but with different output plugin created manually + if !freshlyCreatedSlot && outputPlugin != decodingPlugin { + return nil, fmt.Errorf("replication slot %s already exists with different output plugin: %s", config.ReplicationSlotName, outputPlugin) + } + + var lsnrestart LSN + if freshlyCreatedSlot { + lsnrestart = sysident.XLogPos + } else { + lsnrestart, _ = ParseLSN(confirmedLSNFromDB) + } + stream.clientXLogPos = watermark.New(lsnrestart) + + stream.standbyMessageTimeout = config.PgStandbyTimeout + stream.nextStandbyMessageDeadline = time.Now().Add(stream.standbyMessageTimeout) + + monitor, err := NewMonitor(ctx, config.DBRawDSN, stream.logger, tableNames, stream.slotName, config.WalMonitorInterval) + if err != nil { + return nil, err + } + stream.monitor = monitor + cleanups = append(cleanups, func() { + if err := monitor.Stop(); err != nil { + config.Logger.Warnf("unable to properly cleanup monitor on stream creation failure: %s", err) + } + }) + + stream.logger.Debugf("starting stream from LSN %s with clientXLogPos %s and snapshot name %s", lsnrestart.String(), stream.clientXLogPos.Get().String(), stream.snapshotName) + // TODO(le-vlad): if snapshot processing is restarted we will just skip right to streaming... + if !freshlyCreatedSlot || !config.StreamOldData { + if err = stream.startLr(ctx, lsnrestart); err != nil { + return nil, err + } + + go func() { + defer stream.shutSig.TriggerHasStopped() + if err := stream.streamMessages(); err != nil { + stream.errors <- fmt.Errorf("logical replication stream error: %w", err) + } + }() + } else { + go func() { + defer stream.shutSig.TriggerHasStopped() + if err := stream.processSnapshot(); err != nil { + stream.errors <- fmt.Errorf("failed to process snapshot: %w", err) + return + } + ctx, _ := stream.shutSig.SoftStopCtx(context.Background()) + if err := stream.startLr(ctx, lsnrestart); err != nil { + stream.errors <- fmt.Errorf("failed to start logical replication: %w", err) + return + } + if err := stream.streamMessages(); err != nil { + stream.errors <- fmt.Errorf("logical replication stream error: %w", err) + } + }() + } + + // Success! No need to cleanup + cleanups = nil + return stream, nil +} + +// GetProgress returns the progress of the stream. +// including the % of snapshot messages processed and the WAL lag in bytes. +func (s *Stream) GetProgress() *Report { + return s.monitor.Report() +} + +func (s *Stream) startLr(ctx context.Context, lsnStart LSN) error { + err := StartReplication( + ctx, + s.pgConn, + s.slotName, + lsnStart, + StartReplicationOptions{ + PluginArgs: s.decodingPluginArguments, + }, + ) + if err != nil { + return err + } + s.logger.Debugf("Started logical replication on slot slot-name: %v", s.slotName) + return nil +} + +// AckLSN acknowledges the LSN up to which the stream has processed the messages. +// This makes Postgres to remove the WAL files that are no longer needed. +func (s *Stream) AckLSN(ctx context.Context, lsn string) error { + if s.shutSig.IsHardStopSignalled() { + return fmt.Errorf("unable to ack LSN %s stream shutting down", lsn) + } + clientXLogPos, err := ParseLSN(lsn) + if err != nil { + return err + } + + err = SendStandbyStatusUpdate( + ctx, + s.pgConn, + StandbyStatusUpdate{ + WALApplyPosition: clientXLogPos + 1, + WALWritePosition: clientXLogPos + 1, + WALFlushPosition: clientXLogPos + 1, + ReplyRequested: true, + }, + ) + + if err != nil { + return fmt.Errorf("failed to send Standby status message at LSN %s: %w", clientXLogPos.String(), err) + } + + // Update client XLogPos after we ack the message + s.clientXLogPos.Set(clientXLogPos) + s.logger.Debugf("Sent Standby status message at LSN#%s", clientXLogPos.String()) + s.nextStandbyMessageDeadline = time.Now().Add(s.standbyMessageTimeout) + + return nil +} + +func (s *Stream) streamMessages() error { + handler := NewPgOutputPluginHandler(s.messages, s.batchTransactions, s.monitor, s.clientXLogPos) + + ctx, _ := s.shutSig.SoftStopCtx(context.Background()) + for !s.shutSig.IsSoftStopSignalled() { + if time.Now().After(s.nextStandbyMessageDeadline) { + pos := s.clientXLogPos.Get() + err := SendStandbyStatusUpdate( + ctx, + s.pgConn, + StandbyStatusUpdate{ + WALWritePosition: pos, + }, + ) + if err != nil { + return fmt.Errorf("unable to send standby status message at LSN %s: %w", pos, err) + } + s.logger.Debugf("Sent Standby status message at LSN#%s", pos.String()) + s.nextStandbyMessageDeadline = time.Now().Add(s.standbyMessageTimeout) + } + recvCtx, cancel := context.WithDeadline(ctx, s.nextStandbyMessageDeadline) + rawMsg, err := s.pgConn.ReceiveMessage(recvCtx) + cancel() // don't leak goroutine + hitStandbyTimeout := errors.Is(err, context.DeadlineExceeded) && ctx.Err() == nil + if err != nil { + if hitStandbyTimeout || pgconn.Timeout(err) { + s.logger.Info("continue") + continue + } + return fmt.Errorf("failed to receive messages from Postgres: %w", err) + } + + if errMsg, ok := rawMsg.(*pgproto3.ErrorResponse); ok { + return fmt.Errorf("received error message from Postgres: %v", errMsg) + } + + msg, ok := rawMsg.(*pgproto3.CopyData) + if !ok { + s.logger.Warnf("received unexpected message: %T", rawMsg) + continue + } + + if len(msg.Data) == 0 { + s.logger.Warn("received malformatted with no data") + continue + } + + switch msg.Data[0] { + case PrimaryKeepaliveMessageByteID: + pkm, err := ParsePrimaryKeepaliveMessage(msg.Data[1:]) + if err != nil { + return fmt.Errorf("failed to parse PrimaryKeepaliveMessage: %w", err) + } + if pkm.ReplyRequested { + s.nextStandbyMessageDeadline = time.Time{} + } + + // XLogDataByteID is the message type for the actual WAL data + // It will cause the stream to process WAL changes and create the corresponding messages + case XLogDataByteID: + xld, err := ParseXLogData(msg.Data[1:]) + if err != nil { + return fmt.Errorf("failed to parse XLogData: %w", err) + } + clientXLogPos := xld.WALStart + LSN(len(xld.WALData)) + commit, err := handler.Handle(ctx, clientXLogPos, xld) + if err != nil { + return fmt.Errorf("decoding postgres changes failed: %w", err) + } else if commit { + // This is a hack and we probably should not do it + if err = s.AckLSN(ctx, clientXLogPos.String()); err != nil { + s.logger.Warnf("Failed to ack commit message LSN: %v", err) + } + } + } + } + // clean shutdown, return nil + return nil +} + +func (s *Stream) processSnapshot() error { + if err := s.snapshotter.prepare(); err != nil { + return fmt.Errorf("failed to prepare database snapshot - snapshot may be expired: %w", err) + } + defer func() { + if err := s.snapshotter.releaseSnapshot(); err != nil { + s.logger.Warnf("Failed to release database snapshot: %v", err.Error()) + } + if err := s.snapshotter.closeConn(); err != nil { + s.logger.Warnf("Failed to close database connection: %v", err.Error()) + } + }() + + s.logger.Debugf("Starting snapshot processing") + var wg errgroup.Group + wg.SetLimit(s.maxParallelSnapshotTables) + + for _, table := range s.tableQualifiedName { + tableName := table + wg.Go(func() (err error) { + s.logger.Debugf("Processing snapshot for table: %v", table) + + var ( + avgRowSizeBytes sql.NullInt64 + offset = 0 + ) + + ctx, _ := s.shutSig.SoftStopCtx(context.Background()) + + avgRowSizeBytes, err = s.snapshotter.findAvgRowSize(ctx, table) + if err != nil { + return fmt.Errorf("failed to calculate average row size for table %v: %w", table, err) + } + + availableMemory := getAvailableMemory() + batchSize := s.snapshotter.calculateBatchSize(availableMemory, uint64(avgRowSizeBytes.Int64)) + if s.snapshotBatchSize > 0 { + batchSize = s.snapshotBatchSize + } + + s.logger.Debugf("Querying snapshot batch_side: %v, available_memory: %v, avg_row_size: %v", batchSize, availableMemory, avgRowSizeBytes.Int64) + + lastPrimaryKey, primaryKeyColumns, err := s.getPrimaryKeyColumn(ctx, table) + if err != nil { + return fmt.Errorf("failed to get primary key column for table %v: %w", table, err) + } + + if len(lastPrimaryKey) == 0 { + return fmt.Errorf("failed to get primary key column for table %s", table) + } + + var lastPkVals = map[string]any{} + + for { + var snapshotRows *sql.Rows + queryStart := time.Now() + if offset == 0 { + snapshotRows, err = s.snapshotter.querySnapshotData(ctx, table, nil, primaryKeyColumns, batchSize) + } else { + snapshotRows, err = s.snapshotter.querySnapshotData(ctx, table, lastPkVals, primaryKeyColumns, batchSize) + } + if err != nil { + return fmt.Errorf("failed to query snapshot data for table %v: %w", table, err) + } + + queryDuration := time.Since(queryStart) + s.logger.Tracef("Query duration: %v %s \n", queryDuration, tableName) + + if snapshotRows.Err() != nil { + return fmt.Errorf("failed to get snapshot data for table %v: %w", table, snapshotRows.Err()) + } + + columnTypes, err := snapshotRows.ColumnTypes() + if err != nil { + return fmt.Errorf("failed to get column types for table %v: %w", table, err) + } + + columnNames, err := snapshotRows.Columns() + if err != nil { + return fmt.Errorf("failed to get column names for table %v: %w", table, err) + } + + var rowsCount = 0 + rowsStart := time.Now() + totalScanDuration := time.Duration(0) + totalWaitingFromBenthos := time.Duration(0) + + tableWithoutSchema := strings.Split(table, ".")[1] + for snapshotRows.Next() { + rowsCount += 1 + + scanStart := time.Now() + scanArgs, valueGetters := s.snapshotter.prepareScannersAndGetters(columnTypes) + err := snapshotRows.Scan(scanArgs...) + scanEnd := time.Since(scanStart) + totalScanDuration += scanEnd + + if err != nil { + return fmt.Errorf("failed to scan row for table %v: %v", table, err.Error()) + } + + var data = make(map[string]any) + for i, getter := range valueGetters { + data[columnNames[i]] = getter(scanArgs[i]) + if _, ok := lastPrimaryKey[columnNames[i]]; ok { + lastPkVals[columnNames[i]] = getter(scanArgs[i]) + } + } + + snapshotChangePacket := StreamMessage{ + Lsn: nil, + Changes: []StreamMessageChanges{ + { + Table: tableWithoutSchema, + Operation: "insert", + Schema: s.schema, + Data: data, + }, + }, + } + + if rowsCount%100 == 0 { + s.monitor.UpdateSnapshotProgressForTable(tableWithoutSchema, rowsCount+offset) + } + + tableProgress := s.monitor.GetSnapshotProgressForTable(tableWithoutSchema) + snapshotChangePacket.Changes[0].TableSnapshotProgress = &tableProgress + snapshotChangePacket.Mode = StreamModeSnapshot + + waitingFromBenthos := time.Now() + select { + case s.messages <- snapshotChangePacket: + case <-s.shutSig.SoftStopChan(): + return nil + } + totalWaitingFromBenthos += time.Since(waitingFromBenthos) + + } + + batchEnd := time.Since(rowsStart) + s.logger.Debugf("Batch duration: %v %s \n", batchEnd, tableName) + s.logger.Debugf("Scan duration %v %s\n", totalScanDuration, tableName) + s.logger.Debugf("Waiting from benthos duration %v %s\n", totalWaitingFromBenthos, tableName) + + offset += batchSize + + if rowsCount < batchSize { + break + } + } + return nil + }) + } + return wg.Wait() +} + +// Messages is a channel that can be used to consume messages from the plugin. It will contain LSN nil for snapshot messages +func (s *Stream) Messages() chan StreamMessage { + return s.messages +} + +// Errors is a channel that can be used to see if and error has occured internally and the stream should be restarted +func (s *Stream) Errors() chan error { + return s.errors +} + +func (s *Stream) getPrimaryKeyColumn(ctx context.Context, tableName string) (map[string]any, []string, error) { + /// Query to get all primary key columns in their correct order + q, err := sanitize.SQLQuery(` + SELECT a.attname + FROM pg_index i + JOIN pg_attribute a ON a.attrelid = i.indrelid + AND a.attnum = ANY(i.indkey) + WHERE i.indrelid = $1::regclass + AND i.indisprimary + ORDER BY array_position(i.indkey, a.attnum); + `, tableName) + + if err != nil { + return nil, nil, fmt.Errorf("failed to sanitize query: %w", err) + } + + reader := s.pgConn.Exec(ctx, q) + data, err := reader.ReadAll() + if err != nil { + return nil, nil, fmt.Errorf("failed to read query results: %w", err) + } + + if len(data) == 0 || len(data[0].Rows) == 0 { + return nil, nil, fmt.Errorf("no primary key found for table %s", tableName) + } + + // Extract all primary key column names + pkColumns := make([]string, len(data[0].Rows)) + for i, row := range data[0].Rows { + pkColumns[i] = string(row[0]) + } + + var pksMap = make(map[string]any) + for _, pk := range pkColumns { + pksMap[pk] = nil + } + + return pksMap, pkColumns, nil +} + +// Stop closes the stream (hopefully gracefully) +func (s *Stream) Stop(ctx context.Context) error { + s.shutSig.TriggerSoftStop() + var wg errgroup.Group + stopNowCtx, _ := s.shutSig.HardStopCtx(ctx) + wg.Go(func() error { + return s.pgConn.Close(stopNowCtx) + }) + wg.Go(func() error { + return s.monitor.Stop() + }) + select { + case <-ctx.Done(): + case <-s.shutSig.HasStoppedChan(): + return wg.Wait() + } + s.shutSig.TriggerHardStop() + err := wg.Wait() + select { + case <-time.After(time.Second): + if err == nil { + return errors.New("unable to cleanly shutdown postgres logical replication stream") + } + case <-s.shutSig.HasStoppedChan(): + } + return err +} diff --git a/internal/impl/postgresql/pglogicalstream/monitor.go b/internal/impl/postgresql/pglogicalstream/monitor.go new file mode 100644 index 0000000000..d9ed0ba4db --- /dev/null +++ b/internal/impl/postgresql/pglogicalstream/monitor.go @@ -0,0 +1,172 @@ +// Copyright 2024 Redpanda Data, Inc. +// +// Licensed as a Redpanda Enterprise file under the Redpanda Community +// License (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// https://github.com/redpanda-data/connect/v4/blob/main/licenses/rcl.md + +package pglogicalstream + +import ( + "context" + "database/sql" + "fmt" + "maps" + "math" + "strings" + "sync" + "time" + + "github.com/redpanda-data/benthos/v4/public/service" + + "github.com/redpanda-data/connect/v4/internal/impl/postgresql/pglogicalstream/sanitize" + "github.com/redpanda-data/connect/v4/internal/periodic" +) + +// Report is a structure that contains the current state of the Monitor +type Report struct { + WalLagInBytes int64 + TableProgress map[string]float64 +} + +// Monitor is a structure that allows monitoring the progress of snapshot ingestion and replication lag +type Monitor struct { + // tableStat contains numbers of rows for each table determined at the moment of the snapshot creation + // this is used to calculate snapshot ingestion progress + tableStat map[string]int64 + lock sync.Mutex + // snapshotProgress is a map of table names to the percentage of rows ingested from the snapshot + snapshotProgress map[string]float64 + // replicationLagInBytes is the replication lag in bytes measured by + // finding the difference between the latest LSN and the last confirmed LSN for the replication slot + replicationLagInBytes int64 + + dbConn *sql.DB + slotName string + logger *service.Logger + loop *periodic.Periodic +} + +// NewMonitor creates a new Monitor instance +func NewMonitor( + ctx context.Context, + dbDSN string, + logger *service.Logger, + tables []string, + slotName string, + interval time.Duration, +) (*Monitor, error) { + dbConn, err := openPgConnectionFromConfig(dbDSN) + if err != nil { + return nil, err + } + if interval <= 0 { + return nil, fmt.Errorf("invalid monitoring interval: %s", interval.String()) + } + + m := &Monitor{ + snapshotProgress: map[string]float64{}, + replicationLagInBytes: 0, + dbConn: dbConn, + slotName: slotName, + logger: logger, + } + m.loop = periodic.NewWithContext(interval, m.readReplicationLag) + if err = m.readTablesStat(ctx, tables); err != nil { + return nil, err + } + m.loop.Start() + return m, nil +} + +// GetSnapshotProgressForTable returns the snapshot ingestion progress for a given table +func (m *Monitor) GetSnapshotProgressForTable(table string) float64 { + m.lock.Lock() + defer m.lock.Unlock() + return m.snapshotProgress[table] +} + +// UpdateSnapshotProgressForTable updates the snapshot ingestion progress for a given table +func (m *Monitor) UpdateSnapshotProgressForTable(table string, position int) { + m.lock.Lock() + defer m.lock.Unlock() + m.snapshotProgress[table] = math.Round(float64(position) / float64(m.tableStat[table]) * 100) +} + +// we need to read the tables stat to calculate the snapshot ingestion progress +func (m *Monitor) readTablesStat(ctx context.Context, tables []string) error { + results := make(map[string]int64) + + for _, table := range tables { + tableWithoutSchema := strings.Split(table, ".")[1] + err := sanitize.ValidatePostgresIdentifier(tableWithoutSchema) + + if err != nil { + return fmt.Errorf("error sanitizing query: %w", err) + } + + var count int64 + // tableWithoutSchema has been validated so its safe to use in the query + err = m.dbConn.QueryRowContext(ctx, "SELECT COUNT(*) FROM "+tableWithoutSchema).Scan(&count) + + if err != nil { + // If the error is because the table doesn't exist, we'll set the count to 0 + // and continue. You might want to log this situation. + if strings.Contains(err.Error(), "does not exist") { + results[tableWithoutSchema] = 0 + continue + } + // For any other error, we'll return it + return fmt.Errorf("error counting rows in table %s: %w", tableWithoutSchema, err) + } + + results[tableWithoutSchema] = count + } + + m.tableStat = results + return nil +} + +func (m *Monitor) readReplicationLag(ctx context.Context) { + result, err := m.dbConn.QueryContext(ctx, `SELECT slot_name, + pg_wal_lsn_diff(pg_current_wal_lsn(), restart_lsn) AS lag_bytes + FROM pg_replication_slots WHERE slot_name = $1;`, m.slotName) + // calculate the replication lag in bytes + // replicationLagInBytes = latestLsn - confirmedLsn + if err != nil || result.Err() != nil { + m.logger.Warnf("Error reading replication lag: %v", err) + return + } + + var slotName string + var lagbytes int64 + for result.Next() { + if err = result.Scan(&slotName, &lagbytes); err != nil { + m.logger.Warnf("Error reading replication lag: %v", err) + return + } + } + + m.lock.Lock() + m.replicationLagInBytes = lagbytes + m.lock.Unlock() +} + +// Report returns a snapshot of the monitor's state +func (m *Monitor) Report() *Report { + m.lock.Lock() + defer m.lock.Unlock() + // report the snapshot ingestion progress + // report the replication lag + return &Report{ + WalLagInBytes: m.replicationLagInBytes, + TableProgress: maps.Clone(m.snapshotProgress), + } +} + +// Stop stops the monitor +func (m *Monitor) Stop() error { + m.loop.Stop() + return m.dbConn.Close() +} diff --git a/internal/impl/postgresql/pglogicalstream/pglogrepl.go b/internal/impl/postgresql/pglogicalstream/pglogrepl.go new file mode 100644 index 0000000000..99debf5190 --- /dev/null +++ b/internal/impl/postgresql/pglogicalstream/pglogrepl.go @@ -0,0 +1,698 @@ +// Copyright 2024 Redpanda Data, Inc. +// +// Licensed as a Redpanda Enterprise file under the Redpanda Community +// License (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// https://github.com/redpanda-data/connect/v4/blob/main/licenses/rcl.md + +package pglogicalstream + +// Package pglogrepl implements PostgreSQL logical replication client functionality. +// +// pglogrepl uses package github.com/jackc/pgconn as its underlying PostgreSQL connection. +// Use pgconn to establish a connection to PostgreSQL and then use the pglogrepl functions +// on that connection. +// +// Proper use of this package requires understanding the underlying PostgreSQL concepts. +// See https://www.postgresql.org/docs/current/protocol-replication.html. + +import ( + "context" + "database/sql/driver" + "encoding/binary" + "errors" + "fmt" + "slices" + "strconv" + "strings" + "time" + + "github.com/jackc/pgio" + "github.com/jackc/pgx/v5/pgconn" + "github.com/jackc/pgx/v5/pgproto3" + + "github.com/redpanda-data/connect/v4/internal/impl/postgresql/pglogicalstream/sanitize" +) + +const ( + // XLogDataByteID is the byte ID for XLogData messages. + XLogDataByteID = 'w' + // PrimaryKeepaliveMessageByteID is the byte ID for PrimaryKeepaliveMessage messages. + PrimaryKeepaliveMessageByteID = 'k' + // StandbyStatusUpdateByteID is the byte ID for StandbyStatusUpdate messages. + StandbyStatusUpdateByteID = 'r' +) + +// ReplicationMode is the mode of replication to use. +type ReplicationMode int + +const ( + // LogicalReplication is the only replication mode supported by this plugin + LogicalReplication ReplicationMode = iota +) + +// String formats the mode into a postgres valid string +func (mode ReplicationMode) String() string { + if mode == LogicalReplication { + return "LOGICAL" + } else { + return "PHYSICAL" + } +} + +// LSN is a PostgreSQL Log Sequence Number. See https://www.postgresql.org/docs/current/datatype-pg-lsn.html. +type LSN uint64 + +// String formats the LSN value into the XXX/XXX format which is the text format used by PostgreSQL. +func (lsn LSN) String() string { + return fmt.Sprintf("%X/%X", uint32(lsn>>32), uint32(lsn)) +} + +func (lsn *LSN) decodeText(src string) error { + lsnValue, err := ParseLSN(src) + if err != nil { + return err + } + *lsn = lsnValue + + return nil +} + +// Scan implements the Scanner interface. +func (lsn *LSN) Scan(src interface{}) error { + if lsn == nil { + return nil + } + + switch v := src.(type) { + case uint64: + *lsn = LSN(v) + case string: + if err := lsn.decodeText(v); err != nil { + return err + } + case []byte: + if err := lsn.decodeText(string(v)); err != nil { + return err + } + default: + return fmt.Errorf("can not scan %T to LSN", src) + } + + return nil +} + +// Value implements the Valuer interface. +func (lsn LSN) Value() (driver.Value, error) { + return driver.Value(lsn.String()), nil +} + +// ParseLSN parses the given XXX/XXX text format LSN used by PostgreSQL. +func ParseLSN(s string) (LSN, error) { + var upperHalf uint64 + var lowerHalf uint64 + var nparsed int + nparsed, err := fmt.Sscanf(s, "%X/%X", &upperHalf, &lowerHalf) + if err != nil { + return 0, fmt.Errorf("failed to parse LSN: %w", err) + } + + if nparsed != 2 { + return 0, fmt.Errorf("failed to parsed LSN: %s", s) + } + + return LSN((upperHalf << 32) + lowerHalf), nil +} + +// IdentifySystemResult is the parsed result of the IDENTIFY_SYSTEM command. +type IdentifySystemResult struct { + SystemID string + Timeline int32 + XLogPos LSN + DBName string +} + +// IdentifySystem executes the IDENTIFY_SYSTEM command. +func IdentifySystem(ctx context.Context, conn *pgconn.PgConn) (IdentifySystemResult, error) { + return ParseIdentifySystem(conn.Exec(ctx, "IDENTIFY_SYSTEM")) +} + +// ParseIdentifySystem parses the result of the IDENTIFY_SYSTEM command. +func ParseIdentifySystem(mrr *pgconn.MultiResultReader) (IdentifySystemResult, error) { + var isr IdentifySystemResult + results, err := mrr.ReadAll() + if err != nil { + return isr, err + } + + if len(results) != 1 { + return isr, fmt.Errorf("expected 1 result set, got %d", len(results)) + } + + result := results[0] + if len(result.Rows) != 1 { + return isr, fmt.Errorf("expected 1 result row, got %d", len(result.Rows)) + } + + row := result.Rows[0] + if len(row) != 4 { + return isr, fmt.Errorf("expected 4 result columns, got %d", len(row)) + } + + isr.SystemID = string(row[0]) + timeline, err := strconv.ParseInt(string(row[1]), 10, 32) + if err != nil { + return isr, fmt.Errorf("failed to parse timeline: %w", err) + } + isr.Timeline = int32(timeline) + + isr.XLogPos, err = ParseLSN(string(row[2])) + if err != nil { + return isr, fmt.Errorf("failed to parse xlogpos as LSN: %w", err) + } + + isr.DBName = string(row[3]) + + return isr, nil +} + +// TimelineHistoryResult is the parsed result of the TIMELINE_HISTORY command. +type TimelineHistoryResult struct { + FileName string + Content []byte +} + +// TimelineHistory executes the TIMELINE_HISTORY command. +func TimelineHistory(ctx context.Context, conn *pgconn.PgConn, timeline int32) (TimelineHistoryResult, error) { + sql := fmt.Sprintf("TIMELINE_HISTORY %d", timeline) + return ParseTimelineHistory(conn.Exec(ctx, sql)) +} + +// ParseTimelineHistory parses the result of the TIMELINE_HISTORY command. +func ParseTimelineHistory(mrr *pgconn.MultiResultReader) (TimelineHistoryResult, error) { + var thr TimelineHistoryResult + results, err := mrr.ReadAll() + if err != nil { + return thr, err + } + + if len(results) != 1 { + return thr, fmt.Errorf("expected 1 result set, got %d", len(results)) + } + + result := results[0] + if len(result.Rows) != 1 { + return thr, fmt.Errorf("expected 1 result row, got %d", len(result.Rows)) + } + + row := result.Rows[0] + if len(row) != 2 { + return thr, fmt.Errorf("expected 2 result columns, got %d", len(row)) + } + + thr.FileName = string(row[0]) + thr.Content = row[1] + return thr, nil +} + +// CreateReplicationSlotOptions are the options for the CREATE_REPLICATION_SLOT command. Including Mode, Temporary, and SnapshotAction. +type CreateReplicationSlotOptions struct { + Temporary bool + SnapshotAction string + Mode ReplicationMode +} + +// CreateReplicationSlotResult is the parsed results the CREATE_REPLICATION_SLOT command. +type CreateReplicationSlotResult struct { + SlotName string + ConsistentPoint string + SnapshotName string + OutputPlugin string +} + +// CreateReplicationSlot creates a logical replication slot. +func CreateReplicationSlot( + ctx context.Context, + conn *pgconn.PgConn, + slotName string, + outputPlugin string, + options CreateReplicationSlotOptions, + version int, + snapshotter *Snapshotter, +) (CreateReplicationSlotResult, error) { + var temporaryString string + if options.Temporary { + temporaryString = "TEMPORARY" + } + var snapshotString string + if options.SnapshotAction == "export" { + snapshotString = "(SNAPSHOT export)" + } else { + snapshotString = options.SnapshotAction + } + + // NOTE: All strings passed into here have been validated and are not prone to SQL injection. + newPgCreateSlotCommand := fmt.Sprintf("CREATE_REPLICATION_SLOT %s %s %s %s %s", slotName, temporaryString, options.Mode, outputPlugin, snapshotString) + oldPgCreateSlotCommand := fmt.Sprintf("SELECT * FROM pg_create_logical_replication_slot('%s', '%s', %v);", slotName, outputPlugin, temporaryString == "TEMPORARY") + + var snapshotName string + if version > 14 { + result, err := ParseCreateReplicationSlot(conn.Exec(ctx, newPgCreateSlotCommand), version, snapshotName) + if err != nil { + return CreateReplicationSlotResult{}, err + } + if snapshotter != nil { + snapshotter.setTransactionSnapshotName(result.SnapshotName) + } + + return result, nil + } + + var snapshotResponse SnapshotCreationResponse + if options.SnapshotAction == "export" { + var err error + snapshotResponse, err = snapshotter.initSnapshotTransaction() + if err != nil { + return CreateReplicationSlotResult{}, err + } + snapshotter.setTransactionSnapshotName(snapshotResponse.ExportedSnapshotName) + } + + replicationSlotCreationResponse := conn.Exec(ctx, oldPgCreateSlotCommand) + _, err := replicationSlotCreationResponse.ReadAll() + if err != nil { + return CreateReplicationSlotResult{}, err + } + + return CreateReplicationSlotResult{ + SnapshotName: snapshotResponse.ExportedSnapshotName, + }, nil +} + +// ParseCreateReplicationSlot parses the result of the CREATE_REPLICATION_SLOT command. +func ParseCreateReplicationSlot(mrr *pgconn.MultiResultReader, version int, snapshotName string) (CreateReplicationSlotResult, error) { + var crsr CreateReplicationSlotResult + results, err := mrr.ReadAll() + if err != nil { + return crsr, err + } + + if len(results) != 1 { + return crsr, fmt.Errorf("expected 1 result set, got %d", len(results)) + } + + result := results[0] + if len(result.Rows) != 1 { + return crsr, fmt.Errorf("expected 1 result row, got %d", len(result.Rows)) + } + + row := result.Rows[0] + if version > 14 { + if len(row) != 4 { + return crsr, fmt.Errorf("expected 4 result columns, got %d", len(row)) + } + } + + crsr.SlotName = string(row[0]) + crsr.ConsistentPoint = string(row[1]) + + if version > 14 { + crsr.SnapshotName = string(row[2]) + } else { + crsr.SnapshotName = snapshotName + } + + return crsr, nil +} + +// DropReplicationSlotOptions are options for the DROP_REPLICATION_SLOT command. +type DropReplicationSlotOptions struct { + Wait bool +} + +// DropReplicationSlot drops a logical replication slot. +func DropReplicationSlot(ctx context.Context, conn *pgconn.PgConn, slotName string, options DropReplicationSlotOptions) error { + var waitString string + if options.Wait { + waitString = "WAIT" + } + sql := fmt.Sprintf("DROP_REPLICATION_SLOT %s %s", slotName, waitString) + _, err := conn.Exec(ctx, sql).ReadAll() + return err +} + +// CreatePublication creates a new PostgreSQL publication with the given name for a list of tables and drop if exists flag +func CreatePublication(ctx context.Context, conn *pgconn.PgConn, publicationName string, tables []string) error { + // Check if publication exists + pubQuery, err := sanitize.SQLQuery(` + SELECT pubname, puballtables + FROM pg_publication + WHERE pubname = $1; + `, publicationName) + if err != nil { + return fmt.Errorf("failed to sanitize publication query: %w", err) + } + + // Since we need to pass table names without quoting, we need to validate it + for _, table := range tables { + if err := sanitize.ValidatePostgresIdentifier(table); err != nil { + return errors.New("invalid table name") + } + } + // the same for publication name + if err := sanitize.ValidatePostgresIdentifier(publicationName); err != nil { + return errors.New("invalid publication name") + } + + result := conn.Exec(ctx, pubQuery) + + rows, err := result.ReadAll() + if err != nil { + return fmt.Errorf("failed to check publication existence: %w", err) + } + + tablesClause := "FOR ALL TABLES" + if len(tables) > 0 { + // quotedTables := make([]string, len(tables)) + // for i, table := range tables { + // // Use sanitize.SQLIdentifier to properly quote and escape table names + // quoted, err := sanitize.SQLIdentifier(table) + // if err != nil { + // return fmt.Errorf("invalid table name %q: %w", table, err) + // } + // quotedTables[i] = quoted + // } + tablesClause = "FOR TABLE " + strings.Join(tables, ", ") + } + + if len(rows) == 0 || len(rows[0].Rows) == 0 { + // tablesClause is sanitized, so we can safely interpolate it into the query + sq, err := sanitize.SQLQuery(fmt.Sprintf("CREATE PUBLICATION %s %s;", publicationName, tablesClause)) + if err != nil { + return fmt.Errorf("failed to sanitize publication creation query: %w", err) + } + // Publication doesn't exist, create new one + result = conn.Exec(ctx, sq) + if _, err := result.ReadAll(); err != nil { + return fmt.Errorf("failed to create publication: %w", err) + } + + return nil + } + + // assuming publication already exists + // get a list of tables in the publication + pubTables, forAllTables, err := GetPublicationTables(ctx, conn, publicationName) + if err != nil { + return fmt.Errorf("failed to get publication tables: %w", err) + } + + // list of tables to publish is empty and publication is for all tables + // no update is needed + if forAllTables && len(pubTables) == 0 { + return nil + } + + var tablesToRemoveFromPublication = []string{} + var tablesToAddToPublication = []string{} + for _, table := range tables { + if !slices.Contains(pubTables, table) { + tablesToAddToPublication = append(tablesToAddToPublication, table) + } + } + + for _, table := range pubTables { + if !slices.Contains(tables, table) { + tablesToRemoveFromPublication = append(tablesToRemoveFromPublication, table) + } + } + + // remove tables from publication + for _, dropTable := range tablesToRemoveFromPublication { + sq, err := sanitize.SQLQuery(fmt.Sprintf("ALTER PUBLICATION %s DROP TABLE %s;", publicationName, dropTable)) + if err != nil { + return fmt.Errorf("failed to sanitize drop table query: %w", err) + } + result = conn.Exec(ctx, sq) + if _, err := result.ReadAll(); err != nil { + return fmt.Errorf("failed to remove table from publication: %w", err) + } + } + + // add tables to publication + for _, addTable := range tablesToAddToPublication { + sq, err := sanitize.SQLQuery(fmt.Sprintf("ALTER PUBLICATION %s ADD TABLE %s;", publicationName, addTable)) + if err != nil { + return fmt.Errorf("failed to sanitize add table query: %w", err) + } + result = conn.Exec(ctx, sq) + if _, err := result.ReadAll(); err != nil { + return fmt.Errorf("failed to add table to publication: %w", err) + } + } + + return nil +} + +// GetPublicationTables returns a list of tables currently in the publication +// Arguments, in order: list of the tables, exist for all tables, errror +func GetPublicationTables(ctx context.Context, conn *pgconn.PgConn, publicationName string) ([]string, bool, error) { + query, err := sanitize.SQLQuery(` + SELECT DISTINCT + tablename as table_name + FROM pg_publication_tables + WHERE pubname = $1 + ORDER BY table_name; + `, publicationName) + if err != nil { + return nil, false, fmt.Errorf("failed to get publication tables: %w", err) + } + + // Get specific tables in the publication + result := conn.Exec(ctx, query) + + rows, err := result.ReadAll() + if err != nil { + return nil, false, fmt.Errorf("failed to get publication tables: %w", err) + } + + if len(rows) == 0 || len(rows[0].Rows) == 0 { + return nil, true, nil // Publication exists and is for all tables + } + + tables := make([]string, 0, len(rows)) + for _, row := range rows[0].Rows { + tables = append(tables, string(row[0])) + } + + return tables, false, nil +} + +// StartReplicationOptions are the options for the START_REPLICATION command. +// The Timeline field is optional and defaults to 0, which means the current server timeline. +// The Mode field is required and must be either PhysicalReplication or LogicalReplication. ## PhysicalReplication is not supporter by this plugin, but still can be implemented +// The PluginArgs field is optional and only used for LogicalReplication. +type StartReplicationOptions struct { + Timeline int32 // 0 means current server timeline + Mode ReplicationMode + PluginArgs []string +} + +// StartReplication begins the replication process by executing the START_REPLICATION command. +func StartReplication(ctx context.Context, conn *pgconn.PgConn, slotName string, startLSN LSN, options StartReplicationOptions) error { + var timelineString string + if options.Timeline > 0 { + timelineString = fmt.Sprintf("TIMELINE %d", options.Timeline) + options.PluginArgs = append(options.PluginArgs, timelineString) + } + + sql := fmt.Sprintf("START_REPLICATION SLOT %s %s %s ", slotName, options.Mode, startLSN) + if options.Mode == LogicalReplication { + if len(options.PluginArgs) > 0 { + sql += fmt.Sprintf("(%s)", strings.Join(options.PluginArgs, ", ")) + } + } else { + sql += timelineString + } + + conn.Frontend().SendQuery(&pgproto3.Query{String: sql}) + err := conn.Frontend().Flush() + if err != nil { + return fmt.Errorf("failed to send START_REPLICATION: %w", err) + } + + for { + msg, err := conn.ReceiveMessage(ctx) + if err != nil { + return fmt.Errorf("failed to receive message: %w", err) + } + + switch msg := msg.(type) { + case *pgproto3.NoticeResponse: + case *pgproto3.ErrorResponse: + return pgconn.ErrorResponseToPgError(msg) + case *pgproto3.CopyBothResponse: + // This signals the start of the replication stream. + return nil + default: + return fmt.Errorf("unexpected response type: %T", msg) + } + } +} + +// PrimaryKeepaliveMessage is a message sent by the primary server to the replica server to keep the connection alive. +type PrimaryKeepaliveMessage struct { + ServerWALEnd LSN + ServerTime time.Time + ReplyRequested bool +} + +// ParsePrimaryKeepaliveMessage parses a Primary keepalive message from the server. +func ParsePrimaryKeepaliveMessage(buf []byte) (PrimaryKeepaliveMessage, error) { + var pkm PrimaryKeepaliveMessage + if len(buf) != 17 { + return pkm, fmt.Errorf("PrimaryKeepaliveMessage must be 17 bytes, got %d", len(buf)) + } + + pkm.ServerWALEnd = LSN(binary.BigEndian.Uint64(buf)) + pkm.ServerTime = pgTimeToTime(int64(binary.BigEndian.Uint64(buf[8:]))) + pkm.ReplyRequested = buf[16] != 0 + + return pkm, nil +} + +// XLogData is a message sent by the primary server to the replica server containing WAL data. +type XLogData struct { + WALStart LSN + ServerWALEnd LSN + ServerTime time.Time + WALData []byte +} + +// ParseXLogData parses a XLogData message from the server. +func ParseXLogData(buf []byte) (XLogData, error) { + var xld XLogData + if len(buf) < 24 { + return xld, fmt.Errorf("XLogData must be at least 24 bytes, got %d", len(buf)) + } + + xld.WALStart = LSN(binary.BigEndian.Uint64(buf)) + xld.ServerWALEnd = LSN(binary.BigEndian.Uint64(buf[8:])) + xld.ServerTime = pgTimeToTime(int64(binary.BigEndian.Uint64(buf[16:]))) + xld.WALData = buf[24:] + + return xld, nil +} + +// StandbyStatusUpdate is a message sent from the client that acknowledges receipt of WAL records. +type StandbyStatusUpdate struct { + WALWritePosition LSN // The WAL position that's been locally written + WALFlushPosition LSN // The WAL position that's been locally flushed + WALApplyPosition LSN // The WAL position that's been locally applied + ClientTime time.Time // Client system clock time + ReplyRequested bool // Request server to reply immediately. +} + +// SendStandbyStatusUpdate sends a StandbyStatusUpdate to the PostgreSQL server. +// +// The only required field in ssu is WALWritePosition. If WALFlushPosition is 0 then WALWritePosition will be assigned +// to it. If WALApplyPosition is 0 then WALWritePosition will be assigned to it. If ClientTime is the zero value then +// the current time will be assigned to it. +func SendStandbyStatusUpdate(_ context.Context, conn *pgconn.PgConn, ssu StandbyStatusUpdate) error { + if ssu.WALFlushPosition == 0 { + ssu.WALFlushPosition = ssu.WALWritePosition + } + if ssu.WALApplyPosition == 0 { + ssu.WALApplyPosition = ssu.WALWritePosition + } + if ssu.ClientTime == (time.Time{}) { + ssu.ClientTime = time.Now() + } + + data := make([]byte, 0, 34) + data = append(data, StandbyStatusUpdateByteID) + data = pgio.AppendUint64(data, uint64(ssu.WALWritePosition)) + data = pgio.AppendUint64(data, uint64(ssu.WALFlushPosition)) + data = pgio.AppendUint64(data, uint64(ssu.WALApplyPosition)) + data = pgio.AppendInt64(data, timeToPgTime(ssu.ClientTime)) + if ssu.ReplyRequested { + data = append(data, 1) + } else { + data = append(data, 0) + } + + cd := &pgproto3.CopyData{Data: data} + buf, err := cd.Encode(nil) + if err != nil { + return err + } + + return conn.Frontend().SendUnbufferedEncodedCopyData(buf) +} + +// CopyDoneResult is the parsed result as returned by the server after the client +// sends a CopyDone to the server to confirm ending the copy-both mode. +type CopyDoneResult struct { + Timeline int32 + LSN LSN +} + +// SendStandbyCopyDone sends a StandbyCopyDone to the PostgreSQL server +// to confirm ending the copy-both mode. +func SendStandbyCopyDone(_ context.Context, conn *pgconn.PgConn) (cdr *CopyDoneResult, err error) { + // I am suspicious that this is wildly wrong, but I'm pretty sure the previous + // code was wildly wrong too -- wttw + conn.Frontend().Send(&pgproto3.CopyDone{}) + err = conn.Frontend().Flush() + if err != nil { + return + } + + for { + var msg pgproto3.BackendMessage + msg, err = conn.Frontend().Receive() + if err != nil { + return cdr, err + } + + switch m := msg.(type) { + case *pgproto3.CopyDone: + case *pgproto3.ParameterStatus, *pgproto3.NoticeResponse: + case *pgproto3.CommandComplete: + case *pgproto3.RowDescription: + case *pgproto3.DataRow: + // We are expecting just one row returned, with two columns timeline and LSN + // We should pay attention to RowDescription, but we'll take it on trust. + if len(m.Values) == 2 { + timeline, lerr := strconv.Atoi(string(m.Values[0])) + if lerr == nil { + lsn, lerr := ParseLSN(string(m.Values[1])) + if lerr == nil { + cdr = new(CopyDoneResult) + cdr.Timeline = int32(timeline) + cdr.LSN = lsn + } + } + } + case *pgproto3.EmptyQueryResponse: + case *pgproto3.ErrorResponse: + return cdr, pgconn.ErrorResponseToPgError(m) + case *pgproto3.ReadyForQuery: + // Should we eat the ReadyForQuery here, or not? + return cdr, err + } + } +} + +const microsecFromUnixEpochToY2K = 946684800 * 1000000 + +func pgTimeToTime(microsecSinceY2K int64) time.Time { + microsecSinceUnixEpoch := microsecFromUnixEpochToY2K + microsecSinceY2K + return time.Unix(0, microsecSinceUnixEpoch*1000) +} + +func timeToPgTime(t time.Time) int64 { + microsecSinceUnixEpoch := t.Unix()*1000000 + int64(t.Nanosecond())/1000 + return microsecSinceUnixEpoch - microsecFromUnixEpochToY2K +} diff --git a/internal/impl/postgresql/pglogicalstream/pglogrepl_test.go b/internal/impl/postgresql/pglogicalstream/pglogrepl_test.go new file mode 100644 index 0000000000..8a50b34bc1 --- /dev/null +++ b/internal/impl/postgresql/pglogicalstream/pglogrepl_test.go @@ -0,0 +1,462 @@ +// Copyright 2024 Redpanda Data, Inc. +// +// Licensed as a Redpanda Enterprise file under the Redpanda Community +// License (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// https://github.com/redpanda-data/connect/v4/blob/main/licenses/rcl.md + +package pglogicalstream + +import ( + "context" + "database/sql" + "encoding/json" + "fmt" + "strings" + "testing" + "time" + + "github.com/jackc/pgx/v5/pgconn" + "github.com/jackc/pgx/v5/pgproto3" + "github.com/jackc/pgx/v5/pgtype" + "github.com/ory/dockertest/v3" + "github.com/ory/dockertest/v3/docker" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" +) + +func TestLSNSuite(t *testing.T) { + suite.Run(t, new(lsnSuite)) +} + +type lsnSuite struct { + suite.Suite +} + +func (s *lsnSuite) R() *require.Assertions { + return s.Require() +} + +func (s *lsnSuite) Equal(e, a interface{}, args ...interface{}) { + s.R().Equal(e, a, args...) +} + +func (s *lsnSuite) NoError(err error) { + s.R().NoError(err) +} + +func (s *lsnSuite) TestScannerInterface() { + var lsn LSN + lsnText := "16/B374D848" + lsnUint64 := uint64(97500059720) + var err error + + err = lsn.Scan(lsnText) + s.NoError(err) + s.Equal(lsnText, lsn.String()) + + err = lsn.Scan([]byte(lsnText)) + s.NoError(err) + s.Equal(lsnText, lsn.String()) + + lsn = 0 + err = lsn.Scan(lsnUint64) + s.NoError(err) + s.Equal(lsnText, lsn.String()) + + err = lsn.Scan(int64(lsnUint64)) + s.Error(err) + s.T().Log(err) +} + +func (s *lsnSuite) TestScanToNil() { + var lsnPtr *LSN + err := lsnPtr.Scan("16/B374D848") + s.NoError(err) +} + +func (s *lsnSuite) TestValueInterface() { + lsn := LSN(97500059720) + driverValue, err := lsn.Value() + s.NoError(err) + lsnStr, ok := driverValue.(string) + s.R().True(ok) + s.Equal("16/B374D848", lsnStr) +} + +const slotName = "pglogrepl_test" +const outputPlugin = "pgoutput" + +func closeConn(t testing.TB, conn *pgconn.PgConn) { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + require.NoError(t, conn.Close(ctx)) +} + +func createDockerInstance(t *testing.T) (*dockertest.Pool, *dockertest.Resource, string) { + pool, err := dockertest.NewPool("") + require.NoError(t, err) + + resource, err := pool.RunWithOptions(&dockertest.RunOptions{ + Repository: "postgres", + Tag: "16", + Env: []string{ + "POSTGRES_PASSWORD=secret", + "POSTGRES_USER=user_name", + "POSTGRES_DB=dbname", + }, + Cmd: []string{ + "postgres", + "-c", "wal_level=logical", + }, + }, func(config *docker.HostConfig) { + config.AutoRemove = true + config.RestartPolicy = docker.RestartPolicy{Name: "no"} + }) + + require.NoError(t, err) + require.NoError(t, resource.Expire(120)) + + hostAndPort := resource.GetHostPort("5432/tcp") + hostAndPortSplited := strings.Split(hostAndPort, ":") + databaseURL := fmt.Sprintf("user=user_name password=secret dbname=dbname sslmode=disable host=%s port=%s replication=database", hostAndPortSplited[0], hostAndPortSplited[1]) + + var db *sql.DB + pool.MaxWait = 120 * time.Second + err = pool.Retry(func() error { + if db, err = sql.Open("postgres", databaseURL); err != nil { + return err + } + + if err = db.Ping(); err != nil { + return err + } + + return err + }) + require.NoError(t, err) + + return pool, resource, databaseURL +} + +func TestIdentifySystem(t *testing.T) { + pool, resource, dbURL := createDockerInstance(t) + defer func() { + err := pool.Purge(resource) + require.NoError(t, err) + }() + ctx, cancel := context.WithTimeout(context.Background(), time.Second*100) + defer cancel() + + conn, err := pgconn.Connect(ctx, dbURL) + require.NoError(t, err) + defer closeConn(t, conn) + + sysident, err := IdentifySystem(ctx, conn) + require.NoError(t, err) + + assert.NotEmpty(t, sysident.SystemID, 0) + assert.Greater(t, sysident.Timeline, int32(0)) + + xlogPositionIsPositive := sysident.XLogPos > 0 + assert.True(t, xlogPositionIsPositive) + assert.NotEmpty(t, sysident.DBName, 0) +} + +func TestCreateReplicationSlot(t *testing.T) { + pool, resource, dbURL := createDockerInstance(t) + defer func() { + err := pool.Purge(resource) + require.NoError(t, err) + }() + ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) + defer cancel() + + conn, err := pgconn.Connect(ctx, dbURL) + require.NoError(t, err) + defer closeConn(t, conn) + + result, err := CreateReplicationSlot(ctx, conn, slotName, outputPlugin, CreateReplicationSlotOptions{Temporary: false, SnapshotAction: "export"}, 16, nil) + require.NoError(t, err) + + assert.Equal(t, slotName, result.SlotName) +} + +func TestDropReplicationSlot(t *testing.T) { + pool, resource, dbURL := createDockerInstance(t) + defer func() { + err := pool.Purge(resource) + require.NoError(t, err) + }() + ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) + defer cancel() + + conn, err := pgconn.Connect(ctx, dbURL) + require.NoError(t, err) + defer closeConn(t, conn) + + _, err = CreateReplicationSlot(ctx, conn, slotName, outputPlugin, CreateReplicationSlotOptions{Temporary: false}, 16, nil) + require.NoError(t, err) + + err = DropReplicationSlot(ctx, conn, slotName, DropReplicationSlotOptions{}) + require.NoError(t, err) + + _, err = CreateReplicationSlot(ctx, conn, slotName, outputPlugin, CreateReplicationSlotOptions{Temporary: false}, 16, nil) + require.NoError(t, err) +} + +func TestCreatePublication(t *testing.T) { + pool, resource, dbURL := createDockerInstance(t) + defer func() { + err := pool.Purge(resource) + require.NoError(t, err) + }() + + ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) + defer cancel() + + conn, err := pgconn.Connect(ctx, dbURL) + require.NoError(t, err) + defer closeConn(t, conn) + + publicationName := "test_publication" + err = CreatePublication(context.Background(), conn, publicationName, []string{}) + require.NoError(t, err) + + tables, forAllTables, err := GetPublicationTables(context.Background(), conn, publicationName) + require.NoError(t, err) + assert.Empty(t, tables) + assert.True(t, forAllTables) + + multiReader := conn.Exec(context.Background(), "CREATE TABLE test_table (id serial PRIMARY KEY, name text);") + _, err = multiReader.ReadAll() + require.NoError(t, err) + + publicationWithTables := "test_pub_with_tables" + err = CreatePublication(context.Background(), conn, publicationWithTables, []string{"test_table"}) + require.NoError(t, err) + + tables, forAllTables, err = GetPublicationTables(context.Background(), conn, publicationName) + require.NoError(t, err) + assert.NotEmpty(t, tables) + assert.Contains(t, tables, "test_table") + assert.False(t, forAllTables) + + // add more tables to publication + multiReader = conn.Exec(context.Background(), "CREATE TABLE test_table2 (id serial PRIMARY KEY, name text);") + _, err = multiReader.ReadAll() + require.NoError(t, err) + + // Pass more tables to the publication + err = CreatePublication(context.Background(), conn, publicationWithTables, []string{ + "test_table2", + "test_table", + }) + require.NoError(t, err) + + tables, forAllTables, err = GetPublicationTables(context.Background(), conn, publicationWithTables) + require.NoError(t, err) + assert.NotEmpty(t, tables) + assert.Contains(t, tables, "test_table") + assert.Contains(t, tables, "test_table2") + assert.False(t, forAllTables) + + // Removing one table from the publication + err = CreatePublication(context.Background(), conn, publicationWithTables, []string{ + "test_table", + }) + require.NoError(t, err) + + tables, forAllTables, err = GetPublicationTables(context.Background(), conn, publicationWithTables) + require.NoError(t, err) + assert.NotEmpty(t, tables) + assert.Contains(t, tables, "test_table") + assert.False(t, forAllTables) + + // Add one table and remove one at the same time + err = CreatePublication(context.Background(), conn, publicationWithTables, []string{ + "test_table2", + }) + require.NoError(t, err) + + tables, forAllTables, err = GetPublicationTables(context.Background(), conn, publicationWithTables) + require.NoError(t, err) + assert.NotEmpty(t, tables) + assert.Contains(t, tables, "test_table2") + assert.False(t, forAllTables) + +} + +func TestStartReplication(t *testing.T) { + pool, resource, dbURL := createDockerInstance(t) + defer func() { + err := pool.Purge(resource) + require.NoError(t, err) + }() + + ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) + defer cancel() + + conn, err := pgconn.Connect(ctx, dbURL) + require.NoError(t, err) + defer closeConn(t, conn) + + sysident, err := IdentifySystem(ctx, conn) + require.NoError(t, err) + + // create publication + publicationName := "test_publication" + err = CreatePublication(context.Background(), conn, publicationName, []string{}) + require.NoError(t, err) + + _, err = CreateReplicationSlot(ctx, conn, slotName, outputPlugin, CreateReplicationSlotOptions{Temporary: false, SnapshotAction: "export"}, 16, nil) + require.NoError(t, err) + + err = StartReplication(ctx, conn, slotName, sysident.XLogPos, StartReplicationOptions{ + PluginArgs: []string{ + "proto_version '1'", + "publication_names 'test_publication'", + "messages 'true'", + }, + Mode: LogicalReplication, + }) + require.NoError(t, err) + + go func() { + ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) + defer cancel() + + config, err := pgconn.ParseConfig(dbURL) + require.NoError(t, err) + delete(config.RuntimeParams, "replication") + + conn, err := pgconn.ConnectConfig(ctx, config) + require.NoError(t, err) + defer closeConn(t, conn) + + _, err = conn.Exec(ctx, ` +create table t(id int primary key, name text); + +insert into t values (1, 'foo'); +insert into t values (2, 'bar'); +insert into t values (3, 'baz'); + +update t set name='quz' where id=3; + +delete from t where id=2; + +drop table t; +`).ReadAll() + require.NoError(t, err) + }() + + rxKeepAlive := func() PrimaryKeepaliveMessage { + msg, err := conn.ReceiveMessage(ctx) + require.NoError(t, err) + cdMsg, ok := msg.(*pgproto3.CopyData) + require.True(t, ok) + + require.Equal(t, byte(PrimaryKeepaliveMessageByteID), cdMsg.Data[0]) + pkm, err := ParsePrimaryKeepaliveMessage(cdMsg.Data[1:]) + require.NoError(t, err) + return pkm + } + + relations := map[uint32]*RelationMessage{} + typeMap := pgtype.NewMap() + + rxXLogData := func() XLogData { + var cdMsg *pgproto3.CopyData + // Discard keepalive messages + for { + msg, err := conn.ReceiveMessage(ctx) + require.NoError(t, err) + var ok bool + cdMsg, ok = msg.(*pgproto3.CopyData) + require.True(t, ok) + if cdMsg.Data[0] != PrimaryKeepaliveMessageByteID { + break + } + } + require.Equal(t, byte(XLogDataByteID), cdMsg.Data[0]) + xld, err := ParseXLogData(cdMsg.Data[1:]) + require.NoError(t, err) + return xld + } + + rxKeepAlive() + xld := rxXLogData() + begin, err := isBeginMessage(xld.WALData) + require.NoError(t, err) + assert.True(t, begin) + + xld = rxXLogData() + var streamMessage *StreamMessageChanges + streamMessage, err = decodePgOutput(xld.WALData, relations, typeMap) + require.NoError(t, err) + assert.Nil(t, streamMessage) + + xld = rxXLogData() + streamMessage, err = decodePgOutput(xld.WALData, relations, typeMap) + require.NoError(t, err) + jsonData, err := json.Marshal(&streamMessage) + require.NoError(t, err) + assert.Equal(t, "{\"operation\":\"insert\",\"schema\":\"public\",\"table\":\"t\",\"data\":{\"id\":1,\"name\":\"foo\"}}", string(jsonData)) + + xld = rxXLogData() + streamMessage, err = decodePgOutput(xld.WALData, relations, typeMap) + require.NoError(t, err) + jsonData, err = json.Marshal(&streamMessage) + require.NoError(t, err) + assert.Equal(t, "{\"operation\":\"insert\",\"schema\":\"public\",\"table\":\"t\",\"data\":{\"id\":2,\"name\":\"bar\"}}", string(jsonData)) + + xld = rxXLogData() + streamMessage, err = decodePgOutput(xld.WALData, relations, typeMap) + require.NoError(t, err) + jsonData, err = json.Marshal(&streamMessage) + require.NoError(t, err) + assert.Equal(t, "{\"operation\":\"insert\",\"schema\":\"public\",\"table\":\"t\",\"data\":{\"id\":3,\"name\":\"baz\"}}", string(jsonData)) + + xld = rxXLogData() + streamMessage, err = decodePgOutput(xld.WALData, relations, typeMap) + require.NoError(t, err) + jsonData, err = json.Marshal(&streamMessage) + require.NoError(t, err) + assert.Equal(t, "{\"operation\":\"update\",\"schema\":\"public\",\"table\":\"t\",\"data\":{\"id\":3,\"name\":\"quz\"}}", string(jsonData)) + + xld = rxXLogData() + streamMessage, err = decodePgOutput(xld.WALData, relations, typeMap) + require.NoError(t, err) + jsonData, err = json.Marshal(&streamMessage) + require.NoError(t, err) + assert.Equal(t, "{\"operation\":\"delete\",\"schema\":\"public\",\"table\":\"t\",\"data\":{\"id\":2,\"name\":null}}", string(jsonData)) + xld = rxXLogData() + + var commit bool + commit, _, err = isCommitMessage(xld.WALData) + require.NoError(t, err) + assert.True(t, commit) +} + +func TestSendStandbyStatusUpdate(t *testing.T) { + pool, resource, dbURL := createDockerInstance(t) + defer func() { + err := pool.Purge(resource) + require.NoError(t, err) + }() + + ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) + defer cancel() + + conn, err := pgconn.Connect(ctx, dbURL) + require.NoError(t, err) + defer closeConn(t, conn) + + sysident, err := IdentifySystem(ctx, conn) + require.NoError(t, err) + + err = SendStandbyStatusUpdate(ctx, conn, StandbyStatusUpdate{WALWritePosition: sysident.XLogPos}) + require.NoError(t, err) +} diff --git a/internal/impl/postgresql/pglogicalstream/pluginhandlers.go b/internal/impl/postgresql/pglogicalstream/pluginhandlers.go new file mode 100644 index 0000000000..6e228dc21b --- /dev/null +++ b/internal/impl/postgresql/pglogicalstream/pluginhandlers.go @@ -0,0 +1,166 @@ +// Copyright 2024 Redpanda Data, Inc. +// +// Licensed as a Redpanda Enterprise file under the Redpanda Community +// License (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// https://github.com/redpanda-data/connect/v4/blob/main/licenses/rcl.md + +package pglogicalstream + +import ( + "context" + + "github.com/jackc/pgx/v5/pgtype" + + "github.com/redpanda-data/connect/v4/internal/impl/postgresql/pglogicalstream/watermark" +) + +// PluginHandler is an interface that must be implemented by all plugin handlers +type PluginHandler interface { + // returns true if we need to ack the clientXLogPos + Handle(ctx context.Context, clientXLogPos LSN, xld XLogData) (bool, error) +} + +// PgOutputUnbufferedPluginHandler is a native output handler that emits each message as it's received. +type PgOutputUnbufferedPluginHandler struct { + messages chan StreamMessage + monitor *Monitor + + relations map[uint32]*RelationMessage + typeMap *pgtype.Map + + lastEmitted LSN + lsnWatermark *watermark.Value[LSN] +} + +// PgOutputBufferedPluginHandler is a native output handler that buffers and emits each transaction together +type PgOutputBufferedPluginHandler struct { + messages chan StreamMessage + monitor *Monitor + + relations map[uint32]*RelationMessage + typeMap *pgtype.Map + pgoutputChanges []StreamMessageChanges +} + +// NewPgOutputPluginHandler creates a new PgOutputPluginHandler +func NewPgOutputPluginHandler( + messages chan StreamMessage, + batchTransactions bool, + monitor *Monitor, + lsnWatermark *watermark.Value[LSN], +) PluginHandler { + if batchTransactions { + return &PgOutputUnbufferedPluginHandler{ + messages: messages, + monitor: monitor, + relations: map[uint32]*RelationMessage{}, + typeMap: pgtype.NewMap(), + lastEmitted: lsnWatermark.Get(), + lsnWatermark: lsnWatermark, + } + } + return &PgOutputBufferedPluginHandler{ + messages: messages, + monitor: monitor, + relations: map[uint32]*RelationMessage{}, + typeMap: pgtype.NewMap(), + pgoutputChanges: []StreamMessageChanges{}, + } +} + +// Handle handles the pgoutput output +func (p *PgOutputUnbufferedPluginHandler) Handle(ctx context.Context, clientXLogPos LSN, xld XLogData) (bool, error) { + // parse changes inside the transaction + message, err := decodePgOutput(xld.WALData, p.relations, p.typeMap) + if err != nil { + return false, err + } + + isCommit, _, err := isCommitMessage(xld.WALData) + if err != nil { + return false, err + } + + // when receiving a commit message, we need to acknowledge the LSN + // but we must wait for connect to flush the messages before we can do that + if isCommit { + select { + case <-p.lsnWatermark.WaitFor(p.lastEmitted): + return true, nil + case <-ctx.Done(): + return false, ctx.Err() + } + } + + if message != nil { + lsn := clientXLogPos.String() + msg := StreamMessage{ + Lsn: &lsn, + Changes: []StreamMessageChanges{*message}, + Mode: StreamModeStreaming, + WALLagBytes: &p.monitor.Report().WalLagInBytes, + } + select { + case p.messages <- msg: + p.lastEmitted = clientXLogPos + case <-ctx.Done(): + return false, ctx.Err() + } + } + + return false, nil +} + +// Handle handles the pgoutput output +func (p *PgOutputBufferedPluginHandler) Handle(ctx context.Context, clientXLogPos LSN, xld XLogData) (bool, error) { + // message changes must be collected in the buffer in the context of the same transaction + // as single transaction can contain multiple changes + // and LSN ack will cause potential loss of changes + isBegin, err := isBeginMessage(xld.WALData) + if err != nil { + return false, err + } + + if isBegin { + p.pgoutputChanges = []StreamMessageChanges{} + } + + // parse changes inside the transaction + message, err := decodePgOutput(xld.WALData, p.relations, p.typeMap) + if err != nil { + return false, err + } + + if message != nil { + p.pgoutputChanges = append(p.pgoutputChanges, *message) + } + + isCommit, _, err := isCommitMessage(xld.WALData) + if err != nil { + return false, err + } + + if !isCommit { + return false, nil + } + + if len(p.pgoutputChanges) > 0 { + // send all collected changes + lsn := clientXLogPos.String() + msg := StreamMessage{ + Lsn: &lsn, + Changes: p.pgoutputChanges, + Mode: StreamModeStreaming, + WALLagBytes: &p.monitor.Report().WalLagInBytes, + } + select { + case p.messages <- msg: + case <-ctx.Done(): + return false, ctx.Err() + } + } + + return false, nil +} diff --git a/internal/impl/postgresql/pglogicalstream/replication_message.go b/internal/impl/postgresql/pglogicalstream/replication_message.go new file mode 100644 index 0000000000..f68abd7d3b --- /dev/null +++ b/internal/impl/postgresql/pglogicalstream/replication_message.go @@ -0,0 +1,728 @@ +// Copyright 2024 Redpanda Data, Inc. +// +// Licensed as a Redpanda Enterprise file under the Redpanda Community +// License (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// https://github.com/redpanda-data/connect/v4/blob/main/licenses/rcl.md + +package pglogicalstream + +import ( + "bytes" + "encoding/binary" + "errors" + "fmt" + "strconv" + "time" +) + +var ( + errMsgNotSupported = errors.New("replication message not supported") +) + +// MessageType indicates the type of logical replication message. +type MessageType uint8 + +func (t MessageType) String() string { + switch t { + case MessageTypeBegin: + return "Begin" + case MessageTypeCommit: + return "Commit" + case MessageTypeOrigin: + return "Origin" + case MessageTypeRelation: + return "Relation" + case MessageTypeType: + return "Type" + case MessageTypeInsert: + return "Insert" + case MessageTypeUpdate: + return "Update" + case MessageTypeDelete: + return "Delete" + case MessageTypeTruncate: + return "Truncate" + case MessageTypeMessage: + return "Message" + case MessageTypeStreamStart: + return "StreamStart" + case MessageTypeStreamStop: + return "StreamStop" + case MessageTypeStreamCommit: + return "StreamCommit" + case MessageTypeStreamAbort: + return "StreamAbort" + default: + return "Unknown" + } +} + +// List of types of logical replication messages. +const ( + MessageTypeBegin MessageType = 'B' + MessageTypeMessage MessageType = 'M' + MessageTypeCommit MessageType = 'C' + MessageTypeOrigin MessageType = 'O' + MessageTypeRelation MessageType = 'R' + MessageTypeType MessageType = 'Y' + MessageTypeInsert MessageType = 'I' + MessageTypeUpdate MessageType = 'U' + MessageTypeDelete MessageType = 'D' + MessageTypeTruncate MessageType = 'T' + MessageTypeStreamStart MessageType = 'S' + MessageTypeStreamStop MessageType = 'E' + MessageTypeStreamCommit MessageType = 'c' + MessageTypeStreamAbort MessageType = 'A' +) + +// Message is a message received from server. +type Message interface { + Type() MessageType +} + +// MessageDecoder decodes message into struct. +type MessageDecoder interface { + Decode([]byte) error +} + +type baseMessage struct { + msgType MessageType +} + +// Type returns message type. +func (m *baseMessage) Type() MessageType { + return m.msgType +} + +// SetType sets message type. +// This method is added to help writing test code in application. +// The message type is still defined by message data. +func (m *baseMessage) SetType(t MessageType) { + m.msgType = t +} + +// Decode parse src into message struct. The src must contain the complete message starts after +// the first message type byte. +func (m *baseMessage) Decode(_ []byte) error { + return errors.New("message decode not implemented") +} + +func (m *baseMessage) lengthError(name string, expectedLen, actualLen int) error { + return fmt.Errorf("%s must have %d bytes, got %d bytes", name, expectedLen, actualLen) +} + +func (m *baseMessage) decodeStringError(name, field string) error { + return fmt.Errorf("%s.%s decode string error", name, field) +} + +func (m *baseMessage) decodeTupleDataError(name, field string, e error) error { + return fmt.Errorf("%s.%s decode tuple error: %s", name, field, e.Error()) +} + +func (m *baseMessage) invalidTupleTypeError(name, field string, e string, a byte) error { + return fmt.Errorf("%s.%s invalid tuple type value, expect %s, actual %c", name, field, e, a) +} + +// decodeString decode a string from src and returns the length of bytes being parsed. +// +// String type definition: https://www.postgresql.org/docs/current/protocol-message-types.html +// String(s) +// +// A null-terminated string (C-style string). There is no specific length limitation on strings. +// If s is specified it is the exact value that will appear, otherwise the value is variable. +// Eg. String, String("user"). +// +// If there is no null byte in src, return -1. +func (m *baseMessage) decodeString(src []byte) (string, int) { + end := bytes.IndexByte(src, byte(0)) + if end == -1 { + return "", -1 + } + // Trim the last null byte before converting it to a Golang string, then we can + // compare the result string with a Golang string literal. + return string(src[:end]), end + 1 +} + +func (m *baseMessage) decodeLSN(src []byte) (LSN, int) { + return LSN(binary.BigEndian.Uint64(src)), 8 +} + +func (m *baseMessage) decodeTime(src []byte) (time.Time, int) { + return pgTimeToTime(int64(binary.BigEndian.Uint64(src))), 8 +} + +func (m *baseMessage) decodeUint16(src []byte) (uint16, int) { + return binary.BigEndian.Uint16(src), 2 +} + +func (m *baseMessage) decodeUint32(src []byte) (uint32, int) { + return binary.BigEndian.Uint32(src), 4 +} + +func (m *baseMessage) decodeInt32(src []byte) (int32, int) { + asUint32, size := m.decodeUint32(src) + return int32(asUint32), size +} + +// BeginMessage is a begin message. +type BeginMessage struct { + baseMessage + //FinalLSN is the final LSN of the transaction. + FinalLSN LSN + // CommitTime is the commit timestamp of the transaction. + CommitTime time.Time + // Xid of the transaction. + Xid uint32 +} + +// Decode decodes the message from src. +func (m *BeginMessage) Decode(src []byte) error { + if len(src) < 20 { + return m.lengthError("BeginMessage", 20, len(src)) + } + var low, used int + m.FinalLSN, used = m.decodeLSN(src) + low += used + m.CommitTime, used = m.decodeTime(src[low:]) + low += used + m.Xid = binary.BigEndian.Uint32(src[low:]) + + m.SetType(MessageTypeBegin) + + return nil +} + +// CommitMessage is a commit message. +type CommitMessage struct { + baseMessage + // Flags currently unused (must be 0). + Flags uint8 + // CommitLSN is the LSN of the commit. + CommitLSN LSN + // TransactionEndLSN is the end LSN of the transaction. + TransactionEndLSN LSN + // CommitTime is the commit timestamp of the transaction + CommitTime time.Time +} + +// Decode decodes the message from src. +func (m *CommitMessage) Decode(src []byte) error { + if len(src) < 25 { + return m.lengthError("CommitMessage", 25, len(src)) + } + var low, used int + m.Flags = src[0] + low += 1 + m.CommitLSN, used = m.decodeLSN(src[low:]) + low += used + m.TransactionEndLSN, used = m.decodeLSN(src[low:]) + low += used + m.CommitTime, _ = m.decodeTime(src[low:]) + + m.SetType(MessageTypeCommit) + + return nil +} + +// OriginMessage is an origin message. +type OriginMessage struct { + baseMessage + // CommitLSN is the LSN of the commit on the origin server. + CommitLSN LSN + Name string +} + +// Decode decodes to message from src. +func (m *OriginMessage) Decode(src []byte) error { + if len(src) < 8 { + return m.lengthError("OriginMessage", 9, len(src)) + } + + var low, used int + m.CommitLSN, used = m.decodeLSN(src) + low += used + m.Name, used = m.decodeString(src[low:]) + if used < 0 { + return m.decodeStringError("OriginMessage", "Name") + } + + m.SetType(MessageTypeOrigin) + + return nil +} + +// RelationMessageColumn is one column in a RelationMessage. +type RelationMessageColumn struct { + // Flags for the column. Currently, it can be either 0 for no flags or 1 which marks the column as part of the key. + Flags uint8 + + Name string + + // DataType is the ID of the column's data type. + DataType uint32 + + // TypeModifier is type modifier of the column (atttypmod). + TypeModifier int32 +} + +// RelationMessage is a relation message. +type RelationMessage struct { + baseMessage + RelationID uint32 + Namespace string + RelationName string + ReplicaIdentity uint8 + ColumnNum uint16 + Columns []*RelationMessageColumn +} + +// Decode decodes to message from src. +func (m *RelationMessage) Decode(src []byte) error { + if len(src) < 7 { + return m.lengthError("RelationMessage", 7, len(src)) + } + + var low, used int + m.RelationID, used = m.decodeUint32(src) + low += used + + m.Namespace, used = m.decodeString(src[low:]) + if used < 0 { + return m.decodeStringError("RelationMessage", "Namespace") + } + low += used + + m.RelationName, used = m.decodeString(src[low:]) + if used < 0 { + return m.decodeStringError("RelationMessage", "RelationName") + } + low += used + + m.ReplicaIdentity = src[low] + low++ + + m.ColumnNum, used = m.decodeUint16(src[low:]) + low += used + + for i := 0; i < int(m.ColumnNum); i++ { + column := new(RelationMessageColumn) + column.Flags = src[low] + low++ + column.Name, used = m.decodeString(src[low:]) + if used < 0 { + return m.decodeStringError("RelationMessage", fmt.Sprintf("Column[%d].Name", i)) + } + low += used + + column.DataType, used = m.decodeUint32(src[low:]) + low += used + + column.TypeModifier, used = m.decodeInt32(src[low:]) + low += used + + m.Columns = append(m.Columns, column) + } + + m.SetType(MessageTypeRelation) + + return nil +} + +// TypeMessage is a type message. +type TypeMessage struct { + baseMessage + DataType uint32 + Namespace string + Name string +} + +// Decode decodes to message from src. +func (m *TypeMessage) Decode(src []byte) error { + if len(src) < 6 { + return m.lengthError("TypeMessage", 6, len(src)) + } + + var low, used int + m.DataType, used = m.decodeUint32(src) + low += used + + m.Namespace, used = m.decodeString(src[low:]) + if used < 0 { + return m.decodeStringError("TypeMessage", "Namespace") + } + low += used + + m.Name, used = m.decodeString(src[low:]) + if used < 0 { + return m.decodeStringError("TypeMessage", "Name") + } + + m.SetType(MessageTypeType) + + return nil +} + +// List of types of data in a tuple. +const ( + TupleDataTypeNull = uint8('n') + TupleDataTypeToast = uint8('u') + TupleDataTypeText = uint8('t') + TupleDataTypeBinary = uint8('b') +) + +// TupleDataColumn is a column in a TupleData. +type TupleDataColumn struct { + // DataType indicates how the data is stored. + // Byte1('n') Identifies the data as NULL value. + // Or + // Byte1('u') Identifies unchanged TOASTed value (the actual value is not sent). + // Or + // Byte1('t') Identifies the data as text formatted value. + // Or + // Byte1('b') Identifies the data as binary value. + DataType uint8 + Length uint32 + // Data is th value of the column, in text format. (A future release might support additional formats.) n is the above length. + Data []byte +} + +// Int64 parse column data as an int64 integer. +func (c *TupleDataColumn) Int64() (int64, error) { + if c.DataType != TupleDataTypeText { + return 0, fmt.Errorf("invalid column's data type, expect %c, actual %c", + TupleDataTypeText, c.DataType) + } + + return strconv.ParseInt(string(c.Data), 10, 64) +} + +// TupleData contains row change information. +type TupleData struct { + baseMessage + ColumnNum uint16 + Columns []*TupleDataColumn +} + +// Decode decodes to message from src. +func (m *TupleData) Decode(src []byte) (int, error) { + var low, used int + + m.ColumnNum, used = m.decodeUint16(src) + low += used + + for i := 0; i < int(m.ColumnNum); i++ { + column := new(TupleDataColumn) + column.DataType = src[low] + low += 1 + + switch column.DataType { + case TupleDataTypeText, TupleDataTypeBinary: + column.Length, used = m.decodeUint32(src[low:]) + low += used + + column.Data = make([]byte, int(column.Length)) + for j := 0; j < int(column.Length); j++ { + column.Data[j] = src[low+j] + } + low += int(column.Length) + case TupleDataTypeNull, TupleDataTypeToast: + } + + m.Columns = append(m.Columns, column) + } + + return low, nil +} + +// InsertMessage is a insert message +type InsertMessage struct { + baseMessage + // RelationID is the ID of the relation corresponding to the ID in the relation message. + RelationID uint32 + Tuple *TupleData +} + +// Decode decodes to message from src. +func (m *InsertMessage) Decode(src []byte) error { + if len(src) < 8 { + return m.lengthError("InsertMessage", 8, len(src)) + } + + var low, used int + + m.RelationID, used = m.decodeUint32(src) + low += used + + tupleType := src[low] + low += 1 + if tupleType != 'N' { + return m.invalidTupleTypeError("InsertMessage", "TupleType", "N", tupleType) + } + + m.Tuple = new(TupleData) + _, err := m.Tuple.Decode(src[low:]) + if err != nil { + return m.decodeTupleDataError("InsertMessage", "TupleData", err) + } + + m.SetType(MessageTypeInsert) + + return nil +} + +// List of types of UpdateMessage tuples. +const ( + UpdateMessageTupleTypeNone = uint8(0) + UpdateMessageTupleTypeKey = uint8('K') + UpdateMessageTupleTypeOld = uint8('O') + UpdateMessageTupleTypeNew = uint8('N') +) + +// UpdateMessage is a update message. +type UpdateMessage struct { + baseMessage + RelationID uint32 + + // OldTupleType + // Byte1('K'): + // Identifies the following TupleData submessage as a key. + // This field is optional and is only present if the update changed data + // in any of the column(s) that are part of the REPLICA IDENTITY index. + // + // Byte1('O'): + // Identifies the following TupleData submessage as an old tuple. + // This field is optional and is only present if table in which the update happened + // has REPLICA IDENTITY set to FULL. + // + // The Update message may contain either a 'K' message part or an 'O' message part + // or neither of them, but never both of them. + OldTupleType uint8 + OldTuple *TupleData + + // NewTuple is the contents of a new tuple. + // Byte1('N'): Identifies the following TupleData message as a new tuple. + NewTuple *TupleData +} + +// Decode decodes to message from src. +func (m *UpdateMessage) Decode(src []byte) (err error) { + if len(src) < 6 { + return m.lengthError("UpdateMessage", 6, len(src)) + } + + var low, used int + + m.RelationID, used = m.decodeUint32(src) + low += used + + tupleType := src[low] + low++ + + switch tupleType { + case UpdateMessageTupleTypeKey, UpdateMessageTupleTypeOld: + m.OldTupleType = tupleType + m.OldTuple = new(TupleData) + used, err = m.OldTuple.Decode(src[low:]) + if err != nil { + return m.decodeTupleDataError("UpdateMessage", "OldTuple", err) + } + low += used + low++ + fallthrough + case UpdateMessageTupleTypeNew: + m.NewTuple = new(TupleData) + _, err = m.NewTuple.Decode(src[low:]) + if err != nil { + return m.decodeTupleDataError("UpdateMessage", "NewTuple", err) + } + default: + return m.invalidTupleTypeError("UpdateMessage", "Tuple", "K/O/N", tupleType) + } + + m.SetType(MessageTypeUpdate) + + return nil +} + +// List of types of DeleteMessage tuples. +const ( + DeleteMessageTupleTypeKey = uint8('K') + DeleteMessageTupleTypeOld = uint8('O') +) + +// DeleteMessage is a delete message. +type DeleteMessage struct { + baseMessage + RelationID uint32 + // OldTupleType + // Byte1('K'): + // Identifies the following TupleData submessage as a key. + // This field is present if the table in which the delete has happened uses an index + // as REPLICA IDENTITY. + // + // Byte1('O') + // Identifies the following TupleData message as an old tuple. + // This field is present if the table in which the delete has happened has + // REPLICA IDENTITY set to FULL. + // + // The Delete message may contain either a 'K' message part or an 'O' message part, + // but never both of them. + OldTupleType uint8 + OldTuple *TupleData +} + +// Decode decodes a message from src. +func (m *DeleteMessage) Decode(src []byte) (err error) { + if len(src) < 4 { + return m.lengthError("DeleteMessage", 4, len(src)) + } + + var low, used int + + m.RelationID, used = m.decodeUint32(src) + low += used + + m.OldTupleType = src[low] + low++ + + switch m.OldTupleType { + case DeleteMessageTupleTypeKey, DeleteMessageTupleTypeOld: + m.OldTuple = new(TupleData) + _, err = m.OldTuple.Decode(src[low:]) + if err != nil { + return m.decodeTupleDataError("DeleteMessage", "OldTuple", err) + } + default: + return m.invalidTupleTypeError("DeleteMessage", "OldTupleType", "K/O", m.OldTupleType) + } + + m.SetType(MessageTypeDelete) + + return nil +} + +// List of truncate options. +const ( + TruncateOptionCascade = uint8(1) << iota + TruncateOptionRestartIdentity +) + +// TruncateMessage is a truncate message. +type TruncateMessage struct { + baseMessage + RelationNum uint32 + Option uint8 + RelationIDs []uint32 +} + +// Decode decodes to message from src. +func (m *TruncateMessage) Decode(src []byte) (err error) { + if len(src) < 9 { + return m.lengthError("TruncateMessage", 9, len(src)) + } + + var low, used int + m.RelationNum, used = m.decodeUint32(src) + low += used + + m.Option = src[low] + low++ + + m.RelationIDs = make([]uint32, m.RelationNum) + for i := 0; i < int(m.RelationNum); i++ { + m.RelationIDs[i], used = m.decodeUint32(src[low:]) + low += used + } + + m.SetType(MessageTypeTruncate) + + return nil +} + +// LogicalDecodingMessage is a logical decoding message. +type LogicalDecodingMessage struct { + baseMessage + + LSN LSN + Transactional bool + Prefix string + Content []byte +} + +// Decode decodes a message from src. +func (m *LogicalDecodingMessage) Decode(src []byte) (err error) { + if len(src) < 14 { + return m.lengthError("LogicalDecodingMessage", 14, len(src)) + } + + var low, used int + + flags := src[low] + m.Transactional = flags == 1 + low++ + + m.LSN, used = m.decodeLSN(src[low:]) + low += used + + m.Prefix, used = m.decodeString(src[low:]) + low += used + + contentLength, used := m.decodeUint32(src[low:]) + low += used + + m.Content = src[low : low+int(contentLength)] + + m.SetType(MessageTypeMessage) + + return nil +} + +// Parse parse a logical replication message. +func Parse(data []byte) (m Message, err error) { + var decoder MessageDecoder + msgType := MessageType(data[0]) + switch msgType { + case MessageTypeRelation: + decoder = new(RelationMessage) + case MessageTypeType: + decoder = new(TypeMessage) + case MessageTypeInsert: + decoder = new(InsertMessage) + case MessageTypeUpdate: + decoder = new(UpdateMessage) + case MessageTypeDelete: + decoder = new(DeleteMessage) + case MessageTypeTruncate: + decoder = new(TruncateMessage) + case MessageTypeMessage: + decoder = new(LogicalDecodingMessage) + default: + decoder = getCommonDecoder(msgType) + } + + if decoder == nil { + return nil, errMsgNotSupported + } + + if err = decoder.Decode(data[1:]); err != nil { + return nil, err + } + + return decoder.(Message), nil +} + +func getCommonDecoder(msgType MessageType) MessageDecoder { + var decoder MessageDecoder + switch msgType { + case MessageTypeBegin: + decoder = new(BeginMessage) + case MessageTypeCommit: + decoder = new(CommitMessage) + case MessageTypeOrigin: + decoder = new(OriginMessage) + } + + return decoder +} diff --git a/internal/impl/postgresql/pglogicalstream/replication_message_decoders.go b/internal/impl/postgresql/pglogicalstream/replication_message_decoders.go new file mode 100644 index 0000000000..0c1ed3b236 --- /dev/null +++ b/internal/impl/postgresql/pglogicalstream/replication_message_decoders.go @@ -0,0 +1,153 @@ +// Copyright 2024 Redpanda Data, Inc. +// +// Licensed as a Redpanda Enterprise file under the Redpanda Community +// License (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// https://github.com/redpanda-data/connect/v4/blob/main/licenses/rcl.md + +package pglogicalstream + +import ( + "fmt" + "log" + + "github.com/jackc/pgx/v5/pgtype" +) + +// ---------------------------------------------------------------------------- +// PgOutput section + +func isBeginMessage(WALData []byte) (bool, error) { + logicalMsg, err := Parse(WALData) + if err != nil { + return false, err + } + + _, ok := logicalMsg.(*BeginMessage) + return ok, nil +} + +func isCommitMessage(WALData []byte) (bool, *CommitMessage, error) { + logicalMsg, err := Parse(WALData) + if err != nil { + return false, nil, err + } + + m, ok := logicalMsg.(*CommitMessage) + return ok, m, nil +} + +// decodePgOutput decodes a logical replication message in pgoutput format. +// It uses the provided relations map to look up the relation metadata for the +// as a side effect it updates the relations map with any new relation metadata +// When the relation is changes in the database, the relation message is sent +// before the change message. +func decodePgOutput(WALData []byte, relations map[uint32]*RelationMessage, typeMap *pgtype.Map) (*StreamMessageChanges, error) { + logicalMsg, err := Parse(WALData) + message := &StreamMessageChanges{} + + if err != nil { + return nil, err + } + switch logicalMsg := logicalMsg.(type) { + case *RelationMessage: + relations[logicalMsg.RelationID] = logicalMsg + return nil, nil + case *BeginMessage: + return nil, nil + case *CommitMessage: + return nil, nil + case *InsertMessage: + rel, ok := relations[logicalMsg.RelationID] + if !ok { + return nil, fmt.Errorf("unknown relation ID %d", logicalMsg.RelationID) + } + message.Operation = "insert" + message.Schema = rel.Namespace + message.Table = rel.RelationName + values := map[string]interface{}{} + for idx, col := range logicalMsg.Tuple.Columns { + colName := rel.Columns[idx].Name + switch col.DataType { + case 'n': // null + values[colName] = nil + case 'u': // unchanged toast + // This TOAST value was not changed. TOAST values are not stored in the tuple, and logical replication doesn't want to spend a disk read to fetch its value for you. + case 't': //text + val, err := decodeTextColumnData(typeMap, col.Data, rel.Columns[idx].DataType) + if err != nil { + return nil, err + } + values[colName] = val + } + } + message.Data = values + case *UpdateMessage: + rel, ok := relations[logicalMsg.RelationID] + if !ok { + return nil, fmt.Errorf("unknown relation ID %d", logicalMsg.RelationID) + } + message.Operation = "update" + message.Schema = rel.Namespace + message.Table = rel.RelationName + values := map[string]interface{}{} + for idx, col := range logicalMsg.NewTuple.Columns { + colName := rel.Columns[idx].Name + switch col.DataType { + case 'n': // null + values[colName] = nil + case 'u': // unchanged toast + // This TOAST value was not changed. TOAST values are not stored in the tuple, and logical replication doesn't want to spend a disk read to fetch its value for you. + case 't': //text + val, err := decodeTextColumnData(typeMap, col.Data, rel.Columns[idx].DataType) + if err != nil { + log.Fatalln("error decoding column data:", err) + } + values[colName] = val + } + } + message.Data = values + case *DeleteMessage: + rel, ok := relations[logicalMsg.RelationID] + if !ok { + return nil, fmt.Errorf("unknown relation ID %d", logicalMsg.RelationID) + } + message.Operation = "delete" + message.Schema = rel.Namespace + message.Table = rel.RelationName + values := map[string]interface{}{} + for idx, col := range logicalMsg.OldTuple.Columns { + colName := rel.Columns[idx].Name + switch col.DataType { + case 'n': // null + values[colName] = nil + case 'u': // unchanged toast + // This TOAST value was not changed. TOAST values are not stored in the tuple, and logical replication doesn't want to spend a disk read to fetch its value for you. + case 't': //text + val, err := decodeTextColumnData(typeMap, col.Data, rel.Columns[idx].DataType) + if err != nil { + log.Fatalln("error decoding column data:", err) + } + values[colName] = val + } + } + message.Data = values + case *TruncateMessage: + case *TypeMessage: + case *OriginMessage: + case *LogicalDecodingMessage: + return nil, nil + default: + return nil, nil + } + + return message, nil +} + +func decodeTextColumnData(mi *pgtype.Map, data []byte, dataType uint32) (interface{}, error) { + if dt, ok := mi.TypeForOID(dataType); ok { + return dt.Codec.DecodeValue(mi, dataType, pgtype.TextFormatCode, data) + } + return string(data), nil +} diff --git a/internal/impl/postgresql/pglogicalstream/replication_message_test.go b/internal/impl/postgresql/pglogicalstream/replication_message_test.go new file mode 100644 index 0000000000..d0c438900e --- /dev/null +++ b/internal/impl/postgresql/pglogicalstream/replication_message_test.go @@ -0,0 +1,838 @@ +// Copyright 2024 Redpanda Data, Inc. +// +// Licensed as a Redpanda Enterprise file under the Redpanda Community +// License (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// https://github.com/redpanda-data/connect/v4/blob/main/licenses/rcl.md + +package pglogicalstream + +import ( + "encoding/binary" + "math/rand" + "testing" + "time" + + "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" +) + +var bigEndian = binary.BigEndian + +type messageSuite struct { + suite.Suite +} + +func (s *messageSuite) R() *require.Assertions { + return s.Require() +} + +func (s *messageSuite) Equal(e, a interface{}, args ...interface{}) { + s.R().Equal(e, a, args...) +} + +func (s *messageSuite) NoError(err error) { + s.R().NoError(err) +} + +func (s *messageSuite) True(value bool) { + s.R().True(value) +} + +func (s *messageSuite) newLSN() LSN { + return LSN(rand.Int63()) +} + +func (s *messageSuite) newXid() uint32 { + return uint32(rand.Int31()) +} + +func (s *messageSuite) newTime() (time.Time, uint64) { + // Postgres time format only support millisecond accuracy. + now := time.Now().Truncate(time.Millisecond) + return now, uint64(timeToPgTime(now)) +} + +func (s *messageSuite) newRelationID() uint32 { + return uint32(rand.Int31()) +} + +func (s *messageSuite) putString(dst []byte, value string) int { + copy(dst, value) + dst[len(value)] = byte(0) + return len(value) + 1 +} + +func (s *messageSuite) tupleColumnLength(dataType uint8, data []byte) int { + switch dataType { + case uint8('n'), uint8('u'): + return 1 + case uint8('t'): + return 1 + 4 + len(data) + default: + s.FailNow("invalid data type of a tuple: %c", dataType) + return 0 + } +} + +func (s *messageSuite) putTupleColumn(dst []byte, dataType uint8, data []byte) int { + dst[0] = dataType + + switch dataType { + case uint8('n'), uint8('u'): + return 1 + case uint8('t'): + bigEndian.PutUint32(dst[1:], uint32(len(data))) + copy(dst[5:], data) + return 5 + len(data) + default: + s.FailNow("invalid data type of a tuple: %c", dataType) + return 0 + } +} + +func (s *messageSuite) putMessageTestData(msg []byte) *LogicalDecodingMessage { + // transaction flag + msg[0] = 1 + off := 1 + + lsn := s.newLSN() + bigEndian.PutUint64(msg[off:], uint64(lsn)) + off += 8 + + off += s.putString(msg[off:], "test") + + content := "hello" + + bigEndian.PutUint32(msg[off:], uint32(len(content))) + off += 4 + + for i := 0; i < len(content); i++ { + msg[off] = content[i] + off++ + } + return &LogicalDecodingMessage{ + Transactional: true, + LSN: lsn, + Prefix: "test", + Content: []byte("hello"), + } +} + +func (s *messageSuite) createRelationTestData() ([]byte, *RelationMessage) { + relationID := uint32(rand.Int31()) + namespace := "public" + relationName := "table1" + noAtttypmod := int32(-1) + col1 := "id" // int8 + col2 := "name" // text + col3 := "created_at" // timestamptz + + col1Length := 1 + len(col1) + 1 + 4 + 4 + col2Length := 1 + len(col2) + 1 + 4 + 4 + col3Length := 1 + len(col3) + 1 + 4 + 4 + + msg := make([]byte, 1+4+len(namespace)+1+len(relationName)+1+1+ + 2+col1Length+col2Length+col3Length) + msg[0] = 'R' + off := 1 + bigEndian.PutUint32(msg[off:], relationID) + off += 4 + off += s.putString(msg[off:], namespace) + off += s.putString(msg[off:], relationName) + msg[off] = 1 + off++ + bigEndian.PutUint16(msg[off:], 3) + off += 2 + + msg[off] = 1 // column id is key + off++ + off += s.putString(msg[off:], col1) + bigEndian.PutUint32(msg[off:], 20) // int8 + off += 4 + bigEndian.PutUint32(msg[off:], uint32(noAtttypmod)) + off += 4 + + msg[off] = 0 + off++ + off += s.putString(msg[off:], col2) + bigEndian.PutUint32(msg[off:], 25) // text + off += 4 + bigEndian.PutUint32(msg[off:], uint32(noAtttypmod)) + off += 4 + + msg[off] = 0 + off++ + off += s.putString(msg[off:], col3) + bigEndian.PutUint32(msg[off:], 1184) // timestamptz + off += 4 + bigEndian.PutUint32(msg[off:], uint32(noAtttypmod)) + + expected := &RelationMessage{ + RelationID: relationID, + Namespace: namespace, + RelationName: relationName, + ReplicaIdentity: 1, + ColumnNum: 3, + Columns: []*RelationMessageColumn{ + { + Flags: 1, + Name: col1, + DataType: 20, + TypeModifier: -1, + }, + { + Flags: 0, + Name: col2, + DataType: 25, + TypeModifier: -1, + }, + { + Flags: 0, + Name: col3, + DataType: 1184, + TypeModifier: -1, + }, + }, + } + expected.msgType = 'R' + + return msg, expected +} + +func (s *messageSuite) createTypeTestData() ([]byte, *TypeMessage) { + dataType := uint32(1184) // timestamptz + namespace := "public" + name := "created_at" + + msg := make([]byte, 1+4+len(namespace)+1+len(name)+1) + msg[0] = 'Y' + off := 1 + bigEndian.PutUint32(msg[off:], dataType) + off += 4 + off += s.putString(msg[off:], namespace) + s.putString(msg[off:], name) + + expected := &TypeMessage{ + DataType: dataType, + Namespace: namespace, + Name: name, + } + expected.msgType = 'Y' + + return msg, expected +} + +func (s *messageSuite) createInsertTestData() ([]byte, *InsertMessage) { + relationID := s.newRelationID() + + col1Data := []byte("1") + col2Data := []byte("myname") + col3Data := []byte("123456789") + col1Length := s.tupleColumnLength('t', col1Data) + col2Length := s.tupleColumnLength('t', col2Data) + col3Length := s.tupleColumnLength('t', col3Data) + col4Length := s.tupleColumnLength('n', nil) + col5Length := s.tupleColumnLength('u', nil) + + msg := make([]byte, 1+4+1+2+col1Length+col2Length+col3Length+col4Length+col5Length) + msg[0] = 'I' + off := 1 + bigEndian.PutUint32(msg[off:], relationID) + off += 4 + msg[off] = 'N' + off++ + bigEndian.PutUint16(msg[off:], 5) + off += 2 + off += s.putTupleColumn(msg[off:], 't', col1Data) + off += s.putTupleColumn(msg[off:], 't', col2Data) + off += s.putTupleColumn(msg[off:], 't', col3Data) + off += s.putTupleColumn(msg[off:], 'n', nil) + s.putTupleColumn(msg[off:], 'u', nil) + + expected := &InsertMessage{ + RelationID: relationID, + Tuple: &TupleData{ + ColumnNum: 5, + Columns: []*TupleDataColumn{ + { + DataType: TupleDataTypeText, + Length: uint32(len(col1Data)), + Data: col1Data, + }, + { + DataType: TupleDataTypeText, + Length: uint32(len(col2Data)), + Data: col2Data, + }, + { + DataType: TupleDataTypeText, + Length: uint32(len(col3Data)), + Data: col3Data, + }, + { + DataType: TupleDataTypeNull, + }, + { + DataType: TupleDataTypeToast, + }, + }, + }, + } + expected.msgType = 'I' + + return msg, expected +} + +func (s *messageSuite) createUpdateTestDataTypeK() ([]byte, *UpdateMessage) { + relationID := s.newRelationID() + + oldCol1Data := []byte("123") // like an id + oldCol1Length := s.tupleColumnLength('t', oldCol1Data) + + newCol1Data := []byte("1124") + newCol2Data := []byte("myname") + newCol1Length := s.tupleColumnLength('t', newCol1Data) + newCol2Length := s.tupleColumnLength('t', newCol2Data) + + msg := make([]byte, 1+4+ + 1+2+oldCol1Length+ + 1+2+newCol1Length+newCol2Length) + msg[0] = 'U' + off := 1 + bigEndian.PutUint32(msg[off:], relationID) + off += 4 + msg[off] = 'K' + off += 1 + bigEndian.PutUint16(msg[off:], 1) + off += 2 + off += s.putTupleColumn(msg[off:], 't', oldCol1Data) + msg[off] = 'N' + off++ + bigEndian.PutUint16(msg[off:], 2) + off += 2 + off += s.putTupleColumn(msg[off:], 't', newCol1Data) + s.putTupleColumn(msg[off:], 't', newCol2Data) + expected := &UpdateMessage{ + RelationID: relationID, + OldTupleType: UpdateMessageTupleTypeKey, + OldTuple: &TupleData{ + ColumnNum: 1, + Columns: []*TupleDataColumn{ + { + DataType: TupleDataTypeText, + Length: uint32(len(oldCol1Data)), + Data: oldCol1Data, + }, + }, + }, + NewTuple: &TupleData{ + ColumnNum: 2, + Columns: []*TupleDataColumn{ + { + DataType: TupleDataTypeText, + Length: uint32(len(newCol1Data)), + Data: newCol1Data, + }, + { + DataType: TupleDataTypeText, + Length: uint32(len(newCol2Data)), + Data: newCol2Data, + }, + }, + }, + } + expected.msgType = 'U' + + return msg, expected +} + +func (s *messageSuite) createUpdateTestDataTypeO() ([]byte, *UpdateMessage) { + relationID := s.newRelationID() + + oldCol1Data := []byte("123") // like an id + oldCol1Length := s.tupleColumnLength('t', oldCol1Data) + oldCol2Data := []byte("myoldname") + oldCol2Length := s.tupleColumnLength('t', oldCol2Data) + + newCol1Data := []byte("1124") + newCol2Data := []byte("myname") + newCol1Length := s.tupleColumnLength('t', newCol1Data) + newCol2Length := s.tupleColumnLength('t', newCol2Data) + + msg := make([]byte, 1+4+ + 1+2+oldCol1Length+oldCol2Length+ + 1+2+newCol1Length+newCol2Length) + msg[0] = 'U' + off := 1 + bigEndian.PutUint32(msg[off:], relationID) + off += 4 + msg[off] = 'O' + off += 1 + bigEndian.PutUint16(msg[off:], 2) + off += 2 + off += s.putTupleColumn(msg[off:], 't', oldCol1Data) + off += s.putTupleColumn(msg[off:], 't', oldCol2Data) + msg[off] = 'N' + off++ + bigEndian.PutUint16(msg[off:], 2) + off += 2 + off += s.putTupleColumn(msg[off:], 't', newCol1Data) + s.putTupleColumn(msg[off:], 't', newCol2Data) + expected := &UpdateMessage{ + RelationID: relationID, + OldTupleType: UpdateMessageTupleTypeOld, + OldTuple: &TupleData{ + ColumnNum: 2, + Columns: []*TupleDataColumn{ + { + DataType: TupleDataTypeText, + Length: uint32(len(oldCol1Data)), + Data: oldCol1Data, + }, + { + DataType: TupleDataTypeText, + Length: uint32(len(oldCol2Data)), + Data: oldCol2Data, + }, + }, + }, + NewTuple: &TupleData{ + ColumnNum: 2, + Columns: []*TupleDataColumn{ + { + DataType: TupleDataTypeText, + Length: uint32(len(newCol1Data)), + Data: newCol1Data, + }, + { + DataType: TupleDataTypeText, + Length: uint32(len(newCol2Data)), + Data: newCol2Data, + }, + }, + }, + } + expected.msgType = 'U' + + return msg, expected +} + +func (s *messageSuite) createUpdateTestDataWithoutOldTuple() ([]byte, *UpdateMessage) { + relationID := s.newRelationID() + + newCol1Data := []byte("1124") + newCol2Data := []byte("myname") + newCol1Length := s.tupleColumnLength('t', newCol1Data) + newCol2Length := s.tupleColumnLength('t', newCol2Data) + + msg := make([]byte, 1+4+ + 1+2+newCol1Length+newCol2Length) + msg[0] = 'U' + off := 1 + bigEndian.PutUint32(msg[off:], relationID) + off += 4 + msg[off] = 'N' + off++ + bigEndian.PutUint16(msg[off:], 2) + off += 2 + off += s.putTupleColumn(msg[off:], 't', newCol1Data) + s.putTupleColumn(msg[off:], 't', newCol2Data) + expected := &UpdateMessage{ + RelationID: relationID, + OldTupleType: UpdateMessageTupleTypeNone, + NewTuple: &TupleData{ + ColumnNum: 2, + Columns: []*TupleDataColumn{ + { + DataType: TupleDataTypeText, + Length: uint32(len(newCol1Data)), + Data: newCol1Data, + }, + { + DataType: TupleDataTypeText, + Length: uint32(len(newCol2Data)), + Data: newCol2Data, + }, + }, + }, + } + expected.msgType = 'U' + + return msg, expected +} + +func (s *messageSuite) createDeleteTestDataTypeK() ([]byte, *DeleteMessage) { + relationID := s.newRelationID() + + oldCol1Data := []byte("123") // like an id + oldCol1Length := s.tupleColumnLength('t', oldCol1Data) + + msg := make([]byte, 1+4+ + 1+2+oldCol1Length) + msg[0] = 'D' + off := 1 + bigEndian.PutUint32(msg[off:], relationID) + off += 4 + msg[off] = 'K' + off++ + bigEndian.PutUint16(msg[off:], 1) + off += 2 + s.putTupleColumn(msg[off:], 't', oldCol1Data) + expected := &DeleteMessage{ + RelationID: relationID, + OldTupleType: DeleteMessageTupleTypeKey, + OldTuple: &TupleData{ + ColumnNum: 1, + Columns: []*TupleDataColumn{ + { + DataType: TupleDataTypeText, + Length: uint32(len(oldCol1Data)), + Data: oldCol1Data, + }, + }, + }, + } + expected.msgType = 'D' + return msg, expected +} + +func (s *messageSuite) createDeleteTestDataTypeO() ([]byte, *DeleteMessage) { + relationID := s.newRelationID() + + oldCol1Data := []byte("123") // like an id + oldCol1Length := s.tupleColumnLength('t', oldCol1Data) + oldCol2Data := []byte("myoldname") + oldCol2Length := s.tupleColumnLength('t', oldCol2Data) + + msg := make([]byte, 1+4+ + 1+2+oldCol1Length+oldCol2Length) + msg[0] = 'D' + off := 1 + bigEndian.PutUint32(msg[off:], relationID) + off += 4 + msg[off] = 'O' + off += 1 + bigEndian.PutUint16(msg[off:], 2) + off += 2 + off += s.putTupleColumn(msg[off:], 't', oldCol1Data) + s.putTupleColumn(msg[off:], 't', oldCol2Data) + expected := &DeleteMessage{ + RelationID: relationID, + OldTupleType: DeleteMessageTupleTypeOld, + OldTuple: &TupleData{ + ColumnNum: 2, + Columns: []*TupleDataColumn{ + { + DataType: TupleDataTypeText, + Length: uint32(len(oldCol1Data)), + Data: oldCol1Data, + }, + { + DataType: TupleDataTypeText, + Length: uint32(len(oldCol2Data)), + Data: oldCol2Data, + }, + }, + }, + } + expected.msgType = 'D' + return msg, expected +} + +func (s *messageSuite) createTruncateTestData() ([]byte, *TruncateMessage) { + relationID1 := s.newRelationID() + relationID2 := s.newRelationID() + option := uint8(0x01 | 0x02) + + msg := make([]byte, 1+4+1+4*2) + msg[0] = 'T' + off := 1 + bigEndian.PutUint32(msg[off:], 2) + off += 4 + msg[off] = option + off++ + bigEndian.PutUint32(msg[off:], relationID1) + off += 4 + bigEndian.PutUint32(msg[off:], relationID2) + expected := &TruncateMessage{ + RelationNum: 2, + Option: TruncateOptionCascade | TruncateOptionRestartIdentity, + RelationIDs: []uint32{ + relationID1, + relationID2, + }, + } + expected.msgType = 'T' + return msg, expected +} + +func TestBeginMessageSuite(t *testing.T) { + suite.Run(t, new(beginMessageSuite)) +} + +type beginMessageSuite struct { + messageSuite +} + +func (s *beginMessageSuite) Test() { + finalLSN := s.newLSN() + commitTime, pgCommitTime := s.newTime() + xid := s.newXid() + + msg := make([]byte, 1+8+8+4) + msg[0] = 'B' + bigEndian.PutUint64(msg[1:], uint64(finalLSN)) + bigEndian.PutUint64(msg[9:], pgCommitTime) + bigEndian.PutUint32(msg[17:], xid) + + m, err := Parse(msg) + s.NoError(err) + beginMsg, ok := m.(*BeginMessage) + s.True(ok) + + expected := &BeginMessage{ + FinalLSN: finalLSN, + CommitTime: commitTime, + Xid: xid, + } + expected.msgType = 'B' + s.Equal(expected, beginMsg) +} + +func TestCommitMessage(t *testing.T) { + suite.Run(t, new(commitMessageSuite)) +} + +type commitMessageSuite struct { + messageSuite +} + +func (s *commitMessageSuite) Test() { + flags := uint8(0) + commitLSN := s.newLSN() + transactionEndLSN := s.newLSN() + commitTime, pgCommitTime := s.newTime() + + msg := make([]byte, 1+1+8+8+8) + msg[0] = 'C' + msg[1] = flags + bigEndian.PutUint64(msg[2:], uint64(commitLSN)) + bigEndian.PutUint64(msg[10:], uint64(transactionEndLSN)) + bigEndian.PutUint64(msg[18:], pgCommitTime) + + m, err := Parse(msg) + s.NoError(err) + commitMsg, ok := m.(*CommitMessage) + s.True(ok) + + expected := &CommitMessage{ + Flags: 0, + CommitLSN: commitLSN, + TransactionEndLSN: transactionEndLSN, + CommitTime: commitTime, + } + expected.msgType = 'C' + s.Equal(expected, commitMsg) +} + +func TestOriginMessage(t *testing.T) { + suite.Run(t, new(originMessageSuite)) +} + +type originMessageSuite struct { + messageSuite +} + +func (s *originMessageSuite) Test() { + commitLSN := s.newLSN() + name := "someorigin" + + msg := make([]byte, 1+8+len(name)+1) // 1 byte for \0 + msg[0] = 'O' + bigEndian.PutUint64(msg[1:], uint64(commitLSN)) + s.putString(msg[9:], name) + + m, err := Parse(msg) + s.NoError(err) + originMsg, ok := m.(*OriginMessage) + s.True(ok) + + expected := &OriginMessage{ + CommitLSN: commitLSN, + Name: name, + } + expected.msgType = 'O' + s.Equal(expected, originMsg) +} + +func TestRelationMessageSuite(t *testing.T) { + suite.Run(t, new(relationMessageSuite)) +} + +type relationMessageSuite struct { + messageSuite +} + +func (s *relationMessageSuite) Test() { + + msg, expected := s.createRelationTestData() + + m, err := Parse(msg) + s.NoError(err) + relationMsg, ok := m.(*RelationMessage) + s.True(ok) + + s.Equal(expected, relationMsg) +} + +func TestTypeMessageSuite(t *testing.T) { + suite.Run(t, new(typeMessageSuite)) +} + +type typeMessageSuite struct { + messageSuite +} + +func (s *typeMessageSuite) Test() { + msg, expected := s.createTypeTestData() + + m, err := Parse(msg) + s.NoError(err) + typeMsg, ok := m.(*TypeMessage) + s.True(ok) + + s.Equal(expected, typeMsg) +} + +func TestInsertMessageSuite(t *testing.T) { + suite.Run(t, new(insertMessageSuite)) +} + +type insertMessageSuite struct { + messageSuite +} + +func (s *insertMessageSuite) Test() { + + msg, expected := s.createInsertTestData() + + m, err := Parse(msg) + s.NoError(err) + insertMsg, ok := m.(*InsertMessage) + s.True(ok) + + s.Equal(expected, insertMsg) +} + +func TestUpdateMessageSuite(t *testing.T) { + suite.Run(t, new(updateMessageSuite)) +} + +type updateMessageSuite struct { + messageSuite +} + +func (s *updateMessageSuite) TestWithOldTupleTypeK() { + msg, expected := s.createUpdateTestDataTypeK() + m, err := Parse(msg) + s.NoError(err) + updateMsg, ok := m.(*UpdateMessage) + s.True(ok) + + s.Equal(expected, updateMsg) +} + +func (s *updateMessageSuite) TestWithOldTupleTypeO() { + msg, expected := s.createUpdateTestDataTypeO() + m, err := Parse(msg) + s.NoError(err) + updateMsg, ok := m.(*UpdateMessage) + s.True(ok) + + s.Equal(expected, updateMsg) +} + +func (s *updateMessageSuite) TestWithoutOldTuple() { + msg, expected := s.createUpdateTestDataWithoutOldTuple() + m, err := Parse(msg) + s.NoError(err) + updateMsg, ok := m.(*UpdateMessage) + s.True(ok) + + s.Equal(expected, updateMsg) +} + +func TestDeleteMessageSuite(t *testing.T) { + suite.Run(t, new(deleteMessageSuite)) +} + +type deleteMessageSuite struct { + messageSuite +} + +func (s *deleteMessageSuite) TestWithOldTupleTypeK() { + msg, expected := s.createDeleteTestDataTypeK() + + m, err := Parse(msg) + s.NoError(err) + deleteMsg, ok := m.(*DeleteMessage) + s.True(ok) + + s.Equal(expected, deleteMsg) +} + +func (s *deleteMessageSuite) TestWithOldTupleTypeO() { + msg, expected := s.createDeleteTestDataTypeO() + + m, err := Parse(msg) + s.NoError(err) + deleteMsg, ok := m.(*DeleteMessage) + s.True(ok) + + s.Equal(expected, deleteMsg) +} + +func TestTruncateMessageSuite(t *testing.T) { + suite.Run(t, new(truncateMessageSuite)) +} + +type truncateMessageSuite struct { + messageSuite +} + +func (s *truncateMessageSuite) Test() { + msg, expected := s.createTruncateTestData() + + m, err := Parse(msg) + s.NoError(err) + truncateMsg, ok := m.(*TruncateMessage) + s.True(ok) + + s.Equal(expected, truncateMsg) +} + +func TestLogicalDecodingMessageSuite(t *testing.T) { + suite.Run(t, new(logicalDecodingMessageSuite)) +} + +type logicalDecodingMessageSuite struct { + messageSuite +} + +func (s *logicalDecodingMessageSuite) Test() { + msg := make([]byte, 1+1+8+5+4+5) + msg[0] = 'M' + + expected := s.putMessageTestData(msg[1:]) + + expected.msgType = MessageTypeMessage + + m, err := Parse(msg) + s.NoError(err) + logicalDecodingMsg, ok := m.(*LogicalDecodingMessage) + s.True(ok) + + s.Equal(expected, logicalDecodingMsg) +} diff --git a/internal/impl/postgresql/pglogicalstream/sanitize/sanitize.go b/internal/impl/postgresql/pglogicalstream/sanitize/sanitize.go new file mode 100644 index 0000000000..5bba854bbb --- /dev/null +++ b/internal/impl/postgresql/pglogicalstream/sanitize/sanitize.go @@ -0,0 +1,390 @@ +// Copyright (c) 2013-2021 Jack Christensen +// +// MIT License +// +// Permission is hereby granted, free of charge, to any person obtaining +// a copy of this software and associated documentation files (the +// "Software"), to deal in the Software without restriction, including +// without limitation the rights to use, copy, modify, merge, publish, +// distribute, sublicense, and/or sell copies of the Software, and to +// permit persons to whom the Software is furnished to do so, subject to +// the following conditions: +// +// The above copyright notice and this permission notice shall be +// included in all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +// EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +// MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND +// NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE +// LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION +// OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION +// WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +// An import of sanitization code from pgx/internal/sanitize so that we +// can sanitize +package sanitize + +import ( + "bytes" + "encoding/hex" + "errors" + "fmt" + "strconv" + "strings" + "time" + "unicode" + "unicode/utf8" +) + +// MaxIdentifierLength is PostgreSQL's maximum identifier length +const MaxIdentifierLength = 63 + +// Part is either a string or an int. A string is raw SQL. An int is a +// argument placeholder. +type Part any + +// Query represents a SQL query that consists of []Part +type Query struct { + Parts []Part +} + +// utf.DecodeRune returns the utf8.RuneError for errors. But that is actually rune U+FFFD -- the unicode replacement +// character. utf8.RuneError is not an error if it is also width 3. +// +// https://github.com/jackc/pgx/issues/1380 +const replacementcharacterwidth = 3 + +// Sanitize sanitizes a SQL query +func (q *Query) Sanitize(args ...any) (string, error) { + argUse := make([]bool, len(args)) + buf := &bytes.Buffer{} + + for _, part := range q.Parts { + var str string + switch part := part.(type) { + case string: + str = part + case int: + argIdx := part - 1 + + if argIdx < 0 { + return "", errors.New("first sql argument must be > 0") + } + + if argIdx >= len(args) { + return "", errors.New("insufficient arguments") + } + arg := args[argIdx] + switch arg := arg.(type) { + case nil: + str = "null" + case int64: + str = strconv.FormatInt(arg, 10) + case float64: + str = strconv.FormatFloat(arg, 'f', -1, 64) + case bool: + str = strconv.FormatBool(arg) + case []byte: + str = quoteBytes(arg) + case string: + str = quoteString(arg) + case time.Time: + str = arg.Truncate(time.Microsecond).Format("'2006-01-02 15:04:05.999999999Z07:00:00'") + default: + return "", fmt.Errorf("invalid arg type: %T", arg) + } + argUse[argIdx] = true + + // Prevent SQL injection via Line Comment Creation + // https://github.com/jackc/pgx/security/advisories/GHSA-m7wr-2xf7-cm9p + str = " " + str + " " + default: + return "", fmt.Errorf("invalid Part type: %T", part) + } + buf.WriteString(str) + } + + for i, used := range argUse { + if !used { + return "", fmt.Errorf("unused argument: %d", i) + } + } + return buf.String(), nil +} + +// NewQuery parses a SQL query string and returns a Query object. +func NewQuery(sql string) (*Query, error) { + l := &sqlLexer{ + src: sql, + stateFn: rawState, + } + + for l.stateFn != nil { + l.stateFn = l.stateFn(l) + } + + query := &Query{Parts: l.parts} + + return query, nil +} + +func quoteString(str string) string { + return "'" + strings.ReplaceAll(str, "'", "''") + "'" +} + +func quoteBytes(buf []byte) string { + return `'\x` + hex.EncodeToString(buf) + "'" +} + +type sqlLexer struct { + src string + start int + pos int + nested int // multiline comment nesting level. + stateFn stateFn + parts []Part +} + +type stateFn func(*sqlLexer) stateFn + +func rawState(l *sqlLexer) stateFn { + for { + r, width := utf8.DecodeRuneInString(l.src[l.pos:]) + l.pos += width + + switch r { + case 'e', 'E': + nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:]) + if nextRune == '\'' { + l.pos += width + return escapeStringState + } + case '\'': + return singleQuoteState + case '"': + return doubleQuoteState + case '$': + nextRune, _ := utf8.DecodeRuneInString(l.src[l.pos:]) + if '0' <= nextRune && nextRune <= '9' { + if l.pos-l.start > 0 { + l.parts = append(l.parts, l.src[l.start:l.pos-width]) + } + l.start = l.pos + return placeholderState + } + case '-': + nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:]) + if nextRune == '-' { + l.pos += width + return oneLineCommentState + } + case '/': + nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:]) + if nextRune == '*' { + l.pos += width + return multilineCommentState + } + case utf8.RuneError: + if width != replacementcharacterwidth { + if l.pos-l.start > 0 { + l.parts = append(l.parts, l.src[l.start:l.pos]) + l.start = l.pos + } + return nil + } + } + } +} + +func singleQuoteState(l *sqlLexer) stateFn { + for { + r, width := utf8.DecodeRuneInString(l.src[l.pos:]) + l.pos += width + + switch r { + case '\'': + nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:]) + if nextRune != '\'' { + return rawState + } + l.pos += width + case utf8.RuneError: + if width != replacementcharacterwidth { + if l.pos-l.start > 0 { + l.parts = append(l.parts, l.src[l.start:l.pos]) + l.start = l.pos + } + return nil + } + } + } +} + +func doubleQuoteState(l *sqlLexer) stateFn { + for { + r, width := utf8.DecodeRuneInString(l.src[l.pos:]) + l.pos += width + + switch r { + case '"': + nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:]) + if nextRune != '"' { + return rawState + } + l.pos += width + case utf8.RuneError: + if width != replacementcharacterwidth { + if l.pos-l.start > 0 { + l.parts = append(l.parts, l.src[l.start:l.pos]) + l.start = l.pos + } + return nil + } + } + } +} + +// placeholderState consumes a placeholder value. The $ must have already has +// already been consumed. The first rune must be a digit. +func placeholderState(l *sqlLexer) stateFn { + num := 0 + + for { + r, width := utf8.DecodeRuneInString(l.src[l.pos:]) + l.pos += width + + if '0' <= r && r <= '9' { + num *= 10 + num += int(r - '0') + } else { + l.parts = append(l.parts, num) + l.pos -= width + l.start = l.pos + return rawState + } + } +} + +func escapeStringState(l *sqlLexer) stateFn { + for { + r, width := utf8.DecodeRuneInString(l.src[l.pos:]) + l.pos += width + + switch r { + case '\\': + _, width = utf8.DecodeRuneInString(l.src[l.pos:]) + l.pos += width + case '\'': + nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:]) + if nextRune != '\'' { + return rawState + } + l.pos += width + case utf8.RuneError: + if width != replacementcharacterwidth { + if l.pos-l.start > 0 { + l.parts = append(l.parts, l.src[l.start:l.pos]) + l.start = l.pos + } + return nil + } + } + } +} + +func oneLineCommentState(l *sqlLexer) stateFn { + for { + r, width := utf8.DecodeRuneInString(l.src[l.pos:]) + l.pos += width + + switch r { + case '\\': + _, width = utf8.DecodeRuneInString(l.src[l.pos:]) + l.pos += width + case '\n', '\r': + return rawState + case utf8.RuneError: + if width != replacementcharacterwidth { + if l.pos-l.start > 0 { + l.parts = append(l.parts, l.src[l.start:l.pos]) + l.start = l.pos + } + return nil + } + } + } +} + +func multilineCommentState(l *sqlLexer) stateFn { + for { + r, width := utf8.DecodeRuneInString(l.src[l.pos:]) + l.pos += width + + switch r { + case '/': + nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:]) + if nextRune == '*' { + l.pos += width + l.nested++ + } + case '*': + nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:]) + if nextRune != '/' { + continue + } + + l.pos += width + if l.nested == 0 { + return rawState + } + l.nested-- + + case utf8.RuneError: + if width != replacementcharacterwidth { + if l.pos-l.start > 0 { + l.parts = append(l.parts, l.src[l.start:l.pos]) + l.start = l.pos + } + return nil + } + } + } +} + +// SQLQuery replaces placeholder values with args. It quotes and escapes args +// as necessary. This function is only safe when standard_conforming_strings is +// on. +func SQLQuery(sql string, args ...any) (string, error) { + query, err := NewQuery(sql) + if err != nil { + return "", err + } + return query.Sanitize(args...) +} + +// ValidatePostgresIdentifier checks if a string is a valid PostgreSQL identifier +// This follows PostgreSQL's standard naming rules +func ValidatePostgresIdentifier(name string) error { + if len(name) == 0 { + return errors.New("empty identifier is not allowed") + } + + if len(name) > MaxIdentifierLength { + return fmt.Errorf("identifier length exceeds maximum of %d characters", MaxIdentifierLength) + } + + // First character must be a letter or underscore + if !unicode.IsLetter(rune(name[0])) && name[0] != '_' { + return errors.New("identifier must start with a letter or underscore") + } + + // Subsequent characters must be letters, numbers, underscores, or dots + for i, char := range name { + if !unicode.IsLetter(char) && !unicode.IsDigit(char) && char != '_' && char != '.' { + return fmt.Errorf("invalid character '%c' at position %d in identifier '%s'", char, i, name) + } + } + + return nil +} diff --git a/internal/impl/postgresql/pglogicalstream/sanitize/sanitize_test.go b/internal/impl/postgresql/pglogicalstream/sanitize/sanitize_test.go new file mode 100644 index 0000000000..ba87ba5eaa --- /dev/null +++ b/internal/impl/postgresql/pglogicalstream/sanitize/sanitize_test.go @@ -0,0 +1,252 @@ +// Copyright (c) 2013-2021 Jack Christensen +// +// MIT License +// +// Permission is hereby granted, free of charge, to any person obtaining +// a copy of this software and associated documentation files (the +// "Software"), to deal in the Software without restriction, including +// without limitation the rights to use, copy, modify, merge, publish, +// distribute, sublicense, and/or sell copies of the Software, and to +// permit persons to whom the Software is furnished to do so, subject to +// the following conditions: +// +// The above copyright notice and this permission notice shall be +// included in all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +// EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +// MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND +// NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE +// LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION +// OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION +// WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +package sanitize_test + +import ( + "testing" + "time" + + "github.com/redpanda-data/connect/v4/internal/impl/postgresql/pglogicalstream/sanitize" +) + +func TestNewQuery(t *testing.T) { + successTests := []struct { + sql string + expected sanitize.Query + }{ + { + sql: "select 42", + expected: sanitize.Query{Parts: []sanitize.Part{"select 42"}}, + }, + { + sql: "select $1", + expected: sanitize.Query{Parts: []sanitize.Part{"select ", 1}}, + }, + { + sql: "select 'quoted $42', $1", + expected: sanitize.Query{Parts: []sanitize.Part{"select 'quoted $42', ", 1}}, + }, + { + sql: `select "doubled quoted $42", $1`, + expected: sanitize.Query{Parts: []sanitize.Part{`select "doubled quoted $42", `, 1}}, + }, + { + sql: "select 'foo''bar', $1", + expected: sanitize.Query{Parts: []sanitize.Part{"select 'foo''bar', ", 1}}, + }, + { + sql: `select "foo""bar", $1`, + expected: sanitize.Query{Parts: []sanitize.Part{`select "foo""bar", `, 1}}, + }, + { + sql: "select '''', $1", + expected: sanitize.Query{Parts: []sanitize.Part{"select '''', ", 1}}, + }, + { + sql: `select """", $1`, + expected: sanitize.Query{Parts: []sanitize.Part{`select """", `, 1}}, + }, + { + sql: "select $1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11", + expected: sanitize.Query{Parts: []sanitize.Part{"select ", 1, ", ", 2, ", ", 3, ", ", 4, ", ", 5, ", ", 6, ", ", 7, ", ", 8, ", ", 9, ", ", 10, ", ", 11}}, + }, + { + sql: `select "adsf""$1""adsf", $1, 'foo''$$12bar', $2, '$3'`, + expected: sanitize.Query{Parts: []sanitize.Part{`select "adsf""$1""adsf", `, 1, `, 'foo''$$12bar', `, 2, `, '$3'`}}, + }, + { + sql: `select E'escape string\' $42', $1`, + expected: sanitize.Query{Parts: []sanitize.Part{`select E'escape string\' $42', `, 1}}, + }, + { + sql: `select e'escape string\' $42', $1`, + expected: sanitize.Query{Parts: []sanitize.Part{`select e'escape string\' $42', `, 1}}, + }, + { + sql: `select /* a baby's toy */ 'barbie', $1`, + expected: sanitize.Query{Parts: []sanitize.Part{`select /* a baby's toy */ 'barbie', `, 1}}, + }, + { + sql: `select /* *_* */ $1`, + expected: sanitize.Query{Parts: []sanitize.Part{`select /* *_* */ `, 1}}, + }, + { + sql: `select 42 /* /* /* 42 */ */ */, $1`, + expected: sanitize.Query{Parts: []sanitize.Part{`select 42 /* /* /* 42 */ */ */, `, 1}}, + }, + { + sql: "select -- a baby's toy\n'barbie', $1", + expected: sanitize.Query{Parts: []sanitize.Part{"select -- a baby's toy\n'barbie', ", 1}}, + }, + { + sql: "select 42 -- is a Deep Thought's favorite number", + expected: sanitize.Query{Parts: []sanitize.Part{"select 42 -- is a Deep Thought's favorite number"}}, + }, + { + sql: "select 42, -- \\nis a Deep Thought's favorite number\n$1", + expected: sanitize.Query{Parts: []sanitize.Part{"select 42, -- \\nis a Deep Thought's favorite number\n", 1}}, + }, + { + sql: "select 42, -- \\nis a Deep Thought's favorite number\r$1", + expected: sanitize.Query{Parts: []sanitize.Part{"select 42, -- \\nis a Deep Thought's favorite number\r", 1}}, + }, + { + // https://github.com/jackc/pgx/issues/1380 + sql: "select 'hello w�rld'", + expected: sanitize.Query{Parts: []sanitize.Part{"select 'hello w�rld'"}}, + }, + { + // Unterminated quoted string + sql: "select 'hello world", + expected: sanitize.Query{Parts: []sanitize.Part{"select 'hello world"}}, + }, + } + + for i, tt := range successTests { + query, err := sanitize.NewQuery(tt.sql) + if err != nil { + t.Errorf("%d. %v", i, err) + } + + if len(query.Parts) == len(tt.expected.Parts) { + for j := range query.Parts { + if query.Parts[j] != tt.expected.Parts[j] { + t.Errorf("%d. expected part %d to be %v but it was %v", i, j, tt.expected.Parts[j], query.Parts[j]) + } + } + } else { + t.Errorf("%d. expected query parts to be %v but it was %v", i, tt.expected.Parts, query.Parts) + } + } +} + +func TestQuerySanitize(t *testing.T) { + successfulTests := []struct { + query sanitize.Query + args []any + expected string + }{ + { + query: sanitize.Query{Parts: []sanitize.Part{"select 42"}}, + args: []any{}, + expected: `select 42`, + }, + { + query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}}, + args: []any{int64(42)}, + expected: `select 42 `, + }, + { + query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}}, + args: []any{float64(1.23)}, + expected: `select 1.23 `, + }, + { + query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}}, + args: []any{true}, + expected: `select true `, + }, + { + query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}}, + args: []any{[]byte{0, 1, 2, 3, 255}}, + expected: `select '\x00010203ff' `, + }, + { + query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}}, + args: []any{nil}, + expected: `select null `, + }, + { + query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}}, + args: []any{"foobar"}, + expected: `select 'foobar' `, + }, + { + query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}}, + args: []any{"foo'bar"}, + expected: `select 'foo''bar' `, + }, + { + query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}}, + args: []any{`foo\'bar`}, + expected: `select 'foo\''bar' `, + }, + { + query: sanitize.Query{Parts: []sanitize.Part{"insert ", 1}}, + args: []any{time.Date(2020, time.March, 1, 23, 59, 59, 999999999, time.UTC)}, + expected: `insert '2020-03-01 23:59:59.999999Z' `, + }, + { + query: sanitize.Query{Parts: []sanitize.Part{"select 1-", 1}}, + args: []any{int64(-1)}, + expected: `select 1- -1 `, + }, + { + query: sanitize.Query{Parts: []sanitize.Part{"select 1-", 1}}, + args: []any{float64(-1)}, + expected: `select 1- -1 `, + }, + } + + for i, tt := range successfulTests { + actual, err := tt.query.Sanitize(tt.args...) + if err != nil { + t.Errorf("%d. %v", i, err) + continue + } + + if tt.expected != actual { + t.Errorf("%d. expected %s, but got %s", i, tt.expected, actual) + } + } + + errorTests := []struct { + query sanitize.Query + args []any + expected string + }{ + { + query: sanitize.Query{Parts: []sanitize.Part{"select ", 1, ", ", 2}}, + args: []any{int64(42)}, + expected: `insufficient arguments`, + }, + { + query: sanitize.Query{Parts: []sanitize.Part{"select 'foo'"}}, + args: []any{int64(42)}, + expected: `unused argument: 0`, + }, + { + query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}}, + args: []any{42}, + expected: `invalid arg type: int`, + }, + } + + for i, tt := range errorTests { + _, err := tt.query.Sanitize(tt.args...) + if err == nil || err.Error() != tt.expected { + t.Errorf("%d. expected error %v, got %v", i, tt.expected, err) + } + } +} diff --git a/internal/impl/postgresql/pglogicalstream/snapshotter.go b/internal/impl/postgresql/pglogicalstream/snapshotter.go new file mode 100644 index 0000000000..8f6c737da6 --- /dev/null +++ b/internal/impl/postgresql/pglogicalstream/snapshotter.go @@ -0,0 +1,239 @@ +// Copyright 2024 Redpanda Data, Inc. +// +// Licensed as a Redpanda Enterprise file under the Redpanda Community +// License (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// https://github.com/redpanda-data/connect/v4/blob/main/licenses/rcl.md + +package pglogicalstream + +import ( + "context" + "database/sql" + "fmt" + "strings" + + "errors" + + _ "github.com/lib/pq" + "github.com/redpanda-data/benthos/v4/public/service" + + "github.com/redpanda-data/connect/v4/internal/impl/postgresql/pglogicalstream/sanitize" +) + +// SnapshotCreationResponse is a structure that contains the name of the snapshot that was created +type SnapshotCreationResponse struct { + ExportedSnapshotName string +} + +// Snapshotter is a structure that allows the creation of a snapshot of a database at a given point in time +// At the time we initialize logical replication - we specify what we want to export the snapshot. +// This snapshot exists until the connection that created the replication slot remains open. +// Therefore Snapshotter opens another connection to the database and sets the transaction to the snapshot. +// This allows you to read the data that was in the database at the time of the snapshot creation. +type Snapshotter struct { + pgConnection *sql.DB + snapshotCreateConnection *sql.DB + logger *service.Logger + + snapshotName string + + version int +} + +// NewSnapshotter creates a new Snapshotter instance +func NewSnapshotter(dbDSN string, logger *service.Logger, version int) (*Snapshotter, error) { + pgConn, err := openPgConnectionFromConfig(dbDSN) + if err != nil { + return nil, err + } + + snapshotCreateConnection, err := openPgConnectionFromConfig(dbDSN) + if err != nil { + return nil, err + } + + return &Snapshotter{ + pgConnection: pgConn, + snapshotCreateConnection: snapshotCreateConnection, + logger: logger, + version: version, + }, nil +} + +func (s *Snapshotter) initSnapshotTransaction() (SnapshotCreationResponse, error) { + if s.version > 14 { + return SnapshotCreationResponse{}, errors.New("snapshot is exported by default for versions above PG14") + } + + var snapshotName sql.NullString + + snapshotRow, err := s.pgConnection.Query(`BEGIN; SELECT pg_export_snapshot();`) + if err != nil { + return SnapshotCreationResponse{}, fmt.Errorf("cant get exported snapshot for initial streaming %w pg version: %d", err, s.version) + } + + if snapshotRow.Err() != nil { + return SnapshotCreationResponse{}, fmt.Errorf("can get avg row size due to query failure: %w", snapshotRow.Err()) + } + + if snapshotRow.Next() { + if err = snapshotRow.Scan(&snapshotName); err != nil { + return SnapshotCreationResponse{}, fmt.Errorf("cant scan snapshot name into string: %w", err) + } + } else { + return SnapshotCreationResponse{}, errors.New("cant get avg row size; 0 rows returned") + } + + return SnapshotCreationResponse{ExportedSnapshotName: snapshotName.String}, nil +} + +func (s *Snapshotter) setTransactionSnapshotName(snapshotName string) { + s.snapshotName = snapshotName +} + +func (s *Snapshotter) prepare() error { + if s.snapshotName == "" { + return errors.New("snapshot name is not set") + } + + if _, err := s.pgConnection.Exec("BEGIN TRANSACTION ISOLATION LEVEL REPEATABLE READ;"); err != nil { + return err + } + + sq, err := sanitize.SQLQuery("SET TRANSACTION SNAPSHOT $1;", s.snapshotName) + if err != nil { + return err + } + + if _, err := s.pgConnection.Exec(sq); err != nil { + return err + } + + return nil +} + +func (s *Snapshotter) findAvgRowSize(ctx context.Context, table string) (sql.NullInt64, error) { + var ( + avgRowSize sql.NullInt64 + rows *sql.Rows + err error + ) + + // table is validated to be correct pg identifier, so we can use it directly + if rows, err = s.pgConnection.QueryContext(ctx, fmt.Sprintf(`SELECT SUM(pg_column_size('%s.*')) / COUNT(*) FROM %s;`, table, table)); err != nil { + return avgRowSize, fmt.Errorf("can get avg row size due to query failure: %w", err) + } + + if rows.Err() != nil { + return avgRowSize, fmt.Errorf("can get avg row size due to query failure: %w", rows.Err()) + } + + if rows.Next() { + if err = rows.Scan(&avgRowSize); err != nil { + return avgRowSize, fmt.Errorf("can get avg row size: %w", err) + } + } else { + return avgRowSize, errors.New("can get avg row size; 0 rows returned") + } + + return avgRowSize, nil +} + +func (s *Snapshotter) prepareScannersAndGetters(columnTypes []*sql.ColumnType) ([]interface{}, []func(interface{}) interface{}) { + scanArgs := make([]interface{}, len(columnTypes)) + valueGetters := make([]func(interface{}) interface{}, len(columnTypes)) + + for i, v := range columnTypes { + switch v.DatabaseTypeName() { + case "VARCHAR", "TEXT", "UUID", "TIMESTAMP": + scanArgs[i] = new(sql.NullString) + valueGetters[i] = func(v interface{}) interface{} { return v.(*sql.NullString).String } + case "BOOL": + scanArgs[i] = new(sql.NullBool) + valueGetters[i] = func(v interface{}) interface{} { return v.(*sql.NullBool).Bool } + case "INT4": + scanArgs[i] = new(sql.NullInt64) + valueGetters[i] = func(v interface{}) interface{} { return v.(*sql.NullInt64).Int64 } + default: + scanArgs[i] = new(sql.NullString) + valueGetters[i] = func(v interface{}) interface{} { return v.(*sql.NullString).String } + } + } + + return scanArgs, valueGetters +} + +func (s *Snapshotter) calculateBatchSize(availableMemory uint64, estimatedRowSize uint64) int { + // Adjust this factor based on your system's memory constraints. + // This example uses a safety factor of 0.8 to leave some memory headroom. + safetyFactor := 0.6 + batchSize := int(float64(availableMemory) * safetyFactor / float64(estimatedRowSize)) + if batchSize < 1 { + batchSize = 1 + } + + return batchSize +} + +func (s *Snapshotter) querySnapshotData(ctx context.Context, table string, lastSeenPk map[string]any, pkColumns []string, limit int) (rows *sql.Rows, err error) { + + s.logger.Debugf("Query snapshot table: %v, limit: %v, lastSeenPkVal: %v, pk: %v", table, limit, lastSeenPk, pkColumns) + + if lastSeenPk == nil { + // NOTE: All strings passed into here have been validated or derived from the code/database, therefore not prone to SQL injection. + sq, err := sanitize.SQLQuery(fmt.Sprintf("SELECT * FROM %s ORDER BY %s LIMIT %d;", table, strings.Join(pkColumns, ", "), limit)) + if err != nil { + return nil, err + } + + return s.pgConnection.QueryContext(ctx, sq) + } + + var ( + placeholders []string + lastSeenPksValues []any + i = 1 + ) + + for _, col := range pkColumns { + placeholders = append(placeholders, fmt.Sprintf("$%d", i)) + i++ + lastSeenPksValues = append(lastSeenPksValues, lastSeenPk[col]) + } + + lastSeenPlaceHolders := "(" + strings.Join(placeholders, ", ") + ")" + pkAsTuple := "(" + strings.Join(pkColumns, ", ") + ")" + + // NOTE: All strings passed into here have been validated or derived from the code/database, therefore not prone to SQL injection. + sq, err := sanitize.SQLQuery(fmt.Sprintf("SELECT * FROM %s WHERE %s > %s ORDER BY %s LIMIT %d;", table, pkAsTuple, lastSeenPlaceHolders, strings.Join(pkColumns, ", "), limit), lastSeenPksValues...) + if err != nil { + return nil, err + } + + return s.pgConnection.QueryContext(ctx, sq) +} + +func (s *Snapshotter) releaseSnapshot() error { + if s.version < 14 && s.snapshotCreateConnection != nil { + if _, err := s.snapshotCreateConnection.Exec("COMMIT;"); err != nil { + return err + } + } + + _, err := s.pgConnection.Exec("COMMIT;") + return err +} + +func (s *Snapshotter) closeConn() error { + if s.pgConnection != nil { + return s.pgConnection.Close() + } + + if s.snapshotCreateConnection != nil { + return s.snapshotCreateConnection.Close() + } + + return nil +} diff --git a/internal/impl/postgresql/pglogicalstream/stream_message.go b/internal/impl/postgresql/pglogicalstream/stream_message.go new file mode 100644 index 0000000000..99ebf1acd0 --- /dev/null +++ b/internal/impl/postgresql/pglogicalstream/stream_message.go @@ -0,0 +1,44 @@ +// Copyright 2024 Redpanda Data, Inc. +// +// Licensed as a Redpanda Enterprise file under the Redpanda Community +// License (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// https://github.com/redpanda-data/connect/v4/blob/main/licenses/rcl.md + +package pglogicalstream + +// StreamMessageChanges represents the changes in a single message +// Single message can have multiple changes +type StreamMessageChanges struct { + Operation string `json:"operation"` + Schema string `json:"schema"` + Table string `json:"table"` + TableSnapshotProgress *float64 `json:"table_snapshot_progress,omitempty"` + // For deleted messages - there will be old changes if replica identity set to full or empty changes + Data map[string]any `json:"data"` +} + +// StreamMessageMetrics represents the metrics of a stream. Passed to each message +type StreamMessageMetrics struct { + WALLagBytes *int64 `json:"wal_lag_bytes"` + IsStreaming bool `json:"is_streaming"` +} + +// StreamMode represents the mode of the stream at the time of the message +type StreamMode string + +const ( + // StreamModeStreaming indicates that the stream is in streaming mode + StreamModeStreaming StreamMode = "streaming" + // StreamModeSnapshot indicates that the stream is in snapshot mode + StreamModeSnapshot StreamMode = "snapshot" +) + +// StreamMessage represents a single message after it has been decoded by the plugin +type StreamMessage struct { + Lsn *string `json:"lsn"` + Changes []StreamMessageChanges `json:"changes"` + Mode StreamMode `json:"mode"` + WALLagBytes *int64 `json:"wal_lag_bytes,omitempty"` +} diff --git a/internal/impl/postgresql/pglogicalstream/types.go b/internal/impl/postgresql/pglogicalstream/types.go new file mode 100644 index 0000000000..2d1d0ff3ad --- /dev/null +++ b/internal/impl/postgresql/pglogicalstream/types.go @@ -0,0 +1,9 @@ +// Copyright 2024 Redpanda Data, Inc. +// +// Licensed as a Redpanda Enterprise file under the Redpanda Community +// License (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// https://github.com/redpanda-data/connect/v4/blob/main/licenses/rcl.md + +package pglogicalstream diff --git a/internal/impl/postgresql/pglogicalstream/watermark/watermark.go b/internal/impl/postgresql/pglogicalstream/watermark/watermark.go new file mode 100644 index 0000000000..56dc30783b --- /dev/null +++ b/internal/impl/postgresql/pglogicalstream/watermark/watermark.go @@ -0,0 +1,70 @@ +/* + * Copyright 2024 Redpanda Data, Inc. + * + * Licensed as a Redpanda Enterprise file under the Redpanda Community + * License (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * https://github.com/redpanda-data/redpanda/blob/master/licenses/rcl.md + */ + +package watermark + +import ( + "cmp" + "sync" +) + +// Value is a utility that allows you to store the highest value and subscribe to when +// a specific offset is reached +type ( + Value[T cmp.Ordered] struct { + val T + mu sync.Mutex + waiters map[chan<- any]T + } +) + +// New makes a new Value holding `initial` +func New[T cmp.Ordered](initial T) *Value[T] { + w := &Value[T]{val: initial} + w.waiters = map[chan<- any]T{} + return w +} + +// Set the watermark value if it's newer +func (w *Value[T]) Set(v T) { + w.mu.Lock() + defer w.mu.Unlock() + if v <= w.val { + return + } + w.val = v + for notify, val := range w.waiters { + if val <= w.val { + notify <- nil + delete(w.waiters, notify) + } + } +} + +// Get the current watermark value +func (w *Value[T]) Get() T { + w.mu.Lock() + cpy := w.val + w.mu.Unlock() + return cpy +} + +// WaitFor returns a channel that recieves a value when the watermark reaches `val`. +func (w *Value[T]) WaitFor(val T) <-chan any { + w.mu.Lock() + defer w.mu.Unlock() + ch := make(chan any, 1) + if w.val >= val { + ch <- nil + return ch + } + w.waiters[ch] = val + return ch +} diff --git a/internal/impl/postgresql/pglogicalstream/watermark/watermark_test.go b/internal/impl/postgresql/pglogicalstream/watermark/watermark_test.go new file mode 100644 index 0000000000..637deff653 --- /dev/null +++ b/internal/impl/postgresql/pglogicalstream/watermark/watermark_test.go @@ -0,0 +1,53 @@ +/* + * Copyright 2024 Redpanda Data, Inc. + * + * Licensed as a Redpanda Enterprise file under the Redpanda Community + * License (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * https://github.com/redpanda-data/redpanda/blob/master/licenses/rcl.md + */ + +package watermark_test + +import ( + "testing" + + "github.com/stretchr/testify/require" + + "github.com/redpanda-data/connect/v4/internal/impl/postgresql/pglogicalstream/watermark" +) + +func TestWatermark(t *testing.T) { + w := watermark.New(5) + require.Equal(t, 5, w.Get()) + w.Set(3) + require.Equal(t, 5, w.Get()) + require.Len(t, w.WaitFor(1), 1) + ch1 := w.WaitFor(9) + ch2 := w.WaitFor(10) + ch3 := w.WaitFor(10) + ch4 := w.WaitFor(100) + require.Empty(t, ch1) + require.Empty(t, ch2) + require.Empty(t, ch3) + require.Empty(t, ch4) + w.Set(8) + require.Equal(t, 8, w.Get()) + require.Empty(t, ch1) + require.Empty(t, ch2) + require.Empty(t, ch3) + require.Empty(t, ch4) + w.Set(9) + require.Equal(t, 9, w.Get()) + require.Len(t, ch1, 1) + require.Empty(t, ch2) + require.Empty(t, ch3) + require.Empty(t, ch4) + w.Set(10) + require.Equal(t, 10, w.Get()) + require.Len(t, ch1, 1) + require.Len(t, ch2, 1) + require.Len(t, ch3, 1) + require.Empty(t, ch4) +} diff --git a/internal/impl/postgresql/utils.go b/internal/impl/postgresql/utils.go new file mode 100644 index 0000000000..d01bf441cd --- /dev/null +++ b/internal/impl/postgresql/utils.go @@ -0,0 +1,54 @@ +// Copyright 2024 Redpanda Data, Inc. +// +// Licensed as a Redpanda Enterprise file under the Redpanda Community +// License (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// https://github.com/redpanda-data/connect/v4/blob/main/licenses/rcl.md + +package pgstream + +import ( + "fmt" + "strconv" + "strings" +) + +// LSNToInt64 converts a PostgreSQL LSN string to int64 +func LSNToInt64(lsn string) (int64, error) { + // Split the LSN into segments + parts := strings.Split(lsn, "/") + if len(parts) != 2 { + return 0, fmt.Errorf("invalid LSN format: %s", lsn) + } + + // Parse both segments as hex with uint64 first + upper, err := strconv.ParseUint(parts[0], 16, 32) + if err != nil { + return 0, fmt.Errorf("failed to parse upper part: %w", err) + } + + lower, err := strconv.ParseUint(parts[1], 16, 32) + if err != nil { + return 0, fmt.Errorf("failed to parse lower part: %w", err) + } + + // Combine the segments into a single int64 + // Upper part is shifted left by 32 bits + result := int64((upper << 32) | lower) + + return result, nil +} + +// Int64ToLSN converts an int64 to a PostgreSQL LSN string +func Int64ToLSN(value int64) string { + // Convert to uint64 to handle the bitwise operations properly + uvalue := uint64(value) + + // Extract upper and lower parts + upper := uvalue >> 32 + lower := uvalue & 0xFFFFFFFF + + // Format as hexadecimal with proper padding + return fmt.Sprintf("%X/%X", upper, lower) +} diff --git a/internal/impl/redis/rate_limit_integration_test.go b/internal/impl/redis/rate_limit_integration_test.go index 19dee16685..92098b85ff 100644 --- a/internal/impl/redis/rate_limit_integration_test.go +++ b/internal/impl/redis/rate_limit_integration_test.go @@ -24,10 +24,9 @@ import ( "github.com/ory/dockertest/v3" "github.com/redis/go-redis/v9" + "github.com/redpanda-data/benthos/v4/public/service/integration" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - - "github.com/redpanda-data/benthos/v4/public/service/integration" ) func TestIntegrationRedisRateLimit(t *testing.T) { diff --git a/internal/plugins/info.csv b/internal/plugins/info.csv index a1c0199ad2..1727bc002f 100644 --- a/internal/plugins/info.csv +++ b/internal/plugins/info.csv @@ -168,6 +168,7 @@ parquet ,processor ,parquet ,3.62.0 ,commun parquet_decode ,processor ,parquet_decode ,4.4.0 ,certified ,n ,y ,y parquet_encode ,processor ,parquet_encode ,4.4.0 ,certified ,n ,y ,y parse_log ,processor ,parse_log ,0.0.0 ,community ,n ,y ,y +pg_stream ,input ,pg_stream ,0.0.0 ,community ,n ,n ,n pinecone ,output ,pinecone ,4.31.0 ,certified ,n ,y ,y processors ,processor ,processors ,0.0.0 ,certified ,n ,y ,y prometheus ,metric ,prometheus ,0.0.0 ,certified ,n ,y ,y diff --git a/public/components/all/package.go b/public/components/all/package.go index d950cc3e5a..0b7b3a6c3e 100644 --- a/public/components/all/package.go +++ b/public/components/all/package.go @@ -23,6 +23,7 @@ import ( _ "github.com/redpanda-data/connect/v4/public/components/kafka/enterprise" _ "github.com/redpanda-data/connect/v4/public/components/ollama" _ "github.com/redpanda-data/connect/v4/public/components/openai" + _ "github.com/redpanda-data/connect/v4/public/components/postgresql" _ "github.com/redpanda-data/connect/v4/public/components/snowflake" _ "github.com/redpanda-data/connect/v4/public/components/splunk" ) diff --git a/public/components/postgresql/package.go b/public/components/postgresql/package.go new file mode 100644 index 0000000000..275ee41b31 --- /dev/null +++ b/public/components/postgresql/package.go @@ -0,0 +1,16 @@ +/* + * Copyright 2024 Redpanda Data, Inc. + * + * Licensed as a Redpanda Enterprise file under the Redpanda Community + * License (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * https://github.com/redpanda-data/redpanda/blob/master/licenses/rcl.md + */ + +package postgresql + +import ( + // Bring in the internal plugin definitions. + _ "github.com/redpanda-data/connect/v4/internal/impl/postgresql" +)