diff --git a/.avalanche-golangci.yml b/.avalanche-golangci.yml index 12650b5146..e0aed8d2ad 100644 --- a/.avalanche-golangci.yml +++ b/.avalanche-golangci.yml @@ -54,7 +54,7 @@ linters: - asciicheck - bodyclose - copyloopvar - # - depguard + - depguard # - errcheck - errorlint - forbidigo diff --git a/core/blockchain_ext_test.go b/core/blockchain_ext_test.go index cb4e1bbd1a..4488878330 100644 --- a/core/blockchain_ext_test.go +++ b/core/blockchain_ext_test.go @@ -18,7 +18,6 @@ import ( "github.com/ava-labs/libevm/crypto" "github.com/ava-labs/libevm/ethdb" "github.com/holiman/uint256" - "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/ava-labs/subnet-evm/commontype" @@ -1439,7 +1438,6 @@ func StatefulPrecompiles(t *testing.T, create createFunc) { MaxBlockGasCost: big.NewInt(4_000_000), BlockGasCostStep: big.NewInt(500_000), } - assert := assert.New(t) tests := map[string]test{ "allow list": { addTx: func(gen *BlockGen) { @@ -1501,23 +1499,23 @@ func StatefulPrecompiles(t *testing.T, create createFunc) { }, verifyState: func(sdb *state.StateDB) error { res := feemanager.GetFeeManagerStatus(sdb, addr1) - assert.Equal(allowlist.AdminRole, res) + require.Equal(t, allowlist.AdminRole, res) storedConfig := feemanager.GetStoredFeeConfig(sdb) - assert.Equal(testFeeConfig, storedConfig) + require.Equal(t, testFeeConfig, storedConfig) feeConfig, _, err := blockchain.GetFeeConfigAt(blockchain.CurrentHeader()) require.NoError(t, err) - assert.Equal(testFeeConfig, feeConfig) + require.Equal(t, testFeeConfig, feeConfig) return nil }, verifyGenesis: func(sdb *state.StateDB) { res := feemanager.GetFeeManagerStatus(sdb, addr1) - assert.Equal(allowlist.AdminRole, res) + require.Equal(t, allowlist.AdminRole, res) feeConfig, _, err := blockchain.GetFeeConfigAt(blockchain.Genesis().Header()) require.NoError(t, err) - assert.Equal(params.GetExtra(&config).FeeConfig, feeConfig) + require.Equal(t, params.GetExtra(&config).FeeConfig, feeConfig) }, }, } diff --git a/core/coretest/test_indices.go b/core/coretest/test_indices.go index 7549296911..3c34768822 100644 --- a/core/coretest/test_indices.go +++ b/core/coretest/test_indices.go @@ -9,7 +9,6 @@ import ( "github.com/ava-labs/libevm/core/rawdb" "github.com/ava-labs/libevm/ethdb" - "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -25,10 +24,10 @@ func CheckTxIndices(t *testing.T, expectedTail *uint64, indexedFrom uint64, inde var stored uint64 tailValue := *expectedTail - require.EventuallyWithTf(t, - func(c *assert.CollectT) { + require.Eventually(t, + func() bool { stored = *rawdb.ReadTxIndexTail(db) - assert.Equalf(c, tailValue, stored, "expected tail to be %d, found %d", tailValue, stored) + return tailValue == stored }, 30*time.Second, 500*time.Millisecond, "expected tail to be %d eventually", tailValue) } diff --git a/internal/ethapi/api_extra_test.go b/internal/ethapi/api_extra_test.go index 189eac7dd1..e9c0222e77 100644 --- a/internal/ethapi/api_extra_test.go +++ b/internal/ethapi/api_extra_test.go @@ -12,7 +12,6 @@ import ( "github.com/ava-labs/libevm/common" "github.com/ava-labs/libevm/common/hexutil" "github.com/ava-labs/libevm/core/types" - "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "go.uber.org/mock/gomock" @@ -47,7 +46,7 @@ func TestBlockchainAPI_GetChainConfig(t *testing.T) { api := NewBlockChainAPI(backend) gotConfig := api.GetChainConfig(t.Context()) - assert.Equal(t, params.ToWithUpgradesJSON(wantConfig), gotConfig) + require.Equal(t, params.ToWithUpgradesJSON(wantConfig), gotConfig) } // Copy one test case from TestCall @@ -204,9 +203,9 @@ func TestBlockChainAPI_stateQueryBlockNumberAllowed(t *testing.T) { err := api.stateQueryBlockNumberAllowed(testCase.blockNumOrHash) if testCase.wantErrMessage == "" { - assert.NoError(t, err) + require.NoError(t, err) } else { - assert.EqualError(t, err, testCase.wantErrMessage) + require.EqualError(t, err, testCase.wantErrMessage) } }) } diff --git a/network/network_test.go b/network/network_test.go index 1608ca30ef..8ffb1930a6 100644 --- a/network/network_test.go +++ b/network/network_test.go @@ -22,8 +22,8 @@ import ( "github.com/ava-labs/avalanchego/utils/set" "github.com/ava-labs/avalanchego/version" "github.com/prometheus/client_golang/prometheus" - "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "golang.org/x/sync/errgroup" "github.com/ava-labs/subnet-evm/network/peertest" "github.com/ava-labs/subnet-evm/plugin/evm/message" @@ -59,37 +59,31 @@ func TestNetworkDoesNotConnectToItself(t *testing.T) { } func TestRequestAnyRequestsRoutingAndResponse(t *testing.T) { - callNum := uint32(0) - senderWg := &sync.WaitGroup{} - var net Network + var ( + net Network + callNum atomic.Uint32 + ) + eg, ctx := errgroup.WithContext(t.Context()) sender := testAppSender{ sendAppRequestFn: func(_ context.Context, nodes set.Set[ids.NodeID], requestID uint32, requestBytes []byte) error { nodeID, _ := nodes.Pop() - senderWg.Add(1) - go func() { - defer senderWg.Done() - if err := net.AppRequest(t.Context(), nodeID, requestID, time.Now().Add(5*time.Second), requestBytes); err != nil { - panic(err) - } - }() + eg.Go(func() error { + return net.AppRequest(ctx, nodeID, requestID, time.Now().Add(5*time.Second), requestBytes) + }) return nil }, sendAppResponseFn: func(nodeID ids.NodeID, requestID uint32, responseBytes []byte) error { - senderWg.Add(1) - go func() { - defer senderWg.Done() - if err := net.AppResponse(t.Context(), nodeID, requestID, responseBytes); err != nil { - panic(err) - } - atomic.AddUint32(&callNum, 1) - }() + eg.Go(func() error { + defer callNum.Add(1) + return net.AppResponse(ctx, nodeID, requestID, responseBytes) + }) return nil }, } codecManager := buildCodec(t, HelloRequest{}, HelloResponse{}) - ctx := snowtest.Context(t, snowtest.CChainID) - net, err := NewNetwork(ctx, sender, codecManager, 16, prometheus.NewRegistry()) + snowCtx := snowtest.Context(t, snowtest.CChainID) + net, err := NewNetwork(snowCtx, sender, codecManager, 16, prometheus.NewRegistry()) require.NoError(t, err) net.SetRequestHandler(&HelloGreetingRequestHandler{codec: codecManager}) nodeID := ids.GenerateTestNodeID() @@ -100,32 +94,35 @@ func TestRequestAnyRequestsRoutingAndResponse(t *testing.T) { defer net.Shutdown() require.NoError(t, net.Connected(t.Context(), nodeID, defaultPeerVersion)) - totalRequests := 5000 - numCallsPerRequest := 1 // on sending response - totalCalls := totalRequests * numCallsPerRequest - - requestWg := &sync.WaitGroup{} - requestWg.Add(totalCalls) - for i := 0; i < totalCalls; i++ { - go func(wg *sync.WaitGroup) { - defer wg.Done() + const ( + totalRequests = 5000 + numCallsPerRequest = 1 // on sending response + totalCalls uint32 = totalRequests * numCallsPerRequest + ) + for range totalCalls { + eg.Go(func() error { requestBytes, err := message.RequestToBytes(codecManager, requestMessage) - assert.NoError(t, err) - responseBytes, _, err := net.SendSyncedAppRequestAny(t.Context(), defaultPeerVersion, requestBytes) - assert.NoError(t, err) - assert.NotNil(t, responseBytes) + if err != nil { + return fmt.Errorf("failed to encode request: %w", err) + } + + responseBytes, _, err := net.SendSyncedAppRequestAny(ctx, defaultPeerVersion, requestBytes) + if err != nil { + return fmt.Errorf("failed to send synced app request: %w", err) + } var response TestMessage if _, err = codecManager.Unmarshal(responseBytes, &response); err != nil { - panic(fmt.Errorf("unexpected error during unmarshal: %w", err)) + return fmt.Errorf("failed to decode response: %w", err) + } + if response.Message != "Hi" { + return fmt.Errorf("expected response message 'Hi', got %q", response.Message) } - assert.Equal(t, "Hi", response.Message) - }(requestWg) + return nil + }) } - - requestWg.Wait() - senderWg.Wait() - require.Equal(t, totalCalls, int(atomic.LoadUint32(&callNum))) + require.NoError(t, eg.Wait()) // only the first error is informative + require.Equal(t, totalCalls, callNum.Load()) } func TestAppRequestOnCtxCancellation(t *testing.T) { @@ -158,42 +155,36 @@ func TestAppRequestOnCtxCancellation(t *testing.T) { } func TestRequestRequestsRoutingAndResponse(t *testing.T) { - callNum := uint32(0) - senderWg := &sync.WaitGroup{} - var net Network - var lock sync.Mutex - contactedNodes := make(map[ids.NodeID]struct{}) + var ( + net Network + lock sync.Mutex + contactedNodes set.Set[ids.NodeID] + callNum atomic.Uint32 + ) + eg, ctx := errgroup.WithContext(t.Context()) sender := testAppSender{ sendAppRequestFn: func(_ context.Context, nodes set.Set[ids.NodeID], requestID uint32, requestBytes []byte) error { nodeID, _ := nodes.Pop() lock.Lock() - contactedNodes[nodeID] = struct{}{} + contactedNodes.Add(nodeID) lock.Unlock() - senderWg.Add(1) - go func() { - defer senderWg.Done() - if err := net.AppRequest(t.Context(), nodeID, requestID, time.Now().Add(5*time.Second), requestBytes); err != nil { - panic(err) - } - }() + eg.Go(func() error { + return net.AppRequest(ctx, nodeID, requestID, time.Now().Add(5*time.Second), requestBytes) + }) return nil }, sendAppResponseFn: func(nodeID ids.NodeID, requestID uint32, responseBytes []byte) error { - senderWg.Add(1) - go func() { - defer senderWg.Done() - if err := net.AppResponse(t.Context(), nodeID, requestID, responseBytes); err != nil { - panic(err) - } - atomic.AddUint32(&callNum, 1) - }() + eg.Go(func() error { + defer callNum.Add(1) + return net.AppResponse(ctx, nodeID, requestID, responseBytes) + }) return nil }, } codecManager := buildCodec(t, HelloRequest{}, HelloResponse{}) - ctx := snowtest.Context(t, snowtest.CChainID) - net, err := NewNetwork(ctx, sender, codecManager, 16, prometheus.NewRegistry()) + snowCtx := snowtest.Context(t, snowtest.CChainID) + net, err := NewNetwork(snowCtx, sender, codecManager, 16, prometheus.NewRegistry()) require.NoError(t, err) net.SetRequestHandler(&HelloGreetingRequestHandler{codec: codecManager}) @@ -211,38 +202,37 @@ func TestRequestRequestsRoutingAndResponse(t *testing.T) { requestMessage := HelloRequest{Message: "this is a request"} defer net.Shutdown() - totalRequests := 5000 - numCallsPerRequest := 1 // on sending response - totalCalls := totalRequests * numCallsPerRequest - - requestWg := &sync.WaitGroup{} - requestWg.Add(totalCalls) - nodeIdx := 0 - for i := 0; i < totalCalls; i++ { - nodeIdx = (nodeIdx + 1) % (len(nodes)) - nodeID := nodes[nodeIdx] - go func(wg *sync.WaitGroup, nodeID ids.NodeID) { - defer wg.Done() + const ( + totalRequests = 5000 + numCallsPerRequest = 1 // on sending response + totalCalls = totalRequests * numCallsPerRequest + ) + for i := range totalCalls { + nodeID := nodes[i%len(nodes)] + eg.Go(func() error { requestBytes, err := message.RequestToBytes(codecManager, requestMessage) - assert.NoError(t, err) - responseBytes, err := net.SendSyncedAppRequest(t.Context(), nodeID, requestBytes) - assert.NoError(t, err) - assert.NotNil(t, responseBytes) + if err != nil { + return fmt.Errorf("failed to encode request: %w", err) + } + + responseBytes, err := net.SendSyncedAppRequest(ctx, nodeID, requestBytes) + if err != nil { + return fmt.Errorf("failed to send synced app request: %w", err) + } var response TestMessage if _, err = codecManager.Unmarshal(responseBytes, &response); err != nil { - panic(fmt.Errorf("unexpected error during unmarshal: %w", err)) + return fmt.Errorf("failed to decode response: %w", err) } - assert.Equal(t, "Hi", response.Message) - }(requestWg, nodeID) - } - - requestWg.Wait() - senderWg.Wait() - require.Equal(t, totalCalls, int(atomic.LoadUint32(&callNum))) - for _, nodeID := range nodes { - require.Contains(t, contactedNodes, nodeID, "node %s was not contacted", nodeID) + if response.Message != "Hi" { + return fmt.Errorf("expected response message 'Hi', got %q", response.Message) + } + return nil + }) } + require.NoError(t, eg.Wait()) + require.Equal(t, uint32(totalCalls), callNum.Load()) + require.Equal(t, set.Of(nodes...), contactedNodes) // ensure empty nodeID is not allowed require.ErrorContains(t, @@ -280,16 +270,11 @@ func TestAppRequestOnShutdown(t *testing.T) { requestMessage := HelloRequest{Message: "this is a request"} require.NoError(t, net.Connected(t.Context(), nodeID, defaultPeerVersion)) - wg.Add(1) - go func() { - defer wg.Done() - requestBytes, err := message.RequestToBytes(codecManager, requestMessage) - assert.NoError(t, err) - responseBytes, _, err := net.SendSyncedAppRequestAny(t.Context(), defaultPeerVersion, requestBytes) - assert.ErrorIs(t, err, errRequestFailed) - assert.Nil(t, responseBytes) - }() - wg.Wait() + requestBytes, err := message.RequestToBytes(codecManager, requestMessage) + require.NoError(t, err) + responseBytes, _, err := net.SendSyncedAppRequestAny(t.Context(), defaultPeerVersion, requestBytes) + require.ErrorIs(t, err, errRequestFailed) + require.Nil(t, responseBytes) require.True(t, called) } @@ -350,18 +335,17 @@ func TestSyncedAppRequestAnyOnCtxCancellation(t *testing.T) { // Cancel context after sending require.Empty(t, net.(*network).outstandingRequestHandlers) // no outstanding requests ctx, cancel = context.WithCancel(t.Context()) - doneChan := make(chan struct{}) + errChan := make(chan error, 1) go func() { _, _, err = net.SendSyncedAppRequestAny(ctx, defaultPeerVersion, requestBytes) - assert.ErrorIs(t, err, context.Canceled) - close(doneChan) + errChan <- err }() // Wait until we've "sent" the app request over the network // before cancelling context. sentAppRequestInfo := <-sentAppRequest require.Len(t, net.(*network).outstandingRequestHandlers, 1) cancel() - <-doneChan + require.ErrorIs(t, <-errChan, context.Canceled) // Should still be able to process a response after cancelling. require.Len(t, net.(*network).outstandingRequestHandlers, 1) // context cancellation SendAppRequestAny failure doesn't clear require.NoError(t, net.AppResponse( @@ -373,33 +357,34 @@ func TestSyncedAppRequestAnyOnCtxCancellation(t *testing.T) { } func TestRequestMinVersion(t *testing.T) { - callNum := uint32(0) - nodeID := ids.GenerateTestNodeID() - codecManager := buildCodec(t, TestMessage{}) - - var net Network + const responseMessage = "this is a response" + var ( + callNum atomic.Uint32 + nodeID = ids.GenerateTestNodeID() + codecManager = buildCodec(t, TestMessage{}) + net Network + ) + eg, ctx := errgroup.WithContext(t.Context()) sender := testAppSender{ sendAppRequestFn: func(_ context.Context, nodes set.Set[ids.NodeID], reqID uint32, _ []byte) error { - atomic.AddUint32(&callNum, 1) - require.True(t, nodes.Contains(nodeID), "request nodes should contain expected nodeID") - require.Len(t, nodes, 1, "request nodes should contain exactly one node") + callNum.Add(1) + require.Equal(t, set.Of(nodeID), nodes) - go func() { + eg.Go(func() error { time.Sleep(200 * time.Millisecond) - atomic.AddUint32(&callNum, 1) - responseBytes, err := codecManager.Marshal(message.Version, TestMessage{Message: "this is a response"}) + responseBytes, err := codecManager.Marshal(message.Version, TestMessage{Message: responseMessage}) if err != nil { - panic(err) + return err } - assert.NoError(t, net.AppResponse(t.Context(), nodeID, reqID, responseBytes)) - }() + return net.AppResponse(ctx, nodeID, reqID, responseBytes) + }) return nil }, } // passing nil as codec works because the net.AppRequest is never called - ctx := snowtest.Context(t, snowtest.CChainID) - net, err := NewNetwork(ctx, sender, codecManager, 1, prometheus.NewRegistry()) + snowCtx := snowtest.Context(t, snowtest.CChainID) + net, err := NewNetwork(snowCtx, sender, codecManager, 1, prometheus.NewRegistry()) require.NoError(t, err) requestMessage := TestMessage{Message: "this is a request"} requestBytes, err := message.RequestToBytes(codecManager, requestMessage) @@ -438,7 +423,10 @@ func TestRequestMinVersion(t *testing.T) { var response TestMessage _, err = codecManager.Unmarshal(responseBytes, &response) require.NoError(t, err) - require.Equal(t, "this is a response", response.Message) + require.Equal(t, responseMessage, response.Message) + + require.NoError(t, eg.Wait()) + require.Equal(t, uint32(1), callNum.Load()) } func TestOnRequestHonoursDeadline(t *testing.T) { diff --git a/params/extras/config_extra_test.go b/params/extras/config_extra_test.go index 3e08dc1c70..bb4ea91ea9 100644 --- a/params/extras/config_extra_test.go +++ b/params/extras/config_extra_test.go @@ -6,7 +6,7 @@ package extras import ( "testing" - "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "github.com/ava-labs/subnet-evm/utils" ) @@ -52,7 +52,7 @@ func TestIsTimestampForked(t *testing.T) { } { t.Run(name, func(t *testing.T) { res := isTimestampForked(test.fork, test.block) - assert.Equal(t, test.isForked, res) + require.Equal(t, test.isForked, res) }) } } @@ -122,7 +122,7 @@ func TestIsForkTransition(t *testing.T) { } { t.Run(name, func(t *testing.T) { res := IsForkTransition(test.fork, test.parent, test.current) - assert.Equal(t, test.transitioned, res) + require.Equal(t, test.transitioned, res) }) } } diff --git a/params/extras/config_test.go b/params/extras/config_test.go index c91508bae1..45f8d0fd89 100644 --- a/params/extras/config_test.go +++ b/params/extras/config_test.go @@ -11,7 +11,6 @@ import ( "github.com/ava-labs/avalanchego/snow" "github.com/ava-labs/avalanchego/upgrade" "github.com/ava-labs/libevm/common" - "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/ava-labs/subnet-evm/commontype" @@ -86,7 +85,7 @@ $`, for name, test := range tests { t.Run(name, func(t *testing.T) { got := test.config.Description() - assert.Regexp(t, test.wantRegex, got, "config description mismatch") + require.Regexp(t, test.wantRegex, got, "config description mismatch") }) } } @@ -176,10 +175,10 @@ func TestChainConfigVerify(t *testing.T) { t.Run(name, func(t *testing.T) { err := test.config.Verify() if test.errRegex == "" { - assert.NoError(t, err) + require.NoError(t, err) } else { require.Error(t, err) - assert.Regexp(t, test.errRegex, err.Error()) + require.Regexp(t, test.errRegex, err.Error()) } }) } diff --git a/params/protocol_params_test.go b/params/protocol_params_test.go index 23b2fedd41..a14c35e5d5 100644 --- a/params/protocol_params_test.go +++ b/params/protocol_params_test.go @@ -7,7 +7,7 @@ import ( "testing" "github.com/ava-labs/libevm/common" - "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ethparams "github.com/ava-labs/libevm/params" ) @@ -138,6 +138,6 @@ func TestUpstreamParamsValues(t *testing.T) { } for name, test := range tests { - assert.Equal(t, test.want, test.param, name) + require.Equal(t, test.want, test.param, name) } } diff --git a/plugin/evm/customheader/block_gas_cost_test.go b/plugin/evm/customheader/block_gas_cost_test.go index 14d2ab902b..4f123992fb 100644 --- a/plugin/evm/customheader/block_gas_cost_test.go +++ b/plugin/evm/customheader/block_gas_cost_test.go @@ -10,7 +10,6 @@ import ( "github.com/ava-labs/libevm/common" "github.com/ava-labs/libevm/core/types" - "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/ava-labs/subnet-evm/commontype" @@ -115,7 +114,7 @@ func BlockGasCostTest(t *testing.T, feeConfig commontype.FeeConfig) { }, ) - assert.Equal(t, test.expected, BlockGasCost( + require.Equal(t, test.expected, BlockGasCost( config, feeConfig, parent, @@ -216,7 +215,7 @@ func BlockGasCostWithStepTest(t *testing.T, feeConfig commontype.FeeConfig) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { - assert.Equal(t, test.expected, BlockGasCostWithStep( + require.Equal(t, test.expected, BlockGasCostWithStep( feeConfig, test.parentCost, blockGasCostStep, diff --git a/plugin/evm/customtypes/block_ext_test.go b/plugin/evm/customtypes/block_ext_test.go index a0082a76eb..7c244df884 100644 --- a/plugin/evm/customtypes/block_ext_test.go +++ b/plugin/evm/customtypes/block_ext_test.go @@ -11,7 +11,6 @@ import ( "github.com/ava-labs/avalanchego/vms/evm/acp226" "github.com/ava-labs/libevm/common" - "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/ava-labs/subnet-evm/internal/blocktest" @@ -55,13 +54,13 @@ func TestBlockGetters(t *testing.T) { block := NewBlock(header, nil, nil, nil, blocktest.NewHasher()) blockGasCost := BlockGasCost(block) - assert.Equal(t, test.wantBlockGasCost, blockGasCost, "BlockGasCost()") + require.Equal(t, test.wantBlockGasCost, blockGasCost, "BlockGasCost()") timeMilliseconds := BlockTimeMilliseconds(block) - assert.Equal(t, test.wantTimeMilliseconds, timeMilliseconds, "BlockTimeMilliseconds()") + require.Equal(t, test.wantTimeMilliseconds, timeMilliseconds, "BlockTimeMilliseconds()") minDelayExcess := BlockMinDelayExcess(block) - assert.Equal(t, test.wantMinDelayExcess, minDelayExcess, "BlockMinDelayExcess()") + require.Equal(t, test.wantMinDelayExcess, minDelayExcess, "BlockMinDelayExcess()") }) } } @@ -87,7 +86,7 @@ func TestCopyHeader(t *testing.T) { headerExtra = &HeaderExtra{} extras.Header.Set(want, headerExtra) - assert.Equal(t, want, cpy) + require.Equal(t, want, cpy) }) t.Run("filled_header", func(t *testing.T) { @@ -99,8 +98,8 @@ func TestCopyHeader(t *testing.T) { gotExtra := GetHeaderExtra(gotHeader) wantHeader, wantExtra := headerWithNonZeroFields() - assert.Equal(t, wantHeader, gotHeader) - assert.Equal(t, wantExtra, gotExtra) + require.Equal(t, wantHeader, gotHeader) + require.Equal(t, wantExtra, gotExtra) exportedFieldsPointToDifferentMemory(t, header, gotHeader) exportedFieldsPointToDifferentMemory(t, GetHeaderExtra(header), gotExtra) @@ -140,7 +139,7 @@ func exportedFieldsPointToDifferentMemory[T interface { case []uint8: assertDifferentPointers(t, unsafe.SliceData(f), unsafe.SliceData(fieldCp.([]uint8))) default: - require.Failf(t, "invalid type", "field %q type %T needs to be added to switch cases of exportedFieldsDeepCopied", field.Name, f) + require.Failf(t, "field type needs to be added to switch cases", "field %q type %T needs to be added to switch cases of exportedFieldsDeepCopied", field.Name, f) } }) } diff --git a/plugin/evm/customtypes/header_ext_test.go b/plugin/evm/customtypes/header_ext_test.go index 8652a4709c..42c05aed73 100644 --- a/plugin/evm/customtypes/header_ext_test.go +++ b/plugin/evm/customtypes/header_ext_test.go @@ -16,7 +16,6 @@ import ( "github.com/ava-labs/avalanchego/vms/evm/acp226" "github.com/ava-labs/libevm/common" "github.com/ava-labs/libevm/rlp" - "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/ava-labs/subnet-evm/utils/utilstest" @@ -40,11 +39,11 @@ func TestHeaderRLP(t *testing.T) { wantHashHex = "460d4b45e82d1690f901bc1281125e51b138c7b0559dad122e2bc7386ecf3486" ) - assert.Equal(t, wantHex, hex.EncodeToString(got), "Header RLP") + require.Equal(t, wantHex, hex.EncodeToString(got), "Header RLP") header, _ := headerWithNonZeroFields() gotHashHex := header.Hash().Hex() - assert.Equal(t, "0x"+wantHashHex, gotHashHex, "Header.Hash()") + require.Equal(t, "0x"+wantHashHex, gotHashHex, "Header.Hash()") } func TestHeaderJSON(t *testing.T) { @@ -73,8 +72,8 @@ func testHeaderEncodeDecode( wantHeader, wantExtra := headerWithNonZeroFields() wantHeader.WithdrawalsHash = nil - assert.Equal(t, wantHeader, gotHeader) - assert.Equal(t, wantExtra, gotExtra) + require.Equal(t, wantHeader, gotHeader) + require.Equal(t, wantExtra, gotExtra) return encoded } @@ -173,9 +172,9 @@ func allFieldsSet[T interface { case *acp226.DelayExcess: assertNonZero(t, f) case []uint8, []*Header, Transactions, []*Transaction, Withdrawals, []*Withdrawal: - assert.NotEmpty(t, f) + require.NotEmpty(t, f) default: - assert.Failf(t, "invalid type", "Field %q has unsupported type %T", field.Name, f) + require.Failf(t, "Field has unsupported type", "Field %q has unsupported type %T", field.Name, f) } }) } diff --git a/plugin/evm/gossip_test.go b/plugin/evm/gossip_test.go index f355e03cb4..01ec1a21b7 100644 --- a/plugin/evm/gossip_test.go +++ b/plugin/evm/gossip_test.go @@ -18,7 +18,6 @@ import ( "github.com/ava-labs/libevm/core/vm" "github.com/ava-labs/libevm/crypto" "github.com/prometheus/client_golang/prometheus" - "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/ava-labs/subnet-evm/consensus/dummy" @@ -79,19 +78,17 @@ func TestGossipSubscribe(t *testing.T) { require.NoError(err, "failed adding tx to remote mempool") } - require.EventuallyWithTf( - func(c *assert.CollectT) { - gossipTxPool.lock.RLock() - defer gossipTxPool.lock.RUnlock() + require.Eventually(func() bool { + gossipTxPool.lock.RLock() + defer gossipTxPool.lock.RUnlock() - for i, tx := range ethTxs { - assert.Truef(c, gossipTxPool.bloom.Has(&GossipEthTx{Tx: tx}), "expected tx[%d] to be in bloom filter", i) + for _, tx := range ethTxs { + if !gossipTxPool.bloom.Has(&GossipEthTx{Tx: tx}) { + return false } - }, - 30*time.Second, - 500*time.Millisecond, - "expected all transactions to eventually be in the bloom filter", - ) + } + return true + }, 30*time.Second, 500*time.Millisecond, "expected all transactions to eventually be in the bloom filter") } func setupPoolWithConfig(t *testing.T, config *params.ChainConfig, fundedAddress common.Address) *txpool.TxPool { diff --git a/plugin/evm/syncervm_test.go b/plugin/evm/syncervm_test.go index 228f4ef70d..8054fdbae6 100644 --- a/plugin/evm/syncervm_test.go +++ b/plugin/evm/syncervm_test.go @@ -30,7 +30,6 @@ import ( "github.com/ava-labs/libevm/rlp" "github.com/ava-labs/libevm/trie" "github.com/ava-labs/libevm/triedb" - "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/ava-labs/subnet-evm/consensus/dummy" @@ -151,7 +150,7 @@ func TestStateSyncToggleEnabledToDisabled(t *testing.T) { enabled, err := syncDisabledVM.StateSyncEnabled(t.Context()) require.NoError(t, err) - assert.False(t, enabled, "sync should be disabled") + require.False(t, enabled, "sync should be disabled") // Process the first 10 blocks from the serverVM for i := uint64(1); i < 10; i++ { @@ -209,7 +208,7 @@ func TestStateSyncToggleEnabledToDisabled(t *testing.T) { enabled, err = syncReEnabledVM.StateSyncEnabled(t.Context()) require.NoError(t, err) - assert.True(t, enabled, "sync should be enabled") + require.True(t, enabled, "sync should be enabled") vmSetup.syncerVM = syncReEnabledVM testSyncerVM(t, vmSetup, test) diff --git a/plugin/evm/vm_test.go b/plugin/evm/vm_test.go index b5f6cb3324..6c1dbcae7c 100644 --- a/plugin/evm/vm_test.go +++ b/plugin/evm/vm_test.go @@ -41,7 +41,6 @@ import ( "github.com/ava-labs/libevm/crypto" "github.com/ava-labs/libevm/log" "github.com/ava-labs/libevm/trie" - "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/ava-labs/subnet-evm/commontype" @@ -2740,10 +2739,10 @@ func TestStandaloneDB(t *testing.T) { require.Equal(t, newBlock.Block.Hash(), common.Hash(blk.ID())) // Ensure that the shared database is empty - assert.True(t, isDBEmpty(baseDB)) + require.True(t, isDBEmpty(baseDB)) // Ensure that the standalone database is not empty - assert.False(t, isDBEmpty(vm.db)) - assert.False(t, isDBEmpty(vm.acceptedBlockDB)) + require.False(t, isDBEmpty(vm.db)) + require.False(t, isDBEmpty(vm.acceptedBlockDB)) } func TestFeeManagerRegressionMempoolMinFeeAfterRestart(t *testing.T) { diff --git a/precompile/contract/utils_test.go b/precompile/contract/utils_test.go index 672070a63d..44410f673e 100644 --- a/precompile/contract/utils_test.go +++ b/precompile/contract/utils_test.go @@ -6,7 +6,7 @@ package contract import ( "testing" - "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestFunctionSignatureRegex(t *testing.T) { @@ -61,6 +61,6 @@ func TestFunctionSignatureRegex(t *testing.T) { pass: false, }, } { - assert.Equal(t, test.pass, functionSignatureRegex.MatchString(test.str), "unexpected result for %q", test.str) + require.Equal(t, test.pass, functionSignatureRegex.MatchString(test.str), "unexpected result for %q", test.str) } } diff --git a/sync/handlers/block_request_test.go b/sync/handlers/block_request_test.go index 5a680633a2..b60506d53f 100644 --- a/sync/handlers/block_request_test.go +++ b/sync/handlers/block_request_test.go @@ -17,7 +17,6 @@ import ( "github.com/ava-labs/libevm/crypto" "github.com/ava-labs/libevm/rlp" "github.com/ava-labs/libevm/triedb" - "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/ava-labs/subnet-evm/consensus/dummy" @@ -26,6 +25,7 @@ import ( "github.com/ava-labs/subnet-evm/plugin/evm/customtypes" "github.com/ava-labs/subnet-evm/plugin/evm/message" "github.com/ava-labs/subnet-evm/sync/handlers/stats" + "github.com/ava-labs/subnet-evm/sync/handlers/stats/statstest" ) func TestMain(m *testing.M) { @@ -45,11 +45,11 @@ type blockRequestTest struct { requestedParents uint16 expectedBlocks int expectNilResponse bool - assertResponse func(t testing.TB, stats *stats.MockHandlerStats, b []byte) + requireResponse func(t testing.TB, stats *statstest.TestHandlerStats, b []byte) } func executeBlockRequestTest(t testing.TB, test blockRequestTest, blocks []*types.Block) { - mockHandlerStats := &stats.MockHandlerStats{} + testHandlerStats := &statstest.TestHandlerStats{} // convert into map blocksDB := make(map[common.Hash]*types.Block, len(blocks)) @@ -65,7 +65,7 @@ func executeBlockRequestTest(t testing.TB, test blockRequestTest, blocks []*type return blk }, } - blockRequestHandler := NewBlockRequestHandler(blockProvider, message.Codec, mockHandlerStats) + blockRequestHandler := NewBlockRequestHandler(blockProvider, message.Codec, testHandlerStats) var blockRequest message.BlockRequest if test.startBlockHash != (common.Hash{}) { @@ -80,30 +80,30 @@ func executeBlockRequestTest(t testing.TB, test blockRequestTest, blocks []*type responseBytes, err := blockRequestHandler.OnBlockRequest(t.Context(), ids.GenerateTestNodeID(), 1, blockRequest) require.NoError(t, err) - if test.assertResponse != nil { - test.assertResponse(t, mockHandlerStats, responseBytes) + if test.requireResponse != nil { + test.requireResponse(t, testHandlerStats, responseBytes) } if test.expectNilResponse { - assert.Nil(t, responseBytes) + require.Nil(t, responseBytes) return } - assert.NotEmpty(t, responseBytes) + require.NotEmpty(t, responseBytes) var response message.BlockResponse _, err = message.Codec.Unmarshal(responseBytes, &response) require.NoError(t, err) - assert.Len(t, response.Blocks, test.expectedBlocks) + require.Len(t, response.Blocks, test.expectedBlocks) for _, blockBytes := range response.Blocks { block := new(types.Block) require.NoError(t, rlp.DecodeBytes(blockBytes, block)) - assert.GreaterOrEqual(t, test.startBlockIndex, 0) - assert.Equal(t, blocks[test.startBlockIndex].Hash(), block.Hash()) + require.GreaterOrEqual(t, test.startBlockIndex, 0) + require.Equal(t, blocks[test.startBlockIndex].Hash(), block.Hash()) test.startBlockIndex-- } - mockHandlerStats.Reset() + testHandlerStats.Reset() } func TestBlockRequestHandler(t *testing.T) { @@ -116,7 +116,7 @@ func TestBlockRequestHandler(t *testing.T) { engine := dummy.NewETHFaker() blocks, _, err := core.GenerateChain(params.TestChainConfig, genesis, engine, memdb, 96, 0, func(_ int, _ *core.BlockGen) {}) require.NoError(t, err) - assert.Len(t, blocks, 96) + require.Len(t, blocks, 96) tests := []blockRequestTest{ { @@ -143,8 +143,8 @@ func TestBlockRequestHandler(t *testing.T) { startBlockHeight: 1_000_000, requestedParents: 64, expectNilResponse: true, - assertResponse: func(t testing.TB, mockHandlerStats *stats.MockHandlerStats, _ []byte) { - assert.Equal(t, uint32(1), mockHandlerStats.MissingBlockHashCount) + requireResponse: func(t testing.TB, testHandlerStats *statstest.TestHandlerStats, _ []byte) { + require.Equal(t, uint32(1), testHandlerStats.MissingBlockHashCount) }, }, } @@ -183,7 +183,7 @@ func TestBlockRequestHandlerLargeBlocks(t *testing.T) { b.AddTx(tx) }) require.NoError(t, err) - assert.Len(t, blocks, 96) + require.Len(t, blocks, 96) tests := []blockRequestTest{ { @@ -223,7 +223,7 @@ func TestBlockRequestHandlerCtxExpires(t *testing.T) { blocks, _, err := core.GenerateChain(params.TestChainConfig, genesis, engine, memdb, 11, 0, func(_ int, _ *core.BlockGen) {}) require.NoError(t, err) - assert.Len(t, blocks, 11) + require.Len(t, blocks, 11) // convert into map blocksDB := make(map[common.Hash]*types.Block, 11) @@ -257,17 +257,17 @@ func TestBlockRequestHandlerCtxExpires(t *testing.T) { Parents: uint16(8), }) require.NoError(t, err) - assert.NotEmpty(t, responseBytes) + require.NotEmpty(t, responseBytes) var response message.BlockResponse _, err = message.Codec.Unmarshal(responseBytes, &response) require.NoError(t, err) // requested 8 blocks, received cancelAfterNumRequests because of timeout - assert.Len(t, response.Blocks, cancelAfterNumRequests) + require.Len(t, response.Blocks, cancelAfterNumRequests) for i, blockBytes := range response.Blocks { block := new(types.Block) require.NoError(t, rlp.DecodeBytes(blockBytes, block)) - assert.Equal(t, blocks[len(blocks)-i-1].Hash(), block.Hash()) + require.Equal(t, blocks[len(blocks)-i-1].Hash(), block.Hash()) } } diff --git a/sync/handlers/code_request_test.go b/sync/handlers/code_request_test.go index 81a47e81b2..2e9de70975 100644 --- a/sync/handlers/code_request_test.go +++ b/sync/handlers/code_request_test.go @@ -15,7 +15,7 @@ import ( "github.com/stretchr/testify/require" "github.com/ava-labs/subnet-evm/plugin/evm/message" - "github.com/ava-labs/subnet-evm/sync/handlers/stats" + "github.com/ava-labs/subnet-evm/sync/handlers/stats/statstest" ethparams "github.com/ava-labs/libevm/params" ) @@ -34,7 +34,7 @@ func TestCodeRequestHandler(t *testing.T) { maxSizeCodeHash := crypto.Keccak256Hash(maxSizeCodeBytes) rawdb.WriteCode(database, maxSizeCodeHash, maxSizeCodeBytes) - mockHandlerStats := &stats.MockHandlerStats{} + mockHandlerStats := &statstest.TestHandlerStats{} codeRequestHandler := NewCodeRequestHandler(database, message.Codec, mockHandlerStats) tests := map[string]struct { diff --git a/sync/handlers/leafs_request_test.go b/sync/handlers/leafs_request_test.go index 93ad901998..52d63f893e 100644 --- a/sync/handlers/leafs_request_test.go +++ b/sync/handlers/leafs_request_test.go @@ -21,13 +21,13 @@ import ( "github.com/ava-labs/subnet-evm/core/state/snapshot" "github.com/ava-labs/subnet-evm/plugin/evm/message" - "github.com/ava-labs/subnet-evm/sync/handlers/stats" + "github.com/ava-labs/subnet-evm/sync/handlers/stats/statstest" "github.com/ava-labs/subnet-evm/sync/statesync/statesynctest" ) func TestLeafsRequestHandler_OnLeafsRequest(t *testing.T) { r := rand.New(rand.NewSource(1)) - mockHandlerStats := &stats.MockHandlerStats{} + testHandlerStats := &statstest.TestHandlerStats{} memdb := rawdb.NewMemoryDatabase() trieDB := triedb.NewDatabase(memdb, nil) @@ -75,7 +75,7 @@ func TestLeafsRequestHandler_OnLeafsRequest(t *testing.T) { } } snapshotProvider := &TestSnapshotProvider{} - leafsHandler := NewLeafsRequestHandler(trieDB, message.StateTrieKeyLength, snapshotProvider, message.Codec, mockHandlerStats) + leafsHandler := NewLeafsRequestHandler(trieDB, message.StateTrieKeyLength, snapshotProvider, message.Codec, testHandlerStats) snapConfig := snapshot.Config{ CacheSize: 64, AsyncBuild: false, @@ -100,7 +100,7 @@ func TestLeafsRequestHandler_OnLeafsRequest(t *testing.T) { requireResponseFn: func(t *testing.T, _ message.LeafsRequest, response []byte, err error) { require.Nil(t, response) require.NoError(t, err) - require.Equal(t, uint32(1), mockHandlerStats.InvalidLeafsRequestCount) + require.Equal(t, uint32(1), testHandlerStats.InvalidLeafsRequestCount) }, }, "empty root dropped": { @@ -116,7 +116,7 @@ func TestLeafsRequestHandler_OnLeafsRequest(t *testing.T) { requireResponseFn: func(t *testing.T, _ message.LeafsRequest, response []byte, err error) { require.Nil(t, response) require.NoError(t, err) - require.Equal(t, uint32(1), mockHandlerStats.InvalidLeafsRequestCount) + require.Equal(t, uint32(1), testHandlerStats.InvalidLeafsRequestCount) }, }, "bad start len dropped": { @@ -132,7 +132,7 @@ func TestLeafsRequestHandler_OnLeafsRequest(t *testing.T) { requireResponseFn: func(t *testing.T, _ message.LeafsRequest, response []byte, err error) { require.Nil(t, response) require.NoError(t, err) - require.Equal(t, uint32(1), mockHandlerStats.InvalidLeafsRequestCount) + require.Equal(t, uint32(1), testHandlerStats.InvalidLeafsRequestCount) }, }, "bad end len dropped": { @@ -148,7 +148,7 @@ func TestLeafsRequestHandler_OnLeafsRequest(t *testing.T) { requireResponseFn: func(t *testing.T, _ message.LeafsRequest, response []byte, err error) { require.Nil(t, response) require.NoError(t, err) - require.Equal(t, uint32(1), mockHandlerStats.InvalidLeafsRequestCount) + require.Equal(t, uint32(1), testHandlerStats.InvalidLeafsRequestCount) }, }, "empty storage root dropped": { @@ -164,7 +164,7 @@ func TestLeafsRequestHandler_OnLeafsRequest(t *testing.T) { requireResponseFn: func(t *testing.T, _ message.LeafsRequest, response []byte, err error) { require.Nil(t, response) require.NoError(t, err) - require.Equal(t, uint32(1), mockHandlerStats.InvalidLeafsRequestCount) + require.Equal(t, uint32(1), testHandlerStats.InvalidLeafsRequestCount) }, }, "missing root dropped": { @@ -180,7 +180,7 @@ func TestLeafsRequestHandler_OnLeafsRequest(t *testing.T) { requireResponseFn: func(t *testing.T, _ message.LeafsRequest, response []byte, err error) { require.Nil(t, response) require.NoError(t, err) - require.Equal(t, uint32(1), mockHandlerStats.MissingRootCount) + require.Equal(t, uint32(1), testHandlerStats.MissingRootCount) }, }, "corrupted trie drops request": { @@ -196,7 +196,7 @@ func TestLeafsRequestHandler_OnLeafsRequest(t *testing.T) { requireResponseFn: func(t *testing.T, _ message.LeafsRequest, response []byte, err error) { require.Nil(t, response) require.NoError(t, err) - require.Equal(t, uint32(1), mockHandlerStats.TrieErrorCount) + require.Equal(t, uint32(1), testHandlerStats.TrieErrorCount) }, }, "cancelled context dropped": { @@ -270,7 +270,7 @@ func TestLeafsRequestHandler_OnLeafsRequest(t *testing.T) { requireResponseFn: func(t *testing.T, _ message.LeafsRequest, response []byte, err error) { require.Nil(t, response) require.NoError(t, err) - require.Equal(t, uint32(1), mockHandlerStats.InvalidLeafsRequestCount) + require.Equal(t, uint32(1), testHandlerStats.InvalidLeafsRequestCount) }, }, "invalid node type dropped": { @@ -307,8 +307,8 @@ func TestLeafsRequestHandler_OnLeafsRequest(t *testing.T) { require.NoError(t, err) require.Len(t, leafsResponse.Keys, int(maxLeavesLimit)) require.Len(t, leafsResponse.Vals, int(maxLeavesLimit)) - require.Equal(t, uint32(1), mockHandlerStats.LeafsRequestCount) - require.Equal(t, uint32(len(leafsResponse.Keys)), mockHandlerStats.LeafsReturnedSum) + require.Equal(t, uint32(1), testHandlerStats.LeafsRequestCount) + require.Equal(t, uint32(len(leafsResponse.Keys)), testHandlerStats.LeafsReturnedSum) }, }, "full range with nil start": { @@ -328,8 +328,8 @@ func TestLeafsRequestHandler_OnLeafsRequest(t *testing.T) { require.NoError(t, err) require.Len(t, leafsResponse.Keys, int(maxLeavesLimit)) require.Len(t, leafsResponse.Vals, int(maxLeavesLimit)) - require.Equal(t, uint32(1), mockHandlerStats.LeafsRequestCount) - require.Equal(t, uint32(len(leafsResponse.Keys)), mockHandlerStats.LeafsReturnedSum) + require.Equal(t, uint32(1), testHandlerStats.LeafsRequestCount) + require.Equal(t, uint32(len(leafsResponse.Keys)), testHandlerStats.LeafsReturnedSum) requireRangeProofIsValid(t, &request, &leafsResponse, true) }, }, @@ -350,8 +350,8 @@ func TestLeafsRequestHandler_OnLeafsRequest(t *testing.T) { require.NoError(t, err) require.Len(t, leafsResponse.Keys, int(maxLeavesLimit)) require.Len(t, leafsResponse.Vals, int(maxLeavesLimit)) - require.Equal(t, uint32(1), mockHandlerStats.LeafsRequestCount) - require.Equal(t, uint32(len(leafsResponse.Keys)), mockHandlerStats.LeafsReturnedSum) + require.Equal(t, uint32(1), testHandlerStats.LeafsRequestCount) + require.Equal(t, uint32(len(leafsResponse.Keys)), testHandlerStats.LeafsReturnedSum) requireRangeProofIsValid(t, &request, &leafsResponse, true) }, }, @@ -375,8 +375,8 @@ func TestLeafsRequestHandler_OnLeafsRequest(t *testing.T) { require.NoError(t, err) require.Len(t, leafsResponse.Keys, 40) require.Len(t, leafsResponse.Vals, 40) - require.Equal(t, uint32(1), mockHandlerStats.LeafsRequestCount) - require.Equal(t, uint32(len(leafsResponse.Keys)), mockHandlerStats.LeafsReturnedSum) + require.Equal(t, uint32(1), testHandlerStats.LeafsRequestCount) + require.Equal(t, uint32(len(leafsResponse.Keys)), testHandlerStats.LeafsReturnedSum) requireRangeProofIsValid(t, &request, &leafsResponse, true) }, }, @@ -397,8 +397,8 @@ func TestLeafsRequestHandler_OnLeafsRequest(t *testing.T) { require.NoError(t, err) require.Len(t, leafsResponse.Keys, 600) require.Len(t, leafsResponse.Vals, 600) - require.Equal(t, uint32(1), mockHandlerStats.LeafsRequestCount) - require.Equal(t, uint32(len(leafsResponse.Keys)), mockHandlerStats.LeafsReturnedSum) + require.Equal(t, uint32(1), testHandlerStats.LeafsRequestCount) + require.Equal(t, uint32(len(leafsResponse.Keys)), testHandlerStats.LeafsReturnedSum) requireRangeProofIsValid(t, &request, &leafsResponse, false) }, }, @@ -419,8 +419,8 @@ func TestLeafsRequestHandler_OnLeafsRequest(t *testing.T) { require.NoError(t, err) require.Empty(t, leafsResponse.Keys) require.Empty(t, leafsResponse.Vals) - require.Equal(t, uint32(1), mockHandlerStats.LeafsRequestCount) - require.Equal(t, uint32(len(leafsResponse.Keys)), mockHandlerStats.LeafsReturnedSum) + require.Equal(t, uint32(1), testHandlerStats.LeafsRequestCount) + require.Equal(t, uint32(len(leafsResponse.Keys)), testHandlerStats.LeafsReturnedSum) requireRangeProofIsValid(t, &request, &leafsResponse, false) }, }, @@ -445,8 +445,8 @@ func TestLeafsRequestHandler_OnLeafsRequest(t *testing.T) { require.Len(t, leafsResponse.Keys, 500) require.Len(t, leafsResponse.Vals, 500) require.Empty(t, leafsResponse.ProofVals) - require.Equal(t, uint32(1), mockHandlerStats.LeafsRequestCount) - require.Equal(t, uint32(len(leafsResponse.Keys)), mockHandlerStats.LeafsReturnedSum) + require.Equal(t, uint32(1), testHandlerStats.LeafsRequestCount) + require.Equal(t, uint32(len(leafsResponse.Keys)), testHandlerStats.LeafsReturnedSum) requireRangeProofIsValid(t, &request, &leafsResponse, false) }, }, @@ -468,10 +468,10 @@ func TestLeafsRequestHandler_OnLeafsRequest(t *testing.T) { require.NoError(t, err) require.Len(t, leafsResponse.Keys, int(maxLeavesLimit)) require.Len(t, leafsResponse.Vals, int(maxLeavesLimit)) - require.Equal(t, uint32(1), mockHandlerStats.LeafsRequestCount) - require.Equal(t, uint32(len(leafsResponse.Keys)), mockHandlerStats.LeafsReturnedSum) - require.Equal(t, uint32(1), mockHandlerStats.SnapshotReadAttemptCount) - require.Equal(t, uint32(1), mockHandlerStats.SnapshotReadSuccessCount) + require.Equal(t, uint32(1), testHandlerStats.LeafsRequestCount) + require.Equal(t, uint32(len(leafsResponse.Keys)), testHandlerStats.LeafsReturnedSum) + require.Equal(t, uint32(1), testHandlerStats.SnapshotReadAttemptCount) + require.Equal(t, uint32(1), testHandlerStats.SnapshotReadSuccessCount) requireRangeProofIsValid(t, &request, &leafsResponse, true) }, }, @@ -512,16 +512,16 @@ func TestLeafsRequestHandler_OnLeafsRequest(t *testing.T) { require.NoError(t, err) require.Len(t, leafsResponse.Keys, int(maxLeavesLimit)) require.Len(t, leafsResponse.Vals, int(maxLeavesLimit)) - require.Equal(t, uint32(1), mockHandlerStats.LeafsRequestCount) - require.Equal(t, uint32(len(leafsResponse.Keys)), mockHandlerStats.LeafsReturnedSum) - require.Equal(t, uint32(1), mockHandlerStats.SnapshotReadAttemptCount) - require.Equal(t, uint32(0), mockHandlerStats.SnapshotReadSuccessCount) + require.Equal(t, uint32(1), testHandlerStats.LeafsRequestCount) + require.Equal(t, uint32(len(leafsResponse.Keys)), testHandlerStats.LeafsReturnedSum) + require.Equal(t, uint32(1), testHandlerStats.SnapshotReadAttemptCount) + require.Equal(t, uint32(0), testHandlerStats.SnapshotReadSuccessCount) requireRangeProofIsValid(t, &request, &leafsResponse, true) // expect 1/4th of segments to be invalid numSegments := maxLeavesLimit / segmentLen - require.Equal(t, uint32(numSegments/4), mockHandlerStats.SnapshotSegmentInvalidCount) - require.Equal(t, uint32(3*numSegments/4), mockHandlerStats.SnapshotSegmentValidCount) + require.Equal(t, uint32(numSegments/4), testHandlerStats.SnapshotSegmentInvalidCount) + require.Equal(t, uint32(3*numSegments/4), testHandlerStats.SnapshotSegmentValidCount) }, }, "storage data served from snapshot": { @@ -543,10 +543,10 @@ func TestLeafsRequestHandler_OnLeafsRequest(t *testing.T) { require.NoError(t, err) require.Len(t, leafsResponse.Keys, int(maxLeavesLimit)) require.Len(t, leafsResponse.Vals, int(maxLeavesLimit)) - require.Equal(t, uint32(1), mockHandlerStats.LeafsRequestCount) - require.Equal(t, uint32(len(leafsResponse.Keys)), mockHandlerStats.LeafsReturnedSum) - require.Equal(t, uint32(1), mockHandlerStats.SnapshotReadAttemptCount) - require.Equal(t, uint32(1), mockHandlerStats.SnapshotReadSuccessCount) + require.Equal(t, uint32(1), testHandlerStats.LeafsRequestCount) + require.Equal(t, uint32(len(leafsResponse.Keys)), testHandlerStats.LeafsReturnedSum) + require.Equal(t, uint32(1), testHandlerStats.SnapshotReadAttemptCount) + require.Equal(t, uint32(1), testHandlerStats.SnapshotReadSuccessCount) requireRangeProofIsValid(t, &request, &leafsResponse, true) }, }, @@ -587,16 +587,16 @@ func TestLeafsRequestHandler_OnLeafsRequest(t *testing.T) { require.NoError(t, err) require.Len(t, leafsResponse.Keys, int(maxLeavesLimit)) require.Len(t, leafsResponse.Vals, int(maxLeavesLimit)) - require.Equal(t, uint32(1), mockHandlerStats.LeafsRequestCount) - require.Equal(t, uint32(len(leafsResponse.Keys)), mockHandlerStats.LeafsReturnedSum) - require.Equal(t, uint32(1), mockHandlerStats.SnapshotReadAttemptCount) - require.Equal(t, uint32(0), mockHandlerStats.SnapshotReadSuccessCount) + require.Equal(t, uint32(1), testHandlerStats.LeafsRequestCount) + require.Equal(t, uint32(len(leafsResponse.Keys)), testHandlerStats.LeafsReturnedSum) + require.Equal(t, uint32(1), testHandlerStats.SnapshotReadAttemptCount) + require.Equal(t, uint32(0), testHandlerStats.SnapshotReadSuccessCount) requireRangeProofIsValid(t, &request, &leafsResponse, true) // expect 1/4th of segments to be invalid numSegments := maxLeavesLimit / segmentLen - require.Equal(t, uint32(numSegments/4), mockHandlerStats.SnapshotSegmentInvalidCount) - require.Equal(t, uint32(3*numSegments/4), mockHandlerStats.SnapshotSegmentValidCount) + require.Equal(t, uint32(numSegments/4), testHandlerStats.SnapshotSegmentInvalidCount) + require.Equal(t, uint32(3*numSegments/4), testHandlerStats.SnapshotSegmentValidCount) }, }, "last snapshot key removed": { @@ -626,10 +626,10 @@ func TestLeafsRequestHandler_OnLeafsRequest(t *testing.T) { require.NoError(t, err) require.Len(t, leafsResponse.Keys, 500) require.Len(t, leafsResponse.Vals, 500) - require.Equal(t, uint32(1), mockHandlerStats.LeafsRequestCount) - require.Equal(t, uint32(len(leafsResponse.Keys)), mockHandlerStats.LeafsReturnedSum) - require.Equal(t, uint32(1), mockHandlerStats.SnapshotReadAttemptCount) - require.Equal(t, uint32(1), mockHandlerStats.SnapshotReadSuccessCount) + require.Equal(t, uint32(1), testHandlerStats.LeafsRequestCount) + require.Equal(t, uint32(len(leafsResponse.Keys)), testHandlerStats.LeafsReturnedSum) + require.Equal(t, uint32(1), testHandlerStats.SnapshotReadAttemptCount) + require.Equal(t, uint32(1), testHandlerStats.SnapshotReadSuccessCount) requireRangeProofIsValid(t, &request, &leafsResponse, false) }, }, @@ -661,10 +661,10 @@ func TestLeafsRequestHandler_OnLeafsRequest(t *testing.T) { require.NoError(t, err) require.Len(t, leafsResponse.Keys, 1) require.Len(t, leafsResponse.Vals, 1) - require.Equal(t, uint32(1), mockHandlerStats.LeafsRequestCount) - require.Equal(t, uint32(len(leafsResponse.Keys)), mockHandlerStats.LeafsReturnedSum) - require.Equal(t, uint32(1), mockHandlerStats.SnapshotReadAttemptCount) - require.Equal(t, uint32(0), mockHandlerStats.SnapshotReadSuccessCount) + require.Equal(t, uint32(1), testHandlerStats.LeafsRequestCount) + require.Equal(t, uint32(len(leafsResponse.Keys)), testHandlerStats.LeafsReturnedSum) + require.Equal(t, uint32(1), testHandlerStats.SnapshotReadAttemptCount) + require.Equal(t, uint32(0), testHandlerStats.SnapshotReadSuccessCount) requireRangeProofIsValid(t, &request, &leafsResponse, false) }, }, @@ -674,7 +674,7 @@ func TestLeafsRequestHandler_OnLeafsRequest(t *testing.T) { ctx, request := test.prepareTestFn() t.Cleanup(func() { <-snapshot.WipeSnapshot(memdb, true) - mockHandlerStats.Reset() + testHandlerStats.Reset() snapshotProvider.Snapshot = nil // reset the snapshot to nil }) diff --git a/sync/handlers/stats/mock_stats.go b/sync/handlers/stats/statstest/test_stats.go similarity index 67% rename from sync/handlers/stats/mock_stats.go rename to sync/handlers/stats/statstest/test_stats.go index 008eee45b5..dda77a4bb5 100644 --- a/sync/handlers/stats/mock_stats.go +++ b/sync/handlers/stats/statstest/test_stats.go @@ -1,17 +1,19 @@ // Copyright (C) 2019-2025, Ava Labs, Inc. All rights reserved. // See the file LICENSE for licensing terms. -package stats +package statstest import ( "sync" "time" + + "github.com/ava-labs/subnet-evm/sync/handlers/stats" ) -var _ HandlerStats = &MockHandlerStats{} +var _ stats.HandlerStats = (*TestHandlerStats)(nil) -// MockHandlerStats is mock for capturing and asserting on handler metrics in test -type MockHandlerStats struct { +// TestHandlerStats is test for capturing and asserting on handler metrics in test +type TestHandlerStats struct { lock sync.Mutex BlockRequestCount, @@ -44,7 +46,7 @@ type MockHandlerStats struct { LeafRequestProcessingTimeSum time.Duration } -func (m *MockHandlerStats) Reset() { +func (m *TestHandlerStats) Reset() { m.lock.Lock() defer m.lock.Unlock() m.BlockRequestCount = 0 @@ -75,157 +77,157 @@ func (m *MockHandlerStats) Reset() { m.LeafRequestProcessingTimeSum = 0 } -func (m *MockHandlerStats) IncBlockRequest() { +func (m *TestHandlerStats) IncBlockRequest() { m.lock.Lock() defer m.lock.Unlock() m.BlockRequestCount++ } -func (m *MockHandlerStats) IncMissingBlockHash() { +func (m *TestHandlerStats) IncMissingBlockHash() { m.lock.Lock() defer m.lock.Unlock() m.MissingBlockHashCount++ } -func (m *MockHandlerStats) UpdateBlocksReturned(num uint16) { +func (m *TestHandlerStats) UpdateBlocksReturned(num uint16) { m.lock.Lock() defer m.lock.Unlock() m.BlocksReturnedSum += uint32(num) } -func (m *MockHandlerStats) UpdateBlockRequestProcessingTime(duration time.Duration) { +func (m *TestHandlerStats) UpdateBlockRequestProcessingTime(duration time.Duration) { m.lock.Lock() defer m.lock.Unlock() m.BlockRequestProcessingTimeSum += duration } -func (m *MockHandlerStats) IncCodeRequest() { +func (m *TestHandlerStats) IncCodeRequest() { m.lock.Lock() defer m.lock.Unlock() m.CodeRequestCount++ } -func (m *MockHandlerStats) IncMissingCodeHash() { +func (m *TestHandlerStats) IncMissingCodeHash() { m.lock.Lock() defer m.lock.Unlock() m.MissingCodeHashCount++ } -func (m *MockHandlerStats) IncTooManyHashesRequested() { +func (m *TestHandlerStats) IncTooManyHashesRequested() { m.lock.Lock() defer m.lock.Unlock() m.TooManyHashesRequested++ } -func (m *MockHandlerStats) IncDuplicateHashesRequested() { +func (m *TestHandlerStats) IncDuplicateHashesRequested() { m.lock.Lock() defer m.lock.Unlock() m.DuplicateHashesRequested++ } -func (m *MockHandlerStats) UpdateCodeReadTime(duration time.Duration) { +func (m *TestHandlerStats) UpdateCodeReadTime(duration time.Duration) { m.lock.Lock() defer m.lock.Unlock() m.CodeReadTimeSum += duration } -func (m *MockHandlerStats) UpdateCodeBytesReturned(bytes uint32) { +func (m *TestHandlerStats) UpdateCodeBytesReturned(bytes uint32) { m.lock.Lock() defer m.lock.Unlock() m.CodeBytesReturnedSum += bytes } -func (m *MockHandlerStats) IncLeafsRequest() { +func (m *TestHandlerStats) IncLeafsRequest() { m.lock.Lock() defer m.lock.Unlock() m.LeafsRequestCount++ } -func (m *MockHandlerStats) IncInvalidLeafsRequest() { +func (m *TestHandlerStats) IncInvalidLeafsRequest() { m.lock.Lock() defer m.lock.Unlock() m.InvalidLeafsRequestCount++ } -func (m *MockHandlerStats) UpdateLeafsReturned(numLeafs uint16) { +func (m *TestHandlerStats) UpdateLeafsReturned(numLeafs uint16) { m.lock.Lock() defer m.lock.Unlock() m.LeafsReturnedSum += uint32(numLeafs) } -func (m *MockHandlerStats) UpdateLeafsRequestProcessingTime(duration time.Duration) { +func (m *TestHandlerStats) UpdateLeafsRequestProcessingTime(duration time.Duration) { m.lock.Lock() defer m.lock.Unlock() m.LeafRequestProcessingTimeSum += duration } -func (m *MockHandlerStats) UpdateReadLeafsTime(duration time.Duration) { +func (m *TestHandlerStats) UpdateReadLeafsTime(duration time.Duration) { m.lock.Lock() defer m.lock.Unlock() m.LeafsReadTime += duration } -func (m *MockHandlerStats) UpdateGenerateRangeProofTime(duration time.Duration) { +func (m *TestHandlerStats) UpdateGenerateRangeProofTime(duration time.Duration) { m.lock.Lock() defer m.lock.Unlock() m.GenerateRangeProofTime += duration } -func (m *MockHandlerStats) UpdateSnapshotReadTime(duration time.Duration) { +func (m *TestHandlerStats) UpdateSnapshotReadTime(duration time.Duration) { m.lock.Lock() defer m.lock.Unlock() m.SnapshotReadTime += duration } -func (m *MockHandlerStats) UpdateRangeProofValsReturned(numProofVals int64) { +func (m *TestHandlerStats) UpdateRangeProofValsReturned(numProofVals int64) { m.lock.Lock() defer m.lock.Unlock() m.ProofValsReturned += numProofVals } -func (m *MockHandlerStats) IncMissingRoot() { +func (m *TestHandlerStats) IncMissingRoot() { m.lock.Lock() defer m.lock.Unlock() m.MissingRootCount++ } -func (m *MockHandlerStats) IncTrieError() { +func (m *TestHandlerStats) IncTrieError() { m.lock.Lock() defer m.lock.Unlock() m.TrieErrorCount++ } -func (m *MockHandlerStats) IncProofError() { +func (m *TestHandlerStats) IncProofError() { m.lock.Lock() defer m.lock.Unlock() m.ProofErrorCount++ } -func (m *MockHandlerStats) IncSnapshotReadError() { +func (m *TestHandlerStats) IncSnapshotReadError() { m.lock.Lock() defer m.lock.Unlock() m.SnapshotReadErrorCount++ } -func (m *MockHandlerStats) IncSnapshotReadAttempt() { +func (m *TestHandlerStats) IncSnapshotReadAttempt() { m.lock.Lock() defer m.lock.Unlock() m.SnapshotReadAttemptCount++ } -func (m *MockHandlerStats) IncSnapshotReadSuccess() { +func (m *TestHandlerStats) IncSnapshotReadSuccess() { m.lock.Lock() defer m.lock.Unlock() m.SnapshotReadSuccessCount++ } -func (m *MockHandlerStats) IncSnapshotSegmentValid() { +func (m *TestHandlerStats) IncSnapshotSegmentValid() { m.lock.Lock() defer m.lock.Unlock() m.SnapshotSegmentValidCount++ } -func (m *MockHandlerStats) IncSnapshotSegmentInvalid() { +func (m *TestHandlerStats) IncSnapshotSegmentInvalid() { m.lock.Lock() defer m.lock.Unlock() m.SnapshotSegmentInvalidCount++ diff --git a/sync/statesync/statesynctest/test_sync.go b/sync/statesync/statesynctest/test_sync.go index b9be14f74f..733b68a45c 100644 --- a/sync/statesync/statesynctest/test_sync.go +++ b/sync/statesync/statesynctest/test_sync.go @@ -15,7 +15,6 @@ import ( "github.com/ava-labs/libevm/ethdb" "github.com/ava-labs/libevm/rlp" "github.com/ava-labs/libevm/triedb" - "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/ava-labs/subnet-evm/plugin/evm/customrawdb" @@ -48,7 +47,7 @@ func AssertDBConsistency(t testing.TB, root common.Hash, clientDB ethdb.Database // check snapshot consistency snapshotVal := rawdb.ReadAccountSnapshot(clientDB, accHash) expectedSnapshotVal := types.SlimAccountRLP(acc) - assert.Equal(t, expectedSnapshotVal, snapshotVal) + require.Equal(t, expectedSnapshotVal, snapshotVal) // check code consistency if !bytes.Equal(acc.CodeHash, types.EmptyCodeHash[:]) { @@ -56,7 +55,7 @@ func AssertDBConsistency(t testing.TB, root common.Hash, clientDB ethdb.Database code := rawdb.ReadCode(clientDB, codeHash) actualHash := crypto.Keccak256Hash(code) require.NotEmpty(t, code) - assert.Equal(t, codeHash, actualHash) + require.Equal(t, codeHash, actualHash) } if acc.Root == types.EmptyRootHash { return nil @@ -76,16 +75,16 @@ func AssertDBConsistency(t testing.TB, root common.Hash, clientDB ethdb.Database AssertTrieConsistency(t, acc.Root, serverTrieDB, clientTrieDB, func(key, val []byte) error { storageTrieLeavesCount++ snapshotVal := rawdb.ReadStorageSnapshot(clientDB, accHash, common.BytesToHash(key)) - assert.Equal(t, val, snapshotVal) + require.Equal(t, val, snapshotVal) return nil }) - assert.Equal(t, storageTrieLeavesCount, snapshotStorageKeysCount) + require.Equal(t, storageTrieLeavesCount, snapshotStorageKeysCount) return nil }) // Check that the number of accounts in the snapshot matches the number of leaves in the accounts trie - assert.Equal(t, trieAccountLeaves, numSnapshotAccounts) + require.Equal(t, trieAccountLeaves, numSnapshotAccounts) } // FillAccountsWithOverlappingStorage adds [numAccounts] randomly generated accounts to the secure trie at [root] diff --git a/sync/statesync/sync_test.go b/sync/statesync/sync_test.go index fcebe29eea..f4f92d2daf 100644 --- a/sync/statesync/sync_test.go +++ b/sync/statesync/sync_test.go @@ -21,7 +21,6 @@ import ( "github.com/ava-labs/libevm/rlp" "github.com/ava-labs/libevm/trie" "github.com/ava-labs/libevm/triedb" - "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/ava-labs/subnet-evm/core/state/snapshot" @@ -544,15 +543,15 @@ func assertDBConsistency(t testing.TB, root common.Hash, clientDB ethdb.Database // check snapshot consistency snapshotVal := rawdb.ReadAccountSnapshot(clientDB, accHash) expectedSnapshotVal := types.SlimAccountRLP(acc) - assert.Equal(t, expectedSnapshotVal, snapshotVal) + require.Equal(t, expectedSnapshotVal, snapshotVal) // check code consistency if !bytes.Equal(acc.CodeHash, types.EmptyCodeHash[:]) { codeHash := common.BytesToHash(acc.CodeHash) code := rawdb.ReadCode(clientDB, codeHash) actualHash := crypto.Keccak256Hash(code) - assert.NotEmpty(t, code) - assert.Equal(t, codeHash, actualHash) + require.NotEmpty(t, code) + require.Equal(t, codeHash, actualHash) } if acc.Root == types.EmptyRootHash { return nil @@ -572,16 +571,16 @@ func assertDBConsistency(t testing.TB, root common.Hash, clientDB ethdb.Database statesynctest.AssertTrieConsistency(t, acc.Root, serverTrieDB, clientTrieDB, func(key, val []byte) error { storageTrieLeavesCount++ snapshotVal := rawdb.ReadStorageSnapshot(clientDB, accHash, common.BytesToHash(key)) - assert.Equal(t, val, snapshotVal) + require.Equal(t, val, snapshotVal) return nil }) - assert.Equal(t, storageTrieLeavesCount, snapshotStorageKeysCount) + require.Equal(t, storageTrieLeavesCount, snapshotStorageKeysCount) return nil }) // Check that the number of accounts in the snapshot matches the number of leaves in the accounts trie - assert.Equal(t, trieAccountLeaves, numSnapshotAccounts) + require.Equal(t, trieAccountLeaves, numSnapshotAccounts) } func fillAccountsWithStorage(t *testing.T, r *rand.Rand, serverDB ethdb.Database, serverTrieDB *triedb.Database, root common.Hash, numAccounts int) common.Hash { //nolint:unparam diff --git a/utils/numbers_test.go b/utils/numbers_test.go index 074a3d45bd..cee2bae2ae 100644 --- a/utils/numbers_test.go +++ b/utils/numbers_test.go @@ -7,7 +7,7 @@ import ( "math/big" "testing" - "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestBigEqual(t *testing.T) { @@ -44,10 +44,10 @@ func TestBigEqual(t *testing.T) { } for _, test := range tests { t.Run(test.name, func(t *testing.T) { - assert := assert.New(t) + require := require.New(t) - assert.Equal(test.want, BigEqual(test.a, test.b)) - assert.Equal(test.want, BigEqual(test.b, test.a)) + require.Equal(test.want, BigEqual(test.a, test.b)) + require.Equal(test.want, BigEqual(test.b, test.a)) }) } } @@ -87,7 +87,7 @@ func TestBigEqualUint64(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { got := BigEqualUint64(test.a, test.b) - assert.Equal(t, test.want, got) + require.Equal(t, test.want, got) }) } } @@ -133,7 +133,7 @@ func TestLessOrEqualUint64(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { got := BigLessOrEqualUint64(test.a, test.b) - assert.Equal(t, test.want, got) + require.Equal(t, test.want, got) }) } }