Skip to content

Commit d0280e4

Browse files
committed
GH-49753: [C++][Gandiva] Fix overflow in string functions.
Fixes potential integer-overflow/invalid-length issues in Gandiva string functions by adding overflow-checked allocation sizing and expanding unit tests to cover extreme and negative lengths. Fixed memcpy call in gdv_fn_substring_index function since, the lenth argument is of type size_t. Incorporated review comments.
1 parent d7a02c1 commit d0280e4

4 files changed

Lines changed: 183 additions & 28 deletions

File tree

cpp/src/gandiva/gdv_function_stubs_test.cc

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -640,6 +640,26 @@ TEST(TestGdvFnStubs, TestUpper) {
640640
EXPECT_THAT(ctx.get_error(),
641641
::testing::HasSubstr(
642642
"unexpected byte \\c3 encountered while decoding utf8 string"));
643+
644+
ctx.Reset();
645+
646+
// Max Len Test
647+
out_len = -1;
648+
int32_t bad_len = std::numeric_limits<int32_t>::max() / 2 + 1;
649+
const char* out = gdv_fn_upper_utf8(ctx_ptr, "dummy", bad_len, &out_len);
650+
// Expect failure
651+
EXPECT_EQ(out_len, 0);
652+
EXPECT_STREQ(out, "");
653+
EXPECT_THAT(ctx.get_error(),
654+
::testing::HasSubstr("Would overflow maximum output size"));
655+
ctx.Reset();
656+
657+
// Negative length test
658+
out_len = -1;
659+
out = gdv_fn_upper_utf8(ctx_ptr, "abc", -105, &out_len);
660+
EXPECT_EQ(out_len, 0);
661+
EXPECT_STREQ(out, "");
662+
EXPECT_THAT(ctx.get_error(), ::testing::HasSubstr("Invalid (negative) data length"));
643663
ctx.Reset();
644664

645665
std::string e(
@@ -697,6 +717,26 @@ TEST(TestGdvFnStubs, TestLower) {
697717
out_str = gdv_fn_lower_utf8(ctx_ptr, "", 0, &out_len);
698718
EXPECT_EQ(std::string(out_str, out_len), "");
699719
EXPECT_FALSE(ctx.has_error());
720+
ctx.Reset();
721+
722+
// Max Len Test
723+
out_len = -1;
724+
int32_t bad_len = std::numeric_limits<int32_t>::max() / 2 + 1;
725+
const char* out = gdv_fn_lower_utf8(ctx_ptr, "dummy", bad_len, &out_len);
726+
// Expect failure
727+
EXPECT_EQ(out_len, 0);
728+
EXPECT_STREQ(out, "");
729+
EXPECT_THAT(ctx.get_error(),
730+
::testing::HasSubstr("Would overflow maximum output size"));
731+
ctx.Reset();
732+
733+
// Negative length test
734+
out_len = -1;
735+
out = gdv_fn_lower_utf8(ctx_ptr, "abc", -105, &out_len);
736+
EXPECT_EQ(out_len, 0);
737+
EXPECT_STREQ(out, "");
738+
EXPECT_THAT(ctx.get_error(), ::testing::HasSubstr("Invalid (negative) data length"));
739+
ctx.Reset();
700740

701741
std::string d("AbOJjÜoß\xc3");
702742
out_str = gdv_fn_lower_utf8(ctx_ptr, d.data(), static_cast<int>(d.length()), &out_len);
@@ -796,6 +836,25 @@ TEST(TestGdvFnStubs, TestInitCap) {
796836
"unexpected byte \\c3 encountered while decoding utf8 string"));
797837
ctx.Reset();
798838

839+
// Max Len Test
840+
out_len = -1;
841+
int32_t bad_len = std::numeric_limits<int32_t>::max() / 2 + 1;
842+
const char* out = gdv_fn_initcap_utf8(ctx_ptr, "dummy", bad_len, &out_len);
843+
// Expect failure
844+
EXPECT_EQ(out_len, 0);
845+
EXPECT_STREQ(out, "");
846+
EXPECT_THAT(ctx.get_error(),
847+
::testing::HasSubstr("Would overflow maximum output size"));
848+
ctx.Reset();
849+
850+
// Negative length test
851+
out_len = -1;
852+
out = gdv_fn_initcap_utf8(ctx_ptr, "abc", -105, &out_len);
853+
EXPECT_EQ(out_len, 0);
854+
EXPECT_STREQ(out, "");
855+
EXPECT_THAT(ctx.get_error(), ::testing::HasSubstr("Invalid (negative) data length"));
856+
ctx.Reset();
857+
799858
std::string e(
800859
"åbÑg\xe0\xa0"
801860
"åBUå");
@@ -1127,6 +1186,15 @@ TEST(TestGdvFnStubs, TestTranslate) {
11271186
result = translate_utf8_utf8_utf8(ctx_ptr, "987654321", 9, "123456789", 9, "0123456789",
11281187
10, &out_len);
11291188
EXPECT_EQ(expected, std::string(result, out_len));
1189+
1190+
int32_t bad_in_len = std::numeric_limits<int32_t>::max() / 4 + 1;
1191+
out_len = -1;
1192+
result =
1193+
translate_utf8_utf8_utf8(ctx_ptr, "ABCDE", bad_in_len, "B", 1, "C", 1, &out_len);
1194+
EXPECT_EQ(out_len, 0);
1195+
EXPECT_STREQ(result, "");
1196+
EXPECT_THAT(ctx.get_error(),
1197+
::testing::HasSubstr("Would overflow maximum output size"));
11301198
}
11311199

11321200
TEST(TestGdvFnStubs, TestToUtcTimezone) {

cpp/src/gandiva/gdv_string_function_stubs.cc

Lines changed: 76 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -213,6 +213,25 @@ int32_t gdv_fn_utf8_char_length(char c) {
213213
return 0;
214214
}
215215

216+
static inline bool compute_alloc_len(int64_t context, int32_t data_len,
217+
int32_t* alloc_len, int32_t* out_len) {
218+
// Reject negative lengths
219+
if (ARROW_PREDICT_FALSE(data_len < 0)) {
220+
gdv_fn_context_set_error_msg(context, "Invalid (negative) data length");
221+
*out_len = 0;
222+
return false;
223+
}
224+
225+
// Check overflow: 2 * data_len
226+
if (ARROW_PREDICT_FALSE(
227+
arrow::internal::MultiplyWithOverflow(2, data_len, alloc_len))) {
228+
gdv_fn_context_set_error_msg(context, "Would overflow maximum output size");
229+
*out_len = 0;
230+
return false;
231+
}
232+
return true;
233+
}
234+
216235
// Convert an utf8 string to its corresponding lowercase string
217236
GANDIVA_EXPORT
218237
const char* gdv_fn_lower_utf8(int64_t context, const char* data, int32_t data_len,
@@ -222,10 +241,16 @@ const char* gdv_fn_lower_utf8(int64_t context, const char* data, int32_t data_le
222241
return "";
223242
}
224243

244+
int32_t alloc_length = 0;
245+
if (ARROW_PREDICT_FALSE(
246+
!compute_alloc_len(context, data_len, &alloc_length, out_len))) {
247+
return "";
248+
}
249+
225250
// If it is a single-byte character (ASCII), corresponding lowercase is always 1-byte
226251
// long; if it is >= 2 bytes long, lowercase can be at most 4 bytes long, so length of
227252
// the output can be at most twice the length of the input
228-
char* out = reinterpret_cast<char*>(gdv_fn_context_arena_malloc(context, 2 * data_len));
253+
char* out = reinterpret_cast<char*>(gdv_fn_context_arena_malloc(context, alloc_length));
229254
if (out == nullptr) {
230255
gdv_fn_context_set_error_msg(context, "Could not allocate memory for output string");
231256
*out_len = 0;
@@ -294,10 +319,16 @@ const char* gdv_fn_upper_utf8(int64_t context, const char* data, int32_t data_le
294319
return "";
295320
}
296321

322+
int32_t alloc_length = 0;
323+
if (ARROW_PREDICT_FALSE(
324+
!compute_alloc_len(context, data_len, &alloc_length, out_len))) {
325+
return "";
326+
}
327+
297328
// If it is a single-byte character (ASCII), corresponding uppercase is always 1-byte
298329
// long; if it is >= 2 bytes long, uppercase can be at most 4 bytes long, so length of
299330
// the output can be at most twice the length of the input
300-
char* out = reinterpret_cast<char*>(gdv_fn_context_arena_malloc(context, 2 * data_len));
331+
char* out = reinterpret_cast<char*>(gdv_fn_context_arena_malloc(context, alloc_length));
301332
if (out == nullptr) {
302333
gdv_fn_context_set_error_msg(context, "Could not allocate memory for output string");
303334
*out_len = 0;
@@ -445,8 +476,12 @@ const char* gdv_fn_substring_index(int64_t context, const char* txt, int32_t txt
445476
return out;
446477

447478
} else {
479+
if (txt_len < 0) {
480+
*out_len = 0;
481+
return "";
482+
}
483+
memcpy(out, txt, static_cast<size_t>(txt_len));
448484
*out_len = txt_len;
449-
memcpy(out, txt, txt_len);
450485
return out;
451486
}
452487
}
@@ -480,10 +515,16 @@ const char* gdv_fn_initcap_utf8(int64_t context, const char* data, int32_t data_
480515
return "";
481516
}
482517

518+
int32_t alloc_length = 0;
519+
if (ARROW_PREDICT_FALSE(
520+
!compute_alloc_len(context, data_len, &alloc_length, out_len))) {
521+
return "";
522+
}
523+
483524
// If it is a single-byte character (ASCII), corresponding uppercase is always 1-byte
484525
// long; if it is >= 2 bytes long, uppercase can be at most 4 bytes long, so length of
485526
// the output can be at most twice the length of the input
486-
char* out = reinterpret_cast<char*>(gdv_fn_context_arena_malloc(context, 2 * data_len));
527+
char* out = reinterpret_cast<char*>(gdv_fn_context_arena_malloc(context, alloc_length));
487528
if (out == nullptr) {
488529
gdv_fn_context_set_error_msg(context, "Could not allocate memory for output string");
489530
*out_len = 0;
@@ -579,15 +620,24 @@ const char* translate_utf8_utf8_utf8(int64_t context, const char* in, int32_t in
579620
return in;
580621
}
581622

623+
int32_t alloc_length = 0;
624+
// Check overflow: 4 * in_len
625+
if (ARROW_PREDICT_FALSE(
626+
arrow::internal::MultiplyWithOverflow(4, in_len, &alloc_length))) {
627+
gdv_fn_context_set_error_msg(context, "Would overflow maximum output size");
628+
*out_len = 0;
629+
return "";
630+
}
631+
582632
// This variable is to control if there are multi-byte utf8 entries
583633
bool has_multi_byte = false;
584634

585635
// This variable is to store the final result
586636
char* result;
587-
int result_len;
637+
int32_t result_len;
588638

589639
// Searching multi-bytes in In
590-
for (int i = 0; i < in_len; i++) {
640+
for (int32_t i = 0; i < in_len; i++) {
591641
unsigned char char_single_byte = in[i];
592642
if (char_single_byte > 127) {
593643
// found a multi-byte utf-8 char
@@ -598,7 +648,7 @@ const char* translate_utf8_utf8_utf8(int64_t context, const char* in, int32_t in
598648

599649
// Searching multi-bytes in From
600650
if (!has_multi_byte) {
601-
for (int i = 0; i < from_len; i++) {
651+
for (int32_t i = 0; i < from_len; i++) {
602652
unsigned char char_single_byte = from[i];
603653
if (char_single_byte > 127) {
604654
// found a multi-byte utf-8 char
@@ -610,7 +660,7 @@ const char* translate_utf8_utf8_utf8(int64_t context, const char* in, int32_t in
610660

611661
// Searching multi-bytes in To
612662
if (!has_multi_byte) {
613-
for (int i = 0; i < to_len; i++) {
663+
for (int32_t i = 0; i < to_len; i++) {
614664
unsigned char char_single_byte = to[i];
615665
if (char_single_byte > 127) {
616666
// found a multi-byte utf-8 char
@@ -621,7 +671,7 @@ const char* translate_utf8_utf8_utf8(int64_t context, const char* in, int32_t in
621671
}
622672

623673
// If there are no multibytes in the input, work only with char
624-
if (!has_multi_byte) {
674+
if (not has_multi_byte) {
625675
// This variable is for receive the substitutions
626676
result = reinterpret_cast<char*>(gdv_fn_context_arena_malloc(context, in_len));
627677

@@ -638,7 +688,7 @@ const char* translate_utf8_utf8_utf8(int64_t context, const char* in, int32_t in
638688

639689
// This variable is for controlling the position in entry TO, for never repeat the
640690
// changes
641-
int start_compare;
691+
int32_t start_compare;
642692

643693
if (to_len > 0) {
644694
start_compare = 0;
@@ -650,15 +700,15 @@ const char* translate_utf8_utf8_utf8(int64_t context, const char* in, int32_t in
650700
// list, to mark deletion positions
651701
const char empty = '\0';
652702

653-
for (int in_for = 0; in_for < in_len; in_for++) {
703+
for (int32_t in_for = 0; in_for < in_len; in_for++) {
654704
if (subs_list.find(in[in_for]) != subs_list.end()) {
655705
if (subs_list[in[in_for]] != empty) {
656706
// If exist in map, only add the correspondent value in result
657707
result[result_len] = subs_list[in[in_for]];
658708
result_len++;
659709
}
660710
} else {
661-
for (int from_for = 0; from_for <= from_len; from_for++) {
711+
for (int32_t from_for = 0; from_for <= from_len; from_for++) {
662712
if (from_for == from_len) {
663713
// If it's not in the FROM list, just add it to the map and the result.
664714
subs_list.insert(std::pair<char, char>(in[in_for], in[in_for]));
@@ -686,10 +736,11 @@ const char* translate_utf8_utf8_utf8(int64_t context, const char* in, int32_t in
686736
}
687737
}
688738
}
689-
} else { // If there are no multibytes in the input, work with std::strings
739+
} else {
740+
// If there are multibytes in the input, work with std::strings
690741
// This variable is for receive the substitutions, malloc is in_len * 4 to receive
691742
// possible inputs with 4 bytes
692-
result = reinterpret_cast<char*>(gdv_fn_context_arena_malloc(context, in_len * 4));
743+
result = reinterpret_cast<char*>(gdv_fn_context_arena_malloc(context, alloc_length));
693744

694745
if (result == nullptr) {
695746
gdv_fn_context_set_error_msg(context,
@@ -704,7 +755,7 @@ const char* translate_utf8_utf8_utf8(int64_t context, const char* in, int32_t in
704755

705756
// This variable is for controlling the position in entry TO, for never repeat the
706757
// changes
707-
int start_compare;
758+
int32_t start_compare;
708759

709760
if (to_len > 0) {
710761
start_compare = 0;
@@ -717,11 +768,11 @@ const char* translate_utf8_utf8_utf8(int64_t context, const char* in, int32_t in
717768
const std::string empty = "";
718769

719770
// This variables is to control len of multi-bytes entries
720-
int len_char_in = 0;
721-
int len_char_from = 0;
722-
int len_char_to = 0;
771+
int32_t len_char_in = 0;
772+
int32_t len_char_from = 0;
773+
int32_t len_char_to = 0;
723774

724-
for (int in_for = 0; in_for < in_len; in_for += len_char_in) {
775+
for (int32_t in_for = 0; in_for < in_len; in_for += len_char_in) {
725776
// Updating len to char in this position
726777
len_char_in = gdv_fn_utf8_char_length(in[in_for]);
727778
// Making copy to std::string with length for this char position
@@ -734,11 +785,7 @@ const char* translate_utf8_utf8_utf8(int64_t context, const char* in, int32_t in
734785
result_len += static_cast<int>(subs_list[insert_copy_key].length());
735786
}
736787
} else {
737-
for (int from_for = 0; from_for <= from_len; from_for += len_char_from) {
738-
// Updating len to char in this position
739-
len_char_from = gdv_fn_utf8_char_length(from[from_for]);
740-
// Making copy to std::string with length for this char position
741-
std::string copy_from_compare(from + from_for, len_char_from);
788+
for (int32_t from_for = 0; from_for <= from_len; from_for += len_char_from) {
742789
if (from_for == from_len) {
743790
// If it's not in the FROM list, just add it to the map and the result.
744791
std::string insert_copy_value(in + in_for, len_char_in);
@@ -751,6 +798,11 @@ const char* translate_utf8_utf8_utf8(int64_t context, const char* in, int32_t in
751798
break;
752799
}
753800

801+
// Updating len to char in this position
802+
len_char_from = gdv_fn_utf8_char_length(from[from_for]);
803+
// Making copy to std::string with length for this char position
804+
std::string copy_from_compare(from + from_for, len_char_from);
805+
754806
if (insert_copy_key != copy_from_compare) {
755807
// If this character does not exist in FROM list, don't need treatment
756808
continue;

cpp/src/gandiva/precompiled/string_ops.cc

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1924,9 +1924,17 @@ const char* quote_utf8(gdv_int64 context, const char* in, gdv_int32 in_len,
19241924
*out_len = 0;
19251925
return "";
19261926
}
1927+
1928+
int32_t alloc_length = 0;
1929+
if (ARROW_PREDICT_FALSE(
1930+
arrow::internal::AddWithOverflow(2, (2 * in_len), &alloc_length))) {
1931+
gdv_fn_context_set_error_msg(context, "Memory allocation size too large");
1932+
*out_len = 0;
1933+
return "";
1934+
}
1935+
19271936
// try to allocate double size output string (worst case)
1928-
auto out =
1929-
reinterpret_cast<char*>(gdv_fn_context_arena_malloc(context, (in_len * 2) + 2));
1937+
auto out = reinterpret_cast<char*>(gdv_fn_context_arena_malloc(context, alloc_length));
19301938
if (out == nullptr) {
19311939
gdv_fn_context_set_error_msg(context, "Could not allocate memory for output string");
19321940
*out_len = 0;
@@ -2824,13 +2832,22 @@ const char* elt_int32_utf8_utf8_utf8_utf8_utf8(
28242832
FORCE_INLINE
28252833
const char* to_hex_binary(int64_t context, const char* text, int32_t text_len,
28262834
int32_t* out_len) {
2827-
if (text_len == 0) {
2835+
if (ARROW_PREDICT_FALSE(text_len <= 0)) {
2836+
*out_len = 0;
2837+
return "";
2838+
}
2839+
2840+
int32_t alloc_length = 0;
2841+
2842+
// Check overflow: 2 * text_len
2843+
if (ARROW_PREDICT_FALSE(
2844+
arrow::internal::MultiplyWithOverflow(2, text_len, &alloc_length))) {
28282845
*out_len = 0;
28292846
return "";
28302847
}
28312848

28322849
auto ret =
2833-
reinterpret_cast<char*>(gdv_fn_context_arena_malloc(context, text_len * 2 + 1));
2850+
reinterpret_cast<char*>(gdv_fn_context_arena_malloc(context, alloc_length + 1));
28342851

28352852
if (ret == nullptr) {
28362853
gdv_fn_context_set_error_msg(context, "Could not allocate memory for output string");

0 commit comments

Comments
 (0)