From 231a797ea8d7561233387a4eeff764f680f83121 Mon Sep 17 00:00:00 2001 From: mudler Date: Fri, 6 Oct 2023 18:41:15 +0200 Subject: [PATCH] update to latest llama.cpp breaking API changes Signed-off-by: mudler --- .github/workflows/test-gpu.yaml | 4 +- Makefile | 5 +- binding.cpp | 130 +++++++++++++++++--------------- binding.h | 5 +- llama.cpp | 2 +- llama.go | 16 ++-- llama_test.go | 7 +- options.go | 21 ++++-- patches/1902-cuda.patch | 38 +++++----- 9 files changed, 127 insertions(+), 101 deletions(-) diff --git a/.github/workflows/test-gpu.yaml b/.github/workflows/test-gpu.yaml index 0d1d922..bb263f1 100644 --- a/.github/workflows/test-gpu.yaml +++ b/.github/workflows/test-gpu.yaml @@ -45,10 +45,12 @@ jobs: sudo DEBIAN_FRONTEND=noninteractive apt-get install -y pip wget - name: Build and test run: | + set -o pipefail GPU_TESTS=true BUILD_TYPE=cublas CMAKE_ARGS="-DLLAMA_METAL=OFF -DLLAMA_F16C=OFF -DLLAMA_AVX512=OFF -DLLAMA_AVX2=OFF -DLLAMA_FMA=OFF" \ make test 2>&1 | tee test_log.log + set +o pipefail if grep -q "using CUDA for GPU acceleration" test_log.log; then - echo "All good"; + echo "GPU was used"; else echo "No CUDA found"; exit 1; diff --git a/Makefile b/Makefile index 7dc8ffa..bb51e12 100644 --- a/Makefile +++ b/Makefile @@ -232,7 +232,8 @@ binding.o: prepare llama.cpp/ggml.o llama.cpp/llama.o llama.cpp/common.o llama.c ## https://github.com/ggerganov/llama.cpp/pull/1902 prepare: - cd llama.cpp && patch -p1 < ../patches/1902-cuda.patch + cd llama.cpp && \ + patch -p1 < ../patches/1902-cuda.patch touch $@ libbinding.a: prepare binding.o llama.cpp/k_quants.o llama.cpp/grammar-parser.o llama.cpp/ggml-alloc.o $(EXTRA_TARGETS) @@ -248,4 +249,4 @@ ggllm-test-model.bin: wget -q https://huggingface.co/TheBloke/CodeLlama-7B-Instruct-GGUF/resolve/main/codellama-7b-instruct.Q2_K.gguf -O ggllm-test-model.bin test: ggllm-test-model.bin libbinding.a - C_INCLUDE_PATH=${INCLUDE_PATH} CGO_LDFLAGS=${CGO_LDFLAGS} LIBRARY_PATH=${LIBRARY_PATH} TEST_MODEL=ggllm-test-model.bin go run github.com/onsi/ginkgo/v2/ginkgo --label-filter="$(TEST_LABEL)" --flake-attempts 5 -v -r ./... \ No newline at end of file + C_INCLUDE_PATH=${INCLUDE_PATH} CGO_LDFLAGS=${CGO_LDFLAGS} LIBRARY_PATH=${LIBRARY_PATH} TEST_MODEL=$(abspath ./)/ggllm-test-model.bin go run github.com/onsi/ginkgo/v2/ginkgo --label-filter="$(TEST_LABEL)" -v -r ./... \ No newline at end of file diff --git a/binding.cpp b/binding.cpp index bd2f895..65bb50d 100644 --- a/binding.cpp +++ b/binding.cpp @@ -49,19 +49,19 @@ int get_embeddings(void* params_ptr, void* state_pr, float * res_embeddings) { int n_past = 0; - const bool add_bos = llama_vocab_type(ctx) == LLAMA_VOCAB_TYPE_SPM; + const bool add_bos = llama_vocab_type(state->model) == LLAMA_VOCAB_TYPE_SPM; // tokenize the prompt auto embd_inp = ::llama_tokenize(ctx, params.prompt, add_bos); if (embd_inp.size() > 0) { - if (llama_eval(ctx, embd_inp.data(), embd_inp.size(), n_past, params.n_threads)) { + if (llama_decode(ctx, llama_batch_get_one(embd_inp.data(), embd_inp.size(), n_past,0))) { fprintf(stderr, "%s : failed to eval\n", __func__); return 1; } } - const int n_embd = llama_n_embd(ctx); + const int n_embd = llama_n_embd(state->model); const auto embeddings = llama_get_embeddings(ctx); @@ -99,7 +99,7 @@ int eval(void* params_ptr,void* state_pr,char *text) { auto tokens = std::vector(params_p->n_ctx); std::string str = std::string(text); - auto n_prompt_tokens = llama_tokenize(ctx, str.data(), str.length(), tokens.data(), tokens.size(), true); + auto n_prompt_tokens = llama_tokenize(state->model, str.data(), str.length(), tokens.data(), tokens.size(), true); if (n_prompt_tokens < 1) { fprintf(stderr, "%s : failed to tokenize prompt\n", __func__); @@ -107,7 +107,7 @@ int eval(void* params_ptr,void* state_pr,char *text) { } // evaluate prompt - return llama_eval(ctx, tokens.data(), n_prompt_tokens, n_past, params_p->n_threads); + return llama_decode(ctx, llama_batch_get_one( tokens.data(), n_prompt_tokens, n_past, 0)); } static llama_context ** g_ctx; @@ -185,7 +185,7 @@ int llama_predict(void* params_ptr, void* state_pr, char* result, bool debug) { } } } - const bool add_bos = llama_vocab_type(ctx) == LLAMA_VOCAB_TYPE_SPM; + const bool add_bos = llama_vocab_type(state->model) == LLAMA_VOCAB_TYPE_SPM; std::vector embd_inp; if ( !params.prompt.empty() || session_tokens.empty() ) { @@ -306,24 +306,20 @@ int llama_predict(void* params_ptr, void* state_pr, char* result, bool debug) { std::vector embd; std::vector embd_guidance; - const int n_vocab = llama_n_vocab(ctx); + + const int n_vocab = llama_n_vocab(state->model); + std::vector candidates; candidates.reserve(n_vocab); std::string res = ""; - - { - const std::vector tmp = { llama_token_bos(ctx), }; - llama_eval(ctx, tmp.data(), tmp.size(), 0, params.n_threads); - llama_reset_timings(ctx); - } // set the seed before actually predicting llama_set_rng_seed(ctx, params.seed); while (n_remain != 0) { - // predict - if (embd.size() > 0) { + // predict + if (!embd.empty()) { // Note: n_ctx - 4 here is to match the logic for commandline prompt handling via // --prompt or --file which uses the same value. auto max_embd_size = n_ctx - 4; @@ -338,15 +334,18 @@ int llama_predict(void* params_ptr, void* state_pr, char* result, bool debug) { // - take the n_keep first tokens from the original prompt (via n_past) // - take half of the last (n_ctx - n_keep) tokens and recompute the logits in batches if (n_past + (int) embd.size() + std::max(0, guidance_offset) > n_ctx) { - const int n_left = n_past - params.n_keep; - - // always keep the first token - BOS - n_past = std::max(1, params.n_keep); - n_past_guidance = std::max(1, params.n_keep + guidance_offset); - - // insert n_left/2 tokens at the start of embd from last_tokens - embd.insert(embd.begin(), last_tokens.begin() + n_ctx - n_left/2 - embd.size(), last_tokens.end() - embd.size()); - + const int n_left = n_past - params.n_keep - 1; + const int n_discard = n_left/2; + LOG("context full, swapping: n_past = %d, n_left = %d, n_ctx = %d, n_keep = %d, n_discard = %d\n", + n_past, n_left, n_ctx, params.n_keep, n_discard); + + llama_kv_cache_seq_rm (ctx, 0, params.n_keep + 1 , params.n_keep + n_discard + 1); + llama_kv_cache_seq_shift(ctx, 0, params.n_keep + 1 + n_discard, n_past, -n_discard); + n_past -= n_discard; + if (ctx_guidance) { + n_past_guidance -= n_discard; + } + LOG("after swap: n_past = %d, n_past_guidance = %d\n", n_past, n_past_guidance); // stop saving session if we run out of context path_session.clear(); } @@ -372,6 +371,8 @@ int llama_predict(void* params_ptr, void* state_pr, char* result, bool debug) { if (i > 0) { embd.erase(embd.begin(), embd.begin() + i); } + // remove any "future" tokens that we might have inherited from the session from the KV cache + llama_kv_cache_tokens_rm(ctx, n_past, -1); } // evaluate tokens in batches @@ -409,8 +410,7 @@ int llama_predict(void* params_ptr, void* state_pr, char* result, bool debug) { for (int i = 0; i < input_size; i += params.n_batch) { int n_eval = std::min(input_size - i, params.n_batch); - if (llama_eval(ctx_guidance, input_buf + i, n_eval, n_past_guidance, params.n_threads)) { - fprintf(stderr, "%s : failed to eval\n", __func__); + if (llama_decode(ctx_guidance, llama_batch_get_one(input_buf + i, n_eval, n_past_guidance, 0))) { fprintf(stderr, "%s : failed to eval\n", __func__); return 1; } @@ -424,7 +424,7 @@ int llama_predict(void* params_ptr, void* state_pr, char* result, bool debug) { if (n_eval > params.n_batch) { n_eval = params.n_batch; } - if (llama_eval(ctx, &embd[i], n_eval, n_past, params.n_threads)) { + if (llama_decode(ctx, llama_batch_get_one(&embd[i], n_eval, n_past, 0))) { fprintf(stderr, "%s : failed to eval\n", __func__); return 1; } @@ -586,16 +586,16 @@ int speculative_sampling(void* params_ptr, void* target_model, void* draft_model const auto t_enc_start = ggml_time_us(); // eval the prompt with both models - llama_eval(ctx_tgt, inp.data(), int(inp.size() - 1), 0, params.n_threads); - llama_eval(ctx_tgt, &inp.back(), 1, inp.size() - 1, params.n_threads); - llama_eval(ctx_dft, inp.data(), int(inp.size()), 0, params.n_threads); + llama_decode(ctx_tgt, llama_batch_get_one( inp.data(), n_input - 1, 0, 0)); + llama_decode(ctx_tgt, llama_batch_get_one(&inp.back(), 1, n_input - 1, 0)); + llama_decode(ctx_dft, llama_batch_get_one( inp.data(), n_input, 0, 0)); const auto t_enc_end = ggml_time_us(); // the 2 models should have the same vocab const int n_ctx = llama_n_ctx(ctx_tgt); - const int n_vocab = llama_n_vocab(ctx_tgt); - //GGML_ASSERT(n_vocab == llama_n_vocab(ctx_dft)); + const int n_vocab = llama_n_vocab(model_tgt); + //GGML_ASSERT(n_vocab == llama_n_vocab(model_dft)); // how many tokens to draft each time const int n_draft = params.n_draft; @@ -648,8 +648,8 @@ int speculative_sampling(void* params_ptr, void* target_model, void* draft_model while (true) { // sample from the target model - // const llama_token id = llama_sample_token(ctx_tgt, NULL, grammar_tgt, params, last_tokens, candidates, i_dft); - const llama_token id = llama_sample_token_binding(ctx_tgt, NULL, grammar_tgt, params_p, last_tokens, candidates, i_dft); + // llama_token id = llama_sample_token(ctx_tgt, NULL, grammar_tgt, params, last_tokens, candidates, i_dft); + llama_token id = llama_sample_token_binding(ctx_tgt, NULL, grammar_tgt, params_p, last_tokens, candidates, i_dft); // remember which tokens were sampled - used for repetition penalties during sampling last_tokens.erase(last_tokens.begin()); last_tokens.push_back(id); @@ -687,7 +687,9 @@ int speculative_sampling(void* params_ptr, void* target_model, void* draft_model } // the drafted token was rejected or we are out of drafted tokens - llama_eval(ctx_dft, &id, 1, n_past_dft, params.n_threads); + llama_kv_cache_seq_rm(ctx_dft, 0, n_past_dft, -1); + llama_decode(ctx_dft, llama_batch_get_one(&id, 1, n_past_dft, 0)); + ++n_past_dft; drafted.clear(); @@ -750,7 +752,8 @@ int speculative_sampling(void* params_ptr, void* target_model, void* draft_model } // evaluate the drafted token on the draft model - llama_eval(ctx_dft, &drafted.back(), 1, n_past_cur, params.n_threads); + llama_kv_cache_seq_rm(ctx_dft, 0, n_past_cur, -1); + llama_decode(ctx_dft, llama_batch_get_one(&drafted.back(), 1, n_past_cur, 0)); ++n_past_cur; if (grammar_dft != NULL) { @@ -759,7 +762,8 @@ int speculative_sampling(void* params_ptr, void* target_model, void* draft_model } // evaluate the target model on the drafted tokens - llama_eval(ctx_tgt, drafted.data(), drafted.size(), n_past_tgt, params.n_threads); + llama_kv_cache_seq_rm(ctx_tgt, 0, n_past_tgt, -1); + llama_decode(ctx_tgt, llama_batch_get_one(drafted.data(), drafted.size(), n_past_tgt, 0)); ++n_past_tgt; // the first token is always proposed by the traget model before the speculation loop @@ -813,9 +817,9 @@ int llama_tokenize_string(void* params_ptr, void* state_pr, int* result) { llama_binding_state* state = (llama_binding_state*) state_pr; llama_context* ctx = state->ctx; - const bool add_bos = llama_vocab_type(ctx) == LLAMA_VOCAB_TYPE_SPM; + const bool add_bos = llama_vocab_type(state->model) == LLAMA_VOCAB_TYPE_SPM; - return llama_tokenize(ctx, params_p->prompt.data(), params_p->prompt.length(), result, params_p->n_ctx, add_bos); + return llama_tokenize(state->model, params_p->prompt.data(), params_p->prompt.length(), result, params_p->n_ctx, add_bos); } @@ -872,14 +876,16 @@ void save_state(void *ctx, char *dst, char*modes) { } } -void* llama_allocate_params(const char *prompt, int seed, int threads, int tokens, int top_k, +void* llama_allocate_params(const char *prompt, int seed, int threads, int batch_threads, int tokens, int top_k, float top_p, float temp, float repeat_penalty, int repeat_last_n, bool ignore_eos, bool memory_f16, int n_batch, int n_keep, const char** antiprompt, int antiprompt_count, float tfs_z, float typical_p, float frequency_penalty, float presence_penalty, int mirostat, float mirostat_eta, float mirostat_tau, bool penalize_nl, const char *logit_bias, const char *session_file, bool prompt_cache_all, bool mlock, bool mmap, - const char *maingpu,const char *tensorsplit , bool prompt_cache_ro, const char *grammar, + const char *maingpu,const char *tensorsplit , bool prompt_cache_ro, const char *grammar, float rope_freq_base, float rope_freq_scale, float negative_prompt_scale, const char* negative_prompt, int n_draft) { gpt_params* params = new gpt_params; params->seed = seed; params->n_threads = threads; + params->n_threads_batch = batch_threads; + params->n_threads_batch = params->n_threads_batch == -1 ? params->n_threads : params->n_threads_batch; params->n_predict = tokens; params->repeat_last_n = repeat_last_n; params->prompt_cache_ro = prompt_cache_ro; @@ -948,8 +954,8 @@ void* llama_allocate_params(const char *prompt, int seed, int threads, int token return params; } -void* load_model(const char *fname, int n_ctx, int n_seed, bool memory_f16, bool mlock, bool embeddings, bool mmap, bool low_vram, int n_gpu_layers, int n_batch, const char *maingpu, const char *tensorsplit, bool numa, float rope_freq_base, float rope_freq_scale, bool mul_mat_q, const char *lora, const char *lora_base, bool perplexity) { - return load_binding_model(fname, n_ctx, n_seed, memory_f16, mlock, embeddings, mmap, low_vram, n_gpu_layers, n_batch, maingpu, tensorsplit, numa, rope_freq_base, rope_freq_scale, mul_mat_q, lora, lora_base, perplexity); +void* load_model(const char *fname, int n_ctx, int n_seed, bool memory_f16, bool mlock, bool embeddings, bool mmap, int n_gpu_layers, int n_batch, const char *maingpu, const char *tensorsplit, bool numa, float rope_freq_base, float rope_freq_scale, bool mul_mat_q, const char *lora, const char *lora_base, float lora_scale, bool logits_all) { + return load_binding_model(fname, n_ctx, n_seed, memory_f16, mlock, embeddings, mmap, n_gpu_layers, n_batch, maingpu, tensorsplit, numa, rope_freq_base, rope_freq_scale, mul_mat_q, lora, lora_base, lora_scale, logits_all); } /* @@ -967,7 +973,7 @@ struct llama_binding_state { llama_model * model; }; -void* load_binding_model(const char *fname, int n_ctx, int n_seed, bool memory_f16, bool mlock, bool embeddings, bool mmap, bool low_vram, int n_gpu_layers, int n_batch, const char *maingpu, const char *tensorsplit, bool numa, float rope_freq_base, float rope_freq_scale, bool mul_mat_q, const char *lora, const char *lora_base, bool perplexity); +void* load_binding_model(const char *fname, int n_ctx, int n_seed, bool memory_f16, bool mlock, bool embeddings, bool mmap, int n_gpu_layers, int n_batch, const char *maingpu, const char *tensorsplit, bool numa, float rope_freq_base, float rope_freq_scale, bool mul_mat_q, const char *lora, const char *lora_base, float lora_scale, bool logits_all); llama_token llama_sample_token_binding( struct llama_context * ctx, @@ -980,14 +986,19 @@ llama_token llama_sample_token_binding( common.cpp: -gpt_params* create_gpt_params(const std::string& fname,const std::string& lora,const std::string& lora_base) { +gpt_params* create_gpt_params(const std::string& fname,const std::string& lora,const std::string& lora_base, float lora_scale) { gpt_params* lparams = new gpt_params; fprintf(stderr, "%s: loading model %s\n", __func__, fname.c_str()); // Initialize the 'model' member with the 'fname' parameter lparams->model = fname; lparams->lora_base = lora_base; - lparams->lora_adapter = lora; + if (lora_scale == 0 && !lora_base.empty()) { + lora_scale = 1.0f; + } + if (!lora.empty()) { + lparams->lora_adapter.push_back(std::make_tuple(lora, lora_scale)); + } if (lparams->lora_adapter.empty()) { lparams->use_mmap = false; } @@ -1003,14 +1014,14 @@ gpt_params* create_gpt_params_cuda(const std::string& fname) { return lparams; } -void* load_binding_model(const char *fname, int n_ctx, int n_seed, bool memory_f16, bool mlock, bool embeddings, bool mmap, bool low_vram, int n_gpu_layers, int n_batch, const char *maingpu, const char *tensorsplit, bool numa, float rope_freq_base, float rope_freq_scale, bool mul_mat_q, const char *lora, const char *lora_base, bool perplexity) { +void* load_binding_model(const char *fname, int n_ctx, int n_seed, bool memory_f16, bool mlock, bool embeddings, bool mmap, int n_gpu_layers, int n_batch, const char *maingpu, const char *tensorsplit, bool numa, float rope_freq_base, float rope_freq_scale, bool mul_mat_q, const char *lora, const char *lora_base, float lora_scale, bool logits_all) { // load the model gpt_params * lparams; // Temporary workaround for https://github.com/go-skynet/go-llama.cpp/issues/218 #ifdef GGML_USE_CUBLAS lparams = create_gpt_params_cuda(fname); #else - lparams = create_gpt_params(fname, lora, lora_base); + lparams = create_gpt_params(fname, lora, lora_base, lora_scale); #endif llama_model * model; llama_binding_state * state; @@ -1022,10 +1033,8 @@ void* load_binding_model(const char *fname, int n_ctx, int n_seed, bool memory_f lparams->embedding = embeddings; lparams->use_mlock = mlock; lparams->n_gpu_layers = n_gpu_layers; - lparams->perplexity = perplexity; + lparams->logits_all = logits_all; lparams->use_mmap = mmap; - - lparams->low_vram = low_vram; if (rope_freq_base != 0.0f) { lparams->rope_freq_base = rope_freq_base; } else { @@ -1042,7 +1051,7 @@ void* load_binding_model(const char *fname, int n_ctx, int n_seed, bool memory_f if (maingpu[0] != '\0') { lparams->main_gpu = std::stoi(maingpu); } - + if (tensorsplit[0] != '\0') { std::string arg_next = tensorsplit; // split string by , and / @@ -1081,15 +1090,15 @@ llama_token llama_sample_token_binding( struct llama_context * ctx, struct llama_context * ctx_guidance, struct llama_grammar * grammar, - const struct gpt_params * g_params, // NOTE: this is our patch + const struct gpt_params * g_params, const std::vector & last_tokens, std::vector & candidates, int idx) { - - struct gpt_params params = *g_params; // NOTE: this is our patch + struct gpt_params params = *g_params; + const int n_ctx = llama_n_ctx(ctx); - const int n_vocab = llama_n_vocab(ctx); + const int n_vocab = llama_n_vocab(llama_get_model(ctx)); const float temp = params.temp; const int32_t top_k = params.top_k <= 0 ? n_vocab : params.top_k; @@ -1107,7 +1116,7 @@ llama_token llama_sample_token_binding( llama_token id = 0; - float * logits = llama_get_logits(ctx) + idx * n_vocab; + float * logits = llama_get_logits_ith(ctx, idx); // Apply params.logit_bias map for (auto it = params.logit_bias.begin(); it != params.logit_bias.end(); it++) { @@ -1158,11 +1167,11 @@ llama_token llama_sample_token_binding( if (mirostat == 1) { static float mirostat_mu = 2.0f * mirostat_tau; const int mirostat_m = 100; - llama_sample_temperature(ctx, &cur_p, temp); + llama_sample_temp(ctx, &cur_p, temp); id = llama_sample_token_mirostat(ctx, &cur_p, mirostat_tau, mirostat_eta, mirostat_m, &mirostat_mu); } else if (mirostat == 2) { static float mirostat_mu = 2.0f * mirostat_tau; - llama_sample_temperature(ctx, &cur_p, temp); + llama_sample_temp(ctx, &cur_p, temp); id = llama_sample_token_mirostat_v2(ctx, &cur_p, mirostat_tau, mirostat_eta, &mirostat_mu); } else { // Temperature sampling @@ -1170,7 +1179,7 @@ llama_token llama_sample_token_binding( llama_sample_tail_free (ctx, &cur_p, tfs_z, 1); llama_sample_typical (ctx, &cur_p, typical_p, 1); llama_sample_top_p (ctx, &cur_p, top_p, 1); - llama_sample_temperature(ctx, &cur_p, temp); + llama_sample_temp(ctx, &cur_p, temp); { const int n_top = 10; @@ -1195,5 +1204,4 @@ llama_token llama_sample_token_binding( return id; } - */ diff --git a/binding.h b/binding.h index 44664eb..7ad05df 100644 --- a/binding.h +++ b/binding.h @@ -21,7 +21,6 @@ void* load_model(const char *fname, bool mlock, bool embeddings, bool mmap, - bool low_vram, int n_gpu, int n_batch, const char *maingpu, @@ -29,14 +28,14 @@ void* load_model(const char *fname, bool numa, float rope_freq_base, float rope_freq_scale, - bool mul_mat_q, const char *lora, const char *lora_base, bool perplexity + bool mul_mat_q, const char *lora, const char *lora_base, float lora_scale, bool perplexity ); int get_embeddings(void* params_ptr, void* state_pr, float * res_embeddings); int get_token_embeddings(void* params_ptr, void* state_pr, int *tokens, int tokenSize, float * res_embeddings); -void* llama_allocate_params(const char *prompt, int seed, int threads, int tokens, +void* llama_allocate_params(const char *prompt, int seed, int threads, int batch_threads, int tokens, int top_k, float top_p, float temp, float repeat_penalty, int repeat_last_n, bool ignore_eos, bool memory_f16, int n_batch, int n_keep, const char** antiprompt, int antiprompt_count, diff --git a/llama.cpp b/llama.cpp index ac43576..0e797c2 160000 --- a/llama.cpp +++ b/llama.cpp @@ -1 +1 @@ -Subproject commit ac43576124a75c2de6e333ac31a3444ff9eb9458 +Subproject commit 0e797c2fc571b866090f7d60ac7d39d8533593f2 diff --git a/llama.go b/llama.go index c1ebc2c..67a6272 100644 --- a/llama.go +++ b/llama.go @@ -38,10 +38,10 @@ func New(model string, opts ...ModelOption) (*LLama, error) { result := C.load_model(modelPath, C.int(mo.ContextSize), C.int(mo.Seed), - C.bool(mo.F16Memory), C.bool(mo.MLock), C.bool(mo.Embeddings), C.bool(mo.MMap), C.bool(mo.LowVRAM), + C.bool(mo.F16Memory), C.bool(mo.MLock), C.bool(mo.Embeddings), C.bool(mo.MMap), C.int(mo.NGPULayers), C.int(mo.NBatch), C.CString(mo.MainGPU), C.CString(mo.TensorSplit), C.bool(mo.NUMA), C.float(mo.FreqRopeBase), C.float(mo.FreqRopeScale), - C.bool(MulMatQ), loraAdapter, loraBase, C.bool(mo.Perplexity), + C.bool(MulMatQ), loraAdapter, loraBase, C.float(mo.LoraScale), C.bool(mo.Perplexity), ) if result == nil { @@ -112,7 +112,7 @@ func (l *LLama) TokenEmbeddings(tokens []int, opts ...PredictOption) ([]float32, // float tfs_z, float typical_p, float frequency_penalty, float presence_penalty, int mirostat, float mirostat_eta, float mirostat_tau, bool penalize_nl, const char *logit_bias, const char *session_file, bool prompt_cache_all, bool mlock, bool mmap, const char *maingpu, const char *tensorsplit , bool prompt_cache_ro, // float rope_freq_base, float rope_freq_scale, float negative_prompt_scale, const char* negative_prompt // ); - params := C.llama_allocate_params(C.CString(""), C.int(po.Seed), C.int(po.Threads), C.int(po.Tokens), C.int(po.TopK), + params := C.llama_allocate_params(C.CString(""), C.int(po.Seed), C.int(po.Threads), C.int(po.BatchThreads), C.int(po.Tokens), C.int(po.TopK), C.float(po.TopP), C.float(po.Temperature), C.float(po.Penalty), C.int(po.Repeat), C.bool(po.IgnoreEOS), C.bool(po.F16KV), C.int(po.Batch), C.int(po.NKeep), nil, C.int(0), @@ -154,7 +154,7 @@ func (l *LLama) Embeddings(text string, opts ...PredictOption) ([]float32, error pass = &reversePrompt[0] } - params := C.llama_allocate_params(input, C.int(po.Seed), C.int(po.Threads), C.int(po.Tokens), C.int(po.TopK), + params := C.llama_allocate_params(input, C.int(po.Seed), C.int(po.Threads), C.int(po.BatchThreads), C.int(po.Tokens), C.int(po.TopK), C.float(po.TopP), C.float(po.Temperature), C.float(po.Penalty), C.int(po.Repeat), C.bool(po.IgnoreEOS), C.bool(po.F16KV), C.int(po.Batch), C.int(po.NKeep), pass, C.int(reverseCount), @@ -193,7 +193,7 @@ func (l *LLama) Eval(text string, opts ...PredictOption) error { pass = &reversePrompt[0] } - params := C.llama_allocate_params(input, C.int(po.Seed), C.int(po.Threads), C.int(po.Tokens), C.int(po.TopK), + params := C.llama_allocate_params(input, C.int(po.Seed), C.int(po.Threads), C.int(po.BatchThreads), C.int(po.Tokens), C.int(po.TopK), C.float(po.TopP), C.float(po.Temperature), C.float(po.Penalty), C.int(po.Repeat), C.bool(po.IgnoreEOS), C.bool(po.F16KV), C.int(po.Batch), C.int(po.NKeep), pass, C.int(reverseCount), @@ -238,7 +238,7 @@ func (l *LLama) SpeculativeSampling(ll *LLama, text string, opts ...PredictOptio pass = &reversePrompt[0] } - params := C.llama_allocate_params(input, C.int(po.Seed), C.int(po.Threads), C.int(po.Tokens), C.int(po.TopK), + params := C.llama_allocate_params(input, C.int(po.Seed), C.int(po.Threads), C.int(po.BatchThreads), C.int(po.Tokens), C.int(po.TopK), C.float(po.TopP), C.float(po.Temperature), C.float(po.Penalty), C.int(po.Repeat), C.bool(po.IgnoreEOS), C.bool(po.F16KV), C.int(po.Batch), C.int(po.NKeep), pass, C.int(reverseCount), @@ -296,7 +296,7 @@ func (l *LLama) Predict(text string, opts ...PredictOption) (string, error) { pass = &reversePrompt[0] } - params := C.llama_allocate_params(input, C.int(po.Seed), C.int(po.Threads), C.int(po.Tokens), C.int(po.TopK), + params := C.llama_allocate_params(input, C.int(po.Seed), C.int(po.Threads), C.int(po.BatchThreads), C.int(po.Tokens), C.int(po.TopK), C.float(po.TopP), C.float(po.Temperature), C.float(po.Penalty), C.int(po.Repeat), C.bool(po.IgnoreEOS), C.bool(po.F16KV), C.int(po.Batch), C.int(po.NKeep), pass, C.int(reverseCount), @@ -346,7 +346,7 @@ func (l *LLama) TokenizeString(text string, opts ...PredictOption) (int32, []int var fakeDblPtr **C.char // copy pasted and modified minimally. Should I simplify down / do we need an "allocate defaults" - params := C.llama_allocate_params(input, C.int(po.Seed), C.int(po.Threads), C.int(po.Tokens), C.int(po.TopK), + params := C.llama_allocate_params(input, C.int(po.Seed), C.int(po.Threads), C.int(po.BatchThreads), C.int(po.Tokens), C.int(po.TopK), C.float(po.TopP), C.float(po.Temperature), C.float(po.Penalty), C.int(po.Repeat), C.bool(po.IgnoreEOS), C.bool(po.F16KV), C.int(po.Batch), C.int(po.NKeep), fakeDblPtr, C.int(0), diff --git a/llama_test.go b/llama_test.go index 8c266ef..ded03f2 100644 --- a/llama_test.go +++ b/llama_test.go @@ -71,7 +71,7 @@ how much is 2+2? Expect(err).ToNot(HaveOccurred()) Expect(model).ToNot(BeNil()) text, err := model.SpeculativeSampling(model2, `[INST] Answer to the following question: -how much is 2+2? +Do a simple math calculation: How much is 2+2? [/INST]`, llama.SetNDraft(16), ) Expect(err).ToNot(HaveOccurred(), text) @@ -97,7 +97,10 @@ how much is 2+2? getModel := func() (*LLama, error) { model, err := New( testModelPath, - llama.EnableF16Memory, llama.SetContext(128), llama.EnableEmbeddings, llama.SetGPULayers(10), + llama.EnableF16Memory, + llama.SetContext(128), + llama.EnableEmbeddings, + llama.SetGPULayers(10), ) Expect(err).ToNot(HaveOccurred()) Expect(model).ToNot(BeNil()) diff --git a/options.go b/options.go index b36c671..42c66bb 100644 --- a/options.go +++ b/options.go @@ -7,7 +7,6 @@ type ModelOptions struct { F16Memory bool MLock bool MMap bool - LowVRAM bool Embeddings bool NUMA bool NGPULayers int @@ -16,6 +15,7 @@ type ModelOptions struct { FreqRopeBase float32 FreqRopeScale float32 MulMatQ *bool + LoraScale float32 LoraBase string LoraAdapter string Perplexity bool @@ -29,6 +29,7 @@ type PredictOptions struct { DebugMode bool StopPrompts []string IgnoreEOS bool + BatchThreads int TailFreeSamplingZ float32 TypicalP float32 @@ -68,7 +69,6 @@ var DefaultModelOptions ModelOptions = ModelOptions{ MLock: false, Embeddings: false, MMap: true, - LowVRAM: false, NBatch: 512, FreqRopeBase: 10000, FreqRopeScale: 1.0, @@ -79,6 +79,7 @@ var DefaultOptions PredictOptions = PredictOptions{ Threads: 4, Tokens: 128, Penalty: 1.1, + BatchThreads: -1, Repeat: 64, Batch: 512, NKeep: 64, @@ -109,6 +110,18 @@ func SetLoraBase(s string) ModelOption { } } +func SetBatchThreads(b int) PredictOption { + return func(p *PredictOptions) { + p.BatchThreads = b + } +} + +func SetLoraScale(f float32) ModelOption { + return func(p *ModelOptions) { + p.LoraScale = f + } +} + func SetLoraAdapter(s string) ModelOption { return func(p *ModelOptions) { p.LoraAdapter = s @@ -219,10 +232,6 @@ func SetNegativePrompt(np string) PredictOption { } } -var EnabelLowVRAM ModelOption = func(p *ModelOptions) { - p.LowVRAM = true -} - var EnableNUMA ModelOption = func(p *ModelOptions) { p.NUMA = true } diff --git a/patches/1902-cuda.patch b/patches/1902-cuda.patch index aed2fd4..2658056 100644 --- a/patches/1902-cuda.patch +++ b/patches/1902-cuda.patch @@ -1,20 +1,25 @@ diff --git a/common/common.cpp b/common/common.cpp -index 2597ba0..e42ae73 100644 +index ec181c6..9ba699b 100644 --- a/common/common.cpp +++ b/common/common.cpp -@@ -1268,3 +1268,218 @@ void dump_non_result_info_yaml(FILE * stream, const gpt_params & params, const l +@@ -1345,3 +1345,222 @@ void dump_non_result_info_yaml(FILE * stream, const gpt_params & params, const l fprintf(stream, "typical_p: %f # default: 1.0\n", params.typical_p); fprintf(stream, "verbose_prompt: %s # default: false\n", params.verbose_prompt ? "true" : "false"); } + -+gpt_params* create_gpt_params(const std::string& fname,const std::string& lora,const std::string& lora_base) { ++gpt_params* create_gpt_params(const std::string& fname,const std::string& lora,const std::string& lora_base, float lora_scale) { + gpt_params* lparams = new gpt_params; + fprintf(stderr, "%s: loading model %s\n", __func__, fname.c_str()); + + // Initialize the 'model' member with the 'fname' parameter + lparams->model = fname; + lparams->lora_base = lora_base; -+ lparams->lora_adapter = lora; ++ if (lora_scale == 0 && !lora_base.empty()) { ++ lora_scale = 1.0f; ++ } ++ if (!lora.empty()) { ++ lparams->lora_adapter.push_back(std::make_tuple(lora, lora_scale)); ++ } + if (lparams->lora_adapter.empty()) { + lparams->use_mmap = false; + } @@ -30,14 +35,14 @@ index 2597ba0..e42ae73 100644 + return lparams; +} + -+void* load_binding_model(const char *fname, int n_ctx, int n_seed, bool memory_f16, bool mlock, bool embeddings, bool mmap, bool low_vram, int n_gpu_layers, int n_batch, const char *maingpu, const char *tensorsplit, bool numa, float rope_freq_base, float rope_freq_scale, bool mul_mat_q, const char *lora, const char *lora_base, bool perplexity) { ++void* load_binding_model(const char *fname, int n_ctx, int n_seed, bool memory_f16, bool mlock, bool embeddings, bool mmap, int n_gpu_layers, int n_batch, const char *maingpu, const char *tensorsplit, bool numa, float rope_freq_base, float rope_freq_scale, bool mul_mat_q, const char *lora, const char *lora_base, float lora_scale, bool logits_all) { + // load the model + gpt_params * lparams; +// Temporary workaround for https://github.com/go-skynet/go-llama.cpp/issues/218 +#ifdef GGML_USE_CUBLAS + lparams = create_gpt_params_cuda(fname); +#else -+ lparams = create_gpt_params(fname, lora, lora_base); ++ lparams = create_gpt_params(fname, lora, lora_base, lora_scale); +#endif + llama_model * model; + llama_binding_state * state; @@ -49,10 +54,8 @@ index 2597ba0..e42ae73 100644 + lparams->embedding = embeddings; + lparams->use_mlock = mlock; + lparams->n_gpu_layers = n_gpu_layers; -+ lparams->perplexity = perplexity; ++ lparams->logits_all = logits_all; + lparams->use_mmap = mmap; -+ -+ lparams->low_vram = low_vram; + if (rope_freq_base != 0.0f) { + lparams->rope_freq_base = rope_freq_base; + } else { @@ -114,8 +117,9 @@ index 2597ba0..e42ae73 100644 + int idx) { + + struct gpt_params params = *g_params; ++ + const int n_ctx = llama_n_ctx(ctx); -+ const int n_vocab = llama_n_vocab(ctx); ++ const int n_vocab = llama_n_vocab(llama_get_model(ctx)); + + const float temp = params.temp; + const int32_t top_k = params.top_k <= 0 ? n_vocab : params.top_k; @@ -133,7 +137,7 @@ index 2597ba0..e42ae73 100644 + + llama_token id = 0; + -+ float * logits = llama_get_logits(ctx) + idx * n_vocab; ++ float * logits = llama_get_logits_ith(ctx, idx); + + // Apply params.logit_bias map + for (auto it = params.logit_bias.begin(); it != params.logit_bias.end(); it++) { @@ -184,11 +188,11 @@ index 2597ba0..e42ae73 100644 + if (mirostat == 1) { + static float mirostat_mu = 2.0f * mirostat_tau; + const int mirostat_m = 100; -+ llama_sample_temperature(ctx, &cur_p, temp); ++ llama_sample_temp(ctx, &cur_p, temp); + id = llama_sample_token_mirostat(ctx, &cur_p, mirostat_tau, mirostat_eta, mirostat_m, &mirostat_mu); + } else if (mirostat == 2) { + static float mirostat_mu = 2.0f * mirostat_tau; -+ llama_sample_temperature(ctx, &cur_p, temp); ++ llama_sample_temp(ctx, &cur_p, temp); + id = llama_sample_token_mirostat_v2(ctx, &cur_p, mirostat_tau, mirostat_eta, &mirostat_mu); + } else { + // Temperature sampling @@ -196,7 +200,7 @@ index 2597ba0..e42ae73 100644 + llama_sample_tail_free (ctx, &cur_p, tfs_z, 1); + llama_sample_typical (ctx, &cur_p, typical_p, 1); + llama_sample_top_p (ctx, &cur_p, top_p, 1); -+ llama_sample_temperature(ctx, &cur_p, temp); ++ llama_sample_temp(ctx, &cur_p, temp); + + { + const int n_top = 10; @@ -223,10 +227,10 @@ index 2597ba0..e42ae73 100644 +} \ No newline at end of file diff --git a/common/common.h b/common/common.h -index 18aea38..ca7a168 100644 +index 0e2d3fa..9992d2b 100644 --- a/common/common.h +++ b/common/common.h -@@ -209,3 +209,19 @@ std::string get_sortable_timestamp(); +@@ -221,3 +221,19 @@ std::string get_sortable_timestamp(); void dump_non_result_info_yaml( FILE * stream, const gpt_params & params, const llama_context * lctx, const std::string & timestamp, const std::vector & prompt_tokens, const char * model_desc); @@ -236,7 +240,7 @@ index 18aea38..ca7a168 100644 + llama_model * model; +}; + -+void* load_binding_model(const char *fname, int n_ctx, int n_seed, bool memory_f16, bool mlock, bool embeddings, bool mmap, bool low_vram, int n_gpu_layers, int n_batch, const char *maingpu, const char *tensorsplit, bool numa, float rope_freq_base, float rope_freq_scale, bool mul_mat_q, const char *lora, const char *lora_base, bool perplexity); ++void* load_binding_model(const char *fname, int n_ctx, int n_seed, bool memory_f16, bool mlock, bool embeddings, bool mmap, int n_gpu_layers, int n_batch, const char *maingpu, const char *tensorsplit, bool numa, float rope_freq_base, float rope_freq_scale, bool mul_mat_q, const char *lora, const char *lora_base, float lora_scale, bool logits_all); + +llama_token llama_sample_token_binding( + struct llama_context * ctx,