Skip to content

Commit 16e3256

Browse files
larryliu0820facebook-github-bot
authored andcommitted
Use dependency injection for runner (#10326)
Summary: X-link: pytorch-labs/tokenizers#53 Pass in runner components, move most of the instantiation logic from `load()` to a new static API `create()`. This adds testability to runner components. Differential Revision: D73165546
1 parent ad7cd2b commit 16e3256

File tree

12 files changed

+543
-116
lines changed

12 files changed

+543
-116
lines changed

examples/demo-apps/apple_ios/LLaMA/LLaMARunner/LLaMARunner/Exported/LLaMARunner.mm

+1-1
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ - (instancetype)initWithModelPath:(NSString*)modelPath
3131
self = [super init];
3232
if (self) {
3333
[ExecuTorchLog.sharedLog addSink:self];
34-
_runner = std::make_unique<example::Runner>(
34+
_runner = example::Runner::create(
3535
modelPath.UTF8String, tokenizerPath.UTF8String);
3636
}
3737
return self;

examples/models/llama/main.cpp

+4-3
Original file line numberDiff line numberDiff line change
@@ -74,17 +74,18 @@ int32_t main(int32_t argc, char** argv) {
7474
#endif
7575
// create llama runner
7676
// @lint-ignore CLANGTIDY facebook-hte-Deprecated
77-
example::Runner runner(model_path, tokenizer_path);
77+
std::unique_ptr<example::Runner> runner =
78+
example::Runner::create(model_path, tokenizer_path);
7879

7980
if (warmup) {
8081
// @lint-ignore CLANGTIDY facebook-hte-Deprecated
81-
runner.warmup(prompt, /*max_new_tokens=*/seq_len);
82+
runner->warmup(prompt, /*max_new_tokens=*/seq_len);
8283
}
8384
// generate
8485
executorch::extension::llm::GenerationConfig config{
8586
.seq_len = seq_len, .temperature = temperature};
8687
// @lint-ignore CLANGTIDY facebook-hte-Deprecated
87-
runner.generate(prompt, config);
88+
runner->generate(prompt, config);
8889

8990
return 0;
9091
}

examples/models/llama/runner/runner.cpp

+113-84
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,6 @@
1111

1212
#include <executorch/examples/models/llama/runner/runner.h>
1313

14-
#include <algorithm>
15-
#include <ctime>
16-
1714
#include <executorch/extension/llm/runner/util.h>
1815

1916
#include <executorch/examples/models/llama/tokenizer/llama_tiktoken.h>
@@ -35,129 +32,161 @@ static constexpr auto kMaxSeqLen = "get_max_seq_len";
3532
static constexpr auto kMaxContextLen = "get_max_context_len";
3633
static constexpr auto kVocabSize = "get_vocab_size";
3734
static constexpr auto kUseKVCache = "use_kv_cache";
38-
static constexpr auto kUseSDPAWithKVCache = "use_sdpa_with_kv_cache";
3935
} // namespace
4036

41-
Runner::Runner(
37+
std::unique_ptr<Runner> Runner::create(
4238
const std::string& model_path,
4339
const std::string& tokenizer_path,
44-
std::optional<const std::string> data_path)
45-
// NOTE: we observed ~2x loading performance increase on iPhone 15
46-
// and a ~5% improvement on Galaxy S22 by switching to
47-
// FileDataLoader instead of MmapDataLoader + UseMlockIgnoreErrors.
48-
: tokenizer_path_(tokenizer_path),
49-
metadata_({
50-
{kEnableDynamicShape, false},
51-
{kMaxSeqLen, 128},
52-
{kMaxContextLen, 128},
53-
{kUseKVCache, true},
54-
{kUseSDPAWithKVCache, false},
55-
}) {
56-
if (data_path.has_value()) {
57-
module_ = std::make_unique<Module>(
58-
model_path, data_path.value(), Module::LoadMode::File);
59-
} else {
60-
module_ = std::make_unique<Module>(model_path, Module::LoadMode::File);
61-
}
40+
std::optional<const std::string> data_path,
41+
float temperature) {
6242
ET_LOG(
6343
Info,
6444
"Creating LLaMa runner: model_path=%s, tokenizer_path=%s",
6545
model_path.c_str(),
6646
tokenizer_path.c_str());
67-
}
6847

69-
[[deprecated(
70-
"This constructor is deprecated. Use the constructor without temperature parameter instead.")]]
71-
Runner::Runner(
72-
const std::string& model_path,
73-
const std::string& tokenizer_path,
74-
const float temperature,
75-
std::optional<const std::string> data_path)
76-
: Runner(model_path, tokenizer_path, std::move(data_path)) {
77-
temperature_ = temperature;
78-
}
48+
// Create the Module
49+
std::unique_ptr<Module> module;
50+
if (data_path.has_value()) {
51+
module = std::make_unique<Module>(
52+
model_path, data_path.value(), Module::LoadMode::File);
53+
} else {
54+
module = std::make_unique<Module>(model_path, Module::LoadMode::File);
55+
}
7956

80-
bool Runner::is_loaded() const {
81-
return module_->is_loaded() && tokenizer_ && text_decoder_runner_ &&
82-
text_prefiller_ && text_token_generator_;
83-
}
57+
// Initialize metadata with default values
58+
std::unordered_map<std::string, int64_t> metadata({
59+
{kEnableDynamicShape, false},
60+
{kMaxSeqLen, 128},
61+
{kMaxContextLen, 128},
62+
{kUseKVCache, true},
63+
});
8464

85-
Error Runner::load() {
86-
if (is_loaded()) {
87-
return Error::Ok;
88-
}
89-
ET_CHECK_OK_OR_RETURN_ERROR(module_->load_method("forward"));
90-
// load tokenizer. Assuming tiktoken is the default tokenizer
91-
tokenizer_ = nullptr;
92-
tokenizer_ = get_tiktoken_for_llama();
93-
::tokenizers::Error err = tokenizer_->load(tokenizer_path_);
94-
// Rely on tiktoken to throw error if the artifact is incompatible. Then we
95-
// fallback to BPE tokenizer.
96-
if (err != ::tokenizers::Error::Ok) {
65+
// Create and load tokenizer
66+
std::unique_ptr<::tokenizers::Tokenizer> tokenizer = get_tiktoken_for_llama();
67+
::tokenizers::Error tk_err = tokenizer->load(tokenizer_path);
68+
69+
// Fallback to BPE tokenizer if tiktoken fails
70+
if (tk_err != ::tokenizers::Error::Ok) {
9771
ET_LOG(
9872
Info,
9973
"Failed to load %s as a Tiktoken artifact, trying BPE tokenizer",
100-
tokenizer_path_.c_str());
101-
tokenizer_.reset();
102-
tokenizer_ = std::make_unique<::tokenizers::Llama2cTokenizer>();
103-
err = tokenizer_->load(tokenizer_path_);
104-
ET_CHECK_TK_OK_OR_RETURN_ERROR(
105-
err,
106-
"Failed to load %s as a llama2.c tokenizer artifact",
107-
tokenizer_path_.c_str());
74+
tokenizer_path.c_str());
75+
tokenizer.reset();
76+
tokenizer = std::make_unique<::tokenizers::Llama2cTokenizer>();
77+
tk_err = tokenizer->load(tokenizer_path);
78+
if (tk_err != ::tokenizers::Error::Ok) {
79+
ET_LOG(
80+
Error,
81+
"Failed to load %s as a llama2.c tokenizer artifact",
82+
tokenizer_path.c_str());
83+
return nullptr;
84+
}
10885
}
10986

11087
ET_LOG(Info, "Reading metadata from model");
11188

112-
metadata_[kBosId] = tokenizer_->bos_tok();
89+
// Set tokenizer-related metadata
90+
metadata[kBosId] = tokenizer->bos_tok();
11391
auto eos_ids = std::make_unique<std::unordered_set<uint64_t>>(
114-
std::unordered_set<uint64_t>{tokenizer_->eos_tok()});
115-
metadata_[kVocabSize] = tokenizer_->vocab_size();
116-
117-
const auto method_names =
118-
ET_UNWRAP(module_->method_names(), "Failed reading method names");
92+
std::unordered_set<uint64_t>{tokenizer->eos_tok()});
93+
metadata[kVocabSize] = tokenizer->vocab_size();
94+
95+
// Read metadata from the model
96+
auto method_names_result = module->method_names();
97+
if (method_names_result.error() != Error::Ok) {
98+
ET_LOG(Error, "Failed reading method names");
99+
return nullptr;
100+
}
101+
const auto method_names = method_names_result.get();
119102

120-
for (auto& pair : metadata_) {
103+
for (auto& pair : metadata) {
121104
const auto& method_name = pair.first;
122105
auto& value = pair.second;
123106

124107
if (method_names.count(method_name)) {
125-
value = ET_UNWRAP(module_->get(method_name))
126-
.toScalar()
127-
.to<decltype(metadata_)::mapped_type>();
108+
auto get_result = module->get(method_name);
109+
value = get_result.get().toScalar().to<decltype(metadata)::mapped_type>();
128110
} else {
129111
ET_LOG(
130112
Info,
131-
"Methond %s not found, using the default value %" PRId64,
113+
"Method %s not found, using the default value %" PRId64,
132114
method_name.c_str(),
133115
value);
134116
}
135117
ET_LOG(Info, "Metadata: %s = %" PRId64, method_name.c_str(), value);
136118
}
119+
120+
// Get EOS IDs if available
137121
if (method_names.count(kEosIds)) {
138122
eos_ids->clear();
139-
for (const auto& eos_id : ET_UNWRAP(module_->execute(kEosIds))) {
123+
auto execute_result = module->execute(kEosIds);
124+
if (execute_result.error() != Error::Ok) {
125+
ET_LOG(Error, "Failed to execute %s", kEosIds);
126+
return nullptr;
127+
}
128+
for (const auto& eos_id : execute_result.get()) {
140129
auto value = eos_id.toScalar().to<int64_t>();
141130
eos_ids->emplace(value);
142131
ET_LOG(Info, "eos_id = %" PRId64, value);
143132
}
144133
}
145-
// @lint-ignore CLANGTIDY facebook-hte-Deprecated
146-
text_decoder_runner_ = std::make_unique<llm::TextDecoderRunner>(
147-
module_.get(), metadata_.at(kUseKVCache));
148-
text_prefiller_ = std::make_unique<llm::TextPrefiller>(
149-
text_decoder_runner_.get(),
150-
metadata_.at(kUseKVCache),
151-
metadata_.at(kEnableDynamicShape),
152-
metadata_.at(kMaxSeqLen));
153-
154-
text_token_generator_ = std::make_unique<llm::TextTokenGenerator>(
155-
tokenizer_.get(),
156-
text_decoder_runner_.get(),
157-
metadata_.at(kUseKVCache),
134+
135+
// Create text_decoder_runner
136+
auto text_decoder_runner = std::make_unique<llm::TextDecoderRunner>(
137+
module.get(), metadata.at(kUseKVCache));
138+
139+
// Create text_prefiller
140+
auto text_prefiller = std::make_unique<llm::TextPrefiller>(
141+
text_decoder_runner.get(),
142+
metadata.at(kUseKVCache),
143+
metadata.at(kEnableDynamicShape),
144+
metadata.at(kMaxSeqLen));
145+
146+
// Create text_token_generator with stats
147+
auto stats = new llm::Stats();
148+
auto text_token_generator = std::make_unique<llm::TextTokenGenerator>(
149+
tokenizer.get(),
150+
text_decoder_runner.get(),
151+
metadata.at(kUseKVCache),
158152
std::move(eos_ids),
159-
&stats_);
153+
stats);
154+
155+
// Create and return the Runner instance
156+
return std::make_unique<Runner>(
157+
std::move(metadata),
158+
std::move(tokenizer),
159+
std::move(text_prefiller),
160+
std::move(text_token_generator),
161+
temperature);
162+
}
160163

164+
Runner::Runner(
165+
std::unordered_map<std::string, int64_t> metadata,
166+
std::unique_ptr<::tokenizers::Tokenizer> tokenizer,
167+
std::unique_ptr<::executorch::extension::llm::TextPrefiller> text_prefiller,
168+
std::unique_ptr<::executorch::extension::llm::TextTokenGenerator>
169+
text_token_generator,
170+
float temperature)
171+
: tokenizer_(std::move(tokenizer)),
172+
metadata_(std::move(metadata)),
173+
text_prefiller_(std::move(text_prefiller)),
174+
text_token_generator_(std::move(text_token_generator)),
175+
temperature_(temperature) {
176+
// Note: This constructor assumes that text_prefiller and text_token_generator
177+
// already have references to the Module and TextDecoderRunner they need
178+
}
179+
180+
bool Runner::is_loaded() const {
181+
return text_prefiller_->is_loaded() && text_token_generator_->is_loaded();
182+
}
183+
184+
Error Runner::load() {
185+
if (is_loaded()) {
186+
return Error::Ok;
187+
}
188+
ET_CHECK_OK_OR_RETURN_ERROR(text_prefiller_->load());
189+
ET_CHECK_OK_OR_RETURN_ERROR(text_token_generator_->load());
161190
return Error::Ok;
162191
}
163192

examples/models/llama/runner/runner.h

+14-14
Original file line numberDiff line numberDiff line change
@@ -30,18 +30,22 @@ namespace example {
3030

3131
class ET_EXPERIMENTAL Runner : public executorch::extension::llm::IRunner {
3232
public:
33-
explicit Runner(
33+
// Static factory method to create a Runner instance
34+
static std::unique_ptr<Runner> create(
3435
const std::string& model_path,
3536
const std::string& tokenizer_path,
36-
std::optional<const std::string> data_path = std::nullopt);
37+
std::optional<const std::string> data_path = std::nullopt,
38+
float temperature = -1.0f);
3739

38-
[[deprecated(
39-
"This constructor is deprecated. Use the constructor without temperature parameter instead.")]]
40+
// Constructor with dependency injection
4041
explicit Runner(
41-
const std::string& model_path,
42-
const std::string& tokenizer_path,
43-
const float temperature,
44-
std::optional<const std::string> data_path = std::nullopt);
42+
std::unordered_map<std::string, int64_t> metadata,
43+
std::unique_ptr<::tokenizers::Tokenizer> tokenizer,
44+
std::unique_ptr<::executorch::extension::llm::TextPrefiller>
45+
text_prefiller,
46+
std::unique_ptr<::executorch::extension::llm::TextTokenGenerator>
47+
text_token_generator,
48+
float temperature = -1.0f);
4549

4650
bool is_loaded() const override;
4751
::executorch::runtime::Error load() override;
@@ -59,18 +63,14 @@ class ET_EXPERIMENTAL Runner : public executorch::extension::llm::IRunner {
5963
private:
6064
bool shouldStop_{false};
6165

62-
// model
63-
std::unique_ptr<::executorch::extension::Module> module_;
64-
std::string tokenizer_path_;
66+
// Components
6567
std::unique_ptr<::tokenizers::Tokenizer> tokenizer_;
6668
std::unordered_map<std::string, int64_t> metadata_;
67-
std::unique_ptr<::executorch::extension::llm::TextDecoderRunner>
68-
text_decoder_runner_;
6969
std::unique_ptr<::executorch::extension::llm::TextPrefiller> text_prefiller_;
7070
std::unique_ptr<::executorch::extension::llm::TextTokenGenerator>
7171
text_token_generator_;
7272

73-
// stats
73+
// Stats
7474
::executorch::extension::llm::Stats stats_;
7575

7676
// temperature.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# This file should be formatted with
8+
# ~~~
9+
# cmake-format -i CMakeLists.txt
10+
# ~~~
11+
# It should also be cmake-lint clean.
12+
#
13+
14+
cmake_minimum_required(VERSION 3.19)
15+
16+
set(EXECUTORCH_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/../../../..)
17+
18+
include(${EXECUTORCH_ROOT}/tools/cmake/Test.cmake)
19+
20+
set(_test_srcs runner_test.cpp)
21+
22+
et_cxx_test(
23+
runner_test
24+
SOURCES
25+
${_test_srcs}
26+
EXTRA_LIBS
27+
executorch
28+
)
+14
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# Any targets that should be shared between fbcode and xplat must be defined in
8+
# targets.bzl. This file can contain fbcode-only targets.
9+
10+
load(":targets.bzl", "define_common_targets")
11+
12+
oncall("executorch")
13+
14+
define_common_targets()

0 commit comments

Comments
 (0)