diff --git a/cpp/src/gandiva/CMakeLists.txt b/cpp/src/gandiva/CMakeLists.txt index 2b5fe0d9a6f..13a202be62d 100644 --- a/cpp/src/gandiva/CMakeLists.txt +++ b/cpp/src/gandiva/CMakeLists.txt @@ -59,6 +59,7 @@ set(SRC_FILES encrypt_utils_common.cc encrypt_utils_ecb.cc encrypt_utils_cbc.cc + encrypt_utils_gcm.cc encrypt_mode_dispatcher.cc expr_decomposer.cc expr_validator.cc @@ -261,6 +262,7 @@ add_gandiva_test(internals-test tree_expr_test.cc encrypt_utils_ecb_test.cc encrypt_utils_cbc_test.cc + encrypt_utils_gcm_test.cc encrypt_utils_common_test.cc expr_decomposer_test.cc exported_funcs_registry_test.cc diff --git a/cpp/src/gandiva/encrypt_mode_dispatcher.cc b/cpp/src/gandiva/encrypt_mode_dispatcher.cc index dc93779ed08..b9ee8497ee5 100644 --- a/cpp/src/gandiva/encrypt_mode_dispatcher.cc +++ b/cpp/src/gandiva/encrypt_mode_dispatcher.cc @@ -18,6 +18,7 @@ #include "gandiva/encrypt_mode_dispatcher.h" #include "gandiva/encrypt_utils_ecb.h" #include "gandiva/encrypt_utils_cbc.h" +#include "gandiva/encrypt_utils_gcm.h" #include "arrow/util/string.h" #include #include @@ -33,20 +34,22 @@ int32_t EncryptModeDispatcher::encrypt( std::string mode_str = arrow::internal::AsciiToUpper(std::string_view(mode, mode_len)); - if (mode_str == "AES-ECB") { + if (mode_str == AES_ECB_MODE) { return aes_encrypt_ecb(plaintext, plaintext_len, key, key_len, cipher); - } else if (mode_str == "AES-CBC-PKCS7") { + } else if (mode_str == AES_CBC_PKCS7_MODE) { return aes_encrypt_cbc(plaintext, plaintext_len, key, key_len, iv, iv_len, true, cipher); - } else if (mode_str == "AES-CBC-NONE") { + } else if (mode_str == AES_CBC_NONE_MODE) { return aes_encrypt_cbc(plaintext, plaintext_len, key, key_len, iv, iv_len, false, cipher); - } else if (mode_str == "AES-GCM") { - throw std::runtime_error("AES-GCM encryption mode is not yet implemented"); + } else if (mode_str == AES_GCM_MODE) { + return aes_encrypt_gcm(plaintext, plaintext_len, key, key_len, + iv, iv_len, fifth_argument, fifth_argument_len, cipher); } else { std::ostringstream oss; oss << "Unsupported encryption mode: " << mode_str - << ". Supported modes: AES-ECB, AES-CBC-PKCS7, AES-CBC-NONE"; + << ". Supported modes: " << AES_ECB_MODE << ", " << AES_CBC_PKCS7_MODE + << ", " << AES_CBC_NONE_MODE << ", " << AES_GCM_MODE; throw std::runtime_error(oss.str()); } } @@ -59,20 +62,22 @@ int32_t EncryptModeDispatcher::decrypt( std::string mode_str = arrow::internal::AsciiToUpper(std::string_view(mode, mode_len)); - if (mode_str == "AES-ECB") { + if (mode_str == AES_ECB_MODE) { return aes_decrypt_ecb(ciphertext, ciphertext_len, key, key_len, plaintext); - } else if (mode_str == "AES-CBC-PKCS7") { + } else if (mode_str == AES_CBC_PKCS7_MODE) { return aes_decrypt_cbc(ciphertext, ciphertext_len, key, key_len, iv, iv_len, true, plaintext); - } else if (mode_str == "AES-CBC-NONE") { + } else if (mode_str == AES_CBC_NONE_MODE) { return aes_decrypt_cbc(ciphertext, ciphertext_len, key, key_len, iv, iv_len, false, plaintext); - } else if (mode_str == "AES-GCM") { - throw std::runtime_error("AES-GCM decryption mode is not yet implemented"); + } else if (mode_str == AES_GCM_MODE) { + return aes_decrypt_gcm(ciphertext, ciphertext_len, key, key_len, + iv, iv_len, fifth_argument, fifth_argument_len, plaintext); } else { std::ostringstream oss; oss << "Unsupported decryption mode: " << mode_str - << ". Supported modes: AES-ECB, AES-CBC-PKCS7, AES-CBC-NONE"; + << ". Supported modes: " << AES_ECB_MODE << ", " << AES_CBC_PKCS7_MODE + << ", " << AES_CBC_NONE_MODE << ", " << AES_GCM_MODE; throw std::runtime_error(oss.str()); } } diff --git a/cpp/src/gandiva/encrypt_utils_cbc.h b/cpp/src/gandiva/encrypt_utils_cbc.h index 5cec4eb39c3..1f9749e2874 100644 --- a/cpp/src/gandiva/encrypt_utils_cbc.h +++ b/cpp/src/gandiva/encrypt_utils_cbc.h @@ -23,6 +23,10 @@ namespace gandiva { +// CBC mode identifiers +constexpr const char* AES_CBC_PKCS7_MODE = "AES-CBC-PKCS7"; +constexpr const char* AES_CBC_NONE_MODE = "AES-CBC-NONE"; + /** * Encrypt data using AES-CBC algorithm with explicit padding mode * diff --git a/cpp/src/gandiva/encrypt_utils_ecb.h b/cpp/src/gandiva/encrypt_utils_ecb.h index 51be2644120..6ed3f2e7e37 100644 --- a/cpp/src/gandiva/encrypt_utils_ecb.h +++ b/cpp/src/gandiva/encrypt_utils_ecb.h @@ -23,6 +23,9 @@ namespace gandiva { +// ECB mode identifier +constexpr const char* AES_ECB_MODE = "AES-ECB"; + /** * Encrypt data using AES-ECB algorithm (legacy, insecure) * diff --git a/cpp/src/gandiva/encrypt_utils_gcm.cc b/cpp/src/gandiva/encrypt_utils_gcm.cc new file mode 100644 index 00000000000..f028243da59 --- /dev/null +++ b/cpp/src/gandiva/encrypt_utils_gcm.cc @@ -0,0 +1,214 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "gandiva/encrypt_utils_gcm.h" +#include "gandiva/encrypt_utils_common.h" +#include +#include +#include +#include +#include + +namespace gandiva { + +namespace { + +const EVP_CIPHER* get_gcm_cipher_algo(int32_t key_length) { + switch (key_length) { + case 16: + return EVP_aes_128_gcm(); + case 24: + return EVP_aes_192_gcm(); + case 32: + return EVP_aes_256_gcm(); + default: { + std::ostringstream oss; + oss << "Unsupported key length for AES-GCM: " << key_length + << " bytes. Supported lengths: 16, 24, 32 bytes"; + throw std::runtime_error(oss.str()); + } + } +} + +} // namespace + +GANDIVA_EXPORT +int32_t aes_encrypt_gcm(const char* plaintext, int32_t plaintext_len, + const char* key, int32_t key_len, const char* iv, + int32_t iv_len, const char* aad, int32_t aad_len, + unsigned char* cipher) { + if (iv_len <= 0) { + throw std::runtime_error( + "Invalid IV length for AES-GCM: IV length must be greater than 0"); + } + + int32_t cipher_len = 0; + int32_t len = 0; + EVP_CIPHER_CTX* en_ctx = EVP_CIPHER_CTX_new(); + const EVP_CIPHER* cipher_algo = get_gcm_cipher_algo(key_len); + + if (!en_ctx) { + throw std::runtime_error("Could not create EVP cipher context for encryption: " + + get_openssl_error_string()); + } + + try { + if (!EVP_EncryptInit_ex(en_ctx, cipher_algo, nullptr, + reinterpret_cast(key), + reinterpret_cast(iv))) { + throw std::runtime_error( + "Could not initialize EVP cipher context for encryption: " + + get_openssl_error_string()); + } + + // Set IV length for GCM mode + if (!EVP_CIPHER_CTX_ctrl(en_ctx, EVP_CTRL_GCM_SET_IVLEN, iv_len, nullptr)) { + throw std::runtime_error("Could not set GCM IV length: " + + get_openssl_error_string()); + } + + // Process AAD if provided + if (aad != nullptr && aad_len > 0) { + if (!EVP_EncryptUpdate(en_ctx, nullptr, &len, + reinterpret_cast(aad), aad_len)) { + throw std::runtime_error("Could not process AAD for encryption: " + + get_openssl_error_string()); + } + } + + // Encrypt plaintext + if (!EVP_EncryptUpdate(en_ctx, cipher, &len, + reinterpret_cast(plaintext), + plaintext_len)) { + throw std::runtime_error("Could not update EVP cipher context for encryption: " + + get_openssl_error_string()); + } + + cipher_len += len; + + // Finalize encryption + if (!EVP_EncryptFinal_ex(en_ctx, cipher + len, &len)) { + throw std::runtime_error("Could not finalize EVP cipher context for encryption: " + + get_openssl_error_string()); + } + + cipher_len += len; + + // Get the authentication tag and append it to ciphertext + if (!EVP_CIPHER_CTX_ctrl(en_ctx, EVP_CTRL_GCM_GET_TAG, GCM_TAG_LENGTH, + cipher + cipher_len)) { + throw std::runtime_error("Could not get GCM authentication tag: " + + get_openssl_error_string()); + } + cipher_len += GCM_TAG_LENGTH; + } catch (...) { + EVP_CIPHER_CTX_free(en_ctx); + throw; + } + + EVP_CIPHER_CTX_free(en_ctx); + return cipher_len; +} + +GANDIVA_EXPORT +int32_t aes_decrypt_gcm(const char* ciphertext, int32_t ciphertext_len, + const char* key, int32_t key_len, const char* iv, + int32_t iv_len, const char* aad, int32_t aad_len, + unsigned char* plaintext) { + if (iv_len <= 0) { + throw std::runtime_error( + "Invalid IV length for AES-GCM: IV length must be greater than 0"); + } + + if (ciphertext_len < GCM_TAG_LENGTH) { + throw std::runtime_error( + "Ciphertext too short for AES-GCM: must be at least 16 bytes for tag"); + } + + int32_t plaintext_len = 0; + int32_t len = 0; + EVP_CIPHER_CTX* de_ctx = EVP_CIPHER_CTX_new(); + const EVP_CIPHER* cipher_algo = get_gcm_cipher_algo(key_len); + + if (!de_ctx) { + throw std::runtime_error("Could not create EVP cipher context for decryption: " + + get_openssl_error_string()); + } + + try { + if (!EVP_DecryptInit_ex(de_ctx, cipher_algo, nullptr, + reinterpret_cast(key), + reinterpret_cast(iv))) { + throw std::runtime_error( + "Could not initialize EVP cipher context for decryption: " + + get_openssl_error_string()); + } + + // Set IV length for GCM mode + if (!EVP_CIPHER_CTX_ctrl(de_ctx, EVP_CTRL_GCM_SET_IVLEN, iv_len, nullptr)) { + throw std::runtime_error("Could not set GCM IV length: " + + get_openssl_error_string()); + } + + // Process AAD if provided + if (aad != nullptr && aad_len > 0) { + if (!EVP_DecryptUpdate(de_ctx, nullptr, &len, + reinterpret_cast(aad), aad_len)) { + throw std::runtime_error("Could not process AAD for decryption: " + + get_openssl_error_string()); + } + } + + // Extract tag from end of ciphertext + int32_t actual_ciphertext_len = ciphertext_len - GCM_TAG_LENGTH; + const unsigned char* tag = + reinterpret_cast(ciphertext + actual_ciphertext_len); + + // Set the authentication tag + if (!EVP_CIPHER_CTX_ctrl(de_ctx, EVP_CTRL_GCM_SET_TAG, GCM_TAG_LENGTH, + const_cast(tag))) { + throw std::runtime_error("Could not set GCM authentication tag: " + + get_openssl_error_string()); + } + + // Decrypt ciphertext + if (!EVP_DecryptUpdate(de_ctx, plaintext, &len, + reinterpret_cast(ciphertext), + actual_ciphertext_len)) { + throw std::runtime_error("Could not update EVP cipher context for decryption: " + + get_openssl_error_string()); + } + + plaintext_len += len; + + // Finalize decryption (this verifies the tag) + if (!EVP_DecryptFinal_ex(de_ctx, plaintext + len, &len)) { + throw std::runtime_error("GCM tag verification failed or decryption error: " + + get_openssl_error_string()); + } + plaintext_len += len; + } catch (...) { + EVP_CIPHER_CTX_free(de_ctx); + throw; + } + + EVP_CIPHER_CTX_free(de_ctx); + return plaintext_len; +} + +} // namespace gandiva + diff --git a/cpp/src/gandiva/encrypt_utils_gcm.h b/cpp/src/gandiva/encrypt_utils_gcm.h new file mode 100644 index 00000000000..07a597af0b6 --- /dev/null +++ b/cpp/src/gandiva/encrypt_utils_gcm.h @@ -0,0 +1,73 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include +#include +#include "gandiva/visibility.h" + +namespace gandiva { + +// GCM mode identifier +constexpr const char* AES_GCM_MODE = "AES-GCM"; + +// GCM authentication tag length in bytes +constexpr int32_t GCM_TAG_LENGTH = 16; + +/** + * Encrypt data using AES-GCM algorithm + * + * @param plaintext The data to encrypt + * @param plaintext_len Length of plaintext in bytes + * @param key The encryption key (16, 24, or 32 bytes for 128, 192, 256-bit keys) + * @param key_len Length of key in bytes + * @param iv The initialization vector (variable length, typically 12 bytes) + * @param iv_len Length of IV in bytes + * @param aad Optional additional authenticated data (can be null) + * @param aad_len Length of AAD in bytes (0 if aad is null) + * @param cipher Output buffer for encrypted data (must be at least plaintext_len + 16 bytes) + * @return Length of encrypted data in bytes (plaintext_len + 16 for the tag) + * @throws std::runtime_error on encryption failure or invalid parameters + */ +GANDIVA_EXPORT +int32_t aes_encrypt_gcm(const char* plaintext, int32_t plaintext_len, const char* key, + int32_t key_len, const char* iv, int32_t iv_len, + const char* aad, int32_t aad_len, unsigned char* cipher); + +/** + * Decrypt data using AES-GCM algorithm + * + * @param ciphertext The data to decrypt (includes 16-byte authentication tag at the end) + * @param ciphertext_len Length of ciphertext in bytes (includes tag) + * @param key The decryption key (16, 24, or 32 bytes for 128, 192, 256-bit keys) + * @param key_len Length of key in bytes + * @param iv The initialization vector (variable length, typically 12 bytes) + * @param iv_len Length of IV in bytes + * @param aad Optional additional authenticated data (can be null) + * @param aad_len Length of AAD in bytes (0 if aad is null) + * @param plaintext Output buffer for decrypted data + * @return Length of decrypted data in bytes (ciphertext_len - 16) + * @throws std::runtime_error on decryption failure, invalid parameters, or tag verification failure + */ +GANDIVA_EXPORT +int32_t aes_decrypt_gcm(const char* ciphertext, int32_t ciphertext_len, const char* key, + int32_t key_len, const char* iv, int32_t iv_len, + const char* aad, int32_t aad_len, unsigned char* plaintext); + +} // namespace gandiva + diff --git a/cpp/src/gandiva/encrypt_utils_gcm_test.cc b/cpp/src/gandiva/encrypt_utils_gcm_test.cc new file mode 100644 index 00000000000..2156132bc62 --- /dev/null +++ b/cpp/src/gandiva/encrypt_utils_gcm_test.cc @@ -0,0 +1,162 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "gandiva/encrypt_utils_gcm.h" + +#include +#include +#include + +// Test IV-only GCM with 16-byte key +TEST(TestAesGcmEncryptUtils, TestAesEncryptDecryptIvOnly_16) { + auto* key = "12345678abcdefgh"; + auto* iv = "123456789012"; // 12-byte IV + auto* to_encrypt = "some test string"; + + auto key_len = static_cast(strlen(key)); + auto iv_len = static_cast(strlen(iv)); + auto to_encrypt_len = static_cast(strlen(to_encrypt)); + unsigned char cipher[128]; + + int32_t cipher_len = gandiva::aes_encrypt_gcm(to_encrypt, to_encrypt_len, key, key_len, + iv, iv_len, nullptr, 0, cipher); + + // Ciphertext should be plaintext_len + 16 (tag) + EXPECT_EQ(cipher_len, to_encrypt_len + 16); + + unsigned char decrypted[128]; + int32_t decrypted_len = gandiva::aes_decrypt_gcm(reinterpret_cast(cipher), + cipher_len, key, key_len, iv, iv_len, + nullptr, 0, decrypted); + + EXPECT_EQ(std::string(to_encrypt, to_encrypt_len), + std::string(reinterpret_cast(decrypted), decrypted_len)); +} + +// Test IV + AAD GCM with 16-byte key +TEST(TestAesGcmEncryptUtils, TestAesEncryptDecryptWithAad_16) { + auto* key = "12345678abcdefgh"; + auto* iv = "123456789012"; + auto* to_encrypt = "some test string"; + auto* aad = "additional authenticated data"; + + auto key_len = static_cast(strlen(key)); + auto iv_len = static_cast(strlen(iv)); + auto to_encrypt_len = static_cast(strlen(to_encrypt)); + auto aad_len = static_cast(strlen(aad)); + unsigned char cipher[128]; + + int32_t cipher_len = gandiva::aes_encrypt_gcm(to_encrypt, to_encrypt_len, key, key_len, + iv, iv_len, aad, aad_len, cipher); + + EXPECT_EQ(cipher_len, to_encrypt_len + 16); + + unsigned char decrypted[128]; + int32_t decrypted_len = gandiva::aes_decrypt_gcm(reinterpret_cast(cipher), + cipher_len, key, key_len, iv, iv_len, + aad, aad_len, decrypted); + + EXPECT_EQ(std::string(to_encrypt, to_encrypt_len), + std::string(reinterpret_cast(decrypted), decrypted_len)); +} + +// Test IV-only GCM with 24-byte key +TEST(TestAesGcmEncryptUtils, TestAesEncryptDecryptIvOnly_24) { + auto* key = "12345678abcdefgh12345678"; + auto* iv = "123456789012"; + auto* to_encrypt = "test data"; + + auto key_len = static_cast(strlen(key)); + auto iv_len = static_cast(strlen(iv)); + auto to_encrypt_len = static_cast(strlen(to_encrypt)); + unsigned char cipher[128]; + + int32_t cipher_len = gandiva::aes_encrypt_gcm(to_encrypt, to_encrypt_len, key, key_len, + iv, iv_len, nullptr, 0, cipher); + + unsigned char decrypted[128]; + int32_t decrypted_len = gandiva::aes_decrypt_gcm(reinterpret_cast(cipher), + cipher_len, key, key_len, iv, iv_len, + nullptr, 0, decrypted); + + EXPECT_EQ(std::string(to_encrypt, to_encrypt_len), + std::string(reinterpret_cast(decrypted), decrypted_len)); +} + +// Test IV-only GCM with 32-byte key +TEST(TestAesGcmEncryptUtils, TestAesEncryptDecryptIvOnly_32) { + auto* key = "12345678abcdefgh12345678abcdefgh"; + auto* iv = "123456789012"; + auto* to_encrypt = "another test"; + + auto key_len = static_cast(strlen(key)); + auto iv_len = static_cast(strlen(iv)); + auto to_encrypt_len = static_cast(strlen(to_encrypt)); + unsigned char cipher[128]; + + int32_t cipher_len = gandiva::aes_encrypt_gcm(to_encrypt, to_encrypt_len, key, key_len, + iv, iv_len, nullptr, 0, cipher); + + unsigned char decrypted[128]; + int32_t decrypted_len = gandiva::aes_decrypt_gcm(reinterpret_cast(cipher), + cipher_len, key, key_len, iv, iv_len, + nullptr, 0, decrypted); + + EXPECT_EQ(std::string(to_encrypt, to_encrypt_len), + std::string(reinterpret_cast(decrypted), decrypted_len)); +} + +// Test tag verification failure +TEST(TestAesGcmEncryptUtils, TestTagVerificationFailure) { + auto* key = "12345678abcdefgh"; + auto* iv = "123456789012"; + auto* to_encrypt = "some test string"; + + auto key_len = static_cast(strlen(key)); + auto iv_len = static_cast(strlen(iv)); + auto to_encrypt_len = static_cast(strlen(to_encrypt)); + unsigned char cipher[128]; + + int32_t cipher_len = gandiva::aes_encrypt_gcm(to_encrypt, to_encrypt_len, key, key_len, + iv, iv_len, nullptr, 0, cipher); + + // Corrupt the tag (last byte) + cipher[cipher_len - 1] ^= 0xFF; + + unsigned char decrypted[128]; + EXPECT_THROW(gandiva::aes_decrypt_gcm(reinterpret_cast(cipher), + cipher_len, key, key_len, iv, iv_len, + nullptr, 0, decrypted), + std::runtime_error); +} + +// Test invalid IV length +TEST(TestAesGcmEncryptUtils, TestInvalidIvLength) { + auto* key = "12345678abcdefgh"; + auto* iv = ""; // Empty IV + auto* to_encrypt = "some test string"; + + auto key_len = static_cast(strlen(key)); + auto iv_len = static_cast(strlen(iv)); + auto to_encrypt_len = static_cast(strlen(to_encrypt)); + unsigned char cipher[128]; + + EXPECT_THROW(gandiva::aes_encrypt_gcm(to_encrypt, to_encrypt_len, key, key_len, + iv, iv_len, nullptr, 0, cipher), + std::runtime_error); +} + diff --git a/cpp/src/gandiva/gdv_function_stubs_test.cc b/cpp/src/gandiva/gdv_function_stubs_test.cc index 7a7cc361d1e..eeb5a54042f 100644 --- a/cpp/src/gandiva/gdv_function_stubs_test.cc +++ b/cpp/src/gandiva/gdv_function_stubs_test.cc @@ -23,6 +23,9 @@ #include "arrow/util/logging.h" #include "gandiva/execution_context.h" +#include "gandiva/encrypt_utils_ecb.h" +#include "gandiva/encrypt_utils_cbc.h" +#include "gandiva/encrypt_utils_gcm.h" namespace gandiva { @@ -1353,7 +1356,7 @@ TEST(TestGdvFnStubs, TestAesEncryptDecrypt16) { int32_t decrypted_len = 0; std::string data = "test string"; auto data_len = static_cast(data.length()); - std::string mode = "AES-ECB"; + std::string mode = AES_ECB_MODE; auto mode_len = static_cast(mode.length()); int64_t ctx_ptr = reinterpret_cast(&ctx); @@ -1377,7 +1380,7 @@ TEST(TestGdvFnStubs, TestAesEncryptDecrypt24) { int32_t decrypted_len = 0; std::string data = "test string"; auto data_len = static_cast(data.length()); - std::string mode = "AES-ECB"; + std::string mode = AES_ECB_MODE; auto mode_len = static_cast(mode.length()); int64_t ctx_ptr = reinterpret_cast(&ctx); @@ -1402,7 +1405,7 @@ TEST(TestGdvFnStubs, TestAesEncryptDecrypt32) { int32_t decrypted_len = 0; std::string data = "test string"; auto data_len = static_cast(data.length()); - std::string mode = "AES-ECB"; + std::string mode = AES_ECB_MODE; auto mode_len = static_cast(mode.length()); int64_t ctx_ptr = reinterpret_cast(&ctx); @@ -1426,7 +1429,7 @@ TEST(TestGdvFnStubs, TestAesEncryptDecryptValidation) { int32_t decrypted_len = 0; std::string data = "test string"; auto data_len = static_cast(data.length()); - std::string mode = "AES-ECB"; + std::string mode = AES_ECB_MODE; auto mode_len = static_cast(mode.length()); int64_t ctx_ptr = reinterpret_cast(&ctx); std::string cipher = "12345678abcdefgh12345678abcdefghb"; @@ -1456,7 +1459,7 @@ TEST(TestGdvFnStubs, TestAesEncryptDecryptModeEcb) { int32_t decrypted_len = 0; std::string data = "test string"; auto data_len = static_cast(data.length()); - std::string mode = "AES-ECB"; + std::string mode = AES_ECB_MODE; auto mode_len = static_cast(mode.length()); int64_t ctx_ptr = reinterpret_cast(&ctx); @@ -1506,4 +1509,63 @@ TEST(TestGdvFnStubs, TestAesEncryptDecryptModeValidation) { ctx.Reset(); } +// Tests for AES-GCM mode +TEST(TestGdvFnStubs, TestAesEncryptDecryptGcmIvOnly) { + gandiva::ExecutionContext ctx; + std::string key16 = "12345678abcdefgh"; + auto key16_len = static_cast(key16.length()); + int32_t cipher_len = 0; + int32_t decrypted_len = 0; + std::string data = "test string"; + auto data_len = static_cast(data.length()); + std::string mode = AES_GCM_MODE; + auto mode_len = static_cast(mode.length()); + std::string iv = "123456789012"; + auto iv_len = static_cast(iv.length()); + int64_t ctx_ptr = reinterpret_cast(&ctx); + + const char* cipher = gdv_fn_aes_encrypt_dispatcher_5args( + ctx_ptr, data.c_str(), data_len, key16.c_str(), key16_len, mode.c_str(), + mode_len, iv.c_str(), iv_len, nullptr, 0, &cipher_len); + EXPECT_GT(cipher_len, 0); + + const char* decrypted_value = gdv_fn_aes_decrypt_dispatcher_5args( + ctx_ptr, cipher, cipher_len, key16.c_str(), key16_len, mode.c_str(), + mode_len, iv.c_str(), iv_len, nullptr, 0, &decrypted_len); + + EXPECT_EQ(data, + std::string(reinterpret_cast(decrypted_value), + decrypted_len)); +} + +TEST(TestGdvFnStubs, TestAesEncryptDecryptGcmWithAad) { + gandiva::ExecutionContext ctx; + std::string key16 = "12345678abcdefgh"; + auto key16_len = static_cast(key16.length()); + int32_t cipher_len = 0; + int32_t decrypted_len = 0; + std::string data = "test string"; + auto data_len = static_cast(data.length()); + std::string mode = AES_GCM_MODE; + auto mode_len = static_cast(mode.length()); + std::string iv = "123456789012"; + auto iv_len = static_cast(iv.length()); + std::string aad = "additional authenticated data"; + auto aad_len = static_cast(aad.length()); + int64_t ctx_ptr = reinterpret_cast(&ctx); + + const char* cipher = gdv_fn_aes_encrypt_dispatcher_5args( + ctx_ptr, data.c_str(), data_len, key16.c_str(), key16_len, mode.c_str(), + mode_len, iv.c_str(), iv_len, aad.c_str(), aad_len, &cipher_len); + EXPECT_GT(cipher_len, 0); + + const char* decrypted_value = gdv_fn_aes_decrypt_dispatcher_5args( + ctx_ptr, cipher, cipher_len, key16.c_str(), key16_len, mode.c_str(), + mode_len, iv.c_str(), iv_len, aad.c_str(), aad_len, &decrypted_len); + + EXPECT_EQ(data, + std::string(reinterpret_cast(decrypted_value), + decrypted_len)); +} + } // namespace gandiva