Skip to content

Commit

Permalink
pgcdc: introduce TableFQN type
Browse files Browse the repository at this point in the history
There is currently a mess of "what does this string mean?", which means
it's time to introduce some typesafety to this problem.

TableFQN is a Schema+Table pair that is prevalidated to not have SQL
injection opportunities and we can pass these around to make things a
bit more clear as well as ensure we're handling quoted stuff correctly.
  • Loading branch information
rockwotj committed Dec 13, 2024
1 parent 781d5bc commit 20e90bf
Show file tree
Hide file tree
Showing 10 changed files with 148 additions and 135 deletions.
2 changes: 1 addition & 1 deletion internal/impl/postgresql/input_pg_stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -321,7 +321,7 @@ func (p *pgStreamInput) processStream(pgStream *pglogicalstream.Stream, batcher
// Periodically collect stats
report := pgStream.GetProgress()
for name, progress := range report.TableProgress {
p.snapshotMetrics.SetFloat64(progress, name)
p.snapshotMetrics.SetFloat64(progress, name.String())
}
p.replicationLag.Set(report.WalLagInBytes)
})
Expand Down
12 changes: 6 additions & 6 deletions internal/impl/postgresql/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ func ResourceWithPostgreSQLVersion(t *testing.T, pool *dockertest.Pool, version
}

