@@ -1702,12 +1702,13 @@ struct llama_mlock {
1702
1702
};
1703
1703
using llama_mlocks = std::vector<std::unique_ptr<llama_mlock>>;
1704
1704
1705
- static std::string llama_token_to_piece(const struct llama_context * ctx, llama_token token, bool special) {
1705
+ // NOTE: avoid ever using this except for building the token_to_piece caches
1706
+ static std::string llama_token_to_piece(const struct llama_model * model, llama_token token, bool special) {
1706
1707
std::vector<char> result(8, 0);
1707
- const int n_tokens = llama_token_to_piece(llama_get_model(ctx) , token, result.data(), result.size(), special);
1708
+ const int n_tokens = llama_token_to_piece(model , token, result.data(), result.size(), special);
1708
1709
if (n_tokens < 0) {
1709
1710
result.resize(-n_tokens);
1710
- int check = llama_token_to_piece(llama_get_model(ctx) , token, result.data(), result.size(), special);
1711
+ int check = llama_token_to_piece(model , token, result.data(), result.size(), special);
1711
1712
GGML_ASSERT(check == -n_tokens);
1712
1713
}
1713
1714
else {
@@ -2162,7 +2163,9 @@ struct llama_vocab {
2162
2163
std::unordered_map<token, id> token_to_id;
2163
2164
std::vector<token_data> id_to_token;
2164
2165
2165
- std::vector<id> special_tokens_cache;
2166
+ std::vector<id> cache_special_tokens;
2167
+ std::vector<token> cache_token_to_piece; // llama_token_to_piece(special = false);
2168
+ std::vector<token> cache_token_to_piece_special; // llama_token_to_piece(special = true);
2166
2169
2167
2170
std::map<std::pair<std::string, std::string>, int> bpe_ranks;
2168
2171
@@ -4592,20 +4595,14 @@ static void llm_load_vocab(
4592
4595
vocab.special_cls_id = 101;
4593
4596
vocab.special_mask_id = 103;
4594
4597
vocab.add_space_prefix = false;
4595
- } else {
4596
- if (tokenizer_model == "gpt2") {
4597
- vocab.type = LLAMA_VOCAB_TYPE_BPE;
4598
+ } else if (tokenizer_model == "gpt2") {
4599
+ vocab.type = LLAMA_VOCAB_TYPE_BPE;
4598
4600
4599
- const int add_space_prefix_keyidx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_ADD_PREFIX).c_str());
4600
- if (add_space_prefix_keyidx != -1) {
4601
- vocab.add_space_prefix = gguf_get_val_bool(ctx, add_space_prefix_keyidx);
4602
- }
4603
- } else {
4604
- LLAMA_LOG_WARN("%s: unknown tokenizer: '%s'", __func__, tokenizer_model.c_str());
4605
- LLAMA_LOG_WARN("%s: using default tokenizer: 'llama'", __func__);
4606
- vocab.type = LLAMA_VOCAB_TYPE_SPM;
4607
- return;
4601
+ const int add_space_prefix_keyidx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_ADD_PREFIX).c_str());
4602
+ if (add_space_prefix_keyidx != -1) {
4603
+ vocab.add_space_prefix = gguf_get_val_bool(ctx, add_space_prefix_keyidx);
4608
4604
}
4605
+
4609
4606
// read bpe merges and populate bpe ranks
4610
4607
const int merges_keyidx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_MERGES).c_str());
4611
4608
if (merges_keyidx == -1) {
@@ -4639,6 +4636,8 @@ static void llm_load_vocab(
4639
4636
vocab.special_pad_id = -1;
4640
4637
vocab.special_cls_id = -1;
4641
4638
vocab.special_mask_id = -1;
4639
+ } else {
4640
+ throw std::runtime_error(format("unknown tokenizer: '%s'", tokenizer_model.c_str()));
4642
4641
}
4643
4642
4644
4643
// for now, only BPE models have pre-tokenizers
@@ -4833,17 +4832,38 @@ static void llm_load_vocab(
4833
4832
{
4834
4833
for (llama_vocab::id id = 0; id < (llama_vocab::id)n_vocab; ++id) {
4835
4834
if (vocab.id_to_token[id].type != LLAMA_TOKEN_TYPE_NORMAL) {
4836
- vocab.special_tokens_cache .push_back(id);
4835
+ vocab.cache_special_tokens .push_back(id);
4837
4836
}
4838
4837
}
4839
4838
4840
- std::sort( vocab.special_tokens_cache .begin(), vocab.special_tokens_cache .end(),
4839
+ std::sort( vocab.cache_special_tokens .begin(), vocab.cache_special_tokens .end(),
4841
4840
[&] (const llama_vocab::id a, const llama_vocab::id b) {
4842
4841
return vocab.id_to_token[a].text.size() > vocab.id_to_token[b].text.size();
4843
4842
}
4844
4843
);
4845
4844
4846
- LLAMA_LOG_INFO("%s: special tokens cache size = %u.\n", __func__, (uint32_t)vocab.special_tokens_cache.size());
4845
+ LLAMA_LOG_INFO("%s: special tokens cache size = %u\n", __func__, (uint32_t)vocab.cache_special_tokens.size());
4846
+ }
4847
+
4848
+ // build token to piece caches
4849
+ {
4850
+ size_t size_cache = 0;
4851
+
4852
+ std::vector<llama_vocab::token> cache_token_to_piece (n_vocab);
4853
+ std::vector<llama_vocab::token> cache_token_to_piece_special(n_vocab);
4854
+
4855
+ for (uint32_t id = 0; id < n_vocab; ++id) {
4856
+ cache_token_to_piece[id] = llama_token_to_piece(&model, id, false);
4857
+ cache_token_to_piece_special[id] = llama_token_to_piece(&model, id, true);
4858
+
4859
+ size_cache += cache_token_to_piece[id].size();
4860
+ size_cache += cache_token_to_piece_special[id].size();
4861
+ }
4862
+
4863
+ std::swap(vocab.cache_token_to_piece, cache_token_to_piece);
4864
+ std::swap(vocab.cache_token_to_piece_special, cache_token_to_piece_special);
4865
+
4866
+ LLAMA_LOG_INFO("%s: token to piece cache size = %.4f MB\n", __func__, size_cache / 1024.0 / 1024.0);
4847
4867
}
4848
4868
}
4849
4869
@@ -13233,7 +13253,7 @@ struct fragment_buffer_variant {
13233
13253
13234
13254
static void tokenizer_st_partition(const llama_vocab & vocab, std::forward_list<fragment_buffer_variant> & buffer) {
13235
13255
// for each special token
13236
- for (const llama_vocab::id special_id : vocab.special_tokens_cache ) {
13256
+ for (const llama_vocab::id special_id : vocab.cache_special_tokens ) {
13237
13257
const auto & special_token = vocab.id_to_token[special_id].text;
13238
13258
13239
13259
// for each text fragment
@@ -14392,7 +14412,7 @@ void llama_sample_repetition_penalties(
14392
14412
14393
14413
void llama_sample_grammar(struct llama_context * ctx, llama_token_data_array * candidates, const struct llama_grammar * grammar) {
14394
14414
GGML_ASSERT(ctx);
14395
- const int64_t t_start_sample_us = ggml_time_us();
14415
+ int64_t t_start_sample_us = ggml_time_us();
14396
14416
14397
14417
bool allow_eog = false;
14398
14418
for (const auto & stack : grammar->stacks) {
@@ -14404,12 +14424,13 @@ void llama_sample_grammar(struct llama_context * ctx, llama_token_data_array * c
14404
14424
14405
14425
std::vector<std::pair<std::vector<uint32_t>, llama_partial_utf8>> candidates_decoded;
14406
14426
candidates_decoded.reserve(candidates->size);
14407
- std::vector<llama_grammar_candidate> candidates_grammar;
14427
+
14428
+ std::vector<llama_grammar_candidate> candidates_grammar;
14408
14429
candidates_grammar.reserve(candidates->size);
14409
14430
14410
14431
for (size_t i = 0; i < candidates->size; ++i) {
14411
- const llama_token id = candidates->data[i].id;
14412
- const std::string piece = llama_token_to_piece( ctx, id, false );
14432
+ const llama_token id = candidates->data[i].id;
14433
+ const std::string & piece = ctx->model.vocab.cache_token_to_piece.at(id );
14413
14434
14414
14435
if (llama_token_is_eog(&ctx->model, id)) {
14415
14436
if (!allow_eog) {
@@ -14609,7 +14630,7 @@ void llama_grammar_accept_token(struct llama_context * ctx, struct llama_grammar
14609
14630
GGML_ASSERT(false);
14610
14631
}
14611
14632
14612
- const std::string piece = llama_token_to_piece( ctx, token, false );
14633
+ const std::string & piece = ctx->model.vocab.cache_token_to_piece.at( token);
14613
14634
14614
14635
// Note terminating 0 in decoded string
14615
14636
const auto decoded = decode_utf8(piece, grammar->partial_utf8);
@@ -18292,69 +18313,83 @@ static std::string llama_decode_text(const std::string & text) {
18292
18313
18293
18314
// does not write null-terminator to buf
18294
18315
int32_t llama_token_to_piece(const struct llama_model * model, llama_token token, char * buf, int32_t length, bool special) {
18316
+ // if we have a cache - use it
18317
+ {
18318
+ const auto & cache = special ? model->vocab.cache_token_to_piece_special : model->vocab.cache_token_to_piece;
18319
+
18320
+ if (!cache.empty()) {
18321
+ const auto & res = cache.at(token);
18322
+ if (length < (int) res.size()) {
18323
+ return -(int) res.size();
18324
+ }
18325
+ memcpy(buf, res.c_str(), res.size());
18326
+ return res.size();
18327
+ }
18328
+ }
18329
+
18295
18330
if (0 <= token && token < llama_n_vocab(model)) {
18296
18331
switch (llama_vocab_get_type(model->vocab)) {
18297
- case LLAMA_VOCAB_TYPE_WPM:
18298
- case LLAMA_VOCAB_TYPE_SPM: {
18299
- // NOTE: we accept all unsupported token types,
18300
- // suppressing them like CONTROL tokens.
18301
- if (llama_is_normal_token(model->vocab, token)) {
18302
- std::string result = model->vocab.id_to_token[token].text;
18303
- llama_unescape_whitespace(result);
18304
- if (length < (int) result.length()) {
18305
- return -(int) result.length();
18306
- }
18307
- memcpy(buf, result.c_str(), result.length());
18308
- return result.length();
18309
- } else if (
18310
- (llama_is_user_defined_token(model->vocab, token)) ||
18311
- (llama_is_control_token (model->vocab, token) && special)) {
18312
- std::string result = model->vocab.id_to_token[token].text;
18313
- if (length < (int) result.length()) {
18314
- return -(int) result.length();
18315
- }
18316
- memcpy(buf, result.c_str(), result.length());
18317
- return result.length();
18318
- } else if (llama_is_unknown_token(model->vocab, token)) { // NOLINT
18319
- if (length < 3) {
18320
- return -3;
18321
- }
18322
- memcpy(buf, "\xe2\x96\x85", 3);
18323
- return 3;
18324
- } else if (llama_is_byte_token(model->vocab, token)) {
18325
- if (length < 1) {
18326
- return -1;
18332
+ case LLAMA_VOCAB_TYPE_WPM:
18333
+ case LLAMA_VOCAB_TYPE_SPM: {
18334
+ // NOTE: we accept all unsupported token types,
18335
+ // suppressing them like CONTROL tokens.
18336
+ if (llama_is_normal_token(model->vocab, token)) {
18337
+ std::string result = model->vocab.id_to_token[token].text;
18338
+ llama_unescape_whitespace(result);
18339
+ if (length < (int) result.length()) {
18340
+ return -(int) result.length();
18341
+ }
18342
+ memcpy(buf, result.c_str(), result.length());
18343
+ return result.length();
18344
+ } else if (
18345
+ (llama_is_user_defined_token(model->vocab, token)) ||
18346
+ (llama_is_control_token (model->vocab, token) && special)) {
18347
+ std::string result = model->vocab.id_to_token[token].text;
18348
+ if (length < (int) result.length()) {
18349
+ return -(int) result.length();
18350
+ }
18351
+ memcpy(buf, result.c_str(), result.length());
18352
+ return result.length();
18353
+ } else if (llama_is_unknown_token(model->vocab, token)) { // NOLINT
18354
+ if (length < 3) {
18355
+ return -3;
18356
+ }
18357
+ memcpy(buf, "\xe2\x96\x85", 3);
18358
+ return 3;
18359
+ } else if (llama_is_byte_token(model->vocab, token)) {
18360
+ if (length < 1) {
18361
+ return -1;
18362
+ }
18363
+ buf[0] = llama_token_to_byte(model->vocab, token);
18364
+ return 1;
18327
18365
}
18328
- buf[0] = llama_token_to_byte(model->vocab, token);
18329
- return 1;
18366
+ break;
18330
18367
}
18331
- break;
18332
- }
18333
- case LLAMA_VOCAB_TYPE_BPE: {
18334
- // NOTE: we accept all unsupported token types,
18335
- // suppressing them like CONTROL tokens.
18336
- if (llama_is_normal_token(model->vocab, token)) {
18337
- std::string result = model->vocab.id_to_token[token].text;
18338
- result = llama_decode_text(result);
18339
- if (length < (int) result.length()) {
18340
- return -(int) result.length();
18341
- }
18342
- memcpy(buf, result.c_str(), result.length());
18343
- return result.length();
18344
- } else if (
18345
- (llama_is_user_defined_token(model->vocab, token)) ||
18346
- (llama_is_control_token (model->vocab, token) && special)) {
18347
- std::string result = model->vocab.id_to_token[token].text;
18348
- if (length < (int) result.length()) {
18349
- return -(int) result.length();
18368
+ case LLAMA_VOCAB_TYPE_BPE: {
18369
+ // NOTE: we accept all unsupported token types,
18370
+ // suppressing them like CONTROL tokens.
18371
+ if (llama_is_normal_token(model->vocab, token)) {
18372
+ std::string result = model->vocab.id_to_token[token].text;
18373
+ result = llama_decode_text(result);
18374
+ if (length < (int) result.length()) {
18375
+ return -(int) result.length();
18376
+ }
18377
+ memcpy(buf, result.c_str(), result.length());
18378
+ return result.length();
18379
+ } else if (
18380
+ (llama_is_user_defined_token(model->vocab, token)) ||
18381
+ (llama_is_control_token (model->vocab, token) && special)) {
18382
+ std::string result = model->vocab.id_to_token[token].text;
18383
+ if (length < (int) result.length()) {
18384
+ return -(int) result.length();
18385
+ }
18386
+ memcpy(buf, result.c_str(), result.length());
18387
+ return result.length();
18350
18388
}
18351
- memcpy(buf, result.c_str(), result.length());
18352
- return result.length();
18389
+ break;
18353
18390
}
18354
- break;
18355
- }
18356
- default:
18357
- GGML_ASSERT(false);
18391
+ default:
18392
+ GGML_ASSERT(false);
18358
18393
}
18359
18394
}
18360
18395
return 0;
0 commit comments