Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
122 changes: 88 additions & 34 deletions cpp/src/gandiva/encrypt_mode_dispatcher.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,39 @@
#include <string>
#include <sstream>
#include <stdexcept>
#include <vector>

namespace gandiva {

// Supported encryption modes
static const std::vector<std::string_view> SUPPORTED_MODES = {
AES_ECB_MODE, AES_ECB_PKCS7_MODE, AES_ECB_NONE_MODE,
AES_CBC_MODE, AES_CBC_PKCS7_MODE, AES_CBC_NONE_MODE,
AES_GCM_MODE
};

enum class EncryptionMode {
ECB,
ECB_PKCS7,
ECB_NONE,
CBC,
CBC_PKCS7,
CBC_NONE,
GCM,
UNKNOWN
};

EncryptionMode ParseEncryptionMode(std::string_view mode_str) {
if (mode_str == AES_ECB_MODE) return EncryptionMode::ECB;
if (mode_str == AES_ECB_PKCS7_MODE) return EncryptionMode::ECB_PKCS7;
if (mode_str == AES_ECB_NONE_MODE) return EncryptionMode::ECB_NONE;
if (mode_str == AES_CBC_MODE) return EncryptionMode::CBC;
if (mode_str == AES_CBC_PKCS7_MODE) return EncryptionMode::CBC_PKCS7;
if (mode_str == AES_CBC_NONE_MODE) return EncryptionMode::CBC_NONE;
if (mode_str == AES_GCM_MODE) return EncryptionMode::GCM;
return EncryptionMode::UNKNOWN;
}

int32_t EncryptModeDispatcher::encrypt(
const char* plaintext, int32_t plaintext_len, const char* key,
int32_t key_len, const char* mode, int32_t mode_len, const char* iv,
Expand All @@ -34,23 +64,35 @@ int32_t EncryptModeDispatcher::encrypt(
std::string mode_str =
arrow::internal::AsciiToUpper(std::string_view(mode, mode_len));

if (mode_str == AES_ECB_MODE) {
return aes_encrypt_ecb(plaintext, plaintext_len, key, key_len, cipher);
} else if (mode_str == AES_CBC_PKCS7_MODE) {
return aes_encrypt_cbc(plaintext, plaintext_len, key, key_len,
iv, iv_len, true, cipher);
} else if (mode_str == AES_CBC_NONE_MODE) {
return aes_encrypt_cbc(plaintext, plaintext_len, key, key_len,
iv, iv_len, false, cipher);
} else if (mode_str == AES_GCM_MODE) {
return aes_encrypt_gcm(plaintext, plaintext_len, key, key_len,
iv, iv_len, fifth_argument, fifth_argument_len, cipher);
} else {
std::ostringstream oss;
oss << "Unsupported encryption mode: " << mode_str
<< ". Supported modes: " << AES_ECB_MODE << ", " << AES_CBC_PKCS7_MODE
<< ", " << AES_CBC_NONE_MODE << ", " << AES_GCM_MODE;
throw std::runtime_error(oss.str());
switch (ParseEncryptionMode(mode_str)) {
case EncryptionMode::ECB:
case EncryptionMode::ECB_PKCS7:
// Shorthand AES-ECB and explicit AES-ECB-PKCS7 both use ECB with PKCS7
return aes_encrypt_ecb(plaintext, plaintext_len, key, key_len, cipher);
case EncryptionMode::ECB_NONE:
// ECB mode doesn't use padding, but we still call the same function
// since ECB doesn't have padding options
return aes_encrypt_ecb(plaintext, plaintext_len, key, key_len, cipher);
case EncryptionMode::CBC:
case EncryptionMode::CBC_PKCS7:
// Shorthand AES-CBC and explicit AES-CBC-PKCS7 both use CBC with PKCS7
return aes_encrypt_cbc(plaintext, plaintext_len, key, key_len,
iv, iv_len, true, cipher);
case EncryptionMode::CBC_NONE:
// CBC without padding
return aes_encrypt_cbc(plaintext, plaintext_len, key, key_len,
iv, iv_len, false, cipher);
case EncryptionMode::GCM:
return aes_encrypt_gcm(plaintext, plaintext_len, key, key_len,
iv, iv_len, fifth_argument, fifth_argument_len, cipher);
case EncryptionMode::UNKNOWN:
default: {
std::string modes_str = arrow::internal::JoinStrings(SUPPORTED_MODES, ", ");
std::ostringstream oss;
oss << "Unsupported encryption mode: " << mode_str
<< ". Supported modes: " << modes_str;
throw std::runtime_error(oss.str());
}
}
}

Expand All @@ -62,23 +104,35 @@ int32_t EncryptModeDispatcher::decrypt(
std::string mode_str =
arrow::internal::AsciiToUpper(std::string_view(mode, mode_len));

if (mode_str == AES_ECB_MODE) {
return aes_decrypt_ecb(ciphertext, ciphertext_len, key, key_len, plaintext);
} else if (mode_str == AES_CBC_PKCS7_MODE) {
return aes_decrypt_cbc(ciphertext, ciphertext_len, key, key_len,
iv, iv_len, true, plaintext);
} else if (mode_str == AES_CBC_NONE_MODE) {
return aes_decrypt_cbc(ciphertext, ciphertext_len, key, key_len,
iv, iv_len, false, plaintext);
} else if (mode_str == AES_GCM_MODE) {
return aes_decrypt_gcm(ciphertext, ciphertext_len, key, key_len,
iv, iv_len, fifth_argument, fifth_argument_len, plaintext);
} else {
std::ostringstream oss;
oss << "Unsupported decryption mode: " << mode_str
<< ". Supported modes: " << AES_ECB_MODE << ", " << AES_CBC_PKCS7_MODE
<< ", " << AES_CBC_NONE_MODE << ", " << AES_GCM_MODE;
throw std::runtime_error(oss.str());
switch (ParseEncryptionMode(mode_str)) {
case EncryptionMode::ECB:
case EncryptionMode::ECB_PKCS7:
// Shorthand AES-ECB and explicit AES-ECB-PKCS7 both use ECB with PKCS7
return aes_decrypt_ecb(ciphertext, ciphertext_len, key, key_len, plaintext);
case EncryptionMode::ECB_NONE:
// ECB mode doesn't use padding, but we still call the same function
// since ECB doesn't have padding options
return aes_decrypt_ecb(ciphertext, ciphertext_len, key, key_len, plaintext);
case EncryptionMode::CBC:
case EncryptionMode::CBC_PKCS7:
// Shorthand AES-CBC and explicit AES-CBC-PKCS7 both use CBC with PKCS7
return aes_decrypt_cbc(ciphertext, ciphertext_len, key, key_len,
iv, iv_len, true, plaintext);
case EncryptionMode::CBC_NONE:
// CBC without padding
return aes_decrypt_cbc(ciphertext, ciphertext_len, key, key_len,
iv, iv_len, false, plaintext);
case EncryptionMode::GCM:
return aes_decrypt_gcm(ciphertext, ciphertext_len, key, key_len,
iv, iv_len, fifth_argument, fifth_argument_len, plaintext);
case EncryptionMode::UNKNOWN:
default: {
std::string modes_str = arrow::internal::JoinStrings(SUPPORTED_MODES, ", ");
std::ostringstream oss;
oss << "Unsupported decryption mode: " << mode_str
<< ". Supported modes: " << modes_str;
throw std::runtime_error(oss.str());
}
}
}

Expand Down
1 change: 1 addition & 0 deletions cpp/src/gandiva/encrypt_utils_cbc.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
namespace gandiva {

// CBC mode identifiers
constexpr const char* AES_CBC_MODE = "AES-CBC";
constexpr const char* AES_CBC_PKCS7_MODE = "AES-CBC-PKCS7";
constexpr const char* AES_CBC_NONE_MODE = "AES-CBC-NONE";

Expand Down
4 changes: 3 additions & 1 deletion cpp/src/gandiva/encrypt_utils_ecb.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,10 @@

namespace gandiva {

// ECB mode identifier
// ECB mode identifiers
constexpr const char* AES_ECB_MODE = "AES-ECB";
constexpr const char* AES_ECB_PKCS7_MODE = "AES-ECB-PKCS7";
constexpr const char* AES_ECB_NONE_MODE = "AES-ECB-NONE";

/**
* Encrypt data using AES-ECB algorithm (legacy, insecure)
Expand Down
24 changes: 12 additions & 12 deletions cpp/src/gandiva/function_registry_string.cc
Original file line number Diff line number Diff line change
Expand Up @@ -505,30 +505,30 @@ std::vector<NativeFunction> GetStringFunctionRegistry() {
NativeFunction::kNeedsContext | NativeFunction::kCanReturnErrors),

// Parameters: data, key, mode (e.g. ECB mode)
NativeFunction("aes_encrypt", {}, DataTypeVector{binary(), binary(), utf8()}, binary(),
kResultNullIfNull, "gdv_fn_aes_encrypt_dispatcher_3args",
NativeFunction("encrypt", {}, DataTypeVector{binary(), binary(), utf8()}, binary(),
kResultNullIfNull, "gdv_fn_encrypt_dispatcher_3args",
NativeFunction::kNeedsContext | NativeFunction::kCanReturnErrors),

NativeFunction("aes_decrypt", {}, DataTypeVector{binary(), binary(), utf8()}, binary(),
kResultNullIfNull, "gdv_fn_aes_decrypt_dispatcher_3args",
NativeFunction("decrypt", {}, DataTypeVector{binary(), binary(), utf8()}, binary(),
kResultNullIfNull, "gdv_fn_decrypt_dispatcher_3args",
NativeFunction::kNeedsContext | NativeFunction::kCanReturnErrors),

// Parameters: data, key, mode, iv (e.g. CBC mode)
NativeFunction("aes_encrypt", {}, DataTypeVector{binary(), binary(), utf8(), binary()}, binary(),
kResultNullIfNull, "gdv_fn_aes_encrypt_dispatcher_4args",
NativeFunction("encrypt", {}, DataTypeVector{binary(), binary(), utf8(), binary()}, binary(),
kResultNullIfNull, "gdv_fn_encrypt_dispatcher_4args",
NativeFunction::kNeedsContext | NativeFunction::kCanReturnErrors),

NativeFunction("aes_decrypt", {}, DataTypeVector{binary(), binary(), utf8(), binary()}, binary(),
kResultNullIfNull, "gdv_fn_aes_decrypt_dispatcher_4args",
NativeFunction("decrypt", {}, DataTypeVector{binary(), binary(), utf8(), binary()}, binary(),
kResultNullIfNull, "gdv_fn_decrypt_dispatcher_4args",
NativeFunction::kNeedsContext | NativeFunction::kCanReturnErrors),

// Parameters: data, key, mode, iv, fifth_argument (e.g. GCM mode)
NativeFunction("aes_encrypt", {}, DataTypeVector{binary(), binary(), utf8(), binary(), binary()}, binary(),
kResultNullIfNull, "gdv_fn_aes_encrypt_dispatcher_5args",
NativeFunction("encrypt", {}, DataTypeVector{binary(), binary(), utf8(), binary(), binary()}, binary(),
kResultNullIfNull, "gdv_fn_encrypt_dispatcher_5args",
NativeFunction::kNeedsContext | NativeFunction::kCanReturnErrors),

NativeFunction("aes_decrypt", {}, DataTypeVector{binary(), binary(), utf8(), binary(), binary()}, binary(),
kResultNullIfNull, "gdv_fn_aes_decrypt_dispatcher_5args",
NativeFunction("decrypt", {}, DataTypeVector{binary(), binary(), utf8(), binary(), binary()}, binary(),
kResultNullIfNull, "gdv_fn_decrypt_dispatcher_5args",
NativeFunction::kNeedsContext | NativeFunction::kCanReturnErrors),

NativeFunction("mask_first_n", {}, DataTypeVector{utf8(), int32()}, utf8(),
Expand Down
60 changes: 30 additions & 30 deletions cpp/src/gandiva/gdv_function_stubs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -868,7 +868,7 @@ const char* gdv_fn_aes_encrypt_ecb_legacy(int64_t context, const char* data,
// This function is ECB-only, so we enforce the mode
const char* mode = "AES-ECB";
int32_t mode_len = 7;
const char* result = gdv_fn_aes_encrypt_dispatcher_3args(
const char* result = gdv_fn_encrypt_dispatcher_3args(
context, data, data_len, key_data, key_data_len, mode, mode_len, out_len);

// Add null terminator for string compatibility
Expand All @@ -895,7 +895,7 @@ const char* gdv_fn_aes_decrypt_ecb_legacy(int64_t context, const char* data,
// This function is ECB-only, so we enforce the mode
const char* mode = "AES-ECB";
int32_t mode_len = 7;
const char* result = gdv_fn_aes_decrypt_dispatcher_3args(
const char* result = gdv_fn_decrypt_dispatcher_3args(
context, data, data_len, key_data, key_data_len, mode, mode_len, out_len);

// Add null terminator for string compatibility
Expand All @@ -910,47 +910,47 @@ const char* gdv_fn_aes_decrypt_ecb_legacy(int64_t context, const char* data,

// The 3- and 4-arg signatures exist to support optional IV and other arguments
extern "C" GANDIVA_EXPORT
const char* gdv_fn_aes_encrypt_dispatcher_3args(
const char* gdv_fn_encrypt_dispatcher_3args(
int64_t context, const char* data, int32_t data_len, const char* key_data,
int32_t key_data_len, const char* mode, int32_t mode_len,
int32_t* out_len) {
return gdv_fn_aes_encrypt_dispatcher_5args(
return gdv_fn_encrypt_dispatcher_5args(
context, data, data_len, key_data, key_data_len, mode, mode_len, nullptr,
0, nullptr, 0, out_len);
}

extern "C" GANDIVA_EXPORT
const char* gdv_fn_aes_decrypt_dispatcher_3args(
const char* gdv_fn_decrypt_dispatcher_3args(
int64_t context, const char* data, int32_t data_len, const char* key_data,
int32_t key_data_len, const char* mode, int32_t mode_len,
int32_t* out_len) {
return gdv_fn_aes_decrypt_dispatcher_5args(
return gdv_fn_decrypt_dispatcher_5args(
context, data, data_len, key_data, key_data_len, mode, mode_len, nullptr,
0, nullptr, 0, out_len);
}

extern "C" GANDIVA_EXPORT
const char* gdv_fn_aes_encrypt_dispatcher_4args(
const char* gdv_fn_encrypt_dispatcher_4args(
int64_t context, const char* data, int32_t data_len, const char* key_data,
int32_t key_data_len, const char* mode, int32_t mode_len,
const char* iv_data, int32_t iv_data_len, int32_t* out_len) {
return gdv_fn_aes_encrypt_dispatcher_5args(
return gdv_fn_encrypt_dispatcher_5args(
context, data, data_len, key_data, key_data_len, mode, mode_len, iv_data,
iv_data_len, nullptr, 0, out_len);
}

extern "C" GANDIVA_EXPORT
const char* gdv_fn_aes_decrypt_dispatcher_4args(
const char* gdv_fn_decrypt_dispatcher_4args(
int64_t context, const char* data, int32_t data_len, const char* key_data,
int32_t key_data_len, const char* mode, int32_t mode_len,
const char* iv_data, int32_t iv_data_len, int32_t* out_len) {
return gdv_fn_aes_decrypt_dispatcher_5args(
return gdv_fn_decrypt_dispatcher_5args(
context, data, data_len, key_data, key_data_len, mode, mode_len, iv_data,
iv_data_len, nullptr, 0, out_len);
}

extern "C" GANDIVA_EXPORT
const char* gdv_fn_aes_encrypt_dispatcher_5args(
const char* gdv_fn_encrypt_dispatcher_5args(
int64_t context, const char* data, int32_t data_len, const char* key_data,
int32_t key_data_len, const char* mode, int32_t mode_len,
const char* iv_data, int32_t iv_data_len, const char* fifth_argument,
Expand Down Expand Up @@ -980,7 +980,7 @@ const char* gdv_fn_aes_encrypt_dispatcher_5args(
}

extern "C" GANDIVA_EXPORT
const char* gdv_fn_aes_decrypt_dispatcher_5args(
const char* gdv_fn_decrypt_dispatcher_5args(
int64_t context, const char* data, int32_t data_len, const char* key_data,
int32_t key_data_len, const char* mode, int32_t mode_len,
const char* iv_data, int32_t iv_data_len, const char* fifth_argument,
Expand Down Expand Up @@ -1241,7 +1241,7 @@ arrow::Status ExportedStubFunctions::AddMappings(Engine* engine) const {
types->i32_ptr_type() // out_length
};

// gdv_fn_aes_encrypt_dispatcher_3args (data, key, mode)
// gdv_fn_encrypt_dispatcher_3args (data, key, mode)
args = {
types->i64_type(), // context
types->i8_ptr_type(), // data
Expand All @@ -1254,11 +1254,11 @@ arrow::Status ExportedStubFunctions::AddMappings(Engine* engine) const {
};

engine->AddGlobalMappingForFunc(
"gdv_fn_aes_encrypt_dispatcher_3args",
"gdv_fn_encrypt_dispatcher_3args",
types->i8_ptr_type() /*return_type*/, args,
reinterpret_cast<void*>(gdv_fn_aes_encrypt_dispatcher_3args));
reinterpret_cast<void*>(gdv_fn_encrypt_dispatcher_3args));

// gdv_fn_aes_decrypt_dispatcher_3args (data, key, mode)
// gdv_fn_decrypt_dispatcher_3args (data, key, mode)
args = {
types->i64_type(), // context
types->i8_ptr_type(), // data
Expand All @@ -1271,11 +1271,11 @@ arrow::Status ExportedStubFunctions::AddMappings(Engine* engine) const {
};

engine->AddGlobalMappingForFunc(
"gdv_fn_aes_decrypt_dispatcher_3args",
"gdv_fn_decrypt_dispatcher_3args",
types->i8_ptr_type() /*return_type*/, args,
reinterpret_cast<void*>(gdv_fn_aes_decrypt_dispatcher_3args));
reinterpret_cast<void*>(gdv_fn_decrypt_dispatcher_3args));

// gdv_fn_aes_encrypt_dispatcher_4args (data, key, mode, iv)
// gdv_fn_encrypt_dispatcher_4args (data, key, mode, iv)
args = {
types->i64_type(), // context
types->i8_ptr_type(), // data
Expand All @@ -1290,11 +1290,11 @@ arrow::Status ExportedStubFunctions::AddMappings(Engine* engine) const {
};

engine->AddGlobalMappingForFunc(
"gdv_fn_aes_encrypt_dispatcher_4args",
"gdv_fn_encrypt_dispatcher_4args",
types->i8_ptr_type() /*return_type*/, args,
reinterpret_cast<void*>(gdv_fn_aes_encrypt_dispatcher_4args));
reinterpret_cast<void*>(gdv_fn_encrypt_dispatcher_4args));

// gdv_fn_aes_decrypt_dispatcher_4args (data, key, mode, iv)
// gdv_fn_decrypt_dispatcher_4args (data, key, mode, iv)
args = {
types->i64_type(), // context
types->i8_ptr_type(), // data
Expand All @@ -1309,11 +1309,11 @@ arrow::Status ExportedStubFunctions::AddMappings(Engine* engine) const {
};

engine->AddGlobalMappingForFunc(
"gdv_fn_aes_decrypt_dispatcher_4args",
"gdv_fn_decrypt_dispatcher_4args",
types->i8_ptr_type() /*return_type*/, args,
reinterpret_cast<void*>(gdv_fn_aes_decrypt_dispatcher_4args));
reinterpret_cast<void*>(gdv_fn_decrypt_dispatcher_4args));

// gdv_fn_aes_encrypt_dispatcher_5args (data, key, mode, iv,
// gdv_fn_encrypt_dispatcher_5args (data, key, mode, iv,
// fifth_argument)
args = {
types->i64_type(), // context
Expand All @@ -1331,11 +1331,11 @@ arrow::Status ExportedStubFunctions::AddMappings(Engine* engine) const {
};

engine->AddGlobalMappingForFunc(
"gdv_fn_aes_encrypt_dispatcher_5args",
"gdv_fn_encrypt_dispatcher_5args",
types->i8_ptr_type() /*return_type*/, args,
reinterpret_cast<void*>(gdv_fn_aes_encrypt_dispatcher_5args));
reinterpret_cast<void*>(gdv_fn_encrypt_dispatcher_5args));

// gdv_fn_aes_decrypt_dispatcher_5args (data, key, mode, iv,
// gdv_fn_decrypt_dispatcher_5args (data, key, mode, iv,
// fifth_argument)
args = {
types->i64_type(), // context
Expand All @@ -1353,9 +1353,9 @@ arrow::Status ExportedStubFunctions::AddMappings(Engine* engine) const {
};

engine->AddGlobalMappingForFunc(
"gdv_fn_aes_decrypt_dispatcher_5args",
"gdv_fn_decrypt_dispatcher_5args",
types->i8_ptr_type() /*return_type*/, args,
reinterpret_cast<void*>(gdv_fn_aes_decrypt_dispatcher_5args));
reinterpret_cast<void*>(gdv_fn_decrypt_dispatcher_5args));

// gdv_mask_first_n and gdv_mask_last_n
std::vector<llvm::Type*> mask_args = {
Expand Down
Loading
Loading