diff --git a/engine/CMakeLists.txt b/engine/CMakeLists.txt index f7a20b58b..23d20e5a1 100644 --- a/engine/CMakeLists.txt +++ b/engine/CMakeLists.txt @@ -177,6 +177,8 @@ add_executable(${TARGET_NAME} main.cc ${CMAKE_CURRENT_SOURCE_DIR}/utils/file_logger.cc ${CMAKE_CURRENT_SOURCE_DIR}/extensions/template_renderer.cc + ${CMAKE_CURRENT_SOURCE_DIR}/extensions/python-engines/python_utils.cc + ${CMAKE_CURRENT_SOURCE_DIR}/extensions/python-engines/vllm_engine.cc ${CMAKE_CURRENT_SOURCE_DIR}/utils/dylib_path_manager.cc ${CMAKE_CURRENT_SOURCE_DIR}/utils/process/utils.cc diff --git a/engine/cli/CMakeLists.txt b/engine/cli/CMakeLists.txt index 4163042d0..2f1252ac3 100644 --- a/engine/cli/CMakeLists.txt +++ b/engine/cli/CMakeLists.txt @@ -74,6 +74,8 @@ add_executable(${TARGET_NAME} main.cc ${CMAKE_CURRENT_SOURCE_DIR}/../services/database_service.cc ${CMAKE_CURRENT_SOURCE_DIR}/../extensions/remote-engine/remote_engine.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../extensions/python-engines/python_utils.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../extensions/python-engines/vllm_engine.cc ${CMAKE_CURRENT_SOURCE_DIR}/../extensions/template_renderer.cc ${CMAKE_CURRENT_SOURCE_DIR}/utils/easywsclient.cc diff --git a/engine/cli/commands/chat_completion_cmd.cc b/engine/cli/commands/chat_completion_cmd.cc index 77ee4fca3..6b52464f3 100644 --- a/engine/cli/commands/chat_completion_cmd.cc +++ b/engine/cli/commands/chat_completion_cmd.cc @@ -137,7 +137,11 @@ void ChatCompletionCmd::Exec(const std::string& host, int port, new_data["content"] = user_input; histories_.push_back(std::move(new_data)); - Json::Value json_data = mc.ToJson(); + // vLLM doesn't support params used model config + Json::Value json_data; + if (mc.engine != kVllmEngine) { + json_data = mc.ToJson(); + } json_data["engine"] = mc.engine; Json::Value msgs_array(Json::arrayValue); diff --git a/engine/cli/commands/engine_install_cmd.cc b/engine/cli/commands/engine_install_cmd.cc index bebfdb8ce..c03e72b01 100644 --- a/engine/cli/commands/engine_install_cmd.cc +++ b/engine/cli/commands/engine_install_cmd.cc @@ -7,6 +7,13 @@ #include "utils/string_utils.h" namespace commands { + +// NOTE: should have a single source of truth between CLI and server +static bool NeedCudaDownload(const std::string& engine) { + return !system_info_utils::GetDriverAndCudaVersion().second.empty() && + engine == kLlamaRepo; +} + bool EngineInstallCmd::Exec(const std::string& engine, const std::string& version, const std::string& src) { @@ -35,15 +42,18 @@ bool EngineInstallCmd::Exec(const std::string& engine, if (show_menu_) { DownloadProgress dp; dp.Connect(host_, port_); + bool need_cuda_download = NeedCudaDownload(engine); // engine can be small, so need to start ws first - auto dp_res = std::async(std::launch::deferred, [&dp] { - bool need_cuda_download = - !system_info_utils::GetDriverAndCudaVersion().second.empty(); - if (need_cuda_download) { + auto dp_res = std::async(std::launch::deferred, [&dp, need_cuda_download, engine] { + // if (need_cuda_download) { + // return dp.Handle({DownloadType::Engine, DownloadType::CudaToolkit}); + // } else { + // return dp.Handle({DownloadType::Engine}); + // } + if (engine == kLlamaRepo) return dp.Handle({DownloadType::Engine, DownloadType::CudaToolkit}); - } else { - return dp.Handle({DownloadType::Engine}); - } + else + return dp.Handle({}); }); auto releases_url = url_parser::Url{ @@ -151,15 +161,18 @@ bool EngineInstallCmd::Exec(const std::string& engine, // default DownloadProgress dp; dp.Connect(host_, port_); + bool need_cuda_download = NeedCudaDownload(engine); // engine can be small, so need to start ws first - auto dp_res = std::async(std::launch::deferred, [&dp] { - bool need_cuda_download = - !system_info_utils::GetDriverAndCudaVersion().second.empty(); - if (need_cuda_download) { + auto dp_res = std::async(std::launch::deferred, [&dp, need_cuda_download, engine] { + // if (need_cuda_download) { + // return dp.Handle({DownloadType::Engine, DownloadType::CudaToolkit}); + // } else { + // return dp.Handle({DownloadType::Engine}); + // } + if (engine == kLlamaRepo) return dp.Handle({DownloadType::Engine, DownloadType::CudaToolkit}); - } else { - return dp.Handle({DownloadType::Engine}); - } + else + return dp.Handle({}); }); auto install_url = url_parser::Url{ diff --git a/engine/cli/commands/model_pull_cmd.cc b/engine/cli/commands/model_pull_cmd.cc index b20d7596e..52a7b6326 100644 --- a/engine/cli/commands/model_pull_cmd.cc +++ b/engine/cli/commands/model_pull_cmd.cc @@ -67,8 +67,12 @@ std::optional ModelPullCmd::Exec(const std::string& host, int port, auto download_url = res.value()["downloadUrl"].asString(); if (downloaded.empty() && avails.empty()) { - model_id = id; - model = download_url; + if (res.value()["modelSource"].asString() == "huggingface") { + model = id; + } else { + model_id = id; + model = download_url; + } } else { if (is_cortexso) { auto selection = cli_selection_utils::PrintModelSelection( diff --git a/engine/cli/commands/run_cmd.cc b/engine/cli/commands/run_cmd.cc index c01d3d806..25f3ae45d 100644 --- a/engine/cli/commands/run_cmd.cc +++ b/engine/cli/commands/run_cmd.cc @@ -84,11 +84,18 @@ void RunCmd::Exec(bool run_detach, CLI_LOG("Error: " + model_entry.error()); return; } - yaml_handler.ModelConfigFromFile( - fmu::ToAbsoluteCortexDataPath( - fs::path(model_entry.value().path_to_model_yaml)) - .string()); - auto mc = yaml_handler.GetModelConfig(); + + config::ModelConfig mc; + if (model_entry.value().engine == kVllmEngine) { + // vLLM engine doesn't have model config + mc.engine = kVllmEngine; + } else { + yaml_handler.ModelConfigFromFile( + fmu::ToAbsoluteCortexDataPath( + fs::path(model_entry.value().path_to_model_yaml)) + .string()); + mc = yaml_handler.GetModelConfig(); + } // Check if engine existed. If not, download it { diff --git a/engine/controllers/models.cc b/engine/controllers/models.cc index 2071407f5..b853163db 100644 --- a/engine/controllers/models.cc +++ b/engine/controllers/models.cc @@ -28,7 +28,7 @@ void Models::PullModel(const HttpRequestPtr& req, return; } - auto model_handle = (*(req->getJsonObject())).get("model", "").asString(); + auto model_handle = req->getJsonObject()->get("model", "").asString(); if (model_handle.empty()) { Json::Value ret; ret["result"] = "Bad Request"; @@ -39,52 +39,19 @@ void Models::PullModel(const HttpRequestPtr& req, } std::optional desired_model_id = std::nullopt; - auto id = (*(req->getJsonObject())).get("id", "").asString(); + auto id = req->getJsonObject()->get("id", "").asString(); if (!id.empty()) { desired_model_id = id; } std::optional desired_model_name = std::nullopt; - auto name_value = (*(req->getJsonObject())).get("name", "").asString(); - + auto name_value = req->getJsonObject()->get("name", "").asString(); if (!name_value.empty()) { desired_model_name = name_value; } - auto handle_model_input = - [&, model_handle]() -> cpp::result { - CTL_INF("Handle model input, model handle: " + model_handle); - if (string_utils::StartsWith(model_handle, "https")) { - return model_service_->HandleDownloadUrlAsync( - model_handle, desired_model_id, desired_model_name); - } else if (model_handle.find(":") != std::string::npos) { - auto model_and_branch = string_utils::SplitBy(model_handle, ":"); - if (model_and_branch.size() == 3) { - auto mh = url_parser::Url{ - /* .protocol = */ "https", - /* .host = */ kHuggingFaceHost, - /* .pathParams = */ - { - model_and_branch[0], - model_and_branch[1], - "resolve", - "main", - model_and_branch[2], - }, - /* queries= */ {}, - } - .ToFullPath(); - return model_service_->HandleDownloadUrlAsync(mh, desired_model_id, - desired_model_name); - } - return model_service_->DownloadModelFromCortexsoAsync( - model_and_branch[0], model_and_branch[1], desired_model_id); - } - - return cpp::fail("Invalid model handle or not supported!"); - }; - - auto result = handle_model_input(); + auto result = model_service_->PullModel(model_handle, desired_model_id, + desired_model_name); if (result.has_error()) { Json::Value ret; ret["message"] = result.error(); @@ -213,6 +180,17 @@ void Models::ListModel( data.append(std::move(obj)); continue; } + + if (model_entry.engine == kVllmEngine) { + Json::Value obj; + obj["id"] = model_entry.model; + obj["model"] = model_entry.model; + obj["engine"] = model_entry.engine; + obj["status"] = "downloaded"; + data.append(std::move(obj)); + continue; + } + yaml_handler.ModelConfigFromFile( fmu::ToAbsoluteCortexDataPath( fs::path(model_entry.path_to_model_yaml)) diff --git a/engine/controllers/server.cc b/engine/controllers/server.cc index 079b69423..6ea733a70 100644 --- a/engine/controllers/server.cc +++ b/engine/controllers/server.cc @@ -138,7 +138,7 @@ void server::ProcessStreamRes(std::function cb, auto err_or_done = std::make_shared(false); auto chunked_content_provider = [this, q, err_or_done, engine_type, model_id]( char* buf, - std::size_t buf_size) -> std::size_t { + std::size_t buf_size) -> std::size_t { if (buf == nullptr) { LOG_TRACE << "Buf is null"; if (!(*err_or_done)) { diff --git a/engine/e2e-test/api/engines/test_api_engine.py b/engine/e2e-test/api/engines/test_api_engine.py index 7356ef904..dbdf2dbe9 100644 --- a/engine/e2e-test/api/engines/test_api_engine.py +++ b/engine/e2e-test/api/engines/test_api_engine.py @@ -20,12 +20,12 @@ def setup_and_teardown(self): # Teardown stop_server() - + # engines get def test_engines_get_llamacpp_should_be_successful(self): response = requests.get("http://localhost:3928/engines/llama-cpp") assert response.status_code == 200 - + # engines install def test_engines_install_llamacpp_specific_version_and_variant(self): data = {"version": "v0.1.40-b4354", "variant": "linux-amd64-avx"} @@ -40,7 +40,7 @@ def test_engines_install_llamacpp_specific_version_and_null_variant(self): "http://localhost:3928/v1/engines/llama-cpp/install", json=data ) assert response.status_code == 200 - + # engines uninstall @pytest.mark.asyncio async def test_engines_install_uninstall_llamacpp_should_be_successful(self): diff --git a/engine/extensions/python-engines/python_utils.cc b/engine/extensions/python-engines/python_utils.cc new file mode 100644 index 000000000..965b4c324 --- /dev/null +++ b/engine/extensions/python-engines/python_utils.cc @@ -0,0 +1,119 @@ +#include "python_utils.h" +#include + +#include "utils/archive_utils.h" +#include "utils/curl_utils.h" +#include "utils/file_manager_utils.h" +#include "utils/set_permission_utils.h" +#include "utils/system_info_utils.h" + +namespace python_utils { + +std::filesystem::path GetPythonEnginesPath() { + return file_manager_utils::GetCortexDataPath() / "python_engines"; +} +std::filesystem::path GetEnvsPath() { + return GetPythonEnginesPath() / "envs"; +} +std::filesystem::path GetUvPath() { + auto system_info = system_info_utils::GetSystemInfo(); + const auto bin_name = system_info->os == kWindowsOs ? "uv.exe" : "uv"; + return GetPythonEnginesPath() / "bin" / bin_name; +} +bool UvCleanCache() { + auto cmd = UvBuildCommand("cache"); + cmd.push_back("clean"); + auto result = cortex::process::SpawnProcess(cmd); + if (result.has_error()) { + CTL_INF(result.error()); + return false; + } + return cortex::process::WaitProcess(result.value()); +} + +bool UvIsInstalled() { + return std::filesystem::exists(GetUvPath()); +} +cpp::result UvInstall() { + const auto py_bin_path = GetPythonEnginesPath() / "bin"; + std::filesystem::create_directories(py_bin_path); + + // NOTE: do we need a mechanism to update uv, or just pin uv version with cortex release? + const std::string uv_version = "0.6.11"; + + // build download url based on system info + std::stringstream fname_stream; + fname_stream << "uv-"; + + auto system_info = system_info_utils::GetSystemInfo(); + if (system_info->arch == "amd64") + fname_stream << "x86_64"; + else if (system_info->arch == "arm64") + fname_stream << "aarch64"; + + // NOTE: there is also a musl linux version + if (system_info->os == kMacOs) + fname_stream << "-apple-darwin.tar.gz"; + else if (system_info->os == kWindowsOs) + fname_stream << "-pc-windows-msvc.zip"; + else if (system_info->os == kLinuxOs) + fname_stream << "-unknown-linux-gnu.tar.gz"; + + const std::string fname = fname_stream.str(); + const std::string base_url = + "https://github.com/astral-sh/uv/releases/download/"; + + std::stringstream url_stream; + url_stream << base_url << uv_version << "/" << fname; + const std::string url = url_stream.str(); + CTL_INF("Download uv from " << url); + + const auto save_path = py_bin_path / fname; + auto res = curl_utils::SimpleDownload(url, save_path.string()); + if (res.has_error()) + return res; + + archive_utils::ExtractArchive(save_path, py_bin_path.string(), true); + set_permission_utils::SetExecutePermissionsRecursive(py_bin_path); + std::filesystem::remove(save_path); + + // install Python3.10 from Astral. this will be preferred over system + // Python when possible. + // NOTE: currently this will install to a user-wide directory. we can + // install to a specific location using `--install-dir`, but later + // invocation of `uv run` needs to have `UV_PYTHON_INSTALL_DIR` set to use + // this Python installation. + // we can add this once we allow passing custom env var to SpawnProcess(). + // https://docs.astral.sh/uv/reference/cli/#uv-python-install + std::vector command = UvBuildCommand("python"); + command.push_back("install"); + command.push_back("3.10"); + + auto result = cortex::process::SpawnProcess(command); + if (result.has_error()) + return cpp::fail(result.error()); + + if (!cortex::process::WaitProcess(result.value())) { + const auto msg = "Process spawned but fail to wait"; + CTL_ERR(msg); + return cpp::fail(msg); + } + + return {}; +} + +std::vector UvBuildCommand(const std::string& action, + const std::string& directory) { + // use our own cache dir so that when users delete cortexcpp/, everything is deleted. + const auto cache_dir = GetPythonEnginesPath() / "cache" / "uv"; + std::vector command = {GetUvPath().string(), "--cache-dir", + cache_dir.string()}; + if (!directory.empty()) { + command.push_back("--directory"); + command.push_back(directory); + } + command.push_back(action); + return command; +} + +} // namespace python_utils diff --git a/engine/extensions/python-engines/python_utils.h b/engine/extensions/python-engines/python_utils.h new file mode 100644 index 000000000..5206eb7f1 --- /dev/null +++ b/engine/extensions/python-engines/python_utils.h @@ -0,0 +1,31 @@ +#pragma once + +#include +#include + +#include "services/download_service.h" +#include "utils/process/utils.h" + +namespace python_utils { + +// paths +std::filesystem::path GetPythonEnginesPath(); +std::filesystem::path GetEnvsPath(); +std::filesystem::path GetUvPath(); + +// UV-related functions +bool UvIsInstalled(); +cpp::result UvInstall(); +std::vector UvBuildCommand(const std::string& action, + const std::string& directory = ""); +bool UvCleanCache(); + +struct PythonSubprocess { + cortex::process::ProcessInfo proc_info; + int port; + uint64_t start_time; + + bool IsAlive() { return cortex::process::IsProcessAlive(proc_info); } + bool Kill() { return cortex::process::KillProcess(proc_info); } +}; +} // namespace python_utils diff --git a/engine/extensions/python-engines/vllm_engine.cc b/engine/extensions/python-engines/vllm_engine.cc new file mode 100644 index 000000000..b05e651c5 --- /dev/null +++ b/engine/extensions/python-engines/vllm_engine.cc @@ -0,0 +1,544 @@ +// Note on subprocess lifecycle +// In LoadModel(), we will wait until /health returns 200. Thus, in subsequent +// calls to the subprocess, if the server is working normally, /health is +// guaranteed to return 200. If it doesn't, it either means the subprocess has +// died or the server hangs (for whatever reason). + +#include "vllm_engine.h" +#include +#include "services/engine_service.h" +#include "utils/curl_utils.h" +#include "utils/logging_utils.h" +#include "utils/system_info_utils.h" + +namespace { +static std::pair CreateResponse( + const std::string& msg, int code) { + Json::Value status, res; + status["status_code"] = code; + status["has_error"] = code != 200; + res["message"] = msg; + return {status, res}; +} + +// this is mostly copied from local_engine.cc +struct StreamContext { + std::shared_ptr> callback; + bool need_stop; + + static size_t write_callback(char* ptr, size_t size, size_t nmemb, + void* userdata) { + auto* ctx = static_cast(userdata); + size_t data_length = size * nmemb; + if (data_length <= 6) + return data_length; + + std::string chunk{ptr, data_length}; + CTL_INF(chunk); + Json::Value status; + status["is_stream"] = true; + status["has_error"] = false; + status["status_code"] = 200; + Json::Value chunk_json; + chunk_json["data"] = chunk; + + if (chunk.find("[DONE]") != std::string::npos) { + status["is_done"] = true; + ctx->need_stop = false; + } else { + status["is_done"] = false; + } + + (*ctx->callback)(std::move(status), std::move(chunk_json)); + return data_length; + }; +}; + +} // namespace + +VllmEngine::VllmEngine() + : cortex_port_{std::stoi( + file_manager_utils::GetCortexConfig().apiServerPort)}, + port_offsets_{true}, // cortex_port + 0 is always used (by cortex itself) + queue_{2 /* threadNum */, "vLLM engine"} {} + +VllmEngine::~VllmEngine() { + // NOTE: what happens if we can't kill subprocess? + std::unique_lock write_lock(mutex_); + for (auto& [model_name, py_proc] : model_process_map_) { + if (py_proc.IsAlive()) + py_proc.Kill(); + } +} + +std::vector VllmEngine::GetVariants() { + const auto vllm_path = python_utils::GetEnvsPath() / "vllm"; + + namespace fs = std::filesystem; + if (!fs::exists(vllm_path)) + return {}; + + std::vector variants; + for (const auto& entry : fs::directory_iterator(vllm_path)) { + const auto name = "linux-amd64-cuda"; // arbitrary + // TODO: after llama-server is merged, check if we need to add "v" + const auto version_str = "v" + entry.path().filename().string(); + const EngineVariantResponse variant{name, version_str, kVllmEngine}; + variants.push_back(variant); + } + return variants; +} + +// TODO: once llama-server is merged, check if checking 'v' is still needed +void VllmEngine::Load(EngineLoadOption opts) { + version_ = opts.engine_path; // engine path actually contains version info + if (version_[0] == 'v') + version_ = version_.substr(1); + return; +}; + +void VllmEngine::Unload(EngineUnloadOption opts) {}; + +void VllmEngine::HandleChatCompletion( + std::shared_ptr json_body, + std::function&& callback) { + + // NOTE: request validation should be in controller + if (!json_body->isMember("model")) { + auto [status, error] = + CreateResponse("Missing required fields: model", 400); + callback(std::move(status), std::move(error)); + return; + } + + const std::string model = (*json_body)["model"].asString(); + int port; + // check if model has started + { + std::shared_lock read_lock(mutex_); + if (model_process_map_.find(model) == model_process_map_.end()) { + const std::string msg = "Model " + model + " has not been loaded yet."; + auto [status, error] = CreateResponse(msg, 400); + callback(std::move(status), std::move(error)); + return; + } + port = model_process_map_[model].port; + } + + const std::string url = + "http://127.0.0.1:" + std::to_string(port) + "/v1/chat/completions"; + const std::string json_str = json_body->toStyledString(); + + bool stream = (*json_body)["stream"].asBool(); + if (stream) { + queue_.runTaskInQueue([url = std::move(url), json_str = std::move(json_str), + callback = std::move(callback)] { + CURL* curl = curl_easy_init(); + if (!curl) { + auto [status, res] = CreateResponse("Internal server error", 500); + callback(std::move(status), std::move(res)); + } + + struct curl_slist* headers = nullptr; + headers = curl_slist_append(headers, "Content-Type: application/json"); + + curl_easy_setopt(curl, CURLOPT_URL, url.c_str()); + curl_easy_setopt(curl, CURLOPT_HTTPHEADER, headers); + curl_easy_setopt(curl, CURLOPT_POST, 1L); + curl_easy_setopt(curl, CURLOPT_POSTFIELDS, json_str.c_str()); + curl_easy_setopt(curl, CURLOPT_POSTFIELDSIZE, json_str.length()); + curl_easy_setopt(curl, CURLOPT_TCP_KEEPALIVE, 1L); + + StreamContext ctx; + ctx.callback = + std::make_shared>( + callback); + ctx.need_stop = true; + curl_easy_setopt(curl, CURLOPT_WRITEFUNCTION, + StreamContext::write_callback); + curl_easy_setopt(curl, CURLOPT_WRITEDATA, &ctx); + + CURLcode res = curl_easy_perform(curl); + if (res != CURLE_OK) { + auto msg = curl_easy_strerror(res); + auto [status, res] = CreateResponse(msg, 500); + callback(std::move(status), std::move(res)); + } + + curl_slist_free_all(headers); + curl_easy_cleanup(curl); + if (ctx.need_stop) { + Json::Value status; + status["is_done"] = true; + status["has_error"] = false; + status["is_stream"] = true; + status["status_code"] = 200; + callback(std::move(status), Json::Value{}); + } + + return; + }); + } else { + // non-streaming + auto result = curl_utils::SimplePostJson(url, json_str); + + if (result.has_error()) { + auto [status, res] = CreateResponse(result.error(), 400); + callback(std::move(status), std::move(res)); + } + + Json::Value status; + status["is_done"] = true; + status["has_error"] = false; + status["is_stream"] = false; + status["status_code"] = 200; + callback(std::move(status), std::move(result.value())); + } +}; + +// NOTE: we don't have an option to pass --task embed to vLLM spawn yet +void VllmEngine::HandleEmbedding( + std::shared_ptr json_body, + std::function&& callback) { + + if (!json_body->isMember("model")) { + auto [status, error] = + CreateResponse("Missing required fields: model", 400); + callback(std::move(status), std::move(error)); + return; + } + + const std::string model = (*json_body)["model"].asString(); + int port; + // check if model has started + { + std::shared_lock read_lock(mutex_); + if (model_process_map_.find(model) == model_process_map_.end()) { + const std::string msg = "Model " + model + " has not been loaded yet."; + auto [status, error] = CreateResponse(msg, 400); + callback(std::move(status), std::move(error)); + return; + } + port = model_process_map_[model].port; + } + + const std::string url = + "http://127.0.0.1:" + std::to_string(port) + "/v1/embeddings"; + const std::string json_str = json_body->toStyledString(); + + auto result = curl_utils::SimplePostJson(url, json_str); + + if (result.has_error()) { + auto [status, res] = CreateResponse(result.error(), 400); + callback(std::move(status), std::move(res)); + } + + Json::Value status; + status["is_done"] = true; + status["has_error"] = false; + status["is_stream"] = false; + status["status_code"] = 200; + callback(std::move(status), std::move(result.value())); +}; + +void VllmEngine::LoadModel( + std::shared_ptr json_body, + std::function&& callback) { + + if (!json_body->isMember("model")) { + auto [status, error] = + CreateResponse("Missing required fields: model", 400); + callback(std::move(status), std::move(error)); + return; + } + + const std::string model = (*json_body)["model"].asString(); + + { + std::unique_lock write_lock(mutex_); + if (model_process_map_.find(model) != model_process_map_.end()) { + auto proc = model_process_map_[model]; + + // NOTE: each vLLM instance can only serve 1 task. It means that the + // following logic will not allow serving the same model for 2 different + // tasks at the same time. + // To support it, we also need to know how vLLM decides the default task. + if (proc.IsAlive()) { + auto [status, error] = CreateResponse("Model already loaded!", 409); + callback(std::move(status), std::move(error)); + return; + } else { + // if model has exited, try to load model again? + CTL_WRN("Model " << model << " has exited unexpectedly"); + model_process_map_.erase(model); + port_offsets_[proc.port - cortex_port_] = false; // free the port + } + } + } + + pid_t pid; + try { + namespace fs = std::filesystem; + + const auto model_path = file_manager_utils::GetCortexDataPath() / "models" / + kHuggingFaceHost / model; + + auto env_dir = python_utils::GetEnvsPath() / "vllm" / version_; + if (!fs::exists(env_dir)) + throw std::runtime_error(env_dir.string() + " does not exist"); + + int offset = 1; + for (;; offset++) { + // add this guard to prevent endless loop + if (offset >= 100) + throw std::runtime_error("Unable to find an available port"); + + if (port_offsets_.size() <= offset) + port_offsets_.push_back(false); + + // check if port is used + if (!port_offsets_[offset]) + break; + } + const int port = cortex_port_ + offset; + + // https://docs.astral.sh/uv/reference/cli/#uv-run + std::vector cmd = + python_utils::UvBuildCommand("run", env_dir.string()); + cmd.push_back("vllm"); + cmd.push_back("serve"); + cmd.push_back(model_path.string()); + cmd.push_back("--port"); + cmd.push_back(std::to_string(port)); + cmd.push_back("--served-model-name"); + cmd.push_back(model); + + // NOTE: we might want to adjust max-model-len automatically, since vLLM + // may OOM for large models as it tries to allocate full context length. + const std::string EXTRA_ARGS[] = {"task", "max-model-len"}; + for (const auto arg : EXTRA_ARGS) { + if (json_body->isMember(arg)) { + cmd.push_back("--" + arg); + cmd.push_back((*json_body)[arg].asString()); + } + } + + const auto stdout_file = env_dir / "stdout.log"; + const auto stderr_file = env_dir / "stderr.log"; + + // create empty files for redirection + // TODO: add limit on file size? + if (!std::filesystem::exists(stdout_file)) + std::ofstream(stdout_file).flush(); + if (!std::filesystem::exists(stderr_file)) + std::ofstream(stderr_file).flush(); + + auto result = cortex::process::SpawnProcess(cmd, stdout_file.string(), + stderr_file.string()); + if (result.has_error()) { + throw std::runtime_error(result.error()); + } + auto proc_info = result.value(); + pid = proc_info.pid; + + // wait for server to be up + // NOTE: should we add a timeout to avoid endless loop? + while (true) { + CTL_INF("Wait for vLLM server to be up. Sleep for 5s"); + std::this_thread::sleep_for(std::chrono::seconds(5)); + if (!cortex::process::IsProcessAlive(proc_info)) + throw std::runtime_error("vLLM subprocess fails to start"); + + const auto url = "http://127.0.0.1:" + std::to_string(port) + "/health"; + if (curl_utils::SimpleGet(url).has_value()) + break; + } + + python_utils::PythonSubprocess py_proc; + py_proc.proc_info = proc_info; + py_proc.port = port; + py_proc.start_time = std::chrono::system_clock::now().time_since_epoch() / + std::chrono::milliseconds(1); + + std::unique_lock write_lock(mutex_); + model_process_map_[model] = py_proc; + + } catch (const std::exception& e) { + auto e_msg = e.what(); + auto [status, error] = CreateResponse(e_msg, 500); + callback(std::move(status), std::move(error)); + return; + } + + auto [status, res] = CreateResponse( + "Model loaded successfully with pid: " + std::to_string(pid), 200); + callback(std::move(status), std::move(res)); +}; + +void VllmEngine::UnloadModel( + std::shared_ptr json_body, + std::function&& callback) { + if (!json_body->isMember("model")) { + auto [status, error] = CreateResponse("Missing required field: model", 400); + callback(std::move(status), std::move(error)); + return; + } + + const std::string model = (*json_body)["model"].asString(); + + // check if model has started + { + std::shared_lock read_lock(mutex_); + if (model_process_map_.find(model) == model_process_map_.end()) { + const std::string msg = "Model " + model + " has not been loaded yet."; + auto [status, error] = CreateResponse(msg, 400); + callback(std::move(status), std::move(error)); + return; + } + } + + // we know that model has started + { + std::unique_lock write_lock(mutex_); + auto proc = model_process_map_[model]; + + // check if subprocess is still alive + // NOTE: is this step necessary? the subprocess could have terminated + // after .IsAlive() and before .Kill() later. + if (!proc.IsAlive()) { + model_process_map_.erase(model); + port_offsets_[proc.port - cortex_port_] = false; // free the port + + const std::string msg = "Model " + model + " stopped running."; + auto [status, error] = CreateResponse(msg, 400); + callback(std::move(status), std::move(error)); + return; + } + + // subprocess is alive. we kill it here. + if (!model_process_map_[model].Kill()) { + const std::string msg = "Unable to kill process of model " + model; + auto [status, error] = CreateResponse(msg, 500); + callback(std::move(status), std::move(error)); + return; + } + + model_process_map_.erase(model); + port_offsets_[proc.port - cortex_port_] = false; // free the port + } + + auto [status, res] = CreateResponse("Unload model successfully", 200); + callback(std::move(status), std::move(res)); +}; + +void VllmEngine::GetModelStatus( + std::shared_ptr json_body, + std::function&& callback) { + + if (!json_body->isMember("model")) { + auto [status, error] = CreateResponse("Missing required field: model", 400); + callback(std::move(status), std::move(error)); + return; + } + + const std::string model = (*json_body)["model"].asString(); + // check if model has started + { + std::shared_lock read_lock(mutex_); + if (model_process_map_.find(model) == model_process_map_.end()) { + const std::string msg = "Model " + model + " has not been loaded yet."; + auto [status, error] = CreateResponse(msg, 400); + callback(std::move(status), std::move(error)); + return; + } + } + + // we know that model has started + { + std::unique_lock write_lock(mutex_); + auto py_proc = model_process_map_[model]; + + // health check endpoint + const auto url = + "http://127.0.0.1:" + std::to_string(py_proc.port) + "/health"; + if (curl_utils::SimpleGet(url).has_value()) { + Json::Value status; + status["is_done"] = true; + status["has_error"] = false; + status["is_stream"] = false; + status["status_code"] = 200; + callback(std::move(status), Json::Value{}); + } else { + // try to kill the subprocess to free resources, in case the server hangs + // instead of subprocess has died. + py_proc.Kill(); + + CTL_WRN("Model " << model << " has exited unexpectedly."); + model_process_map_.erase(model); + const std::string msg = "Model " + model + " stopped running."; + auto [status, error] = CreateResponse(msg, 400); + callback(std::move(status), std::move(error)); + } + } +}; + +bool VllmEngine::IsSupported(const std::string& f) { + return true; +}; + +void VllmEngine::GetModels( + std::shared_ptr jsonBody, + std::function&& callback) { + Json::Value res, model_list(Json::arrayValue), status; + { + std::unique_lock write_lock(mutex_); + for (auto& [model_name, py_proc] : model_process_map_) { + const auto url = + "http://127.0.0.1:" + std::to_string(py_proc.port) + "/health"; + if (curl_utils::SimpleGet(url).has_error()) { + // try to kill the subprocess to free resources, in case the server hangs + // instead of subprocess has died. + py_proc.Kill(); + + CTL_WRN("Model " << model_name << " has exited unexpectedly."); + model_process_map_.erase(model_name); + continue; + } + + Json::Value val; + val["id"] = model_name; + val["engine"] = kVllmEngine; + val["start_time"] = py_proc.start_time; + val["port"] = py_proc.port; + val["object"] = "model"; + // TODO + // val["ram"]; + // val["vram"]; + model_list.append(val); + } + } + + res["object"] = "list"; + res["data"] = model_list; + + status["is_done"] = true; + status["has_error"] = false; + status["is_stream"] = false; + status["status_code"] = 200; + + callback(std::move(status), std::move(res)); +}; + +bool VllmEngine::SetFileLogger(int max_log_lines, const std::string& log_path) { + CTL_WRN("Not implemented"); + throw std::runtime_error("Not implemented"); +}; +void VllmEngine::SetLogLevel(trantor::Logger::LogLevel logLevel) { + CTL_WRN("Not implemented"); + throw std::runtime_error("Not implemented"); +}; + +void VllmEngine::StopInferencing(const std::string& model_id) { + CTL_WRN("Not implemented"); + throw std::runtime_error("Not implemented"); +}; diff --git a/engine/extensions/python-engines/vllm_engine.h b/engine/extensions/python-engines/vllm_engine.h new file mode 100644 index 000000000..d7724b703 --- /dev/null +++ b/engine/extensions/python-engines/vllm_engine.h @@ -0,0 +1,62 @@ +#include +#include "common/engine_servicei.h" +#include "cortex-common/EngineI.h" +#include "python_utils.h" +#include "trantor/utils/ConcurrentTaskQueue.h" + +class VllmEngine : public EngineI { + private: + std::string version_; + int cortex_port_; + + // port_offsets_[i] == true means cortex_port + i is used + // otherwise, cortex_port + i is not used + std::vector port_offsets_; + + mutable std::shared_mutex mutex_; + std::unordered_map + model_process_map_; + + // TODO: will use cortex's main TaskQueue once llama.cpp PR is merged + trantor::ConcurrentTaskQueue queue_; + + public: + VllmEngine(); + ~VllmEngine(); + + static std::vector GetVariants(); + + void Load(EngineLoadOption opts) override; + void Unload(EngineUnloadOption opts) override; + + // cortex.llamacpp interface + void HandleChatCompletion( + std::shared_ptr json_body, + std::function&& callback) override; + void HandleEmbedding( + std::shared_ptr json_body, + std::function&& callback) override; + void LoadModel( + std::shared_ptr json_body, + std::function&& callback) override; + void UnloadModel( + std::shared_ptr json_body, + std::function&& callback) override; + void GetModelStatus( + std::shared_ptr json_body, + std::function&& callback) override; + + // For backward compatible checking + bool IsSupported(const std::string& f) override; + + // Get list of running models + void GetModels( + std::shared_ptr jsonBody, + std::function&& callback) override; + + bool SetFileLogger(int max_log_lines, const std::string& log_path) override; + void SetLogLevel(trantor::Logger::LogLevel logLevel) override; + + // Stop inflight chat completion in stream mode + void StopInferencing(const std::string& model_id) override; +}; diff --git a/engine/services/engine_service.cc b/engine/services/engine_service.cc index 48cc6ff37..9df6b74a2 100644 --- a/engine/services/engine_service.cc +++ b/engine/services/engine_service.cc @@ -9,6 +9,7 @@ #include "config/model_config.h" #include "database/engines.h" #include "database/models.h" +#include "extensions/python-engines/vllm_engine.h" #include "extensions/remote-engine/remote_engine.h" #include "utils/archive_utils.h" @@ -184,19 +185,34 @@ cpp::result EngineService::UninstallEngineVariant( } std::optional path_to_remove = std::nullopt; - if (version == std::nullopt && variant == std::nullopt) { - // if no version and variant provided, remove all engines variant of that engine - path_to_remove = file_manager_utils::GetEnginesContainerPath() / ne; - } else if (version != std::nullopt && variant != std::nullopt) { - // if both version and variant are provided, we only remove that variant - path_to_remove = file_manager_utils::GetEnginesContainerPath() / ne / - variant.value() / version.value(); - } else if (version == std::nullopt) { - // if only have variant, we remove all of that variant - path_to_remove = - file_manager_utils::GetEnginesContainerPath() / ne / variant.value(); + + if (ne == kLlamaRepo) { + if (version == std::nullopt && variant == std::nullopt) { + // if no version and variant provided, remove all engines variant of that engine + path_to_remove = file_manager_utils::GetEnginesContainerPath() / ne; + } else if (version != std::nullopt && variant != std::nullopt) { + // if both version and variant are provided, we only remove that variant + path_to_remove = file_manager_utils::GetEnginesContainerPath() / ne / + variant.value() / version.value(); + } else if (version == std::nullopt) { + // if only have variant, we remove all of that variant + path_to_remove = + file_manager_utils::GetEnginesContainerPath() / ne / variant.value(); + } else { + return cpp::fail("No variant provided"); + } + } else if (ne == kVllmEngine) { + // variant is ignored for vLLM + if (version == std::nullopt) { + path_to_remove = python_utils::GetEnvsPath() / "vllm"; + + // we only clean uv cache when all vLLM versions are deleted + python_utils::UvCleanCache(); + } else { + path_to_remove = python_utils::GetEnvsPath() / "vllm" / version.value(); + } } else { - return cpp::fail("No variant provided"); + return cpp::fail("Not implemented for engine " + ne); } if (path_to_remove == std::nullopt) { @@ -220,6 +236,18 @@ cpp::result EngineService::DownloadEngine( const std::string& engine, const std::string& version, const std::optional variant_name) { + if (engine == kLlamaRepo) + return DownloadLlamaCpp(version, variant_name); + if (engine == kVllmEngine) + return DownloadVllm(version, variant_name); + + return cpp::fail("Unknown engine " + engine); +} + +cpp::result EngineService::DownloadLlamaCpp( + const std::string& version, const std::optional variant_name) { + + const std::string engine = kLlamaRepo; auto normalized_version = version == "latest" ? "latest" : string_utils::RemoveSubstring(version, "v"); @@ -360,10 +388,86 @@ cpp::result EngineService::DownloadEngine( return {}; } +cpp::result EngineService::DownloadVllm( + const std::string& version, const std::optional variant_name) { + + auto system_info = system_info_utils::GetSystemInfo(); + if (!(system_info->os == kLinuxOs && system_info->arch == "amd64" && + system_info_utils::IsNvidiaSmiAvailable())) + return cpp::fail( + "vLLM engine is only supported on Linux x86_64 with Nvidia GPU."); + + if (variant_name.has_value()) { + return cpp::fail("variant_name must be empty"); + } + + // NOTE: everything below is not async + // to make it async, we have to run everything in a thread (spawning and waiting + // for subprocesses) + if (!python_utils::UvIsInstalled()) { + auto result = python_utils::UvInstall(); + if (result.has_error()) + return result; + } + + std::string concrete_version = version; + if (version == "latest") { + auto result = curl_utils::SimpleGetJson("https://pypi.org/pypi/vllm/json"); + if (result.has_error()) + return cpp::fail(result.error()); + + auto version_value = result.value()["info"]["version"]; + if (version_value.isNull()) + return cpp::fail("Can't find version in the response"); + concrete_version = version_value.asString(); + } + CTL_INF("Download vLLM " << concrete_version); + namespace fs = std::filesystem; + + const auto vllm_path = + python_utils::GetEnvsPath() / "vllm" / concrete_version; + fs::create_directories(vllm_path); + + // initialize venv + if (!fs::exists(vllm_path / ".venv")) { + std::vector cmd = + python_utils::UvBuildCommand("venv", vllm_path.string()); + cmd.push_back("--relocatable"); + cmd.push_back("--seed"); + auto result = cortex::process::SpawnProcess(cmd); + if (result.has_error()) + return cpp::fail(result.error()); + + // TODO: check return code + cortex::process::WaitProcess(result.value()); + } + + // install vLLM + { + std::vector cmd = + python_utils::UvBuildCommand("pip", vllm_path.string()); + cmd.push_back("install"); + cmd.push_back("vllm==" + concrete_version); + auto result = cortex::process::SpawnProcess(cmd); + if (result.has_error()) + return cpp::fail(result.error()); + + // TODO: check return code + // one reason this may fail is that the requested version does not exist + cortex::process::WaitProcess(result.value()); + } + + auto result = SetDefaultEngineVariant(kVllmEngine, concrete_version, ""); + if (result.has_error()) + return cpp::fail(result.error()); + + return {}; +} + cpp::result EngineService::DownloadCuda( const std::string& engine, bool async) { - if (hw_inf_.sys_inf->os == "mac") { - // mac does not require cuda toolkit + if (hw_inf_.sys_inf->os == "mac" || engine != kLlamaRepo) { + // mac and non-llama.cpp engine do not require cuda toolkit return true; } @@ -550,8 +654,14 @@ EngineService::SetDefaultEngineVariant(const std::string& engine, auto normalized_version = string_utils::RemoveSubstring(version, "v"); auto config = file_manager_utils::GetCortexConfig(); - config.llamacppVersion = "v" + normalized_version; - config.llamacppVariant = variant; + if (ne == kLlamaRepo) { + config.llamacppVersion = "v" + normalized_version; + config.llamacppVariant = variant; + } else if (ne == kVllmEngine) { + config.vllmVersion = "v" + normalized_version; + } else { + return cpp::fail("Unrecognized engine " + engine); + } auto result = file_manager_utils::UpdateCortexConfig(config); if (result.has_error()) { return cpp::fail(result.error()); @@ -591,18 +701,23 @@ cpp::result EngineService::IsEngineVariantReady( cpp::result EngineService::GetDefaultEngineVariant(const std::string& engine) { auto ne = cortex::engine::NormalizeEngine(engine); - // current we don't support other engine - if (ne != kLlamaRepo) { - return cpp::fail("Engine " + engine + " is not supported yet!"); - } auto config = file_manager_utils::GetCortexConfig(); - auto variant = config.llamacppVariant; - auto version = config.llamacppVersion; - - if (variant.empty() || version.empty()) { - return cpp::fail("Default engine variant for " + engine + - " is not set yet!"); + std::string variant, version; + if (engine == kLlamaRepo) { + variant = config.llamacppVariant; + version = config.llamacppVersion; + if (variant.empty() || version.empty()) + return cpp::fail("Default engine version and variant for " + engine + + " is not set yet!"); + } else if (engine == kVllmEngine) { + variant = ""; + version = config.vllmVersion; + if (version.empty()) + return cpp::fail("Default engine version for " + engine + + " is not set yet!"); + } else { + return cpp::fail("Engine " + engine + " is not supported yet!"); } return DefaultEngineVariant{ @@ -617,6 +732,9 @@ EngineService::GetInstalledEngineVariants(const std::string& engine) const { auto ne = cortex::engine::NormalizeEngine(engine); auto os = hw_inf_.sys_inf->os; + if (ne == kVllmEngine) + return VllmEngine::GetVariants(); + auto engines_variants_dir = file_manager_utils::GetEnginesContainerPath() / ne; @@ -681,6 +799,7 @@ cpp::result EngineService::LoadEngine( CTL_INF("Engine " << ne << " is already loaded"); return {}; } + CTL_INF("Loading engine: " << ne); // Check for remote engine if (IsRemoteEngine(engine_name)) { @@ -697,9 +816,24 @@ cpp::result EngineService::LoadEngine( return {}; } + // check for vLLM engine + if (engine_name == kVllmEngine) { + auto engine = new VllmEngine(); + EngineI::EngineLoadOption load_opts; + + auto result = GetDefaultEngineVariant(engine_name); + if (result.has_error()) + return cpp::fail(result.error()); + + // we set version to engine_path + load_opts.engine_path = result.value().version; + engine->Load(load_opts); + engines_[engine_name].engine = engine; + return {}; + } + // End hard code - CTL_INF("Loading engine: " << ne); #if defined(_WIN32) || defined(_WIN64) || defined(__linux__) CTL_INF("CPU Info: " << cortex::cpuid::CpuInfo().to_string()); #endif @@ -912,8 +1046,6 @@ cpp::result EngineService::IsEngineReady( return true; } - auto os = hw_inf_.sys_inf->os; - auto installed_variants = GetInstalledEngineVariants(engine); if (installed_variants.has_error()) { return cpp::fail(installed_variants.error()); @@ -1100,6 +1232,11 @@ cpp::result EngineService::GetRemoteModels( bool EngineService::IsRemoteEngine(const std::string& engine_name) const { auto ne = Repo2Engine(engine_name); + + if (ne == kLlamaEngine || ne == kVllmEngine) + return false; + return true; + auto local_engines = file_manager_utils::GetCortexConfig().supportedEngines; for (auto const& le : local_engines) { if (le == ne) @@ -1110,5 +1247,6 @@ bool EngineService::IsRemoteEngine(const std::string& engine_name) const { cpp::result, std::string> EngineService::GetSupportedEngineNames() { + return config_yaml_utils::kDefaultSupportedEngines; return file_manager_utils::GetCortexConfig().supportedEngines; } diff --git a/engine/services/engine_service.h b/engine/services/engine_service.h index 7e6be74c5..a054993c6 100644 --- a/engine/services/engine_service.h +++ b/engine/services/engine_service.h @@ -166,6 +166,14 @@ class EngineService : public EngineServiceI { const std::string& engine, const std::string& version = "latest", const std::optional variant_name = std::nullopt); + cpp::result DownloadLlamaCpp( + const std::string& version = "latest", + const std::optional variant_name = std::nullopt); + + cpp::result DownloadVllm( + const std::string& version = "latest", + const std::optional variant_name = std::nullopt); + cpp::result DownloadCuda(const std::string& engine, bool async = false); diff --git a/engine/services/inference_service.cc b/engine/services/inference_service.cc index a1646495b..86d452c75 100644 --- a/engine/services/inference_service.cc +++ b/engine/services/inference_service.cc @@ -111,7 +111,9 @@ cpp::result InferenceService::HandleEmbedding( std::shared_ptr q, std::shared_ptr json_body) { std::string engine_type; if (!HasFieldInReq(json_body, "engine")) { - engine_type = kLlamaRepo; + auto engine_type_maybe = + GetEngineByModelId((*json_body)["model"].asString()); + engine_type = engine_type_maybe.empty() ? kLlamaRepo : engine_type_maybe; } else { engine_type = (*(json_body)).get("engine", kLlamaRepo).asString(); } @@ -161,18 +163,16 @@ InferResult InferenceService::LoadModel( } // might need mutex here - auto engine_result = engine_service_->GetLoadedEngine(engine_type); + auto engine = engine_service_->GetLoadedEngine(engine_type).value(); auto cb = [&stt, &r](Json::Value status, Json::Value res) { stt = status; r = res; }; - if (std::holds_alternative(engine_result.value())) { - std::get(engine_result.value()) - ->LoadModel(json_body, std::move(cb)); + if (std::holds_alternative(engine)) { + std::get(engine)->LoadModel(json_body, std::move(cb)); } else { - std::get(engine_result.value()) - ->LoadModel(json_body, std::move(cb)); + std::get(engine)->LoadModel(json_body, std::move(cb)); } // Save model config to reload if needed auto model_id = json_body->get("model", "").asString(); @@ -203,12 +203,13 @@ InferResult InferenceService::UnloadModel(const std::string& engine_name, stt = status; r = res; }; - if (std::holds_alternative(engine_result.value())) { - std::get(engine_result.value()) - ->UnloadModel(std::make_shared(json_body), std::move(cb)); + auto engine = engine_result.value(); + if (std::holds_alternative(engine)) { + std::get(engine)->UnloadModel( + std::make_shared(json_body), std::move(cb)); } else { - std::get(engine_result.value()) - ->UnloadModel(std::make_shared(json_body), std::move(cb)); + std::get(engine)->UnloadModel( + std::make_shared(json_body), std::move(cb)); } return std::make_pair(stt, r); @@ -241,12 +242,11 @@ InferResult InferenceService::GetModelStatus( stt = status; r = res; }; - if (std::holds_alternative(engine_result.value())) { - std::get(engine_result.value()) - ->GetModelStatus(json_body, std::move(cb)); + auto engine = engine_result.value(); + if (std::holds_alternative(engine)) { + std::get(engine)->GetModelStatus(json_body, std::move(cb)); } else { - std::get(engine_result.value()) - ->GetModelStatus(json_body, std::move(cb)); + std::get(engine)->GetModelStatus(json_body, std::move(cb)); } return std::make_pair(stt, r); diff --git a/engine/services/model_service.cc b/engine/services/model_service.cc index d9359b698..accc9787e 100644 --- a/engine/services/model_service.cc +++ b/engine/services/model_service.cc @@ -101,12 +101,14 @@ void ParseGguf(DatabaseService& db_service, } } -cpp::result GetDownloadTask( - const std::string& modelId, const std::string& branch = "main") { +cpp::result GetCloneRepoDownloadTask( + const std::string& author_id, const std::string& modelId, + const std::string& branch, const std::vector& save_dir, + const std::string& task_id) { url_parser::Url url = {/* .protocol = */ "https", /* .host = */ kHuggingFaceHost, /* .pathParams = */ - {"api", "models", "cortexso", modelId, "tree", branch}, + {"api", "models", author_id, modelId, "tree", branch}, {}}; auto result = curl_utils::SimpleGetJsonRecursive(url.ToFullPath()); @@ -115,8 +117,9 @@ cpp::result GetDownloadTask( } std::vector download_items{}; - auto model_container_path = file_manager_utils::GetModelsContainerPath() / - "cortex.so" / modelId / branch; + auto model_container_path = file_manager_utils::GetModelsContainerPath(); + for (auto subdir : save_dir) + model_container_path /= subdir; file_manager_utils::CreateDirectoryRecursively(model_container_path.string()); for (const auto& value : result.value()) { @@ -129,7 +132,7 @@ cpp::result GetDownloadTask( url_parser::Url download_url = { /* .protocol = */ "https", /* .host = */ kHuggingFaceHost, - /* .pathParams = */ {"cortexso", modelId, "resolve", branch, path}, + /* .pathParams = */ {author_id, modelId, "resolve", branch, path}, {}}; auto local_path = model_container_path / path; @@ -147,7 +150,7 @@ cpp::result GetDownloadTask( } return DownloadTask{ - /* .id = */ branch == "main" ? modelId : modelId + "-" + branch, + /* .id = */ task_id, /* .status = */ DownloadTask::Status::Pending, /* .type = */ DownloadType::Model, /* .items = */ download_items}; @@ -188,7 +191,12 @@ void ModelService::ForceIndexingModelList() { if (model_entry.status != cortex::db::ModelStatus::Downloaded) { continue; } + if (model_entry.engine == kVllmEngine) { + // TODO: check if folder still exists? + continue; + } try { + // check if path_to_model_yaml still exists CTL_DBG(fmu::ToAbsoluteCortexDataPath( fs::path(model_entry.path_to_model_yaml)) .string()); @@ -212,22 +220,29 @@ std::optional ModelService::GetDownloadedModel( const std::string& modelId) const { config::YamlHandler yaml_handler; - auto model_entry = db_service_->GetModelInfo(modelId); - if (!model_entry.has_value()) { + auto result = db_service_->GetModelInfo(modelId); + if (result.has_error()) { return std::nullopt; } + auto model_entry = result.value(); + + // ignore all other params + if (model_entry.engine == kVllmEngine) { + config::ModelConfig cfg; + cfg.engine = kVllmEngine; + return cfg; + } try { config::YamlHandler yaml_handler; namespace fs = std::filesystem; namespace fmu = file_manager_utils; yaml_handler.ModelConfigFromFile( - fmu::ToAbsoluteCortexDataPath( - fs::path(model_entry.value().path_to_model_yaml)) + fmu::ToAbsoluteCortexDataPath(fs::path(model_entry.path_to_model_yaml)) .string()); return yaml_handler.GetModelConfig(); } catch (const std::exception& e) { - LOG_ERROR << "Error reading yaml file '" << model_entry->path_to_model_yaml + LOG_ERROR << "Error reading yaml file '" << model_entry.path_to_model_yaml << "': " << e.what(); return std::nullopt; } @@ -316,6 +331,56 @@ cpp::result ModelService::HandleDownloadUrlAsync( return download_service_->AddTask(downloadTask, on_finished); } +cpp::result ModelService::DownloadHfModelAsync( + const std::string& author_id, const std::string& model_id) { + + const std::string unique_model_id = author_id + "/" + model_id; + auto model_entry = db_service_->GetModelInfo(unique_model_id); + if (model_entry.has_value() && + model_entry->status == cortex::db::ModelStatus::Downloaded) + return cpp::fail("Please delete the model before downloading again"); + + auto download_task = GetCloneRepoDownloadTask( + author_id, model_id, "main", {kHuggingFaceHost, author_id, model_id}, + unique_model_id); + if (download_task.has_error()) + return download_task; + + // TODO: validate that this is a vllm-compatible model + auto on_finished = [this, author_id, + unique_model_id](const DownloadTask& finishedTask) { + if (!db_service_->HasModel(unique_model_id)) { + CTL_INF("Before creating model entry"); + cortex::db::ModelEntry model_entry{ + .model = unique_model_id, + .author_repo_id = author_id, + .branch_name = "main", + .path_to_model_yaml = "", + .model_alias = unique_model_id, + .status = cortex::db::ModelStatus::Downloaded, + .engine = kVllmEngine}; + + auto result = db_service_->AddModelEntry(model_entry); + if (result.has_error()) { + CTL_ERR("Error adding model to modellist: " + result.error()); + } + } else { + if (auto m = db_service_->GetModelInfo(unique_model_id); m.has_value()) { + auto upd_m = m.value(); + upd_m.status = cortex::db::ModelStatus::Downloaded; + if (auto r = db_service_->UpdateModelEntry(unique_model_id, upd_m); + r.has_error()) { + CTL_ERR(r.error()); + } + } else { + CTL_WRN("Could not get model entry with model id: " << unique_model_id); + } + } + }; + + return download_service_->AddTask(download_task.value(), on_finished); +} + std::optional ModelService::GetEstimation( const std::string& model_handle) { std::lock_guard l(es_mtx_); @@ -375,27 +440,24 @@ bool ModelService::HasModel(const std::string& id) const { cpp::result ModelService::DownloadModelFromCortexsoAsync( - const std::string& name, const std::string& branch, + const std::string& model_name, const std::string& branch, std::optional temp_model_id) { - auto download_task = GetDownloadTask(name, branch); - if (download_task.has_error()) { - return cpp::fail(download_task.error()); - } - - std::string unique_model_id = ""; - if (temp_model_id.has_value()) { - unique_model_id = temp_model_id.value(); - } else { - unique_model_id = name + ":" + branch; - } - + std::string unique_model_id = + temp_model_id.value_or(model_name + ":" + branch); auto model_entry = db_service_->GetModelInfo(unique_model_id); if (model_entry.has_value() && model_entry->status == cortex::db::ModelStatus::Downloaded) { return cpp::fail("Please delete the model before downloading again"); } + auto download_task = GetCloneRepoDownloadTask( + "cortexso", model_name, branch, {"cortex.so", model_name, branch}, + unique_model_id); + if (download_task.has_error()) { + return cpp::fail(download_task.error()); + } + auto on_finished = [this, unique_model_id, branch](const DownloadTask& finishedTask) { const DownloadItem* model_yml_item = nullptr; @@ -467,9 +529,7 @@ ModelService::DownloadModelFromCortexsoAsync( } }; - auto task = download_task.value(); - task.id = unique_model_id; - return download_service_->AddTask(task, on_finished); + return download_service_->AddTask(download_task.value(), on_finished); } cpp::result ModelService::DeleteModel( @@ -564,14 +624,35 @@ cpp::result ModelService::StartModel( Json::Value json_data; // Currently we don't support download vision models, so we need to bypass check if (!bypass_model_check) { - auto model_entry = db_service_->GetModelInfo(model_handle); - if (model_entry.has_error()) { - CTL_WRN("Error: " + model_entry.error()); - return cpp::fail(model_entry.error()); + auto result = db_service_->GetModelInfo(model_handle); + if (result.has_error()) { + CTL_WRN("Error: " + result.error()); + return cpp::fail(result.error()); } + auto model_entry = result.value(); + + if (model_entry.engine == kVllmEngine) { + Json::Value json_data; + json_data["model"] = model_handle; + json_data["engine"] = kVllmEngine; + auto [status, data] = + inference_svc_->LoadModel(std::make_shared(json_data)); + + auto status_code = status["status_code"].asInt(); + if (status_code == drogon::k200OK) { + return StartModelResult{true, ""}; + } else if (status_code == drogon::k409Conflict) { + CTL_INF("Model '" + model_handle + "' is already loaded"); + return StartModelResult{true, ""}; + } else { + return cpp::fail("Model failed to start: " + + data["message"].asString()); + } + } + yaml_handler.ModelConfigFromFile( fmu::ToAbsoluteCortexDataPath( - fs::path(model_entry.value().path_to_model_yaml)) + fs::path(model_entry.path_to_model_yaml)) .string()); auto mc = yaml_handler.GetModelConfig(); @@ -579,17 +660,15 @@ cpp::result ModelService::StartModel( if (engine_svc_->IsRemoteEngine(mc.engine)) { (void)engine_svc_->LoadEngine(mc.engine); config::RemoteModelConfig remote_mc; - remote_mc.LoadFromYamlFile( - fmu::ToAbsoluteCortexDataPath( - fs::path(model_entry.value().path_to_model_yaml)) - .string()); - auto remote_engine_entry = - engine_svc_->GetEngineByNameAndVariant(mc.engine); - if (remote_engine_entry.has_error()) { - CTL_WRN("Remote engine error: " + model_entry.error()); - return cpp::fail(remote_engine_entry.error()); + remote_mc.LoadFromYamlFile(fmu::ToAbsoluteCortexDataPath( + fs::path(model_entry.path_to_model_yaml)) + .string()); + auto result = engine_svc_->GetEngineByNameAndVariant(mc.engine); + if (result.has_error()) { + CTL_WRN("Remote engine error: " + result.error()); + return cpp::fail(result.error()); } - auto remote_engine_json = remote_engine_entry.value().ToJson(); + auto remote_engine_json = result.value().ToJson(); json_data = remote_mc.ToJson(); json_data["api_key"] = std::move(remote_engine_json["api_key"]); @@ -597,10 +676,9 @@ cpp::result ModelService::StartModel( !v.empty() && v != "latest") { json_data["version"] = v; } - json_data["model_path"] = - fmu::ToAbsoluteCortexDataPath( - fs::path(model_entry.value().path_to_model_yaml)) - .string(); + json_data["model_path"] = fmu::ToAbsoluteCortexDataPath( + fs::path(model_entry.path_to_model_yaml)) + .string(); json_data["metadata"] = std::move(remote_engine_json["metadata"]); auto ir = @@ -735,17 +813,23 @@ cpp::result ModelService::StopModel( bypass_stop_check_set_.end()); std::string engine_name = ""; if (!bypass_check) { - auto model_entry = db_service_->GetModelInfo(model_handle); - if (model_entry.has_error()) { - CTL_WRN("Error: " + model_entry.error()); - return cpp::fail(model_entry.error()); + auto result = db_service_->GetModelInfo(model_handle); + if (result.has_error()) { + CTL_WRN("Error: " + result.error()); + return cpp::fail(result.error()); + } + + const auto model_entry = result.value(); + if (model_entry.engine == kVllmEngine) { + engine_name = kVllmEngine; + } else { + yaml_handler.ModelConfigFromFile( + fmu::ToAbsoluteCortexDataPath( + fs::path(model_entry.path_to_model_yaml)) + .string()); + auto mc = yaml_handler.GetModelConfig(); + engine_name = mc.engine; } - yaml_handler.ModelConfigFromFile( - fmu::ToAbsoluteCortexDataPath( - fs::path(model_entry.value().path_to_model_yaml)) - .string()); - auto mc = yaml_handler.GetModelConfig(); - engine_name = mc.engine; } if (bypass_check) { engine_name = kLlamaEngine; @@ -882,23 +966,19 @@ cpp::result ModelService::GetModelPullInfo( huggingface_utils::GetHuggingFaceModelRepoInfo(author, model_name); if (!repo_info.has_value()) { - return cpp::fail("Model not found"); + return cpp::fail("Model not found on " + std::string{kHuggingFaceHost}); } - if (!repo_info->gguf.has_value()) { - return cpp::fail( - "Not a GGUF model. Currently, only GGUF single file is " - "supported."); - } - - std::vector options{}; - for (const auto& sibling : repo_info->siblings) { - if (string_utils::EndsWith(sibling.rfilename, ".gguf")) { - options.push_back(sibling.rfilename); + // repo containing GGUF files + if (repo_info->gguf.has_value()) { + std::vector options{}; + for (const auto& sibling : repo_info->siblings) { + if (string_utils::EndsWith(sibling.rfilename, ".gguf")) { + options.push_back(sibling.rfilename); + } } - } - return ModelPullInfo{ + return ModelPullInfo{ /* .id = */ author + ":" + model_name, /* .default_branch = */ "main", /* .downloaded_models = */ {}, @@ -906,6 +986,23 @@ cpp::result ModelService::GetModelPullInfo( /* .model_source = */ "", /* .download_url = */ huggingface_utils::GetDownloadableUrl(author, model_name, "")}; + } + + // repo that is supported by HF transformers + // we will download the whole repo + if (repo_info->library_name.value_or("") == "transformers") { + return ModelPullInfo{ + /* .id = */ author + "/" + model_name, + /* .default_branch = */ "main", + /* .downloaded_models = */ {}, + /* .available_models = */ {}, + /* .model_source = */ "huggingface", + /* .download_url = */ ""}; + } + + return cpp::fail( + "Unsupported model. Currently, only GGUF models and HF models are " + "supported."); } } auto branches = @@ -953,6 +1050,54 @@ cpp::result ModelService::GetModelPullInfo( /* .download_url = */ ""}; } +cpp::result ModelService::PullModel( + const std::string& model_handle, + const std::optional& desired_model_id, + const std::optional& desired_model_name) { + CTL_INF("Handle model input, model handle: " + model_handle); + + if (string_utils::StartsWith(model_handle, "https")) + return HandleDownloadUrlAsync(model_handle, desired_model_id, + desired_model_name); + + // HF model handle + if (model_handle.find("/") != std::string::npos) { + const auto author_model = string_utils::SplitBy(model_handle, "/"); + if (author_model.size() != 2) + return cpp::fail("Invalid model handle"); + + return DownloadHfModelAsync(author_model[0], author_model[1]); + } + + if (model_handle.find(":") == std::string::npos) + return cpp::fail("Invalid model handle or not supported!"); + + auto model_and_branch = string_utils::SplitBy(model_handle, ":"); + + // cortexso format - model:branch + if (model_and_branch.size() == 2) + return DownloadModelFromCortexsoAsync( + model_and_branch[0], model_and_branch[1], desired_model_id); + + if (model_and_branch.size() == 3) { + // single GGUF file + // author_id:model_name:filename + auto mh = url_parser::Url{ + .protocol = "https", + .host = kHuggingFaceHost, + .pathParams = { + model_and_branch[0], + model_and_branch[1], + "resolve", + "main", + model_and_branch[2], + }}.ToFullPath(); + return HandleDownloadUrlAsync(mh, desired_model_id, desired_model_name); + } + + return cpp::fail("Invalid model handle or not supported!"); +} + cpp::result ModelService::AbortDownloadModel( const std::string& task_id) { return download_service_->StopTask(task_id); @@ -1107,6 +1252,10 @@ std::string ModelService::GetEngineByModelId( CTL_WRN("Error: " + model_entry.error()); return ""; } + + if (model_entry.value().engine == kVllmEngine) + return kVllmEngine; + config::YamlHandler yaml_handler; yaml_handler.ModelConfigFromFile( fmu::ToAbsoluteCortexDataPath( diff --git a/engine/services/model_service.h b/engine/services/model_service.h index beba91f8c..e61d17171 100644 --- a/engine/services/model_service.h +++ b/engine/services/model_service.h @@ -39,13 +39,14 @@ class ModelService { std::shared_ptr engine_svc, cortex::TaskQueue& task_queue); + cpp::result PullModel( + const std::string& model_handle, + const std::optional& desired_model_id, + const std::optional& desired_model_name); + cpp::result AbortDownloadModel( const std::string& task_id); - cpp::result DownloadModelFromCortexsoAsync( - const std::string& name, const std::string& branch = "main", - std::optional temp_model_id = std::nullopt); - std::optional GetDownloadedModel( const std::string& modelId) const; @@ -67,10 +68,6 @@ class ModelService { cpp::result GetModelPullInfo( const std::string& model_handle); - cpp::result HandleDownloadUrlAsync( - const std::string& url, std::optional temp_model_id, - std::optional temp_name); - bool HasModel(const std::string& id) const; std::optional GetEstimation( @@ -89,6 +86,17 @@ class ModelService { std::string GetEngineByModelId(const std::string& model_id) const; private: + cpp::result HandleDownloadUrlAsync( + const std::string& url, std::optional temp_model_id, + std::optional temp_name); + + cpp::result DownloadModelFromCortexsoAsync( + const std::string& name, const std::string& branch = "main", + std::optional temp_model_id = std::nullopt); + + cpp::result DownloadHfModelAsync( + const std::string& author_id, const std::string& model_id); + cpp::result, std::string> MayFallbackToCpu( const std::string& model_path, int ngl, int ctx_len, int n_batch = 2048, int n_ubatch = 2048, const std::string& kv_cache_type = "f16"); diff --git a/engine/utils/config_yaml_utils.cc b/engine/utils/config_yaml_utils.cc index e6843c04c..b3a6b9962 100644 --- a/engine/utils/config_yaml_utils.cc +++ b/engine/utils/config_yaml_utils.cc @@ -36,6 +36,7 @@ cpp::result CortexConfigMgr::DumpYamlConfig( node["gitHubToken"] = config.gitHubToken; node["llamacppVariant"] = config.llamacppVariant; node["llamacppVersion"] = config.llamacppVersion; + node["vllmVersion"] = config.vllmVersion; node["enableCors"] = config.enableCors; node["allowedOrigins"] = config.allowedOrigins; node["proxyUrl"] = config.proxyUrl; @@ -80,7 +81,8 @@ CortexConfig CortexConfigMgr::FromYaml(const std::string& path, !node["logOnnxPath"] || !node["huggingFaceToken"] || !node["gitHubUserAgent"] || !node["gitHubToken"] || !node["llamacppVariant"] || !node["llamacppVersion"] || - !node["enableCors"] || !node["allowedOrigins"] || !node["proxyUrl"] || + !node["vllmVersion"] || !node["enableCors"] || + !node["allowedOrigins"] || !node["proxyUrl"] || !node["proxyUsername"] || !node["proxyPassword"] || !node["verifyPeerSsl"] || !node["verifyHostSsl"] || !node["verifyProxySsl"] || !node["verifyProxyHostSsl"] || @@ -141,6 +143,9 @@ CortexConfig CortexConfigMgr::FromYaml(const std::string& path, /* .llamacppVersion = */ node["llamacppVersion"] ? node["llamacppVersion"].as() : default_cfg.llamacppVersion, + /* .vllmVersion = */ + node["vllmVersion"] ? node["vllmVersion"].as() + : default_cfg.vllmVersion, /* .enableCors = */ node["enableCors"] ? node["enableCors"].as() : default_cfg.enableCors, diff --git a/engine/utils/config_yaml_utils.h b/engine/utils/config_yaml_utils.h index c871fd100..f41b00e54 100644 --- a/engine/utils/config_yaml_utils.h +++ b/engine/utils/config_yaml_utils.h @@ -24,7 +24,7 @@ constexpr const auto kDefaultCorsEnabled = true; const std::vector kDefaultEnabledOrigins{ "http://localhost:39281", "http://127.0.0.1:39281", "http://0.0.0.0:39281"}; constexpr const auto kDefaultNoProxy = "example.com,::1,localhost,127.0.0.1"; -const std::vector kDefaultSupportedEngines{kLlamaEngine}; +const std::vector kDefaultSupportedEngines{kLlamaEngine, kVllmEngine}; struct CortexConfig { std::string logFolderPath; @@ -48,6 +48,7 @@ struct CortexConfig { std::string gitHubToken; std::string llamacppVariant; std::string llamacppVersion; + std::string vllmVersion; bool enableCors; std::vector allowedOrigins; diff --git a/engine/utils/curl_utils.cc b/engine/utils/curl_utils.cc index 1d0be2f70..00aac430c 100644 --- a/engine/utils/curl_utils.cc +++ b/engine/utils/curl_utils.cc @@ -373,4 +373,48 @@ cpp::result SimplePatchJson(const std::string& url, return root; } + +cpp::result SimpleDownload(const std::string& url, + const std::string& save_path, + const int timeout) { + auto curl = curl_easy_init(); + if (!curl) { + return cpp::fail("Failed to init CURL"); + } + + auto headers = GetHeaders(url); + curl_slist* curl_headers = nullptr; + if (headers) { + for (const auto& [key, value] : headers->m) { + auto header = key + ": " + value; + curl_headers = curl_slist_append(curl_headers, header.c_str()); + } + curl_easy_setopt(curl, CURLOPT_HTTPHEADER, curl_headers); + } + + auto file = fopen(save_path.c_str(), "wb"); + if (!file) + return cpp::fail("Failed to open " + save_path); + + SetUpProxy(curl, url); + curl_easy_setopt(curl, CURLOPT_FOLLOWLOCATION, 1L); + curl_easy_setopt(curl, CURLOPT_URL, url.c_str()); + curl_easy_setopt(curl, CURLOPT_WRITEFUNCTION, fwrite); + curl_easy_setopt(curl, CURLOPT_WRITEDATA, file); + if (timeout > 0) { + curl_easy_setopt(curl, CURLOPT_TIMEOUT, timeout); + } + + // Perform the request + auto res = curl_easy_perform(curl); + fclose(file); + curl_slist_free_all(curl_headers); + curl_easy_cleanup(curl); + if (res != CURLE_OK) { + return cpp::fail("CURL request failed: " + + std::string{curl_easy_strerror(res)}); + } + + return {}; +} } // namespace curl_utils diff --git a/engine/utils/curl_utils.h b/engine/utils/curl_utils.h index 9035b6b3c..91a67077e 100644 --- a/engine/utils/curl_utils.h +++ b/engine/utils/curl_utils.h @@ -37,8 +37,8 @@ cpp::result ReadRemoteYaml(const std::string& url); */ cpp::result SimpleGetJson(const std::string& url, const int timeout = -1); -cpp::result SimpleGetJsonRecursive(const std::string& url, - const int timeout = -1); +cpp::result SimpleGetJsonRecursive( + const std::string& url, const int timeout = -1); cpp::result SimplePostJson( const std::string& url, const std::string& body = ""); @@ -49,4 +49,7 @@ cpp::result SimpleDeleteJson( cpp::result SimplePatchJson( const std::string& url, const std::string& body = ""); +cpp::result SimpleDownload(const std::string& url, + const std::string& save_path, + const int timeout = -1); } // namespace curl_utils diff --git a/engine/utils/engine_constants.h b/engine/utils/engine_constants.h index 2c5cd1be3..cf9a6904e 100644 --- a/engine/utils/engine_constants.h +++ b/engine/utils/engine_constants.h @@ -1,13 +1,13 @@ #pragma once constexpr const auto kLlamaEngine = "llama-cpp"; +constexpr const auto kVllmEngine = "vllm"; constexpr const auto kRemote = "remote"; constexpr const auto kLocal = "local"; constexpr const auto kLlamaRepo = "cortex.llamacpp"; - constexpr const auto kLlamaLibPath = "./engines/cortex.llamacpp"; // other constants diff --git a/engine/utils/file_manager_utils.cc b/engine/utils/file_manager_utils.cc index f4ffb99db..1e476b443 100644 --- a/engine/utils/file_manager_utils.cc +++ b/engine/utils/file_manager_utils.cc @@ -208,6 +208,7 @@ config_yaml_utils::CortexConfig GetDefaultConfig() { /* .gitHubToken = */ "", /* .llamacppVariant = */ "", /* .llamacppVersion = */ "", + /* .vllmVersion = */ "", /* .enableCors = */ config_yaml_utils::kDefaultCorsEnabled, /* .allowedOrigins = */ config_yaml_utils::kDefaultEnabledOrigins, /* .proxyUrl = */ "", diff --git a/engine/utils/huggingface_utils.h b/engine/utils/huggingface_utils.h index ad1524fc4..277185c30 100644 --- a/engine/utils/huggingface_utils.h +++ b/engine/utils/huggingface_utils.h @@ -170,6 +170,7 @@ struct HuggingFaceModelRepoInfo { int downloads; int likes; + std::optional library_name; std::optional gguf; std::vector siblings; std::vector spaces; @@ -178,6 +179,10 @@ struct HuggingFaceModelRepoInfo { static cpp::result FromJson( const Json::Value& body) { + std::optional library_name = std::nullopt; + if (body["library_name"]) + library_name = body["library_name"].asString(); + std::optional gguf = std::nullopt; auto gguf_result = HuggingFaceGgufInfo::FromJson(body["gguf"]); if (gguf_result.has_value()) { @@ -208,6 +213,7 @@ struct HuggingFaceModelRepoInfo { /* .downloads = */ body["downloads"].asInt(), /* .likes = */ body["likes"].asInt(), + /* .library_name = */ library_name, /* .gguf = */ gguf, /* .siblings = */ siblings, /* .spaces = */ diff --git a/engine/utils/process/utils.cc b/engine/utils/process/utils.cc index f63de5c5e..8c0120394 100644 --- a/engine/utils/process/utils.cc +++ b/engine/utils/process/utils.cc @@ -11,6 +11,44 @@ extern char** environ; // environment variables #include #endif +namespace { +// retrieve current env vars, make a copy, then add new env vars from input +std::vector BuildEnvVars( + const std::unordered_map& new_env_vars) { +#if defined(_WIN32) + throw std::runtime_error("Not implemented"); +#endif + + // parse current env var to an unordered map + std::unordered_map env_vars_map; + for (int i = 0; environ[i] != nullptr; i++) { + std::string env_var{environ[i]}; + auto split_idx = env_var.find("="); + + if (split_idx == std::string::npos) { + throw std::runtime_error( + "Error while parsing current environment variables"); + } + + env_vars_map[env_var.substr(0, split_idx)] = env_var.substr(split_idx + 1); + } + + // add new env vars. it will override existing env vars + for (const auto& [key, value] : new_env_vars) { + env_vars_map[key] = value; + } + + // convert back to key=value format + std::vector env_vars_vector; + for (const auto& [key, value] : env_vars_map) { + env_vars_vector.push_back(key + "=" + value); + } + + return env_vars_vector; +} + +} // namespace + namespace cortex::process { std::string ConstructWindowsCommandLine(const std::vector& args) { @@ -42,7 +80,10 @@ std::vector ConvertToArgv(const std::vector& args) { cpp::result SpawnProcess( const std::vector& command, const std::string& stdout_file, - const std::string& stderr_file) { + const std::string& stderr_file, + std::optional>> + env_vars) { std::stringstream ss; for (const auto& item : command) { ss << item << " "; @@ -191,6 +232,8 @@ cpp::result SpawnProcess( posix_spawn_file_actions_destroy(action_ptr); throw std::runtime_error("Unable to add stdout to file action"); } + } else { + CTL_WRN(stdout_file + " does not exist"); } } @@ -203,18 +246,33 @@ cpp::result SpawnProcess( posix_spawn_file_actions_destroy(action_ptr); throw std::runtime_error("Unable to add stderr to file action"); } + } else { + CTL_WRN(stderr_file + " does not exist"); } } } + char** envp; + // we put these 2 here so that its lifetime lasts entire function + std::vector env_vars_vector; + std::vector env_vars_; + if (env_vars.has_value()) { + env_vars_vector = BuildEnvVars(env_vars.value()); + env_vars_ = ConvertToArgv(env_vars_vector); + envp = env_vars_.data(); + } else { + envp = environ; // simply inherit current env + } + // Use posix_spawn for cross-platform compatibility + // NOTE: posix_spawn() returns after fork() step. it means that we may + // need to keep argv and envp data alive until exec() step finishes. auto spawn_result = posix_spawn(&pid, // pid output command[0].c_str(), // executable path action_ptr, // file actions NULL, // spawn attributes argv.data(), // argument vector - environ // environment (inherit) - ); + envp); // environment // NOTE: it seems like it's ok to destroy this immediately before // subprocess terminates. diff --git a/engine/utils/process/utils.h b/engine/utils/process/utils.h index 19b821cef..db1ac7460 100644 --- a/engine/utils/process/utils.h +++ b/engine/utils/process/utils.h @@ -12,7 +12,9 @@ using pid_t = DWORD; #include #endif +#include #include +#include #include #include "utils/result.hpp" @@ -36,7 +38,10 @@ std::vector ConvertToArgv(const std::vector& args); cpp::result SpawnProcess( const std::vector& command, - const std::string& stdout_file = "", const std::string& stderr_file = ""); + const std::string& stdout_file = "", const std::string& stderr_file = "", + std::optional>> + env_vars = {}); bool IsProcessAlive(ProcessInfo& proc_info); bool WaitProcess(ProcessInfo& proc_info); bool KillProcess(ProcessInfo& proc_info);