diff --git a/src/chat_completion_request.h b/src/chat_completion_request.h index b78dd66..c64a453 100644 --- a/src/chat_completion_request.h +++ b/src/chat_completion_request.h @@ -1,5 +1,6 @@ #pragma once #include "json/value.h" +#include "sampling.h" namespace llama::inferences { struct ChatCompletionRequest { @@ -12,10 +13,29 @@ struct ChatCompletionRequest { Json::Value stop = Json::Value(Json::arrayValue); Json::Value messages = Json::Value(Json::arrayValue); std::string model_id; + + int seed = -1; + float dynatemp_range = 0.0f; + float dynatemp_exponent = 1.0f; + int top_k = 40; + float min_p = 0.05f; + float tfs_z = 1.0f; + float typ_p = 1.0f; + int repeat_last_n = 64; + float penalty_repeat = 1.0f; + bool mirostat = false; + float mirostat_tau = 5.0f; + float mirostat_eta = 0.1f; + bool penalize_nl = false; + bool ignore_eos = false; + int n_probs = 0; + int min_keep = 0; + std::string grammar; }; inline ChatCompletionRequest fromJson(std::shared_ptr jsonBody) { ChatCompletionRequest completion; + gpt_sampler_params default_params; if (jsonBody) { completion.stream = (*jsonBody).get("stream", false).asBool(); completion.max_tokens = (*jsonBody).get("max_tokens", 500).asInt(); @@ -28,6 +48,24 @@ inline ChatCompletionRequest fromJson(std::shared_ptr jsonBody) { completion.messages = (*jsonBody)["messages"]; completion.stop = (*jsonBody)["stop"]; completion.model_id = (*jsonBody).get("model", {}).asString(); + + completion.seed = (*jsonBody).get("seed", -1).asInt(); + completion.dynatemp_range = (*jsonBody).get("dynatemp_range", 0.0f).asFloat(); + completion.dynatemp_exponent = (*jsonBody).get("dynatemp_exponent", 0.0f).asFloat(); + completion.top_k = (*jsonBody).get("top_k", 40).asInt(); + completion.min_p = (*jsonBody).get("min_p", 0.05f).asFloat(); + completion.tfs_z = (*jsonBody).get("tfs_z", 1.0f).asFloat(); + completion.typ_p = (*jsonBody).get("typ_p", 1.0f).asFloat(); + completion.repeat_last_n = (*jsonBody).get("repeat_last_n", 64).asInt(); + completion.penalty_repeat = (*jsonBody).get("repeat_penalty", 1.1f).asFloat(); + completion.mirostat = (*jsonBody).get("mirostat", false).asBool(); + completion.mirostat_tau = (*jsonBody).get("mirostat_tau", 5.0f).asFloat(); + completion.mirostat_eta = (*jsonBody).get("mirostat_eta", 0.1f).asFloat(); + completion.penalize_nl = (*jsonBody).get("penalize_nl", true).asBool(); + completion.ignore_eos = (*jsonBody).get("ignore_eos", false).asBool(); + completion.n_probs = (*jsonBody).get("n_probs", 0).asInt(); + completion.min_keep = (*jsonBody).get("min_keep", 0).asInt(); + completion.grammar = (*jsonBody).get("grammar", "").asString(); } return completion; } diff --git a/src/llama_engine.cc b/src/llama_engine.cc index b21d842..2865ee6 100644 --- a/src/llama_engine.cc +++ b/src/llama_engine.cc @@ -480,6 +480,10 @@ bool LlamaEngine::LoadModelImpl(std::shared_ptr json_body) { if (!params.use_mmap) { LOG_DEBUG << "Disabled mmap"; } + params.n_predict = json_body->get("n_predict", -1).asInt(); + params.prompt = json_body->get("prompt", "").asString(); + params.conversation = json_body->get("conversation", false).asBool(); + params.special = json_body->get("special", false).asBool(); server_map_[model_id].caching_enabled = json_body->get("caching_enabled", true).asBool(); @@ -599,6 +603,24 @@ void LlamaEngine::HandleInferenceImpl( data["temperature"] = completion.temperature; data["frequency_penalty"] = completion.frequency_penalty; data["presence_penalty"] = completion.presence_penalty; + data["seed"] = completion.seed; + data["dynatemp_range"] = completion.dynatemp_range; + data["dynatemp_exponent"] = completion.dynatemp_exponent; + data["top_k"] = completion.top_k; + data["min_p"] = completion.min_p; + data["tfs_z"] = completion.tfs_z; + data["typical_p"] = completion.typ_p; + data["repeat_last_n"] = completion.repeat_last_n; + data["repeat_penalty"] = completion.penalty_repeat; + data["mirostat"] = completion.mirostat; + data["mirostat_tau"] = completion.mirostat_tau; + data["mirostat_eta"] = completion.mirostat_eta; + data["penalize_nl"] = completion.penalize_nl; + data["ignore_eos"] = completion.ignore_eos; + data["n_probs"] = completion.n_probs; + data["min_keep"] = completion.min_keep; + data["grammar"] = completion.grammar; + int n_probs = completion.n_probs; const Json::Value& messages = completion.messages; if (!si.grammar_file_content.empty()) { @@ -717,12 +739,17 @@ void LlamaEngine::HandleInferenceImpl( auto state = CreateInferenceState(si.ctx); // Queued task - si.q->runTaskInQueue([cb = std::move(callback), state, data, request_id]() { + si.q->runTaskInQueue([cb = std::move(callback), state, data, request_id, n_probs]() { state->task_id = state->llama.RequestCompletion(data, false, false, -1); while (state->llama.model_loaded_external) { TaskResult result = state->llama.NextResult(state->task_id); if (!result.error) { - std::string to_send = result.result_json["content"]; + std::string to_send; + if (n_probs > 0){ + to_send = result.result_json["completion_probabilities"].dump(); + }else{ + to_send = result.result_json["content"]; + } // trim the leading space if it is the first token if (std::exchange(state->is_first_token, false)) { llama_utils::ltrim(to_send); diff --git a/src/llama_server_context.cc b/src/llama_server_context.cc index 11f4a82..e8e7205 100644 --- a/src/llama_server_context.cc +++ b/src/llama_server_context.cc @@ -459,6 +459,15 @@ bool LlamaServerContext::LaunchSlotWithData(LlamaClientSlot*& slot, json data) { slot->params.seed = json_value(data, "seed", default_params.seed); slot->sparams.grammar = json_value(data, "grammar", default_sparams.grammar); slot->sparams.n_probs = json_value(data, "n_probs", default_sparams.n_probs); + slot->sparams.min_keep = + json_value(data, "min_keep", default_sparams.min_keep); + slot->sparams.seed = json_value(data, "seed", default_sparams.seed); + slot->sparams.dynatemp_range = + json_value(data, "dynatemp_range", default_sparams.dynatemp_range); + slot->sparams.dynatemp_exponent = + json_value(data, "dynatemp_exponent", default_sparams.dynatemp_exponent); + slot->sparams.ignore_eos = + json_value(data, "ignore_eos", default_sparams.ignore_eos); // infill if (data.count("input_prefix") != 0) { @@ -970,8 +979,13 @@ void LlamaServerContext::SendFinalResponse(LlamaClientSlot& slot) { slot.generated_token_probs.begin(), slot.generated_token_probs.begin() + slot.sent_token_probs_index); } - res.result_json["completion_probabilities"] = + if(!slot.params.stream ){ + res.result_json["completion_probabilities"] = probs_vector_to_json(ctx, probs); + } + else{ + res.result_json["completion_probabilities"] = std::move(json()); + } } if (slot.oaicompat) {