Skip to content

Commit 75dd028

Browse files
committed
Scott pr review
1 parent 0883f32 commit 75dd028

File tree

1 file changed

+42
-28
lines changed

1 file changed

+42
-28
lines changed

examples/models/llama/runner/runner.cpp

Lines changed: 42 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@
1616
#include <executorch/extension/llm/runner/util.h>
1717

1818
#include <executorch/examples/models/llama/tokenizer/llama_tiktoken.h>
19-
#include <pytorch/tokenizers/llama2c_tokenizer.h>
2019
#include <pytorch/tokenizers/hf_tokenizer.h>
20+
#include <pytorch/tokenizers/llama2c_tokenizer.h>
2121

2222
namespace example {
2323

@@ -35,6 +35,40 @@ static constexpr auto kMaxSeqLen = "get_max_seq_len";
3535
static constexpr auto kVocabSize = "get_vocab_size";
3636
static constexpr auto kUseKVCache = "use_kv_cache";
3737
static constexpr auto kUseSDPAWithKVCache = "use_sdpa_with_kv_cache";
38+
39+
std::unique_ptr<::tokenizers::Tokenizer> load_tokenizer(
40+
std::string tokenizer_path) {
41+
std::unique_ptr<::tokenizers::Tokenizer> tokenizer = nullptr;
42+
::tokenizers::Error err;
43+
44+
// First try to load as a json tokenizer.
45+
tokenizer = std::make_unique<tokenizers::HFTokenizer>();
46+
err = tokenizer->load(tokenizer_path);
47+
if (err == ::tokenizers::Error::Ok) {
48+
ET_LOG(Info, "Loaded json tokenizer");
49+
return std::move(tokenizer);
50+
}
51+
52+
// Try to load as tiktoken tokenizer.
53+
tokenizer.reset();
54+
tokenizer = get_tiktoken_for_llama();
55+
err = tokenizer->load(tokenizer_path);
56+
if (err == ::tokenizers::Error::Ok) {
57+
ET_LOG(Info, "Loaded TikToken tokenizer");
58+
return std::move(tokenizer);
59+
}
60+
61+
// Try to load as BPE tokenizer.
62+
tokenizer.reset();
63+
tokenizer = std::make_unique<::tokenizers::Llama2cTokenizer>();
64+
err = tokenizer->load(tokenizer_path);
65+
if (err == ::tokenizers::Error::Ok) {
66+
ET_LOG(Info, "Loaded BPE tokenizer");
67+
return std::move(tokenizer);
68+
}
69+
70+
return nullptr;
71+
}
3872
} // namespace
3973

4074
Runner::Runner(
@@ -76,35 +110,15 @@ Error Runner::load() {
76110
return Error::Ok;
77111
}
78112
ET_CHECK_OK_OR_RETURN_ERROR(module_->load_method("forward"));
113+
79114
// Load tokenizer.
80-
tokenizer_ = nullptr;
81-
// Check if tokenizer_path_ ends with ".json".
82-
if (tokenizer_path_.size() >= 5 &&
83-
84-
tokenizer_path_.compare(tokenizer_path_.size() - 5, 5, ".json") == 0) {
85-
tokenizer_ = std::make_unique<tokenizers::HFTokenizer>();
86-
ET_LOG(Info, "Loading json tokenizer");
87-
tokenizer_->load(tokenizer_path_);
115+
tokenizer_ = load_tokenizer(tokenizer_path_);
116+
if (tokenizer_ == nullptr) {
88117
ET_LOG(
89-
Info, "Loaded tokenizer %s as HF tokenizer", tokenizer_path_.c_str());
90-
} else {
91-
::tokenizers::Error err = tokenizer_->load(tokenizer_path_);
92-
tokenizer_ = get_tiktoken_for_llama();
93-
// Rely on tiktoken to throw error if the artifact is incompatible. Then we
94-
// fallback to BPE tokenizer.
95-
if (err != ::tokenizers::Error::Ok) {
96-
ET_LOG(
97-
Info,
98-
"Failed to load %s as a Tiktoken artifact, trying BPE tokenizer",
99-
tokenizer_path_.c_str());
100-
tokenizer_.reset();
101-
tokenizer_ = std::make_unique<::tokenizers::Llama2cTokenizer>();
102-
err = tokenizer_->load(tokenizer_path_);
103-
ET_CHECK_TK_OK_OR_RETURN_ERROR(
104-
err,
105-
"Failed to load %s as a llama2.c tokenizer artifact",
106-
tokenizer_path_.c_str());
107-
}
118+
Error,
119+
"Failed to load %s as a llama2.c tokenizer artifact",
120+
tokenizer_path_.c_str());
121+
return ::executorch::runtime::Error::InvalidArgument;
108122
}
109123

110124
ET_LOG(Info, "Reading metadata from model");

0 commit comments

Comments
 (0)