1616#include < executorch/extension/llm/runner/util.h>
1717
1818#include < executorch/examples/models/llama/tokenizer/llama_tiktoken.h>
19- #include < pytorch/tokenizers/llama2c_tokenizer.h>
2019#include < pytorch/tokenizers/hf_tokenizer.h>
20+ #include < pytorch/tokenizers/llama2c_tokenizer.h>
2121
2222namespace example {
2323
@@ -35,6 +35,40 @@ static constexpr auto kMaxSeqLen = "get_max_seq_len";
3535static constexpr auto kVocabSize = " get_vocab_size" ;
3636static constexpr auto kUseKVCache = " use_kv_cache" ;
3737static constexpr auto kUseSDPAWithKVCache = " use_sdpa_with_kv_cache" ;
38+
39+ std::unique_ptr<::tokenizers::Tokenizer> load_tokenizer (
40+ std::string tokenizer_path) {
41+ std::unique_ptr<::tokenizers::Tokenizer> tokenizer = nullptr ;
42+ ::tokenizers::Error err;
43+
44+ // First try to load as a json tokenizer.
45+ tokenizer = std::make_unique<tokenizers::HFTokenizer>();
46+ err = tokenizer->load (tokenizer_path);
47+ if (err == ::tokenizers::Error::Ok) {
48+ ET_LOG (Info, " Loaded json tokenizer" );
49+ return std::move (tokenizer);
50+ }
51+
52+ // Try to load as tiktoken tokenizer.
53+ tokenizer.reset ();
54+ tokenizer = get_tiktoken_for_llama ();
55+ err = tokenizer->load (tokenizer_path);
56+ if (err == ::tokenizers::Error::Ok) {
57+ ET_LOG (Info, " Loaded TikToken tokenizer" );
58+ return std::move (tokenizer);
59+ }
60+
61+ // Try to load as BPE tokenizer.
62+ tokenizer.reset ();
63+ tokenizer = std::make_unique<::tokenizers::Llama2cTokenizer>();
64+ err = tokenizer->load (tokenizer_path);
65+ if (err == ::tokenizers::Error::Ok) {
66+ ET_LOG (Info, " Loaded BPE tokenizer" );
67+ return std::move (tokenizer);
68+ }
69+
70+ return nullptr ;
71+ }
3872} // namespace
3973
4074Runner::Runner (
@@ -76,35 +110,15 @@ Error Runner::load() {
76110 return Error::Ok;
77111 }
78112 ET_CHECK_OK_OR_RETURN_ERROR (module_->load_method (" forward" ));
113+
79114 // Load tokenizer.
80- tokenizer_ = nullptr ;
81- // Check if tokenizer_path_ ends with ".json".
82- if (tokenizer_path_.size () >= 5 &&
83-
84- tokenizer_path_.compare (tokenizer_path_.size () - 5 , 5 , " .json" ) == 0 ) {
85- tokenizer_ = std::make_unique<tokenizers::HFTokenizer>();
86- ET_LOG (Info, " Loading json tokenizer" );
87- tokenizer_->load (tokenizer_path_);
115+ tokenizer_ = load_tokenizer (tokenizer_path_);
116+ if (tokenizer_ == nullptr ) {
88117 ET_LOG (
89- Info, " Loaded tokenizer %s as HF tokenizer" , tokenizer_path_.c_str ());
90- } else {
91- ::tokenizers::Error err = tokenizer_->load (tokenizer_path_);
92- tokenizer_ = get_tiktoken_for_llama ();
93- // Rely on tiktoken to throw error if the artifact is incompatible. Then we
94- // fallback to BPE tokenizer.
95- if (err != ::tokenizers::Error::Ok) {
96- ET_LOG (
97- Info,
98- " Failed to load %s as a Tiktoken artifact, trying BPE tokenizer" ,
99- tokenizer_path_.c_str ());
100- tokenizer_.reset ();
101- tokenizer_ = std::make_unique<::tokenizers::Llama2cTokenizer>();
102- err = tokenizer_->load (tokenizer_path_);
103- ET_CHECK_TK_OK_OR_RETURN_ERROR (
104- err,
105- " Failed to load %s as a llama2.c tokenizer artifact" ,
106- tokenizer_path_.c_str ());
107- }
118+ Error,
119+ " Failed to load %s as a llama2.c tokenizer artifact" ,
120+ tokenizer_path_.c_str ());
121+ return ::executorch::runtime::Error::InvalidArgument;
108122 }
109123
110124 ET_LOG (Info, " Reading metadata from model" );
0 commit comments