Skip to content

Commit 07f2371

Browse files
committed
Strong clobber test
1 parent f834802 commit 07f2371

1 file changed

Lines changed: 145 additions & 38 deletions

File tree

cpp/src/arrow/util/rle_bitmap_test.cc

Lines changed: 145 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
#include <array>
1919
#include <cstdint>
20+
#include <random>
2021
#include <string>
2122
#include <vector>
2223

@@ -31,6 +32,17 @@ namespace arrow::util {
3132

3233
namespace {
3334

35+
/// Make a vector of `size` pseudo-random bytes, deterministic for a given `seed`.
36+
std::vector<uint8_t> MakeRandomBytes(size_t size, uint32_t seed = 56) {
37+
std::vector<uint8_t> bytes(size);
38+
std::minstd_rand gen(seed);
39+
std::uniform_int_distribution<uint8_t> dist(0, 255);
40+
for (auto& byte : bytes) {
41+
byte = dist(gen);
42+
}
43+
return bytes;
44+
}
45+
3446
/// Read the first `count` bits of `bytes` (LSB first) into a vector of booleans.
3547
std::vector<bool> BitsFromBytes(const std::vector<uint8_t>& bytes, rle_size_t count) {
3648
std::vector<bool> bits(count);
@@ -40,20 +52,41 @@ std::vector<bool> BitsFromBytes(const std::vector<uint8_t>& bytes, rle_size_t co
4052
return bits;
4153
}
4254

55+
struct CheckDecodedBitsParams {
56+
const std::vector<uint8_t>& actual;
57+
const std::vector<bool>& expected;
58+
rle_size_t count;
59+
rle_size_t actual_start_bit = 0;
60+
rle_size_t expected_start_idx = 0;
61+
};
62+
4363
/// Check the decoded output in `out` against `expected`.
44-
/// Bits `out[out_offset..out_offset + count]` must equal
45-
/// `expected[expected_skip..expected_skip + count]`. The `out_offset` bits before them
46-
/// must still be zero.
47-
void CheckDecodedBits(const std::vector<uint8_t>& out, const std::vector<bool>& expected,
48-
rle_size_t count, rle_size_t out_offset = 0,
49-
rle_size_t expected_skip = 0) {
50-
ARROW_SCOPED_TRACE("out_offset = ", out_offset, ", expected_skip = ", expected_skip);
51-
for (rle_size_t i = 0; i < out_offset; ++i) {
52-
EXPECT_FALSE(bit_util::GetBit(out.data(), i)) << "clobbered bit " << i;
64+
void CheckDecodedBits(const CheckDecodedBitsParams& params) {
65+
ARROW_SCOPED_TRACE("out_start_bit = ", params.actual_start_bit,
66+
", expected_start_idx = ", params.expected_start_idx);
67+
for (rle_size_t i = 0; i < params.count; ++i) {
68+
ASSERT_EQ(bit_util::GetBit(params.actual.data(), params.actual_start_bit + i),
69+
params.expected[params.expected_start_idx + i])
70+
<< "first difference at bit " << i;
5371
}
54-
for (rle_size_t i = 0; i < count; ++i) {
55-
EXPECT_EQ(bit_util::GetBit(out.data(), out_offset + i), expected[expected_skip + i])
56-
<< "at bit " << i;
72+
}
73+
74+
struct CheckBitsEqualParams {
75+
const std::vector<uint8_t>& actual;
76+
const std::vector<uint8_t>& expected;
77+
rle_size_t count;
78+
rle_size_t actual_start_bit = 0;
79+
rle_size_t expected_start_bit = 0;
80+
};
81+
82+
/// Check that two bit ranges, stored in `actual` and `expected`, are equal.
83+
void CheckBitsEqual(const CheckBitsEqualParams& params) {
84+
ARROW_SCOPED_TRACE("actual_start_bit = ", params.actual_start_bit,
85+
", expected_start_bit = ", params.expected_start_bit);
86+
for (rle_size_t i = 0; i < params.count; ++i) {
87+
ASSERT_EQ(bit_util::GetBit(params.actual.data(), params.actual_start_bit + i),
88+
bit_util::GetBit(params.expected.data(), params.expected_start_bit + i))
89+
<< "first difference at bit " << i;
5790
}
5891
}
5992

@@ -68,37 +101,103 @@ void CheckDecodedBits(const std::vector<uint8_t>& out, const std::vector<bool>&
68101
/// read path of BitPackedRunToBitmapDecoder. With `expected_skip == 0` they stay in sync
69102
/// and only the aligned path runs.
70103
template <typename Decoder>
71-
void CheckChunkedDecode(const typename Decoder::RunType& run,
72-
const std::vector<bool>& expected, rle_size_t chunk_size = 1,
73-
rle_size_t expected_skip = 0) {
104+
void CheckDecoderValuesChunked(const typename Decoder::RunType& run,
105+
const std::vector<bool>& expected,
106+
rle_size_t chunk_size = 1, rle_size_t expected_skip = 0) {
74107
ARROW_SCOPED_TRACE("chunk_size = ", chunk_size, ", expected_skip = ", expected_skip);
108+
75109
const auto n_vals = static_cast<rle_size_t>(expected.size());
76110
ASSERT_LE(expected_skip, n_vals);
77111

78112
Decoder decoder(run);
79113
const auto advanced = decoder.Advance(expected_skip);
80114
ASSERT_EQ(advanced, expected_skip);
81-
const auto rest = n_vals - expected_skip;
115+
const auto n_vals_to_decode = n_vals - expected_skip;
82116

83-
// Output buffer with one guard byte to catch out-of-bounds writes.
84-
std::vector<uint8_t> out(static_cast<size_t>(bit_util::BytesForBits(rest)) + 1, 0);
85-
const uint8_t guard = 0xA5;
86-
out.back() = guard;
117+
// Output buffer
118+
const auto n_bytes = static_cast<size_t>(bit_util::BytesForBits(n_vals_to_decode));
119+
std::vector<uint8_t> out(n_bytes, 0);
87120

88-
rle_size_t read = 0;
89-
while (read < rest) {
90-
const auto want = std::min(chunk_size, rest - read);
121+
rle_size_t n_val_read = 0;
122+
while (n_val_read < n_vals_to_decode) {
123+
const auto want = std::min(chunk_size, n_vals_to_decode - n_val_read);
91124
const auto got =
92-
decoder.GetBatch(BitmapSpanMut(out.data(), /*bit_start=*/read), want);
93-
EXPECT_EQ(got, want) << "at pos " << read;
94-
ASSERT_GT(got, 0) << "at pos " << read; // break on failure
95-
read += got;
96-
EXPECT_EQ(decoder.remaining(), rest - read);
125+
decoder.GetBatch(BitmapSpanMut(out.data(), /*bit_start=*/n_val_read), want);
126+
EXPECT_EQ(got, want) << "at pos " << n_val_read;
127+
ASSERT_GT(got, 0) << "at pos " << n_val_read; // break on failure
128+
n_val_read += got;
129+
EXPECT_EQ(decoder.remaining(), n_vals_to_decode - n_val_read);
97130
}
98131

99132
EXPECT_EQ(decoder.remaining(), 0);
100-
EXPECT_EQ(out.back(), guard) << "decoder wrote past the end of the output";
101-
CheckDecodedBits(out, expected, /*count=*/rest, /*out_offset=*/0, expected_skip);
133+
CheckDecodedBits({
134+
.actual = out,
135+
.expected = expected,
136+
.count = n_vals_to_decode,
137+
.actual_start_bit = 0,
138+
.expected_start_idx = expected_skip,
139+
});
140+
}
141+
142+
/// Decode a chunk of data into a known output to check for out of bounds write.
143+
///
144+
/// @see CheckDecoderValuesChunked
145+
template <typename Decoder>
146+
void CheckDecoderClobber(const typename Decoder::RunType& run,
147+
const std::vector<bool>& expected, rle_size_t chunk_size = 1,
148+
rle_size_t expected_skip = 0) {
149+
ARROW_SCOPED_TRACE("chunk_size = ", chunk_size, ", expected_skip = ", expected_skip);
150+
151+
const auto n_vals = static_cast<rle_size_t>(expected.size());
152+
ASSERT_LE(expected_skip, n_vals);
153+
154+
Decoder decoder(run);
155+
const auto advanced = decoder.Advance(expected_skip);
156+
ASSERT_EQ(advanced, expected_skip);
157+
const auto n_vals_to_decode = n_vals - expected_skip;
158+
159+
// Output buffer with enough capacity to store a full chunk plus extra bytes as
160+
// clobbers/guard to check for out of bounds write.
161+
const auto n_bytes = static_cast<size_t>(bit_util::BytesForBits(chunk_size) +
162+
bit_util::CeilDiv(n_vals, chunk_size) + 2);
163+
// This seed is arbitrary and of little importance. We are simply trying to avoid an
164+
// unlikely case where guards have the same pattern in all invocations.
165+
const auto out_pattern =
166+
MakeRandomBytes(n_bytes, /* seed= */ (chunk_size << 16) ^ expected_skip);
167+
auto out = out_pattern;
168+
169+
rle_size_t n_val_read = 0;
170+
rle_size_t out_bit_start = 0;
171+
while (n_val_read < n_vals_to_decode) {
172+
// Clean output buffer
173+
out = out_pattern;
174+
const auto want = std::min(chunk_size, n_vals_to_decode - n_val_read);
175+
const auto got = decoder.GetBatch(BitmapSpanMut(out.data(), out_bit_start), want);
176+
ASSERT_GT(got, 0) << "at pos " << n_val_read; // break on failure
177+
EXPECT_EQ(got, want) << "at pos " << n_val_read;
178+
// Check that the leading bits have not been modified
179+
CheckBitsEqual({.actual = out, .expected = out_pattern, .count = out_bit_start});
180+
// Check that the trailing bits have not been modified
181+
CheckBitsEqual({
182+
.actual = out,
183+
.expected = out_pattern,
184+
.count = static_cast<rle_size_t>(8 * n_bytes) - (out_bit_start + want),
185+
.actual_start_bit = out_bit_start + want,
186+
.expected_start_bit = out_bit_start + want,
187+
});
188+
// Check decoded bits are also correct
189+
CheckDecodedBits({
190+
.actual = out,
191+
.expected = expected,
192+
.count = want,
193+
.actual_start_bit = out_bit_start,
194+
.expected_start_idx = expected_skip + n_val_read,
195+
});
196+
197+
n_val_read += got;
198+
++out_bit_start;
199+
EXPECT_EQ(decoder.remaining(), n_vals_to_decode - n_val_read);
200+
}
102201
}
103202

104203
/// All the checks shared by both decoder types.
@@ -127,7 +226,7 @@ void CheckBitmapDecoder(const typename Decoder::RunType& run,
127226
// Decode the whole run in several chunks.
128227
for (const rle_size_t chunk_size : {rle_size_t{1}, rle_size_t{3}, rle_size_t{7},
129228
rle_size_t{8}, rle_size_t{9}, n_vals, n_vals + 1}) {
130-
CheckChunkedDecode<Decoder>(run, expected, chunk_size);
229+
CheckDecoderValuesChunked<Decoder>(run, expected, chunk_size);
131230
}
132231

133232
// Decode the whole run in several chunks, after an initial Advance that shifts
@@ -136,7 +235,10 @@ void CheckBitmapDecoder(const typename Decoder::RunType& run,
136235
rle_size_t{8}, rle_size_t{9}, n_vals, n_vals + 1}) {
137236
for (rle_size_t expected_skip = 1; expected_skip < 8 && expected_skip < n_vals;
138237
++expected_skip) {
139-
CheckChunkedDecode<Decoder>(run, expected, chunk_size, expected_skip);
238+
// Check the decoding happens as expected
239+
CheckDecoderValuesChunked<Decoder>(run, expected, chunk_size, expected_skip);
240+
// Check the decoding does not write out of bounds
241+
CheckDecoderClobber<Decoder>(run, expected, chunk_size, expected_skip);
140242
}
141243
}
142244

@@ -155,7 +257,7 @@ void CheckBitmapDecoder(const typename Decoder::RunType& run,
155257
const auto advanced = decoder.Advance(1);
156258
EXPECT_EQ(advanced, 0);
157259
EXPECT_EQ(decoder.remaining(), 0);
158-
CheckDecodedBits(out, expected, /*count=*/n_vals);
260+
CheckDecodedBits({.actual = out, .expected = expected, .count = n_vals});
159261
}
160262

161263
// Advancing more than available stops at the run boundary.
@@ -179,7 +281,7 @@ void CheckBitmapDecoder(const typename Decoder::RunType& run,
179281
std::vector<uint8_t> out_2(static_cast<size_t>(bit_util::BytesForBits(n_vals)), 0);
180282
const auto got = decoder.GetBatch(BitmapSpanMut(out_2.data()), n_vals);
181283
EXPECT_EQ(got, n_vals);
182-
CheckDecodedBits(out_2, expected, /*count=*/n_vals);
284+
CheckDecodedBits({.actual = out_2, .expected = expected, .count = n_vals});
183285
}
184286
}
185287

@@ -337,7 +439,12 @@ void CheckRleBitPackedDecode(const std::vector<uint8_t>& bytes,
337439
EXPECT_TRUE(decoder.exhausted());
338440

339441
EXPECT_EQ(out.back(), guard) << "decoder wrote past the end of the output";
340-
CheckDecodedBits(out, expected, /*count=*/n_vals, out_offset);
442+
CheckDecodedBits({
443+
.actual = out,
444+
.expected = expected,
445+
.count = n_vals,
446+
.actual_start_bit = out_offset,
447+
});
341448
}
342449

343450
/// Run the decode check over a battery of chunk sizes and output offsets.
@@ -438,7 +545,7 @@ TEST(RleBitPackedToBitmapDecoder, ReadPastEnd) {
438545
EXPECT_TRUE(decoder.exhausted());
439546
got = decoder.GetBatch(BitmapSpanMut(out.data()), 10);
440547
EXPECT_EQ(got, 0);
441-
CheckDecodedBits(out, expected, /*count=*/n_vals);
548+
CheckDecodedBits({.actual = out, .expected = expected, .count = n_vals});
442549
}
443550

444551
TEST(RleBitPackedToBitmapDecoder, Reset) {
@@ -463,7 +570,7 @@ TEST(RleBitPackedToBitmapDecoder, Reset) {
463570
const auto got_2 = decoder.GetBatch(BitmapSpanMut(out_2.data()), n_vals);
464571
EXPECT_EQ(got_2, n_vals);
465572
EXPECT_TRUE(decoder.exhausted());
466-
CheckDecodedBits(out_2, expected, /*count=*/n_vals);
573+
CheckDecodedBits({.actual = out_2, .expected = expected, .count = n_vals});
467574
}
468575

469576
TEST(RleBitPackedToBitmapDecoder, Truncated) {
@@ -485,7 +592,7 @@ TEST(RleBitPackedToBitmapDecoder, Truncated) {
485592
const auto got = decoder.GetBatch(BitmapSpanMut(out.data()), 1000);
486593
EXPECT_EQ(got, 10);
487594
EXPECT_FALSE(decoder.exhausted());
488-
CheckDecodedBits(out, expected, /*count=*/10);
595+
CheckDecodedBits({.actual = out, .expected = expected, .count = 10});
489596
}
490597

491598
} // namespace arrow::util

0 commit comments

Comments
 (0)