Skip to content

Commit 61c1b63

Browse files
larryliu0820facebook-github-bot
authored andcommitted
Use unique_ptr and shared_ptr properly for runner components (#10338)
Summary: The ownership of these components need some clarification. * `Module` should be shared by `TextDecoderRunner` and potentially `TextPrefiller` (or `ImagePrefiller` in multimodal runner). * `TextDecoderRunner` should be shared by the `TextPrefiller` and `TextTokenGenerator`. * `Tokenizer` should be owned by the `Runner` as well as `TextTokenGenerator`. Reviewed By: kirklandsign Differential Revision: D73399600
1 parent 647e1f1 commit 61c1b63

18 files changed

+190
-105
lines changed

examples/mediatek/executor_runner/mtk_llama_executor_runner.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -287,7 +287,7 @@ std::unique_ptr<Tokenizer> load_tokenizer() {
287287
if (FLAGS_tokenizer_type == "bpe") {
288288
tokenizer = std::make_unique<BPETokenizer>();
289289
} else if (FLAGS_tokenizer_type == "tiktoken") {
290-
tokenizer = example::get_tiktoken_for_llama();
290+
tokenizer = example::get_tiktoken_for_llama<decltype(tokenizer)>();
291291
}
292292
ET_CHECK_MSG(
293293
tokenizer, "Invalid tokenizer type: %s", FLAGS_tokenizer_type.c_str());

examples/mediatek/executor_runner/mtk_llama_runner.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -292,7 +292,7 @@ Error MTKLlamaRunner::inference(
292292
std::unique_ptr<Tokenizer> MTKLlamaRunner::load_tokenizer() {
293293
std::unique_ptr<Tokenizer> tokenizer;
294294
// Assumes that tokenizer type is Tiktoken
295-
tokenizer = example::get_tiktoken_for_llama();
295+
tokenizer = example::get_tiktoken_for_llama<decltype(tokenizer)>();
296296
tokenizer->load(modelpaths_.tokenizer_path);
297297
return tokenizer;
298298
}

examples/models/llama/runner/runner.cpp

+22-21
Original file line numberDiff line numberDiff line change
@@ -54,10 +54,10 @@ Runner::Runner(
5454
{kUseSDPAWithKVCache, false},
5555
}) {
5656
if (data_path.has_value()) {
57-
module_ = std::make_unique<Module>(
57+
module_ = std::make_shared<Module>(
5858
model_path, data_path.value(), Module::LoadMode::File);
5959
} else {
60-
module_ = std::make_unique<Module>(model_path, Module::LoadMode::File);
60+
module_ = std::make_shared<Module>(model_path, Module::LoadMode::File);
6161
}
6262
ET_LOG(
6363
Info,
@@ -89,7 +89,7 @@ Error Runner::load() {
8989
ET_CHECK_OK_OR_RETURN_ERROR(module_->load_method("forward"));
9090
// load tokenizer. Assuming tiktoken is the default tokenizer
9191
tokenizer_ = nullptr;
92-
tokenizer_ = get_tiktoken_for_llama();
92+
tokenizer_ = get_tiktoken_for_llama<decltype(tokenizer_)>();
9393
::tokenizers::Error err = tokenizer_->load(tokenizer_path_);
9494
// Rely on tiktoken to throw error if the artifact is incompatible. Then we
9595
// fallback to BPE tokenizer.
@@ -99,7 +99,7 @@ Error Runner::load() {
9999
"Failed to load %s as a Tiktoken artifact, trying BPE tokenizer",
100100
tokenizer_path_.c_str());
101101
tokenizer_.reset();
102-
tokenizer_ = std::make_unique<::tokenizers::Llama2cTokenizer>();
102+
tokenizer_ = std::make_shared<::tokenizers::Llama2cTokenizer>();
103103
err = tokenizer_->load(tokenizer_path_);
104104
ET_CHECK_TK_OK_OR_RETURN_ERROR(
105105
err,
@@ -143,20 +143,21 @@ Error Runner::load() {
143143
}
144144
}
145145
// @lint-ignore CLANGTIDY facebook-hte-Deprecated
146-
text_decoder_runner_ = std::make_unique<llm::TextDecoderRunner>(
147-
module_.get(), metadata_.at(kUseKVCache));
146+
text_decoder_runner_ = std::make_shared<llm::TextDecoderRunner>(
147+
module_, metadata_.at(kUseKVCache));
148148
text_prefiller_ = std::make_unique<llm::TextPrefiller>(
149-
text_decoder_runner_.get(),
149+
text_decoder_runner_,
150150
metadata_.at(kUseKVCache),
151151
metadata_.at(kEnableDynamicShape),
152152
metadata_.at(kMaxSeqLen));
153153

154+
stats_ = std::make_shared<llm::Stats>();
154155
text_token_generator_ = std::make_unique<llm::TextTokenGenerator>(
155-
tokenizer_.get(),
156-
text_decoder_runner_.get(),
156+
tokenizer_,
157+
text_decoder_runner_,
157158
metadata_.at(kUseKVCache),
158159
std::move(eos_ids),
159-
&stats_);
160+
stats_);
160161

161162
return Error::Ok;
162163
}
@@ -178,9 +179,9 @@ Error Runner::generate(
178179
// Use ones-initialized inputs.
179180
ET_CHECK_MSG(!prompt.empty(), "Prompt cannot be null");
180181
if (!is_loaded()) {
181-
stats_.model_load_start_ms = llm::time_in_ms();
182+
stats_->model_load_start_ms = llm::time_in_ms();
182183
ET_CHECK_OK_OR_RETURN_ERROR(load());
183-
stats_.model_load_end_ms = llm::time_in_ms();
184+
stats_->model_load_end_ms = llm::time_in_ms();
184185
}
185186

186187
if (config.warming) {
@@ -206,7 +207,7 @@ Error Runner::generate(
206207
// First token time only measures the time it takes to encode the prompt and
207208
// return a response token.
208209

209-
stats_.inference_start_ms = llm::time_in_ms();
210+
stats_->inference_start_ms = llm::time_in_ms();
210211
shouldStop_ = false;
211212

212213
::tokenizers::Result<std::vector<uint64_t>> encode_res = tokenizer_->encode(
@@ -247,8 +248,8 @@ Error Runner::generate(
247248
auto prefill_res = text_prefiller_->prefill(prompt_tokens, pos);
248249
ET_CHECK_OK_OR_RETURN_ERROR(prefill_res.error());
249250
uint64_t cur_token = prefill_res.get();
250-
stats_.first_token_ms = llm::time_in_ms();
251-
stats_.prompt_eval_end_ms = llm::time_in_ms();
251+
stats_->first_token_ms = llm::time_in_ms();
252+
stats_->prompt_eval_end_ms = llm::time_in_ms();
252253

253254
// print the first token from prefill. No prev_token so use cur_token for it.
254255
wrapped_callback(
@@ -269,7 +270,7 @@ Error Runner::generate(
269270
temperature_ == -1.0f ? config.temperature : temperature_,
270271
wrapped_callback));
271272

272-
stats_.inference_end_ms = llm::time_in_ms();
273+
stats_->inference_end_ms = llm::time_in_ms();
273274
if (!config.warming) {
274275
printf("\n");
275276
}
@@ -282,17 +283,17 @@ Error Runner::generate(
282283
RUNNER_ET_LOG(config.warming, "Max new tokens %i reached!", max_new_tokens);
283284
}
284285

285-
stats_.num_prompt_tokens = num_prompt_tokens;
286-
stats_.num_generated_tokens = num_generated_tokens;
286+
stats_->num_prompt_tokens = num_prompt_tokens;
287+
stats_->num_generated_tokens = num_generated_tokens;
287288

288289
if (config.warming) {
289290
ET_LOG(Info, "Warmup run finished!");
290291
} else {
291292
// Do not print report during warmup
292-
::executorch::llm::print_report(stats_);
293+
::executorch::llm::print_report(*stats_);
293294
}
294295
if (stats_callback) {
295-
stats_callback(stats_);
296+
stats_callback(*stats_);
296297
}
297298

298299
return Error::Ok;
@@ -307,7 +308,7 @@ Error Runner::warmup(const std::string& prompt, int32_t max_new_tokens) {
307308
Error err = generate(prompt, config);
308309

309310
// Reset stats after warmup
310-
stats_.reset();
311+
stats_->reset();
311312
return err;
312313
}
313314

examples/models/llama/runner/runner.h

+4-4
Original file line numberDiff line numberDiff line change
@@ -60,18 +60,18 @@ class ET_EXPERIMENTAL Runner : public executorch::extension::llm::IRunner {
6060
bool shouldStop_{false};
6161

6262
// model
63-
std::unique_ptr<::executorch::extension::Module> module_;
63+
std::shared_ptr<::executorch::extension::Module> module_;
6464
std::string tokenizer_path_;
65-
std::unique_ptr<::tokenizers::Tokenizer> tokenizer_;
65+
std::shared_ptr<::tokenizers::Tokenizer> tokenizer_;
6666
std::unordered_map<std::string, int64_t> metadata_;
67-
std::unique_ptr<::executorch::extension::llm::TextDecoderRunner>
67+
std::shared_ptr<::executorch::extension::llm::TextDecoderRunner>
6868
text_decoder_runner_;
6969
std::unique_ptr<::executorch::extension::llm::TextPrefiller> text_prefiller_;
7070
std::unique_ptr<::executorch::extension::llm::TextTokenGenerator>
7171
text_token_generator_;
7272

7373
// stats
74-
::executorch::extension::llm::Stats stats_;
74+
std::shared_ptr<::executorch::extension::llm::Stats> stats_;
7575

7676
// temperature.
7777
// Deprecated, we should rely on the temperature in GenerationConfig instead.

examples/models/llama/tokenizer/llama_tiktoken.cpp

+92-33
Original file line numberDiff line numberDiff line change
@@ -11,51 +11,85 @@
1111
namespace example {
1212

1313
using ::tokenizers::Tiktoken;
14+
using ::tokenizers::Tokenizer;
1415

1516
namespace {
1617
static constexpr int32_t kSpecialTokensSize = 256;
1718
static constexpr size_t kBOSTokenIndex = 0;
1819
static constexpr size_t kEOSTokenIndex = 1;
1920

20-
static inline std::unique_ptr<std::vector<std::string>>
21-
_get_default_special_tokens() {
22-
auto special_tokens =
23-
std::make_unique<std::vector<std::string>>(std::vector<std::string>{
24-
"<|begin_of_text|>",
25-
"<|end_of_text|>",
26-
"<|reserved_special_token_0|>",
27-
"<|reserved_special_token_1|>",
28-
"<|finetune_right_pad_id|>",
29-
"<|step_id|>",
30-
"<|start_header_id|>",
31-
"<|end_header_id|>",
32-
"<|eom_id|>",
33-
"<|eot_id|>",
34-
"<|python_tag|>"});
35-
// pad the rest of the special tokens with reserved tokens
36-
ssize_t reserved_special_token_num = 2;
37-
while (special_tokens->size() < kSpecialTokensSize) {
38-
special_tokens->emplace_back(
39-
"<|reserved_special_token_" +
40-
std::to_string(reserved_special_token_num++) + "|>");
21+
// Compile-time special tokens selection using templates
22+
template <Version V>
23+
struct SpecialTokensSelector {
24+
static std::unique_ptr<std::vector<std::string>> get();
25+
};
26+
27+
// Compile-time special tokens selection using templates
28+
template <>
29+
struct SpecialTokensSelector<Version::Default> {
30+
static std::unique_ptr<std::vector<std::string>> get() {
31+
auto special_tokens =
32+
std::make_unique<std::vector<std::string>>(std::vector<std::string>{
33+
"<|begin_of_text|>",
34+
"<|end_of_text|>",
35+
"<|reserved_special_token_0|>",
36+
"<|reserved_special_token_1|>",
37+
"<|finetune_right_pad_id|>",
38+
"<|step_id|>",
39+
"<|start_header_id|>",
40+
"<|end_header_id|>",
41+
"<|eom_id|>",
42+
"<|eot_id|>",
43+
"<|python_tag|>"});
44+
// pad the rest of the special tokens with reserved tokens
45+
ssize_t reserved_special_token_num = 2;
46+
while (special_tokens->size() < kSpecialTokensSize) {
47+
special_tokens->emplace_back(
48+
"<|reserved_special_token_" +
49+
std::to_string(reserved_special_token_num++) + "|>");
50+
}
51+
return special_tokens;
4152
}
42-
return special_tokens;
43-
}
53+
};
4454

45-
std::unique_ptr<std::vector<std::string>> _get_special_tokens(Version version) {
46-
switch (version) {
47-
case Version::Multimodal:
48-
return get_multimodal_special_tokens();
49-
default:
50-
return _get_default_special_tokens();
55+
// Specialization for Multimodal version
56+
template <>
57+
struct SpecialTokensSelector<Version::Multimodal> {
58+
static std::unique_ptr<std::vector<std::string>> get() {
59+
return get_multimodal_special_tokens();
5160
}
52-
}
61+
};
5362

5463
} // namespace
5564

56-
std::unique_ptr<Tiktoken> get_tiktoken_for_llama(Version version) {
57-
return std::make_unique<Tiktoken>(
58-
_get_special_tokens(version), kBOSTokenIndex, kEOSTokenIndex);
65+
namespace detail {
66+
// Helper function to create a Tiktoken with the given version
67+
template <typename PtrType, Version V>
68+
PtrType create_tiktoken() {
69+
std::unique_ptr<std::vector<std::string>> special_tokens =
70+
example::SpecialTokensSelector<V>::get();
71+
if constexpr (is_shared_ptr_of_tokenizer<PtrType>()) {
72+
return std::make_shared<Tiktoken>(
73+
std::move(special_tokens), kBOSTokenIndex, kEOSTokenIndex);
74+
} else if constexpr (is_unique_ptr_of_tokenizer<PtrType>()) {
75+
return std::make_unique<Tiktoken>(
76+
std::move(special_tokens), kBOSTokenIndex, kEOSTokenIndex);
77+
} else {
78+
static_assert(
79+
is_shared_ptr_of_tokenizer<PtrType>() ||
80+
is_unique_ptr_of_tokenizer<PtrType>(),
81+
"PtrType must be either std::shared_ptr<Tiktoken> or std::unique_ptr<Tiktoken>");
82+
// This line is never reached due to the static_assert, but needed for
83+
// compilation
84+
return PtrType{};
85+
}
86+
}
87+
} // namespace detail
88+
89+
// Function that returns a shared_ptr
90+
template <typename PtrType, Version V>
91+
PtrType get_tiktoken_for_llama() {
92+
return detail::create_tiktoken<PtrType, V>();
5993
}
6094

6195
std::unique_ptr<std::vector<std::string>> get_multimodal_special_tokens() {
@@ -87,4 +121,29 @@ std::unique_ptr<std::vector<std::string>> get_multimodal_special_tokens() {
87121
return special_tokens;
88122
}
89123

124+
// specialization
125+
126+
template std::shared_ptr<Tiktoken>
127+
get_tiktoken_for_llama<std::shared_ptr<Tiktoken>, Version::Multimodal>();
128+
129+
template std::unique_ptr<Tiktoken>
130+
get_tiktoken_for_llama<std::unique_ptr<Tiktoken>, Version::Multimodal>();
131+
132+
template std::shared_ptr<Tiktoken>
133+
get_tiktoken_for_llama<std::shared_ptr<Tiktoken>, Version::Default>();
134+
135+
template std::unique_ptr<Tiktoken>
136+
get_tiktoken_for_llama<std::unique_ptr<Tiktoken>, Version::Default>();
137+
138+
template std::shared_ptr<Tokenizer>
139+
get_tiktoken_for_llama<std::shared_ptr<Tokenizer>, Version::Multimodal>();
140+
141+
template std::unique_ptr<Tokenizer>
142+
get_tiktoken_for_llama<std::unique_ptr<Tokenizer>, Version::Multimodal>();
143+
144+
template std::shared_ptr<Tokenizer>
145+
get_tiktoken_for_llama<std::shared_ptr<Tokenizer>, Version::Default>();
146+
147+
template std::unique_ptr<Tokenizer>
148+
get_tiktoken_for_llama<std::unique_ptr<Tokenizer>, Version::Default>();
90149
} // namespace example

examples/models/llama/tokenizer/llama_tiktoken.h

+20-2
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,27 @@ enum class Version {
1717
Multimodal,
1818
};
1919

20-
std::unique_ptr<::tokenizers::Tiktoken> get_tiktoken_for_llama(
21-
Version version = Version::Default);
20+
// Type traits to check if a type is a shared_ptr or unique_ptr of Tokenizer or
21+
// a derived class
22+
template <typename T>
23+
struct is_shared_ptr_of_tokenizer : std::false_type {};
2224

25+
template <typename T>
26+
struct is_shared_ptr_of_tokenizer<std::shared_ptr<T>>
27+
: std::is_base_of<::tokenizers::Tokenizer, T> {};
28+
29+
template <typename T>
30+
struct is_unique_ptr_of_tokenizer : std::false_type {};
31+
32+
template <typename T>
33+
struct is_unique_ptr_of_tokenizer<std::unique_ptr<T>>
34+
: std::is_base_of<::tokenizers::Tokenizer, T> {};
35+
36+
// Template version that can return either shared_ptr or unique_ptr
37+
template <typename PtrType, Version V = Version::Default>
38+
PtrType get_tiktoken_for_llama();
39+
40+
// For backward compatibility
2341
std::unique_ptr<std::vector<std::string>> get_multimodal_special_tokens();
2442

2543
} // namespace example

examples/models/llava/runner/llava_image_prefiller.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ namespace example {
1818
class ET_EXPERIMENTAL LlavaImagePrefiller
1919
: public ::executorch::extension::llm::ImagePrefiller {
2020
public:
21-
LlavaImagePrefiller(::executorch::extension::Module* module)
21+
LlavaImagePrefiller(std::shared_ptr<::executorch::extension::Module> module)
2222
: ImagePrefiller(module){};
2323
/**
2424
* Prefill an LLM Module with the given image input.

0 commit comments

Comments
 (0)