diff --git a/internal/bits/bit_array.go b/internal/bits/bit_array.go index 4dad7a94e..686a63c6c 100644 --- a/internal/bits/bit_array.go +++ b/internal/bits/bit_array.go @@ -28,7 +28,7 @@ func NewBitArray(bits int) *BitArray { } return &BitArray{ Bits: bits, - Elems: make([]uint64, (bits+63)/64), + Elems: make([]uint64, numElements(bits)), } } @@ -41,7 +41,7 @@ func NewBitArrayFromFn(bits int, fn func(int) bool) *BitArray { } bA := &BitArray{ Bits: bits, - Elems: make([]uint64, (bits+63)/64), + Elems: make([]uint64, numElements(bits)), } for i := 0; i < bits; i++ { v := fn(i) @@ -90,7 +90,7 @@ func (bA *BitArray) SetIndex(i int, v bool) bool { } func (bA *BitArray) setIndex(i int, v bool) bool { - if i >= bA.Bits { + if i >= bA.Bits || i/64 >= len(bA.Elems) { return false } if v { @@ -121,7 +121,7 @@ func (bA *BitArray) copy() *BitArray { } func (bA *BitArray) copyBits(bits int) *BitArray { - c := make([]uint64, (bits+63)/64) + c := make([]uint64, numElements(bits)) copy(c, bA.Elems) return &BitArray{ Bits: bits, @@ -282,6 +282,11 @@ func (bA *BitArray) PickRandom(r *rand.Rand) (int, bool) { } func (bA *BitArray) getNumTrueIndices() int { + if bA.Size() == 0 || len(bA.Elems) == 0 || len(bA.Elems) != numElements(bA.Size()) { + // size and elements must be valid to do this calc + return 0 + } + count := 0 numElems := len(bA.Elems) // handle all elements except the last one @@ -500,3 +505,22 @@ func (bA *BitArray) FromProto(protoBitArray *cmtprotobits.BitArray) { bA.Elems = protoBitArray.Elems } } + +// ValidateBasic validates a BitArray. Note that a nil BitArray and BitArray of +// size 0 bits is valid. However the number of Bits and Elems be valid based on +// each other. +func (bA *BitArray) ValidateBasic() error { + if bA == nil { + return nil + } + + expectedElems := numElements(bA.Size()) + if expectedElems != len(bA.Elems) { + return fmt.Errorf("mismatch between specified number of bits %d, and number of elements %d, expected %d elements", bA.Size(), len(bA.Elems), expectedElems) + } + return nil +} + +func numElements(bits int) int { + return (bits + 63) / 64 +} diff --git a/internal/bits/bit_array_test.go b/internal/bits/bit_array_test.go index 5b0a86c52..8d718689f 100644 --- a/internal/bits/bit_array_test.go +++ b/internal/bits/bit_array_test.go @@ -177,6 +177,28 @@ func TestGetNumTrueIndices(t *testing.T) { } } +func TestGetNumTrueIndicesInvalidStates(t *testing.T) { + testCases := []struct { + name string + bA1 *BitArray + exp int + }{ + {"empty", &BitArray{}, 0}, + {"explicit 0 bits nil elements", &BitArray{Bits: 0, Elems: nil}, 0}, + {"explicit 0 bits 0 len elements", &BitArray{Bits: 0, Elems: make([]uint64, 0)}, 0}, + {"nil", nil, 0}, + {"with elements", NewBitArray(10), 0}, + {"more elements than bits specifies", &BitArray{Bits: 0, Elems: make([]uint64, 5)}, 0}, + {"less elements than bits specifies", &BitArray{Bits: 200, Elems: make([]uint64, 1)}, 0}, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + n := tc.bA1.getNumTrueIndices() + require.Equal(t, n, tc.exp) + }) + } +} + func TestGetNthTrueIndex(t *testing.T) { type testcase struct { Input string @@ -230,7 +252,7 @@ func TestGetNthTrueIndex(t *testing.T) { } } -func TestBytes(_ *testing.T) { +func TestBytes(t *testing.T) { bA := NewBitArray(4) bA.SetIndex(0, true) check := func(bA *BitArray, bz []byte) { @@ -257,6 +279,10 @@ func TestBytes(_ *testing.T) { check(bA, []byte{0x80, 0x01}) bA.SetIndex(9, true) check(bA, []byte{0x80, 0x03}) + + bA = NewBitArray(4) + bA.Elems = nil + require.False(t, bA.SetIndex(1, true)) } func TestEmptyFull(t *testing.T) { @@ -374,6 +400,28 @@ func TestBitArrayProtoBuf(t *testing.T) { } } +func TestBitArrayValidateBasic(t *testing.T) { + testCases := []struct { + name string + bA1 *BitArray + expPass bool + }{ + {"valid empty", &BitArray{}, true}, + {"valid explicit 0 bits nil elements", &BitArray{Bits: 0, Elems: nil}, true}, + {"valid explicit 0 bits 0 len elements", &BitArray{Bits: 0, Elems: make([]uint64, 0)}, true}, + {"valid nil", nil, true}, + {"valid with elements", NewBitArray(10), true}, + {"more elements than bits specifies", &BitArray{Bits: 0, Elems: make([]uint64, 5)}, false}, + {"less elements than bits specifies", &BitArray{Bits: 200, Elems: make([]uint64, 1)}, false}, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + err := tc.bA1.ValidateBasic() + require.Equal(t, err == nil, tc.expPass) + }) + } +} + // Tests that UnmarshalJSON doesn't crash when no bits are passed into the JSON. // See issue https://github.com/cometbft/cometbft/issues/2658 func TestUnmarshalJSONDoesntCrashOnZeroBits(t *testing.T) { diff --git a/internal/consensus/reactor.go b/internal/consensus/reactor.go index 3e1a352a1..909bfd8df 100644 --- a/internal/consensus/reactor.go +++ b/internal/consensus/reactor.go @@ -1826,6 +1826,9 @@ func (m *NewValidBlockMessage) ValidateBasic() error { if err := m.BlockPartSetHeader.ValidateBasic(); err != nil { return cmterrors.ErrWrongField{Field: "BlockPartSetHeader", Err: err} } + if err := m.BlockParts.ValidateBasic(); err != nil { + return fmt.Errorf("validating BlockParts: %w", err) + } if m.BlockParts.Size() == 0 { return cmterrors.ErrRequiredField{Field: "blockParts"} } @@ -1880,6 +1883,9 @@ func (m *ProposalPOLMessage) ValidateBasic() error { if m.ProposalPOLRound < 0 { return cmterrors.ErrNegativeField{Field: "ProposalPOLRound"} } + if err := m.ProposalPOL.ValidateBasic(); err != nil { + return fmt.Errorf("validating ProposalPOL: %w", err) + } if m.ProposalPOL.Size() == 0 { return cmterrors.ErrRequiredField{Field: "ProposalPOL"} } @@ -2042,6 +2048,9 @@ func (m *VoteSetBitsMessage) ValidateBasic() error { if err := m.BlockID.ValidateBasic(); err != nil { return cmterrors.ErrWrongField{Field: "BlockID", Err: err} } + if err := m.Votes.ValidateBasic(); err != nil { + return fmt.Errorf("validating Votes: %w", err) + } // NOTE: Votes.Size() can be zero if the node does not have any if m.Votes.Size() > types.MaxVotesCount { return fmt.Errorf("votes bit array is too big: %d, max: %d", m.Votes.Size(), types.MaxVotesCount) diff --git a/internal/consensus/reactor_test.go b/internal/consensus/reactor_test.go index 7ad27cab3..284cde16a 100644 --- a/internal/consensus/reactor_test.go +++ b/internal/consensus/reactor_test.go @@ -879,6 +879,14 @@ func TestNewValidBlockMessageValidateBasic(t *testing.T) { func(msg *NewValidBlockMessage) { msg.BlockParts = bits.NewBitArray(int(types.MaxBlockPartsCount) + 1) }, "blockParts bit array size 1602 not equal to BlockPartSetHeader.Total 1", }, + { + func(msg *NewValidBlockMessage) { msg.BlockParts.Elems = nil }, + "mismatch between specified number of bits 1, and number of elements 0, expected 1 elements", + }, + { + func(msg *NewValidBlockMessage) { msg.BlockParts.Bits = 500 }, + "mismatch between specified number of bits 500, and number of elements 1, expected 8 elements", + }, } for i, tc := range testCases { @@ -914,6 +922,14 @@ func TestProposalPOLMessageValidateBasic(t *testing.T) { func(msg *ProposalPOLMessage) { msg.ProposalPOL = bits.NewBitArray(types.MaxVotesCount + 1) }, "proposalPOL bit array is too big: 10001, max: 10000", }, + { + func(msg *ProposalPOLMessage) { msg.ProposalPOL.Elems = nil }, + "mismatch between specified number of bits 1, and number of elements 0, expected 1 elements", + }, + { + func(msg *ProposalPOLMessage) { msg.ProposalPOL.Bits = 500 }, + "mismatch between specified number of bits 500, and number of elements 1, expected 8 elements", + }, } for i, tc := range testCases { @@ -1066,6 +1082,14 @@ func TestVoteSetBitsMessageValidateBasic(t *testing.T) { func(msg *VoteSetBitsMessage) { msg.Votes = bits.NewBitArray(types.MaxVotesCount + 1) }, "votes bit array is too big: 10001, max: 10000", }, + { + func(msg *VoteSetBitsMessage) { msg.Votes.Elems = nil }, + "mismatch between specified number of bits 1, and number of elements 0, expected 1 elements", + }, + { + func(msg *VoteSetBitsMessage) { msg.Votes.Bits = 500 }, + "mismatch between specified number of bits 500, and number of elements 1, expected 8 elements", + }, } for i, tc := range testCases {