From 6d0cd2f9d64ef6fe24655a489a40475a9f592742 Mon Sep 17 00:00:00 2001 From: sangjanai Date: Fri, 7 Mar 2025 11:07:29 +0700 Subject: [PATCH] feat: remote embeddings --- .../extensions/remote-engine/remote_engine.cc | 271 +++++++++++++++--- .../extensions/remote-engine/remote_engine.h | 12 +- 2 files changed, 234 insertions(+), 49 deletions(-) diff --git a/engine/extensions/remote-engine/remote_engine.cc b/engine/extensions/remote-engine/remote_engine.cc index 1640b7fac..21d413d9d 100644 --- a/engine/extensions/remote-engine/remote_engine.cc +++ b/engine/extensions/remote-engine/remote_engine.cc @@ -15,6 +15,9 @@ constexpr const int k409Conflict = 409; constexpr const int k500InternalServerError = 500; constexpr const int kFileLoggerOption = 0; +constexpr const auto kChatCompletions = "chat_completions"; +constexpr const auto kEmbeddings = "embeddings"; + constexpr const std::array kAnthropicModels = { "claude-3-5-sonnet-20241022", "claude-3-5-haiku-20241022", "claude-3-opus-20240229", "claude-3-sonnet-20240229", @@ -117,9 +120,8 @@ CurlResponse RemoteEngine::MakeStreamingChatCompletionRequest( std::string full_url = chat_url_; - if (config.transform_req["chat_completions"]["url"]) { - full_url = - config.transform_req["chat_completions"]["url"].as(); + if (config.transform_req[kChatCompletions]["url"]) { + full_url = config.transform_req[kChatCompletions]["url"].as(); } CTL_DBG("full_url: " << full_url); @@ -134,11 +136,11 @@ CurlResponse RemoteEngine::MakeStreamingChatCompletionRequest( headers = curl_slist_append(headers, "Connection: keep-alive"); std::string stream_template = chat_res_template_; - if (config.transform_resp["chat_completions"] && - config.transform_resp["chat_completions"]["template"]) { + if (config.transform_resp[kChatCompletions] && + config.transform_resp[kChatCompletions]["template"]) { // Model level overrides engine level stream_template = - config.transform_resp["chat_completions"]["template"].as(); + config.transform_resp[kChatCompletions]["template"].as(); } StreamContext context{ @@ -283,9 +285,9 @@ CurlResponse RemoteEngine::MakeGetModelsRequest( return response; } -CurlResponse RemoteEngine::MakeChatCompletionRequest( - const ModelConfig& config, const std::string& body, - const std::string& method) { +CurlResponse RemoteEngine::MakeNonStreamRequest(const ModelConfig& config, + const std::string& body, + const RequestType& req_type) { CURL* curl = curl_easy_init(); CurlResponse response; @@ -294,12 +296,21 @@ CurlResponse RemoteEngine::MakeChatCompletionRequest( response.error_message = "Failed to initialize CURL"; return response; } - std::string full_url = chat_url_; - - if (config.transform_req["chat_completions"]["url"]) { - full_url = - config.transform_req["chat_completions"]["url"].as(); + std::string full_url; + if (req_type == RequestType::kChatCompletions) { + full_url = chat_url_; + if (config.transform_req[kChatCompletions]["url"]) { + full_url = + config.transform_req[kChatCompletions]["url"].as(); + } + } else { + full_url = embed_url_; + CTL_INF("embed_url_: " << embed_url_); + if (config.transform_req[kEmbeddings]["url"]) { + full_url = config.transform_req[kEmbeddings]["url"].as(); + } } + CTL_DBG("full_url: " << full_url); struct curl_slist* headers = nullptr; @@ -311,10 +322,7 @@ CurlResponse RemoteEngine::MakeChatCompletionRequest( curl_easy_setopt(curl, CURLOPT_URL, full_url.c_str()); curl_easy_setopt(curl, CURLOPT_HTTPHEADER, headers); - - if (method == "POST") { - curl_easy_setopt(curl, CURLOPT_POSTFIELDS, body.c_str()); - } + curl_easy_setopt(curl, CURLOPT_POSTFIELDS, body.c_str()); std::string response_string; curl_easy_setopt(curl, CURLOPT_WRITEFUNCTION, WriteCallback); @@ -432,29 +440,51 @@ void RemoteEngine::LoadModel( if (json_body->isMember("metadata")) { metadata_ = (*json_body)["metadata"]; - if (!metadata_["transform_req"].isNull() && - !metadata_["transform_req"]["chat_completions"].isNull() && - !metadata_["transform_req"]["chat_completions"]["template"].isNull()) { - chat_req_template_ = - metadata_["transform_req"]["chat_completions"]["template"].asString(); - CTL_INF(chat_req_template_); + if (!metadata_["transform_req"].isNull()) { + if (!metadata_["transform_req"][kChatCompletions].isNull() && + !metadata_["transform_req"][kChatCompletions]["template"].isNull()) { + chat_req_template_ = + metadata_["transform_req"][kChatCompletions]["template"].asString(); + CTL_INF("Chat request template: " << chat_req_template_); + } + + if (!metadata_["transform_req"][kEmbeddings].isNull() && + !metadata_["transform_req"][kEmbeddings]["template"].isNull()) { + embed_req_template_ = + metadata_["transform_req"][kEmbeddings]["template"].asString(); + CTL_INF("Embedding request template: " << embed_req_template_); + } } - if (!metadata_["transform_resp"].isNull() && - !metadata_["transform_resp"]["chat_completions"].isNull() && - !metadata_["transform_resp"]["chat_completions"]["template"].isNull()) { - chat_res_template_ = - metadata_["transform_resp"]["chat_completions"]["template"] - .asString(); - CTL_INF(chat_res_template_); + if (!metadata_["transform_resp"].isNull()) { + if (!metadata_["transform_resp"][kChatCompletions].isNull() && + !metadata_["transform_resp"][kChatCompletions]["template"].isNull()) { + chat_res_template_ = + metadata_["transform_resp"][kChatCompletions]["template"] + .asString(); + CTL_INF("Chat response template: " << chat_res_template_); + } + if (!metadata_["transform_resp"][kEmbeddings].isNull() && + !metadata_["transform_resp"][kEmbeddings]["template"].isNull()) { + embed_res_template_ = + metadata_["transform_resp"][kEmbeddings]["template"].asString(); + CTL_INF("Embedding request template: " << embed_res_template_); + } } - if (!metadata_["transform_req"].isNull() && - !metadata_["transform_req"]["chat_completions"].isNull() && - !metadata_["transform_req"]["chat_completions"]["url"].isNull()) { - chat_url_ = - metadata_["transform_req"]["chat_completions"]["url"].asString(); - CTL_INF(chat_url_); + if (!metadata_["transform_req"].isNull()) { + if (!metadata_["transform_req"][kChatCompletions].isNull() && + !metadata_["transform_req"][kChatCompletions]["url"].isNull()) { + chat_url_ = + metadata_["transform_req"][kChatCompletions]["url"].asString(); + CTL_INF("chat_url: " << chat_url_); + } + + if (!metadata_["transform_req"][kEmbeddings].isNull() && + !metadata_["transform_req"][kEmbeddings]["url"].isNull()) { + embed_url_ = metadata_["transform_req"][kEmbeddings]["url"].asString(); + CTL_INF("embeddings_url: " << embed_url_); + } } } @@ -569,10 +599,10 @@ void RemoteEngine::HandleChatCompletion( CTL_DBG("Use engine transform request template: " << chat_req_template_); template_str = chat_req_template_; } - if (model_config->transform_req["chat_completions"] && - model_config->transform_req["chat_completions"]["template"]) { + if (model_config->transform_req[kChatCompletions] && + model_config->transform_req[kChatCompletions]["template"]) { // Model level overrides engine level - template_str = model_config->transform_req["chat_completions"]["template"] + template_str = model_config->transform_req[kChatCompletions]["template"] .as(); CTL_DBG("Use model transform request template: " << template_str); } @@ -597,7 +627,7 @@ void RemoteEngine::HandleChatCompletion( }); } else { - auto response = MakeChatCompletionRequest(*model_config, result); + auto response = MakeNonStreamRequest(*model_config, result); if (response.error) { Json::Value status; @@ -635,11 +665,11 @@ void RemoteEngine::HandleChatCompletion( "Use engine transform response template: " << chat_res_template_); template_str = chat_res_template_; } - if (model_config->transform_resp["chat_completions"] && - model_config->transform_resp["chat_completions"]["template"]) { + if (model_config->transform_resp[kChatCompletions] && + model_config->transform_resp[kChatCompletions]["template"]) { // Model level overrides engine level template_str = - model_config->transform_resp["chat_completions"]["template"] + model_config->transform_resp[kChatCompletions]["template"] .as(); CTL_DBG("Use model transform request template: " << template_str); } @@ -722,9 +752,158 @@ void RemoteEngine::GetModelStatus( // Implement remaining virtual functions void RemoteEngine::HandleEmbedding( - std::shared_ptr, + std::shared_ptr json_body, std::function&& callback) { - callback(Json::Value(), Json::Value()); + + if (!json_body->isMember("model")) { + Json::Value status; + status["is_done"] = true; + status["has_error"] = true; + status["is_stream"] = false; + status["status_code"] = k400BadRequest; + Json::Value error; + error["error"] = "Missing required fields: model"; + callback(std::move(status), std::move(error)); + return; + } + + const std::string& model = (*json_body)["model"].asString(); + auto* model_config = GetModelConfig(model); + + if (!model_config || !model_config->is_loaded) { + Json::Value status; + status["is_done"] = true; + status["has_error"] = true; + status["is_stream"] = false; + status["status_code"] = k400BadRequest; + Json::Value error; + error["error"] = "Model not found or not loaded: " + model; + callback(std::move(status), std::move(error)); + return; + } + + // Transform request + std::string result; + try { + // Validate JSON body + if (!json_body || json_body->isNull()) { + throw std::runtime_error("Invalid or null JSON body"); + } + + // Get template string with error check + std::string template_str; + if (!embed_req_template_.empty()) { + CTL_DBG("Use engine transform request template: " << embed_req_template_); + template_str = embed_req_template_; + } + if (model_config->transform_req[kEmbeddings] && + model_config->transform_req[kEmbeddings]["template"]) { + // Model level overrides engine level + template_str = model_config->transform_req[kEmbeddings]["template"] + .as(); + CTL_DBG("Use model transform request template: " << template_str); + } + + // Render with error handling + try { + result = renderer_.Render(template_str, *json_body); + } catch (const std::exception& e) { + throw std::runtime_error("Template rendering error: " + + std::string(e.what())); + } + } catch (const std::exception& e) { + // Log error and potentially rethrow or handle accordingly + LOG_WARN << "Error in TransformRequest: " << e.what(); + LOG_WARN << "Using original request body"; + result = (*json_body).toStyledString(); + } + + auto response = + MakeNonStreamRequest(*model_config, result, RequestType::kEmbeddings); + + if (response.error) { + Json::Value status; + status["is_done"] = true; + status["has_error"] = true; + status["is_stream"] = false; + status["status_code"] = k400BadRequest; + Json::Value error; + error["error"] = response.error_message; + callback(std::move(status), std::move(error)); + return; + } + + Json::Value response_json; + Json::Reader reader; + if (!reader.parse(response.body, response_json)) { + Json::Value status; + status["is_done"] = true; + status["has_error"] = true; + status["is_stream"] = false; + status["status_code"] = k500InternalServerError; + Json::Value error; + error["error"] = "Failed to parse response"; + LOG_WARN << "Failed to parse response: " << response.body; + callback(std::move(status), std::move(error)); + return; + } + + // Transform Response + std::string response_str; + try { + std::string template_str; + if (!chat_res_template_.empty()) { + CTL_DBG("Use engine transform response template: " << chat_res_template_); + template_str = chat_res_template_; + } + if (model_config->transform_resp[kChatCompletions] && + model_config->transform_resp[kChatCompletions]["template"]) { + // Model level overrides engine level + template_str = model_config->transform_resp[kChatCompletions]["template"] + .as(); + CTL_DBG("Use model transform request template: " << template_str); + } + + try { + response_json["stream"] = false; + if (!response_json.isMember("model")) { + response_json["model"] = model; + } + response_str = renderer_.Render(template_str, response_json); + } catch (const std::exception& e) { + throw std::runtime_error("Template rendering error: " + + std::string(e.what())); + } + } catch (const std::exception& e) { + // Log error and potentially rethrow or handle accordingly + LOG_WARN << "Error: " << e.what(); + LOG_WARN << "Response: " << response.body; + LOG_WARN << "Using original body"; + response_str = response_json.toStyledString(); + } + + Json::Reader reader_final; + Json::Value response_json_final; + if (!reader_final.parse(response_str, response_json_final)) { + Json::Value status; + status["is_done"] = true; + status["has_error"] = true; + status["is_stream"] = false; + status["status_code"] = k500InternalServerError; + Json::Value error; + error["error"] = "Failed to parse response"; + callback(std::move(status), std::move(error)); + LOG_WARN << "Failed to parse response: " << response_str; + return; + } + + Json::Value status; + status["is_done"] = true; + status["has_error"] = false; + status["is_stream"] = false; + status["status_code"] = k200OK; + + callback(std::move(status), std::move(response_json_final)); } Json::Value RemoteEngine::GetRemoteModels(const std::string& url, diff --git a/engine/extensions/remote-engine/remote_engine.h b/engine/extensions/remote-engine/remote_engine.h index 6f1b731c6..b41859f3f 100644 --- a/engine/extensions/remote-engine/remote_engine.h +++ b/engine/extensions/remote-engine/remote_engine.h @@ -34,6 +34,8 @@ struct CurlResponse { std::string error_message; }; +enum class RequestType { kChatCompletions, kEmbeddings }; + class RemoteEngine : public RemoteEngineI { protected: // Model configuration @@ -58,11 +60,15 @@ class RemoteEngine : public RemoteEngineI { std::string engine_name_; std::string chat_url_; trantor::ConcurrentTaskQueue q_; + // TODO(sang) + std::string embed_req_template_; + std::string embed_res_template_; + std::string embed_url_; // Helper functions - CurlResponse MakeChatCompletionRequest(const ModelConfig& config, - const std::string& body, - const std::string& method = "POST"); + CurlResponse MakeNonStreamRequest( + const ModelConfig& config, const std::string& body, + const RequestType& req_type = RequestType::kChatCompletions); CurlResponse MakeStreamingChatCompletionRequest( const ModelConfig& config, const std::string& body, const std::function& callback);