|
17 | 17 |
|
18 | 18 | #include <executorch/examples/models/llama/tokenizer/llama_tiktoken.h> |
19 | 19 | #include <pytorch/tokenizers/llama2c_tokenizer.h> |
| 20 | +#include <pytorch/tokenizers/hf_tokenizer.h> |
20 | 21 |
|
21 | 22 | namespace example { |
22 | 23 |
|
@@ -75,24 +76,33 @@ Error Runner::load() { |
75 | 76 | return Error::Ok; |
76 | 77 | } |
77 | 78 | ET_CHECK_OK_OR_RETURN_ERROR(module_->load_method("forward")); |
78 | | - // load tokenizer. Assuming tiktoken is the default tokenizer |
| 79 | + // Load tokenizer. |
79 | 80 | tokenizer_ = nullptr; |
80 | | - tokenizer_ = get_tiktoken_for_llama(); |
81 | | - ::tokenizers::Error err = tokenizer_->load(tokenizer_path_); |
82 | | - // Rely on tiktoken to throw error if the artifact is incompatible. Then we |
83 | | - // fallback to BPE tokenizer. |
84 | | - if (err != ::tokenizers::Error::Ok) { |
| 81 | + // Check if tokenizer_path_ ends with ".json". |
| 82 | + if (tokenizer_path_.size() >= 5 && |
| 83 | + tokenizer_path_.compare(tokenizer_path_.size() - 5, 5, ".json") == 0) { |
| 84 | + tokenizer_ = std::make_unique<tokenizers::HFTokenizer>(); |
| 85 | + tokenizer_->load(tokenizer_path_); |
85 | 86 | ET_LOG( |
86 | | - Info, |
87 | | - "Failed to load %s as a Tiktoken artifact, trying BPE tokenizer", |
88 | | - tokenizer_path_.c_str()); |
89 | | - tokenizer_.reset(); |
90 | | - tokenizer_ = std::make_unique<::tokenizers::Llama2cTokenizer>(); |
91 | | - err = tokenizer_->load(tokenizer_path_); |
92 | | - ET_CHECK_TK_OK_OR_RETURN_ERROR( |
93 | | - err, |
94 | | - "Failed to load %s as a llama2.c tokenizer artifact", |
95 | | - tokenizer_path_.c_str()); |
| 87 | + Info, "Loaded tokenizer %s as HF tokenizer", tokenizer_path_.c_str()); |
| 88 | + } else { |
| 89 | + ::tokenizers::Error err = tokenizer_->load(tokenizer_path_); |
| 90 | + tokenizer_ = get_tiktoken_for_llama(); |
| 91 | + // Rely on tiktoken to throw error if the artifact is incompatible. Then we |
| 92 | + // fallback to BPE tokenizer. |
| 93 | + if (err != ::tokenizers::Error::Ok) { |
| 94 | + ET_LOG( |
| 95 | + Info, |
| 96 | + "Failed to load %s as a Tiktoken artifact, trying BPE tokenizer", |
| 97 | + tokenizer_path_.c_str()); |
| 98 | + tokenizer_.reset(); |
| 99 | + tokenizer_ = std::make_unique<::tokenizers::Llama2cTokenizer>(); |
| 100 | + err = tokenizer_->load(tokenizer_path_); |
| 101 | + ET_CHECK_TK_OK_OR_RETURN_ERROR( |
| 102 | + err, |
| 103 | + "Failed to load %s as a llama2.c tokenizer artifact", |
| 104 | + tokenizer_path_.c_str()); |
| 105 | + } |
96 | 106 | } |
97 | 107 |
|
98 | 108 | ET_LOG(Info, "Reading metadata from model"); |
|
0 commit comments