diff --git a/cpp/src/gandiva/encrypt_mode_dispatcher.cc b/cpp/src/gandiva/encrypt_mode_dispatcher.cc index b9ee8497ee5..1f63a96be68 100644 --- a/cpp/src/gandiva/encrypt_mode_dispatcher.cc +++ b/cpp/src/gandiva/encrypt_mode_dispatcher.cc @@ -23,9 +23,39 @@ #include #include #include +#include namespace gandiva { +// Supported encryption modes +static const std::vector SUPPORTED_MODES = { + AES_ECB_MODE, AES_ECB_PKCS7_MODE, AES_ECB_NONE_MODE, + AES_CBC_MODE, AES_CBC_PKCS7_MODE, AES_CBC_NONE_MODE, + AES_GCM_MODE +}; + +enum class EncryptionMode { + ECB, + ECB_PKCS7, + ECB_NONE, + CBC, + CBC_PKCS7, + CBC_NONE, + GCM, + UNKNOWN +}; + +EncryptionMode ParseEncryptionMode(std::string_view mode_str) { + if (mode_str == AES_ECB_MODE) return EncryptionMode::ECB; + if (mode_str == AES_ECB_PKCS7_MODE) return EncryptionMode::ECB_PKCS7; + if (mode_str == AES_ECB_NONE_MODE) return EncryptionMode::ECB_NONE; + if (mode_str == AES_CBC_MODE) return EncryptionMode::CBC; + if (mode_str == AES_CBC_PKCS7_MODE) return EncryptionMode::CBC_PKCS7; + if (mode_str == AES_CBC_NONE_MODE) return EncryptionMode::CBC_NONE; + if (mode_str == AES_GCM_MODE) return EncryptionMode::GCM; + return EncryptionMode::UNKNOWN; +} + int32_t EncryptModeDispatcher::encrypt( const char* plaintext, int32_t plaintext_len, const char* key, int32_t key_len, const char* mode, int32_t mode_len, const char* iv, @@ -34,23 +64,35 @@ int32_t EncryptModeDispatcher::encrypt( std::string mode_str = arrow::internal::AsciiToUpper(std::string_view(mode, mode_len)); - if (mode_str == AES_ECB_MODE) { - return aes_encrypt_ecb(plaintext, plaintext_len, key, key_len, cipher); - } else if (mode_str == AES_CBC_PKCS7_MODE) { - return aes_encrypt_cbc(plaintext, plaintext_len, key, key_len, - iv, iv_len, true, cipher); - } else if (mode_str == AES_CBC_NONE_MODE) { - return aes_encrypt_cbc(plaintext, plaintext_len, key, key_len, - iv, iv_len, false, cipher); - } else if (mode_str == AES_GCM_MODE) { - return aes_encrypt_gcm(plaintext, plaintext_len, key, key_len, - iv, iv_len, fifth_argument, fifth_argument_len, cipher); - } else { - std::ostringstream oss; - oss << "Unsupported encryption mode: " << mode_str - << ". Supported modes: " << AES_ECB_MODE << ", " << AES_CBC_PKCS7_MODE - << ", " << AES_CBC_NONE_MODE << ", " << AES_GCM_MODE; - throw std::runtime_error(oss.str()); + switch (ParseEncryptionMode(mode_str)) { + case EncryptionMode::ECB: + case EncryptionMode::ECB_PKCS7: + // Shorthand AES-ECB and explicit AES-ECB-PKCS7 both use ECB with PKCS7 + return aes_encrypt_ecb(plaintext, plaintext_len, key, key_len, cipher); + case EncryptionMode::ECB_NONE: + // ECB mode doesn't use padding, but we still call the same function + // since ECB doesn't have padding options + return aes_encrypt_ecb(plaintext, plaintext_len, key, key_len, cipher); + case EncryptionMode::CBC: + case EncryptionMode::CBC_PKCS7: + // Shorthand AES-CBC and explicit AES-CBC-PKCS7 both use CBC with PKCS7 + return aes_encrypt_cbc(plaintext, plaintext_len, key, key_len, + iv, iv_len, true, cipher); + case EncryptionMode::CBC_NONE: + // CBC without padding + return aes_encrypt_cbc(plaintext, plaintext_len, key, key_len, + iv, iv_len, false, cipher); + case EncryptionMode::GCM: + return aes_encrypt_gcm(plaintext, plaintext_len, key, key_len, + iv, iv_len, fifth_argument, fifth_argument_len, cipher); + case EncryptionMode::UNKNOWN: + default: { + std::string modes_str = arrow::internal::JoinStrings(SUPPORTED_MODES, ", "); + std::ostringstream oss; + oss << "Unsupported encryption mode: " << mode_str + << ". Supported modes: " << modes_str; + throw std::runtime_error(oss.str()); + } } } @@ -62,23 +104,35 @@ int32_t EncryptModeDispatcher::decrypt( std::string mode_str = arrow::internal::AsciiToUpper(std::string_view(mode, mode_len)); - if (mode_str == AES_ECB_MODE) { - return aes_decrypt_ecb(ciphertext, ciphertext_len, key, key_len, plaintext); - } else if (mode_str == AES_CBC_PKCS7_MODE) { - return aes_decrypt_cbc(ciphertext, ciphertext_len, key, key_len, - iv, iv_len, true, plaintext); - } else if (mode_str == AES_CBC_NONE_MODE) { - return aes_decrypt_cbc(ciphertext, ciphertext_len, key, key_len, - iv, iv_len, false, plaintext); - } else if (mode_str == AES_GCM_MODE) { - return aes_decrypt_gcm(ciphertext, ciphertext_len, key, key_len, - iv, iv_len, fifth_argument, fifth_argument_len, plaintext); - } else { - std::ostringstream oss; - oss << "Unsupported decryption mode: " << mode_str - << ". Supported modes: " << AES_ECB_MODE << ", " << AES_CBC_PKCS7_MODE - << ", " << AES_CBC_NONE_MODE << ", " << AES_GCM_MODE; - throw std::runtime_error(oss.str()); + switch (ParseEncryptionMode(mode_str)) { + case EncryptionMode::ECB: + case EncryptionMode::ECB_PKCS7: + // Shorthand AES-ECB and explicit AES-ECB-PKCS7 both use ECB with PKCS7 + return aes_decrypt_ecb(ciphertext, ciphertext_len, key, key_len, plaintext); + case EncryptionMode::ECB_NONE: + // ECB mode doesn't use padding, but we still call the same function + // since ECB doesn't have padding options + return aes_decrypt_ecb(ciphertext, ciphertext_len, key, key_len, plaintext); + case EncryptionMode::CBC: + case EncryptionMode::CBC_PKCS7: + // Shorthand AES-CBC and explicit AES-CBC-PKCS7 both use CBC with PKCS7 + return aes_decrypt_cbc(ciphertext, ciphertext_len, key, key_len, + iv, iv_len, true, plaintext); + case EncryptionMode::CBC_NONE: + // CBC without padding + return aes_decrypt_cbc(ciphertext, ciphertext_len, key, key_len, + iv, iv_len, false, plaintext); + case EncryptionMode::GCM: + return aes_decrypt_gcm(ciphertext, ciphertext_len, key, key_len, + iv, iv_len, fifth_argument, fifth_argument_len, plaintext); + case EncryptionMode::UNKNOWN: + default: { + std::string modes_str = arrow::internal::JoinStrings(SUPPORTED_MODES, ", "); + std::ostringstream oss; + oss << "Unsupported decryption mode: " << mode_str + << ". Supported modes: " << modes_str; + throw std::runtime_error(oss.str()); + } } } diff --git a/cpp/src/gandiva/encrypt_utils_cbc.h b/cpp/src/gandiva/encrypt_utils_cbc.h index 1f9749e2874..b083d6f0a2d 100644 --- a/cpp/src/gandiva/encrypt_utils_cbc.h +++ b/cpp/src/gandiva/encrypt_utils_cbc.h @@ -24,6 +24,7 @@ namespace gandiva { // CBC mode identifiers +constexpr const char* AES_CBC_MODE = "AES-CBC"; constexpr const char* AES_CBC_PKCS7_MODE = "AES-CBC-PKCS7"; constexpr const char* AES_CBC_NONE_MODE = "AES-CBC-NONE"; diff --git a/cpp/src/gandiva/encrypt_utils_ecb.h b/cpp/src/gandiva/encrypt_utils_ecb.h index 6ed3f2e7e37..f8a44632e80 100644 --- a/cpp/src/gandiva/encrypt_utils_ecb.h +++ b/cpp/src/gandiva/encrypt_utils_ecb.h @@ -23,8 +23,10 @@ namespace gandiva { -// ECB mode identifier +// ECB mode identifiers constexpr const char* AES_ECB_MODE = "AES-ECB"; +constexpr const char* AES_ECB_PKCS7_MODE = "AES-ECB-PKCS7"; +constexpr const char* AES_ECB_NONE_MODE = "AES-ECB-NONE"; /** * Encrypt data using AES-ECB algorithm (legacy, insecure) diff --git a/cpp/src/gandiva/function_registry_string.cc b/cpp/src/gandiva/function_registry_string.cc index 721f13d90eb..7750421360e 100644 --- a/cpp/src/gandiva/function_registry_string.cc +++ b/cpp/src/gandiva/function_registry_string.cc @@ -505,30 +505,30 @@ std::vector GetStringFunctionRegistry() { NativeFunction::kNeedsContext | NativeFunction::kCanReturnErrors), // Parameters: data, key, mode (e.g. ECB mode) - NativeFunction("aes_encrypt", {}, DataTypeVector{binary(), binary(), utf8()}, binary(), - kResultNullIfNull, "gdv_fn_aes_encrypt_dispatcher_3args", + NativeFunction("encrypt", {}, DataTypeVector{binary(), binary(), utf8()}, binary(), + kResultNullIfNull, "gdv_fn_encrypt_dispatcher_3args", NativeFunction::kNeedsContext | NativeFunction::kCanReturnErrors), - NativeFunction("aes_decrypt", {}, DataTypeVector{binary(), binary(), utf8()}, binary(), - kResultNullIfNull, "gdv_fn_aes_decrypt_dispatcher_3args", + NativeFunction("decrypt", {}, DataTypeVector{binary(), binary(), utf8()}, binary(), + kResultNullIfNull, "gdv_fn_decrypt_dispatcher_3args", NativeFunction::kNeedsContext | NativeFunction::kCanReturnErrors), // Parameters: data, key, mode, iv (e.g. CBC mode) - NativeFunction("aes_encrypt", {}, DataTypeVector{binary(), binary(), utf8(), binary()}, binary(), - kResultNullIfNull, "gdv_fn_aes_encrypt_dispatcher_4args", + NativeFunction("encrypt", {}, DataTypeVector{binary(), binary(), utf8(), binary()}, binary(), + kResultNullIfNull, "gdv_fn_encrypt_dispatcher_4args", NativeFunction::kNeedsContext | NativeFunction::kCanReturnErrors), - NativeFunction("aes_decrypt", {}, DataTypeVector{binary(), binary(), utf8(), binary()}, binary(), - kResultNullIfNull, "gdv_fn_aes_decrypt_dispatcher_4args", + NativeFunction("decrypt", {}, DataTypeVector{binary(), binary(), utf8(), binary()}, binary(), + kResultNullIfNull, "gdv_fn_decrypt_dispatcher_4args", NativeFunction::kNeedsContext | NativeFunction::kCanReturnErrors), // Parameters: data, key, mode, iv, fifth_argument (e.g. GCM mode) - NativeFunction("aes_encrypt", {}, DataTypeVector{binary(), binary(), utf8(), binary(), binary()}, binary(), - kResultNullIfNull, "gdv_fn_aes_encrypt_dispatcher_5args", + NativeFunction("encrypt", {}, DataTypeVector{binary(), binary(), utf8(), binary(), binary()}, binary(), + kResultNullIfNull, "gdv_fn_encrypt_dispatcher_5args", NativeFunction::kNeedsContext | NativeFunction::kCanReturnErrors), - NativeFunction("aes_decrypt", {}, DataTypeVector{binary(), binary(), utf8(), binary(), binary()}, binary(), - kResultNullIfNull, "gdv_fn_aes_decrypt_dispatcher_5args", + NativeFunction("decrypt", {}, DataTypeVector{binary(), binary(), utf8(), binary(), binary()}, binary(), + kResultNullIfNull, "gdv_fn_decrypt_dispatcher_5args", NativeFunction::kNeedsContext | NativeFunction::kCanReturnErrors), NativeFunction("mask_first_n", {}, DataTypeVector{utf8(), int32()}, utf8(), diff --git a/cpp/src/gandiva/gdv_function_stubs.cc b/cpp/src/gandiva/gdv_function_stubs.cc index a4c98373e72..a33483e8a00 100644 --- a/cpp/src/gandiva/gdv_function_stubs.cc +++ b/cpp/src/gandiva/gdv_function_stubs.cc @@ -868,7 +868,7 @@ const char* gdv_fn_aes_encrypt_ecb_legacy(int64_t context, const char* data, // This function is ECB-only, so we enforce the mode const char* mode = "AES-ECB"; int32_t mode_len = 7; - const char* result = gdv_fn_aes_encrypt_dispatcher_3args( + const char* result = gdv_fn_encrypt_dispatcher_3args( context, data, data_len, key_data, key_data_len, mode, mode_len, out_len); // Add null terminator for string compatibility @@ -895,7 +895,7 @@ const char* gdv_fn_aes_decrypt_ecb_legacy(int64_t context, const char* data, // This function is ECB-only, so we enforce the mode const char* mode = "AES-ECB"; int32_t mode_len = 7; - const char* result = gdv_fn_aes_decrypt_dispatcher_3args( + const char* result = gdv_fn_decrypt_dispatcher_3args( context, data, data_len, key_data, key_data_len, mode, mode_len, out_len); // Add null terminator for string compatibility @@ -910,47 +910,47 @@ const char* gdv_fn_aes_decrypt_ecb_legacy(int64_t context, const char* data, // The 3- and 4-arg signatures exist to support optional IV and other arguments extern "C" GANDIVA_EXPORT -const char* gdv_fn_aes_encrypt_dispatcher_3args( +const char* gdv_fn_encrypt_dispatcher_3args( int64_t context, const char* data, int32_t data_len, const char* key_data, int32_t key_data_len, const char* mode, int32_t mode_len, int32_t* out_len) { - return gdv_fn_aes_encrypt_dispatcher_5args( + return gdv_fn_encrypt_dispatcher_5args( context, data, data_len, key_data, key_data_len, mode, mode_len, nullptr, 0, nullptr, 0, out_len); } extern "C" GANDIVA_EXPORT -const char* gdv_fn_aes_decrypt_dispatcher_3args( +const char* gdv_fn_decrypt_dispatcher_3args( int64_t context, const char* data, int32_t data_len, const char* key_data, int32_t key_data_len, const char* mode, int32_t mode_len, int32_t* out_len) { - return gdv_fn_aes_decrypt_dispatcher_5args( + return gdv_fn_decrypt_dispatcher_5args( context, data, data_len, key_data, key_data_len, mode, mode_len, nullptr, 0, nullptr, 0, out_len); } extern "C" GANDIVA_EXPORT -const char* gdv_fn_aes_encrypt_dispatcher_4args( +const char* gdv_fn_encrypt_dispatcher_4args( int64_t context, const char* data, int32_t data_len, const char* key_data, int32_t key_data_len, const char* mode, int32_t mode_len, const char* iv_data, int32_t iv_data_len, int32_t* out_len) { - return gdv_fn_aes_encrypt_dispatcher_5args( + return gdv_fn_encrypt_dispatcher_5args( context, data, data_len, key_data, key_data_len, mode, mode_len, iv_data, iv_data_len, nullptr, 0, out_len); } extern "C" GANDIVA_EXPORT -const char* gdv_fn_aes_decrypt_dispatcher_4args( +const char* gdv_fn_decrypt_dispatcher_4args( int64_t context, const char* data, int32_t data_len, const char* key_data, int32_t key_data_len, const char* mode, int32_t mode_len, const char* iv_data, int32_t iv_data_len, int32_t* out_len) { - return gdv_fn_aes_decrypt_dispatcher_5args( + return gdv_fn_decrypt_dispatcher_5args( context, data, data_len, key_data, key_data_len, mode, mode_len, iv_data, iv_data_len, nullptr, 0, out_len); } extern "C" GANDIVA_EXPORT -const char* gdv_fn_aes_encrypt_dispatcher_5args( +const char* gdv_fn_encrypt_dispatcher_5args( int64_t context, const char* data, int32_t data_len, const char* key_data, int32_t key_data_len, const char* mode, int32_t mode_len, const char* iv_data, int32_t iv_data_len, const char* fifth_argument, @@ -980,7 +980,7 @@ const char* gdv_fn_aes_encrypt_dispatcher_5args( } extern "C" GANDIVA_EXPORT -const char* gdv_fn_aes_decrypt_dispatcher_5args( +const char* gdv_fn_decrypt_dispatcher_5args( int64_t context, const char* data, int32_t data_len, const char* key_data, int32_t key_data_len, const char* mode, int32_t mode_len, const char* iv_data, int32_t iv_data_len, const char* fifth_argument, @@ -1241,7 +1241,7 @@ arrow::Status ExportedStubFunctions::AddMappings(Engine* engine) const { types->i32_ptr_type() // out_length }; - // gdv_fn_aes_encrypt_dispatcher_3args (data, key, mode) + // gdv_fn_encrypt_dispatcher_3args (data, key, mode) args = { types->i64_type(), // context types->i8_ptr_type(), // data @@ -1254,11 +1254,11 @@ arrow::Status ExportedStubFunctions::AddMappings(Engine* engine) const { }; engine->AddGlobalMappingForFunc( - "gdv_fn_aes_encrypt_dispatcher_3args", + "gdv_fn_encrypt_dispatcher_3args", types->i8_ptr_type() /*return_type*/, args, - reinterpret_cast(gdv_fn_aes_encrypt_dispatcher_3args)); + reinterpret_cast(gdv_fn_encrypt_dispatcher_3args)); - // gdv_fn_aes_decrypt_dispatcher_3args (data, key, mode) + // gdv_fn_decrypt_dispatcher_3args (data, key, mode) args = { types->i64_type(), // context types->i8_ptr_type(), // data @@ -1271,11 +1271,11 @@ arrow::Status ExportedStubFunctions::AddMappings(Engine* engine) const { }; engine->AddGlobalMappingForFunc( - "gdv_fn_aes_decrypt_dispatcher_3args", + "gdv_fn_decrypt_dispatcher_3args", types->i8_ptr_type() /*return_type*/, args, - reinterpret_cast(gdv_fn_aes_decrypt_dispatcher_3args)); + reinterpret_cast(gdv_fn_decrypt_dispatcher_3args)); - // gdv_fn_aes_encrypt_dispatcher_4args (data, key, mode, iv) + // gdv_fn_encrypt_dispatcher_4args (data, key, mode, iv) args = { types->i64_type(), // context types->i8_ptr_type(), // data @@ -1290,11 +1290,11 @@ arrow::Status ExportedStubFunctions::AddMappings(Engine* engine) const { }; engine->AddGlobalMappingForFunc( - "gdv_fn_aes_encrypt_dispatcher_4args", + "gdv_fn_encrypt_dispatcher_4args", types->i8_ptr_type() /*return_type*/, args, - reinterpret_cast(gdv_fn_aes_encrypt_dispatcher_4args)); + reinterpret_cast(gdv_fn_encrypt_dispatcher_4args)); - // gdv_fn_aes_decrypt_dispatcher_4args (data, key, mode, iv) + // gdv_fn_decrypt_dispatcher_4args (data, key, mode, iv) args = { types->i64_type(), // context types->i8_ptr_type(), // data @@ -1309,11 +1309,11 @@ arrow::Status ExportedStubFunctions::AddMappings(Engine* engine) const { }; engine->AddGlobalMappingForFunc( - "gdv_fn_aes_decrypt_dispatcher_4args", + "gdv_fn_decrypt_dispatcher_4args", types->i8_ptr_type() /*return_type*/, args, - reinterpret_cast(gdv_fn_aes_decrypt_dispatcher_4args)); + reinterpret_cast(gdv_fn_decrypt_dispatcher_4args)); - // gdv_fn_aes_encrypt_dispatcher_5args (data, key, mode, iv, + // gdv_fn_encrypt_dispatcher_5args (data, key, mode, iv, // fifth_argument) args = { types->i64_type(), // context @@ -1331,11 +1331,11 @@ arrow::Status ExportedStubFunctions::AddMappings(Engine* engine) const { }; engine->AddGlobalMappingForFunc( - "gdv_fn_aes_encrypt_dispatcher_5args", + "gdv_fn_encrypt_dispatcher_5args", types->i8_ptr_type() /*return_type*/, args, - reinterpret_cast(gdv_fn_aes_encrypt_dispatcher_5args)); + reinterpret_cast(gdv_fn_encrypt_dispatcher_5args)); - // gdv_fn_aes_decrypt_dispatcher_5args (data, key, mode, iv, + // gdv_fn_decrypt_dispatcher_5args (data, key, mode, iv, // fifth_argument) args = { types->i64_type(), // context @@ -1353,9 +1353,9 @@ arrow::Status ExportedStubFunctions::AddMappings(Engine* engine) const { }; engine->AddGlobalMappingForFunc( - "gdv_fn_aes_decrypt_dispatcher_5args", + "gdv_fn_decrypt_dispatcher_5args", types->i8_ptr_type() /*return_type*/, args, - reinterpret_cast(gdv_fn_aes_decrypt_dispatcher_5args)); + reinterpret_cast(gdv_fn_decrypt_dispatcher_5args)); // gdv_mask_first_n and gdv_mask_last_n std::vector mask_args = { diff --git a/cpp/src/gandiva/gdv_function_stubs.h b/cpp/src/gandiva/gdv_function_stubs.h index 8224378f60a..54480ac7f6f 100644 --- a/cpp/src/gandiva/gdv_function_stubs.h +++ b/cpp/src/gandiva/gdv_function_stubs.h @@ -206,27 +206,27 @@ const char* gdv_fn_aes_decrypt_ecb_legacy(int64_t context, const char* data, // 3-argument dispatcher: (data, key, mode) GANDIVA_EXPORT -const char* gdv_fn_aes_encrypt_dispatcher_3args( +const char* gdv_fn_encrypt_dispatcher_3args( int64_t context, const char* data, int32_t data_len, const char* key_data, int32_t key_data_len, const char* mode, int32_t mode_len, int32_t* out_len); GANDIVA_EXPORT -const char* gdv_fn_aes_decrypt_dispatcher_3args( +const char* gdv_fn_decrypt_dispatcher_3args( int64_t context, const char* data, int32_t data_len, const char* key_data, int32_t key_data_len, const char* mode, int32_t mode_len, int32_t* out_len); // 4-argument dispatcher: (data, key, mode, iv) GANDIVA_EXPORT -const char* gdv_fn_aes_encrypt_dispatcher_4args( +const char* gdv_fn_encrypt_dispatcher_4args( int64_t context, const char* data, int32_t data_len, const char* key_data, int32_t key_data_len, const char* mode, int32_t mode_len, const char* iv_data, int32_t iv_data_len, int32_t* out_len); GANDIVA_EXPORT -const char* gdv_fn_aes_decrypt_dispatcher_4args( +const char* gdv_fn_decrypt_dispatcher_4args( int64_t context, const char* data, int32_t data_len, const char* key_data, int32_t key_data_len, const char* mode, int32_t mode_len, const char* iv_data, int32_t iv_data_len, @@ -234,7 +234,7 @@ const char* gdv_fn_aes_decrypt_dispatcher_4args( // 5-argument dispatcher: (data, key, mode, iv, fifth_argument) GANDIVA_EXPORT -const char* gdv_fn_aes_encrypt_dispatcher_5args( +const char* gdv_fn_encrypt_dispatcher_5args( int64_t context, const char* data, int32_t data_len, const char* key_data, int32_t key_data_len, const char* mode, int32_t mode_len, const char* iv_data, int32_t iv_data_len, @@ -242,7 +242,7 @@ const char* gdv_fn_aes_encrypt_dispatcher_5args( int32_t* out_len); GANDIVA_EXPORT -const char* gdv_fn_aes_decrypt_dispatcher_5args( +const char* gdv_fn_decrypt_dispatcher_5args( int64_t context, const char* data, int32_t data_len, const char* key_data, int32_t key_data_len, const char* mode, int32_t mode_len, const char* iv_data, int32_t iv_data_len, diff --git a/cpp/src/gandiva/gdv_function_stubs_test.cc b/cpp/src/gandiva/gdv_function_stubs_test.cc index eeb5a54042f..bfb34eeb31d 100644 --- a/cpp/src/gandiva/gdv_function_stubs_test.cc +++ b/cpp/src/gandiva/gdv_function_stubs_test.cc @@ -1360,10 +1360,10 @@ TEST(TestGdvFnStubs, TestAesEncryptDecrypt16) { auto mode_len = static_cast(mode.length()); int64_t ctx_ptr = reinterpret_cast(&ctx); - const char* cipher = gdv_fn_aes_encrypt_dispatcher_3args( + const char* cipher = gdv_fn_encrypt_dispatcher_3args( ctx_ptr, data.c_str(), data_len, key16.c_str(), key16_len, mode.c_str(), mode_len, &cipher_len); - const char* decrypted_value = gdv_fn_aes_decrypt_dispatcher_3args( + const char* decrypted_value = gdv_fn_decrypt_dispatcher_3args( ctx_ptr, cipher, cipher_len, key16.c_str(), key16_len, mode.c_str(), mode_len, &decrypted_len); @@ -1384,11 +1384,11 @@ TEST(TestGdvFnStubs, TestAesEncryptDecrypt24) { auto mode_len = static_cast(mode.length()); int64_t ctx_ptr = reinterpret_cast(&ctx); - const char* cipher = gdv_fn_aes_encrypt_dispatcher_3args( + const char* cipher = gdv_fn_encrypt_dispatcher_3args( ctx_ptr, data.c_str(), data_len, key24.c_str(), key24_len, mode.c_str(), mode_len, &cipher_len); - const char* decrypted_value = gdv_fn_aes_decrypt_dispatcher_3args( + const char* decrypted_value = gdv_fn_decrypt_dispatcher_3args( ctx_ptr, cipher, cipher_len, key24.c_str(), key24_len, mode.c_str(), mode_len, &decrypted_len); @@ -1409,11 +1409,11 @@ TEST(TestGdvFnStubs, TestAesEncryptDecrypt32) { auto mode_len = static_cast(mode.length()); int64_t ctx_ptr = reinterpret_cast(&ctx); - const char* cipher = gdv_fn_aes_encrypt_dispatcher_3args( + const char* cipher = gdv_fn_encrypt_dispatcher_3args( ctx_ptr, data.c_str(), data_len, key32.c_str(), key32_len, mode.c_str(), mode_len, &cipher_len); - const char* decrypted_value = gdv_fn_aes_decrypt_dispatcher_3args( + const char* decrypted_value = gdv_fn_decrypt_dispatcher_3args( ctx_ptr, cipher, cipher_len, key32.c_str(), key32_len, mode.c_str(), mode_len, &decrypted_len); @@ -1435,14 +1435,14 @@ TEST(TestGdvFnStubs, TestAesEncryptDecryptValidation) { std::string cipher = "12345678abcdefgh12345678abcdefghb"; auto cipher_len = static_cast(cipher.length()); - gdv_fn_aes_encrypt_dispatcher_3args(ctx_ptr, data.c_str(), data_len, + gdv_fn_encrypt_dispatcher_3args(ctx_ptr, data.c_str(), data_len, key33.c_str(), key33_len, mode.c_str(), mode_len, &cipher_len); EXPECT_THAT(ctx.get_error(), ::testing::HasSubstr("Unsupported key length for AES-ECB")); ctx.Reset(); - gdv_fn_aes_decrypt_dispatcher_3args(ctx_ptr, cipher.c_str(), cipher_len, + gdv_fn_decrypt_dispatcher_3args(ctx_ptr, cipher.c_str(), cipher_len, key33.c_str(), key33_len, mode.c_str(), mode_len, &decrypted_len); EXPECT_THAT(ctx.get_error(), @@ -1463,12 +1463,12 @@ TEST(TestGdvFnStubs, TestAesEncryptDecryptModeEcb) { auto mode_len = static_cast(mode.length()); int64_t ctx_ptr = reinterpret_cast(&ctx); - const char* cipher = gdv_fn_aes_encrypt_dispatcher_3args( + const char* cipher = gdv_fn_encrypt_dispatcher_3args( ctx_ptr, data.c_str(), data_len, key16.c_str(), key16_len, mode.c_str(), mode_len, &cipher_len); EXPECT_GT(cipher_len, 0); - const char* decrypted_value = gdv_fn_aes_decrypt_dispatcher_3args( + const char* decrypted_value = gdv_fn_decrypt_dispatcher_3args( ctx_ptr, cipher, cipher_len, key16.c_str(), key16_len, mode.c_str(), mode_len, &decrypted_len); EXPECT_EQ(data, @@ -1489,7 +1489,7 @@ TEST(TestGdvFnStubs, TestAesEncryptDecryptModeValidation) { int64_t ctx_ptr = reinterpret_cast(&ctx); // Test encrypt with invalid mode - gdv_fn_aes_encrypt_dispatcher_3args(ctx_ptr, data.c_str(), data_len, + gdv_fn_encrypt_dispatcher_3args(ctx_ptr, data.c_str(), data_len, key16.c_str(), key16_len, invalid_mode.c_str(), invalid_mode_len, &cipher_len); @@ -1500,7 +1500,7 @@ TEST(TestGdvFnStubs, TestAesEncryptDecryptModeValidation) { // Test decrypt with invalid mode std::string cipher = "12345678abcdefgh12345678abcdefgh"; auto cipher_len_val = static_cast(cipher.length()); - gdv_fn_aes_decrypt_dispatcher_3args(ctx_ptr, cipher.c_str(), cipher_len_val, + gdv_fn_decrypt_dispatcher_3args(ctx_ptr, cipher.c_str(), cipher_len_val, key16.c_str(), key16_len, invalid_mode.c_str(), invalid_mode_len, &decrypted_len); @@ -1524,12 +1524,12 @@ TEST(TestGdvFnStubs, TestAesEncryptDecryptGcmIvOnly) { auto iv_len = static_cast(iv.length()); int64_t ctx_ptr = reinterpret_cast(&ctx); - const char* cipher = gdv_fn_aes_encrypt_dispatcher_5args( + const char* cipher = gdv_fn_encrypt_dispatcher_5args( ctx_ptr, data.c_str(), data_len, key16.c_str(), key16_len, mode.c_str(), mode_len, iv.c_str(), iv_len, nullptr, 0, &cipher_len); EXPECT_GT(cipher_len, 0); - const char* decrypted_value = gdv_fn_aes_decrypt_dispatcher_5args( + const char* decrypted_value = gdv_fn_decrypt_dispatcher_5args( ctx_ptr, cipher, cipher_len, key16.c_str(), key16_len, mode.c_str(), mode_len, iv.c_str(), iv_len, nullptr, 0, &decrypted_len); @@ -1554,12 +1554,12 @@ TEST(TestGdvFnStubs, TestAesEncryptDecryptGcmWithAad) { auto aad_len = static_cast(aad.length()); int64_t ctx_ptr = reinterpret_cast(&ctx); - const char* cipher = gdv_fn_aes_encrypt_dispatcher_5args( + const char* cipher = gdv_fn_encrypt_dispatcher_5args( ctx_ptr, data.c_str(), data_len, key16.c_str(), key16_len, mode.c_str(), mode_len, iv.c_str(), iv_len, aad.c_str(), aad_len, &cipher_len); EXPECT_GT(cipher_len, 0); - const char* decrypted_value = gdv_fn_aes_decrypt_dispatcher_5args( + const char* decrypted_value = gdv_fn_decrypt_dispatcher_5args( ctx_ptr, cipher, cipher_len, key16.c_str(), key16_len, mode.c_str(), mode_len, iv.c_str(), iv_len, aad.c_str(), aad_len, &decrypted_len); @@ -1568,4 +1568,146 @@ TEST(TestGdvFnStubs, TestAesEncryptDecryptGcmWithAad) { decrypted_len)); } +// Tests for shorthand mode: AES-ECB (defaults to PKCS7) +TEST(TestGdvFnStubs, TestAesEncryptDecryptShorthandEcb) { + gandiva::ExecutionContext ctx; + std::string key16 = "12345678abcdefgh"; + auto key16_len = static_cast(key16.length()); + int32_t cipher_len = 0; + int32_t decrypted_len = 0; + std::string data = "test string"; + auto data_len = static_cast(data.length()); + std::string mode = AES_ECB_MODE; // Shorthand mode + auto mode_len = static_cast(mode.length()); + int64_t ctx_ptr = reinterpret_cast(&ctx); + + const char* cipher = gdv_fn_encrypt_dispatcher_3args( + ctx_ptr, data.c_str(), data_len, key16.c_str(), key16_len, mode.c_str(), + mode_len, &cipher_len); + EXPECT_GT(cipher_len, 0); + + const char* decrypted_value = gdv_fn_decrypt_dispatcher_3args( + ctx_ptr, cipher, cipher_len, key16.c_str(), key16_len, mode.c_str(), + mode_len, &decrypted_len); + + EXPECT_EQ(data, + std::string(reinterpret_cast(decrypted_value), + decrypted_len)); +} + +// Tests for explicit mode: AES-ECB-PKCS7 +TEST(TestGdvFnStubs, TestAesEncryptDecryptExplicitEcbPkcs7) { + gandiva::ExecutionContext ctx; + std::string key16 = "12345678abcdefgh"; + auto key16_len = static_cast(key16.length()); + int32_t cipher_len = 0; + int32_t decrypted_len = 0; + std::string data = "test string"; + auto data_len = static_cast(data.length()); + std::string mode = AES_ECB_PKCS7_MODE; // Explicit mode + auto mode_len = static_cast(mode.length()); + int64_t ctx_ptr = reinterpret_cast(&ctx); + + const char* cipher = gdv_fn_encrypt_dispatcher_3args( + ctx_ptr, data.c_str(), data_len, key16.c_str(), key16_len, mode.c_str(), + mode_len, &cipher_len); + EXPECT_GT(cipher_len, 0); + + const char* decrypted_value = gdv_fn_decrypt_dispatcher_3args( + ctx_ptr, cipher, cipher_len, key16.c_str(), key16_len, mode.c_str(), + mode_len, &decrypted_len); + + EXPECT_EQ(data, + std::string(reinterpret_cast(decrypted_value), + decrypted_len)); +} + +// Tests for shorthand mode: AES-CBC (defaults to PKCS7) +TEST(TestGdvFnStubs, TestAesEncryptDecryptShorthandCbc) { + gandiva::ExecutionContext ctx; + std::string key16 = "12345678abcdefgh"; + auto key16_len = static_cast(key16.length()); + int32_t cipher_len = 0; + int32_t decrypted_len = 0; + std::string data = "test string"; + auto data_len = static_cast(data.length()); + std::string mode = AES_CBC_MODE; // Shorthand mode + auto mode_len = static_cast(mode.length()); + std::string iv = "1234567890123456"; + auto iv_len = static_cast(iv.length()); + int64_t ctx_ptr = reinterpret_cast(&ctx); + + const char* cipher = gdv_fn_encrypt_dispatcher_4args( + ctx_ptr, data.c_str(), data_len, key16.c_str(), key16_len, mode.c_str(), + mode_len, iv.c_str(), iv_len, &cipher_len); + EXPECT_GT(cipher_len, 0); + + const char* decrypted_value = gdv_fn_decrypt_dispatcher_4args( + ctx_ptr, cipher, cipher_len, key16.c_str(), key16_len, mode.c_str(), + mode_len, iv.c_str(), iv_len, &decrypted_len); + + EXPECT_EQ(data, + std::string(reinterpret_cast(decrypted_value), + decrypted_len)); +} + +// Tests for explicit mode: AES-CBC-PKCS7 +TEST(TestGdvFnStubs, TestAesEncryptDecryptExplicitCbcPkcs7) { + gandiva::ExecutionContext ctx; + std::string key16 = "12345678abcdefgh"; + auto key16_len = static_cast(key16.length()); + int32_t cipher_len = 0; + int32_t decrypted_len = 0; + std::string data = "test string"; + auto data_len = static_cast(data.length()); + std::string mode = AES_CBC_PKCS7_MODE; // Explicit mode + auto mode_len = static_cast(mode.length()); + std::string iv = "1234567890123456"; + auto iv_len = static_cast(iv.length()); + int64_t ctx_ptr = reinterpret_cast(&ctx); + + const char* cipher = gdv_fn_encrypt_dispatcher_4args( + ctx_ptr, data.c_str(), data_len, key16.c_str(), key16_len, mode.c_str(), + mode_len, iv.c_str(), iv_len, &cipher_len); + EXPECT_GT(cipher_len, 0); + + const char* decrypted_value = gdv_fn_decrypt_dispatcher_4args( + ctx_ptr, cipher, cipher_len, key16.c_str(), key16_len, mode.c_str(), + mode_len, iv.c_str(), iv_len, &decrypted_len); + + EXPECT_EQ(data, + std::string(reinterpret_cast(decrypted_value), + decrypted_len)); +} + +// Tests for explicit mode: AES-CBC-NONE (no padding) +TEST(TestGdvFnStubs, TestAesEncryptDecryptCbcNone) { + gandiva::ExecutionContext ctx; + std::string key16 = "12345678abcdefgh"; + auto key16_len = static_cast(key16.length()); + int32_t cipher_len = 0; + int32_t decrypted_len = 0; + // Use exactly 16 bytes (one block) for no-padding mode + std::string data = "1234567890123456"; + auto data_len = static_cast(data.length()); + std::string mode = AES_CBC_NONE_MODE; // No padding mode + auto mode_len = static_cast(mode.length()); + std::string iv = "1234567890123456"; + auto iv_len = static_cast(iv.length()); + int64_t ctx_ptr = reinterpret_cast(&ctx); + + const char* cipher = gdv_fn_encrypt_dispatcher_4args( + ctx_ptr, data.c_str(), data_len, key16.c_str(), key16_len, mode.c_str(), + mode_len, iv.c_str(), iv_len, &cipher_len); + EXPECT_GT(cipher_len, 0); + + const char* decrypted_value = gdv_fn_decrypt_dispatcher_4args( + ctx_ptr, cipher, cipher_len, key16.c_str(), key16_len, mode.c_str(), + mode_len, iv.c_str(), iv_len, &decrypted_len); + + EXPECT_EQ(data, + std::string(reinterpret_cast(decrypted_value), + decrypted_len)); +} + } // namespace gandiva