Skip to content

Commit d94237a

Browse files
larryliu0820facebook-github-bot
authored andcommitted
Use unique_ptr and shared_ptr properly for runner components (#10338)
Summary: The ownership of these components need some clarification. * `Module` should be solely owned by `TextDecoderRunner` * `TextDecoderRunner` should be shared by the `TextPrefiller` and `TextTokenGenerator`. * `Tokenizer` should be owned by the `Runner` as well as `TextTokenGenerator` Reviewed By: kirklandsign Differential Revision: D73399600
1 parent c053712 commit d94237a

File tree

9 files changed

+67
-27
lines changed

9 files changed

+67
-27
lines changed

examples/models/llama/runner/runner.cpp

+6-6
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ Error Runner::load() {
9999
"Failed to load %s as a Tiktoken artifact, trying BPE tokenizer",
100100
tokenizer_path_.c_str());
101101
tokenizer_.reset();
102-
tokenizer_ = std::make_unique<::tokenizers::Llama2cTokenizer>();
102+
tokenizer_ = std::make_shared<::tokenizers::Llama2cTokenizer>();
103103
err = tokenizer_->load(tokenizer_path_);
104104
ET_CHECK_TK_OK_OR_RETURN_ERROR(
105105
err,
@@ -143,17 +143,17 @@ Error Runner::load() {
143143
}
144144
}
145145
// @lint-ignore CLANGTIDY facebook-hte-Deprecated
146-
text_decoder_runner_ = std::make_unique<llm::TextDecoderRunner>(
147-
module_.get(), metadata_.at(kUseKVCache));
146+
text_decoder_runner_ = std::make_shared<llm::TextDecoderRunner>(
147+
std::move(module_), metadata_.at(kUseKVCache));
148148
text_prefiller_ = std::make_unique<llm::TextPrefiller>(
149-
text_decoder_runner_.get(),
149+
text_decoder_runner_,
150150
metadata_.at(kUseKVCache),
151151
metadata_.at(kEnableDynamicShape),
152152
metadata_.at(kMaxSeqLen));
153153

154154
text_token_generator_ = std::make_unique<llm::TextTokenGenerator>(
155-
tokenizer_.get(),
156-
text_decoder_runner_.get(),
155+
tokenizer_,
156+
text_decoder_runner_,
157157
metadata_.at(kUseKVCache),
158158
std::move(eos_ids),
159159
&stats_);

examples/models/llama/runner/runner.h

+2-2
Original file line numberDiff line numberDiff line change
@@ -62,9 +62,9 @@ class ET_EXPERIMENTAL Runner : public executorch::extension::llm::IRunner {
6262
// model
6363
std::unique_ptr<::executorch::extension::Module> module_;
6464
std::string tokenizer_path_;
65-
std::unique_ptr<::tokenizers::Tokenizer> tokenizer_;
65+
std::shared_ptr<::tokenizers::Tokenizer> tokenizer_;
6666
std::unordered_map<std::string, int64_t> metadata_;
67-
std::unique_ptr<::executorch::extension::llm::TextDecoderRunner>
67+
std::shared_ptr<::executorch::extension::llm::TextDecoderRunner>
6868
text_decoder_runner_;
6969
std::unique_ptr<::executorch::extension::llm::TextPrefiller> text_prefiller_;
7070
std::unique_ptr<::executorch::extension::llm::TextTokenGenerator>

examples/models/llama/tokenizer/llama_tiktoken.cpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -53,8 +53,8 @@ std::unique_ptr<std::vector<std::string>> _get_special_tokens(Version version) {
5353

5454
} // namespace
5555

56-
std::unique_ptr<Tiktoken> get_tiktoken_for_llama(Version version) {
57-
return std::make_unique<Tiktoken>(
56+
std::shared_ptr<Tiktoken> get_tiktoken_for_llama(Version version) {
57+
return std::make_shared<Tiktoken>(
5858
_get_special_tokens(version), kBOSTokenIndex, kEOSTokenIndex);
5959
}
6060

examples/models/llama/tokenizer/llama_tiktoken.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ enum class Version {
1717
Multimodal,
1818
};
1919

20-
std::unique_ptr<::tokenizers::Tiktoken> get_tiktoken_for_llama(
20+
std::shared_ptr<::tokenizers::Tiktoken> get_tiktoken_for_llama(
2121
Version version = Version::Default);
2222

2323
std::unique_ptr<std::vector<std::string>> get_multimodal_special_tokens();

extension/llm/runner/text_decoder_runner.cpp

+4-2
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,10 @@ namespace llm {
2121
// NOTE: we observed ~2x loading performance increase on iPhone 15
2222
// and a ~5% improvement on Galaxy S22 by switching to
2323
// FileDataLoader instead of MmapDataLoader + UseMlockIgnoreErrors.
24-
TextDecoderRunner::TextDecoderRunner(Module* module, bool use_kv_cache)
25-
: module_(module), use_kv_cache_(use_kv_cache) {}
24+
TextDecoderRunner::TextDecoderRunner(
25+
std::unique_ptr<Module> module,
26+
bool use_kv_cache)
27+
: module_(std::move(module)), use_kv_cache_(use_kv_cache) {}
2628

2729
// This function is functional, meaning it shouldn't modify any state of the
2830
// input. It should be safe to call multiple times with the same inputs. The

extension/llm/runner/text_decoder_runner.h

+2-3
Original file line numberDiff line numberDiff line change
@@ -14,15 +14,14 @@
1414
#include <executorch/extension/module/module.h>
1515
#include <executorch/extension/tensor/tensor.h>
1616
#include <executorch/runtime/platform/compiler.h>
17-
#include <functional>
1817

1918
namespace executorch {
2019
namespace extension {
2120
namespace llm {
2221

2322
class ET_EXPERIMENTAL TextDecoderRunner {
2423
public:
25-
TextDecoderRunner(Module* module, bool use_kv_cache);
24+
TextDecoderRunner(std::unique_ptr<Module> module, bool use_kv_cache);
2625

2726
virtual ~TextDecoderRunner() = default;
2827

@@ -95,7 +94,7 @@ class ET_EXPERIMENTAL TextDecoderRunner {
9594

9695
protected:
9796
// TODO: use shared_ptr for module
98-
Module* module_;
97+
std::unique_ptr<Module> module_;
9998
bool use_kv_cache_;
10099
bool should_stop_{false};
101100
};

extension/llm/runner/text_prefiller.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ namespace extension {
1717
namespace llm {
1818

1919
TextPrefiller::TextPrefiller(
20-
TextDecoderRunner* text_decoder_runner,
20+
std::shared_ptr<TextDecoderRunner> text_decoder_runner,
2121
bool use_kv_cache,
2222
bool enable_parallel_prefill,
2323
int64_t max_seq_len)

extension/llm/runner/text_prefiller.h

+22-3
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,12 @@ namespace llm {
2020
class ET_EXPERIMENTAL TextPrefiller {
2121
public:
2222
TextPrefiller(
23-
TextDecoderRunner* text_decoder_runner,
23+
std::shared_ptr<TextDecoderRunner> text_decoder_runner,
2424
bool use_kv_cache_,
2525
bool enable_parallel_prefill,
2626
int64_t max_seq_len = 128);
27+
28+
virtual ~TextPrefiller() = default;
2729
/**
2830
* Prefill an LLM Module with the given text input.
2931
* @param prompt_tokens The text prompt tokens to the LLM Module. Encoded by
@@ -32,7 +34,7 @@ class ET_EXPERIMENTAL TextPrefiller {
3234
* Module.
3335
* @return The next token of the LLM Module after prefill.
3436
*/
35-
::executorch::runtime::Result<uint64_t> prefill(
37+
virtual ::executorch::runtime::Result<uint64_t> prefill(
3638
std::vector<uint64_t>& prompt_tokens,
3739
int64_t& start_pos);
3840

@@ -47,8 +49,25 @@ class ET_EXPERIMENTAL TextPrefiller {
4749
std::vector<uint64_t>& prompt_tokens,
4850
int64_t& start_pos);
4951

52+
/**
53+
* Load the necessary resources for the TextPrefiller.
54+
* This method should be called before using the prefill methods.
55+
*/
56+
::executorch::runtime::Error load() {
57+
return text_decoder_runner_->load();
58+
}
59+
60+
/**
61+
* Check if the TextPrefiller has been successfully loaded.
62+
* @return True if the resources are loaded, false otherwise.
63+
*/
64+
bool inline is_loaded() const {
65+
// Implementation to check if resources are loaded
66+
return text_decoder_runner_->is_method_loaded();
67+
}
68+
5069
private:
51-
TextDecoderRunner* text_decoder_runner_;
70+
std::shared_ptr<TextDecoderRunner> text_decoder_runner_;
5271
bool use_kv_cache_;
5372
bool enable_parallel_prefill_;
5473
int64_t max_seq_len_;

extension/llm/runner/text_token_generator.h

+27-7
Original file line numberDiff line numberDiff line change
@@ -21,16 +21,18 @@ namespace llm {
2121
class ET_EXPERIMENTAL TextTokenGenerator {
2222
public:
2323
TextTokenGenerator(
24-
::tokenizers::Tokenizer* tokenizer,
25-
TextDecoderRunner* text_decoder_runner,
24+
std::shared_ptr<::tokenizers::Tokenizer> tokenizer,
25+
std::shared_ptr<TextDecoderRunner> text_decoder_runner,
2626
bool use_kv_cache,
2727
std::unique_ptr<std::unordered_set<uint64_t>>&& eos_ids,
28-
Stats* stats)
28+
std::unique_ptr<Stats> stats)
2929
: tokenizer_(tokenizer),
3030
text_decoder_runner_(text_decoder_runner),
3131
eos_ids_(std::move(eos_ids)),
3232
use_kv_cache_(use_kv_cache),
33-
stats_(stats) {}
33+
stats_(std::move(stats)) {}
34+
35+
virtual ~TextTokenGenerator() = default;
3436

3537
/**
3638
* Token generation loop.
@@ -135,17 +137,35 @@ class ET_EXPERIMENTAL TextTokenGenerator {
135137
should_stop_ = true;
136138
}
137139

140+
/**
141+
* Load the necessary resources for TextTokenGenerator.
142+
* This method should be called before using the generate() method.
143+
*/
144+
::executorch::runtime::Error load() {
145+
return text_decoder_runner_->load();
146+
}
147+
148+
/**
149+
* Check if the TextTokenGenerator has been successfully loaded.
150+
* @return True if the resources are loaded, false otherwise.
151+
*/
152+
bool inline is_loaded() const {
153+
// Implementation to check if resources are loaded
154+
return tokenizer_->is_initialized() &&
155+
text_decoder_runner_->is_method_loaded();
156+
}
157+
138158
private:
139-
::tokenizers::Tokenizer* tokenizer_;
140-
TextDecoderRunner* text_decoder_runner_;
159+
std::shared_ptr<::tokenizers::Tokenizer> tokenizer_;
160+
std::shared_ptr<TextDecoderRunner> text_decoder_runner_;
141161
std::unique_ptr<std::unordered_set<uint64_t>> eos_ids_;
142162
bool use_kv_cache_;
143163

144164
// state machine
145165
bool should_stop_ = false;
146166

147167
// stats
148-
Stats* stats_;
168+
std::unique_ptr<Stats> stats_;
149169
};
150170

151171
} // namespace llm

0 commit comments

Comments
 (0)