diff --git a/internal/impl/postgresql/input_pg_stream.go b/internal/impl/postgresql/input_pg_stream.go index d451287946..f3c9b17a61 100644 --- a/internal/impl/postgresql/input_pg_stream.go +++ b/internal/impl/postgresql/input_pg_stream.go @@ -338,7 +338,7 @@ func (p *pgStreamInput) processStream(pgStream *pglogicalstream.Stream, batcher var nextTimedBatchChan <-chan time.Time // offsets are nilable since we don't provide offset tracking during the snapshot phase - cp := checkpoint.NewCapped[*int64](int64(p.checkpointLimit)) + cp := checkpoint.NewCapped[*string](int64(p.checkpointLimit)) for !p.stopSig.IsSoftStopSignalled() { select { case <-nextTimedBatchChan: @@ -365,8 +365,8 @@ func (p *pgStreamInput) processStream(pgStream *pglogicalstream.Stream, batcher batchMsg.MetaSet("mode", string(message.Mode)) batchMsg.MetaSet("table", message.Table) batchMsg.MetaSet("operation", string(message.Operation)) - if message.Lsn != nil { - batchMsg.MetaSet("lsn", *message.Lsn) + if message.LSN != nil { + batchMsg.MetaSet("lsn", *message.LSN) } if batcher.Add(batchMsg) { nextTimedBatchChan = nil @@ -398,22 +398,18 @@ func (p *pgStreamInput) processStream(pgStream *pglogicalstream.Stream, batcher func (p *pgStreamInput) flushBatch( ctx context.Context, pgStream *pglogicalstream.Stream, - checkpointer *checkpoint.Capped[*int64], + checkpointer *checkpoint.Capped[*string], batch service.MessageBatch, ) error { if len(batch) == 0 { return nil } - var lsn *int64 + var lsn *string lastMsg := batch[len(batch)-1] lsnStr, ok := lastMsg.MetaGet("lsn") if ok { - parsed, err := LSNToInt64(lsnStr) - if err != nil { - return fmt.Errorf("unable to extract LSN from last message in batch: %w", err) - } - lsn = &parsed + lsn = &lsnStr } resolveFn, err := checkpointer.Track(ctx, lsn, int64(len(batch))) if err != nil { @@ -425,11 +421,11 @@ func (p *pgStreamInput) flushBatch( if maxOffset == nil { return nil } - lsn := *maxOffset - if lsn == nil { + maxLSN := *maxOffset + if maxLSN == nil { return nil } - if err = pgStream.AckLSN(ctx, Int64ToLSN(*lsn)); err != nil { + if err = pgStream.AckLSN(ctx, *maxLSN); err != nil { return fmt.Errorf("unable to ack LSN to postgres: %w", err) } return nil diff --git a/internal/impl/postgresql/pglogicalstream/logical_stream.go b/internal/impl/postgresql/pglogicalstream/logical_stream.go index e17f4367fc..5d46407877 100644 --- a/internal/impl/postgresql/pglogicalstream/logical_stream.go +++ b/internal/impl/postgresql/pglogicalstream/logical_stream.go @@ -15,16 +15,17 @@ import ( "fmt" "slices" "strings" + "sync" "time" "github.com/Jeffail/shutdown" "github.com/jackc/pgx/v5/pgconn" "github.com/jackc/pgx/v5/pgproto3" + "github.com/jackc/pgx/v5/pgtype" "github.com/redpanda-data/benthos/v4/public/service" "golang.org/x/sync/errgroup" "github.com/redpanda-data/connect/v4/internal/impl/postgresql/pglogicalstream/sanitize" - "github.com/redpanda-data/connect/v4/internal/impl/postgresql/pglogicalstream/watermark" ) const decodingPlugin = "pgoutput" @@ -36,7 +37,9 @@ type Stream struct { shutSig *shutdown.Signaller - clientXLogPos *watermark.Value[LSN] + ackedLSNMu sync.Mutex + // The LSN acked by the stream, we may not have acked this to postgres yet (ack, ack, ack) + ackedLSN LSN standbyMessageTimeout time.Duration nextStandbyMessageDeadline time.Time @@ -209,7 +212,9 @@ func NewPgStream(ctx context.Context, config *Config) (*Stream, error) { } else { lsnrestart, _ = ParseLSN(confirmedLSNFromDB) } - stream.clientXLogPos = watermark.New(lsnrestart) + if lsnrestart > 0 { + stream.ackedLSN = lsnrestart - 1 + } stream.standbyMessageTimeout = config.PgStandbyTimeout stream.nextStandbyMessageDeadline = time.Now().Add(stream.standbyMessageTimeout) @@ -225,7 +230,7 @@ func NewPgStream(ctx context.Context, config *Config) (*Stream, error) { } }) - stream.logger.Debugf("starting stream from LSN %s with clientXLogPos %s and snapshot name %s", lsnrestart.String(), stream.clientXLogPos.Get().String(), stream.snapshotName) + stream.logger.Debugf("starting stream from LSN %s with snapshot name %s", lsnrestart.String(), stream.snapshotName) // TODO(le-vlad): if snapshot processing is restarted we will just skip right to streaming... if !freshlyCreatedSlot || !config.StreamOldData { if err = stream.startLr(ctx, lsnrestart); err != nil { @@ -234,7 +239,7 @@ func NewPgStream(ctx context.Context, config *Config) (*Stream, error) { go func() { defer stream.shutSig.TriggerHasStopped() - if err := stream.streamMessages(); err != nil { + if err := stream.streamMessages(lsnrestart); err != nil { stream.errors <- fmt.Errorf("logical replication stream error: %w", err) } }() @@ -250,7 +255,7 @@ func NewPgStream(ctx context.Context, config *Config) (*Stream, error) { stream.errors <- fmt.Errorf("failed to start logical replication: %w", err) return } - if err := stream.streamMessages(); err != nil { + if err := stream.streamMessages(lsnrestart); err != nil { stream.errors <- fmt.Errorf("logical replication stream error: %w", err) } }() @@ -287,56 +292,76 @@ func (s *Stream) startLr(ctx context.Context, lsnStart LSN) error { // AckLSN acknowledges the LSN up to which the stream has processed the messages. // This makes Postgres to remove the WAL files that are no longer needed. func (s *Stream) AckLSN(ctx context.Context, lsn string) error { + parsed, err := ParseLSN(lsn) + if err != nil { + return fmt.Errorf("unable to parse LSN: %w", err) + } + s.ackedLSNMu.Lock() + defer s.ackedLSNMu.Unlock() if s.shutSig.IsHardStopSignalled() { return fmt.Errorf("unable to ack LSN %s stream shutting down", lsn) } - clientXLogPos, err := ParseLSN(lsn) - if err != nil { - return err - } + s.ackedLSN = parsed + return nil +} + +func (s *Stream) getAckedLSN() LSN { + s.ackedLSNMu.Lock() + ackedLSN := s.ackedLSN + s.ackedLSNMu.Unlock() + return ackedLSN +} - err = SendStandbyStatusUpdate( +func (s *Stream) commitAckedLSN(ctx context.Context, lsn LSN) error { + err := SendStandbyStatusUpdate( ctx, s.pgConn, StandbyStatusUpdate{ - WALApplyPosition: clientXLogPos + 1, - WALWritePosition: clientXLogPos + 1, - WALFlushPosition: clientXLogPos + 1, + WALWritePosition: lsn + 1, ReplyRequested: true, }, ) - if err != nil { - return fmt.Errorf("failed to send Standby status message at LSN %s: %w", clientXLogPos.String(), err) + return fmt.Errorf("failed to send standby status message at LSN %s: %w", lsn, err) } - - // Update client XLogPos after we ack the message - s.clientXLogPos.Set(clientXLogPos) - s.logger.Debugf("Sent Standby status message at LSN#%s", clientXLogPos.String()) s.nextStandbyMessageDeadline = time.Now().Add(s.standbyMessageTimeout) - return nil } -func (s *Stream) streamMessages() error { - handler := NewPgOutputPluginHandler(s.messages, s.monitor, s.clientXLogPos, s.includeTxnMarkers) +func (s *Stream) streamMessages(currentLSN LSN) error { + relations := map[uint32]*RelationMessage{} + typeMap := pgtype.NewMap() + // If we don't stream commit messages we could not ack them, which means postgres will replay the whole transaction + // so if we're at the end of a stream and we get an ack for the last message in a txn, we need to ack the txn not the + // last message. + lastEmittedLSN := currentLSN + lastEmittedCommitLSN := currentLSN + + commitLSN := func(force bool) error { + ctx, _ := s.shutSig.HardStopCtx(context.Background()) + ackedLSN := s.getAckedLSN() + if ackedLSN == lastEmittedLSN { + ackedLSN = lastEmittedCommitLSN + } + if force || ackedLSN > currentLSN { + if err := s.commitAckedLSN(ctx, ackedLSN); err != nil { + return err + } + // Update the currentLSN + currentLSN = ackedLSN + } + return nil + } + defer func() { + if err := commitLSN(false); err != nil { + s.logger.Errorf("unable to acknowledge LSN on stream shutdown: %v", err) + } + }() ctx, _ := s.shutSig.SoftStopCtx(context.Background()) for !s.shutSig.IsSoftStopSignalled() { - if time.Now().After(s.nextStandbyMessageDeadline) { - pos := s.clientXLogPos.Get() - err := SendStandbyStatusUpdate( - ctx, - s.pgConn, - StandbyStatusUpdate{ - WALWritePosition: pos, - }, - ) - if err != nil { - return fmt.Errorf("unable to send standby status message at LSN %s: %w", pos, err) - } - s.logger.Debugf("Sent Standby status message at LSN#%s", pos.String()) - s.nextStandbyMessageDeadline = time.Now().Add(s.standbyMessageTimeout) + if err := commitLSN(time.Now().After(s.nextStandbyMessageDeadline)); err != nil { + return err } recvCtx, cancel := context.WithDeadline(ctx, s.nextStandbyMessageDeadline) rawMsg, err := s.pgConn.ReceiveMessage(recvCtx) @@ -381,15 +406,19 @@ func (s *Stream) streamMessages() error { if err != nil { return fmt.Errorf("failed to parse XLogData: %w", err) } - clientXLogPos := xld.WALStart + LSN(len(xld.WALData)) - commit, err := handler.Handle(ctx, clientXLogPos, xld) + msgLSN := xld.WALStart + LSN(len(xld.WALData)) + result, err := s.processChange(ctx, msgLSN, xld, relations, typeMap) if err != nil { return fmt.Errorf("decoding postgres changes failed: %w", err) - } else if commit { - // This is a hack and we probably should not do it - if err = s.AckLSN(ctx, clientXLogPos.String()); err != nil { - s.logger.Warnf("Failed to ack commit message LSN: %v", err) - } + } + // See the explaination above about lastEmittedCommitLSN but if this is a commit message, we want to + // only remap the commit of the last message in a transaction, so only update the remapped value if + // it was a suppressed commit, otherwise we just provide a noop mapping of commit LSN + if result == changeResultSuppressedCommitMessage { + lastEmittedCommitLSN = msgLSN + } else if result == changeResultEmittedMessage { + lastEmittedLSN = msgLSN + lastEmittedCommitLSN = msgLSN } } } @@ -397,6 +426,44 @@ func (s *Stream) streamMessages() error { return nil } +type processChangeResult int + +const ( + changeResultNoMessage = 0 + changeResultSuppressedCommitMessage = 1 + changeResultEmittedMessage = 2 +) + +// Handle handles the pgoutput output +func (s *Stream) processChange(ctx context.Context, msgLSN LSN, xld XLogData, relations map[uint32]*RelationMessage, typeMap *pgtype.Map) (processChangeResult, error) { + // parse changes inside the transaction + message, err := decodePgOutput(xld.WALData, relations, typeMap) + if err != nil { + return changeResultNoMessage, err + } + if message == nil { + return changeResultNoMessage, nil + } + + if !s.includeTxnMarkers { + switch message.Operation { + case CommitOpType: + return changeResultSuppressedCommitMessage, nil + case BeginOpType: + return changeResultNoMessage, nil + } + } + + lsn := msgLSN.String() + message.LSN = &lsn + select { + case s.messages <- *message: + return changeResultEmittedMessage, nil + case <-ctx.Done(): + return changeResultNoMessage, ctx.Err() + } +} + func (s *Stream) processSnapshot() error { if err := s.snapshotter.prepare(); err != nil { return fmt.Errorf("failed to prepare database snapshot - snapshot may be expired: %w", err) @@ -512,7 +579,7 @@ func (s *Stream) processSnapshot() error { } snapshotChangePacket := StreamMessage{ - Lsn: nil, + LSN: nil, Mode: StreamModeSnapshot, Operation: InsertOpType, diff --git a/internal/impl/postgresql/pglogicalstream/pglogrepl.go b/internal/impl/postgresql/pglogicalstream/pglogrepl.go index 99debf5190..368d7f0b02 100644 --- a/internal/impl/postgresql/pglogicalstream/pglogrepl.go +++ b/internal/impl/postgresql/pglogicalstream/pglogrepl.go @@ -66,7 +66,7 @@ type LSN uint64 // String formats the LSN value into the XXX/XXX format which is the text format used by PostgreSQL. func (lsn LSN) String() string { - return fmt.Sprintf("%X/%X", uint32(lsn>>32), uint32(lsn)) + return fmt.Sprintf("%08X/%08X", uint32(lsn>>32), uint32(lsn)) } func (lsn *LSN) decodeText(src string) error { diff --git a/internal/impl/postgresql/pglogicalstream/pglogrepl_test.go b/internal/impl/postgresql/pglogicalstream/pglogrepl_test.go index a67ff0fcd4..683d2910b0 100644 --- a/internal/impl/postgresql/pglogicalstream/pglogrepl_test.go +++ b/internal/impl/postgresql/pglogicalstream/pglogrepl_test.go @@ -13,6 +13,8 @@ import ( "database/sql" "encoding/json" "fmt" + "math" + "slices" "strings" "testing" "time" @@ -50,7 +52,7 @@ func (s *lsnSuite) NoError(err error) { func (s *lsnSuite) TestScannerInterface() { var lsn LSN - lsnText := "16/B374D848" + lsnText := "00000016/B374D848" lsnUint64 := uint64(97500059720) var err error @@ -84,7 +86,7 @@ func (s *lsnSuite) TestValueInterface() { s.NoError(err) lsnStr, ok := driverValue.(string) s.R().True(ok) - s.Equal("16/B374D848", lsnStr) + s.Equal("00000016/B374D848", lsnStr) } const slotName = "pglogrepl_test" @@ -472,3 +474,34 @@ func TestIntegrationSendStandbyStatusUpdate(t *testing.T) { err = SendStandbyStatusUpdate(ctx, conn, StandbyStatusUpdate{WALWritePosition: sysident.XLogPos}) require.NoError(t, err) } + +func TestLSNStringLexicographicalOrder(t *testing.T) { + ordered := []uint64{ + 0, + 1, + 42, + math.MaxInt16 - 1, + math.MaxInt16, + math.MaxInt16 + 1, + math.MaxInt32 - 1, + math.MaxInt32, + math.MaxInt32 + 1, + math.MaxInt64 - 1, + math.MaxInt64, + math.MaxInt64 + 1, + math.MaxUint64 - 1, + math.MaxUint64, + } + slices.SortFunc(ordered, func(a, b uint64) int { + aStr := LSN(a).String() + bStr := LSN(b).String() + if aStr < bStr { + return -1 + } else if aStr > bStr { + return 1 + } else { + return 0 + } + }) + require.IsIncreasing(t, ordered) +} diff --git a/internal/impl/postgresql/pglogicalstream/pluginhandlers.go b/internal/impl/postgresql/pglogicalstream/pluginhandlers.go deleted file mode 100644 index 431608f74e..0000000000 --- a/internal/impl/postgresql/pglogicalstream/pluginhandlers.go +++ /dev/null @@ -1,93 +0,0 @@ -// Copyright 2024 Redpanda Data, Inc. -// -// Licensed as a Redpanda Enterprise file under the Redpanda Community -// License (the "License"); you may not use this file except in compliance with -// the License. You may obtain a copy of the License at -// -// https://github.com/redpanda-data/connect/v4/blob/main/licenses/rcl.md - -package pglogicalstream - -import ( - "context" - - "github.com/jackc/pgx/v5/pgtype" - - "github.com/redpanda-data/connect/v4/internal/impl/postgresql/pglogicalstream/watermark" -) - -// PluginHandler is an interface that must be implemented by all plugin handlers -type PluginHandler interface { - // returns true if we need to ack the clientXLogPos - Handle(ctx context.Context, clientXLogPos LSN, xld XLogData) (bool, error) -} - -// PgOutputUnbufferedPluginHandler is a native output handler that emits each message as it's received. -type PgOutputUnbufferedPluginHandler struct { - messages chan StreamMessage - monitor *Monitor - - relations map[uint32]*RelationMessage - typeMap *pgtype.Map - - lastEmitted LSN - lsnWatermark *watermark.Value[LSN] - includeTxnMarkers bool -} - -// NewPgOutputPluginHandler creates a new PgOutputPluginHandler -func NewPgOutputPluginHandler( - messages chan StreamMessage, - monitor *Monitor, - lsnWatermark *watermark.Value[LSN], - includeTxnMarkers bool, -) PluginHandler { - return &PgOutputUnbufferedPluginHandler{ - messages: messages, - monitor: monitor, - relations: map[uint32]*RelationMessage{}, - typeMap: pgtype.NewMap(), - lastEmitted: lsnWatermark.Get(), - lsnWatermark: lsnWatermark, - includeTxnMarkers: includeTxnMarkers, - } -} - -// Handle handles the pgoutput output -func (p *PgOutputUnbufferedPluginHandler) Handle(ctx context.Context, clientXLogPos LSN, xld XLogData) (bool, error) { - // parse changes inside the transaction - message, err := decodePgOutput(xld.WALData, p.relations, p.typeMap) - if err != nil { - return false, err - } - if message == nil { - return false, nil - } - - if !p.includeTxnMarkers { - switch message.Operation { - case CommitOpType: - // when receiving a commit message, we need to acknowledge the LSN - // but we must wait for connect to flush the messages before we can do that - select { - case <-p.lsnWatermark.WaitFor(p.lastEmitted): - return true, nil - case <-ctx.Done(): - return false, ctx.Err() - } - case BeginOpType: - return false, nil - } - } - - lsn := clientXLogPos.String() - message.Lsn = &lsn - select { - case p.messages <- *message: - p.lastEmitted = clientXLogPos - case <-ctx.Done(): - return false, ctx.Err() - } - - return false, nil -} diff --git a/internal/impl/postgresql/pglogicalstream/stream_message.go b/internal/impl/postgresql/pglogicalstream/stream_message.go index 3291287956..bf75faaa0a 100644 --- a/internal/impl/postgresql/pglogicalstream/stream_message.go +++ b/internal/impl/postgresql/pglogicalstream/stream_message.go @@ -36,7 +36,7 @@ const ( // StreamMessage represents a single change from the database type StreamMessage struct { - Lsn *string `json:"lsn"` + LSN *string `json:"lsn"` Operation OpType `json:"operation"` Schema string `json:"schema"` Table string `json:"table"` diff --git a/internal/impl/postgresql/pglogicalstream/watermark/watermark.go b/internal/impl/postgresql/pglogicalstream/watermark/watermark.go deleted file mode 100644 index 56dc30783b..0000000000 --- a/internal/impl/postgresql/pglogicalstream/watermark/watermark.go +++ /dev/null @@ -1,70 +0,0 @@ -/* - * Copyright 2024 Redpanda Data, Inc. - * - * Licensed as a Redpanda Enterprise file under the Redpanda Community - * License (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * https://github.com/redpanda-data/redpanda/blob/master/licenses/rcl.md - */ - -package watermark - -import ( - "cmp" - "sync" -) - -// Value is a utility that allows you to store the highest value and subscribe to when -// a specific offset is reached -type ( - Value[T cmp.Ordered] struct { - val T - mu sync.Mutex - waiters map[chan<- any]T - } -) - -// New makes a new Value holding `initial` -func New[T cmp.Ordered](initial T) *Value[T] { - w := &Value[T]{val: initial} - w.waiters = map[chan<- any]T{} - return w -} - -// Set the watermark value if it's newer -func (w *Value[T]) Set(v T) { - w.mu.Lock() - defer w.mu.Unlock() - if v <= w.val { - return - } - w.val = v - for notify, val := range w.waiters { - if val <= w.val { - notify <- nil - delete(w.waiters, notify) - } - } -} - -// Get the current watermark value -func (w *Value[T]) Get() T { - w.mu.Lock() - cpy := w.val - w.mu.Unlock() - return cpy -} - -// WaitFor returns a channel that recieves a value when the watermark reaches `val`. -func (w *Value[T]) WaitFor(val T) <-chan any { - w.mu.Lock() - defer w.mu.Unlock() - ch := make(chan any, 1) - if w.val >= val { - ch <- nil - return ch - } - w.waiters[ch] = val - return ch -} diff --git a/internal/impl/postgresql/pglogicalstream/watermark/watermark_test.go b/internal/impl/postgresql/pglogicalstream/watermark/watermark_test.go deleted file mode 100644 index 637deff653..0000000000 --- a/internal/impl/postgresql/pglogicalstream/watermark/watermark_test.go +++ /dev/null @@ -1,53 +0,0 @@ -/* - * Copyright 2024 Redpanda Data, Inc. - * - * Licensed as a Redpanda Enterprise file under the Redpanda Community - * License (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * https://github.com/redpanda-data/redpanda/blob/master/licenses/rcl.md - */ - -package watermark_test - -import ( - "testing" - - "github.com/stretchr/testify/require" - - "github.com/redpanda-data/connect/v4/internal/impl/postgresql/pglogicalstream/watermark" -) - -func TestWatermark(t *testing.T) { - w := watermark.New(5) - require.Equal(t, 5, w.Get()) - w.Set(3) - require.Equal(t, 5, w.Get()) - require.Len(t, w.WaitFor(1), 1) - ch1 := w.WaitFor(9) - ch2 := w.WaitFor(10) - ch3 := w.WaitFor(10) - ch4 := w.WaitFor(100) - require.Empty(t, ch1) - require.Empty(t, ch2) - require.Empty(t, ch3) - require.Empty(t, ch4) - w.Set(8) - require.Equal(t, 8, w.Get()) - require.Empty(t, ch1) - require.Empty(t, ch2) - require.Empty(t, ch3) - require.Empty(t, ch4) - w.Set(9) - require.Equal(t, 9, w.Get()) - require.Len(t, ch1, 1) - require.Empty(t, ch2) - require.Empty(t, ch3) - require.Empty(t, ch4) - w.Set(10) - require.Equal(t, 10, w.Get()) - require.Len(t, ch1, 1) - require.Len(t, ch2, 1) - require.Len(t, ch3, 1) - require.Empty(t, ch4) -} diff --git a/internal/impl/postgresql/utils.go b/internal/impl/postgresql/utils.go deleted file mode 100644 index d01bf441cd..0000000000 --- a/internal/impl/postgresql/utils.go +++ /dev/null @@ -1,54 +0,0 @@ -// Copyright 2024 Redpanda Data, Inc. -// -// Licensed as a Redpanda Enterprise file under the Redpanda Community -// License (the "License"); you may not use this file except in compliance with -// the License. You may obtain a copy of the License at -// -// https://github.com/redpanda-data/connect/v4/blob/main/licenses/rcl.md - -package pgstream - -import ( - "fmt" - "strconv" - "strings" -) - -// LSNToInt64 converts a PostgreSQL LSN string to int64 -func LSNToInt64(lsn string) (int64, error) { - // Split the LSN into segments - parts := strings.Split(lsn, "/") - if len(parts) != 2 { - return 0, fmt.Errorf("invalid LSN format: %s", lsn) - } - - // Parse both segments as hex with uint64 first - upper, err := strconv.ParseUint(parts[0], 16, 32) - if err != nil { - return 0, fmt.Errorf("failed to parse upper part: %w", err) - } - - lower, err := strconv.ParseUint(parts[1], 16, 32) - if err != nil { - return 0, fmt.Errorf("failed to parse lower part: %w", err) - } - - // Combine the segments into a single int64 - // Upper part is shifted left by 32 bits - result := int64((upper << 32) | lower) - - return result, nil -} - -// Int64ToLSN converts an int64 to a PostgreSQL LSN string -func Int64ToLSN(value int64) string { - // Convert to uint64 to handle the bitwise operations properly - uvalue := uint64(value) - - // Extract upper and lower parts - upper := uvalue >> 32 - lower := uvalue & 0xFFFFFFFF - - // Format as hexadecimal with proper padding - return fmt.Sprintf("%X/%X", upper, lower) -}