Skip to content

Commit

Permalink
pgcdc: introduce TableFQN type
Browse files Browse the repository at this point in the history
There is currently a mess of "what does this string mean?", which means
it's time to introduce some typesafety to this problem.

TableFQN is a Schema+Table pair that is prevalidated to not have SQL
injection opportunities and we can pass these around to make things a
bit more clear as well as ensure we're handling quoted stuff correctly.
  • Loading branch information
rockwotj committed Dec 13, 2024
1 parent 781d5bc commit d8f02c4
Show file tree
Hide file tree
Showing 8 changed files with 140 additions and 128 deletions.
2 changes: 1 addition & 1 deletion internal/impl/postgresql/input_pg_stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -321,7 +321,7 @@ func (p *pgStreamInput) processStream(pgStream *pglogicalstream.Stream, batcher
// Periodically collect stats
report := pgStream.GetProgress()
for name, progress := range report.TableProgress {
p.snapshotMetrics.SetFloat64(progress, name)
p.snapshotMetrics.SetFloat64(progress, name.String())
}
p.replicationLag.Set(report.WalLagInBytes)
})
Expand Down
47 changes: 22 additions & 25 deletions internal/impl/postgresql/pglogicalstream/logical_stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,6 @@ import (
"database/sql"
"errors"
"fmt"
"slices"
"strings"
"sync"
"time"

Expand Down Expand Up @@ -46,12 +44,10 @@ type Stream struct {
messages chan StreamMessage
errors chan error

includeTxnMarkers bool
snapshotName string
slotName string
schema string
// includes schema
tableQualifiedName []string
includeTxnMarkers bool
snapshotName string
slotName string
tables []TableFQN
snapshotBatchSize int
decodingPluginArguments []string
snapshotMemorySafetyFactor float64
Expand Down Expand Up @@ -90,13 +86,16 @@ func NewPgStream(ctx context.Context, config *Config) (*Stream, error) {
return nil, err
}

tableNames := slices.Clone(config.DBTables)
for i, table := range tableNames {
if err := sanitize.ValidatePostgresIdentifier(config.DBSchema); err != nil {
return nil, fmt.Errorf("invalid schema name %q: %w", config.DBSchema, err)
}

tables := []TableFQN{}
for _, table := range config.DBTables {
if err := sanitize.ValidatePostgresIdentifier(table); err != nil {
return nil, fmt.Errorf("invalid table name %q: %w", table, err)
}

tableNames[i] = fmt.Sprintf("%s.%s", config.DBSchema, table)
tables = append(tables, TableFQN{Schema: config.DBSchema, Table: table})
}
stream := &Stream{
pgConn: dbConn,
Expand All @@ -105,8 +104,7 @@ func NewPgStream(ctx context.Context, config *Config) (*Stream, error) {
slotName: config.ReplicationSlotName,
snapshotMemorySafetyFactor: config.SnapshotMemorySafetyFactor,
snapshotBatchSize: config.BatchSize,
schema: config.DBSchema,
tableQualifiedName: tableNames,
tables: tables,
maxParallelSnapshotTables: config.MaxParallelSnapshotTables,
logger: config.Logger,
shutSig: shutdown.NewSignaller(),
Expand Down Expand Up @@ -143,8 +141,8 @@ func NewPgStream(ctx context.Context, config *Config) (*Stream, error) {
stream.decodingPluginArguments = pluginArguments

pubName := "pglog_stream_" + config.ReplicationSlotName
stream.logger.Infof("Creating publication %s for tables: %s", pubName, tableNames)
if err = CreatePublication(ctx, stream.pgConn, pubName, config.DBSchema, tableNames); err != nil {
stream.logger.Infof("Creating publication %s for tables: %s", pubName, tables)
if err = CreatePublication(ctx, stream.pgConn, pubName, tables); err != nil {
return nil, err
}
cleanups = append(cleanups, func() {
Expand Down Expand Up @@ -219,7 +217,7 @@ func NewPgStream(ctx context.Context, config *Config) (*Stream, error) {
stream.standbyMessageTimeout = config.PgStandbyTimeout
stream.nextStandbyMessageDeadline = time.Now().Add(stream.standbyMessageTimeout)

monitor, err := NewMonitor(ctx, config.DBRawDSN, stream.logger, tableNames, stream.slotName, config.WalMonitorInterval)
monitor, err := NewMonitor(ctx, config.DBRawDSN, stream.logger, tables, stream.slotName, config.WalMonitorInterval)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -481,7 +479,7 @@ func (s *Stream) processSnapshot() error {
var wg errgroup.Group
wg.SetLimit(s.maxParallelSnapshotTables)

for _, table := range s.tableQualifiedName {
for _, table := range s.tables {
tableName := table
wg.Go(func() (err error) {
s.logger.Debugf("Processing snapshot for table: %v", table)
Expand Down Expand Up @@ -551,7 +549,6 @@ func (s *Stream) processSnapshot() error {
totalScanDuration := time.Duration(0)
totalWaitingFromBenthos := time.Duration(0)

tableWithoutSchema := strings.Split(table, ".")[1]
for snapshotRows.Next() {
rowsCount += 1

Expand Down Expand Up @@ -581,13 +578,13 @@ func (s *Stream) processSnapshot() error {
snapshotChangePacket := StreamMessage{
LSN: nil,
Operation: ReadOpType,
Table: tableWithoutSchema,
Schema: s.schema,
Table: table.Table,
Schema: table.Schema,
Data: data,
}

if rowsCount%100 == 0 {
s.monitor.UpdateSnapshotProgressForTable(tableWithoutSchema, rowsCount+offset)
s.monitor.UpdateSnapshotProgressForTable(table, rowsCount+offset)
}

waitingFromBenthos := time.Now()
Expand Down Expand Up @@ -627,7 +624,7 @@ func (s *Stream) Errors() chan error {
return s.errors
}

func (s *Stream) getPrimaryKeyColumn(ctx context.Context, tableName string) (map[string]any, []string, error) {
func (s *Stream) getPrimaryKeyColumn(ctx context.Context, table TableFQN) (map[string]any, []string, error) {
/// Query to get all primary key columns in their correct order
q, err := sanitize.SQLQuery(`
SELECT a.attname
Expand All @@ -637,7 +634,7 @@ func (s *Stream) getPrimaryKeyColumn(ctx context.Context, tableName string) (map
WHERE i.indrelid = $1::regclass
AND i.indisprimary
ORDER BY array_position(i.indkey, a.attnum);
`, tableName)
`, table.String())

if err != nil {
return nil, nil, fmt.Errorf("failed to sanitize query: %w", err)
Expand All @@ -650,7 +647,7 @@ func (s *Stream) getPrimaryKeyColumn(ctx context.Context, tableName string) (map
}

if len(data) == 0 || len(data[0].Rows) == 0 {
return nil, nil, fmt.Errorf("no primary key found for table %s", tableName)
return nil, nil, fmt.Errorf("no primary key found for table %s", table)
}

// Extract all primary key column names
Expand Down
33 changes: 12 additions & 21 deletions internal/impl/postgresql/pglogicalstream/monitor.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,23 +21,22 @@ import (
"github.com/redpanda-data/benthos/v4/public/service"

"github.com/redpanda-data/connect/v4/internal/asyncroutine"
"github.com/redpanda-data/connect/v4/internal/impl/postgresql/pglogicalstream/sanitize"
)

// Report is a structure that contains the current state of the Monitor
type Report struct {
WalLagInBytes int64
TableProgress map[string]float64
TableProgress map[TableFQN]float64
}

// Monitor is a structure that allows monitoring the progress of snapshot ingestion and replication lag
type Monitor struct {
// tableStat contains numbers of rows for each table determined at the moment of the snapshot creation
// this is used to calculate snapshot ingestion progress
tableStat map[string]int64
tableStat map[TableFQN]int64
lock sync.Mutex
// snapshotProgress is a map of table names to the percentage of rows ingested from the snapshot
snapshotProgress map[string]float64
snapshotProgress map[TableFQN]float64
// replicationLagInBytes is the replication lag in bytes measured by
// finding the difference between the latest LSN and the last confirmed LSN for the replication slot
replicationLagInBytes int64
Expand All @@ -53,7 +52,7 @@ func NewMonitor(
ctx context.Context,
dbDSN string,
logger *service.Logger,
tables []string,
tables []TableFQN,
slotName string,
interval time.Duration,
) (*Monitor, error) {
Expand All @@ -66,7 +65,7 @@ func NewMonitor(
}

m := &Monitor{
snapshotProgress: map[string]float64{},
snapshotProgress: map[TableFQN]float64{},
replicationLagInBytes: 0,
dbConn: dbConn,
slotName: slotName,
Expand All @@ -81,40 +80,32 @@ func NewMonitor(
}

// UpdateSnapshotProgressForTable updates the snapshot ingestion progress for a given table
func (m *Monitor) UpdateSnapshotProgressForTable(table string, position int) {
func (m *Monitor) UpdateSnapshotProgressForTable(table TableFQN, position int) {
m.lock.Lock()
defer m.lock.Unlock()
m.snapshotProgress[table] = math.Round(float64(position) / float64(m.tableStat[table]) * 100)
}

// we need to read the tables stat to calculate the snapshot ingestion progress
func (m *Monitor) readTablesStat(ctx context.Context, tables []string) error {
results := make(map[string]int64)
func (m *Monitor) readTablesStat(ctx context.Context, tables []TableFQN) error {
results := make(map[TableFQN]int64)

for _, table := range tables {
tableWithoutSchema := strings.Split(table, ".")[1]
err := sanitize.ValidatePostgresIdentifier(tableWithoutSchema)

if err != nil {
return fmt.Errorf("error sanitizing query: %w", err)
}

var count int64
// tableWithoutSchema has been validated so its safe to use in the query
err = m.dbConn.QueryRowContext(ctx, "SELECT COUNT(*) FROM "+tableWithoutSchema).Scan(&count)
err := m.dbConn.QueryRowContext(ctx, "SELECT COUNT(*) FROM "+table.String()).Scan(&count)

if err != nil {
// If the error is because the table doesn't exist, we'll set the count to 0
// and continue. You might want to log this situation.
if strings.Contains(err.Error(), "does not exist") {
results[tableWithoutSchema] = 0
results[table] = 0
continue
}
// For any other error, we'll return it
return fmt.Errorf("error counting rows in table %s: %w", tableWithoutSchema, err)
return fmt.Errorf("error counting rows in table %s: %w", table, err)
}

results[tableWithoutSchema] = count
results[table] = count
}

m.tableStat = results
Expand Down
59 changes: 24 additions & 35 deletions internal/impl/postgresql/pglogicalstream/pglogrepl.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ import (
"context"
"database/sql/driver"
"encoding/binary"
"errors"
"fmt"
"slices"
"strconv"
Expand Down Expand Up @@ -343,7 +342,7 @@ func DropReplicationSlot(ctx context.Context, conn *pgconn.PgConn, slotName stri
}

// CreatePublication creates a new PostgreSQL publication with the given name for a list of tables and drop if exists flag
func CreatePublication(ctx context.Context, conn *pgconn.PgConn, publicationName string, schema string, tables []string) error {
func CreatePublication(ctx context.Context, conn *pgconn.PgConn, publicationName string, tables []TableFQN) error {
// Check if publication exists
pubQuery, err := sanitize.SQLQuery(`
SELECT pubname, puballtables
Expand All @@ -354,17 +353,6 @@ func CreatePublication(ctx context.Context, conn *pgconn.PgConn, publicationName
return fmt.Errorf("failed to sanitize publication query: %w", err)
}

// Since we need to pass table names without quoting, we need to validate it
for _, table := range tables {
if err := sanitize.ValidatePostgresIdentifier(table); err != nil {
return errors.New("invalid table name")
}
}
// the same for publication name
if err := sanitize.ValidatePostgresIdentifier(publicationName); err != nil {
return errors.New("invalid publication name")
}

result := conn.Exec(ctx, pubQuery)

rows, err := result.ReadAll()
Expand All @@ -374,16 +362,13 @@ func CreatePublication(ctx context.Context, conn *pgconn.PgConn, publicationName

tablesClause := "FOR ALL TABLES"
if len(tables) > 0 {
// quotedTables := make([]string, len(tables))
// for i, table := range tables {
// // Use sanitize.SQLIdentifier to properly quote and escape table names
// quoted, err := sanitize.SQLIdentifier(table)
// if err != nil {
// return fmt.Errorf("invalid table name %q: %w", table, err)
// }
// quotedTables[i] = quoted
// }
tablesClause = "FOR TABLE " + strings.Join(tables, ", ")
tablesClause = "FOR TABLE "
for i, table := range tables {
if i > 0 {
tablesClause += ", "
}
tablesClause += table.String()
}
}

if len(rows) == 0 || len(rows[0].Rows) == 0 {
Expand All @@ -403,7 +388,7 @@ func CreatePublication(ctx context.Context, conn *pgconn.PgConn, publicationName

// assuming publication already exists
// get a list of tables in the publication
pubTables, forAllTables, err := GetPublicationTables(ctx, conn, publicationName, schema)
pubTables, forAllTables, err := GetPublicationTables(ctx, conn, publicationName)
if err != nil {
return fmt.Errorf("failed to get publication tables: %w", err)
}
Expand All @@ -414,8 +399,8 @@ func CreatePublication(ctx context.Context, conn *pgconn.PgConn, publicationName
return nil
}

var tablesToRemoveFromPublication = []string{}
var tablesToAddToPublication = []string{}
var tablesToRemoveFromPublication = []TableFQN{}
var tablesToAddToPublication = []TableFQN{}
for _, table := range tables {
if !slices.Contains(pubTables, table) {
tablesToAddToPublication = append(tablesToAddToPublication, table)
Expand All @@ -430,7 +415,7 @@ func CreatePublication(ctx context.Context, conn *pgconn.PgConn, publicationName

// remove tables from publication
for _, dropTable := range tablesToRemoveFromPublication {
sq, err := sanitize.SQLQuery(fmt.Sprintf(`ALTER PUBLICATION %s DROP TABLE %s."%s";`, publicationName, schema, dropTable))
sq, err := sanitize.SQLQuery(fmt.Sprintf(`ALTER PUBLICATION %s DROP TABLE %s;`, publicationName, dropTable.String()))
if err != nil {
return fmt.Errorf("failed to sanitize drop table query: %w", err)
}
Expand All @@ -442,7 +427,7 @@ func CreatePublication(ctx context.Context, conn *pgconn.PgConn, publicationName

// add tables to publication
for _, addTable := range tablesToAddToPublication {
sq, err := sanitize.SQLQuery(fmt.Sprintf("ALTER PUBLICATION %s ADD TABLE %s;", publicationName, addTable))
sq, err := sanitize.SQLQuery(fmt.Sprintf("ALTER PUBLICATION %s ADD TABLE %s;", publicationName, addTable.String()))
if err != nil {
return fmt.Errorf("failed to sanitize add table query: %w", err)
}
Expand All @@ -457,15 +442,15 @@ func CreatePublication(ctx context.Context, conn *pgconn.PgConn, publicationName

// GetPublicationTables returns a list of tables currently in the publication
// Arguments, in order: list of the tables, exist for all tables, errror
func GetPublicationTables(ctx context.Context, conn *pgconn.PgConn, publicationName, schema string) ([]string, bool, error) {
func GetPublicationTables(ctx context.Context, conn *pgconn.PgConn, publicationName string) ([]TableFQN, bool, error) {
query, err := sanitize.SQLQuery(`
SELECT DISTINCT
tablename as table_name
tablename as table_name,
schemaname as schema_name
FROM pg_publication_tables
WHERE pubname = $1
AND schemaname = $2
ORDER BY table_name;
`, publicationName, strings.Trim(schema, "\""))
ORDER BY schema_name, table_name;
`, publicationName)
if err != nil {
return nil, false, fmt.Errorf("failed to get publication tables: %w", err)
}
Expand All @@ -482,9 +467,13 @@ func GetPublicationTables(ctx context.Context, conn *pgconn.PgConn, publicationN
return nil, true, nil // Publication exists and is for all tables
}

tables := make([]string, 0, len(rows))
tables := make([]TableFQN, 0, len(rows))
for _, row := range rows[0].Rows {
tables = append(tables, string(row[0]))
// These come from postgres so they are valid, but we have to quote them
// to prevent normalization
table := sanitize.QuotePostgresIdentifier(string(row[0]))
schema := sanitize.QuotePostgresIdentifier(string(row[1]))
tables = append(tables, TableFQN{Table: table, Schema: schema})
}

return tables, false, nil
Expand Down
Loading

0 comments on commit d8f02c4

Please sign in to comment.