Skip to content

Commit

Permalink
Merge pull request #3059 from redpanda-data/pgcdc
Browse files Browse the repository at this point in the history
  • Loading branch information
rockwotj authored Dec 4, 2024
2 parents d303c7f + cce7e9c commit 88681e3
Show file tree
Hide file tree
Showing 9 changed files with 158 additions and 332 deletions.
22 changes: 9 additions & 13 deletions internal/impl/postgresql/input_pg_stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -338,7 +338,7 @@ func (p *pgStreamInput) processStream(pgStream *pglogicalstream.Stream, batcher
var nextTimedBatchChan <-chan time.Time

// offsets are nilable since we don't provide offset tracking during the snapshot phase
cp := checkpoint.NewCapped[*int64](int64(p.checkpointLimit))
cp := checkpoint.NewCapped[*string](int64(p.checkpointLimit))
for !p.stopSig.IsSoftStopSignalled() {
select {
case <-nextTimedBatchChan:
Expand All @@ -365,8 +365,8 @@ func (p *pgStreamInput) processStream(pgStream *pglogicalstream.Stream, batcher
batchMsg.MetaSet("mode", string(message.Mode))
batchMsg.MetaSet("table", message.Table)
batchMsg.MetaSet("operation", string(message.Operation))
if message.Lsn != nil {
batchMsg.MetaSet("lsn", *message.Lsn)
if message.LSN != nil {
batchMsg.MetaSet("lsn", *message.LSN)
}
if batcher.Add(batchMsg) {
nextTimedBatchChan = nil
Expand Down Expand Up @@ -398,22 +398,18 @@ func (p *pgStreamInput) processStream(pgStream *pglogicalstream.Stream, batcher
func (p *pgStreamInput) flushBatch(
ctx context.Context,
pgStream *pglogicalstream.Stream,
checkpointer *checkpoint.Capped[*int64],
checkpointer *checkpoint.Capped[*string],
batch service.MessageBatch,
) error {
if len(batch) == 0 {
return nil
}

var lsn *int64
var lsn *string
lastMsg := batch[len(batch)-1]
lsnStr, ok := lastMsg.MetaGet("lsn")
if ok {
parsed, err := LSNToInt64(lsnStr)
if err != nil {
return fmt.Errorf("unable to extract LSN from last message in batch: %w", err)
}
lsn = &parsed
lsn = &lsnStr
}
resolveFn, err := checkpointer.Track(ctx, lsn, int64(len(batch)))
if err != nil {
Expand All @@ -425,11 +421,11 @@ func (p *pgStreamInput) flushBatch(
if maxOffset == nil {
return nil
}
lsn := *maxOffset
if lsn == nil {
maxLSN := *maxOffset
if maxLSN == nil {
return nil
}
if err = pgStream.AckLSN(ctx, Int64ToLSN(*lsn)); err != nil {
if err = pgStream.AckLSN(ctx, *maxLSN); err != nil {
return fmt.Errorf("unable to ack LSN to postgres: %w", err)
}
return nil
Expand Down
157 changes: 112 additions & 45 deletions internal/impl/postgresql/pglogicalstream/logical_stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,17 @@ import (
"fmt"
"slices"
"strings"
"sync"
"time"

"github.com/Jeffail/shutdown"
"github.com/jackc/pgx/v5/pgconn"
"github.com/jackc/pgx/v5/pgproto3"
"github.com/jackc/pgx/v5/pgtype"
"github.com/redpanda-data/benthos/v4/public/service"
"golang.org/x/sync/errgroup"

"github.com/redpanda-data/connect/v4/internal/impl/postgresql/pglogicalstream/sanitize"
"github.com/redpanda-data/connect/v4/internal/impl/postgresql/pglogicalstream/watermark"
)

const decodingPlugin = "pgoutput"
Expand All @@ -36,7 +37,9 @@ type Stream struct {

shutSig *shutdown.Signaller

clientXLogPos *watermark.Value[LSN]
ackedLSNMu sync.Mutex
// The LSN acked by the stream, we may not have acked this to postgres yet (ack, ack, ack)
ackedLSN LSN

standbyMessageTimeout time.Duration
nextStandbyMessageDeadline time.Time
Expand Down Expand Up @@ -209,7 +212,9 @@ func NewPgStream(ctx context.Context, config *Config) (*Stream, error) {
} else {
lsnrestart, _ = ParseLSN(confirmedLSNFromDB)
}
stream.clientXLogPos = watermark.New(lsnrestart)
if lsnrestart > 0 {
stream.ackedLSN = lsnrestart - 1
}

stream.standbyMessageTimeout = config.PgStandbyTimeout
stream.nextStandbyMessageDeadline = time.Now().Add(stream.standbyMessageTimeout)
Expand All @@ -225,7 +230,7 @@ func NewPgStream(ctx context.Context, config *Config) (*Stream, error) {
}
})

stream.logger.Debugf("starting stream from LSN %s with clientXLogPos %s and snapshot name %s", lsnrestart.String(), stream.clientXLogPos.Get().String(), stream.snapshotName)
stream.logger.Debugf("starting stream from LSN %s with snapshot name %s", lsnrestart.String(), stream.snapshotName)
// TODO(le-vlad): if snapshot processing is restarted we will just skip right to streaming...
if !freshlyCreatedSlot || !config.StreamOldData {
if err = stream.startLr(ctx, lsnrestart); err != nil {
Expand All @@ -234,7 +239,7 @@ func NewPgStream(ctx context.Context, config *Config) (*Stream, error) {

go func() {
defer stream.shutSig.TriggerHasStopped()
if err := stream.streamMessages(); err != nil {
if err := stream.streamMessages(lsnrestart); err != nil {
stream.errors <- fmt.Errorf("logical replication stream error: %w", err)
}
}()
Expand All @@ -250,7 +255,7 @@ func NewPgStream(ctx context.Context, config *Config) (*Stream, error) {
stream.errors <- fmt.Errorf("failed to start logical replication: %w", err)
return
}
if err := stream.streamMessages(); err != nil {
if err := stream.streamMessages(lsnrestart); err != nil {
stream.errors <- fmt.Errorf("logical replication stream error: %w", err)
}
}()
Expand Down Expand Up @@ -287,56 +292,76 @@ func (s *Stream) startLr(ctx context.Context, lsnStart LSN) error {
// AckLSN acknowledges the LSN up to which the stream has processed the messages.
// This makes Postgres to remove the WAL files that are no longer needed.
func (s *Stream) AckLSN(ctx context.Context, lsn string) error {
parsed, err := ParseLSN(lsn)
if err != nil {
return fmt.Errorf("unable to parse LSN: %w", err)
}
s.ackedLSNMu.Lock()
defer s.ackedLSNMu.Unlock()
if s.shutSig.IsHardStopSignalled() {
return fmt.Errorf("unable to ack LSN %s stream shutting down", lsn)
}
clientXLogPos, err := ParseLSN(lsn)
if err != nil {
return err
}
s.ackedLSN = parsed
return nil
}

func (s *Stream) getAckedLSN() LSN {
s.ackedLSNMu.Lock()
ackedLSN := s.ackedLSN
s.ackedLSNMu.Unlock()
return ackedLSN
}

err = SendStandbyStatusUpdate(
func (s *Stream) commitAckedLSN(ctx context.Context, lsn LSN) error {
err := SendStandbyStatusUpdate(
ctx,
s.pgConn,
StandbyStatusUpdate{
WALApplyPosition: clientXLogPos + 1,
WALWritePosition: clientXLogPos + 1,
WALFlushPosition: clientXLogPos + 1,
WALWritePosition: lsn + 1,
ReplyRequested: true,
},
)

if err != nil {
return fmt.Errorf("failed to send Standby status message at LSN %s: %w", clientXLogPos.String(), err)
return fmt.Errorf("failed to send standby status message at LSN %s: %w", lsn, err)
}

// Update client XLogPos after we ack the message
s.clientXLogPos.Set(clientXLogPos)
s.logger.Debugf("Sent Standby status message at LSN#%s", clientXLogPos.String())
s.nextStandbyMessageDeadline = time.Now().Add(s.standbyMessageTimeout)

return nil
}

func (s *Stream) streamMessages() error {
handler := NewPgOutputPluginHandler(s.messages, s.monitor, s.clientXLogPos, s.includeTxnMarkers)
func (s *Stream) streamMessages(currentLSN LSN) error {
relations := map[uint32]*RelationMessage{}
typeMap := pgtype.NewMap()
// If we don't stream commit messages we could not ack them, which means postgres will replay the whole transaction
// so if we're at the end of a stream and we get an ack for the last message in a txn, we need to ack the txn not the
// last message.
lastEmittedLSN := currentLSN
lastEmittedCommitLSN := currentLSN

commitLSN := func(force bool) error {
ctx, _ := s.shutSig.HardStopCtx(context.Background())
ackedLSN := s.getAckedLSN()
if ackedLSN == lastEmittedLSN {
ackedLSN = lastEmittedCommitLSN
}
if force || ackedLSN > currentLSN {
if err := s.commitAckedLSN(ctx, ackedLSN); err != nil {
return err
}
// Update the currentLSN
currentLSN = ackedLSN
}
return nil
}
defer func() {
if err := commitLSN(false); err != nil {
s.logger.Errorf("unable to acknowledge LSN on stream shutdown: %v", err)
}
}()

ctx, _ := s.shutSig.SoftStopCtx(context.Background())
for !s.shutSig.IsSoftStopSignalled() {
if time.Now().After(s.nextStandbyMessageDeadline) {
pos := s.clientXLogPos.Get()
err := SendStandbyStatusUpdate(
ctx,
s.pgConn,
StandbyStatusUpdate{
WALWritePosition: pos,
},
)
if err != nil {
return fmt.Errorf("unable to send standby status message at LSN %s: %w", pos, err)
}
s.logger.Debugf("Sent Standby status message at LSN#%s", pos.String())
s.nextStandbyMessageDeadline = time.Now().Add(s.standbyMessageTimeout)
if err := commitLSN(time.Now().After(s.nextStandbyMessageDeadline)); err != nil {
return err
}
recvCtx, cancel := context.WithDeadline(ctx, s.nextStandbyMessageDeadline)
rawMsg, err := s.pgConn.ReceiveMessage(recvCtx)
Expand Down Expand Up @@ -381,22 +406,64 @@ func (s *Stream) streamMessages() error {
if err != nil {
return fmt.Errorf("failed to parse XLogData: %w", err)
}
clientXLogPos := xld.WALStart + LSN(len(xld.WALData))
commit, err := handler.Handle(ctx, clientXLogPos, xld)
msgLSN := xld.WALStart + LSN(len(xld.WALData))
result, err := s.processChange(ctx, msgLSN, xld, relations, typeMap)
if err != nil {
return fmt.Errorf("decoding postgres changes failed: %w", err)
} else if commit {
// This is a hack and we probably should not do it
if err = s.AckLSN(ctx, clientXLogPos.String()); err != nil {
s.logger.Warnf("Failed to ack commit message LSN: %v", err)
}
}
// See the explaination above about lastEmittedCommitLSN but if this is a commit message, we want to
// only remap the commit of the last message in a transaction, so only update the remapped value if
// it was a suppressed commit, otherwise we just provide a noop mapping of commit LSN
if result == changeResultSuppressedCommitMessage {
lastEmittedCommitLSN = msgLSN
} else if result == changeResultEmittedMessage {
lastEmittedLSN = msgLSN
lastEmittedCommitLSN = msgLSN
}
}
}
// clean shutdown, return nil
return nil
}

type processChangeResult int

const (
changeResultNoMessage = 0
changeResultSuppressedCommitMessage = 1
changeResultEmittedMessage = 2
)

// Handle handles the pgoutput output
func (s *Stream) processChange(ctx context.Context, msgLSN LSN, xld XLogData, relations map[uint32]*RelationMessage, typeMap *pgtype.Map) (processChangeResult, error) {
// parse changes inside the transaction
message, err := decodePgOutput(xld.WALData, relations, typeMap)
if err != nil {
return changeResultNoMessage, err
}
if message == nil {
return changeResultNoMessage, nil
}

if !s.includeTxnMarkers {
switch message.Operation {
case CommitOpType:
return changeResultSuppressedCommitMessage, nil
case BeginOpType:
return changeResultNoMessage, nil
}
}

lsn := msgLSN.String()
message.LSN = &lsn
select {
case s.messages <- *message:
return changeResultEmittedMessage, nil
case <-ctx.Done():
return changeResultNoMessage, ctx.Err()
}
}

func (s *Stream) processSnapshot() error {
if err := s.snapshotter.prepare(); err != nil {
return fmt.Errorf("failed to prepare database snapshot - snapshot may be expired: %w", err)
Expand Down Expand Up @@ -512,7 +579,7 @@ func (s *Stream) processSnapshot() error {
}

snapshotChangePacket := StreamMessage{
Lsn: nil,
LSN: nil,
Mode: StreamModeSnapshot,
Operation: InsertOpType,

Expand Down
2 changes: 1 addition & 1 deletion internal/impl/postgresql/pglogicalstream/pglogrepl.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ type LSN uint64

// String formats the LSN value into the XXX/XXX format which is the text format used by PostgreSQL.
func (lsn LSN) String() string {
return fmt.Sprintf("%X/%X", uint32(lsn>>32), uint32(lsn))
return fmt.Sprintf("%08X/%08X", uint32(lsn>>32), uint32(lsn))
}

func (lsn *LSN) decodeText(src string) error {
Expand Down
37 changes: 35 additions & 2 deletions internal/impl/postgresql/pglogicalstream/pglogrepl_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ import (
"database/sql"
"encoding/json"
"fmt"
"math"
"slices"
"strings"
"testing"
"time"
Expand Down Expand Up @@ -50,7 +52,7 @@ func (s *lsnSuite) NoError(err error) {

func (s *lsnSuite) TestScannerInterface() {
var lsn LSN
lsnText := "16/B374D848"
lsnText := "00000016/B374D848"
lsnUint64 := uint64(97500059720)
var err error

Expand Down Expand Up @@ -84,7 +86,7 @@ func (s *lsnSuite) TestValueInterface() {
s.NoError(err)
lsnStr, ok := driverValue.(string)
s.R().True(ok)
s.Equal("16/B374D848", lsnStr)
s.Equal("00000016/B374D848", lsnStr)
}

const slotName = "pglogrepl_test"
Expand Down Expand Up @@ -472,3 +474,34 @@ func TestIntegrationSendStandbyStatusUpdate(t *testing.T) {
err = SendStandbyStatusUpdate(ctx, conn, StandbyStatusUpdate{WALWritePosition: sysident.XLogPos})
require.NoError(t, err)
}

func TestLSNStringLexicographicalOrder(t *testing.T) {
ordered := []uint64{
0,
1,
42,
math.MaxInt16 - 1,
math.MaxInt16,
math.MaxInt16 + 1,
math.MaxInt32 - 1,
math.MaxInt32,
math.MaxInt32 + 1,
math.MaxInt64 - 1,
math.MaxInt64,
math.MaxInt64 + 1,
math.MaxUint64 - 1,
math.MaxUint64,
}
slices.SortFunc(ordered, func(a, b uint64) int {
aStr := LSN(a).String()
bStr := LSN(b).String()
if aStr < bStr {
return -1
} else if aStr > bStr {
return 1
} else {
return 0
}
})
require.IsIncreasing(t, ordered)
}
Loading

0 comments on commit 88681e3

Please sign in to comment.