@@ -54,10 +54,10 @@ Runner::Runner(
54
54
{kUseSDPAWithKVCache , false },
55
55
}) {
56
56
if (data_path.has_value ()) {
57
- module_ = std::make_unique <Module>(
57
+ module_ = std::make_shared <Module>(
58
58
model_path, data_path.value (), Module::LoadMode::File);
59
59
} else {
60
- module_ = std::make_unique <Module>(model_path, Module::LoadMode::File);
60
+ module_ = std::make_shared <Module>(model_path, Module::LoadMode::File);
61
61
}
62
62
ET_LOG (
63
63
Info,
@@ -89,7 +89,7 @@ Error Runner::load() {
89
89
ET_CHECK_OK_OR_RETURN_ERROR (module_->load_method (" forward" ));
90
90
// load tokenizer. Assuming tiktoken is the default tokenizer
91
91
tokenizer_ = nullptr ;
92
- tokenizer_ = get_tiktoken_for_llama ();
92
+ tokenizer_ = get_tiktoken_for_llama< decltype (tokenizer_)> ();
93
93
::tokenizers::Error err = tokenizer_->load (tokenizer_path_);
94
94
// Rely on tiktoken to throw error if the artifact is incompatible. Then we
95
95
// fallback to BPE tokenizer.
@@ -99,7 +99,7 @@ Error Runner::load() {
99
99
" Failed to load %s as a Tiktoken artifact, trying BPE tokenizer" ,
100
100
tokenizer_path_.c_str ());
101
101
tokenizer_.reset ();
102
- tokenizer_ = std::make_unique <::tokenizers::Llama2cTokenizer>();
102
+ tokenizer_ = std::make_shared <::tokenizers::Llama2cTokenizer>();
103
103
err = tokenizer_->load (tokenizer_path_);
104
104
ET_CHECK_TK_OK_OR_RETURN_ERROR (
105
105
err,
@@ -143,20 +143,21 @@ Error Runner::load() {
143
143
}
144
144
}
145
145
// @lint-ignore CLANGTIDY facebook-hte-Deprecated
146
- text_decoder_runner_ = std::make_unique <llm::TextDecoderRunner>(
147
- module_. get () , metadata_.at (kUseKVCache ));
146
+ text_decoder_runner_ = std::make_shared <llm::TextDecoderRunner>(
147
+ module_, metadata_.at (kUseKVCache ));
148
148
text_prefiller_ = std::make_unique<llm::TextPrefiller>(
149
- text_decoder_runner_. get () ,
149
+ text_decoder_runner_,
150
150
metadata_.at (kUseKVCache ),
151
151
metadata_.at (kEnableDynamicShape ),
152
152
metadata_.at (kMaxSeqLen ));
153
153
154
+ stats_ = std::make_shared<llm::Stats>();
154
155
text_token_generator_ = std::make_unique<llm::TextTokenGenerator>(
155
- tokenizer_. get () ,
156
- text_decoder_runner_. get () ,
156
+ tokenizer_,
157
+ text_decoder_runner_,
157
158
metadata_.at (kUseKVCache ),
158
159
std::move (eos_ids),
159
- & stats_);
160
+ stats_);
160
161
161
162
return Error::Ok;
162
163
}
@@ -178,9 +179,9 @@ Error Runner::generate(
178
179
// Use ones-initialized inputs.
179
180
ET_CHECK_MSG (!prompt.empty (), " Prompt cannot be null" );
180
181
if (!is_loaded ()) {
181
- stats_. model_load_start_ms = llm::time_in_ms ();
182
+ stats_-> model_load_start_ms = llm::time_in_ms ();
182
183
ET_CHECK_OK_OR_RETURN_ERROR (load ());
183
- stats_. model_load_end_ms = llm::time_in_ms ();
184
+ stats_-> model_load_end_ms = llm::time_in_ms ();
184
185
}
185
186
186
187
if (config.warming ) {
@@ -206,7 +207,7 @@ Error Runner::generate(
206
207
// First token time only measures the time it takes to encode the prompt and
207
208
// return a response token.
208
209
209
- stats_. inference_start_ms = llm::time_in_ms ();
210
+ stats_-> inference_start_ms = llm::time_in_ms ();
210
211
shouldStop_ = false ;
211
212
212
213
::tokenizers::Result<std::vector<uint64_t >> encode_res = tokenizer_->encode (
@@ -247,8 +248,8 @@ Error Runner::generate(
247
248
auto prefill_res = text_prefiller_->prefill (prompt_tokens, pos);
248
249
ET_CHECK_OK_OR_RETURN_ERROR (prefill_res.error ());
249
250
uint64_t cur_token = prefill_res.get ();
250
- stats_. first_token_ms = llm::time_in_ms ();
251
- stats_. prompt_eval_end_ms = llm::time_in_ms ();
251
+ stats_-> first_token_ms = llm::time_in_ms ();
252
+ stats_-> prompt_eval_end_ms = llm::time_in_ms ();
252
253
253
254
// print the first token from prefill. No prev_token so use cur_token for it.
254
255
wrapped_callback (
@@ -269,7 +270,7 @@ Error Runner::generate(
269
270
temperature_ == -1 .0f ? config.temperature : temperature_,
270
271
wrapped_callback));
271
272
272
- stats_. inference_end_ms = llm::time_in_ms ();
273
+ stats_-> inference_end_ms = llm::time_in_ms ();
273
274
if (!config.warming ) {
274
275
printf (" \n " );
275
276
}
@@ -282,17 +283,17 @@ Error Runner::generate(
282
283
RUNNER_ET_LOG (config.warming , " Max new tokens %i reached!" , max_new_tokens);
283
284
}
284
285
285
- stats_. num_prompt_tokens = num_prompt_tokens;
286
- stats_. num_generated_tokens = num_generated_tokens;
286
+ stats_-> num_prompt_tokens = num_prompt_tokens;
287
+ stats_-> num_generated_tokens = num_generated_tokens;
287
288
288
289
if (config.warming ) {
289
290
ET_LOG (Info, " Warmup run finished!" );
290
291
} else {
291
292
// Do not print report during warmup
292
- ::executorch::llm::print_report (stats_);
293
+ ::executorch::llm::print_report (* stats_);
293
294
}
294
295
if (stats_callback) {
295
- stats_callback (stats_);
296
+ stats_callback (* stats_);
296
297
}
297
298
298
299
return Error::Ok;
@@ -307,7 +308,7 @@ Error Runner::warmup(const std::string& prompt, int32_t max_new_tokens) {
307
308
Error err = generate (prompt, config);
308
309
309
310
// Reset stats after warmup
310
- stats_. reset ();
311
+ stats_-> reset ();
311
312
return err;
312
313
}
313
314
0 commit comments