Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions include/llama.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,6 @@ extern "C" {

typedef struct llama_memory_i * llama_memory_t;

struct llama_kv_cache; // DEPRECATED (use llama_memory instead)

typedef int32_t llama_pos;
typedef int32_t llama_token;
typedef int32_t llama_seq_id;
Expand Down Expand Up @@ -469,8 +467,6 @@ extern "C" {
LLAMA_API llama_memory_t llama_get_memory (const struct llama_context * ctx);
LLAMA_API enum llama_pooling_type llama_pooling_type(const struct llama_context * ctx); // TODO: rename to llama_get_pooling_type

DEPRECATED(LLAMA_API struct llama_kv_cache * llama_get_kv_self(struct llama_context * ctx), "use llama_get_memory instead");

LLAMA_API const struct llama_vocab * llama_model_get_vocab(const struct llama_model * model);
LLAMA_API enum llama_rope_type llama_model_rope_type(const struct llama_model * model);

Expand Down
4 changes: 2 additions & 2 deletions src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@ add_library(llama
llama-hparams.cpp
llama-impl.cpp
llama-io.cpp
llama-kv-cache-unified.cpp
llama-kv-cache-unified-iswa.cpp
llama-kv-cache.cpp
llama-kv-cache-iswa.cpp
llama-memory.cpp
llama-memory-hybrid.cpp
llama-memory-recurrent.cpp
Expand Down
5 changes: 0 additions & 5 deletions src/llama-context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2338,11 +2338,6 @@ const llama_model * llama_get_model(const llama_context * ctx) {
return &ctx->get_model();
}

// deprecated
llama_kv_cache * llama_get_kv_self(llama_context * ctx) {
return dynamic_cast<llama_kv_cache *>(ctx->get_memory());
}

// deprecated
void llama_kv_self_update(llama_context * ctx) {
ctx->kv_self_update(false);
Expand Down
54 changes: 27 additions & 27 deletions src/llama-graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
#include "llama-batch.h"
#include "llama-cparams.h"

#include "llama-kv-cache-unified.h"
#include "llama-kv-cache-unified-iswa.h"
#include "llama-kv-cache.h"
#include "llama-kv-cache-iswa.h"
#include "llama-memory-hybrid.h"
#include "llama-memory-recurrent.h"

Expand Down Expand Up @@ -277,7 +277,7 @@ void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
for (int s = 0; s < ubatch->n_seq_id[i0]; ++s) {
const llama_seq_id s0 = ubatch->seq_id[i0][0];

// TODO: reimplement this like in llama_kv_cache_unified
// TODO: reimplement this like in llama_kv_cache
if (s0 == s1 && (!cparams.causal_attn || ubatch->pos[i0] <= ubatch->pos[i1])) {
if (hparams.use_alibi) {
f = -std::abs(ubatch->pos[i0] - ubatch->pos[i1]);
Expand All @@ -294,15 +294,15 @@ void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
}
}

void llm_graph_input_attn_kv_unified::set_input(const llama_ubatch * ubatch) {
void llm_graph_input_attn_kv::set_input(const llama_ubatch * ubatch) {
mctx->set_input_k_idxs(self_k_idxs, ubatch);
mctx->set_input_v_idxs(self_v_idxs, ubatch);

mctx->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
}

bool llm_graph_input_attn_kv_unified::can_reuse(const llm_graph_params & params) {
const auto * mctx = static_cast<const llama_kv_cache_unified_context *>(params.mctx);
bool llm_graph_input_attn_kv::can_reuse(const llm_graph_params & params) {
const auto * mctx = static_cast<const llama_kv_cache_context *>(params.mctx);

this->mctx = mctx;

Expand All @@ -319,7 +319,7 @@ bool llm_graph_input_attn_kv_unified::can_reuse(const llm_graph_params & params)
return res;
}

void llm_graph_input_attn_kv_unified_iswa::set_input(const llama_ubatch * ubatch) {
void llm_graph_input_attn_kv_iswa::set_input(const llama_ubatch * ubatch) {
mctx->get_base()->set_input_k_idxs(self_k_idxs, ubatch);
mctx->get_base()->set_input_v_idxs(self_v_idxs, ubatch);

Expand All @@ -331,8 +331,8 @@ void llm_graph_input_attn_kv_unified_iswa::set_input(const llama_ubatch * ubatch
mctx->get_swa()->set_input_kq_mask(self_kq_mask_swa, ubatch, cparams.causal_attn);
}

bool llm_graph_input_attn_kv_unified_iswa::can_reuse(const llm_graph_params & params) {
const auto * mctx = static_cast<const llama_kv_cache_unified_iswa_context *>(params.mctx);
bool llm_graph_input_attn_kv_iswa::can_reuse(const llm_graph_params & params) {
const auto * mctx = static_cast<const llama_kv_cache_iswa_context *>(params.mctx);

this->mctx = mctx;

Expand Down Expand Up @@ -1186,7 +1186,7 @@ ggml_tensor * llm_graph_context::build_inp_pos_bucket_enc() const {
}

ggml_tensor * llm_graph_context::build_inp_pos_bucket_dec() const {
const auto * mctx_cur = static_cast<const llama_kv_cache_unified_context *>(mctx);
const auto * mctx_cur = static_cast<const llama_kv_cache_context *>(mctx);

auto inp = std::make_unique<llm_graph_input_pos_bucket_kv>(hparams, mctx_cur);

Expand Down Expand Up @@ -1399,17 +1399,17 @@ ggml_tensor * llm_graph_context::build_attn(
return cur;
}

static std::unique_ptr<llm_graph_input_attn_kv_unified> build_attn_inp_kv_unified_impl(
static std::unique_ptr<llm_graph_input_attn_kv> build_attn_inp_kv_impl(
ggml_context * ctx0,
const llama_ubatch & ubatch,
const llama_hparams & hparams,
const llama_cparams & cparams,
const llama_kv_cache_unified_context * mctx_cur) {
const llama_kv_cache_context * mctx_cur) {

auto inp = std::make_unique<llm_graph_input_attn_kv_unified>(hparams, cparams, mctx_cur);
auto inp = std::make_unique<llm_graph_input_attn_kv>(hparams, cparams, mctx_cur);

{
GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_unified_iswa for SWA");
GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_iswa for SWA");

const auto n_kv = mctx_cur->get_n_kv();
const auto n_tokens = ubatch.n_tokens;
Expand All @@ -1427,16 +1427,16 @@ static std::unique_ptr<llm_graph_input_attn_kv_unified> build_attn_inp_kv_unifie
return inp;
}

llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified() const {
const auto * mctx_cur = static_cast<const llama_kv_cache_unified_context *>(mctx);
llm_graph_input_attn_kv * llm_graph_context::build_attn_inp_kv() const {
const auto * mctx_cur = static_cast<const llama_kv_cache_context *>(mctx);

auto inp = build_attn_inp_kv_unified_impl(ctx0, ubatch, hparams, cparams, mctx_cur);
auto inp = build_attn_inp_kv_impl(ctx0, ubatch, hparams, cparams, mctx_cur);

return (llm_graph_input_attn_kv_unified *) res->add_input(std::move(inp));
return (llm_graph_input_attn_kv *) res->add_input(std::move(inp));
}

ggml_tensor * llm_graph_context::build_attn(
llm_graph_input_attn_kv_unified * inp,
llm_graph_input_attn_kv * inp,
ggml_tensor * wo,
ggml_tensor * wo_b,
ggml_tensor * q_cur,
Expand Down Expand Up @@ -1488,7 +1488,7 @@ ggml_tensor * llm_graph_context::build_attn(
}

ggml_tensor * llm_graph_context::build_attn(
llm_graph_input_attn_kv_unified_iswa * inp,
llm_graph_input_attn_kv_iswa * inp,
ggml_tensor * wo,
ggml_tensor * wo_b,
ggml_tensor * q_cur,
Expand All @@ -1513,7 +1513,7 @@ ggml_tensor * llm_graph_context::build_attn(
}

ggml_tensor * llm_graph_context::build_attn_with_sinks(
llm_graph_input_attn_kv_unified_iswa * inp,
llm_graph_input_attn_kv_iswa * inp,
ggml_tensor * wo,
ggml_tensor * wo_b,
ggml_tensor * q_cur,
Expand Down Expand Up @@ -1636,10 +1636,10 @@ ggml_tensor * llm_graph_context::build_attn(
// TODO: maybe separate the inner implementation into a separate function
// like with the non-sliding window equivalent
// once sliding-window hybrid caches are a thing.
llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unified_iswa() const {
const auto * mctx_cur = static_cast<const llama_kv_cache_unified_iswa_context *>(mctx);
llm_graph_input_attn_kv_iswa * llm_graph_context::build_attn_inp_kv_iswa() const {
const auto * mctx_cur = static_cast<const llama_kv_cache_iswa_context *>(mctx);

auto inp = std::make_unique<llm_graph_input_attn_kv_unified_iswa>(hparams, cparams, mctx_cur);
auto inp = std::make_unique<llm_graph_input_attn_kv_iswa>(hparams, cparams, mctx_cur);

const auto n_stream = cparams.kv_unified ? 1 : ubatch.n_seqs_unq;

Expand All @@ -1656,7 +1656,7 @@ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unif
}

{
GGML_ASSERT(hparams.swa_type != LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_unified for non-SWA");
GGML_ASSERT(hparams.swa_type != LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache for non-SWA");

const auto n_kv = mctx_cur->get_swa()->get_n_kv();

Expand All @@ -1669,7 +1669,7 @@ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unif
inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa;
}

return (llm_graph_input_attn_kv_unified_iswa *) res->add_input(std::move(inp));
return (llm_graph_input_attn_kv_iswa *) res->add_input(std::move(inp));
}

ggml_tensor * llm_graph_context::build_rs(
Expand Down Expand Up @@ -1792,7 +1792,7 @@ llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid() const {
const auto * mctx_cur = static_cast<const llama_memory_hybrid_context *>(mctx);

auto inp_rs = build_rs_inp_impl(ctx0, ubatch, mctx_cur->get_recr());
auto inp_attn = build_attn_inp_kv_unified_impl(ctx0, ubatch, hparams, cparams, mctx_cur->get_attn());
auto inp_attn = build_attn_inp_kv_impl(ctx0, ubatch, hparams, cparams, mctx_cur->get_attn());

auto inp = std::make_unique<llm_graph_input_mem_hybrid>(std::move(inp_attn), std::move(inp_rs), mctx_cur);

Expand Down
50 changes: 25 additions & 25 deletions src/llama-graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ struct llama_cparams;

struct llama_memory_context_i;

class llama_kv_cache_unified_context;
class llama_kv_cache_unified_iswa_context;
class llama_kv_cache_context;
class llama_kv_cache_iswa_context;
class llama_memory_recurrent_context;
class llama_memory_hybrid_context;

Expand Down Expand Up @@ -152,7 +152,7 @@ class llm_graph_input_pos_bucket_kv : public llm_graph_input_i {
public:
llm_graph_input_pos_bucket_kv(
const llama_hparams & hparams,
const llama_kv_cache_unified_context * mctx) : hparams(hparams), mctx(mctx) {}
const llama_kv_cache_context * mctx) : hparams(hparams), mctx(mctx) {}
virtual ~llm_graph_input_pos_bucket_kv() = default;

void set_input(const llama_ubatch * ubatch) override;
Expand All @@ -161,7 +161,7 @@ class llm_graph_input_pos_bucket_kv : public llm_graph_input_i {

const llama_hparams hparams;

const llama_kv_cache_unified_context * mctx;
const llama_kv_cache_context * mctx;
};

class llm_graph_input_out_ids : public llm_graph_input_i {
Expand Down Expand Up @@ -257,17 +257,17 @@ class llm_graph_input_attn_no_cache : public llm_graph_input_i {
const llama_cparams cparams;
};

class llm_graph_input_attn_kv_unified : public llm_graph_input_i {
class llm_graph_input_attn_kv : public llm_graph_input_i {
public:
llm_graph_input_attn_kv_unified(
llm_graph_input_attn_kv(
const llama_hparams & hparams,
const llama_cparams & cparams,
const llama_kv_cache_unified_context * mctx) :
const llama_kv_cache_context * mctx) :
hparams(hparams),
cparams(cparams),
mctx(mctx) {
}
~llm_graph_input_attn_kv_unified() = default;
~llm_graph_input_attn_kv() = default;

void set_input(const llama_ubatch * ubatch) override;

Expand All @@ -290,20 +290,20 @@ class llm_graph_input_attn_kv_unified : public llm_graph_input_i {
const llama_hparams hparams;
const llama_cparams cparams;

const llama_kv_cache_unified_context * mctx;
const llama_kv_cache_context * mctx;
};

class llm_graph_input_attn_kv_unified_iswa : public llm_graph_input_i {
class llm_graph_input_attn_kv_iswa : public llm_graph_input_i {
public:
llm_graph_input_attn_kv_unified_iswa(
llm_graph_input_attn_kv_iswa(
const llama_hparams & hparams,
const llama_cparams & cparams,
const llama_kv_cache_unified_iswa_context * mctx) :
const llama_kv_cache_iswa_context * mctx) :
hparams(hparams),
cparams(cparams),
mctx(mctx) {
}
~llm_graph_input_attn_kv_unified_iswa() = default;
~llm_graph_input_attn_kv_iswa() = default;

void set_input(const llama_ubatch * ubatch) override;

Expand All @@ -330,7 +330,7 @@ class llm_graph_input_attn_kv_unified_iswa : public llm_graph_input_i {
const llama_hparams hparams;
const llama_cparams cparams;

const llama_kv_cache_unified_iswa_context * mctx;
const llama_kv_cache_iswa_context * mctx;
};

class llm_graph_input_attn_cross : public llm_graph_input_i {
Expand All @@ -351,7 +351,7 @@ class llm_graph_input_attn_cross : public llm_graph_input_i {
class llm_graph_input_mem_hybrid : public llm_graph_input_i {
public:
llm_graph_input_mem_hybrid(
std::unique_ptr<llm_graph_input_attn_kv_unified> inp_attn,
std::unique_ptr<llm_graph_input_attn_kv> inp_attn,
std::unique_ptr<llm_graph_input_rs> inp_rs,
const llama_memory_hybrid_context * mctx) :
inp_attn(std::move(inp_attn)),
Expand All @@ -361,11 +361,11 @@ class llm_graph_input_mem_hybrid : public llm_graph_input_i {

void set_input(const llama_ubatch * ubatch) override;

std::unique_ptr<llm_graph_input_attn_kv_unified> inp_attn;
std::unique_ptr<llm_graph_input_rs> inp_rs;
std::unique_ptr<llm_graph_input_attn_kv> inp_attn;
std::unique_ptr<llm_graph_input_rs> inp_rs;

llm_graph_input_attn_kv_unified * get_attn() const { return inp_attn.get(); }
llm_graph_input_rs * get_recr() const { return inp_rs.get(); }
llm_graph_input_attn_kv * get_attn() const { return inp_attn.get(); }
llm_graph_input_rs * get_recr() const { return inp_rs.get(); }

const llama_memory_hybrid_context * mctx;
};
Expand Down Expand Up @@ -703,10 +703,10 @@ struct llm_graph_context {
float kq_scale,
int il) const;

llm_graph_input_attn_kv_unified * build_attn_inp_kv_unified() const;
llm_graph_input_attn_kv * build_attn_inp_kv() const;

ggml_tensor * build_attn(
llm_graph_input_attn_kv_unified * inp,
llm_graph_input_attn_kv * inp,
ggml_tensor * wo,
ggml_tensor * wo_b,
ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
Expand All @@ -717,11 +717,11 @@ struct llm_graph_context {
float kq_scale,
int il) const;

llm_graph_input_attn_kv_unified_iswa * build_attn_inp_kv_unified_iswa() const;
llm_graph_input_attn_kv_iswa * build_attn_inp_kv_iswa() const;

// note: if k_cur or v_cur are not provided, they will not be stored in the memory
ggml_tensor * build_attn(
llm_graph_input_attn_kv_unified_iswa * inp,
llm_graph_input_attn_kv_iswa * inp,
ggml_tensor * wo,
ggml_tensor * wo_b,
ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
Expand All @@ -734,7 +734,7 @@ struct llm_graph_context {

// TODO: temporary to keep the diff small. after the code is public will refactor to simplify this
ggml_tensor * build_attn_with_sinks(
llm_graph_input_attn_kv_unified_iswa * inp,
llm_graph_input_attn_kv_iswa * inp,
ggml_tensor * wo,
ggml_tensor * wo_b,
ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
Expand Down Expand Up @@ -765,7 +765,7 @@ struct llm_graph_context {
//

// TODO: move this implementation to llama_memory_recurrent.
// this is analogous to llama_kv_cache_unified::cpy_k / cpy_v
// this is analogous to llama_kv_cache::cpy_k / cpy_v
// when moving, avoid passing `ggml_cgraph` - only pass `ggml_context`. would likely need to split the
// implementation in 2 separate methods. the goal is to avoid calling `ggml_build_forward_expand` in
// `llama_memory_recurrent`
Expand Down
Loading