11
11
12
12
#include < executorch/examples/models/llama/runner/runner.h>
13
13
14
- #include < algorithm>
15
- #include < ctime>
16
-
17
14
#include < executorch/extension/llm/runner/util.h>
18
15
19
16
#include < executorch/examples/models/llama/tokenizer/llama_tiktoken.h>
@@ -35,129 +32,161 @@ static constexpr auto kMaxSeqLen = "get_max_seq_len";
35
32
static constexpr auto kMaxContextLen = " get_max_context_len" ;
36
33
static constexpr auto kVocabSize = " get_vocab_size" ;
37
34
static constexpr auto kUseKVCache = " use_kv_cache" ;
38
- static constexpr auto kUseSDPAWithKVCache = " use_sdpa_with_kv_cache" ;
39
35
} // namespace
40
36
41
- Runner:: Runner (
37
+ std::unique_ptr< Runner> Runner::create (
42
38
const std::string& model_path,
43
39
const std::string& tokenizer_path,
44
- std::optional<const std::string> data_path)
45
- // NOTE: we observed ~2x loading performance increase on iPhone 15
46
- // and a ~5% improvement on Galaxy S22 by switching to
47
- // FileDataLoader instead of MmapDataLoader + UseMlockIgnoreErrors.
48
- : tokenizer_path_(tokenizer_path),
49
- metadata_ ({
50
- {kEnableDynamicShape , false },
51
- {kMaxSeqLen , 128 },
52
- {kMaxContextLen , 128 },
53
- {kUseKVCache , true },
54
- {kUseSDPAWithKVCache , false },
55
- }) {
56
- if (data_path.has_value ()) {
57
- module_ = std::make_unique<Module>(
58
- model_path, data_path.value (), Module::LoadMode::File);
59
- } else {
60
- module_ = std::make_unique<Module>(model_path, Module::LoadMode::File);
61
- }
40
+ std::optional<const std::string> data_path,
41
+ float temperature) {
62
42
ET_LOG (
63
43
Info,
64
44
" Creating LLaMa runner: model_path=%s, tokenizer_path=%s" ,
65
45
model_path.c_str (),
66
46
tokenizer_path.c_str ());
67
- }
68
47
69
- [[deprecated(
70
- " This constructor is deprecated. Use the constructor without temperature parameter instead." )]]
71
- Runner::Runner (
72
- const std::string& model_path,
73
- const std::string& tokenizer_path,
74
- const float temperature,
75
- std::optional<const std::string> data_path)
76
- : Runner(model_path, tokenizer_path, std::move(data_path)) {
77
- temperature_ = temperature;
78
- }
48
+ // Create the Module
49
+ std::unique_ptr<Module> module;
50
+ if (data_path.has_value ()) {
51
+ module = std::make_unique<Module>(
52
+ model_path, data_path.value (), Module::LoadMode::File);
53
+ } else {
54
+ module = std::make_unique<Module>(model_path, Module::LoadMode::File);
55
+ }
79
56
80
- bool Runner::is_loaded () const {
81
- return module_->is_loaded () && tokenizer_ && text_decoder_runner_ &&
82
- text_prefiller_ && text_token_generator_;
83
- }
57
+ // Initialize metadata with default values
58
+ std::unordered_map<std::string, int64_t > metadata ({
59
+ {kEnableDynamicShape , false },
60
+ {kMaxSeqLen , 128 },
61
+ {kMaxContextLen , 128 },
62
+ {kUseKVCache , true },
63
+ });
84
64
85
- Error Runner::load () {
86
- if (is_loaded ()) {
87
- return Error::Ok;
88
- }
89
- ET_CHECK_OK_OR_RETURN_ERROR (module_->load_method (" forward" ));
90
- // load tokenizer. Assuming tiktoken is the default tokenizer
91
- tokenizer_ = nullptr ;
92
- tokenizer_ = get_tiktoken_for_llama ();
93
- ::tokenizers::Error err = tokenizer_->load (tokenizer_path_);
94
- // Rely on tiktoken to throw error if the artifact is incompatible. Then we
95
- // fallback to BPE tokenizer.
96
- if (err != ::tokenizers::Error::Ok) {
65
+ // Create and load tokenizer
66
+ std::unique_ptr<::tokenizers::Tokenizer> tokenizer = get_tiktoken_for_llama ();
67
+ ::tokenizers::Error tk_err = tokenizer->load (tokenizer_path);
68
+
69
+ // Fallback to BPE tokenizer if tiktoken fails
70
+ if (tk_err != ::tokenizers::Error::Ok) {
97
71
ET_LOG (
98
72
Info,
99
73
" Failed to load %s as a Tiktoken artifact, trying BPE tokenizer" ,
100
- tokenizer_path_.c_str ());
101
- tokenizer_.reset ();
102
- tokenizer_ = std::make_unique<::tokenizers::Llama2cTokenizer>();
103
- err = tokenizer_->load (tokenizer_path_);
104
- ET_CHECK_TK_OK_OR_RETURN_ERROR (
105
- err,
106
- " Failed to load %s as a llama2.c tokenizer artifact" ,
107
- tokenizer_path_.c_str ());
74
+ tokenizer_path.c_str ());
75
+ tokenizer.reset ();
76
+ tokenizer = std::make_unique<::tokenizers::Llama2cTokenizer>();
77
+ tk_err = tokenizer->load (tokenizer_path);
78
+ if (tk_err != ::tokenizers::Error::Ok) {
79
+ ET_LOG (
80
+ Error,
81
+ " Failed to load %s as a llama2.c tokenizer artifact" ,
82
+ tokenizer_path.c_str ());
83
+ return nullptr ;
84
+ }
108
85
}
109
86
110
87
ET_LOG (Info, " Reading metadata from model" );
111
88
112
- metadata_[kBosId ] = tokenizer_->bos_tok ();
89
+ // Set tokenizer-related metadata
90
+ metadata[kBosId ] = tokenizer->bos_tok ();
113
91
auto eos_ids = std::make_unique<std::unordered_set<uint64_t >>(
114
- std::unordered_set<uint64_t >{tokenizer_->eos_tok ()});
115
- metadata_[kVocabSize ] = tokenizer_->vocab_size ();
116
-
117
- const auto method_names =
118
- ET_UNWRAP (module_->method_names (), " Failed reading method names" );
92
+ std::unordered_set<uint64_t >{tokenizer->eos_tok ()});
93
+ metadata[kVocabSize ] = tokenizer->vocab_size ();
94
+
95
+ // Read metadata from the model
96
+ auto method_names_result = module->method_names ();
97
+ if (method_names_result.error () != Error::Ok) {
98
+ ET_LOG (Error, " Failed reading method names" );
99
+ return nullptr ;
100
+ }
101
+ const auto method_names = method_names_result.get ();
119
102
120
- for (auto & pair : metadata_ ) {
103
+ for (auto & pair : metadata ) {
121
104
const auto & method_name = pair.first ;
122
105
auto & value = pair.second ;
123
106
124
107
if (method_names.count (method_name)) {
125
- value = ET_UNWRAP (module_->get (method_name))
126
- .toScalar ()
127
- .to <decltype (metadata_)::mapped_type>();
108
+ auto get_result = module->get (method_name);
109
+ value = get_result.get ().toScalar ().to <decltype (metadata)::mapped_type>();
128
110
} else {
129
111
ET_LOG (
130
112
Info,
131
- " Methond %s not found, using the default value %" PRId64,
113
+ " Method %s not found, using the default value %" PRId64,
132
114
method_name.c_str (),
133
115
value);
134
116
}
135
117
ET_LOG (Info, " Metadata: %s = %" PRId64, method_name.c_str (), value);
136
118
}
119
+
120
+ // Get EOS IDs if available
137
121
if (method_names.count (kEosIds )) {
138
122
eos_ids->clear ();
139
- for (const auto & eos_id : ET_UNWRAP (module_->execute (kEosIds ))) {
123
+ auto execute_result = module->execute (kEosIds );
124
+ if (execute_result.error () != Error::Ok) {
125
+ ET_LOG (Error, " Failed to execute %s" , kEosIds );
126
+ return nullptr ;
127
+ }
128
+ for (const auto & eos_id : execute_result.get ()) {
140
129
auto value = eos_id.toScalar ().to <int64_t >();
141
130
eos_ids->emplace (value);
142
131
ET_LOG (Info, " eos_id = %" PRId64, value);
143
132
}
144
133
}
145
- // @lint-ignore CLANGTIDY facebook-hte-Deprecated
146
- text_decoder_runner_ = std::make_unique<llm::TextDecoderRunner>(
147
- module_.get (), metadata_.at (kUseKVCache ));
148
- text_prefiller_ = std::make_unique<llm::TextPrefiller>(
149
- text_decoder_runner_.get (),
150
- metadata_.at (kUseKVCache ),
151
- metadata_.at (kEnableDynamicShape ),
152
- metadata_.at (kMaxSeqLen ));
153
-
154
- text_token_generator_ = std::make_unique<llm::TextTokenGenerator>(
155
- tokenizer_.get (),
156
- text_decoder_runner_.get (),
157
- metadata_.at (kUseKVCache ),
134
+
135
+ // Create text_decoder_runner
136
+ auto text_decoder_runner = std::make_unique<llm::TextDecoderRunner>(
137
+ module.get (), metadata.at (kUseKVCache ));
138
+
139
+ // Create text_prefiller
140
+ auto text_prefiller = std::make_unique<llm::TextPrefiller>(
141
+ text_decoder_runner.get (),
142
+ metadata.at (kUseKVCache ),
143
+ metadata.at (kEnableDynamicShape ),
144
+ metadata.at (kMaxSeqLen ));
145
+
146
+ // Create text_token_generator with stats
147
+ auto stats = new llm::Stats ();
148
+ auto text_token_generator = std::make_unique<llm::TextTokenGenerator>(
149
+ tokenizer.get (),
150
+ text_decoder_runner.get (),
151
+ metadata.at (kUseKVCache ),
158
152
std::move (eos_ids),
159
- &stats_);
153
+ stats);
154
+
155
+ // Create and return the Runner instance
156
+ return std::make_unique<Runner>(
157
+ std::move (metadata),
158
+ std::move (tokenizer),
159
+ std::move (text_prefiller),
160
+ std::move (text_token_generator),
161
+ temperature);
162
+ }
160
163
164
+ Runner::Runner (
165
+ std::unordered_map<std::string, int64_t > metadata,
166
+ std::unique_ptr<::tokenizers::Tokenizer> tokenizer,
167
+ std::unique_ptr<::executorch::extension::llm::TextPrefiller> text_prefiller,
168
+ std::unique_ptr<::executorch::extension::llm::TextTokenGenerator>
169
+ text_token_generator,
170
+ float temperature)
171
+ : tokenizer_(std::move(tokenizer)),
172
+ metadata_ (std::move(metadata)),
173
+ text_prefiller_(std::move(text_prefiller)),
174
+ text_token_generator_(std::move(text_token_generator)),
175
+ temperature_(temperature) {
176
+ // Note: This constructor assumes that text_prefiller and text_token_generator
177
+ // already have references to the Module and TextDecoderRunner they need
178
+ }
179
+
180
+ bool Runner::is_loaded () const {
181
+ return text_prefiller_->is_loaded () && text_token_generator_->is_loaded ();
182
+ }
183
+
184
+ Error Runner::load () {
185
+ if (is_loaded ()) {
186
+ return Error::Ok;
187
+ }
188
+ ET_CHECK_OK_OR_RETURN_ERROR (text_prefiller_->load ());
189
+ ET_CHECK_OK_OR_RETURN_ERROR (text_token_generator_->load ());
161
190
return Error::Ok;
162
191
}
163
192
0 commit comments