Skip to content

Use dependency injection for runner #10326

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 22, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ - (instancetype)initWithModelPath:(NSString*)modelPath
self = [super init];
if (self) {
[ExecuTorchLog.sharedLog addSink:self];
_runner = std::make_unique<example::Runner>(
_runner = example::Runner::create(
modelPath.UTF8String, tokenizerPath.UTF8String);
}
return self;
Expand Down
11 changes: 5 additions & 6 deletions examples/models/llama/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
* @lint-ignore-every CLANGTIDY facebook-hte-Deprecated
*/

#include <gflags/gflags.h>
Expand Down Expand Up @@ -80,18 +81,16 @@ int32_t main(int32_t argc, char** argv) {
}
#endif
// create llama runner
// @lint-ignore CLANGTIDY facebook-hte-Deprecated
example::Runner runner(model_path, tokenizer_path, data_path);
std::unique_ptr<example::Runner> runner =
example::Runner::create(model_path, tokenizer_path, data_path);

if (warmup) {
// @lint-ignore CLANGTIDY facebook-hte-Deprecated
runner.warmup(prompt, /*max_new_tokens=*/seq_len);
runner->warmup(prompt, /*max_new_tokens=*/seq_len);
}
// generate
executorch::extension::llm::GenerationConfig config{
.seq_len = seq_len, .temperature = temperature};
// @lint-ignore CLANGTIDY facebook-hte-Deprecated
runner.generate(prompt, config);
runner->generate(prompt, config);

return 0;
}
225 changes: 130 additions & 95 deletions examples/models/llama/runner/runner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,14 @@
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
* @lint-ignore-every CLANGTIDY facebook-hte-Deprecated
*/

// A simple llama2 runner that includes preprocessing and post processing logic.
// The module takes in a string as input and emits a string as output.

#include <executorch/examples/models/llama/runner/runner.h>

#include <algorithm>
#include <ctime>

#include <executorch/extension/llm/runner/util.h>

#include <executorch/examples/models/llama/tokenizer/llama_tiktoken.h>
Expand Down Expand Up @@ -62,125 +60,162 @@ std::unique_ptr<::tokenizers::Tokenizer> load_tokenizer(
}
} // namespace

