Skip to content

feat: remote embeddings #2080

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
271 changes: 225 additions & 46 deletions engine/extensions/remote-engine/remote_engine.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@ constexpr const int k409Conflict = 409;
constexpr const int k500InternalServerError = 500;
constexpr const int kFileLoggerOption = 0;

constexpr const auto kChatCompletions = "chat_completions";
constexpr const auto kEmbeddings = "embeddings";

constexpr const std::array<std::string_view, 5> kAnthropicModels = {
"claude-3-5-sonnet-20241022", "claude-3-5-haiku-20241022",
"claude-3-opus-20240229", "claude-3-sonnet-20240229",
Expand Down Expand Up @@ -117,9 +120,8 @@ CurlResponse RemoteEngine::MakeStreamingChatCompletionRequest(

std::string full_url = chat_url_;

if (config.transform_req["chat_completions"]["url"]) {
full_url =
config.transform_req["chat_completions"]["url"].as<std::string>();
if (config.transform_req[kChatCompletions]["url"]) {
full_url = config.transform_req[kChatCompletions]["url"].as<std::string>();
}
CTL_DBG("full_url: " << full_url);

Expand All @@ -134,11 +136,11 @@ CurlResponse RemoteEngine::MakeStreamingChatCompletionRequest(
headers = curl_slist_append(headers, "Connection: keep-alive");

std::string stream_template = chat_res_template_;
if (config.transform_resp["chat_completions"] &&
config.transform_resp["chat_completions"]["template"]) {
if (config.transform_resp[kChatCompletions] &&
config.transform_resp[kChatCompletions]["template"]) {
// Model level overrides engine level
stream_template =
config.transform_resp["chat_completions"]["template"].as<std::string>();
config.transform_resp[kChatCompletions]["template"].as<std::string>();
}

StreamContext context{
Expand Down Expand Up @@ -283,9 +285,9 @@ CurlResponse RemoteEngine::MakeGetModelsRequest(
return response;
}

CurlResponse RemoteEngine::MakeChatCompletionRequest(
const ModelConfig& config, const std::string& body,
const std::string& method) {
CurlResponse RemoteEngine::MakeNonStreamRequest(const ModelConfig& config,
const std::string& body,
const RequestType& req_type) {
CURL* curl = curl_easy_init();
CurlResponse response;

Expand All @@ -294,12 +296,21 @@ CurlResponse RemoteEngine::MakeChatCompletionRequest(
response.error_message = "Failed to initialize CURL";
return response;
}
std::string full_url = chat_url_;

if (config.transform_req["chat_completions"]["url"]) {
full_url =
config.transform_req["chat_completions"]["url"].as<std::string>();
std::string full_url;
if (req_type == RequestType::kChatCompletions) {
full_url = chat_url_;
if (config.transform_req[kChatCompletions]["url"]) {
full_url =
config.transform_req[kChatCompletions]["url"].as<std::string>();
}
} else {
full_url = embed_url_;
CTL_INF("embed_url_: " << embed_url_);
if (config.transform_req[kEmbeddings]["url"]) {
full_url = config.transform_req[kEmbeddings]["url"].as<std::string>();
}
}

CTL_DBG("full_url: " << full_url);

struct curl_slist* headers = nullptr;
Expand All @@ -311,10 +322,7 @@ CurlResponse RemoteEngine::MakeChatCompletionRequest(

curl_easy_setopt(curl, CURLOPT_URL, full_url.c_str());
curl_easy_setopt(curl, CURLOPT_HTTPHEADER, headers);

if (method == "POST") {
curl_easy_setopt(curl, CURLOPT_POSTFIELDS, body.c_str());
}
curl_easy_setopt(curl, CURLOPT_POSTFIELDS, body.c_str());

std::string response_string;
curl_easy_setopt(curl, CURLOPT_WRITEFUNCTION, WriteCallback);
Expand Down Expand Up @@ -432,29 +440,51 @@ void RemoteEngine::LoadModel(

if (json_body->isMember("metadata")) {
metadata_ = (*json_body)["metadata"];
if (!metadata_["transform_req"].isNull() &&
!metadata_["transform_req"]["chat_completions"].isNull() &&
!metadata_["transform_req"]["chat_completions"]["template"].isNull()) {
chat_req_template_ =
metadata_["transform_req"]["chat_completions"]["template"].asString();
CTL_INF(chat_req_template_);
if (!metadata_["transform_req"].isNull()) {
if (!metadata_["transform_req"][kChatCompletions].isNull() &&
!metadata_["transform_req"][kChatCompletions]["template"].isNull()) {
chat_req_template_ =
metadata_["transform_req"][kChatCompletions]["template"].asString();
CTL_INF("Chat request template: " << chat_req_template_);
}

if (!metadata_["transform_req"][kEmbeddings].isNull() &&
!metadata_["transform_req"][kEmbeddings]["template"].isNull()) {
embed_req_template_ =
metadata_["transform_req"][kEmbeddings]["template"].asString();
CTL_INF("Embedding request template: " << embed_req_template_);
}
}

if (!metadata_["transform_resp"].isNull() &&
!metadata_["transform_resp"]["chat_completions"].isNull() &&
!metadata_["transform_resp"]["chat_completions"]["template"].isNull()) {
chat_res_template_ =
metadata_["transform_resp"]["chat_completions"]["template"]
.asString();
CTL_INF(chat_res_template_);
if (!metadata_["transform_resp"].isNull()) {
if (!metadata_["transform_resp"][kChatCompletions].isNull() &&
!metadata_["transform_resp"][kChatCompletions]["template"].isNull()) {
chat_res_template_ =
metadata_["transform_resp"][kChatCompletions]["template"]
.asString();
CTL_INF("Chat response template: " << chat_res_template_);
}
if (!metadata_["transform_resp"][kEmbeddings].isNull() &&
!metadata_["transform_resp"][kEmbeddings]["template"].isNull()) {
embed_res_template_ =
metadata_["transform_resp"][kEmbeddings]["template"].asString();
CTL_INF("Embedding request template: " << embed_res_template_);
}
}

if (!metadata_["transform_req"].isNull() &&
!metadata_["transform_req"]["chat_completions"].isNull() &&
!metadata_["transform_req"]["chat_completions"]["url"].isNull()) {
chat_url_ =
metadata_["transform_req"]["chat_completions"]["url"].asString();
CTL_INF(chat_url_);
if (!metadata_["transform_req"].isNull()) {
if (!metadata_["transform_req"][kChatCompletions].isNull() &&
!metadata_["transform_req"][kChatCompletions]["url"].isNull()) {
chat_url_ =
metadata_["transform_req"][kChatCompletions]["url"].asString();
CTL_INF("chat_url: " << chat_url_);
}

if (!metadata_["transform_req"][kEmbeddings].isNull() &&
!metadata_["transform_req"][kEmbeddings]["url"].isNull()) {
embed_url_ = metadata_["transform_req"][kEmbeddings]["url"].asString();
CTL_INF("embeddings_url: " << embed_url_);
}
}
}

Expand Down Expand Up @@ -569,10 +599,10 @@ void RemoteEngine::HandleChatCompletion(
CTL_DBG("Use engine transform request template: " << chat_req_template_);
template_str = chat_req_template_;
}
if (model_config->transform_req["chat_completions"] &&
model_config->transform_req["chat_completions"]["template"]) {
if (model_config->transform_req[kChatCompletions] &&
model_config->transform_req[kChatCompletions]["template"]) {
// Model level overrides engine level
template_str = model_config->transform_req["chat_completions"]["template"]
template_str = model_config->transform_req[kChatCompletions]["template"]
.as<std::string>();
CTL_DBG("Use model transform request template: " << template_str);
}
Expand All @@ -597,7 +627,7 @@ void RemoteEngine::HandleChatCompletion(
});
} else {

auto response = MakeChatCompletionRequest(*model_config, result);
auto response = MakeNonStreamRequest(*model_config, result);

if (response.error) {
Json::Value status;
Expand Down Expand Up @@ -635,11 +665,11 @@ void RemoteEngine::HandleChatCompletion(
"Use engine transform response template: " << chat_res_template_);
template_str = chat_res_template_;
}
if (model_config->transform_resp["chat_completions"] &&
model_config->transform_resp["chat_completions"]["template"]) {
if (model_config->transform_resp[kChatCompletions] &&
model_config->transform_resp[kChatCompletions]["template"]) {
// Model level overrides engine level
template_str =
model_config->transform_resp["chat_completions"]["template"]
model_config->transform_resp[kChatCompletions]["template"]
.as<std::string>();
CTL_DBG("Use model transform request template: " << template_str);
}
Expand Down Expand Up @@ -722,9 +752,158 @@ void RemoteEngine::GetModelStatus(

// Implement remaining virtual functions
void RemoteEngine::HandleEmbedding(
std::shared_ptr<Json::Value>,
std::shared_ptr<Json::Value> json_body,
std::function<void(Json::Value&&, Json::Value&&)>&& callback) {
callback(Json::Value(), Json::Value());

if (!json_body->isMember("model")) {
Json::Value status;
status["is_done"] = true;
status["has_error"] = true;
status["is_stream"] = false;
status["status_code"] = k400BadRequest;
Json::Value error;
error["error"] = "Missing required fields: model";
callback(std::move(status), std::move(error));
return;
}

const std::string& model = (*json_body)["model"].asString();
auto* model_config = GetModelConfig(model);

if (!model_config || !model_config->is_loaded) {
Json::Value status;
status["is_done"] = true;
status["has_error"] = true;
status["is_stream"] = false;
status["status_code"] = k400BadRequest;
Json::Value error;
error["error"] = "Model not found or not loaded: " + model;
callback(std::move(status), std::move(error));
return;
}

// Transform request
std::string result;
try {
// Validate JSON body
if (!json_body || json_body->isNull()) {
throw std::runtime_error("Invalid or null JSON body");
}

// Get template string with error check
std::string template_str;
if (!embed_req_template_.empty()) {
CTL_DBG("Use engine transform request template: " << embed_req_template_);
template_str = embed_req_template_;
}
if (model_config->transform_req[kEmbeddings] &&
model_config->transform_req[kEmbeddings]["template"]) {
// Model level overrides engine level
template_str = model_config->transform_req[kEmbeddings]["template"]
.as<std::string>();
CTL_DBG("Use model transform request template: " << template_str);
}

// Render with error handling
try {
result = renderer_.Render(template_str, *json_body);
} catch (const std::exception& e) {
throw std::runtime_error("Template rendering error: " +
std::string(e.what()));
}
} catch (const std::exception& e) {
// Log error and potentially rethrow or handle accordingly
LOG_WARN << "Error in TransformRequest: " << e.what();
LOG_WARN << "Using original request body";
result = (*json_body).toStyledString();
}

auto response =
MakeNonStreamRequest(*model_config, result, RequestType::kEmbeddings);

if (response.error) {
Json::Value status;
status["is_done"] = true;
status["has_error"] = true;
status["is_stream"] = false;
status["status_code"] = k400BadRequest;
Json::Value error;
error["error"] = response.error_message;
callback(std::move(status), std::move(error));
return;
}

Json::Value response_json;
Json::Reader reader;
if (!reader.parse(response.body, response_json)) {
Json::Value status;
status["is_done"] = true;
status["has_error"] = true;
status["is_stream"] = false;
status["status_code"] = k500InternalServerError;
Json::Value error;
error["error"] = "Failed to parse response";
LOG_WARN << "Failed to parse response: " << response.body;
callback(std::move(status), std::move(error));
return;
}

// Transform Response
std::string response_str;
try {
std::string template_str;
if (!chat_res_template_.empty()) {
CTL_DBG("Use engine transform response template: " << chat_res_template_);
template_str = chat_res_template_;
}
if (model_config->transform_resp[kChatCompletions] &&
model_config->transform_resp[kChatCompletions]["template"]) {
// Model level overrides engine level
template_str = model_config->transform_resp[kChatCompletions]["template"]
.as<std::string>();
CTL_DBG("Use model transform request template: " << template_str);
}

try {
response_json["stream"] = false;
if (!response_json.isMember("model")) {
response_json["model"] = model;
}
response_str = renderer_.Render(template_str, response_json);
} catch (const std::exception& e) {
throw std::runtime_error("Template rendering error: " +
std::string(e.what()));
}
} catch (const std::exception& e) {
// Log error and potentially rethrow or handle accordingly
LOG_WARN << "Error: " << e.what();
LOG_WARN << "Response: " << response.body;
LOG_WARN << "Using original body";
response_str = response_json.toStyledString();
}

Json::Reader reader_final;
Json::Value response_json_final;
if (!reader_final.parse(response_str, response_json_final)) {
Json::Value status;
status["is_done"] = true;
status["has_error"] = true;
status["is_stream"] = false;
status["status_code"] = k500InternalServerError;
Json::Value error;
error["error"] = "Failed to parse response";
callback(std::move(status), std::move(error));
LOG_WARN << "Failed to parse response: " << response_str;
return;
}

Json::Value status;
status["is_done"] = true;
status["has_error"] = false;
status["is_stream"] = false;
status["status_code"] = k200OK;

callback(std::move(status), std::move(response_json_final));
}

Json::Value RemoteEngine::GetRemoteModels(const std::string& url,
Expand Down
12 changes: 9 additions & 3 deletions engine/extensions/remote-engine/remote_engine.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ struct CurlResponse {
std::string error_message;
};

enum class RequestType { kChatCompletions, kEmbeddings };

class RemoteEngine : public RemoteEngineI {
protected:
// Model configuration
Expand All @@ -58,11 +60,15 @@ class RemoteEngine : public RemoteEngineI {
std::string engine_name_;
std::string chat_url_;
trantor::ConcurrentTaskQueue q_;
// TODO(sang)
std::string embed_req_template_;
std::string embed_res_template_;
std::string embed_url_;

// Helper functions
CurlResponse MakeChatCompletionRequest(const ModelConfig& config,
const std::string& body,
const std::string& method = "POST");
CurlResponse MakeNonStreamRequest(
const ModelConfig& config, const std::string& body,
const RequestType& req_type = RequestType::kChatCompletions);
CurlResponse MakeStreamingChatCompletionRequest(
const ModelConfig& config, const std::string& body,
const std::function<void(Json::Value&&, Json::Value&&)>& callback);
Expand Down
Loading