Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix postgres_cdc input #3075

Merged
merged 8 commits into from
Dec 13, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ All notable changes to this project will be documented in this file.
### Fixed

- `gcp_bigquery` output with parquet format no longer returns errors incorrectly. (@rockwotj)
- `postgres_cdc` input now allows quoted identifiers for the table names. (@mihaitodor)

## 4.43.1 - 2024-12-09

Expand Down
2 changes: 1 addition & 1 deletion internal/impl/postgresql/pglogicalstream/logical_stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ func NewPgStream(ctx context.Context, config *Config) (*Stream, error) {

pubName := "pglog_stream_" + config.ReplicationSlotName
stream.logger.Infof("Creating publication %s for tables: %s", pubName, tableNames)
if err = CreatePublication(ctx, stream.pgConn, pubName, tableNames); err != nil {
if err = CreatePublication(ctx, stream.pgConn, pubName, config.DBSchema, tableNames); err != nil {
return nil, err
}
cleanups = append(cleanups, func() {
Expand Down
11 changes: 6 additions & 5 deletions internal/impl/postgresql/pglogicalstream/pglogrepl.go
Original file line number Diff line number Diff line change
Expand Up @@ -343,7 +343,7 @@ func DropReplicationSlot(ctx context.Context, conn *pgconn.PgConn, slotName stri
}

// CreatePublication creates a new PostgreSQL publication with the given name for a list of tables and drop if exists flag
func CreatePublication(ctx context.Context, conn *pgconn.PgConn, publicationName string, tables []string) error {
func CreatePublication(ctx context.Context, conn *pgconn.PgConn, publicationName string, schema string, tables []string) error {
// Check if publication exists
pubQuery, err := sanitize.SQLQuery(`
SELECT pubname, puballtables
Expand Down Expand Up @@ -403,7 +403,7 @@ func CreatePublication(ctx context.Context, conn *pgconn.PgConn, publicationName

// assuming publication already exists
// get a list of tables in the publication
pubTables, forAllTables, err := GetPublicationTables(ctx, conn, publicationName)
pubTables, forAllTables, err := GetPublicationTables(ctx, conn, publicationName, schema)
if err != nil {
return fmt.Errorf("failed to get publication tables: %w", err)
}
Expand All @@ -430,7 +430,7 @@ func CreatePublication(ctx context.Context, conn *pgconn.PgConn, publicationName

// remove tables from publication
for _, dropTable := range tablesToRemoveFromPublication {
sq, err := sanitize.SQLQuery(fmt.Sprintf("ALTER PUBLICATION %s DROP TABLE %s;", publicationName, dropTable))
sq, err := sanitize.SQLQuery(fmt.Sprintf(`ALTER PUBLICATION %s DROP TABLE %s."%s";`, publicationName, schema, dropTable))
if err != nil {
return fmt.Errorf("failed to sanitize drop table query: %w", err)
}
Expand All @@ -457,14 +457,15 @@ func CreatePublication(ctx context.Context, conn *pgconn.PgConn, publicationName

// GetPublicationTables returns a list of tables currently in the publication
// Arguments, in order: list of the tables, exist for all tables, errror
func GetPublicationTables(ctx context.Context, conn *pgconn.PgConn, publicationName string) ([]string, bool, error) {
func GetPublicationTables(ctx context.Context, conn *pgconn.PgConn, publicationName, schema string) ([]string, bool, error) {
query, err := sanitize.SQLQuery(`
SELECT DISTINCT
tablename as table_name
FROM pg_publication_tables
WHERE pubname = $1
AND schemaname = $2
rockwotj marked this conversation as resolved.
Show resolved Hide resolved
ORDER BY table_name;
`, publicationName)
`, publicationName, strings.Trim(schema, "\""))
if err != nil {
return nil, false, fmt.Errorf("failed to get publication tables: %w", err)
}
Expand Down
73 changes: 55 additions & 18 deletions internal/impl/postgresql/pglogicalstream/pglogrepl_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -233,10 +233,11 @@ func TestIntegrationCreatePublication(t *testing.T) {
defer closeConn(t, conn)

publicationName := "test_publication"
err = CreatePublication(context.Background(), conn, publicationName, []string{})
schema := "public"
err = CreatePublication(context.Background(), conn, publicationName, schema, []string{})
require.NoError(t, err)

tables, forAllTables, err := GetPublicationTables(context.Background(), conn, publicationName)
tables, forAllTables, err := GetPublicationTables(context.Background(), conn, publicationName, schema)
require.NoError(t, err)
assert.Empty(t, tables)
assert.True(t, forAllTables)
Expand All @@ -246,58 +247,93 @@ func TestIntegrationCreatePublication(t *testing.T) {
require.NoError(t, err)

publicationWithTables := "test_pub_with_tables"
err = CreatePublication(context.Background(), conn, publicationWithTables, []string{"test_table"})
err = CreatePublication(context.Background(), conn, publicationWithTables, schema, []string{"test_table"})
require.NoError(t, err)

tables, forAllTables, err = GetPublicationTables(context.Background(), conn, publicationName)
tables, forAllTables, err = GetPublicationTables(context.Background(), conn, publicationName, schema)
require.NoError(t, err)
assert.NotEmpty(t, tables)
assert.Contains(t, tables, "test_table")
assert.False(t, forAllTables)

// add more tables to publication
// Add more tables to publication
multiReader = conn.Exec(context.Background(), "CREATE TABLE test_table2 (id serial PRIMARY KEY, name text);")
_, err = multiReader.ReadAll()
require.NoError(t, err)

// Pass more tables to the publication
err = CreatePublication(context.Background(), conn, publicationWithTables, []string{
err = CreatePublication(context.Background(), conn, publicationWithTables, schema, []string{
"test_table2",
"test_table",
})
require.NoError(t, err)

tables, forAllTables, err = GetPublicationTables(context.Background(), conn, publicationWithTables)
tables, forAllTables, err = GetPublicationTables(context.Background(), conn, publicationWithTables, schema)
require.NoError(t, err)
assert.NotEmpty(t, tables)
assert.Contains(t, tables, "test_table")
assert.Contains(t, tables, "test_table2")
assert.False(t, forAllTables)

// Removing one table from the publication
err = CreatePublication(context.Background(), conn, publicationWithTables, []string{
// Remove one table from the publication
err = CreatePublication(context.Background(), conn, publicationWithTables, schema, []string{
"test_table",
})
require.NoError(t, err)

tables, forAllTables, err = GetPublicationTables(context.Background(), conn, publicationWithTables)
tables, forAllTables, err = GetPublicationTables(context.Background(), conn, publicationWithTables, schema)
require.NoError(t, err)
assert.NotEmpty(t, tables)
assert.Contains(t, tables, "test_table")
assert.False(t, forAllTables)

// Add one table and remove one at the same time
err = CreatePublication(context.Background(), conn, publicationWithTables, []string{
err = CreatePublication(context.Background(), conn, publicationWithTables, schema, []string{
"test_table2",
})
require.NoError(t, err)

tables, forAllTables, err = GetPublicationTables(context.Background(), conn, publicationWithTables)
tables, forAllTables, err = GetPublicationTables(context.Background(), conn, publicationWithTables, schema)
require.NoError(t, err)
assert.NotEmpty(t, tables)
assert.Contains(t, tables, "test_table2")
assert.False(t, forAllTables)

// Create a schema with a quoted identifier
caseSensitiveSchema := `"FooBar"`
multiReader = conn.Exec(context.Background(), fmt.Sprintf("CREATE SCHEMA %s;", caseSensitiveSchema))
_, err = multiReader.ReadAll()
require.NoError(t, err)

caseSensitiveTable := `"Foo"`
multiReader = conn.Exec(context.Background(), fmt.Sprintf("CREATE TABLE %s.%s (id serial PRIMARY KEY, name text);", caseSensitiveSchema, caseSensitiveTable))
_, err = multiReader.ReadAll()
require.NoError(t, err)

caseSensitiveTable2 := `"Bar"`
multiReader = conn.Exec(context.Background(), fmt.Sprintf("CREATE TABLE %s.%s (id serial PRIMARY KEY, name text);", caseSensitiveSchema, caseSensitiveTable2))
_, err = multiReader.ReadAll()
require.NoError(t, err)

// Pass tables to the schema with quoted identifiers
publicationQuotedIdentifiers := "quoted_identifiers"
err = CreatePublication(context.Background(), conn, publicationQuotedIdentifiers, caseSensitiveSchema, []string{
caseSensitiveSchema + "." + caseSensitiveTable,
caseSensitiveSchema + "." + caseSensitiveTable2,
})
require.NoError(t, err)

// Remove one table with a quoted identifier from the publication
err = CreatePublication(context.Background(), conn, publicationQuotedIdentifiers, caseSensitiveSchema, []string{
caseSensitiveSchema + "." + caseSensitiveTable,
})
require.NoError(t, err)

tables, forAllTables, err = GetPublicationTables(context.Background(), conn, publicationQuotedIdentifiers, caseSensitiveSchema)
require.NoError(t, err)
assert.NotEmpty(t, tables)
assert.Contains(t, tables, "Foo")
assert.False(t, forAllTables)
}

func TestIntegrationStartReplication(t *testing.T) {
Expand All @@ -321,7 +357,8 @@ func TestIntegrationStartReplication(t *testing.T) {

// create publication
publicationName := "test_publication"
err = CreatePublication(context.Background(), conn, publicationName, []string{})
schema := "public"
err = CreatePublication(context.Background(), conn, publicationName, schema, []string{})
require.NoError(t, err)

_, err = CreateReplicationSlot(ctx, conn, slotName, outputPlugin, CreateReplicationSlotOptions{Temporary: false, SnapshotAction: "export"}, 16, nil)
Expand Down Expand Up @@ -416,35 +453,35 @@ drop table t;
require.NoError(t, err)
jsonData, err := json.Marshal(&streamMessage)
require.NoError(t, err)
assert.JSONEq(t, `{"operation":"insert","schema":"public","table":"t","mode":"streaming","lsn":null,"data":{"id":1, "name":"foo"}}`, string(jsonData))
assert.JSONEq(t, `{"operation":"insert","schema":"public","table":"t","lsn":null,"data":{"id":1, "name":"foo"}}`, string(jsonData))

xld = rxXLogData()
streamMessage, err = decodePgOutput(xld.WALData, relations, typeMap)
require.NoError(t, err)
jsonData, err = json.Marshal(&streamMessage)
require.NoError(t, err)
assert.JSONEq(t, `{"operation":"insert","schema":"public","table":"t","mode":"streaming","lsn":null,"data":{"id":2,"name":"bar"}}`, string(jsonData))
assert.JSONEq(t, `{"operation":"insert","schema":"public","table":"t","lsn":null,"data":{"id":2,"name":"bar"}}`, string(jsonData))

xld = rxXLogData()
streamMessage, err = decodePgOutput(xld.WALData, relations, typeMap)
require.NoError(t, err)
jsonData, err = json.Marshal(&streamMessage)
require.NoError(t, err)
assert.JSONEq(t, `{"operation":"insert","schema":"public","table":"t","mode":"streaming","lsn":null,"data":{"id":3,"name":"baz"}}`, string(jsonData))
assert.JSONEq(t, `{"operation":"insert","schema":"public","table":"t","lsn":null,"data":{"id":3,"name":"baz"}}`, string(jsonData))

xld = rxXLogData()
streamMessage, err = decodePgOutput(xld.WALData, relations, typeMap)
require.NoError(t, err)
jsonData, err = json.Marshal(&streamMessage)
require.NoError(t, err)
assert.JSONEq(t, `{"operation":"update","schema":"public","table":"t","mode":"streaming","lsn":null,"data":{"id":3,"name":"quz"}}`, string(jsonData))
assert.JSONEq(t, `{"operation":"update","schema":"public","table":"t","lsn":null,"data":{"id":3,"name":"quz"}}`, string(jsonData))

xld = rxXLogData()
streamMessage, err = decodePgOutput(xld.WALData, relations, typeMap)
require.NoError(t, err)
jsonData, err = json.Marshal(&streamMessage)
require.NoError(t, err)
assert.JSONEq(t, `{"operation":"delete","schema":"public","table":"t","mode":"streaming","lsn":null,"data":{"id":2,"name":null}}`, string(jsonData))
assert.JSONEq(t, `{"operation":"delete","schema":"public","table":"t","lsn":null,"data":{"id":2,"name":null}}`, string(jsonData))
xld = rxXLogData()

commit, _, err := isCommitMessage(xld.WALData)
Expand Down
12 changes: 12 additions & 0 deletions internal/impl/postgresql/pglogicalstream/sanitize/sanitize.go
Original file line number Diff line number Diff line change
Expand Up @@ -366,6 +366,18 @@ func SQLQuery(sql string, args ...any) (string, error) {
// ValidatePostgresIdentifier checks if a string is a valid PostgreSQL identifier
// This follows PostgreSQL's standard naming rules
func ValidatePostgresIdentifier(name string) error {
if parts := strings.Split(name, "."); len(parts) == 2 {
rockwotj marked this conversation as resolved.
Show resolved Hide resolved
if err := ValidatePostgresIdentifier(parts[0]); err != nil {
return fmt.Errorf("invalid schema identifier: %s", err)
}
name = parts[1]
}

// Strip quotes if they are present
if strings.HasPrefix(name, "\"") && strings.HasSuffix(name, "\"") {
name = strings.Trim(name, "\"")
rockwotj marked this conversation as resolved.
Show resolved Hide resolved
}

if len(name) == 0 {
return errors.New("empty identifier is not allowed")
}
Expand Down
Loading