Skip to content

Commit 5921b8f

Browse files
authored
llama : cache llama_token_to_piece (ggml-org#7587)
* llama : cache llama_token_to_piece ggml-ci * llama : use vectors and avoid has_cache ggml-ci * llama : throw on unknown tokenizer types ggml-ci * llama : print a log of the total cache size
1 parent 5dcdf94 commit 5921b8f

File tree

2 files changed

+119
-84
lines changed

2 files changed

+119
-84
lines changed

llama.cpp

+117-82
Original file line numberDiff line numberDiff line change
@@ -1702,12 +1702,13 @@ struct llama_mlock {
17021702
};
17031703
using llama_mlocks = std::vector<std::unique_ptr<llama_mlock>>;
17041704

1705-
static std::string llama_token_to_piece(const struct llama_context * ctx, llama_token token, bool special) {
1705+
// NOTE: avoid ever using this except for building the token_to_piece caches
1706+
static std::string llama_token_to_piece(const struct llama_model * model, llama_token token, bool special) {
17061707
std::vector<char> result(8, 0);
1707-
const int n_tokens = llama_token_to_piece(llama_get_model(ctx), token, result.data(), result.size(), special);
1708+
const int n_tokens = llama_token_to_piece(model, token, result.data(), result.size(), special);
17081709
if (n_tokens < 0) {
17091710
result.resize(-n_tokens);
1710-
int check = llama_token_to_piece(llama_get_model(ctx), token, result.data(), result.size(), special);
1711+
int check = llama_token_to_piece(model, token, result.data(), result.size(), special);
17111712
GGML_ASSERT(check == -n_tokens);
17121713
}
17131714
else {
@@ -2162,7 +2163,9 @@ struct llama_vocab {
21622163
std::unordered_map<token, id> token_to_id;
21632164
std::vector<token_data> id_to_token;
21642165

2165-
std::vector<id> special_tokens_cache;
2166+
std::vector<id> cache_special_tokens;
2167+
std::vector<token> cache_token_to_piece; // llama_token_to_piece(special = false);
2168+
std::vector<token> cache_token_to_piece_special; // llama_token_to_piece(special = true);
21662169

21672170
std::map<std::pair<std::string, std::string>, int> bpe_ranks;
21682171

@@ -4592,20 +4595,14 @@ static void llm_load_vocab(
45924595
vocab.special_cls_id = 101;
45934596
vocab.special_mask_id = 103;
45944597
vocab.add_space_prefix = false;
4595-
} else {
4596-
if (tokenizer_model == "gpt2") {
4597-
vocab.type = LLAMA_VOCAB_TYPE_BPE;
4598+
} else if (tokenizer_model == "gpt2") {
4599+
vocab.type = LLAMA_VOCAB_TYPE_BPE;
45984600

4599-
const int add_space_prefix_keyidx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_ADD_PREFIX).c_str());
4600-
if (add_space_prefix_keyidx != -1) {
4601-
vocab.add_space_prefix = gguf_get_val_bool(ctx, add_space_prefix_keyidx);
4602-
}
4603-
} else {
4604-
LLAMA_LOG_WARN("%s: unknown tokenizer: '%s'", __func__, tokenizer_model.c_str());
4605-
LLAMA_LOG_WARN("%s: using default tokenizer: 'llama'", __func__);
4606-
vocab.type = LLAMA_VOCAB_TYPE_SPM;
4607-
return;
4601+
const int add_space_prefix_keyidx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_ADD_PREFIX).c_str());
4602+
if (add_space_prefix_keyidx != -1) {
4603+
vocab.add_space_prefix = gguf_get_val_bool(ctx, add_space_prefix_keyidx);
46084604
}
4605+
46094606
// read bpe merges and populate bpe ranks
46104607
const int merges_keyidx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_MERGES).c_str());
46114608
if (merges_keyidx == -1) {
@@ -4639,6 +4636,8 @@ static void llm_load_vocab(
46394636
vocab.special_pad_id = -1;
46404637
vocab.special_cls_id = -1;
46414638
vocab.special_mask_id = -1;
4639+
} else {
4640+
throw std::runtime_error(format("unknown tokenizer: '%s'", tokenizer_model.c_str()));
46424641
}
46434642

46444643
// for now, only BPE models have pre-tokenizers
@@ -4833,17 +4832,38 @@ static void llm_load_vocab(
48334832
{
48344833
for (llama_vocab::id id = 0; id < (llama_vocab::id)n_vocab; ++id) {
48354834
if (vocab.id_to_token[id].type != LLAMA_TOKEN_TYPE_NORMAL) {
4836-
vocab.special_tokens_cache.push_back(id);
4835+
vocab.cache_special_tokens.push_back(id);
48374836
}
48384837
}
48394838

4840-
std::sort( vocab.special_tokens_cache.begin(), vocab.special_tokens_cache.end(),
4839+
std::sort( vocab.cache_special_tokens.begin(), vocab.cache_special_tokens.end(),
48414840
[&] (const llama_vocab::id a, const llama_vocab::id b) {
48424841
return vocab.id_to_token[a].text.size() > vocab.id_to_token[b].text.size();
48434842
}
48444843
);
48454844

4846-
LLAMA_LOG_INFO("%s: special tokens cache size = %u.\n", __func__, (uint32_t)vocab.special_tokens_cache.size());
4845+
LLAMA_LOG_INFO("%s: special tokens cache size = %u\n", __func__, (uint32_t)vocab.cache_special_tokens.size());
4846+
}
4847+
4848+
// build token to piece caches
4849+
{
4850+
size_t size_cache = 0;
4851+
4852+
std::vector<llama_vocab::token> cache_token_to_piece (n_vocab);
4853+
std::vector<llama_vocab::token> cache_token_to_piece_special(n_vocab);
4854+
4855+
for (uint32_t id = 0; id < n_vocab; ++id) {
4856+
cache_token_to_piece[id] = llama_token_to_piece(&model, id, false);
4857+
cache_token_to_piece_special[id] = llama_token_to_piece(&model, id, true);
4858+
4859+
size_cache += cache_token_to_piece[id].size();
4860+
size_cache += cache_token_to_piece_special[id].size();
4861+
}
4862+
4863+
std::swap(vocab.cache_token_to_piece, cache_token_to_piece);
4864+
std::swap(vocab.cache_token_to_piece_special, cache_token_to_piece_special);
4865+
4866+
LLAMA_LOG_INFO("%s: token to piece cache size = %.4f MB\n", __func__, size_cache / 1024.0 / 1024.0);
48474867
}
48484868
}
48494869

@@ -13233,7 +13253,7 @@ struct fragment_buffer_variant {
1323313253

1323413254
static void tokenizer_st_partition(const llama_vocab & vocab, std::forward_list<fragment_buffer_variant> & buffer) {
1323513255
// for each special token
13236-
for (const llama_vocab::id special_id : vocab.special_tokens_cache) {
13256+
for (const llama_vocab::id special_id : vocab.cache_special_tokens) {
1323713257
const auto & special_token = vocab.id_to_token[special_id].text;
1323813258

1323913259
// for each text fragment
@@ -14392,7 +14412,7 @@ void llama_sample_repetition_penalties(
1439214412

1439314413
void llama_sample_grammar(struct llama_context * ctx, llama_token_data_array * candidates, const struct llama_grammar * grammar) {
1439414414
GGML_ASSERT(ctx);
14395-
const int64_t t_start_sample_us = ggml_time_us();
14415+
int64_t t_start_sample_us = ggml_time_us();
1439614416

1439714417
bool allow_eog = false;
1439814418
for (const auto & stack : grammar->stacks) {
@@ -14404,12 +14424,13 @@ void llama_sample_grammar(struct llama_context * ctx, llama_token_data_array * c
1440414424

1440514425
std::vector<std::pair<std::vector<uint32_t>, llama_partial_utf8>> candidates_decoded;
1440614426
candidates_decoded.reserve(candidates->size);
14407-
std::vector<llama_grammar_candidate> candidates_grammar;
14427+
14428+
std::vector<llama_grammar_candidate> candidates_grammar;
1440814429
candidates_grammar.reserve(candidates->size);
1440914430

1441014431
for (size_t i = 0; i < candidates->size; ++i) {
14411-
const llama_token id = candidates->data[i].id;
14412-
const std::string piece = llama_token_to_piece(ctx, id, false);
14432+
const llama_token id = candidates->data[i].id;
14433+
const std::string & piece = ctx->model.vocab.cache_token_to_piece.at(id);
1441314434

1441414435
if (llama_token_is_eog(&ctx->model, id)) {
1441514436
if (!allow_eog) {
@@ -14609,7 +14630,7 @@ void llama_grammar_accept_token(struct llama_context * ctx, struct llama_grammar
1460914630
GGML_ASSERT(false);
1461014631
}
1461114632

14612-
const std::string piece = llama_token_to_piece(ctx, token, false);
14633+
const std::string & piece = ctx->model.vocab.cache_token_to_piece.at(token);
1461314634

1461414635
// Note terminating 0 in decoded string
1461514636
const auto decoded = decode_utf8(piece, grammar->partial_utf8);
@@ -18292,69 +18313,83 @@ static std::string llama_decode_text(const std::string & text) {
1829218313

1829318314
// does not write null-terminator to buf
1829418315
int32_t llama_token_to_piece(const struct llama_model * model, llama_token token, char * buf, int32_t length, bool special) {
18316+
// if we have a cache - use it
18317+
{
18318+
const auto & cache = special ? model->vocab.cache_token_to_piece_special : model->vocab.cache_token_to_piece;
18319+
18320+
if (!cache.empty()) {
18321+
const auto & res = cache.at(token);
18322+
if (length < (int) res.size()) {
18323+
return -(int) res.size();
18324+
}
18325+
memcpy(buf, res.c_str(), res.size());
18326+
return res.size();
18327+
}
18328+
}
18329+
1829518330
if (0 <= token && token < llama_n_vocab(model)) {
1829618331
switch (llama_vocab_get_type(model->vocab)) {
18297-
case LLAMA_VOCAB_TYPE_WPM:
18298-
case LLAMA_VOCAB_TYPE_SPM: {
18299-
// NOTE: we accept all unsupported token types,
18300-
// suppressing them like CONTROL tokens.
18301-
if (llama_is_normal_token(model->vocab, token)) {
18302-
std::string result = model->vocab.id_to_token[token].text;
18303-
llama_unescape_whitespace(result);
18304-
if (length < (int) result.length()) {
18305-
return -(int) result.length();
18306-
}
18307-
memcpy(buf, result.c_str(), result.length());
18308-
return result.length();
18309-
} else if (
18310-
(llama_is_user_defined_token(model->vocab, token)) ||
18311-
(llama_is_control_token (model->vocab, token) && special)) {
18312-
std::string result = model->vocab.id_to_token[token].text;
18313-
if (length < (int) result.length()) {
18314-
return -(int) result.length();
18315-
}
18316-
memcpy(buf, result.c_str(), result.length());
18317-
return result.length();
18318-
} else if (llama_is_unknown_token(model->vocab, token)) { // NOLINT
18319-
if (length < 3) {
18320-
return -3;
18321-
}
18322-
memcpy(buf, "\xe2\x96\x85", 3);
18323-
return 3;
18324-
} else if (llama_is_byte_token(model->vocab, token)) {
18325-
if (length < 1) {
18326-
return -1;
18332+
case LLAMA_VOCAB_TYPE_WPM:
18333+
case LLAMA_VOCAB_TYPE_SPM: {
18334+
// NOTE: we accept all unsupported token types,
18335+
// suppressing them like CONTROL tokens.
18336+
if (llama_is_normal_token(model->vocab, token)) {
18337+
std::string result = model->vocab.id_to_token[token].text;
18338+
llama_unescape_whitespace(result);
18339+
if (length < (int) result.length()) {
18340+
return -(int) result.length();
18341+
}
18342+
memcpy(buf, result.c_str(), result.length());
18343+
return result.length();
18344+
} else if (
18345+
(llama_is_user_defined_token(model->vocab, token)) ||
18346+
(llama_is_control_token (model->vocab, token) && special)) {
18347+
std::string result = model->vocab.id_to_token[token].text;
18348+
if (length < (int) result.length()) {
18349+
return -(int) result.length();
18350+
}
18351+
memcpy(buf, result.c_str(), result.length());
18352+
return result.length();
18353+
} else if (llama_is_unknown_token(model->vocab, token)) { // NOLINT
18354+
if (length < 3) {
18355+
return -3;
18356+
}
18357+
memcpy(buf, "\xe2\x96\x85", 3);
18358+
return 3;
18359+
} else if (llama_is_byte_token(model->vocab, token)) {
18360+
if (length < 1) {
18361+
return -1;
18362+
}
18363+
buf[0] = llama_token_to_byte(model->vocab, token);
18364+
return 1;
1832718365
}
18328-
buf[0] = llama_token_to_byte(model->vocab, token);
18329-
return 1;
18366+
break;
1833018367
}
18331-
break;
18332-
}
18333-
case LLAMA_VOCAB_TYPE_BPE: {
18334-
// NOTE: we accept all unsupported token types,
18335-
// suppressing them like CONTROL tokens.
18336-
if (llama_is_normal_token(model->vocab, token)) {
18337-
std::string result = model->vocab.id_to_token[token].text;
18338-
result = llama_decode_text(result);
18339-
if (length < (int) result.length()) {
18340-
return -(int) result.length();
18341-
}
18342-
memcpy(buf, result.c_str(), result.length());
18343-
return result.length();
18344-
} else if (
18345-
(llama_is_user_defined_token(model->vocab, token)) ||
18346-
(llama_is_control_token (model->vocab, token) && special)) {
18347-
std::string result = model->vocab.id_to_token[token].text;
18348-
if (length < (int) result.length()) {
18349-
return -(int) result.length();
18368+
case LLAMA_VOCAB_TYPE_BPE: {
18369+
// NOTE: we accept all unsupported token types,
18370+
// suppressing them like CONTROL tokens.
18371+
if (llama_is_normal_token(model->vocab, token)) {
18372+
std::string result = model->vocab.id_to_token[token].text;
18373+
result = llama_decode_text(result);
18374+
if (length < (int) result.length()) {
18375+
return -(int) result.length();
18376+
}
18377+
memcpy(buf, result.c_str(), result.length());
18378+
return result.length();
18379+
} else if (
18380+
(llama_is_user_defined_token(model->vocab, token)) ||
18381+
(llama_is_control_token (model->vocab, token) && special)) {
18382+
std::string result = model->vocab.id_to_token[token].text;
18383+
if (length < (int) result.length()) {
18384+
return -(int) result.length();
18385+
}
18386+
memcpy(buf, result.c_str(), result.length());
18387+
return result.length();
1835018388
}
18351-
memcpy(buf, result.c_str(), result.length());
18352-
return result.length();
18389+
break;
1835318390
}
18354-
break;
18355-
}
18356-
default:
18357-
GGML_ASSERT(false);
18391+
default:
18392+
GGML_ASSERT(false);
1835818393
}
1835918394
}
1836018395
return 0;

llama.h

+2-2
Original file line numberDiff line numberDiff line change
@@ -424,8 +424,8 @@ extern "C" {
424424

425425
LLAMA_API enum llama_pooling_type llama_pooling_type(const struct llama_context * ctx);
426426

427-
LLAMA_API enum llama_vocab_type llama_vocab_type (const struct llama_model * model);
428-
LLAMA_API enum llama_rope_type llama_rope_type (const struct llama_model * model);
427+
LLAMA_API enum llama_vocab_type llama_vocab_type (const struct llama_model * model);
428+
LLAMA_API enum llama_rope_type llama_rope_type (const struct llama_model * model);
429429

430430
LLAMA_API int32_t llama_n_vocab (const struct llama_model * model);
431431
LLAMA_API int32_t llama_n_ctx_train(const struct llama_model * model);

0 commit comments

Comments
 (0)