11
11
12
12
#include < executorch/examples/models/llama/runner/runner.h>
13
13
14
- #include < algorithm>
15
- #include < ctime>
16
-
17
14
#include < executorch/extension/llm/runner/util.h>
18
15
19
16
#include < executorch/examples/models/llama/tokenizer/llama_tiktoken.h>
@@ -35,130 +32,165 @@ static constexpr auto kMaxSeqLen = "get_max_seq_len";
35
32
static constexpr auto kMaxContextLen = " get_max_context_len" ;
36
33
static constexpr auto kVocabSize = " get_vocab_size" ;
37
34
static constexpr auto kUseKVCache = " use_kv_cache" ;
38
- static constexpr auto kUseSDPAWithKVCache = " use_sdpa_with_kv_cache" ;
39
35
} // namespace
40
36
41
- Runner:: Runner (
37
+ std::unique_ptr< Runner> Runner::create (
42
38
const std::string& model_path,
43
39
const std::string& tokenizer_path,
44
- std::optional<const std::string> data_path)
45
- // NOTE: we observed ~2x loading performance increase on iPhone 15
46
- // and a ~5% improvement on Galaxy S22 by switching to
47
- // FileDataLoader instead of MmapDataLoader + UseMlockIgnoreErrors.
48
- : tokenizer_path_(tokenizer_path),
49
- metadata_ ({
50
- {kEnableDynamicShape , false },
51
- {kMaxSeqLen , 128 },
52
- {kMaxContextLen , 128 },
53
- {kUseKVCache , true },
54
- {kUseSDPAWithKVCache , false },
55
- }) {
56
- if (data_path.has_value ()) {
57
- module_ = std::make_unique<Module>(
58
- model_path, data_path.value (), Module::LoadMode::File);
59
- } else {
60
- module_ = std::make_unique<Module>(model_path, Module::LoadMode::File);
61
- }
40
+ std::optional<const std::string> data_path,
41
+ float temperature) {
62
42
ET_LOG (
63
43
Info,
64
44
" Creating LLaMa runner: model_path=%s, tokenizer_path=%s" ,
65
45
model_path.c_str (),
66
46
tokenizer_path.c_str ());
67
- }
68
47
69
- [[deprecated(
70
- " This constructor is deprecated. Use the constructor without temperature parameter instead." )]]
71
- Runner::Runner (
72
- const std::string& model_path,
73
- const std::string& tokenizer_path,
74
- const float temperature,
75
- std::optional<const std::string> data_path)
76
- : Runner(model_path, tokenizer_path, std::move(data_path)) {
77
- temperature_ = temperature;
78
- }
48
+ // Create the Module
49
+ std::unique_ptr<Module> module;
50
+ if (data_path.has_value ()) {
51
+ module = std::make_unique<Module>(
52
+ model_path, data_path.value (), Module::LoadMode::File);
53
+ } else {
54
+ module = std::make_unique<Module>(model_path, Module::LoadMode::File);
55
+ }
79
56
80
- bool Runner::is_loaded () const {
81
- return module_->is_loaded () && tokenizer_ && text_decoder_runner_ &&
82
- text_prefiller_ && text_token_generator_;
83
- }
57
+ // Initialize metadata with default values
58
+ std::unordered_map<std::string, int64_t > metadata ({
59
+ {kEnableDynamicShape , false },
60
+ {kMaxSeqLen , 128 },
61
+ {kMaxContextLen , 128 },
62
+ {kUseKVCache , true },
63
+ });
84
64
85
- Error Runner::load () {
86
- if (is_loaded ()) {
87
- return Error::Ok;
88
- }
89
- ET_CHECK_OK_OR_RETURN_ERROR (module_->load_method (" forward" ));
90
- // load tokenizer. Assuming tiktoken is the default tokenizer
91
- tokenizer_ = nullptr ;
92
- tokenizer_ = get_tiktoken_for_llama ();
93
- ::tokenizers::Error err = tokenizer_->load (tokenizer_path_);
94
- // Rely on tiktoken to throw error if the artifact is incompatible. Then we
95
- // fallback to BPE tokenizer.
96
- if (err != ::tokenizers::Error::Ok) {
65
+ // Create and load tokenizer
66
+ std::unique_ptr<::tokenizers::Tokenizer> tokenizer = get_tiktoken_for_llama ();
67
+ ::tokenizers::Error tk_err = tokenizer->load (tokenizer_path);
68
+
69
+ // Fallback to BPE tokenizer if tiktoken fails
70
+ if (tk_err != ::tokenizers::Error::Ok) {
97
71
ET_LOG (
98
72
Info,
99
73
" Failed to load %s as a Tiktoken artifact, trying BPE tokenizer" ,
100
- tokenizer_path_.c_str ());
101
- tokenizer_.reset ();
102
- // @lint-ignore CLANGTIDY facebook-hte-Deprecated
103
- tokenizer_ = std::make_unique<::tokenizers::Llama2cTokenizer>();
104
- err = tokenizer_->load (tokenizer_path_);
105
- ET_CHECK_TK_OK_OR_RETURN_ERROR (
106
- err,
107
- " Failed to load %s as a llama2.c tokenizer artifact" ,
108
- tokenizer_path_.c_str ());
74
+ tokenizer_path.c_str ());
75
+ tokenizer.reset ();
76
+ tokenizer = std::make_unique<::tokenizers::Llama2cTokenizer>();
77
+ tk_err = tokenizer->load (tokenizer_path);
78
+ if (tk_err != ::tokenizers::Error::Ok) {
79
+ ET_LOG (
80
+ Error,
81
+ " Failed to load %s as a llama2.c tokenizer artifact" ,
82
+ tokenizer_path.c_str ());
83
+ return nullptr ;
84
+ }
109
85
}
110
86
111
87
ET_LOG (Info, " Reading metadata from model" );
112
88
113
- metadata_[kBosId ] = tokenizer_->bos_tok ();
89
+ // Set tokenizer-related metadata
90
+ metadata[kBosId ] = tokenizer->bos_tok ();
114
91
auto eos_ids = std::make_unique<std::unordered_set<uint64_t >>(
115
- std::unordered_set<uint64_t >{tokenizer_->eos_tok ()});
116
- metadata_[kVocabSize ] = tokenizer_->vocab_size ();
117
-
118
- const auto method_names =
119
- ET_UNWRAP (module_->method_names (), " Failed reading method names" );
92
+ std::unordered_set<uint64_t >{tokenizer->eos_tok ()});
93
+ metadata[kVocabSize ] = tokenizer->vocab_size ();
94
+
95
+ // Read metadata from the model
96
+ auto method_names_result = module->method_names ();
97
+ if (method_names_result.error () != Error::Ok) {
98
+ ET_LOG (Error, " Failed reading method names" );
99
+ return nullptr ;
100
+ }
101
+ const auto method_names = method_names_result.get ();
120
102
121
- for (auto & pair : metadata_ ) {
103
+ for (auto & pair : metadata ) {
122
104
const auto & method_name = pair.first ;
123
105
auto & value = pair.second ;
124
106
125
107
if (method_names.count (method_name)) {
126
- value = ET_UNWRAP (module_->get (method_name))
127
- .toScalar ()
128
- .to <decltype (metadata_)::mapped_type>();
108
+ auto get_result = module->get (method_name);
109
+ value = get_result.get ().toScalar ().to <decltype (metadata)::mapped_type>();
129
110
} else {
130
111
ET_LOG (
131
112
Info,
132
- " Methond %s not found, using the default value %" PRId64,
113
+ " Method %s not found, using the default value %" PRId64,
133
114
method_name.c_str (),
134
115
value);
135
116
}
136
117
ET_LOG (Info, " Metadata: %s = %" PRId64, method_name.c_str (), value);
137
118
}
119
+
120
+ // Get EOS IDs if available
138
121
if (method_names.count (kEosIds )) {
139
122
eos_ids->clear ();
140
- for (const auto & eos_id : ET_UNWRAP (module_->execute (kEosIds ))) {
123
+ auto execute_result = module->execute (kEosIds );
124
+ if (execute_result.error () != Error::Ok) {
125
+ ET_LOG (Error, " Failed to execute %s" , kEosIds );
126
+ return nullptr ;
127
+ }
128
+ for (const auto & eos_id : execute_result.get ()) {
141
129
auto value = eos_id.toScalar ().to <int64_t >();
142
130
eos_ids->emplace (value);
143
131
ET_LOG (Info, " eos_id = %" PRId64, value);
144
132
}
145
133
}
146
- // @lint-ignore CLANGTIDY facebook-hte-Deprecated
147
- text_decoder_runner_ = std::make_unique<llm::TextDecoderRunner>(
148
- module_.get (), metadata_.at (kUseKVCache ));
149
- text_prefiller_ = std::make_unique<llm::TextPrefiller>(
150
- text_decoder_runner_.get (),
151
- metadata_.at (kUseKVCache ),
152
- metadata_.at (kEnableDynamicShape ),
153
- metadata_.at (kMaxSeqLen ));
154
-
155
- text_token_generator_ = std::make_unique<llm::TextTokenGenerator>(
156
- tokenizer_.get (),
157
- text_decoder_runner_.get (),
158
- metadata_.at (kUseKVCache ),
134
+
135
+ // Create text_decoder_runner. Use a shared_ptr so that it can be shared with
136
+ // TextPrefiller and TextTokenGenerator
137
+ auto text_decoder_runner = std::make_unique<llm::TextDecoderRunner>(
138
+ module.get (), metadata.at (kUseKVCache ));
139
+
140
+ // Create text_prefiller
141
+ auto text_prefiller = std::make_unique<llm::TextPrefiller>(
142
+ text_decoder_runner.get (),
143
+ metadata.at (kUseKVCache ),
144
+ metadata.at (kEnableDynamicShape ),
145
+ metadata.at (kMaxSeqLen ));
146
+
147
+ // Create text_token_generator with stats
148
+ auto stats = std::make_unique<llm::Stats>();
149
+ auto text_token_generator = std::make_unique<llm::TextTokenGenerator>(
150
+ tokenizer.get (),
151
+ text_decoder_runner.get (),
152
+ metadata.at (kUseKVCache ),
159
153
std::move (eos_ids),
160
- &stats_);
154
+ stats.get ());
155
+
156
+ // Create and return the Runner instance
157
+ return std::make_unique<Runner>(
158
+ std::move (metadata),
159
+ std::move (tokenizer),
160
+ std::move (text_prefiller),
161
+ std::move (text_token_generator),
162
+ std::move (stats),
163
+ temperature);
164
+ }
161
165
166
+ Runner::Runner (
167
+ std::unordered_map<std::string, int64_t > metadata,
168
+ std::unique_ptr<::tokenizers::Tokenizer> tokenizer,
169
+ std::unique_ptr<::executorch::extension::llm::TextPrefiller> text_prefiller,
170
+ std::unique_ptr<::executorch::extension::llm::TextTokenGenerator>
171
+ text_token_generator,
172
+ std::unique_ptr<::executorch::extension::llm::Stats> stats,
173
+ float temperature)
174
+ : tokenizer_(std::move(tokenizer)),
175
+ metadata_ (std::move(metadata)),
176
+ text_prefiller_(std::move(text_prefiller)),
177
+ text_token_generator_(std::move(text_token_generator)),
178
+ stats_(std::move(stats)),
179
+ temperature_(temperature) {
180
+ // Note: This constructor assumes that text_prefiller and text_token_generator
181
+ // already have references to the Module and TextDecoderRunner they need
182
+ }
183
+
184
+ bool Runner::is_loaded () const {
185
+ return text_prefiller_->is_loaded () && text_token_generator_->is_loaded ();
186
+ }
187
+
188
+ Error Runner::load () {
189
+ if (is_loaded ()) {
190
+ return Error::Ok;
191
+ }
192
+ ET_CHECK_OK_OR_RETURN_ERROR (text_prefiller_->load ());
193
+ ET_CHECK_OK_OR_RETURN_ERROR (text_token_generator_->load ());
162
194
return Error::Ok;
163
195
}
164
196
@@ -179,9 +211,9 @@ Error Runner::generate(
179
211
// Use ones-initialized inputs.
180
212
ET_CHECK_MSG (!prompt.empty (), " Prompt cannot be null" );
181
213
if (!is_loaded ()) {
182
- stats_. model_load_start_ms = llm::time_in_ms ();
214
+ stats_-> model_load_start_ms = llm::time_in_ms ();
183
215
ET_CHECK_OK_OR_RETURN_ERROR (load ());
184
- stats_. model_load_end_ms = llm::time_in_ms ();
216
+ stats_-> model_load_end_ms = llm::time_in_ms ();
185
217
}
186
218
187
219
if (config.warming ) {
@@ -207,7 +239,7 @@ Error Runner::generate(
207
239
// First token time only measures the time it takes to encode the prompt and
208
240
// return a response token.
209
241
210
- stats_. inference_start_ms = llm::time_in_ms ();
242
+ stats_-> inference_start_ms = llm::time_in_ms ();
211
243
shouldStop_ = false ;
212
244
213
245
::tokenizers::Result<std::vector<uint64_t >> encode_res = tokenizer_->encode (
@@ -248,8 +280,8 @@ Error Runner::generate(
248
280
auto prefill_res = text_prefiller_->prefill (prompt_tokens, pos);
249
281
ET_CHECK_OK_OR_RETURN_ERROR (prefill_res.error ());
250
282
uint64_t cur_token = prefill_res.get ();
251
- stats_. first_token_ms = llm::time_in_ms ();
252
- stats_. prompt_eval_end_ms = llm::time_in_ms ();
283
+ stats_-> first_token_ms = llm::time_in_ms ();
284
+ stats_-> prompt_eval_end_ms = llm::time_in_ms ();
253
285
254
286
// print the first token from prefill. No prev_token so use cur_token for it.
255
287
wrapped_callback (
@@ -270,7 +302,7 @@ Error Runner::generate(
270
302
temperature_ == -1 .0f ? config.temperature : temperature_,
271
303
wrapped_callback));
272
304
273
- stats_. inference_end_ms = llm::time_in_ms ();
305
+ stats_-> inference_end_ms = llm::time_in_ms ();
274
306
if (!config.warming ) {
275
307
printf (" \n " );
276
308
}
@@ -283,17 +315,17 @@ Error Runner::generate(
283
315
RUNNER_ET_LOG (config.warming , " Max new tokens %i reached!" , max_new_tokens);
284
316
}
285
317
286
- stats_. num_prompt_tokens = num_prompt_tokens;
287
- stats_. num_generated_tokens = num_generated_tokens;
318
+ stats_-> num_prompt_tokens = num_prompt_tokens;
319
+ stats_-> num_generated_tokens = num_generated_tokens;
288
320
289
321
if (config.warming ) {
290
322
ET_LOG (Info, " Warmup run finished!" );
291
323
} else {
292
324
// Do not print report during warmup
293
- ::executorch::llm::print_report (stats_);
325
+ ::executorch::llm::print_report (* stats_);
294
326
}
295
327
if (stats_callback) {
296
- stats_callback (stats_);
328
+ stats_callback (* stats_);
297
329
}
298
330
299
331
return Error::Ok;
0 commit comments