_, err = db.Exec(`
CREATE TABLE IF NOT EXISTS flights_composite_pks (
CREATE TABLE IF NOT EXISTS "FlightsCompositePK" (
id serial, seq integer, name VARCHAR(50), created_at TIMESTAMP,
PRIMARY KEY (id, seq)
);`)
Expand Down Expand Up @@ -171,7 +171,7 @@ func TestIntegrationPostgresNoTxnMarkers(t *testing.T) {

for i := 0; i < 10; i++ {
f := GetFakeFlightRecord()
_, err = db.Exec("INSERT INTO flights_composite_pks (seq, name, created_at) VALUES ($1, $2, $3);", i, f.RealAddress.City, time.Unix(f.CreatedAt, 0).Format(time.RFC3339))
_, err = db.Exec(`INSERT INTO "FlightsCompositePK" (seq, name, created_at) VALUES ($1, $2, $3);`, i, f.RealAddress.City, time.Unix(f.CreatedAt, 0).Format(time.RFC3339))
require.NoError(t, err)
}

Expand All @@ -184,7 +184,7 @@ pg_stream:
snapshot_batch_size: 5
schema: public
tables:
- flights_composite_pks
- '"FlightsCompositePK"'
`, databaseURL)

cacheConf := fmt.Sprintf(`
Expand Down Expand Up @@ -226,9 +226,9 @@ file:

for i := 10; i < 20; i++ {
f := GetFakeFlightRecord()
_, err = db.Exec("INSERT INTO flights_composite_pks (seq, name, created_at) VALUES ($1, $2, $3);", i, f.RealAddress.City, time.Unix(f.CreatedAt, 0).Format(time.RFC3339))
_, err = db.Exec(`INSERT INTO "FlightsCompositePK" (seq, name, created_at) VALUES ($1, $2, $3);`, i, f.RealAddress.City, time.Unix(f.CreatedAt, 0).Format(time.RFC3339))
require.NoError(t, err)
_, err = db.Exec("INSERT INTO flights_non_streamed (name, created_at) VALUES ($1, $2);", f.RealAddress.City, time.Unix(f.CreatedAt, 0).Format(time.RFC3339))
_, err = db.Exec(`INSERT INTO flights_non_streamed (name, created_at) VALUES ($1, $2);`, f.RealAddress.City, time.Unix(f.CreatedAt, 0).Format(time.RFC3339))
require.NoError(t, err)
}

Expand Down Expand Up @@ -270,7 +270,7 @@ file:
time.Sleep(time.Second * 5)
for i := 20; i < 30; i++ {
f := GetFakeFlightRecord()
_, err = db.Exec("INSERT INTO flights_composite_pks (seq, name, created_at) VALUES ($1, $2, $3);", i, f.RealAddress.City, time.Unix(f.CreatedAt, 0).Format(time.RFC3339))
_, err = db.Exec(`INSERT INTO "FlightsCompositePK" (seq, name, created_at) VALUES ($1, $2, $3);`, i, f.RealAddress.City, time.Unix(f.CreatedAt, 0).Format(time.RFC3339))
require.NoError(t, err)
}

Expand Down
47 changes: 22 additions & 25 deletions internal/impl/postgresql/pglogicalstream/logical_stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,6 @@ import (
"database/sql"
"errors"
"fmt"
"slices"
"strings"
"sync"
"time"

Expand Down Expand Up @@ -46,12 +44,10 @@ type Stream struct {
messages chan StreamMessage
errors chan error

includeTxnMarkers bool
snapshotName string
slotName string
schema string
// includes schema
tableQualifiedName []string
includeTxnMarkers bool
snapshotName string
slotName string
tables []TableFQN
snapshotBatchSize int
decodingPluginArguments []string
snapshotMemorySafetyFactor float64
Expand Down Expand Up @@ -90,13 +86,16 @@ func NewPgStream(ctx context.Context, config *Config) (*Stream, error) {
return nil, err
}

tableNames := slices.Clone(config.DBTables)
for i, table := range tableNames {
if err := sanitize.ValidatePostgresIdentifier(config.DBSchema); err != nil {
return nil, fmt.Errorf("invalid schema name %q: %w", config.DBSchema, err)
}

tables := []TableFQN{}
for _, table := range config.DBTables {
if err := sanitize.ValidatePostgresIdentifier(table); err != nil {
return nil, fmt.Errorf("invalid table name %q: %w", table, err)
}

tableNames[i] = fmt.Sprintf("%s.%s", config.DBSchema, table)
tables = append(tables, TableFQN{Schema: config.DBSchema, Table: table})
}
stream := &Stream{
pgConn: dbConn,
Expand All @@ -105,8 +104,7 @@ func NewPgStream(ctx context.Context, config *Config) (*Stream, error) {
slotName: config.ReplicationSlotName,
snapshotMemorySafetyFactor: config.SnapshotMemorySafetyFactor,
snapshotBatchSize: config.BatchSize,
schema: config.DBSchema,
tableQualifiedName: tableNames,
tables: tables,
maxParallelSnapshotTables: config.MaxParallelSnapshotTables,
logger: config.Logger,
shutSig: shutdown.NewSignaller(),
Expand Down Expand Up @@ -143,8 +141,8 @@ func NewPgStream(ctx context.Context, config *Config) (*Stream, error) {
stream.decodingPluginArguments = pluginArguments

pubName := "pglog_stream_" + config.ReplicationSlotName
stream.logger.Infof("Creating publication %s for tables: %s", pubName, tableNames)
if err = CreatePublication(ctx, stream.pgConn, pubName, config.DBSchema, tableNames); err != nil {
stream.logger.Infof("Creating publication %s for tables: %s", pubName, tables)
if err = CreatePublication(ctx, stream.pgConn, pubName, tables); err != nil {
return nil, err
}
cleanups = append(cleanups, func() {
Expand Down Expand Up @@ -219,7 +217,7 @@ func NewPgStream(ctx context.Context, config *Config) (*Stream, error) {
stream.standbyMessageTimeout = config.PgStandbyTimeout
stream.nextStandbyMessageDeadline = time.Now().Add(stream.standbyMessageTimeout)

monitor, err := NewMonitor(ctx, config.DBRawDSN, stream.logger, tableNames, stream.slotName, config.WalMonitorInterval)
monitor, err := NewMonitor(ctx, config.DBRawDSN, stream.logger, tables, stream.slotName, config.WalMonitorInterval)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -481,7 +479,7 @@ func (s *Stream) processSnapshot() error {
var wg errgroup.Group
wg.SetLimit(s.maxParallelSnapshotTables)

for _, table := range s.tableQualifiedName {
for _, table := range s.tables {
tableName := table
wg.Go(func() (err error) {
s.logger.Debugf("Processing snapshot for table: %v", table)
Expand Down Expand Up @@ -551,7 +549,6 @@ func (s *Stream) processSnapshot() error {
totalScanDuration := time.Duration(0)
totalWaitingFromBenthos := time.Duration(0)

tableWithoutSchema := strings.Split(table, ".")[1]
for snapshotRows.Next() {
rowsCount += 1

Expand Down Expand Up @@ -581,13 +578,13 @@ func (s *Stream) processSnapshot() error {
snapshotChangePacket := StreamMessage{
LSN: nil,
Operation: ReadOpType,
Table: tableWithoutSchema,
Schema: s.schema,
Table: table.Table,
Schema: table.Schema,
Data: data,
}

if rowsCount%100 == 0 {
s.monitor.UpdateSnapshotProgressForTable(tableWithoutSchema, rowsCount+offset)
s.monitor.UpdateSnapshotProgressForTable(table, rowsCount+offset)
}

waitingFromBenthos := time.Now()
Expand Down Expand Up @@ -627,7 +624,7 @@ func (s *Stream) Errors() chan error {
return s.errors
}

func (s *Stream) getPrimaryKeyColumn(ctx context.Context, tableName string) (map[string]any, []string, error) {
func (s *Stream) getPrimaryKeyColumn(ctx context.Context, table TableFQN) (map[string]any, []string, error) {
/// Query to get all primary key columns in their correct order
q, err := sanitize.SQLQuery(`
SELECT a.attname
Expand All @@ -637,7 +634,7 @@ func (s *Stream) getPrimaryKeyColumn(ctx context.Context, tableName string) (map
WHERE i.indrelid = $1::regclass
AND i.indisprimary
ORDER BY array_position(i.indkey, a.attnum);
`, tableName)
`, table.String())

if err != nil {
return nil, nil, fmt.Errorf("failed to sanitize query: %w", err)
Expand All @@ -650,7 +647,7 @@ func (s *Stream) getPrimaryKeyColumn(ctx context.Context, tableName string) (map
}

if len(data) == 0 || len(data[0].Rows) == 0 {
return nil, nil, fmt.Errorf("no primary key found for table %s", tableName)
return nil, nil, fmt.Errorf("no primary key found for table %s", table)
}

// Extract all primary key column names
Expand Down
33 changes: 12 additions & 21 deletions internal/impl/postgresql/pglogicalstream/monitor.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,23 +21,22 @@ import (
"github.com/redpanda-data/benthos/v4/public/service"

"github.com/redpanda-data/connect/v4/internal/asyncroutine"
"github.com/redpanda-data/connect/v4/internal/impl/postgresql/pglogicalstream/sanitize"
)

// Report is a structure that contains the current state of the Monitor
type Report struct {
WalLagInBytes int64
TableProgress map[string]float64
TableProgress map[TableFQN]float64
}

// Monitor is a structure that allows monitoring the progress of snapshot ingestion and replication lag
type Monitor struct {
// tableStat contains numbers of rows for each table determined at the moment of the snapshot creation
// this is used to calculate snapshot ingestion progress
tableStat map[string]int64
tableStat map[TableFQN]int64
lock sync.Mutex
// snapshotProgress is a map of table names to the percentage of rows ingested from the snapshot
snapshotProgress map[string]float64
snapshotProgress map[TableFQN]float64
// replicationLagInBytes is the replication lag in bytes measured by
// finding the difference between the latest LSN and the last confirmed LSN for the replication slot
replicationLagInBytes int64
Expand All @@ -53,7 +52,7 @@ func NewMonitor(
ctx context.Context,
dbDSN string,
logger *service.Logger,
tables []string,
tables []TableFQN,
slotName string,
interval time.Duration,
) (*Monitor, error) {
Expand All @@ -66,7 +65,7 @@ func NewMonitor(
}

m := &Monitor{
snapshotProgress: map[string]float64{},
snapshotProgress: map[TableFQN]float64{},
replicationLagInBytes: 0,
dbConn: dbConn,
slotName: slotName,
Expand All @@ -81,40 +80,32 @@ func NewMonitor(
}

// UpdateSnapshotProgressForTable updates the snapshot ingestion progress for a given table
func (m *Monitor) UpdateSnapshotProgressForTable(table string, position int) {
func (m *Monitor) UpdateSnapshotProgressForTable(table TableFQN, position int) {
m.lock.Lock()
defer m.lock.Unlock()
m.snapshotProgress[table] = math.Round(float64(position) / float64(m.tableStat[table]) * 100)
}

// we need to read the tables stat to calculate the snapshot ingestion progress
func (m *Monitor) readTablesStat(ctx context.Context, tables []string) error {
results := make(map[string]int64)
func (m *Monitor) readTablesStat(ctx context.Context, tables []TableFQN) error {
results := make(map[TableFQN]int64)

for _, table := range tables {
tableWithoutSchema := strings.Split(table, ".")[1]
err := sanitize.ValidatePostgresIdentifier(tableWithoutSchema)

if err != nil {
return fmt.Errorf("error sanitizing query: %w", err)
}

var count int64
// tableWithoutSchema has been validated so its safe to use in the query
err = m.dbConn.QueryRowContext(ctx, "SELECT COUNT(*) FROM "+tableWithoutSchema).Scan(&count)
err := m.dbConn.QueryRowContext(ctx, "SELECT COUNT(*) FROM "+table.String()).Scan(&count)

if err != nil {
// If the error is because the table doesn't exist, we'll set the count to 0
// and continue. You might want to log this situation.
if strings.Contains(err.Error(), "does not exist") {
results[tableWithoutSchema] = 0
results[table] = 0
continue
}
// For any other error, we'll return it
return fmt.Errorf("error counting rows in table %s: %w", tableWithoutSchema, err)
return fmt.Errorf("error counting rows in table %s: %w", table, err)
}

results[tableWithoutSchema] = count
results[table] = count
}

m.tableStat = results
Expand Down
Loading

0 comments on commit 20e90bf

Please sign in to comment.