Skip to content

Use dependency injection for runner #10326

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ - (instancetype)initWithModelPath:(NSString*)modelPath
self = [super init];
if (self) {
[ExecuTorchLog.sharedLog addSink:self];
_runner = std::make_unique<example::Runner>(
_runner = example::Runner::create(
modelPath.UTF8String, tokenizerPath.UTF8String);
}
return self;
Expand Down
7 changes: 4 additions & 3 deletions examples/models/llama/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -74,17 +74,18 @@ int32_t main(int32_t argc, char** argv) {
#endif
// create llama runner
// @lint-ignore CLANGTIDY facebook-hte-Deprecated
example::Runner runner(model_path, tokenizer_path);
std::unique_ptr<example::Runner> runner =
example::Runner::create(model_path, tokenizer_path);

if (warmup) {
// @lint-ignore CLANGTIDY facebook-hte-Deprecated
runner.warmup(prompt, /*max_new_tokens=*/seq_len);
runner->warmup(prompt, /*max_new_tokens=*/seq_len);
}
// generate
executorch::extension::llm::GenerationConfig config{
.seq_len = seq_len, .temperature = temperature};
// @lint-ignore CLANGTIDY facebook-hte-Deprecated
runner.generate(prompt, config);
runner->generate(prompt, config);

return 0;
}
221 changes: 127 additions & 94 deletions examples/models/llama/runner/runner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,6 @@

#include <executorch/examples/models/llama/runner/runner.h>

#include <algorithm>
#include <ctime>

#include <executorch/extension/llm/runner/util.h>

#include <executorch/examples/models/llama/tokenizer/llama_tiktoken.h>
Expand All @@ -35,129 +32,165 @@ static constexpr auto kMaxSeqLen = "get_max_seq_len";
static constexpr auto kMaxContextLen = "get_max_context_len";
static constexpr auto kVocabSize = "get_vocab_size";
static constexpr auto kUseKVCache = "use_kv_cache";
static constexpr auto kUseSDPAWithKVCache = "use_sdpa_with_kv_cache";
} // namespace

