Skip to content

Commit c4bf4be

Browse files
larryliu0820facebook-github-bot
authored andcommitted
Use dependency injection for runner (#10326)
Summary: Pull Request resolved: #10326 X-link: pytorch-labs/tokenizers#53 Pass in runner components, move most of the instantiation logic from `load()` to a new static API `create()`. This adds testability to runner components. Differential Revision: D73165546
1 parent c3dc721 commit c4bf4be

File tree

11 files changed

+553
-122
lines changed

11 files changed

+553
-122
lines changed

examples/demo-apps/apple_ios/LLaMA/LLaMARunner/LLaMARunner/Exported/LLaMARunner.mm

+1-1
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ - (instancetype)initWithModelPath:(NSString*)modelPath
3131
self = [super init];
3232
if (self) {
3333
[ExecuTorchLog.sharedLog addSink:self];
34-
_runner = std::make_unique<example::Runner>(
34+
_runner = example::Runner::create(
3535
modelPath.UTF8String, tokenizerPath.UTF8String);
3636
}
3737
return self;

examples/models/llama/main.cpp

+4-3
Original file line numberDiff line numberDiff line change
@@ -74,17 +74,18 @@ int32_t main(int32_t argc, char** argv) {
7474
#endif
7575
// create llama runner
7676
// @lint-ignore CLANGTIDY facebook-hte-Deprecated
77-
example::Runner runner(model_path, tokenizer_path);
77+
std::unique_ptr<example::Runner> runner =
78+
example::Runner::create(model_path, tokenizer_path);
7879

7980
if (warmup) {
8081
// @lint-ignore CLANGTIDY facebook-hte-Deprecated
81-
runner.warmup(prompt, /*max_new_tokens=*/seq_len);
82+
runner->warmup(prompt, /*max_new_tokens=*/seq_len);
8283
}
8384
// generate
8485
executorch::extension::llm::GenerationConfig config{
8586
.seq_len = seq_len, .temperature = temperature};
8687
// @lint-ignore CLANGTIDY facebook-hte-Deprecated
87-
runner.generate(prompt, config);
88+
runner->generate(prompt, config);
8889

8990
return 0;
9091
}

examples/models/llama/runner/runner.cpp

+127-95
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,6 @@
1111

1212
#include <executorch/examples/models/llama/runner/runner.h>
1313

14-
#include <algorithm>
15-
#include <ctime>
16-
1714
#include <executorch/extension/llm/runner/util.h>
1815

1916
#include <executorch/examples/models/llama/tokenizer/llama_tiktoken.h>
@@ -35,130 +32,165 @@ static constexpr auto kMaxSeqLen = "get_max_seq_len";
3532
static constexpr auto kMaxContextLen = "get_max_context_len";
3633
static constexpr auto kVocabSize = "get_vocab_size";
3734
static constexpr auto kUseKVCache = "use_kv_cache";
38-
static constexpr auto kUseSDPAWithKVCache = "use_sdpa_with_kv_cache";
3935
} // namespace
4036

41-
Runner::Runner(
37+
std::unique_ptr<Runner> Runner::create(
4238
const std::string& model_path,
4339
const std::string& tokenizer_path,
44-
std::optional<const std::string> data_path)
45-
// NOTE: we observed ~2x loading performance increase on iPhone 15
46-
// and a ~5% improvement on Galaxy S22 by switching to
47-
// FileDataLoader instead of MmapDataLoader + UseMlockIgnoreErrors.
48-
: tokenizer_path_(tokenizer_path),
49-
metadata_({
50-
{kEnableDynamicShape, false},
51-
{kMaxSeqLen, 128},
52-
{kMaxContextLen, 128},
53-
{kUseKVCache, true},
54-
{kUseSDPAWithKVCache, false},
55-
}) {
56-
if (data_path.has_value()) {
57-
module_ = std::make_unique<Module>(
58-
model_path, data_path.value(), Module::LoadMode::File);
59-
} else {
60-
module_ = std::make_unique<Module>(model_path, Module::LoadMode::File);
61-
}
40+
std::optional<const std::string> data_path,
41+
float temperature) {
6242
ET_LOG(
6343
Info,
6444
"Creating LLaMa runner: model_path=%s, tokenizer_path=%s",
6545
model_path.c_str(),
6646
tokenizer_path.c_str());
67-
}
6847

69-
[[deprecated(
70-
"This constructor is deprecated. Use the constructor without temperature parameter instead.")]]
71-
Runner::Runner(
72-
const std::string& model_path,
73-
const std::string& tokenizer_path,
74-
const float temperature,
75-
std::optional<const std::string> data_path)
76-
: Runner(model_path, tokenizer_path, std::move(data_path)) {
77-
temperature_ = temperature;
78-
}
48+
// Create the Module
49+
std::unique_ptr<Module> module;
50+
if (data_path.has_value()) {
51+
module = std::make_unique<Module>(
52+
model_path, data_path.value(), Module::LoadMode::File);
53+
} else {
54+
module = std::make_unique<Module>(model_path, Module::LoadMode::File);
55+
}
7956

80-
bool Runner::is_loaded() const {
81-
return module_->is_loaded() && tokenizer_ && text_decoder_runner_ &&
82-
text_prefiller_ && text_token_generator_;
83-
}
57+
// Initialize metadata with default values
58+
std::unordered_map<std::string, int64_t> metadata({
59+
{kEnableDynamicShape, false},
60+
{kMaxSeqLen, 128},
61+
{kMaxContextLen, 128},
62+
{kUseKVCache, true},
63+
});
8464

85-
Error Runner::load() {
86-
if (is_loaded()) {
87-
return Error::Ok;
88-
}
89-
ET_CHECK_OK_OR_RETURN_ERROR(module_->load_method("forward"));
90-
// load tokenizer. Assuming tiktoken is the default tokenizer
91-
tokenizer_ = nullptr;
92-
tokenizer_ = get_tiktoken_for_llama();
93-
::tokenizers::Error err = tokenizer_->load(tokenizer_path_);
94-
// Rely on tiktoken to throw error if the artifact is incompatible. Then we
95-
// fallback to BPE tokenizer.
96-
if (err != ::tokenizers::Error::Ok) {
65+
// Create and load tokenizer
66+
std::unique_ptr<::tokenizers::Tokenizer> tokenizer = get_tiktoken_for_llama();
67+
::tokenizers::Error tk_err = tokenizer->load(tokenizer_path);
68+
69+
// Fallback to BPE tokenizer if tiktoken fails
70+
if (tk_err != ::tokenizers::Error::Ok) {
9771
ET_LOG(
9872
Info,
9973
"Failed to load %s as a Tiktoken artifact, trying BPE tokenizer",
100-
tokenizer_path_.c_str());
101-
tokenizer_.reset();
102-
// @lint-ignore CLANGTIDY facebook-hte-Deprecated
103-
tokenizer_ = std::make_unique<::tokenizers::Llama2cTokenizer>();
104-
err = tokenizer_->load(tokenizer_path_);
105-
ET_CHECK_TK_OK_OR_RETURN_ERROR(
106-
err,
107-
"Failed to load %s as a llama2.c tokenizer artifact",
108-
tokenizer_path_.c_str());
74+
tokenizer_path.c_str());
75+
tokenizer.reset();
76+
tokenizer = std::make_unique<::tokenizers::Llama2cTokenizer>();
77+
tk_err = tokenizer->load(tokenizer_path);
78+
if (tk_err != ::tokenizers::Error::Ok) {
79+
ET_LOG(
80+
Error,
81+
"Failed to load %s as a llama2.c tokenizer artifact",
82+
tokenizer_path.c_str());
83+
return nullptr;
84+
}
10985
}
11086

11187
ET_LOG(Info, "Reading metadata from model");
11288

113-
metadata_[kBosId] = tokenizer_->bos_tok();
89+
// Set tokenizer-related metadata
90+
metadata[kBosId] = tokenizer->bos_tok();
11491
auto eos_ids = std::make_unique<std::unordered_set<uint64_t>>(
115-
std::unordered_set<uint64_t>{tokenizer_->eos_tok()});
116-
metadata_[kVocabSize] = tokenizer_->vocab_size();
117-
118-
const auto method_names =
119-
ET_UNWRAP(module_->method_names(), "Failed reading method names");
92+
std::unordered_set<uint64_t>{tokenizer->eos_tok()});
93+
metadata[kVocabSize] = tokenizer->vocab_size();
94+
95+
// Read metadata from the model
96+
auto method_names_result = module->method_names();
97+
if (method_names_result.error() != Error::Ok) {
98+
ET_LOG(Error, "Failed reading method names");
99+
return nullptr;
100+
}
101+
const auto method_names = method_names_result.get();
120102

121-
for (auto& pair : metadata_) {
103+
for (auto& pair : metadata) {
122104
const auto& method_name = pair.first;
123105
auto& value = pair.second;
124106

125107
if (method_names.count(method_name)) {
126-
value = ET_UNWRAP(module_->get(method_name))
127-
.toScalar()
128-
.to<decltype(metadata_)::mapped_type>();
108+
auto get_result = module->get(method_name);
109+
value = get_result.get().toScalar().to<decltype(metadata)::mapped_type>();
129110
} else {
130111
ET_LOG(
131112
Info,
132-
"Methond %s not found, using the default value %" PRId64,
113+
"Method %s not found, using the default value %" PRId64,
133114
method_name.c_str(),
134115
value);
135116
}
136117
ET_LOG(Info, "Metadata: %s = %" PRId64, method_name.c_str(), value);
137118
}
119+
120+
// Get EOS IDs if available
138121
if (method_names.count(kEosIds)) {
139122
eos_ids->clear();
140-
for (const auto& eos_id : ET_UNWRAP(module_->execute(kEosIds))) {
123+
auto execute_result = module->execute(kEosIds);
124+
if (execute_result.error() != Error::Ok) {
125+
ET_LOG(Error, "Failed to execute %s", kEosIds);
126+
return nullptr;
127+
}
128+
for (const auto& eos_id : execute_result.get()) {
141129
auto value = eos_id.toScalar().to<int64_t>();
142130
eos_ids->emplace(value);
143131
ET_LOG(Info, "eos_id = %" PRId64, value);
144132
}
145133
}
146-
// @lint-ignore CLANGTIDY facebook-hte-Deprecated
147-
text_decoder_runner_ = std::make_unique<llm::TextDecoderRunner>(
148-
module_.get(), metadata_.at(kUseKVCache));
149-
text_prefiller_ = std::make_unique<llm::TextPrefiller>(
150-
text_decoder_runner_.get(),
151-
metadata_.at(kUseKVCache),
152-
metadata_.at(kEnableDynamicShape),
153-
metadata_.at(kMaxSeqLen));
154-
155-
text_token_generator_ = std::make_unique<llm::TextTokenGenerator>(
156-
tokenizer_.get(),
157-
text_decoder_runner_.get(),
158-
metadata_.at(kUseKVCache),
134+
135+
// Create text_decoder_runner. Use a shared_ptr so that it can be shared with
136+
// TextPrefiller and TextTokenGenerator
137+
auto text_decoder_runner = std::make_unique<llm::TextDecoderRunner>(
138+
module.get(), metadata.at(kUseKVCache));
139+
140+
// Create text_prefiller
141+
auto text_prefiller = std::make_unique<llm::TextPrefiller>(
142+
text_decoder_runner.get(),
143+
metadata.at(kUseKVCache),
144+
metadata.at(kEnableDynamicShape),
145+
metadata.at(kMaxSeqLen));
146+
147+
// Create text_token_generator with stats
148+
auto stats = std::make_unique<llm::Stats>();
149+
auto text_token_generator = std::make_unique<llm::TextTokenGenerator>(
150+
tokenizer.get(),
151+
text_decoder_runner.get(),
152+
metadata.at(kUseKVCache),
159153
std::move(eos_ids),
160-
&stats_);
154+
stats.get());
155+
156+
// Create and return the Runner instance
157+
return std::make_unique<Runner>(
158+
std::move(metadata),
159+
std::move(tokenizer),
160+
std::move(text_prefiller),
161+
std::move(text_token_generator),
162+
std::move(stats),
163+
temperature);
164+
}
161165

166+
Runner::Runner(
167+
std::unordered_map<std::string, int64_t> metadata,
168+
std::unique_ptr<::tokenizers::Tokenizer> tokenizer,
169+
std::unique_ptr<::executorch::extension::llm::TextPrefiller> text_prefiller,
170+
std::unique_ptr<::executorch::extension::llm::TextTokenGenerator>
171+
text_token_generator,
172+
std::unique_ptr<::executorch::extension::llm::Stats> stats,
173+
float temperature)
174+
: tokenizer_(std::move(tokenizer)),
175+
metadata_(std::move(metadata)),
176+
text_prefiller_(std::move(text_prefiller)),
177+
text_token_generator_(std::move(text_token_generator)),
178+
stats_(std::move(stats)),
179+
temperature_(temperature) {
180+
// Note: This constructor assumes that text_prefiller and text_token_generator
181+
// already have references to the Module and TextDecoderRunner they need
182+
}
183+
184+
bool Runner::is_loaded() const {
185+
return text_prefiller_->is_loaded() && text_token_generator_->is_loaded();
186+
}
187+
188+
Error Runner::load() {
189+
if (is_loaded()) {
190+
return Error::Ok;
191+
}
192+
ET_CHECK_OK_OR_RETURN_ERROR(text_prefiller_->load());
193+
ET_CHECK_OK_OR_RETURN_ERROR(text_token_generator_->load());
162194
return Error::Ok;
163195
}
164196

@@ -179,9 +211,9 @@ Error Runner::generate(
179211
// Use ones-initialized inputs.
180212
ET_CHECK_MSG(!prompt.empty(), "Prompt cannot be null");
181213
if (!is_loaded()) {
182-
stats_.model_load_start_ms = llm::time_in_ms();
214+
stats_->model_load_start_ms = llm::time_in_ms();
183215
ET_CHECK_OK_OR_RETURN_ERROR(load());
184-
stats_.model_load_end_ms = llm::time_in_ms();
216+
stats_->model_load_end_ms = llm::time_in_ms();
185217
}
186218

187219
if (config.warming) {
@@ -207,7 +239,7 @@ Error Runner::generate(
207239
// First token time only measures the time it takes to encode the prompt and
208240
// return a response token.
209241

210-
stats_.inference_start_ms = llm::time_in_ms();
242+
stats_->inference_start_ms = llm::time_in_ms();
211243
shouldStop_ = false;
212244

213245
::tokenizers::Result<std::vector<uint64_t>> encode_res = tokenizer_->encode(
@@ -248,8 +280,8 @@ Error Runner::generate(
248280
auto prefill_res = text_prefiller_->prefill(prompt_tokens, pos);
249281
ET_CHECK_OK_OR_RETURN_ERROR(prefill_res.error());
250282
uint64_t cur_token = prefill_res.get();
251-
stats_.first_token_ms = llm::time_in_ms();
252-
stats_.prompt_eval_end_ms = llm::time_in_ms();
283+
stats_->first_token_ms = llm::time_in_ms();
284+
stats_->prompt_eval_end_ms = llm::time_in_ms();
253285

254286
// print the first token from prefill. No prev_token so use cur_token for it.
255287
wrapped_callback(
@@ -270,7 +302,7 @@ Error Runner::generate(
270302
temperature_ == -1.0f ? config.temperature : temperature_,
271303
wrapped_callback));
272304

273-
stats_.inference_end_ms = llm::time_in_ms();
305+
stats_->inference_end_ms = llm::time_in_ms();
274306
if (!config.warming) {
275307
printf("\n");
276308
}
@@ -283,17 +315,17 @@ Error Runner::generate(
283315
RUNNER_ET_LOG(config.warming, "Max new tokens %i reached!", max_new_tokens);
284316
}
285317

286-
stats_.num_prompt_tokens = num_prompt_tokens;
287-
stats_.num_generated_tokens = num_generated_tokens;
318+
stats_->num_prompt_tokens = num_prompt_tokens;
319+
stats_->num_generated_tokens = num_generated_tokens;
288320

289321
if (config.warming) {
290322
ET_LOG(Info, "Warmup run finished!");
291323
} else {
292324
// Do not print report during warmup
293-
::executorch::llm::print_report(stats_);
325+
::executorch::llm::print_report(*stats_);
294326
}
295327
if (stats_callback) {
296-
stats_callback(stats_);
328+
stats_callback(*stats_);
297329
}
298330

299331
return Error::Ok;

0 commit comments

Comments
 (0)