Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 28 additions & 4 deletions internal/bits/bit_array.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ func NewBitArray(bits int) *BitArray {
}
return &BitArray{
Bits: bits,
Elems: make([]uint64, (bits+63)/64),
Elems: make([]uint64, numElements(bits)),
}
}

Expand All @@ -41,7 +41,7 @@ func NewBitArrayFromFn(bits int, fn func(int) bool) *BitArray {
}
bA := &BitArray{
Bits: bits,
Elems: make([]uint64, (bits+63)/64),
Elems: make([]uint64, numElements(bits)),
}
for i := 0; i < bits; i++ {
v := fn(i)
Expand Down Expand Up @@ -90,7 +90,7 @@ func (bA *BitArray) SetIndex(i int, v bool) bool {
}

func (bA *BitArray) setIndex(i int, v bool) bool {
if i >= bA.Bits {
if i >= bA.Bits || i/64 >= len(bA.Elems) {
return false
}
if v {
Expand Down Expand Up @@ -121,7 +121,7 @@ func (bA *BitArray) copy() *BitArray {
}

func (bA *BitArray) copyBits(bits int) *BitArray {
c := make([]uint64, (bits+63)/64)
c := make([]uint64, numElements(bits))
copy(c, bA.Elems)
return &BitArray{
Bits: bits,
Expand Down Expand Up @@ -282,6 +282,11 @@ func (bA *BitArray) PickRandom(r *rand.Rand) (int, bool) {
}

func (bA *BitArray) getNumTrueIndices() int {
if bA.Size() == 0 || len(bA.Elems) == 0 || len(bA.Elems) != numElements(bA.Size()) {
// size and elements must be valid to do this calc
return 0
}

count := 0
numElems := len(bA.Elems)
// handle all elements except the last one
Expand Down Expand Up @@ -500,3 +505,22 @@ func (bA *BitArray) FromProto(protoBitArray *cmtprotobits.BitArray) {
bA.Elems = protoBitArray.Elems
}
}

// ValidateBasic validates a BitArray. Note that a nil BitArray and BitArray of
// size 0 bits is valid. However the number of Bits and Elems be valid based on
// each other.
func (bA *BitArray) ValidateBasic() error {
if bA == nil {
return nil
}

expectedElems := numElements(bA.Size())
if expectedElems != len(bA.Elems) {
return fmt.Errorf("mismatch between specified number of bits %d, and number of elements %d, expected %d elements", bA.Size(), len(bA.Elems), expectedElems)
}
return nil
}

func numElements(bits int) int {
return (bits + 63) / 64
}
50 changes: 49 additions & 1 deletion internal/bits/bit_array_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,28 @@ func TestGetNumTrueIndices(t *testing.T) {
}
}

func TestGetNumTrueIndicesInvalidStates(t *testing.T) {
testCases := []struct {
name string
bA1 *BitArray
exp int
}{
{"empty", &BitArray{}, 0},
{"explicit 0 bits nil elements", &BitArray{Bits: 0, Elems: nil}, 0},
{"explicit 0 bits 0 len elements", &BitArray{Bits: 0, Elems: make([]uint64, 0)}, 0},
{"nil", nil, 0},
{"with elements", NewBitArray(10), 0},
{"more elements than bits specifies", &BitArray{Bits: 0, Elems: make([]uint64, 5)}, 0},
{"less elements than bits specifies", &BitArray{Bits: 200, Elems: make([]uint64, 1)}, 0},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
n := tc.bA1.getNumTrueIndices()
require.Equal(t, n, tc.exp)
})
}
}

func TestGetNthTrueIndex(t *testing.T) {
type testcase struct {
Input string
Expand Down Expand Up @@ -230,7 +252,7 @@ func TestGetNthTrueIndex(t *testing.T) {
}
}

func TestBytes(_ *testing.T) {
func TestBytes(t *testing.T) {
bA := NewBitArray(4)
bA.SetIndex(0, true)
check := func(bA *BitArray, bz []byte) {
Expand All @@ -257,6 +279,10 @@ func TestBytes(_ *testing.T) {
check(bA, []byte{0x80, 0x01})
bA.SetIndex(9, true)
check(bA, []byte{0x80, 0x03})

bA = NewBitArray(4)
bA.Elems = nil
require.False(t, bA.SetIndex(1, true))
}

func TestEmptyFull(t *testing.T) {
Expand Down Expand Up @@ -374,6 +400,28 @@ func TestBitArrayProtoBuf(t *testing.T) {
}
}

func TestBitArrayValidateBasic(t *testing.T) {
testCases := []struct {
name string
bA1 *BitArray
expPass bool
}{
{"valid empty", &BitArray{}, true},
{"valid explicit 0 bits nil elements", &BitArray{Bits: 0, Elems: nil}, true},
{"valid explicit 0 bits 0 len elements", &BitArray{Bits: 0, Elems: make([]uint64, 0)}, true},
{"valid nil", nil, true},
{"valid with elements", NewBitArray(10), true},
{"more elements than bits specifies", &BitArray{Bits: 0, Elems: make([]uint64, 5)}, false},
{"less elements than bits specifies", &BitArray{Bits: 200, Elems: make([]uint64, 1)}, false},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
err := tc.bA1.ValidateBasic()
require.Equal(t, err == nil, tc.expPass)
})
}
}

// Tests that UnmarshalJSON doesn't crash when no bits are passed into the JSON.
// See issue https://github.com/cometbft/cometbft/issues/2658
func TestUnmarshalJSONDoesntCrashOnZeroBits(t *testing.T) {
Expand Down
9 changes: 9 additions & 0 deletions internal/consensus/reactor.go
Original file line number Diff line number Diff line change
Expand Up @@ -1826,6 +1826,9 @@ func (m *NewValidBlockMessage) ValidateBasic() error {
if err := m.BlockPartSetHeader.ValidateBasic(); err != nil {
return cmterrors.ErrWrongField{Field: "BlockPartSetHeader", Err: err}
}
if err := m.BlockParts.ValidateBasic(); err != nil {
return fmt.Errorf("validating BlockParts: %w", err)
}
if m.BlockParts.Size() == 0 {
return cmterrors.ErrRequiredField{Field: "blockParts"}
}
Expand Down Expand Up @@ -1880,6 +1883,9 @@ func (m *ProposalPOLMessage) ValidateBasic() error {
if m.ProposalPOLRound < 0 {
return cmterrors.ErrNegativeField{Field: "ProposalPOLRound"}
}
if err := m.ProposalPOL.ValidateBasic(); err != nil {
return fmt.Errorf("validating ProposalPOL: %w", err)
}
if m.ProposalPOL.Size() == 0 {
return cmterrors.ErrRequiredField{Field: "ProposalPOL"}
}
Expand Down Expand Up @@ -2042,6 +2048,9 @@ func (m *VoteSetBitsMessage) ValidateBasic() error {
if err := m.BlockID.ValidateBasic(); err != nil {
return cmterrors.ErrWrongField{Field: "BlockID", Err: err}
}
if err := m.Votes.ValidateBasic(); err != nil {
return fmt.Errorf("validating Votes: %w", err)
}
// NOTE: Votes.Size() can be zero if the node does not have any
if m.Votes.Size() > types.MaxVotesCount {
return fmt.Errorf("votes bit array is too big: %d, max: %d", m.Votes.Size(), types.MaxVotesCount)
Expand Down
24 changes: 24 additions & 0 deletions internal/consensus/reactor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -879,6 +879,14 @@ func TestNewValidBlockMessageValidateBasic(t *testing.T) {
func(msg *NewValidBlockMessage) { msg.BlockParts = bits.NewBitArray(int(types.MaxBlockPartsCount) + 1) },
"blockParts bit array size 1602 not equal to BlockPartSetHeader.Total 1",
},
{
func(msg *NewValidBlockMessage) { msg.BlockParts.Elems = nil },
"mismatch between specified number of bits 1, and number of elements 0, expected 1 elements",
},
{
func(msg *NewValidBlockMessage) { msg.BlockParts.Bits = 500 },
"mismatch between specified number of bits 500, and number of elements 1, expected 8 elements",
},
}

for i, tc := range testCases {
Expand Down Expand Up @@ -914,6 +922,14 @@ func TestProposalPOLMessageValidateBasic(t *testing.T) {
func(msg *ProposalPOLMessage) { msg.ProposalPOL = bits.NewBitArray(types.MaxVotesCount + 1) },
"proposalPOL bit array is too big: 10001, max: 10000",
},
{
func(msg *ProposalPOLMessage) { msg.ProposalPOL.Elems = nil },
"mismatch between specified number of bits 1, and number of elements 0, expected 1 elements",
},
{
func(msg *ProposalPOLMessage) { msg.ProposalPOL.Bits = 500 },
"mismatch between specified number of bits 500, and number of elements 1, expected 8 elements",
},
}

for i, tc := range testCases {
Expand Down Expand Up @@ -1066,6 +1082,14 @@ func TestVoteSetBitsMessageValidateBasic(t *testing.T) {
func(msg *VoteSetBitsMessage) { msg.Votes = bits.NewBitArray(types.MaxVotesCount + 1) },
"votes bit array is too big: 10001, max: 10000",
},
{
func(msg *VoteSetBitsMessage) { msg.Votes.Elems = nil },
"mismatch between specified number of bits 1, and number of elements 0, expected 1 elements",
},
{
func(msg *VoteSetBitsMessage) { msg.Votes.Bits = 500 },
"mismatch between specified number of bits 500, and number of elements 1, expected 8 elements",
},
}

for i, tc := range testCases {
Expand Down
Loading