Skip to content

Commit 5680dbf

Browse files
larryliu0820facebook-github-bot
authored andcommitted
Use unique_ptr and shared_ptr properly for runner components (#10338)
Summary: The ownership of these components need some clarification. * `Module` should be shared by `TextDecoderRunner` and potentially `TextPrefiller` (or `ImagePrefiller` in multimodal runner). * `TextDecoderRunner` should be shared by the `TextPrefiller` and `TextTokenGenerator`. * `Tokenizer` should be owned by the `Runner` as well as `TextTokenGenerator`. Reviewed By: kirklandsign Differential Revision: D73399600
1 parent c0593ff commit 5680dbf

File tree

10 files changed

+117
-65
lines changed

10 files changed

+117
-65
lines changed

examples/models/llama/runner/runner.cpp

+16-13
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,9 @@ Runner::Runner(
5252
{kMaxContextLen, 128},
5353
{kUseKVCache, true},
5454
{kUseSDPAWithKVCache, false},
55-
}) {
55+
}),
56+
// @lint-ignore CLANGTIDY facebook-hte-Deprecated
57+
stats_(std::make_unique<llm::Stats>()) {
5658
if (data_path.has_value()) {
5759
module_ = std::make_unique<Module>(
5860
model_path, data_path.value(), Module::LoadMode::File);
@@ -99,6 +101,7 @@ Error Runner::load() {
99101
"Failed to load %s as a Tiktoken artifact, trying BPE tokenizer",
100102
tokenizer_path_.c_str());
101103
tokenizer_.reset();
104+
// @lint-ignore CLANGTIDY facebook-hte-Deprecated
102105
tokenizer_ = std::make_unique<::tokenizers::Llama2cTokenizer>();
103106
err = tokenizer_->load(tokenizer_path_);
104107
ET_CHECK_TK_OK_OR_RETURN_ERROR(
@@ -156,7 +159,7 @@ Error Runner::load() {
156159
text_decoder_runner_.get(),
157160
metadata_.at(kUseKVCache),
158161
std::move(eos_ids),
159-
&stats_);
162+
stats_.get());
160163

161164
return Error::Ok;
162165
}
@@ -178,9 +181,9 @@ Error Runner::generate(
178181
// Use ones-initialized inputs.
179182
ET_CHECK_MSG(!prompt.empty(), "Prompt cannot be null");
180183
if (!is_loaded()) {
181-
stats_.model_load_start_ms = llm::time_in_ms();
184+
stats_->model_load_start_ms = llm::time_in_ms();
182185
ET_CHECK_OK_OR_RETURN_ERROR(load());
183-
stats_.model_load_end_ms = llm::time_in_ms();
186+
stats_->model_load_end_ms = llm::time_in_ms();
184187
}
185188

186189
if (config.warming) {
@@ -206,7 +209,7 @@ Error Runner::generate(
206209
// First token time only measures the time it takes to encode the prompt and
207210
// return a response token.
208211

209-
stats_.inference_start_ms = llm::time_in_ms();
212+
stats_->inference_start_ms = llm::time_in_ms();
210213
shouldStop_ = false;
211214

212215
::tokenizers::Result<std::vector<uint64_t>> encode_res = tokenizer_->encode(
@@ -247,8 +250,8 @@ Error Runner::generate(
247250
auto prefill_res = text_prefiller_->prefill(prompt_tokens, pos);
248251
ET_CHECK_OK_OR_RETURN_ERROR(prefill_res.error());
249252
uint64_t cur_token = prefill_res.get();
250-
stats_.first_token_ms = llm::time_in_ms();
251-
stats_.prompt_eval_end_ms = llm::time_in_ms();
253+
stats_->first_token_ms = llm::time_in_ms();
254+
stats_->prompt_eval_end_ms = llm::time_in_ms();
252255

253256
// print the first token from prefill. No prev_token so use cur_token for it.
254257
wrapped_callback(
@@ -269,7 +272,7 @@ Error Runner::generate(
269272
temperature_ == -1.0f ? config.temperature : temperature_,
270273
wrapped_callback));
271274

272-
stats_.inference_end_ms = llm::time_in_ms();
275+
stats_->inference_end_ms = llm::time_in_ms();
273276
if (!config.warming) {
274277
printf("\n");
275278
}
@@ -282,17 +285,17 @@ Error Runner::generate(
282285
RUNNER_ET_LOG(config.warming, "Max new tokens %i reached!", max_new_tokens);
283286
}
284287

285-
stats_.num_prompt_tokens = num_prompt_tokens;
286-
stats_.num_generated_tokens = num_generated_tokens;
288+
stats_->num_prompt_tokens = num_prompt_tokens;
289+
stats_->num_generated_tokens = num_generated_tokens;
287290

288291
if (config.warming) {
289292
ET_LOG(Info, "Warmup run finished!");
290293
} else {
291294
// Do not print report during warmup
292-
::executorch::llm::print_report(stats_);
295+
::executorch::llm::print_report(*stats_);
293296
}
294297
if (stats_callback) {
295-
stats_callback(stats_);
298+
stats_callback(*stats_);
296299
}
297300

298301
return Error::Ok;
@@ -307,7 +310,7 @@ Error Runner::warmup(const std::string& prompt, int32_t max_new_tokens) {
307310
Error err = generate(prompt, config);
308311

309312
// Reset stats after warmup
310-
stats_.reset();
313+
stats_->reset();
311314
return err;
312315
}
313316

examples/models/llama/runner/runner.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ class ET_EXPERIMENTAL Runner : public executorch::extension::llm::IRunner {
7171
text_token_generator_;
7272

7373
// stats
74-
::executorch::extension::llm::Stats stats_;
74+
std::unique_ptr<::executorch::extension::llm::Stats> stats_;
7575

7676
// temperature.
7777
// Deprecated, we should rely on the temperature in GenerationConfig instead.

examples/models/llama/tokenizer/llama_tiktoken.cpp

+58-32
Original file line numberDiff line numberDiff line change
@@ -11,51 +11,71 @@
1111
namespace example {
1212

1313
using ::tokenizers::Tiktoken;
14+
using ::tokenizers::Tokenizer;
1415

1516
namespace {
1617
static constexpr int32_t kSpecialTokensSize = 256;
1718
static constexpr size_t kBOSTokenIndex = 0;
1819
static constexpr size_t kEOSTokenIndex = 1;
1920

20-
static inline std::unique_ptr<std::vector<std::string>>
21-
_get_default_special_tokens() {
22-
auto special_tokens =
23-
std::make_unique<std::vector<std::string>>(std::vector<std::string>{
24-
"<|begin_of_text|>",
25-
"<|end_of_text|>",
26-
"<|reserved_special_token_0|>",
27-
"<|reserved_special_token_1|>",
28-
"<|finetune_right_pad_id|>",
29-
"<|step_id|>",
30-
"<|start_header_id|>",
31-
"<|end_header_id|>",
32-
"<|eom_id|>",
33-
"<|eot_id|>",
34-
"<|python_tag|>"});
35-
// pad the rest of the special tokens with reserved tokens
36-
ssize_t reserved_special_token_num = 2;
37-
while (special_tokens->size() < kSpecialTokensSize) {
38-
special_tokens->emplace_back(
39-
"<|reserved_special_token_" +
40-
std::to_string(reserved_special_token_num++) + "|>");
21+
// Compile-time special tokens selection using templates
22+
template <Version V>
23+
struct SpecialTokensSelector {
24+
static std::unique_ptr<std::vector<std::string>> create();
25+
};
26+
27+
// Compile-time special tokens selection using templates
28+
template <>
29+
struct SpecialTokensSelector<Version::Default> {
30+
static std::unique_ptr<std::vector<std::string>> create() {
31+
auto special_tokens =
32+
std::make_unique<std::vector<std::string>>(std::vector<std::string>{
33+
"<|begin_of_text|>",
34+
"<|end_of_text|>",
35+
"<|reserved_special_token_0|>",
36+
"<|reserved_special_token_1|>",
37+
"<|finetune_right_pad_id|>",
38+
"<|step_id|>",
39+
"<|start_header_id|>",
40+
"<|end_header_id|>",
41+
"<|eom_id|>",
42+
"<|eot_id|>",
43+
"<|python_tag|>"});
44+
// pad the rest of the special tokens with reserved tokens
45+
ssize_t reserved_special_token_num = 2;
46+
while (special_tokens->size() < kSpecialTokensSize) {
47+
special_tokens->emplace_back(
48+
"<|reserved_special_token_" +
49+
std::to_string(reserved_special_token_num++) + "|>");
50+
}
51+
return special_tokens;
4152
}
42-
return special_tokens;
43-
}
53+
};
4454

45-
std::unique_ptr<std::vector<std::string>> _get_special_tokens(Version version) {
46-
switch (version) {
47-
case Version::Multimodal:
48-
return get_multimodal_special_tokens();
49-
default:
50-
return _get_default_special_tokens();
55+
// Specialization for Multimodal version
56+
template <>
57+
struct SpecialTokensSelector<Version::Multimodal> {
58+
static std::unique_ptr<std::vector<std::string>> create() {
59+
return get_multimodal_special_tokens();
5160
}
52-
}
61+
};
5362

5463
} // namespace
5564

56-
std::unique_ptr<Tiktoken> get_tiktoken_for_llama(Version version) {
65+
namespace detail {
66+
// Helper function to create a Tiktoken with the given version
67+
template <Version V>
68+
std::unique_ptr<Tiktoken> create_tiktoken() {
69+
std::unique_ptr<std::vector<std::string>> special_tokens =
70+
example::SpecialTokensSelector<V>::create();
5771
return std::make_unique<Tiktoken>(
58-
_get_special_tokens(version), kBOSTokenIndex, kEOSTokenIndex);
72+
std::move(special_tokens), kBOSTokenIndex, kEOSTokenIndex);
73+
}
74+
} // namespace detail
75+
76+
template <Version V>
77+
std::unique_ptr<::tokenizers::Tiktoken> get_tiktoken_for_llama() {
78+
return detail::create_tiktoken<V>();
5979
}
6080

6181
std::unique_ptr<std::vector<std::string>> get_multimodal_special_tokens() {
@@ -87,4 +107,10 @@ std::unique_ptr<std::vector<std::string>> get_multimodal_special_tokens() {
87107
return special_tokens;
88108
}
89109

110+
// specialization
111+
112+
template std::unique_ptr<Tiktoken>
113+
get_tiktoken_for_llama<Version::Multimodal>();
114+
115+
template std::unique_ptr<Tiktoken> get_tiktoken_for_llama<Version::Default>();
90116
} // namespace example

examples/models/llama/tokenizer/llama_tiktoken.h

+3-2
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,10 @@ enum class Version {
1717
Multimodal,
1818
};
1919

20-
std::unique_ptr<::tokenizers::Tiktoken> get_tiktoken_for_llama(
21-
Version version = Version::Default);
20+
template <Version V = Version::Default>
21+
std::unique_ptr<::tokenizers::Tiktoken> get_tiktoken_for_llama();
2222

23+
// For backward compatibility
2324
std::unique_ptr<std::vector<std::string>> get_multimodal_special_tokens();
2425

2526
} // namespace example

examples/models/llava/runner/llava_image_prefiller.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ namespace example {
1818
class ET_EXPERIMENTAL LlavaImagePrefiller
1919
: public ::executorch::extension::llm::ImagePrefiller {
2020
public:
21-
LlavaImagePrefiller(::executorch::extension::Module* module)
21+
explicit LlavaImagePrefiller(::executorch::extension::Module* module)
2222
: ImagePrefiller(module){};
2323
/**
2424
* Prefill an LLM Module with the given image input.

examples/models/llava/runner/llava_runner.cpp

+11-11
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ Error LlavaRunner::load() {
4040
if (is_loaded()) {
4141
return Error::Ok;
4242
}
43-
stats_.model_load_start_ms = llm::time_in_ms();
43+
stats_->model_load_start_ms = llm::time_in_ms();
4444

4545
// Load the tokenizer
4646
tokenizer_ = std::make_unique<tokenizers::Llama2cTokenizer>();
@@ -71,9 +71,9 @@ Error LlavaRunner::load() {
7171
/*use_kv_cache=*/true,
7272
std::make_unique<std::unordered_set<uint64_t>>(
7373
std::unordered_set<uint64_t>{tokenizer_->eos_tok()}),
74-
&stats_);
74+
stats_.get());
7575

76-
stats_.model_load_end_ms = llm::time_in_ms();
76+
stats_->model_load_end_ms = llm::time_in_ms();
7777
return Error::Ok;
7878
}
7979

@@ -113,9 +113,9 @@ Error LlavaRunner::generate_from_pos(
113113

114114
uint64_t prefill_next_token =
115115
ET_UNWRAP(prefill_prompt(prompt, start_pos, /*bos=*/0, /*eos*/ 0));
116-
stats_.first_token_ms = llm::time_in_ms();
117-
stats_.prompt_eval_end_ms = llm::time_in_ms();
118-
stats_.num_prompt_tokens = start_pos;
116+
stats_->first_token_ms = llm::time_in_ms();
117+
stats_->prompt_eval_end_ms = llm::time_in_ms();
118+
stats_->num_prompt_tokens = start_pos;
119119

120120
// Generate tokens
121121
int64_t num_generated_tokens = ET_UNWRAP(text_token_generator_->generate(
@@ -126,9 +126,9 @@ Error LlavaRunner::generate_from_pos(
126126
/*token_callback=*/token_callback));
127127

128128
// Bookkeeping
129-
stats_.num_generated_tokens = num_generated_tokens;
129+
stats_->num_generated_tokens = num_generated_tokens;
130130
if (stats_callback) {
131-
stats_callback(stats_);
131+
stats_callback(*stats_);
132132
}
133133
return Error::Ok;
134134
}
@@ -161,7 +161,7 @@ Error LlavaRunner::generate(
161161
};
162162

163163
int64_t pos = 0;
164-
stats_.inference_start_ms = llm::time_in_ms();
164+
stats_->inference_start_ms = llm::time_in_ms();
165165

166166
// prefill preset prompt
167167
prefill_prompt(kPresetPrompt, pos, /*bos=*/1, /*eos*/ 0);
@@ -178,8 +178,8 @@ Error LlavaRunner::generate(
178178
Error err = generate_from_pos(
179179
prompt, seq_len, pos, wrapped_callback, stats_callback, echo);
180180

181-
stats_.inference_end_ms = llm::time_in_ms();
182-
::executorch::llm::print_report(stats_);
181+
stats_->inference_end_ms = llm::time_in_ms();
182+
::executorch::llm::print_report(*stats_);
183183

184184
ET_LOG(
185185
Info,

extension/llm/runner/multimodal_runner.h

+3-2
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,8 @@ class ET_EXPERIMENTAL MultimodalRunner {
4141
const float temperature = 0.8f)
4242
: temperature_(temperature),
4343
module_(std::make_unique<Module>(model_path, Module::LoadMode::File)),
44-
tokenizer_path_(tokenizer_path) {
44+
tokenizer_path_(tokenizer_path),
45+
stats_(std::make_unique<llm::Stats>()) {
4546
ET_LOG(
4647
Info,
4748
"Creating Multimodal LLM runner: model_path=%s, tokenizer_path=%s",
@@ -132,7 +133,7 @@ class ET_EXPERIMENTAL MultimodalRunner {
132133
std::unique_ptr<::tokenizers::Tokenizer> tokenizer_;
133134

134135
// stats
135-
Stats stats_;
136+
std::unique_ptr<Stats> stats_;
136137
};
137138

138139
} // namespace llm

extension/llm/runner/text_decoder_runner.h

+7-2
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
#include <executorch/extension/module/module.h>
1515
#include <executorch/extension/tensor/tensor.h>
1616
#include <executorch/runtime/platform/compiler.h>
17-
#include <functional>
1817

1918
namespace executorch {
2019
namespace extension {
@@ -94,7 +93,13 @@ class ET_EXPERIMENTAL TextDecoderRunner {
9493
}
9594

9695
protected:
97-
// TODO: use shared_ptr for module
96+
/**
97+
* Note: TextDecoderRunner does not own the Module instance. It is expected
98+
* that the outer class (likely Runner) manages the lifecycle of the Module.
99+
* This means that the responsibility for creating, maintaining, and
100+
* destroying the Module lies outside of TextDecoderRunner. Ensure that the
101+
* Module remains valid for the duration of TextDecoderRunner's usage.
102+
*/
98103
Module* module_;
99104
bool use_kv_cache_;
100105
bool should_stop_{false};

extension/llm/runner/text_prefiller.h

+9-1
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@ class ET_EXPERIMENTAL TextPrefiller {
2424
bool use_kv_cache_,
2525
bool enable_parallel_prefill,
2626
int64_t max_seq_len = 128);
27+
28+
virtual ~TextPrefiller() = default;
2729
/**
2830
* Prefill an LLM Module with the given text input.
2931
* @param prompt_tokens The text prompt tokens to the LLM Module. Encoded by
@@ -32,7 +34,7 @@ class ET_EXPERIMENTAL TextPrefiller {
3234
* Module.
3335
* @return The next token of the LLM Module after prefill.
3436
*/
35-
::executorch::runtime::Result<uint64_t> prefill(
37+
virtual ::executorch::runtime::Result<uint64_t> prefill(
3638
std::vector<uint64_t>& prompt_tokens,
3739
int64_t& start_pos);
3840

@@ -48,6 +50,12 @@ class ET_EXPERIMENTAL TextPrefiller {
4850
int64_t& start_pos);
4951

5052
private:
53+
/**
54+
* Note: TextPrefiller does not own the TextDecoderRunner instance.
55+
* The responsibility of managing the lifecycle of TextDecoderRunner
56+
* lies with the outer class or entity (likely Runner) that creates
57+
* and passes the TextDecoderRunner instance to TextPrefiller.
58+
*/
5159
TextDecoderRunner* text_decoder_runner_;
5260
bool use_kv_cache_;
5361
bool enable_parallel_prefill_;

0 commit comments

Comments
 (0)