diff --git a/backend/cpp/llama/grpc-server.cpp b/backend/cpp/llama/grpc-server.cpp index 0d17141f83ea..883fbf8fac8e 100644 --- a/backend/cpp/llama/grpc-server.cpp +++ b/backend/cpp/llama/grpc-server.cpp @@ -467,6 +467,7 @@ struct llama_server_context bool all_slots_are_idle = false; bool add_bos_token = true; bool has_eos_token = true; + bool has_gpu = false; bool grammar_lazy = false; std::vector grammar_triggers; @@ -511,7 +512,10 @@ struct llama_server_context if (!params.mmproj.empty()) { multimodal = true; LOG_INFO("Multi Modal Mode Enabled", {}); - clp_ctx = clip_model_load(params.mmproj.c_str(), /*verbosity=*/ 1); + clp_ctx = clip_init(params.mmproj.c_str(), clip_context_params { + /* use_gpu */ has_gpu, + /*verbosity=*/ 1, + }); if(clp_ctx == nullptr) { LOG_ERR("unable to load clip model: %s", params.mmproj.c_str()); return false; @@ -2314,7 +2318,7 @@ static std::string get_all_kv_cache_types() { } static void params_parse(const backend::ModelOptions* request, - common_params & params) { + common_params & params, llama_server_context &llama) { // this is comparable to: https://github.com/ggerganov/llama.cpp/blob/d9b33fe95bd257b36c84ee5769cc048230067d6f/examples/server/server.cpp#L1809 @@ -2352,6 +2356,20 @@ static void params_parse(const backend::ModelOptions* request, add_rpc_devices(std::string(llama_grpc_servers)); } + // decode options. Options are in form optname:optvale, or if booleans only optname. + for (int i = 0; i < request->options_size(); i++) { + std::string opt = request->options(i); + char *optname = strtok(&opt[0], ":"); + char *optval = strtok(NULL, ":"); + if (optval == NULL) { + optval = "true"; + } + + if (!strcmp(optname, "gpu")) { + llama.has_gpu = true; + } + } + // TODO: Add yarn if (!request->tensorsplit().empty()) { @@ -2445,7 +2463,7 @@ class BackendServiceImpl final : public backend::Backend::Service { grpc::Status LoadModel(ServerContext* context, const backend::ModelOptions* request, backend::Result* result) { // Implement LoadModel RPC common_params params; - params_parse(request, params); + params_parse(request, params, llama); llama_backend_init(); llama_numa_init(params.numa);