Runner::Runner(
std::unique_ptr<Runner> Runner::create(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Specifically, it looks like create() could just be another Runner constructor; it's not clear to me why it has to be a static method instead.

const std::string& model_path,
const std::string& tokenizer_path,
std::optional<const std::string> data_path)
// NOTE: we observed ~2x loading performance increase on iPhone 15
// and a ~5% improvement on Galaxy S22 by switching to
// FileDataLoader instead of MmapDataLoader + UseMlockIgnoreErrors.
: tokenizer_path_(tokenizer_path),
metadata_({
{kEnableDynamicShape, false},
{kMaxSeqLen, 128},
{kMaxContextLen, 128},
{kUseKVCache, true},
{kUseSDPAWithKVCache, false},
}) {
if (data_path.has_value()) {
module_ = std::make_unique<Module>(
model_path, data_path.value(), Module::LoadMode::File);
} else {
module_ = std::make_unique<Module>(model_path, Module::LoadMode::File);
}
std::optional<const std::string> data_path,
float temperature) {
ET_LOG(
Info,
"Creating LLaMa runner: model_path=%s, tokenizer_path=%s",
model_path.c_str(),
tokenizer_path.c_str());
}

[[deprecated(
"This constructor is deprecated. Use the constructor without temperature parameter instead.")]]
Runner::Runner(
const std::string& model_path,
const std::string& tokenizer_path,
const float temperature,
std::optional<const std::string> data_path)
: Runner(model_path, tokenizer_path, std::move(data_path)) {
temperature_ = temperature;
}
// Create the Module
std::unique_ptr<Module> module;
if (data_path.has_value()) {
module = std::make_unique<Module>(
model_path, data_path.value(), Module::LoadMode::File);
} else {
module = std::make_unique<Module>(model_path, Module::LoadMode::File);
}

bool Runner::is_loaded() const {
return module_->is_loaded() && tokenizer_ && text_decoder_runner_ &&
text_prefiller_ && text_token_generator_;
}
// Initialize metadata with default values
std::unordered_map<std::string, int64_t> metadata({
{kEnableDynamicShape, false},
{kMaxSeqLen, 128},
{kMaxContextLen, 128},
{kUseKVCache, true},
});

Error Runner::load() {
if (is_loaded()) {
return Error::Ok;
}
ET_CHECK_OK_OR_RETURN_ERROR(module_->load_method("forward"));
// load tokenizer. Assuming tiktoken is the default tokenizer
tokenizer_ = nullptr;
tokenizer_ = get_tiktoken_for_llama();
::tokenizers::Error err = tokenizer_->load(tokenizer_path_);
// Rely on tiktoken to throw error if the artifact is incompatible. Then we
// fallback to BPE tokenizer.
if (err != ::tokenizers::Error::Ok) {
// Create and load tokenizer
std::unique_ptr<::tokenizers::Tokenizer> tokenizer = get_tiktoken_for_llama();
::tokenizers::Error tk_err = tokenizer->load(tokenizer_path);

// Fallback to BPE tokenizer if tiktoken fails
if (tk_err != ::tokenizers::Error::Ok) {
ET_LOG(
Info,
"Failed to load %s as a Tiktoken artifact, trying BPE tokenizer",
tokenizer_path_.c_str());
tokenizer_.reset();
tokenizer_ = std::make_unique<::tokenizers::Llama2cTokenizer>();
err = tokenizer_->load(tokenizer_path_);
ET_CHECK_TK_OK_OR_RETURN_ERROR(
err,
"Failed to load %s as a llama2.c tokenizer artifact",
tokenizer_path_.c_str());
tokenizer_path.c_str());
tokenizer.reset();
tokenizer = std::make_unique<::tokenizers::Llama2cTokenizer>();
tk_err = tokenizer->load(tokenizer_path);
if (tk_err != ::tokenizers::Error::Ok) {
ET_LOG(
Error,
"Failed to load %s as a llama2.c tokenizer artifact",
tokenizer_path.c_str());
return nullptr;
}
}

ET_LOG(Info, "Reading metadata from model");

metadata_[kBosId] = tokenizer_->bos_tok();
// Set tokenizer-related metadata
metadata[kBosId] = tokenizer->bos_tok();
auto eos_ids = std::make_unique<std::unordered_set<uint64_t>>(
std::unordered_set<uint64_t>{tokenizer_->eos_tok()});
metadata_[kVocabSize] = tokenizer_->vocab_size();

const auto method_names =
ET_UNWRAP(module_->method_names(), "Failed reading method names");
std::unordered_set<uint64_t>{tokenizer->eos_tok()});
metadata[kVocabSize] = tokenizer->vocab_size();

// Read metadata from the model
auto method_names_result = module->method_names();
if (method_names_result.error() != Error::Ok) {
ET_LOG(Error, "Failed reading method names");
return nullptr;
}
const auto method_names = method_names_result.get();

for (auto& pair : metadata_) {
for (auto& pair : metadata) {
const auto& method_name = pair.first;
auto& value = pair.second;

if (method_names.count(method_name)) {
value = ET_UNWRAP(module_->get(method_name))
.toScalar()
.to<decltype(metadata_)::mapped_type>();
auto get_result = module->get(method_name);
value = get_result.get().toScalar().to<decltype(metadata)::mapped_type>();
} else {
ET_LOG(
Info,
"Methond %s not found, using the default value %" PRId64,
"Method %s not found, using the default value %" PRId64,
method_name.c_str(),
value);
}
ET_LOG(Info, "Metadata: %s = %" PRId64, method_name.c_str(), value);
}

// Get EOS IDs if available
if (method_names.count(kEosIds)) {
eos_ids->clear();
for (const auto& eos_id : ET_UNWRAP(module_->execute(kEosIds))) {
auto execute_result = module->execute(kEosIds);
if (execute_result.error() != Error::Ok) {
ET_LOG(Error, "Failed to execute %s", kEosIds);
return nullptr;
}
for (const auto& eos_id : execute_result.get()) {
auto value = eos_id.toScalar().to<int64_t>();
eos_ids->emplace(value);
ET_LOG(Info, "eos_id = %" PRId64, value);
}
}
// @lint-ignore CLANGTIDY facebook-hte-Deprecated
text_decoder_runner_ = std::make_unique<llm::TextDecoderRunner>(
module_.get(), metadata_.at(kUseKVCache));
text_prefiller_ = std::make_unique<llm::TextPrefiller>(
text_decoder_runner_.get(),
metadata_.at(kUseKVCache),
metadata_.at(kEnableDynamicShape),
metadata_.at(kMaxSeqLen));

text_token_generator_ = std::make_unique<llm::TextTokenGenerator>(
tokenizer_.get(),
text_decoder_runner_.get(),
metadata_.at(kUseKVCache),

// Create text_decoder_runner. Use a shared_ptr so that it can be shared with
// TextPrefiller and TextTokenGenerator
auto text_decoder_runner = std::make_unique<llm::TextDecoderRunner>(
module.get(), metadata.at(kUseKVCache));

// Create text_prefiller
auto text_prefiller = std::make_unique<llm::TextPrefiller>(
text_decoder_runner.get(),
metadata.at(kUseKVCache),
metadata.at(kEnableDynamicShape),
metadata.at(kMaxSeqLen));

// Create text_token_generator with stats
auto stats = std::make_unique<llm::Stats>();
auto text_token_generator = std::make_unique<llm::TextTokenGenerator>(
tokenizer.get(),
text_decoder_runner.get(),
metadata.at(kUseKVCache),
std::move(eos_ids),
&stats_);
stats.get());

// Create and return the Runner instance
return std::make_unique<Runner>(
std::move(metadata),
std::move(tokenizer),
std::move(text_prefiller),
std::move(text_token_generator),
std::move(stats),
temperature);
}

Runner::Runner(
std::unordered_map<std::string, int64_t> metadata,
std::unique_ptr<::tokenizers::Tokenizer> tokenizer,
std::unique_ptr<::executorch::extension::llm::TextPrefiller> text_prefiller,
std::unique_ptr<::executorch::extension::llm::TextTokenGenerator>
text_token_generator,
std::unique_ptr<::executorch::extension::llm::Stats> stats,
float temperature)
: tokenizer_(std::move(tokenizer)),
metadata_(std::move(metadata)),
text_prefiller_(std::move(text_prefiller)),
text_token_generator_(std::move(text_token_generator)),
stats_(std::move(stats)),
temperature_(temperature) {
// Note: This constructor assumes that text_prefiller and text_token_generator
// already have references to the Module and TextDecoderRunner they need
}

bool Runner::is_loaded() const {
return text_prefiller_->is_loaded() && text_token_generator_->is_loaded();
}

Error Runner::load() {
if (is_loaded()) {
return Error::Ok;
}
ET_CHECK_OK_OR_RETURN_ERROR(text_prefiller_->load());
ET_CHECK_OK_OR_RETURN_ERROR(text_token_generator_->load());
return Error::Ok;
}

Expand All @@ -178,9 +211,9 @@ Error Runner::generate(
// Use ones-initialized inputs.
ET_CHECK_MSG(!prompt.empty(), "Prompt cannot be null");
if (!is_loaded()) {
stats_.model_load_start_ms = llm::time_in_ms();
stats_->model_load_start_ms = llm::time_in_ms();
ET_CHECK_OK_OR_RETURN_ERROR(load());
stats_.model_load_end_ms = llm::time_in_ms();
stats_->model_load_end_ms = llm::time_in_ms();
}

if (config.warming) {
Expand All @@ -206,7 +239,7 @@ Error Runner::generate(
// First token time only measures the time it takes to encode the prompt and
// return a response token.

stats_.inference_start_ms = llm::time_in_ms();
stats_->inference_start_ms = llm::time_in_ms();
shouldStop_ = false;

::tokenizers::Result<std::vector<uint64_t>> encode_res = tokenizer_->encode(
Expand Down Expand Up @@ -247,8 +280,8 @@ Error Runner::generate(
auto prefill_res = text_prefiller_->prefill(prompt_tokens, pos);
ET_CHECK_OK_OR_RETURN_ERROR(prefill_res.error());
uint64_t cur_token = prefill_res.get();
stats_.first_token_ms = llm::time_in_ms();
stats_.prompt_eval_end_ms = llm::time_in_ms();
stats_->first_token_ms = llm::time_in_ms();
stats_->prompt_eval_end_ms = llm::time_in_ms();

// print the first token from prefill. No prev_token so use cur_token for it.
wrapped_callback(
Expand All @@ -269,7 +302,7 @@ Error Runner::generate(
temperature_ == -1.0f ? config.temperature : temperature_,
wrapped_callback));

stats_.inference_end_ms = llm::time_in_ms();
stats_->inference_end_ms = llm::time_in_ms();
if (!config.warming) {
printf("\n");
}
Expand All @@ -282,17 +315,17 @@ Error Runner::generate(
RUNNER_ET_LOG(config.warming, "Max new tokens %i reached!", max_new_tokens);
}

stats_.num_prompt_tokens = num_prompt_tokens;
stats_.num_generated_tokens = num_generated_tokens;
stats_->num_prompt_tokens = num_prompt_tokens;
stats_->num_generated_tokens = num_generated_tokens;

if (config.warming) {
ET_LOG(Info, "Warmup run finished!");
} else {
// Do not print report during warmup
::executorch::llm::print_report(stats_);
::executorch::llm::print_report(*stats_);
}
if (stats_callback) {
stats_callback(stats_);
stats_callback(*stats_);
}

return Error::Ok;
Expand Down
Loading
Loading