diff --git a/chain/chain.go b/chain/chain.go index 5ae1d646c4..d48991dd05 100644 --- a/chain/chain.go +++ b/chain/chain.go @@ -27,7 +27,6 @@ type Chain struct { func NewChain( tracer trace.Tracer, registerer *prometheus.Registry, - parser Parser, mempool Mempool, logger logging.Logger, ruleFactory RuleFactory, @@ -35,6 +34,7 @@ func NewChain( balanceHandler BalanceHandler, authVerifiers workers.Workers, authEngines AuthEngines, + blockParser *BlockParser, validityWindow ValidityWindow, config Config, ) (*Chain, error) { @@ -72,7 +72,7 @@ func NewChain( metadataManager, balanceHandler, ), - blockParser: NewBlockParser(tracer, parser), + blockParser: blockParser, accepter: NewAccepter(tracer, validityWindow, metrics), }, nil } diff --git a/chain/chaintest/block.go b/chain/chaintest/block.go index f31c7e6231..51f84fafb6 100644 --- a/chain/chaintest/block.go +++ b/chain/chaintest/block.go @@ -278,27 +278,6 @@ func (test *BlockBenchmark[T]) Run(ctx context.Context, b *testing.B) { } chainIndex := &validitywindowtest.MockChainIndex[*chain.Transaction]{} - validityWindow := validitywindow.NewTimeValidityWindow( - logging.NoLog{}, - trace.Noop, - chainIndex, - func(timestamp int64) int64 { - return test.RuleFactory.GetRules(timestamp).GetValidityWindow() - }, - ) - - processor := chain.NewProcessor( - trace.Noop, - &logging.NoLog{}, - test.RuleFactory, - processorWorkers, - test.AuthEngines, - test.MetadataManager, - test.BalanceHandler, - validityWindow, - metrics, - test.Config, - ) factories, keys, genesis, err := test.GenesisF(test.NumTxsPerBlock) r.NoError(err) @@ -343,6 +322,34 @@ func (test *BlockBenchmark[T]) Run(ctx context.Context, b *testing.B) { chainIndex.Set(blk.GetID(), blk) } + // Populate a validity window starting from the last block from chain-index + head := blocks[len(blocks)-1] + + validityWindow, err := validitywindow.NewTimeValidityWindow( + ctx, + logging.NoLog{}, + trace.Noop, + chainIndex, + head, + func(timestamp int64) int64 { + return test.RuleFactory.GetRules(timestamp).GetValidityWindow() + }, + ) + r.NoError(err) + + processor := chain.NewProcessor( + trace.Noop, + &logging.NoLog{}, + test.RuleFactory, + processorWorkers, + test.AuthEngines, + test.MetadataManager, + test.BalanceHandler, + validityWindow, + metrics, + test.Config, + ) + var parentView merkledb.View parentView = db b.ResetTimer() diff --git a/chainindex/chain_index.go b/chainindex/chain_index.go index 06e5bed9d8..b1ab921450 100644 --- a/chainindex/chain_index.go +++ b/chainindex/chain_index.go @@ -7,6 +7,7 @@ import ( "context" "encoding/binary" "errors" + "fmt" "math/rand" "time" @@ -65,6 +66,7 @@ type Parser[T Block] interface { } func New[T Block]( + ctx context.Context, log logging.Logger, registry prometheus.Registerer, config Config, @@ -79,7 +81,7 @@ func New[T Block]( return nil, errBlockCompactionFrequencyZero } - return &ChainIndex[T]{ + ci := &ChainIndex[T]{ config: config, // Offset by random number to ensure the network does not compact simultaneously compactionOffset: rand.Uint64() % config.BlockCompactionFrequency, //nolint:gosec @@ -87,7 +89,13 @@ func New[T Block]( log: log, db: db, parser: parser, - }, nil + } + + if err := ci.cleanupOnStartup(ctx); err != nil { + return nil, err + } + + return ci, nil } func (c *ChainIndex[T]) GetLastAcceptedHeight(_ context.Context) (uint64, error) { @@ -101,19 +109,11 @@ func (c *ChainIndex[T]) GetLastAcceptedHeight(_ context.Context) (uint64, error) func (c *ChainIndex[T]) UpdateLastAccepted(ctx context.Context, blk T) error { batch := c.db.NewBatch() - var ( - blkID = blk.GetID() - height = blk.GetHeight() - blkBytes = blk.GetBytes() - ) + height := blk.GetHeight() heightBytes := binary.BigEndian.AppendUint64(nil, height) - err := errors.Join( + if err := errors.Join( batch.Put(lastAcceptedKey, heightBytes), - batch.Put(prefixBlockIDHeightKey(blkID), heightBytes), - batch.Put(prefixBlockHeightIDKey(height), blkID[:]), - batch.Put(prefixBlockKey(height), blkBytes), - ) - if err != nil { + c.writeBlock(batch, blk)); err != nil { return err } @@ -122,17 +122,15 @@ func (c *ChainIndex[T]) UpdateLastAccepted(ctx context.Context, blk T) error { return batch.Write() } - if err := batch.Delete(prefixBlockKey(expiryHeight)); err != nil { - return err - } deleteBlkID, err := c.GetBlockIDAtHeight(ctx, expiryHeight) if err != nil { return err } - if err := batch.Delete(prefixBlockIDHeightKey(deleteBlkID)); err != nil { - return err - } - if err := batch.Delete(prefixBlockHeightIDKey(expiryHeight)); err != nil { + if err = errors.Join( + batch.Delete(prefixBlockKey(expiryHeight)), + batch.Delete(prefixBlockIDHeightKey(deleteBlkID)), + batch.Delete(prefixBlockHeightIDKey(expiryHeight)), + ); err != nil { return err } c.metrics.deletedBlocks.Inc() @@ -151,6 +149,18 @@ func (c *ChainIndex[T]) UpdateLastAccepted(ctx context.Context, blk T) error { return batch.Write() } +// SaveHistorical writes block on-disk, without updating lastAcceptedKey, +// It should be used only for historical blocks, it's relying on heuristic of eventually calling UpdateLastAccepted, +// which will delete expired blocks +func (c *ChainIndex[T]) SaveHistorical(blk T) error { + batch := c.db.NewBatch() + if err := c.writeBlock(batch, blk); err != nil { + return err + } + + return batch.Write() +} + func (c *ChainIndex[T]) GetBlock(ctx context.Context, blkID ids.ID) (T, error) { height, err := c.GetBlockIDHeight(ctx, blkID) if err != nil { @@ -183,6 +193,105 @@ func (c *ChainIndex[T]) GetBlockByHeight(ctx context.Context, blkHeight uint64) return c.parser.ParseBlock(ctx, blkBytes) } +func (_ *ChainIndex[T]) writeBlock(batch database.Batch, blk T) error { + var ( + blkID = blk.GetID() + height = blk.GetHeight() + blkBytes = blk.GetBytes() + ) + heightBytes := binary.BigEndian.AppendUint64(nil, height) + return errors.Join( + batch.Put(prefixBlockIDHeightKey(blkID), heightBytes), + batch.Put(prefixBlockHeightIDKey(height), blkID[:]), + batch.Put(prefixBlockKey(height), blkBytes), + ) +} + +// cleanupOnStartup performs cleanup of historical blocks outside the accepted window. +// +// The cleanup removes all blocks below this threshold: +// | <--- Historical Blocks (delete) ---> | <--- AcceptedBlockWindow ---> | Last Accepted | +func (c *ChainIndex[T]) cleanupOnStartup(ctx context.Context) error { + lastAcceptedHeight, err := c.GetLastAcceptedHeight(ctx) + if err != nil && err != database.ErrNotFound { + return err + } + + // If there's no accepted window or lastAcceptedHeight is too small, nothing to clean + if c.config.AcceptedBlockWindow == 0 || lastAcceptedHeight <= c.config.AcceptedBlockWindow { + return nil + } + + thresholdHeight := lastAcceptedHeight - c.config.AcceptedBlockWindow + + c.log.Debug("cleaning up historical blocks outside accepted window", + zap.Uint64("lastAcceptedHeight", lastAcceptedHeight), + zap.Uint64("thresholdHeight", thresholdHeight), + zap.Uint64("acceptedBlockWindow", c.config.AcceptedBlockWindow)) + + it := c.db.NewIteratorWithPrefix([]byte{blockHeightIDPrefix}) + defer it.Release() + + batch := c.db.NewBatch() + var lastDeletedHeight uint64 + + for it.Next() { + key := it.Key() + height := extractBlockHeightFromKey(key) + + // Nothing to delete after the threshold height + if height >= thresholdHeight { + break + } + + // Skip if: + // Block is at genesis height (0) + if height == 0 { + continue + } + + deleteBlkID, err := c.GetBlockIDAtHeight(ctx, height) + if err != nil { + return err + } + + if err = errors.Join( + batch.Delete(prefixBlockKey(height)), + batch.Delete(prefixBlockIDHeightKey(deleteBlkID)), + batch.Delete(prefixBlockHeightIDKey(height)), + ); err != nil { + return err + } + c.metrics.deletedBlocks.Inc() + + // Keep track of the last height we deleted + lastDeletedHeight = height + } + + if err := it.Error(); err != nil { + return fmt.Errorf("iterator error during cleanup: %w", err) + } + + // Write all the deletions + if err := batch.Write(); err != nil { + return err + } + + // Perform a single compaction at the end if we deleted anything + if lastDeletedHeight > 0 { + go func() { + start := time.Now() + if err := c.db.Compact([]byte{blockPrefix}, prefixBlockKey(lastDeletedHeight)); err != nil { + c.log.Error("failed to compact block store", zap.Error(err)) + return + } + c.log.Info("compacted disk blocks", zap.Uint64("end", lastDeletedHeight), zap.Duration("t", time.Since(start))) + }() + } + + return nil +} + func prefixBlockKey(height uint64) []byte { k := make([]byte, 1+consts.Uint64Len) k[0] = blockPrefix @@ -203,3 +312,9 @@ func prefixBlockHeightIDKey(height uint64) []byte { binary.BigEndian.PutUint64(k[1:], height) return k } + +// extractBlockHeightFromKey extracts block height from the key. +// The key is expected to be in the format: [1-byte prefix][8-byte big-endian encoded uint64] +func extractBlockHeightFromKey(key []byte) uint64 { + return binary.BigEndian.Uint64(key[1:]) +} diff --git a/chainindex/chain_index_test.go b/chainindex/chain_index_test.go index 1268c6bf4e..d8715c5c36 100644 --- a/chainindex/chain_index_test.go +++ b/chainindex/chain_index_test.go @@ -38,8 +38,8 @@ func (*parser) ParseBlock(_ context.Context, b []byte) (*testBlock, error) { return &testBlock{height: height}, nil } -func newTestChainIndex(config Config, db database.Database) (*ChainIndex[*testBlock], error) { - return New(logging.NoLog{}, prometheus.NewRegistry(), config, &parser{}, db) +func newTestChainIndex(ctx context.Context, config Config, db database.Database) (*ChainIndex[*testBlock], error) { + return New(ctx, logging.NoLog{}, prometheus.NewRegistry(), config, &parser{}, db) } func confirmBlockIndexed(r *require.Assertions, ctx context.Context, chainIndex *ChainIndex[*testBlock], expectedBlk *testBlock, expectedErr error) { @@ -74,7 +74,7 @@ func confirmLastAcceptedHeight(r *require.Assertions, ctx context.Context, chain func TestChainIndex(t *testing.T) { r := require.New(t) ctx := context.Background() - chainIndex, err := newTestChainIndex(NewDefaultConfig(), memdb.New()) + chainIndex, err := newTestChainIndex(ctx, NewDefaultConfig(), memdb.New()) r.NoError(err) genesisBlk := &testBlock{height: 0} @@ -93,14 +93,15 @@ func TestChainIndex(t *testing.T) { } func TestChainIndexInvalidCompactionFrequency(t *testing.T) { - _, err := newTestChainIndex(Config{BlockCompactionFrequency: 0}, memdb.New()) + ctx := context.Background() + _, err := newTestChainIndex(ctx, Config{BlockCompactionFrequency: 0}, memdb.New()) require.ErrorIs(t, err, errBlockCompactionFrequencyZero) } func TestChainIndexExpiry(t *testing.T) { r := require.New(t) ctx := context.Background() - chainIndex, err := newTestChainIndex(Config{AcceptedBlockWindow: 1, BlockCompactionFrequency: 64}, memdb.New()) + chainIndex, err := newTestChainIndex(ctx, Config{AcceptedBlockWindow: 1, BlockCompactionFrequency: 64}, memdb.New()) r.NoError(err) genesisBlk := &testBlock{height: 0} @@ -128,3 +129,174 @@ func TestChainIndexExpiry(t *testing.T) { confirmBlockIndexed(r, ctx, chainIndex, blk2, database.ErrNotFound) confirmLastAcceptedHeight(r, ctx, chainIndex, blk3.GetHeight()) } + +func TestChainIndex_SaveHistorical(t *testing.T) { + r := require.New(t) + ctx := context.Background() + chainIndex, err := newTestChainIndex(ctx, NewDefaultConfig(), memdb.New()) + r.NoError(err) + + // Create and save a genesis block normally first + genesisBlk := &testBlock{height: 0} + r.NoError(chainIndex.UpdateLastAccepted(ctx, genesisBlk)) + confirmBlockIndexed(r, ctx, chainIndex, genesisBlk, nil) + confirmLastAcceptedHeight(r, ctx, chainIndex, genesisBlk.GetHeight()) + + // Create a higher height block, but don't make it the last accepted + historicalBlk := &testBlock{height: 100} + + // Save the historical block + r.NoError(chainIndex.SaveHistorical(historicalBlk)) + + // Verify the historical block is indexed + confirmBlockIndexed(r, ctx, chainIndex, historicalBlk, nil) + + // Verify lastAccepted hasn't changed (still points to genesis) + confirmLastAcceptedHeight(r, ctx, chainIndex, genesisBlk.GetHeight()) + + // Create and save a normal block that should become the new last accepted + blk1 := &testBlock{height: 1} + r.NoError(chainIndex.UpdateLastAccepted(ctx, blk1)) + + // Verify the new block is indexed + confirmBlockIndexed(r, ctx, chainIndex, blk1, nil) + + // Verify the historical block is still indexed + confirmBlockIndexed(r, ctx, chainIndex, historicalBlk, nil) + + // Verify lastAccepted points to the new block + confirmLastAcceptedHeight(r, ctx, chainIndex, blk1.GetHeight()) +} + +func TestChainIndex_Cleanup(t *testing.T) { + tests := []struct { + name string + config Config + setup func(ctx context.Context, r *require.Assertions, chainIndex *ChainIndex[*testBlock]) + verifyAfterCleanup func(ctx context.Context, r *require.Assertions, chainIndex *ChainIndex[*testBlock]) + }{ + { + name: "If there's no accepted window, nothing to clean", + config: Config{ + AcceptedBlockWindow: 0, + BlockCompactionFrequency: 1, + }, + setup: func(ctx context.Context, r *require.Assertions, chainIndex *ChainIndex[*testBlock]) { + // Add blocks 0-10 + for i := 0; i <= 10; i++ { + blkHeight := uint64(i) + blk := &testBlock{height: blkHeight} + if i < 5 { + r.NoError(chainIndex.SaveHistorical(blk)) + } else { + r.NoError(chainIndex.UpdateLastAccepted(ctx, blk)) + } + confirmBlockIndexed(r, ctx, chainIndex, blk, nil) + } + confirmLastAcceptedHeight(r, ctx, chainIndex, 10) + }, + verifyAfterCleanup: func(ctx context.Context, r *require.Assertions, chainIndex *ChainIndex[*testBlock]) { + // All blocks should still exist + for i := uint64(0); i <= 10; i++ { + _, err := chainIndex.GetBlockByHeight(ctx, i) + r.NoError(err, "Block at height %d should exist", i) + } + }, + }, + { + name: "If lastAcceptedHeight is too small (less than window), nothing to clean", + config: Config{ + AcceptedBlockWindow: 20, + BlockCompactionFrequency: 1, + }, + setup: func(ctx context.Context, r *require.Assertions, chainIndex *ChainIndex[*testBlock]) { + // Add blocks 0-10 + for i := 0; i <= 10; i++ { + blk := &testBlock{height: uint64(i)} + if i < 5 { + r.NoError(chainIndex.SaveHistorical(blk)) + } else { + r.NoError(chainIndex.UpdateLastAccepted(ctx, blk)) + } + confirmBlockIndexed(r, ctx, chainIndex, blk, nil) + } + confirmLastAcceptedHeight(r, ctx, chainIndex, 10) + }, + verifyAfterCleanup: func(ctx context.Context, r *require.Assertions, chainIndex *ChainIndex[*testBlock]) { + // All blocks should still exist because 10 < 20 (window) + for i := uint64(0); i <= 10; i++ { + _, err := chainIndex.GetBlockByHeight(ctx, i) + r.NoError(err, "Block at height %d should exist", i) + } + }, + }, + { + name: "Should cleanup historical blocks", + config: Config{ + AcceptedBlockWindow: 5, + BlockCompactionFrequency: 1, + }, + setup: func(ctx context.Context, r *require.Assertions, chainIndex *ChainIndex[*testBlock]) { + // Create blocks in reverse order (historical first) + + // Historical blocks (0-7) added in reverse order + for i := 7; i >= 0; i-- { + blk := &testBlock{height: uint64(i)} + r.NoError(chainIndex.SaveHistorical(blk)) + confirmBlockIndexed(r, ctx, chainIndex, blk, nil) + _, err := chainIndex.GetLastAcceptedHeight(ctx) + r.ErrorIs(err, database.ErrNotFound) + } + + // Last accepted blocks (8-10) + for i := 8; i <= 10; i++ { + blk := &testBlock{height: uint64(i)} + // UpdateLastAccepted should clean blocks: `expiryHeight := i-AcceptedBlockWindow` + // example: 8-5 = 3, 9-5 = 4, 10-5 = 5 + r.NoError(chainIndex.UpdateLastAccepted(ctx, blk)) + confirmBlockIndexed(r, ctx, chainIndex, blk, nil) + } + + confirmLastAcceptedHeight(r, ctx, chainIndex, 10) + }, + verifyAfterCleanup: func(ctx context.Context, r *require.Assertions, chainIndex *ChainIndex[*testBlock]) { + // UpdateLastAccepted deleted blocks, 3, 4 and 5 but there's gap 1 and 2, those should be deleted as well + for i := 1; i <= 5; i++ { + _, err := chainIndex.GetBlockByHeight(ctx, uint64(i)) + r.ErrorIs(err, database.ErrNotFound, "Block at height %d should be deleted", i) + } + + // Blocks 6-10 should still exist + for i := uint64(6); i <= 10; i++ { + _, err := chainIndex.GetBlockByHeight(ctx, i) + r.NoError(err, "Block at height %d should exist", i) + } + }, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + r := require.New(t) + ctx := context.Background() + + chainIndex, err := newTestChainIndex(ctx, test.config, memdb.New()) + r.NoError(err) + + if test.setup != nil { + test.setup(ctx, r, chainIndex) + } + + r.NoError(chainIndex.cleanupOnStartup(ctx)) + + if test.verifyAfterCleanup != nil { + test.verifyAfterCleanup(ctx, r, chainIndex) + } + + // Genesis should never be deleted + genesis, err := chainIndex.GetBlockByHeight(ctx, uint64(0)) + r.Equal(uint64(0), genesis.GetHeight(), "Genesis block should not be deleted") + r.NoError(err) + }) + } +} diff --git a/internal/validitywindow/syncer.go b/internal/validitywindow/syncer.go index 1f47d85d08..5beeb33219 100644 --- a/internal/validitywindow/syncer.go +++ b/internal/validitywindow/syncer.go @@ -5,6 +5,7 @@ package validitywindow import ( "context" + "errors" "fmt" "sync" "sync/atomic" @@ -16,6 +17,12 @@ type BlockFetcher[T Block] interface { FetchBlocks(ctx context.Context, blk Block, minTimestamp *atomic.Int64) <-chan T } +type BlockStore[T Block] interface { + SaveHistorical(blk T) error +} + +var errSaveHistoricalBlocks = errors.New("failed to save historical blocks") + // Syncer ensures the node does not transition to normal operation // until it has built a complete validity window of blocks. // @@ -39,7 +46,7 @@ type BlockFetcher[T Block] interface { // // The validity window can be marked as complete once either mechanism completes. type Syncer[T emap.Item, B ExecutionBlock[T]] struct { - chainIndex ChainIndex[T] + blockStore BlockStore[B] timeValidityWindow *TimeValidityWindow[T] getValidityWindow GetTimeValidityWindowFunc blockFetcherClient BlockFetcher[B] @@ -53,9 +60,9 @@ type Syncer[T emap.Item, B ExecutionBlock[T]] struct { cancel context.CancelFunc // For canceling backward sync } -func NewSyncer[T emap.Item, B ExecutionBlock[T]](chainIndex ChainIndex[T], timeValidityWindow *TimeValidityWindow[T], blockFetcherClient BlockFetcher[B], getValidityWindow GetTimeValidityWindowFunc) *Syncer[T, B] { +func NewSyncer[T emap.Item, B ExecutionBlock[T]](blockStore BlockStore[B], timeValidityWindow *TimeValidityWindow[T], blockFetcherClient BlockFetcher[B], getValidityWindow GetTimeValidityWindowFunc) *Syncer[T, B] { return &Syncer[T, B]{ - chainIndex: chainIndex, + blockStore: blockStore, timeValidityWindow: timeValidityWindow, blockFetcherClient: blockFetcherClient, getValidityWindow: getValidityWindow, @@ -65,7 +72,7 @@ func NewSyncer[T emap.Item, B ExecutionBlock[T]](chainIndex ChainIndex[T], timeV } func (s *Syncer[T, B]) Start(ctx context.Context, target B) error { - minTS := s.calculateMinTimestamp(target.GetTimestamp()) + minTS := s.timeValidityWindow.calculateOldestAllowed(target.GetTimestamp()) s.minTimestamp.Store(minTS) // Try to build a partial validity window from existing blocks @@ -83,6 +90,14 @@ func (s *Syncer[T, B]) Start(ctx context.Context, target B) error { go func() { resultChan := s.blockFetcherClient.FetchBlocks(syncCtx, s.oldestBlock, &s.minTimestamp) for blk := range resultChan { + if err := s.blockStore.SaveHistorical(blk); err != nil { + s.errChan <- fmt.Errorf( + "%w: aborting to prevent inconsistencies %w", + errSaveHistoricalBlocks, + err, + ) + return + } s.timeValidityWindow.AcceptHistorical(blk) } @@ -97,7 +112,7 @@ func (s *Syncer[T, B]) Wait(ctx context.Context) error { case <-s.doneChan: return nil case err := <-s.errChan: - return fmt.Errorf("timve valdity syncer exited with error: %w", err) + return fmt.Errorf("time validity syncer exited with error: %w", err) case <-ctx.Done(): return fmt.Errorf("waiting for time validity syncer timed out: %w", ctx.Err()) } @@ -118,8 +133,8 @@ func (s *Syncer[T, B]) UpdateSyncTarget(_ context.Context, target B) error { return s.Close() } - // Update minimum timestamp based on new target - minTS := s.calculateMinTimestamp(target.GetTimestamp()) + // Update timestamp based on new target + minTS := s.timeValidityWindow.calculateOldestAllowed(target.GetTimestamp()) s.minTimestamp.Store(minTS) return nil @@ -136,17 +151,15 @@ func (s *Syncer[T, B]) accept(blk B) bool { } // backfillFromExisting attempts to build a validity window from existing blocks -// Returns: -// - The last accepted block (newest) -// - Whether we saw the full validity window +// returns whether we saw the full validity window func (s *Syncer[T, B]) backfillFromExisting( ctx context.Context, block ExecutionBlock[T], ) bool { - parents, seenValidityWindow := s.timeValidityWindow.PopulateValidityWindow(ctx, block) + validityBlocks, windowComplete := s.timeValidityWindow.populate(ctx, block) - s.oldestBlock = parents[len(parents)-1] - return seenValidityWindow + s.oldestBlock = validityBlocks[0] + return windowComplete } func (s *Syncer[T, B]) signalDone() { @@ -154,18 +167,3 @@ func (s *Syncer[T, B]) signalDone() { close(s.doneChan) }) } - -// calculateMinTimestamp determines the oldest allowable timestamp for blocks -// in the validity window based on: -// - target block's timestamp -// - validity window duration from getValidityWindow -// The minimum timestamp is used to determine when to stop fetching historical -// blocks when backfilling the validity window. -func (s *Syncer[T, B]) calculateMinTimestamp(targetTS int64) int64 { - validityWindow := s.getValidityWindow(targetTS) - minTS := targetTS - validityWindow - if minTS < 0 { - minTS = 0 - } - return minTS -} diff --git a/internal/validitywindow/syncer_test.go b/internal/validitywindow/syncer_test.go index 9cf000493e..3414b92205 100644 --- a/internal/validitywindow/syncer_test.go +++ b/internal/validitywindow/syncer_test.go @@ -7,6 +7,7 @@ import ( "context" "testing" + "github.com/ava-labs/avalanchego/ids" "github.com/ava-labs/avalanchego/trace" "github.com/ava-labs/avalanchego/utils/logging" "github.com/stretchr/testify/require" @@ -18,224 +19,264 @@ type testCases struct { numOfBlocks int setupChainIndex func([]ExecutionBlock[container]) *testChainIndex setupFetcher func([]ExecutionBlock[container]) *BlockFetcherClient[ExecutionBlock[container]] - verifyFunc func(context.Context, *require.Assertions, []ExecutionBlock[container], *Syncer[container, ExecutionBlock[container]]) + verify func(context.Context, *require.Assertions, []ExecutionBlock[container], *Syncer[container, ExecutionBlock[container]], *testChainIndex) } -func TestSyncer(t *testing.T) { - t.Run("Start", func(t *testing.T) { - tests := []testCases{ - { - name: "should return full validity window from cache", - numOfBlocks: 15, - validityWindow: 5, - setupChainIndex: newTestChainIndex, - setupFetcher: func(_ []ExecutionBlock[container]) *BlockFetcherClient[ExecutionBlock[container]] { - // no need for fetcher - return nil - }, - verifyFunc: func(ctx context.Context, req *require.Assertions, blkChain []ExecutionBlock[container], syncer *Syncer[container, ExecutionBlock[container]]) { - target := blkChain[len(blkChain)-1] - err := syncer.Start(ctx, target) - req.NoError(err) - req.Equal(blkChain[len(blkChain)-1].GetHeight(), syncer.timeValidityWindow.lastAcceptedBlockHeight) - // We're expecting oldestBlock to have height 8 because: - // - We have 15 blocks (height 0-14) - // - Validity window is 5 time units - // - Given target block at height 14 (timestamp 14) - // - We need blocks until timestamp difference > 5 - // - This happens at block height 8 (14 - 8 > 5) - req.Equal(blkChain[8].GetHeight(), syncer.oldestBlock.GetHeight()) - }, +func TestSyncer_Start(t *testing.T) { + tests := []testCases{ + { + name: "should return full validity window from cache", + numOfBlocks: 15, + validityWindow: 5, + setupChainIndex: newTestChainIndex, + setupFetcher: func(_ []ExecutionBlock[container]) *BlockFetcherClient[ExecutionBlock[container]] { + // no need for fetcher + return nil }, - { - name: "should return full validity window built partially from cache and peers", - validityWindow: 15, - numOfBlocks: 20, - setupChainIndex: func(blkChain []ExecutionBlock[container]) *testChainIndex { - // Add the most recent 5 blocks in-memory - return newTestChainIndex(blkChain[15:]) - }, - setupFetcher: newFetcher, - verifyFunc: func(ctx context.Context, req *require.Assertions, blkChain []ExecutionBlock[container], syncer *Syncer[container, ExecutionBlock[container]]) { - // we should have the most recent 5 blocks in-memory - // that is not enough to build full validity window, we need to fetch the rest from the network - target := blkChain[len(blkChain)-1] - err := syncer.Start(ctx, target) - req.NoError(err) - req.NoError(syncer.Wait(ctx)) - - // the last accepted height should be the last accepted height from the cache, since historical blocks should not update the last accepted field - req.Equal(blkChain[len(blkChain)-1].GetHeight(), syncer.timeValidityWindow.lastAcceptedBlockHeight) - req.Equal(blkChain[15].GetHeight(), syncer.oldestBlock.GetHeight()) - - // verify the oldest allowed block in time validity window - req.Equal(blkChain[4].GetTimestamp(), syncer.timeValidityWindow.calculateOldestAllowed(target.GetTimestamp())) - req.NotEqual(blkChain[3].GetTimestamp(), syncer.timeValidityWindow.calculateOldestAllowed(target.GetTimestamp())) - }, + verify: func(ctx context.Context, r *require.Assertions, blkChain []ExecutionBlock[container], syncer *Syncer[container, ExecutionBlock[container]], _ *testChainIndex) { + target := blkChain[len(blkChain)-1] + err := syncer.Start(ctx, target) + r.NoError(err) + r.Equal(blkChain[len(blkChain)-1].GetHeight(), syncer.timeValidityWindow.lastAcceptedBlockHeight) + // We're expecting oldestBlock to have height 8 because: + // - We have 15 blocks (height 0-14) + // - Validity window is 5 time units + // - Given target block at height 14 (timestamp 14) + // - We need blocks until timestamp difference > 5 + // - This happens at block height 8 (14 - 8 > 5) + r.Equal(blkChain[8].GetHeight(), syncer.oldestBlock.GetHeight()) }, - { - name: "should return full validity window from peers", - validityWindow: 15, - numOfBlocks: 20, - setupChainIndex: func(_ []ExecutionBlock[container]) *testChainIndex { - return &testChainIndex{} - }, - setupFetcher: newFetcher, - verifyFunc: func(ctx context.Context, req *require.Assertions, blkChain []ExecutionBlock[container], syncer *Syncer[container, ExecutionBlock[container]]) { - target := blkChain[len(blkChain)-1] - err := syncer.Start(ctx, target) - req.NoError(err) - req.NoError(syncer.Wait(ctx)) - - req.Equal(uint64(19), syncer.timeValidityWindow.lastAcceptedBlockHeight) - req.Equal(syncer.timeValidityWindow.calculateOldestAllowed(target.GetTimestamp()), blkChain[4].GetTimestamp()) - }, + }, + { + name: "should return full validity window built partially from cache and peers", + validityWindow: 15, + numOfBlocks: 20, + setupChainIndex: func(blkChain []ExecutionBlock[container]) *testChainIndex { + // Add the most recent 5 blocks in-memory + return newTestChainIndex(blkChain[15:]) }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - runSyncerTest(t, test) - }) - } - }) - - t.Run("UpdateSyncTarget", func(t *testing.T) { - tests := []testCases{ - { - name: "update with newer block expands window forward", - validityWindow: 15, - numOfBlocks: 25, - setupChainIndex: func(blkChain []ExecutionBlock[container]) *testChainIndex { - // Start with the most recent 5 blocks in cache - return newTestChainIndex(blkChain[20:]) - }, - setupFetcher: newFetcher, - verifyFunc: func(ctx context.Context, req *require.Assertions, blkChain []ExecutionBlock[container], syncer *Syncer[container, ExecutionBlock[container]]) { - // Perform initial sync with second-to-last block - initialTarget := blkChain[len(blkChain)-2] - err := syncer.Start(ctx, initialTarget) - req.NoError(err) - req.NoError(syncer.Wait(ctx)) - - initialMinTS := syncer.minTimestamp.Load() - initialOldestAllowed := syncer.timeValidityWindow.calculateOldestAllowed(initialTarget.GetTimestamp()) - - // Update to newer block (the last block in chain) - newTarget := blkChain[len(blkChain)-1] - err = syncer.UpdateSyncTarget(ctx, newTarget) - req.NoError(err) - - // Verify window has moved forward - newOldestAllowed := syncer.timeValidityWindow.calculateOldestAllowed(newTarget.GetTimestamp()) - req.Greater(newOldestAllowed, initialOldestAllowed, "window should expand forward with newer block") - - // minTimestamp defines the earliest point in time from which we need to maintain block history - // When new blocks arrive from consensus, they effectively push this boundary forward in time - // as newer blocks are added to the chain - req.Greater(syncer.minTimestamp.Load(), initialMinTS, "min timestamp should move forward with newer block") - }, + setupFetcher: newFetcher, + verify: func(ctx context.Context, r *require.Assertions, blkChain []ExecutionBlock[container], syncer *Syncer[container, ExecutionBlock[container]], _ *testChainIndex) { + // we should have the most recent 5 blocks in-memory + // that is not enough to build full validity window, we need to fetch the rest from the network + target := blkChain[len(blkChain)-1] + err := syncer.Start(ctx, target) + r.NoError(err) + r.NoError(syncer.Wait(ctx)) + + // the last acceptedIndex height should be the last acceptedIndex height from the cache, since historical blocks should not update the last acceptedIndex field + r.Equal(blkChain[len(blkChain)-1].GetHeight(), syncer.timeValidityWindow.lastAcceptedBlockHeight) + r.Equal(blkChain[15].GetHeight(), syncer.oldestBlock.GetHeight()) + + // verify the oldest allowed block in time validity window + r.Equal(blkChain[4].GetTimestamp(), syncer.timeValidityWindow.calculateOldestAllowed(target.GetTimestamp())) + r.NotEqual(blkChain[3].GetTimestamp(), syncer.timeValidityWindow.calculateOldestAllowed(target.GetTimestamp())) }, - { - name: "process sequence of consensus blocks maintains correct window", - validityWindow: 15, - numOfBlocks: 25, - setupChainIndex: func(blkChain []ExecutionBlock[container]) *testChainIndex { - return newTestChainIndex(blkChain[20:]) - }, - setupFetcher: newFetcher, - verifyFunc: func(ctx context.Context, req *require.Assertions, blkChain []ExecutionBlock[container], syncer *Syncer[container, ExecutionBlock[container]]) { - // Start initial sync - initialTarget := blkChain[len(blkChain)-1] - err := syncer.Start(ctx, initialTarget) - req.NoError(err) - req.NoError(syncer.Wait(ctx)) - - initialOldestAllowed := syncer.timeValidityWindow.calculateOldestAllowed(initialTarget.GetTimestamp()) - - // Simulate processing 5 new blocks from consensus - currentTimestamp := initialTarget.GetTimestamp() - currentHeight := initialTarget.GetHeight() - - for i := 0; i < 5; i++ { - newBlock := newExecutionBlock( - currentHeight+1, - currentTimestamp+1, - []int64{}, - ) - // Update sync target - err = syncer.UpdateSyncTarget(ctx, newBlock) - req.NoError(err) - - // Verify window boundaries - newOldestAllowed := syncer.timeValidityWindow.calculateOldestAllowed(newBlock.GetTimestamp()) - req.Greater( - newOldestAllowed, - initialOldestAllowed, - "window should move forward with each consensus block", - ) - - currentTimestamp = newBlock.GetTimestamp() - currentHeight = newBlock.GetHeight() - } - }, + }, + { + name: "should return full validity window from peers", + validityWindow: 15, + numOfBlocks: 20, + setupChainIndex: func(_ []ExecutionBlock[container]) *testChainIndex { + return &testChainIndex{} }, - { - name: "update with block at same height maintains window", - validityWindow: 15, - numOfBlocks: 25, - setupChainIndex: func(blkChain []ExecutionBlock[container]) *testChainIndex { - return newTestChainIndex(blkChain[20:]) - }, - setupFetcher: newFetcher, - verifyFunc: func(ctx context.Context, req *require.Assertions, blkChain []ExecutionBlock[container], syncer *Syncer[container, ExecutionBlock[container]]) { - // Start initial sync - initialTarget := blkChain[len(blkChain)-1] - err := syncer.Start(ctx, initialTarget) - req.NoError(err) - req.NoError(syncer.Wait(ctx)) - - initialOldestAllowed := syncer.timeValidityWindow.calculateOldestAllowed(initialTarget.GetTimestamp()) - - // Create new block at same height but different ID - sameHeightBlock := newExecutionBlock( - initialTarget.GetHeight(), - initialTarget.GetTimestamp(), - []int64{1}, // Different container to get different ID - ) + setupFetcher: newFetcher, + verify: func(ctx context.Context, r *require.Assertions, blkChain []ExecutionBlock[container], syncer *Syncer[container, ExecutionBlock[container]], blockStore *testChainIndex) { + target := blkChain[len(blkChain)-1] + err := syncer.Start(ctx, target) + r.NoError(err) + r.NoError(syncer.Wait(ctx)) + + r.Equal(uint64(19), syncer.timeValidityWindow.lastAcceptedBlockHeight) + r.Equal(syncer.timeValidityWindow.calculateOldestAllowed(target.GetTimestamp()), blkChain[4].GetTimestamp()) + + // Ensure historical (fetched) blocks are saved + // this is required to ensure consistency between on-disk representation and validity window + for i := 3; i < 19; i++ { + blk := blkChain[i] + fetchedBlk, err := blockStore.GetExecutionBlock(ctx, blk.GetID()) + r.NoError(err) + r.Equal(blk, fetchedBlk) + } + }, + }, + { + name: "should stop historical block fetching in case of persistence error", + validityWindow: 15, + numOfBlocks: 20, + setupChainIndex: func(_ []ExecutionBlock[container]) *testChainIndex { + return &testChainIndex{ + beforeSaveFunc: func(_ map[ids.ID]ExecutionBlock[container]) error { + return errSaveHistoricalBlocks + }, + } + }, + setupFetcher: newFetcher, + verify: func(ctx context.Context, r *require.Assertions, blkChain []ExecutionBlock[container], syncer *Syncer[container, ExecutionBlock[container]], blockStore *testChainIndex) { + target := blkChain[len(blkChain)-1] + err := syncer.Start(ctx, target) + r.NoError(err) + r.ErrorIs(syncer.Wait(ctx), errSaveHistoricalBlocks) + + // Last accepted block in time validity window should be the target + r.Equal(uint64(19), syncer.timeValidityWindow.lastAcceptedBlockHeight) + + // No blocks should be saved on chain index starting from target-1 + for i := 3; i < 19; i++ { + blk := blkChain[i] + _, err := blockStore.GetExecutionBlock(ctx, blk.GetID()) + r.Error(err) + } + }, + }, + } - // Update to new block - err = syncer.UpdateSyncTarget(ctx, sameHeightBlock) - req.NoError(err) + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + runSyncerTest(t, test) + }) + } +} + +func TestSyncer_UpdateSyncTarget(t *testing.T) { + tests := []testCases{ + { + name: "update with newer block expands window forward", + validityWindow: 15, + numOfBlocks: 25, + setupChainIndex: func(blkChain []ExecutionBlock[container]) *testChainIndex { + // Start with the most recent 5 blocks in cache + return newTestChainIndex(blkChain[20:]) + }, + setupFetcher: newFetcher, + verify: func(ctx context.Context, r *require.Assertions, blkChain []ExecutionBlock[container], syncer *Syncer[container, ExecutionBlock[container]], _ *testChainIndex) { + // Perform initial sync with second-to-last block + initialTarget := blkChain[len(blkChain)-2] + err := syncer.Start(ctx, initialTarget) + r.NoError(err) + r.NoError(syncer.Wait(ctx)) + + initialMinTS := syncer.minTimestamp.Load() + initialOldestAllowed := syncer.timeValidityWindow.calculateOldestAllowed(initialTarget.GetTimestamp()) + + // Update to newer block (the last block in chain) + newTarget := blkChain[len(blkChain)-1] + err = syncer.UpdateSyncTarget(ctx, newTarget) + r.NoError(err) + + // Verify window has moved forward + newOldestAllowed := syncer.timeValidityWindow.calculateOldestAllowed(newTarget.GetTimestamp()) + r.Greater(newOldestAllowed, initialOldestAllowed, "window should expand forward with newer block") + + // minTimestamp defines the earliest point in time from which we need to maintain block history + // When new blocks arrive from consensus, they effectively push this boundary forward in time + // as newer blocks are added to the chain + r.Greater(syncer.minTimestamp.Load(), initialMinTS, "min timestamp should move forward with newer block") + }, + }, + { + name: "process sequence of consensus blocks maintains correct window", + validityWindow: 15, + numOfBlocks: 25, + setupChainIndex: func(blkChain []ExecutionBlock[container]) *testChainIndex { + return newTestChainIndex(blkChain[20:]) + }, + setupFetcher: newFetcher, + verify: func(ctx context.Context, r *require.Assertions, blkChain []ExecutionBlock[container], syncer *Syncer[container, ExecutionBlock[container]], _ *testChainIndex) { + // Start initial sync + initialTarget := blkChain[len(blkChain)-1] + err := syncer.Start(ctx, initialTarget) + r.NoError(err) + r.NoError(syncer.Wait(ctx)) + + initialOldestAllowed := syncer.timeValidityWindow.calculateOldestAllowed(initialTarget.GetTimestamp()) + + // Simulate processing 5 new blocks from consensus + currentTimestamp := initialTarget.GetTimestamp() + currentHeight := initialTarget.GetHeight() + + for i := 0; i < 5; i++ { + newBlock := newExecutionBlock( + currentHeight+1, + currentTimestamp+1, + []int64{}, + ) + // Update sync target + err = syncer.UpdateSyncTarget(ctx, newBlock) + r.NoError(err) + + // Verify window boundaries + newOldestAllowed := syncer.timeValidityWindow.calculateOldestAllowed(newBlock.GetTimestamp()) + r.Greater( + newOldestAllowed, + initialOldestAllowed, + "window should move forward with each consensus block", + ) - // Verify window remains unchanged - newOldestAllowed := syncer.timeValidityWindow.calculateOldestAllowed(sameHeightBlock.GetTimestamp()) - req.Equal(initialOldestAllowed, newOldestAllowed, "window should not change with same-height block") - }, + currentTimestamp = newBlock.GetTimestamp() + currentHeight = newBlock.GetHeight() + } + }, + }, + { + name: "update with block at same height maintains window", + validityWindow: 15, + numOfBlocks: 25, + setupChainIndex: func(blkChain []ExecutionBlock[container]) *testChainIndex { + return newTestChainIndex(blkChain[20:]) }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - runSyncerTest(t, test) - }) - } - }) + setupFetcher: newFetcher, + verify: func(ctx context.Context, r *require.Assertions, blkChain []ExecutionBlock[container], syncer *Syncer[container, ExecutionBlock[container]], _ *testChainIndex) { + // Start initial sync + initialTarget := blkChain[len(blkChain)-1] + err := syncer.Start(ctx, initialTarget) + r.NoError(err) + r.NoError(syncer.Wait(ctx)) + + initialOldestAllowed := syncer.timeValidityWindow.calculateOldestAllowed(initialTarget.GetTimestamp()) + + // Create new block at same height but different ID + sameHeightBlock := newExecutionBlock( + initialTarget.GetHeight(), + initialTarget.GetTimestamp(), + []int64{1}, // Different container to get different ID + ) + + // Update to new block + err = syncer.UpdateSyncTarget(ctx, sameHeightBlock) + r.NoError(err) + + // Verify window remains unchanged + newOldestAllowed := syncer.timeValidityWindow.calculateOldestAllowed(sameHeightBlock.GetTimestamp()) + r.Equal(initialOldestAllowed, newOldestAllowed, "window should not change with same-height block") + }, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + runSyncerTest(t, test) + }) + } } func runSyncerTest(t *testing.T, test testCases) { ctx := context.Background() - req := require.New(t) + r := require.New(t) blkChain := generateTestChain(test.numOfBlocks) chainIndex := test.setupChainIndex(blkChain) - validityWindow := NewTimeValidityWindow( + head := blkChain[len(blkChain)-1] + validityWindow, err := NewTimeValidityWindow( + ctx, &logging.NoLog{}, trace.Noop, chainIndex, + head, func(_ int64) int64 { return test.validityWindow }, ) + r.NoError(err) fetcher := test.setupFetcher(blkChain) syncer := NewSyncer[container, ExecutionBlock[container]]( @@ -245,7 +286,7 @@ func runSyncerTest(t *testing.T, test testCases) { func(_ int64) int64 { return test.validityWindow }, ) - test.verifyFunc(ctx, req, blkChain, syncer) + test.verify(ctx, r, blkChain, syncer, chainIndex) } func newTestChainIndex(blocks []ExecutionBlock[container]) *testChainIndex { diff --git a/internal/validitywindow/validitywindow.go b/internal/validitywindow/validitywindow.go index 7155fb1cf0..ac2c855c45 100644 --- a/internal/validitywindow/validitywindow.go +++ b/internal/validitywindow/validitywindow.go @@ -7,6 +7,7 @@ import ( "context" "errors" "fmt" + "slices" "sync" "time" @@ -21,6 +22,7 @@ import ( var ( _ Interface[emap.Item] = (*TimeValidityWindow[emap.Item])(nil) + ErrNilInitialBlock = errors.New("missing head block") ErrDuplicateContainer = errors.New("duplicate container") ErrMisalignedTime = errors.New("misaligned time") ErrTimestampExpired = errors.New("declared timestamp expired") @@ -53,36 +55,66 @@ type Interface[T emap.Item] interface { IsRepeat(ctx context.Context, parentBlk ExecutionBlock[T], currentTimestamp int64, containers []T) (set.Bits, error) } +// TimeValidityWindow is a timestamp-based replay protection mechanism. +// It maintains a configurable timestamp validity window of emap.Item entries +// that have been with a specific expiry time. +// +// Items can be: +// 1. Too far in the future (item timestamp > block timestamp + validity window) +// 2. Currently valid for inclusion (block timestamp <= item timestamp <= block timestamp + validity window) +// 3. Expired (item timestamp < block timestamp) +// +// To prevent duplicates; we track all items included in a block until they expire. +// Once they are invalidated by their timestamp, we remove them from tracking as it's no +// longer necessary to track them to guarantee that they cannot be replayed. +// +// TimeValidityWindow assumes the ChainIndex contains all blocks with timestamps +// in the validity window interval +1 extra block below the interval minimum. +// The extra block below the interval minimum is necessary to verify that every +// block within the interval has been included. type TimeValidityWindow[T emap.Item] struct { - log logging.Logger - tracer trace.Tracer - - lock sync.Mutex + log logging.Logger + tracer trace.Tracer + mu sync.Mutex chainIndex ChainIndex[T] seen *emap.EMap[T] lastAcceptedBlockHeight uint64 getTimeValidityWindow GetTimeValidityWindowFunc } +// NewTimeValidityWindow constructs TimeValidityWindow and eagerly tries to populate +// a validity window from the tip func NewTimeValidityWindow[T emap.Item]( + ctx context.Context, log logging.Logger, tracer trace.Tracer, chainIndex ChainIndex[T], + head ExecutionBlock[T], getTimeValidityWindowF GetTimeValidityWindowFunc, -) *TimeValidityWindow[T] { - return &TimeValidityWindow[T]{ +) (*TimeValidityWindow[T], error) { + t := &TimeValidityWindow[T]{ log: log, tracer: tracer, chainIndex: chainIndex, seen: emap.NewEMap[T](), getTimeValidityWindow: getTimeValidityWindowF, } + + t.populate(ctx, head) + return t, nil +} + +// Complete will attempt to complete a validity window. +// It returns a boolean that signals if it's ready to reliably prevent replay attacks +func (v *TimeValidityWindow[T]) Complete(ctx context.Context, block ExecutionBlock[T]) bool { + _, isComplete := v.populate(ctx, block) + return isComplete } func (v *TimeValidityWindow[T]) Accept(blk ExecutionBlock[T]) { // Grab the lock before modifiying seen - v.lock.Lock() - defer v.lock.Unlock() + v.mu.Lock() + defer v.mu.Unlock() evicted := v.seen.SetMin(blk.GetTimestamp()) v.log.Debug("accepting block to validity window", @@ -95,8 +127,8 @@ func (v *TimeValidityWindow[T]) Accept(blk ExecutionBlock[T]) { } func (v *TimeValidityWindow[T]) AcceptHistorical(blk ExecutionBlock[T]) { - v.lock.Lock() - defer v.lock.Unlock() + v.mu.Lock() + defer v.mu.Unlock() v.log.Debug("adding historical block to validity window", zap.Stringer("blkID", blk.GetID()), @@ -113,7 +145,11 @@ func (v *TimeValidityWindow[T]) VerifyExpiryReplayProtection( _, span := v.tracer.Start(ctx, "Chain.VerifyExpiryReplayProtection") defer span.End() - if blk.GetHeight() <= v.lastAcceptedBlockHeight { + v.mu.Lock() + lastAcceptedBlockHeight := v.lastAcceptedBlockHeight + v.mu.Unlock() + + if blk.GetHeight() <= lastAcceptedBlockHeight { return nil } @@ -164,8 +200,8 @@ func (v *TimeValidityWindow[T]) isRepeat( ) (set.Bits, error) { marker := set.NewBits() - v.lock.Lock() - defer v.lock.Unlock() + v.mu.Lock() + defer v.mu.Unlock() var err error for { @@ -196,24 +232,27 @@ func (v *TimeValidityWindow[T]) isRepeat( } } -func (v *TimeValidityWindow[T]) calculateOldestAllowed(timestamp int64) int64 { - return max(0, timestamp-v.getTimeValidityWindow(timestamp)) -} - -func (v *TimeValidityWindow[T]) PopulateValidityWindow(ctx context.Context, block ExecutionBlock[T]) ([]ExecutionBlock[T], bool) { +func (v *TimeValidityWindow[T]) populate(ctx context.Context, block ExecutionBlock[T]) ([]ExecutionBlock[T], bool) { var ( parent = block parents = []ExecutionBlock[T]{parent} - seenValidityWindow = false - validityWindow = v.getTimeValidityWindow(block.GetTimestamp()) + fullValidityWindow = false + oldestAllowed = v.calculateOldestAllowed(block.GetTimestamp()) err error ) // Keep fetching parents until we: + // - Reach block height 0 (Genesis) at that point we have a full validity window, + // and we can correctly preform replay protection // - Fill a validity window, or // - Can't find more blocks // Descending order is guaranteed by the parent-based traversal method for { + if parent.GetHeight() == 0 { + fullValidityWindow = true + break + } + // Get execution block from cache or disk parent, err = v.chainIndex.GetExecutionBlock(ctx, parent.GetParent()) if err != nil { @@ -221,16 +260,23 @@ func (v *TimeValidityWindow[T]) PopulateValidityWindow(ctx context.Context, bloc } parents = append(parents, parent) - seenValidityWindow = block.GetTimestamp()-parent.GetTimestamp() > validityWindow - if seenValidityWindow { + fullValidityWindow = parent.GetTimestamp() < oldestAllowed + if fullValidityWindow { break } } - for i := len(parents) - 1; i >= 0; i-- { - v.Accept(parents[i]) + // Reverse blocks to process in chronological order + slices.Reverse(parents) + for _, blk := range parents { + v.Accept(blk) } - return parents, seenValidityWindow + + return parents, fullValidityWindow +} + +func (v *TimeValidityWindow[T]) calculateOldestAllowed(timestamp int64) int64 { + return max(0, timestamp-v.getTimeValidityWindow(timestamp)) } func VerifyTimestamp(containerTimestamp int64, executionTimestamp int64, divisor int64, validityWindow int64) error { diff --git a/internal/validitywindow/validitywindow_test.go b/internal/validitywindow/validitywindow_test.go index 4808ccd81d..f009da80fb 100644 --- a/internal/validitywindow/validitywindow_test.go +++ b/internal/validitywindow/validitywindow_test.go @@ -16,6 +16,20 @@ import ( "github.com/stretchr/testify/require" ) +func newPopulatedValidityWindow(ctx context.Context, r *require.Assertions, blocks []executionBlock, head executionBlock, validityWindowDuration int64) *TimeValidityWindow[container] { + chainIndex := &testChainIndex{} + for _, blk := range blocks { + chainIndex.set(blk.GetID(), blk) + } + + validityWindow, err := NewTimeValidityWindow(ctx, &logging.NoLog{}, trace.Noop, chainIndex, head, func(int64) int64 { + return validityWindowDuration + }) + r.NoError(err) + + return validityWindow +} + func TestValidityWindowVerifyExpiryReplayProtection(t *testing.T) { tests := []struct { name string @@ -116,17 +130,9 @@ func TestValidityWindowVerifyExpiryReplayProtection(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { r := require.New(t) + ctx := context.Background() - chainIndex := &testChainIndex{} - validityWindow := NewTimeValidityWindow(&logging.NoLog{}, trace.Noop, chainIndex, func(int64) int64 { - return test.validityWindow - }) - for i, blk := range test.blocks { - if i <= test.accepted { - validityWindow.Accept(blk) - } - chainIndex.set(blk.GetID(), blk) - } + validityWindow := newPopulatedValidityWindow(ctx, r, test.blocks, test.blocks[test.accepted], test.validityWindow) r.ErrorIs(validityWindow.VerifyExpiryReplayProtection(context.Background(), test.verifyBlock), test.expectedError) }) } @@ -272,17 +278,9 @@ func TestValidityWindowIsRepeat(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { r := require.New(t) + ctx := context.Background() - chainIndex := &testChainIndex{} - validityWindow := NewTimeValidityWindow(&logging.NoLog{}, trace.Noop, chainIndex, func(int64) int64 { - return test.validityWindow - }) - for i, blk := range test.blocks { - if i <= test.accepted { - validityWindow.Accept(blk) - } - chainIndex.set(blk.GetID(), blk) - } + validityWindow := newPopulatedValidityWindow(ctx, r, test.blocks, test.blocks[test.accepted], test.validityWindow) parent := test.blocks[len(test.blocks)-1] if test.overrideParentBlock != nil { parent = test.overrideParentBlock() @@ -357,19 +355,21 @@ func TestVerifyTimestamp(t *testing.T) { } // TestValidityWindowBoundaryLifespan tests that a container included at the validity window boundary transitions -// seamlessly from failing veriifcation due to a duplicate within the validity window to failing because it expired. +// seamlessly from failing verification due to a duplicate within the validity window to failing because it expired. func TestValidityWindowBoundaryLifespan(t *testing.T) { r := require.New(t) + ctx := context.Background() + // Create accepted genesis block chainIndex := &testChainIndex{} + genesisBlk := newExecutionBlock(0, 0, []int64{1}) + chainIndex.set(genesisBlk.GetID(), genesisBlk) + validityWindowDuration := int64(10) - validityWindow := NewTimeValidityWindow(&logging.NoLog{}, trace.Noop, chainIndex, func(int64) int64 { + validityWindow, err := NewTimeValidityWindow[container](ctx, &logging.NoLog{}, trace.Noop, chainIndex, genesisBlk, func(int64) int64 { return validityWindowDuration }) - - // Create accepted genesis block - genesisBlk := newExecutionBlock(0, 0, []int64{1}) - chainIndex.set(genesisBlk.GetID(), genesisBlk) + r.NoError(err) validityWindow.Accept(genesisBlk) blk1 := newExecutionBlock(1, 0, []int64{validityWindowDuration}) @@ -397,16 +397,18 @@ func TestValidityWindowBoundaryLifespan(t *testing.T) { func TestAcceptHistorical(t *testing.T) { r := require.New(t) + ctx := context.Background() + // Create and accept the genesis block to set an initial lastAcceptedBlockHeight chainIndex := &testChainIndex{} + genesisBlk := newExecutionBlock(0, 0, []int64{}) + chainIndex.set(genesisBlk.GetID(), genesisBlk) + validityWindowDuration := int64(10) - validityWindow := NewTimeValidityWindow(&logging.NoLog{}, trace.Noop, chainIndex, func(int64) int64 { + validityWindow, err := NewTimeValidityWindow(ctx, &logging.NoLog{}, trace.Noop, chainIndex, genesisBlk, func(int64) int64 { return validityWindowDuration }) - - // Create and accept the genesis block to set an initial lastAcceptedBlockHeight - genesisBlk := newExecutionBlock(0, 0, []int64{}) - chainIndex.set(genesisBlk.GetID(), genesisBlk) + r.NoError(err) validityWindow.Accept(genesisBlk) r.Equal(uint64(0), validityWindow.lastAcceptedBlockHeight) @@ -425,7 +427,8 @@ func TestAcceptHistorical(t *testing.T) { } type testChainIndex struct { - blocks map[ids.ID]ExecutionBlock[container] + beforeSaveFunc func(blocks map[ids.ID]ExecutionBlock[container]) error + blocks map[ids.ID]ExecutionBlock[container] } func (t *testChainIndex) GetExecutionBlock(_ context.Context, blkID ids.ID) (ExecutionBlock[container], error) { @@ -435,6 +438,18 @@ func (t *testChainIndex) GetExecutionBlock(_ context.Context, blkID ids.ID) (Exe return nil, database.ErrNotFound } +func (t *testChainIndex) SaveHistorical(blk ExecutionBlock[container]) error { + if t.blocks == nil { + t.blocks = make(map[ids.ID]ExecutionBlock[container]) + } + + if t.beforeSaveFunc != nil { + return t.beforeSaveFunc(t.blocks) + } + t.blocks[blk.GetID()] = blk + return nil +} + func (t *testChainIndex) set(blkID ids.ID, blk ExecutionBlock[container]) { if t.blocks == nil { t.blocks = make(map[ids.ID]ExecutionBlock[container]) diff --git a/snow/vm_test.go b/snow/vm_test.go index e565cd9670..45fa3d9dfc 100644 --- a/snow/vm_test.go +++ b/snow/vm_test.go @@ -137,7 +137,7 @@ func (t *TestChain) Initialize( chainInput ChainInput, _ *VM[*TestBlock, *TestBlock, *TestBlock], ) (ChainIndex[*TestBlock], *TestBlock, *TestBlock, bool, error) { - chainIndex, err := chainindex.New[*TestBlock](chainInput.SnowCtx.Log, prometheus.NewRegistry(), chainindex.NewDefaultConfig(), t, memdb.New()) + chainIndex, err := chainindex.New[*TestBlock](ctx, chainInput.SnowCtx.Log, prometheus.NewRegistry(), chainindex.NewDefaultConfig(), t, memdb.New()) if err != nil { return nil, nil, nil, false, err } diff --git a/vm/statesync.go b/vm/statesync.go index 79e89dbf5f..4953cf7b68 100644 --- a/vm/statesync.go +++ b/vm/statesync.go @@ -69,7 +69,7 @@ func (vm *VM) initStateSync(ctx context.Context) error { vm, p2p.PeerSampler{Peers: vm.network.Peers}, ) - syncer := validitywindow.NewSyncer[*chain.Transaction, *chain.ExecutionBlock](vm, vm.chainTimeValidityWindow, blockFetcherClient, func(time int64) int64 { + syncer := validitywindow.NewSyncer[*chain.Transaction, *chain.ExecutionBlock](vm.chainStore, vm.chainTimeValidityWindow, blockFetcherClient, func(time int64) int64 { return vm.ruleFactory.GetRules(time).GetValidityWindow() }) diff --git a/vm/vm.go b/vm/vm.go index a82f2e6b4b..55fef1a0b2 100644 --- a/vm/vm.go +++ b/vm/vm.go @@ -62,7 +62,11 @@ const ( blockFetchHandleID = 0x3 ) -var ErrNotAdded = errors.New("not added") +var ( + errInconsistentValidityWindow = errors.New("critical error: validity window's partial state may lead to inconsistencies") + + ErrNotAdded = errors.New("not added") +) var ( _ hsnow.Block = (*chain.ExecutionBlock)(nil) @@ -297,9 +301,24 @@ func (vm *VM) Initialize( return nil, nil, nil, false, fmt.Errorf("failed to apply options : %w", err) } - vm.chainTimeValidityWindow = validitywindow.NewTimeValidityWindow(vm.snowCtx.Log, vm.tracer, vm, func(timestamp int64) int64 { + executionBlockParser := chain.NewBlockParser(vm.Tracer(), vm.txParser) + + if err := vm.initChainStore(ctx, executionBlockParser); err != nil { + return nil, nil, nil, false, err + } + + lastAccepted, err := vm.initLastAccepted(ctx) + if err != nil { + return nil, nil, nil, false, fmt.Errorf("failed to initialize last accepted block: %w", err) + } + + vm.chainTimeValidityWindow, err = validitywindow.NewTimeValidityWindow[*chain.Transaction](ctx, vm.snowCtx.Log, vm.tracer, vm, lastAccepted, func(timestamp int64) int64 { return vm.ruleFactory.GetRules(timestamp).GetValidityWindow() }) + if err != nil { + return nil, nil, nil, false, fmt.Errorf("failed to initialize chain time validity window: %w", err) + } + chainRegistry, err := metrics.MakeAndRegister(vm.snowCtx.Metrics, chainNamespace) if err != nil { return nil, nil, nil, false, fmt.Errorf("failed to make %q registry: %w", chainNamespace, err) @@ -308,10 +327,10 @@ func (vm *VM) Initialize( if err != nil { return nil, nil, nil, false, fmt.Errorf("failed to get chain config: %w", err) } + vm.chain, err = chain.NewChain( vm.Tracer(), chainRegistry, - vm.txParser, vm.Mempool(), vm.Logger(), vm.ruleFactory, @@ -319,6 +338,7 @@ func (vm *VM) Initialize( vm.BalanceHandler(), vm.AuthVerifiers(), vm.authEngines, + executionBlockParser, vm.chainTimeValidityWindow, chainConfig, ) @@ -326,18 +346,10 @@ func (vm *VM) Initialize( return nil, nil, nil, false, err } - if err := vm.initChainStore(); err != nil { - return nil, nil, nil, false, err - } - if err := vm.initStateSync(ctx); err != nil { return nil, nil, nil, false, err } - if err := vm.populateValidityWindow(ctx); err != nil { - return nil, nil, nil, false, err - } - snowApp.AddNormalOpStarter(func(_ context.Context) error { if vm.SyncClient.Started() { return nil @@ -354,13 +366,11 @@ func (vm *VM) Initialize( } stateReady := !vm.SyncClient.MustStateSync() - var lastAccepted *chain.OutputBlock - if stateReady { - lastAccepted, err = vm.initLastAccepted(ctx) - if err != nil { - return nil, nil, nil, false, err - } + // The branch is executed when the VM must preform state sync + if !stateReady { + lastAccepted = nil } + return vm.chainStore, lastAccepted, lastAccepted, stateReady, nil } @@ -368,7 +378,7 @@ func (vm *VM) SetConsensusIndex(consensusIndex *hsnow.ConsensusIndex[*chain.Exec vm.consensusIndex = consensusIndex } -func (vm *VM) initChainStore() error { +func (vm *VM) initChainStore(ctx context.Context, executionBlockParser *chain.BlockParser) error { blockDBRegistry, err := metrics.MakeAndRegister(vm.snowCtx.Metrics, blockDB) if err != nil { return fmt.Errorf("failed to register %s metrics: %w", blockDB, err) @@ -383,13 +393,18 @@ func (vm *VM) initChainStore() error { if err != nil { return fmt.Errorf("failed to create chain index config: %w", err) } - vm.chainStore, err = chainindex.New[*chain.ExecutionBlock](vm.snowCtx.Log, blockDBRegistry, config, vm.chain, chainStoreDB) + vm.chainStore, err = chainindex.New[*chain.ExecutionBlock](ctx, vm.snowCtx.Log, blockDBRegistry, config, executionBlockParser, chainStoreDB) if err != nil { return fmt.Errorf("failed to create chain index: %w", err) } return nil } +// initLastAccepted determines and loads the last accepted block during VM initialization. +// It serves three critical purposes: +// 1. For a fresh chain: Creates and commits genesis block +// 2. For an existing chain: Load the last accepted block that corresponds to current state +// 3. Ensures state consistency between the chain store and state database func (vm *VM) initLastAccepted(ctx context.Context) (*chain.OutputBlock, error) { lastAcceptedHeight, err := vm.chainStore.GetLastAcceptedHeight(ctx) if err != nil && err != database.ErrNotFound { @@ -521,6 +536,8 @@ func (vm *VM) initGenesisAsLastAccepted(ctx context.Context) (*chain.OutputBlock }, nil } +// startNormalOp initializes components required for normal VM operation when transitioning +// from bootstrapping or state sync func (vm *VM) startNormalOp(ctx context.Context) error { vm.builder.Start() vm.snowApp.AddCloser("builder", func() error { @@ -544,8 +561,22 @@ func (vm *VM) startNormalOp(ctx context.Context) error { return fmt.Errorf("failed to add tx gossip handler: %w", err) } vm.checkActivity(ctx) - vm.normalOp.Store(true) + lastAccepted, err := vm.LastAcceptedBlock(ctx) + if err != nil { + return err + } + executionBlk, err := vm.GetExecutionBlock(ctx, lastAccepted.GetID()) + if err != nil { + return err + } + + isValidityWindowComplete := vm.chainTimeValidityWindow.Complete(ctx, executionBlk) + if !isValidityWindowComplete { + return errInconsistentValidityWindow + } + + vm.normalOp.Store(true) return nil } @@ -701,29 +732,3 @@ func (vm *VM) Submit( vm.snowCtx.Log.Info("Submitted tx(s)", zap.Int("validTxs", len(validTxs)), zap.Int("invalidTxs", len(errs)-len(validTxs)), zap.Int("mempoolSize", vm.mempool.Len(ctx))) return errs } - -// populateValidityWindow populates the VM's time validity window on startup, -// ensuring it contains recent transactions even if state sync is skipped (e.g., due to restart). -// This is necessary because a node might restart with only a few blocks behind (or slightly ahead) -// of the network, and thus opt not to trigger state sync. Without backfilling, the node's validity window -// may be incomplete, causing the node to accept a duplicate transaction that the network already processed. - -// When Initialize is called, vm.consensusIndex is nil—it is set later via SetConsensusIndex. -// Therefore, we must use the chainStore (which reads blocks from disk) to backfill the validity window. -// This prepopulation ensures the validity window is complete, even if state sync is skipped. -func (vm *VM) populateValidityWindow(ctx context.Context) error { - lastAcceptedBlkHeight, err := vm.chainStore.GetLastAcceptedHeight(ctx) - if err != nil { - if errors.Is(err, database.ErrNotFound) { - return nil - } - return err - } - lastAcceptedBlock, err := vm.chainStore.GetBlockByHeight(ctx, lastAcceptedBlkHeight) - if err != nil { - return err - } - - vm.chainTimeValidityWindow.PopulateValidityWindow(ctx, lastAcceptedBlock) - return nil -} diff --git a/vm/vm_test.go b/vm/vm_test.go index 7c64f5ff71..c667a4ddb6 100644 --- a/vm/vm_test.go +++ b/vm/vm_test.go @@ -738,8 +738,8 @@ func TestSkipStateSync(t *testing.T) { return []*chain.Transaction{tx} }) - config := map[string]interface{}{ - vm.StateSyncNamespace: map[string]interface{}{ + config := map[string]any{ + vm.StateSyncNamespace: map[string]any{ "minBlocks": uint64(numBlocks + 1), }, } @@ -809,8 +809,8 @@ func TestStateSync(t *testing.T) { return []*chain.Transaction{tx} }) - config := map[string]interface{}{ - vm.StateSyncNamespace: map[string]interface{}{ + config := map[string]any{ + vm.StateSyncNamespace: map[string]any{ "minBlocks": uint64(numBlocks - 1), }, }