Runner::Runner(
std::unique_ptr<Runner> Runner::create(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Specifically, it looks like create() could just be another Runner constructor; it's not clear to me why it has to be a static method instead.

const std::string& model_path,
const std::string& tokenizer_path,
std::optional<const std::string> data_path)
// NOTE: we observed ~2x loading performance increase on iPhone 15
// and a ~5% improvement on Galaxy S22 by switching to
// FileDataLoader instead of MmapDataLoader + UseMlockIgnoreErrors.
: tokenizer_path_(tokenizer_path),
metadata_({
{kEnableDynamicShape, false},
{kMaxSeqLen, 128},
{kMaxContextLen, 128},
{kUseKVCache, true},
{kUseSDPAWithKVCache, false},
}) {
if (data_path.has_value()) {
module_ = std::make_unique<Module>(
model_path, data_path.value(), Module::LoadMode::File);
} else {
module_ = std::make_unique<Module>(model_path, Module::LoadMode::File);
}
std::optional<const std::string> data_path,
float temperature) {
ET_LOG(
Info,
"Creating LLaMa runner: model_path=%s, tokenizer_path=%s",
model_path.c_str(),
tokenizer_path.c_str());
}

[[deprecated(
"This constructor is deprecated. Use the constructor without temperature parameter instead.")]]
Runner::Runner(
const std::string& model_path,
const std::string& tokenizer_path,
const float temperature,
std::optional<const std::string> data_path)
: Runner(model_path, tokenizer_path, std::move(data_path)) {
temperature_ = temperature;
}

bool Runner::is_loaded() const {
return module_->is_loaded() && tokenizer_ && text_decoder_runner_ &&
text_prefiller_ && text_token_generator_;
}

Error Runner::load() {
if (is_loaded()) {
return Error::Ok;
// Create the Module
std::unique_ptr<Module> module;
if (data_path.has_value()) {
module = std::make_unique<Module>(
model_path, data_path.value(), Module::LoadMode::File);
} else {
module = std::make_unique<Module>(model_path, Module::LoadMode::File);
}
ET_CHECK_OK_OR_RETURN_ERROR(module_->load_method("forward"));

// Load tokenizer.
tokenizer_ = load_tokenizer(tokenizer_path_);
if (tokenizer_ == nullptr) {
// Initialize metadata with default values
std::unordered_map<std::string, int64_t> metadata({
{kEnableDynamicShape, false},
{kMaxSeqLen, 128},
{kMaxContextLen, 128},
{kUseKVCache, true},
{kUseSDPAWithKVCache, false},
});

// Create and load tokenizer
std::unique_ptr<::tokenizers::Tokenizer> tokenizer =
load_tokenizer(tokenizer_path);

// Fallback to BPE tokenizer if tiktoken fails
if (tokenizer == nullptr) {
ET_LOG(
Info,
"Failed to load %s as a Tiktoken artifact, trying BPE tokenizer",
tokenizer_path_.c_str());
tokenizer_.reset();
// @lint-ignore CLANGTIDY facebook-hte-Deprecated
tokenizer_ = std::make_unique<::tokenizers::Llama2cTokenizer>();
auto err = tokenizer_->load(tokenizer_path_);
ET_CHECK_TK_OK_OR_RETURN_ERROR(
err,
"Failed to load %s as a llama2.c tokenizer artifact",
tokenizer_path_.c_str());
return ::executorch::runtime::Error::InvalidArgument;
"Failed to load %s as a Tiktoken, Sentencepiece or Llama2.c tokenizer, make sure the artifact is one of these types",
tokenizer_path.c_str());
return nullptr;
}

ET_LOG(Info, "Reading metadata from model");

metadata_[kBosId] = tokenizer_->bos_tok();
// Set tokenizer-related metadata
metadata[kBosId] = tokenizer->bos_tok();
auto eos_ids = std::make_unique<std::unordered_set<uint64_t>>(
std::unordered_set<uint64_t>{tokenizer_->eos_tok()});
metadata_[kVocabSize] = tokenizer_->vocab_size();

const auto method_names =
ET_UNWRAP(module_->method_names(), "Failed reading method names");
std::unordered_set<uint64_t>{tokenizer->eos_tok()});
metadata[kVocabSize] = tokenizer->vocab_size();

// Read metadata from the model
auto method_names_result = module->method_names();
if (method_names_result.error() != Error::Ok) {
ET_LOG(Error, "Failed reading method names");
return nullptr;
}
const auto method_names = method_names_result.get();

for (auto& pair : metadata_) {
for (auto& pair : metadata) {
const auto& method_name = pair.first;
auto& value = pair.second;

if (method_names.count(method_name)) {
value = ET_UNWRAP(module_->get(method_name))
.toScalar()
.to<decltype(metadata_)::mapped_type>();
auto get_result = module->get(method_name);
value = get_result.get().toScalar().to<decltype(metadata)::mapped_type>();
} else {
ET_LOG(
Info,
"Methond %s not found, using the default value %" PRId64,
"Method %s not found, using the default value %" PRId64,
method_name.c_str(),
value);
}
ET_LOG(Info, "Metadata: %s = %" PRId64, method_name.c_str(), value);
}

// Get EOS IDs if available
if (method_names.count(kEosIds)) {
eos_ids->clear();
for (const auto& eos_id : ET_UNWRAP(module_->execute(kEosIds))) {
auto execute_result = module->execute(kEosIds);
if (execute_result.error() != Error::Ok) {
ET_LOG(Error, "Failed to execute %s", kEosIds);
return nullptr;
}
for (const auto& eos_id : execute_result.get()) {
auto value = eos_id.toScalar().to<int64_t>();
eos_ids->emplace(value);
ET_LOG(Info, "eos_id = %" PRId64, value);
}
}
// @lint-ignore CLANGTIDY facebook-hte-Deprecated
text_decoder_runner_ = std::make_unique<llm::TextDecoderRunner>(
module_.get(), metadata_.at(kUseKVCache));
text_prefiller_ = std::make_unique<llm::TextPrefiller>(
text_decoder_runner_.get(),
metadata_.at(kUseKVCache),
metadata_.at(kEnableDynamicShape),
metadata_.at(kMaxSeqLen));

text_token_generator_ = std::make_unique<llm::TextTokenGenerator>(
tokenizer_.get(),
text_decoder_runner_.get(),
metadata_.at(kUseKVCache),

// Create text_decoder_runner. Use a shared_ptr so that it can be shared with
// TextPrefiller and TextTokenGenerator
auto text_decoder_runner = std::make_unique<llm::TextDecoderRunner>(
module.get(), metadata.at(kUseKVCache));

// Create text_prefiller
auto text_prefiller = std::make_unique<llm::TextPrefiller>(
text_decoder_runner.get(),
metadata.at(kUseKVCache),
metadata.at(kEnableDynamicShape),
metadata.at(kMaxSeqLen));

// Create text_token_generator with stats
auto stats = std::make_unique<llm::Stats>();
auto text_token_generator = std::make_unique<llm::TextTokenGenerator>(
tokenizer.get(),
text_decoder_runner.get(),
metadata.at(kUseKVCache),
std::move(eos_ids),
&stats_);
stats.get());

// Create and return the Runner instance
return std::make_unique<Runner>(
std::move(metadata),
std::move(tokenizer),
std::move(module),
std::move(text_decoder_runner),
std::move(text_prefiller),
std::move(text_token_generator),
std::move(stats),
temperature);
}

Runner::Runner(
std::unordered_map<std::string, int64_t> metadata,
std::unique_ptr<::tokenizers::Tokenizer> tokenizer,
std::unique_ptr<::executorch::extension::Module> module,
std::unique_ptr<::executorch::extension::llm::TextDecoderRunner>
text_decoder_runner,
std::unique_ptr<::executorch::extension::llm::TextPrefiller> text_prefiller,
std::unique_ptr<::executorch::extension::llm::TextTokenGenerator>
text_token_generator,
std::unique_ptr<::executorch::extension::llm::Stats> stats,
float temperature)
: tokenizer_(std::move(tokenizer)),
metadata_(std::move(metadata)),
module_(std::move(module)),
text_decoder_runner_(std::move(text_decoder_runner)),
text_prefiller_(std::move(text_prefiller)),
text_token_generator_(std::move(text_token_generator)),
stats_(std::move(stats)),
temperature_(temperature) {
// Note: This constructor assumes that text_prefiller and text_token_generator
// already have references to the Module and TextDecoderRunner they need
}

bool Runner::is_loaded() const {
return text_prefiller_->is_loaded() && text_token_generator_->is_loaded();
}

Error Runner::load() {
if (is_loaded()) {
return Error::Ok;
}
ET_CHECK_OK_OR_RETURN_ERROR(text_prefiller_->load());
ET_CHECK_OK_OR_RETURN_ERROR(text_token_generator_->load());
return Error::Ok;
}

Expand All @@ -201,9 +236,9 @@ Error Runner::generate(
// Use ones-initialized inputs.
ET_CHECK_MSG(!prompt.empty(), "Prompt cannot be null");
if (!is_loaded()) {
stats_.model_load_start_ms = llm::time_in_ms();
stats_->model_load_start_ms = llm::time_in_ms();
ET_CHECK_OK_OR_RETURN_ERROR(load());
stats_.model_load_end_ms = llm::time_in_ms();
stats_->model_load_end_ms = llm::time_in_ms();
}

if (config.warming) {
Expand All @@ -229,7 +264,7 @@ Error Runner::generate(
// First token time only measures the time it takes to encode the prompt and
// return a response token.

stats_.inference_start_ms = llm::time_in_ms();
stats_->inference_start_ms = llm::time_in_ms();
shouldStop_ = false;

::tokenizers::Result<std::vector<uint64_t>> encode_res = tokenizer_->encode(
Expand Down Expand Up @@ -270,8 +305,8 @@ Error Runner::generate(
auto prefill_res = text_prefiller_->prefill(prompt_tokens, pos);
ET_CHECK_OK_OR_RETURN_ERROR(prefill_res.error());
uint64_t cur_token = prefill_res.get();
stats_.first_token_ms = llm::time_in_ms();
stats_.prompt_eval_end_ms = llm::time_in_ms();
stats_->first_token_ms = llm::time_in_ms();
stats_->prompt_eval_end_ms = llm::time_in_ms();

// print the first token from prefill. No prev_token so use cur_token for it.
wrapped_callback(
Expand All @@ -292,7 +327,7 @@ Error Runner::generate(
temperature_ == -1.0f ? config.temperature : temperature_,
wrapped_callback));

stats_.inference_end_ms = llm::time_in_ms();
stats_->inference_end_ms = llm::time_in_ms();
if (!config.warming) {
printf("\n");
}
Expand All @@ -305,17 +340,17 @@ Error Runner::generate(
RUNNER_ET_LOG(config.warming, "Max new tokens %i reached!", max_new_tokens);
}

stats_.num_prompt_tokens = num_prompt_tokens;
stats_.num_generated_tokens = num_generated_tokens;
stats_->num_prompt_tokens = num_prompt_tokens;
stats_->num_generated_tokens = num_generated_tokens;

if (config.warming) {
ET_LOG(Info, "Warmup run finished!");
} else {
// Do not print report during warmup
::executorch::llm::print_report(stats_);
::executorch::llm::print_report(*stats_);
}
if (stats_callback) {
stats_callback(stats_);
stats_callback(*stats_);
}

return Error::Ok;
Expand All @@ -329,8 +364,8 @@ Error Runner::warmup(const std::string& prompt, int32_t max_new_tokens) {
// Call generate with the warmup config
Error err = generate(prompt, config);

// Reset stats after warmup
stats_.reset();
// Reset stats after warmup, not resetting the std::unique_ptr!
stats_->reset();
return err;
}

Expand Down
Loading
Loading