Skip to content

Commit 14851b5

Browse files
committed
GH-49753: [C++][Gandiva] Fix overflow in string functions.
Fixed overflows and added unit-tests.
1 parent d7a02c1 commit 14851b5

4 files changed

Lines changed: 158 additions & 20 deletions

File tree

cpp/src/gandiva/gdv_function_stubs_test.cc

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -640,6 +640,24 @@ TEST(TestGdvFnStubs, TestUpper) {
640640
EXPECT_THAT(ctx.get_error(),
641641
::testing::HasSubstr(
642642
"unexpected byte \\c3 encountered while decoding utf8 string"));
643+
644+
// Max Len Test
645+
out_len = -1;
646+
int32_t bad_len = std::numeric_limits<int32_t>::max() / 2 + 1;
647+
const char* out = gdv_fn_upper_utf8(ctx_ptr, "dummy", bad_len, &out_len);
648+
// Expect failure
649+
EXPECT_EQ(out_len, 0);
650+
EXPECT_STREQ(out, "");
651+
EXPECT_THAT(ctx.get_error(),
652+
::testing::HasSubstr("Would overflow maximum output size"));
653+
654+
// Negative length test
655+
out_len = -1;
656+
out = gdv_fn_upper_utf8(ctx_ptr, "abc", -105, &out_len);
657+
EXPECT_EQ(out_len, 0);
658+
EXPECT_STREQ(out, "");
659+
EXPECT_THAT(ctx.get_error(), ::testing::HasSubstr("Invalid (negative) data length"));
660+
643661
ctx.Reset();
644662

645663
std::string e(
@@ -698,6 +716,23 @@ TEST(TestGdvFnStubs, TestLower) {
698716
EXPECT_EQ(std::string(out_str, out_len), "");
699717
EXPECT_FALSE(ctx.has_error());
700718

719+
// Max Len Test
720+
out_len = -1;
721+
int32_t bad_len = std::numeric_limits<int32_t>::max() / 2 + 1;
722+
const char* out = gdv_fn_lower_utf8(ctx_ptr, "dummy", bad_len, &out_len);
723+
// Expect failure
724+
EXPECT_EQ(out_len, 0);
725+
EXPECT_STREQ(out, "");
726+
EXPECT_THAT(ctx.get_error(),
727+
::testing::HasSubstr("Would overflow maximum output size"));
728+
729+
// Negative length test
730+
out_len = -1;
731+
out = gdv_fn_lower_utf8(ctx_ptr, "abc", -105, &out_len);
732+
EXPECT_EQ(out_len, 0);
733+
EXPECT_STREQ(out, "");
734+
EXPECT_THAT(ctx.get_error(), ::testing::HasSubstr("Invalid (negative) data length"));
735+
701736
std::string d("AbOJjÜoß\xc3");
702737
out_str = gdv_fn_lower_utf8(ctx_ptr, d.data(), static_cast<int>(d.length()), &out_len);
703738
EXPECT_EQ(std::string(out_str, out_len), "");
@@ -794,6 +829,24 @@ TEST(TestGdvFnStubs, TestInitCap) {
794829
EXPECT_THAT(ctx.get_error(),
795830
::testing::HasSubstr(
796831
"unexpected byte \\c3 encountered while decoding utf8 string"));
832+
833+
// Max Len Test
834+
out_len = -1;
835+
int32_t bad_len = std::numeric_limits<int32_t>::max() / 2 + 1;
836+
const char* out = gdv_fn_initcap_utf8(ctx_ptr, "dummy", bad_len, &out_len);
837+
// Expect failure
838+
EXPECT_EQ(out_len, 0);
839+
EXPECT_STREQ(out, "");
840+
EXPECT_THAT(ctx.get_error(),
841+
::testing::HasSubstr("Would overflow maximum output size"));
842+
843+
// Negative length test
844+
out_len = -1;
845+
out = gdv_fn_initcap_utf8(ctx_ptr, "abc", -105, &out_len);
846+
EXPECT_EQ(out_len, 0);
847+
EXPECT_STREQ(out, "");
848+
EXPECT_THAT(ctx.get_error(), ::testing::HasSubstr("Invalid (negative) data length"));
849+
797850
ctx.Reset();
798851

799852
std::string e(
@@ -1127,6 +1180,15 @@ TEST(TestGdvFnStubs, TestTranslate) {
11271180
result = translate_utf8_utf8_utf8(ctx_ptr, "987654321", 9, "123456789", 9, "0123456789",
11281181
10, &out_len);
11291182
EXPECT_EQ(expected, std::string(result, out_len));
1183+
1184+
int32_t bad_in_len = std::numeric_limits<int32_t>::max() / 4 + 1;
1185+
out_len = -1;
1186+
result =
1187+
translate_utf8_utf8_utf8(ctx_ptr, "ABCDE", bad_in_len, "B", 1, "C", 1, &out_len);
1188+
EXPECT_EQ(out_len, 0);
1189+
EXPECT_STREQ(result, "");
1190+
EXPECT_THAT(ctx.get_error(),
1191+
::testing::HasSubstr("Would overflow maximum output size"));
11301192
}
11311193

11321194
TEST(TestGdvFnStubs, TestToUtcTimezone) {

cpp/src/gandiva/gdv_string_function_stubs.cc

Lines changed: 65 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -213,6 +213,25 @@ int32_t gdv_fn_utf8_char_length(char c) {
213213
return 0;
214214
}
215215

216+
static inline bool compute_alloc_len(int64_t context, int32_t data_len,
217+
int32_t* alloc_len, int32_t* out_len) {
218+
// Reject negative lengths
219+
if (ARROW_PREDICT_FALSE(data_len < 0)) {
220+
gdv_fn_context_set_error_msg(context, "Invalid (negative) data length");
221+
*out_len = 0;
222+
return false;
223+
}
224+
225+
// Check overflow: 2 * data_len
226+
if (ARROW_PREDICT_FALSE(
227+
arrow::internal::MultiplyWithOverflow(2, data_len, alloc_len))) {
228+
gdv_fn_context_set_error_msg(context, "Would overflow maximum output size");
229+
*out_len = 0;
230+
return false;
231+
}
232+
return true;
233+
}
234+
216235
// Convert an utf8 string to its corresponding lowercase string
217236
GANDIVA_EXPORT
218237
const char* gdv_fn_lower_utf8(int64_t context, const char* data, int32_t data_len,
@@ -222,10 +241,16 @@ const char* gdv_fn_lower_utf8(int64_t context, const char* data, int32_t data_le
222241
return "";
223242
}
224243

244+
int32_t alloc_length = 0;
245+
if (ARROW_PREDICT_FALSE(
246+
!compute_alloc_len(context, data_len, &alloc_length, out_len))) {
247+
return "";
248+
}
249+
225250
// If it is a single-byte character (ASCII), corresponding lowercase is always 1-byte
226251
// long; if it is >= 2 bytes long, lowercase can be at most 4 bytes long, so length of
227252
// the output can be at most twice the length of the input
228-
char* out = reinterpret_cast<char*>(gdv_fn_context_arena_malloc(context, 2 * data_len));
253+
char* out = reinterpret_cast<char*>(gdv_fn_context_arena_malloc(context, alloc_length));
229254
if (out == nullptr) {
230255
gdv_fn_context_set_error_msg(context, "Could not allocate memory for output string");
231256
*out_len = 0;
@@ -294,10 +319,16 @@ const char* gdv_fn_upper_utf8(int64_t context, const char* data, int32_t data_le
294319
return "";
295320
}
296321

322+
int32_t alloc_length = 0;
323+
if (ARROW_PREDICT_FALSE(
324+
!compute_alloc_len(context, data_len, &alloc_length, out_len))) {
325+
return "";
326+
}
327+
297328
// If it is a single-byte character (ASCII), corresponding uppercase is always 1-byte
298329
// long; if it is >= 2 bytes long, uppercase can be at most 4 bytes long, so length of
299330
// the output can be at most twice the length of the input
300-
char* out = reinterpret_cast<char*>(gdv_fn_context_arena_malloc(context, 2 * data_len));
331+
char* out = reinterpret_cast<char*>(gdv_fn_context_arena_malloc(context, alloc_length));
301332
if (out == nullptr) {
302333
gdv_fn_context_set_error_msg(context, "Could not allocate memory for output string");
303334
*out_len = 0;
@@ -480,10 +511,16 @@ const char* gdv_fn_initcap_utf8(int64_t context, const char* data, int32_t data_
480511
return "";
481512
}
482513

514+
int32_t alloc_length = 0;
515+
if (ARROW_PREDICT_FALSE(
516+
!compute_alloc_len(context, data_len, &alloc_length, out_len))) {
517+
return "";
518+
}
519+
483520
// If it is a single-byte character (ASCII), corresponding uppercase is always 1-byte
484521
// long; if it is >= 2 bytes long, uppercase can be at most 4 bytes long, so length of
485522
// the output can be at most twice the length of the input
486-
char* out = reinterpret_cast<char*>(gdv_fn_context_arena_malloc(context, 2 * data_len));
523+
char* out = reinterpret_cast<char*>(gdv_fn_context_arena_malloc(context, alloc_length));
487524
if (out == nullptr) {
488525
gdv_fn_context_set_error_msg(context, "Could not allocate memory for output string");
489526
*out_len = 0;
@@ -579,15 +616,24 @@ const char* translate_utf8_utf8_utf8(int64_t context, const char* in, int32_t in
579616
return in;
580617
}
581618

619+
int32_t alloc_length = 0;
620+
// Check overflow: 4 * in_len
621+
if (ARROW_PREDICT_FALSE(
622+
arrow::internal::MultiplyWithOverflow(4, in_len, &alloc_length))) {
623+
gdv_fn_context_set_error_msg(context, "Would overflow maximum output size");
624+
*out_len = 0;
625+
return "";
626+
}
627+
582628
// This variable is to control if there are multi-byte utf8 entries
583629
bool has_multi_byte = false;
584630

585631
// This variable is to store the final result
586632
char* result;
587-
int result_len;
633+
int32_t result_len;
588634

589635
// Searching multi-bytes in In
590-
for (int i = 0; i < in_len; i++) {
636+
for (int32_t i = 0; i < in_len; i++) {
591637
unsigned char char_single_byte = in[i];
592638
if (char_single_byte > 127) {
593639
// found a multi-byte utf-8 char
@@ -598,7 +644,7 @@ const char* translate_utf8_utf8_utf8(int64_t context, const char* in, int32_t in
598644

599645
// Searching multi-bytes in From
600646
if (!has_multi_byte) {
601-
for (int i = 0; i < from_len; i++) {
647+
for (int32_t i = 0; i < from_len; i++) {
602648
unsigned char char_single_byte = from[i];
603649
if (char_single_byte > 127) {
604650
// found a multi-byte utf-8 char
@@ -610,7 +656,7 @@ const char* translate_utf8_utf8_utf8(int64_t context, const char* in, int32_t in
610656

611657
// Searching multi-bytes in To
612658
if (!has_multi_byte) {
613-
for (int i = 0; i < to_len; i++) {
659+
for (int32_t i = 0; i < to_len; i++) {
614660
unsigned char char_single_byte = to[i];
615661
if (char_single_byte > 127) {
616662
// found a multi-byte utf-8 char
@@ -638,7 +684,7 @@ const char* translate_utf8_utf8_utf8(int64_t context, const char* in, int32_t in
638684

639685
// This variable is for controlling the position in entry TO, for never repeat the
640686
// changes
641-
int start_compare;
687+
int32_t start_compare;
642688

643689
if (to_len > 0) {
644690
start_compare = 0;
@@ -650,15 +696,15 @@ const char* translate_utf8_utf8_utf8(int64_t context, const char* in, int32_t in
650696
// list, to mark deletion positions
651697
const char empty = '\0';
652698

653-
for (int in_for = 0; in_for < in_len; in_for++) {
699+
for (int32_t in_for = 0; in_for < in_len; in_for++) {
654700
if (subs_list.find(in[in_for]) != subs_list.end()) {
655701
if (subs_list[in[in_for]] != empty) {
656702
// If exist in map, only add the correspondent value in result
657703
result[result_len] = subs_list[in[in_for]];
658704
result_len++;
659705
}
660706
} else {
661-
for (int from_for = 0; from_for <= from_len; from_for++) {
707+
for (int32_t from_for = 0; from_for <= from_len; from_for++) {
662708
if (from_for == from_len) {
663709
// If it's not in the FROM list, just add it to the map and the result.
664710
subs_list.insert(std::pair<char, char>(in[in_for], in[in_for]));
@@ -686,10 +732,11 @@ const char* translate_utf8_utf8_utf8(int64_t context, const char* in, int32_t in
686732
}
687733
}
688734
}
689-
} else { // If there are no multibytes in the input, work with std::strings
735+
} else {
736+
// If there are no multibytes in the input, work with std::strings
690737
// This variable is for receive the substitutions, malloc is in_len * 4 to receive
691738
// possible inputs with 4 bytes
692-
result = reinterpret_cast<char*>(gdv_fn_context_arena_malloc(context, in_len * 4));
739+
result = reinterpret_cast<char*>(gdv_fn_context_arena_malloc(context, alloc_length));
693740

694741
if (result == nullptr) {
695742
gdv_fn_context_set_error_msg(context,
@@ -704,7 +751,7 @@ const char* translate_utf8_utf8_utf8(int64_t context, const char* in, int32_t in
704751

705752
// This variable is for controlling the position in entry TO, for never repeat the
706753
// changes
707-
int start_compare;
754+
int32_t start_compare;
708755

709756
if (to_len > 0) {
710757
start_compare = 0;
@@ -717,11 +764,11 @@ const char* translate_utf8_utf8_utf8(int64_t context, const char* in, int32_t in
717764
const std::string empty = "";
718765

719766
// This variables is to control len of multi-bytes entries
720-
int len_char_in = 0;
721-
int len_char_from = 0;
722-
int len_char_to = 0;
767+
int32_t len_char_in = 0;
768+
int32_t len_char_from = 0;
769+
int32_t len_char_to = 0;
723770

724-
for (int in_for = 0; in_for < in_len; in_for += len_char_in) {
771+
for (int32_t in_for = 0; in_for < in_len; in_for += len_char_in) {
725772
// Updating len to char in this position
726773
len_char_in = gdv_fn_utf8_char_length(in[in_for]);
727774
// Making copy to std::string with length for this char position
@@ -734,7 +781,7 @@ const char* translate_utf8_utf8_utf8(int64_t context, const char* in, int32_t in
734781
result_len += static_cast<int>(subs_list[insert_copy_key].length());
735782
}
736783
} else {
737-
for (int from_for = 0; from_for <= from_len; from_for += len_char_from) {
784+
for (int32_t from_for = 0; from_for <= from_len; from_for += len_char_from) {
738785
// Updating len to char in this position
739786
len_char_from = gdv_fn_utf8_char_length(from[from_for]);
740787
// Making copy to std::string with length for this char position

cpp/src/gandiva/precompiled/string_ops.cc

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1924,9 +1924,19 @@ const char* quote_utf8(gdv_int64 context, const char* in, gdv_int32 in_len,
19241924
*out_len = 0;
19251925
return "";
19261926
}
1927+
1928+
int32_t alloc_length = 0;
1929+
// Check overflow: 2 * in_len
1930+
if (ARROW_PREDICT_FALSE(
1931+
arrow::internal::MultiplyWithOverflow(2, in_len, &alloc_length))) {
1932+
gdv_fn_context_set_error_msg(context, "Would overflow maximum output size");
1933+
*out_len = 0;
1934+
return "";
1935+
}
1936+
19271937
// try to allocate double size output string (worst case)
19281938
auto out =
1929-
reinterpret_cast<char*>(gdv_fn_context_arena_malloc(context, (in_len * 2) + 2));
1939+
reinterpret_cast<char*>(gdv_fn_context_arena_malloc(context, alloc_length + 2));
19301940
if (out == nullptr) {
19311941
gdv_fn_context_set_error_msg(context, "Could not allocate memory for output string");
19321942
*out_len = 0;
@@ -2829,8 +2839,17 @@ const char* to_hex_binary(int64_t context, const char* text, int32_t text_len,
28292839
return "";
28302840
}
28312841

2842+
int32_t alloc_length = 0;
2843+
// Check overflow: 2 * text_len
2844+
if (ARROW_PREDICT_FALSE(
2845+
arrow::internal::MultiplyWithOverflow(2, text_len, &alloc_length))) {
2846+
gdv_fn_context_set_error_msg(context, "Would overflow maximum output size");
2847+
*out_len = 0;
2848+
return "";
2849+
}
2850+
28322851
auto ret =
2833-
reinterpret_cast<char*>(gdv_fn_context_arena_malloc(context, text_len * 2 + 1));
2852+
reinterpret_cast<char*>(gdv_fn_context_arena_malloc(context, alloc_length + 1));
28342853

28352854
if (ret == nullptr) {
28362855
gdv_fn_context_set_error_msg(context, "Could not allocate memory for output string");

cpp/src/gandiva/precompiled/string_ops_test.cc

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1165,6 +1165,11 @@ TEST(TestStringOps, TestQuote) {
11651165
out_str = quote_utf8(ctx_ptr, "'''''''''", 9, &out_len);
11661166
EXPECT_EQ(std::string(out_str, out_len), "'\\'\\'\\'\\'\\'\\'\\'\\'\\''");
11671167
EXPECT_FALSE(ctx.has_error());
1168+
1169+
int32_t bad_in_len = std::numeric_limits<int32_t>::max() / 2 + 20;
1170+
out_str = quote_utf8(ctx_ptr, "ABCDE", bad_in_len, &out_len);
1171+
EXPECT_EQ(out_len, 0);
1172+
EXPECT_EQ(out_str, "");
11681173
}
11691174

11701175
TEST(TestStringOps, TestLtrim) {
@@ -2498,6 +2503,11 @@ TEST(TestStringOps, TestToHex) {
24982503
output = std::string(out_str, out_len);
24992504
EXPECT_EQ(out_len, 2 * in_len);
25002505
EXPECT_EQ(output, "090A090A090A090A0A0A092061206C657474405D6572");
2506+
2507+
int32_t bad_in_len = std::numeric_limits<int32_t>::max() / 2 + 20;
2508+
out_str = to_hex_binary(ctx_ptr, binary_string, bad_in_len, &out_len);
2509+
EXPECT_EQ(out_len, 0);
2510+
EXPECT_EQ(out_str, "");
25012511
}
25022512

25032513
TEST(TestStringOps, TestToHexInt64) {

0 commit comments

Comments
 (0)