From c3dc721181d7bbf9ccdba113f74824157ffce121 Mon Sep 17 00:00:00 2001 From: Mengwei Liu Date: Tue, 22 Apr 2025 15:08:06 -0700 Subject: [PATCH 1/2] Use unique_ptr and shared_ptr properly for runner components (#10338) Summary: Pull Request resolved: https://github.com/pytorch/executorch/pull/10338 The ownership of these components need some clarification. * `Module` should be shared by `TextDecoderRunner` and potentially `TextPrefiller` (or `ImagePrefiller` in multimodal runner). * `TextDecoderRunner` should be shared by the `TextPrefiller` and `TextTokenGenerator`. * `Tokenizer` should be owned by the `Runner` as well as `TextTokenGenerator`. Differential Revision: D73399600 Reviewed By: kirklandsign --- examples/models/llama/runner/runner.cpp | 1 + examples/models/llava/runner/llava_image_prefiller.h | 2 +- extension/llm/runner/text_decoder_runner.h | 9 +++++++-- extension/llm/runner/text_prefiller.h | 10 +++++++++- extension/llm/runner/text_token_generator.h | 8 ++++++++ 5 files changed, 26 insertions(+), 4 deletions(-) diff --git a/examples/models/llama/runner/runner.cpp b/examples/models/llama/runner/runner.cpp index 53c777fa80b..186c2013616 100644 --- a/examples/models/llama/runner/runner.cpp +++ b/examples/models/llama/runner/runner.cpp @@ -99,6 +99,7 @@ Error Runner::load() { "Failed to load %s as a Tiktoken artifact, trying BPE tokenizer", tokenizer_path_.c_str()); tokenizer_.reset(); + // @lint-ignore CLANGTIDY facebook-hte-Deprecated tokenizer_ = std::make_unique<::tokenizers::Llama2cTokenizer>(); err = tokenizer_->load(tokenizer_path_); ET_CHECK_TK_OK_OR_RETURN_ERROR( diff --git a/examples/models/llava/runner/llava_image_prefiller.h b/examples/models/llava/runner/llava_image_prefiller.h index c48fe2b1fe7..762a28d0d07 100644 --- a/examples/models/llava/runner/llava_image_prefiller.h +++ b/examples/models/llava/runner/llava_image_prefiller.h @@ -18,7 +18,7 @@ namespace example { class ET_EXPERIMENTAL LlavaImagePrefiller : public ::executorch::extension::llm::ImagePrefiller { public: - LlavaImagePrefiller(::executorch::extension::Module* module) + explicit LlavaImagePrefiller(::executorch::extension::Module* module) : ImagePrefiller(module){}; /** * Prefill an LLM Module with the given image input. diff --git a/extension/llm/runner/text_decoder_runner.h b/extension/llm/runner/text_decoder_runner.h index b0db48ee75e..6c1256c6b90 100644 --- a/extension/llm/runner/text_decoder_runner.h +++ b/extension/llm/runner/text_decoder_runner.h @@ -14,7 +14,6 @@ #include #include #include -#include namespace executorch { namespace extension { @@ -94,7 +93,13 @@ class ET_EXPERIMENTAL TextDecoderRunner { } protected: - // TODO: use shared_ptr for module + /** + * Note: TextDecoderRunner does not own the Module instance. It is expected + * that the outer class (likely Runner) manages the lifecycle of the Module. + * This means that the responsibility for creating, maintaining, and + * destroying the Module lies outside of TextDecoderRunner. Ensure that the + * Module remains valid for the duration of TextDecoderRunner's usage. + */ Module* module_; bool use_kv_cache_; bool should_stop_{false}; diff --git a/extension/llm/runner/text_prefiller.h b/extension/llm/runner/text_prefiller.h index 0620eadfe9f..28632ad856a 100644 --- a/extension/llm/runner/text_prefiller.h +++ b/extension/llm/runner/text_prefiller.h @@ -24,6 +24,8 @@ class ET_EXPERIMENTAL TextPrefiller { bool use_kv_cache_, bool enable_parallel_prefill, int64_t max_seq_len = 128); + + virtual ~TextPrefiller() = default; /** * Prefill an LLM Module with the given text input. * @param prompt_tokens The text prompt tokens to the LLM Module. Encoded by @@ -32,7 +34,7 @@ class ET_EXPERIMENTAL TextPrefiller { * Module. * @return The next token of the LLM Module after prefill. */ - ::executorch::runtime::Result prefill( + virtual ::executorch::runtime::Result prefill( std::vector& prompt_tokens, int64_t& start_pos); @@ -48,6 +50,12 @@ class ET_EXPERIMENTAL TextPrefiller { int64_t& start_pos); private: + /** + * Note: TextPrefiller does not own the TextDecoderRunner instance. + * The responsibility of managing the lifecycle of TextDecoderRunner + * lies with the outer class or entity (likely Runner) that creates + * and passes the TextDecoderRunner instance to TextPrefiller. + */ TextDecoderRunner* text_decoder_runner_; bool use_kv_cache_; bool enable_parallel_prefill_; diff --git a/extension/llm/runner/text_token_generator.h b/extension/llm/runner/text_token_generator.h index 1b928de1717..38873e25fc1 100644 --- a/extension/llm/runner/text_token_generator.h +++ b/extension/llm/runner/text_token_generator.h @@ -32,6 +32,8 @@ class ET_EXPERIMENTAL TextTokenGenerator { use_kv_cache_(use_kv_cache), stats_(stats) {} + virtual ~TextTokenGenerator() = default; + /** * Token generation loop. * @param tokens prompt tokens as well as the first token generated by @@ -136,6 +138,12 @@ class ET_EXPERIMENTAL TextTokenGenerator { } private: + /** + * Note: TextTokenGenerator does not own the tokenizer_ and + * text_decoder_runner_. The lifecycle of these objects should be managed + * externally, likely in the Runner. This class assumes that the provided + * pointers remain valid for the duration of its use. + */ ::tokenizers::Tokenizer* tokenizer_; TextDecoderRunner* text_decoder_runner_; std::unique_ptr> eos_ids_; From c4bf4be6835d8ea725df59ea46e6efcfbbb1e7d8 Mon Sep 17 00:00:00 2001 From: Mengwei Liu Date: Tue, 22 Apr 2025 16:09:38 -0700 Subject: [PATCH 2/2] Use dependency injection for runner (#10326) Summary: Pull Request resolved: https://github.com/pytorch/executorch/pull/10326 X-link: https://github.com/pytorch-labs/tokenizers/pull/53 Pass in runner components, move most of the instantiation logic from `load()` to a new static API `create()`. This adds testability to runner components. Differential Revision: D73165546 --- .../LLaMARunner/Exported/LLaMARunner.mm | 2 +- examples/models/llama/main.cpp | 7 +- examples/models/llama/runner/runner.cpp | 222 +++++++------ examples/models/llama/runner/runner.h | 29 +- .../models/llama/runner/test/CMakeLists.txt | 28 ++ examples/models/llama/runner/test/TARGETS | 14 + .../models/llama/runner/test/runner_test.cpp | 297 ++++++++++++++++++ examples/models/llama/runner/test/targets.bzl | 25 ++ extension/android/jni/jni_layer_llama.cpp | 16 +- extension/llm/runner/text_prefiller.h | 17 + extension/llm/runner/text_token_generator.h | 18 ++ 11 files changed, 553 insertions(+), 122 deletions(-) create mode 100644 examples/models/llama/runner/test/CMakeLists.txt create mode 100644 examples/models/llama/runner/test/TARGETS create mode 100644 examples/models/llama/runner/test/runner_test.cpp create mode 100644 examples/models/llama/runner/test/targets.bzl diff --git a/examples/demo-apps/apple_ios/LLaMA/LLaMARunner/LLaMARunner/Exported/LLaMARunner.mm b/examples/demo-apps/apple_ios/LLaMA/LLaMARunner/LLaMARunner/Exported/LLaMARunner.mm index 3618d05ec6c..c2f01bf17b1 100644 --- a/examples/demo-apps/apple_ios/LLaMA/LLaMARunner/LLaMARunner/Exported/LLaMARunner.mm +++ b/examples/demo-apps/apple_ios/LLaMA/LLaMARunner/LLaMARunner/Exported/LLaMARunner.mm @@ -31,7 +31,7 @@ - (instancetype)initWithModelPath:(NSString*)modelPath self = [super init]; if (self) { [ExecuTorchLog.sharedLog addSink:self]; - _runner = std::make_unique( + _runner = example::Runner::create( modelPath.UTF8String, tokenizerPath.UTF8String); } return self; diff --git a/examples/models/llama/main.cpp b/examples/models/llama/main.cpp index 5179bf28fc7..67152ec190b 100644 --- a/examples/models/llama/main.cpp +++ b/examples/models/llama/main.cpp @@ -74,17 +74,18 @@ int32_t main(int32_t argc, char** argv) { #endif // create llama runner // @lint-ignore CLANGTIDY facebook-hte-Deprecated - example::Runner runner(model_path, tokenizer_path); + std::unique_ptr runner = + example::Runner::create(model_path, tokenizer_path); if (warmup) { // @lint-ignore CLANGTIDY facebook-hte-Deprecated - runner.warmup(prompt, /*max_new_tokens=*/seq_len); + runner->warmup(prompt, /*max_new_tokens=*/seq_len); } // generate executorch::extension::llm::GenerationConfig config{ .seq_len = seq_len, .temperature = temperature}; // @lint-ignore CLANGTIDY facebook-hte-Deprecated - runner.generate(prompt, config); + runner->generate(prompt, config); return 0; } diff --git a/examples/models/llama/runner/runner.cpp b/examples/models/llama/runner/runner.cpp index 186c2013616..69c573bf216 100644 --- a/examples/models/llama/runner/runner.cpp +++ b/examples/models/llama/runner/runner.cpp @@ -11,9 +11,6 @@ #include -#include -#include - #include #include @@ -35,130 +32,165 @@ static constexpr auto kMaxSeqLen = "get_max_seq_len"; static constexpr auto kMaxContextLen = "get_max_context_len"; static constexpr auto kVocabSize = "get_vocab_size"; static constexpr auto kUseKVCache = "use_kv_cache"; -static constexpr auto kUseSDPAWithKVCache = "use_sdpa_with_kv_cache"; } // namespace -Runner::Runner( +std::unique_ptr Runner::create( const std::string& model_path, const std::string& tokenizer_path, - std::optional data_path) - // NOTE: we observed ~2x loading performance increase on iPhone 15 - // and a ~5% improvement on Galaxy S22 by switching to - // FileDataLoader instead of MmapDataLoader + UseMlockIgnoreErrors. - : tokenizer_path_(tokenizer_path), - metadata_({ - {kEnableDynamicShape, false}, - {kMaxSeqLen, 128}, - {kMaxContextLen, 128}, - {kUseKVCache, true}, - {kUseSDPAWithKVCache, false}, - }) { - if (data_path.has_value()) { - module_ = std::make_unique( - model_path, data_path.value(), Module::LoadMode::File); - } else { - module_ = std::make_unique(model_path, Module::LoadMode::File); - } + std::optional data_path, + float temperature) { ET_LOG( Info, "Creating LLaMa runner: model_path=%s, tokenizer_path=%s", model_path.c_str(), tokenizer_path.c_str()); -} -[[deprecated( - "This constructor is deprecated. Use the constructor without temperature parameter instead.")]] -Runner::Runner( - const std::string& model_path, - const std::string& tokenizer_path, - const float temperature, - std::optional data_path) - : Runner(model_path, tokenizer_path, std::move(data_path)) { - temperature_ = temperature; -} + // Create the Module + std::unique_ptr module; + if (data_path.has_value()) { + module = std::make_unique( + model_path, data_path.value(), Module::LoadMode::File); + } else { + module = std::make_unique(model_path, Module::LoadMode::File); + } -bool Runner::is_loaded() const { - return module_->is_loaded() && tokenizer_ && text_decoder_runner_ && - text_prefiller_ && text_token_generator_; -} + // Initialize metadata with default values + std::unordered_map metadata({ + {kEnableDynamicShape, false}, + {kMaxSeqLen, 128}, + {kMaxContextLen, 128}, + {kUseKVCache, true}, + }); -Error Runner::load() { - if (is_loaded()) { - return Error::Ok; - } - ET_CHECK_OK_OR_RETURN_ERROR(module_->load_method("forward")); - // load tokenizer. Assuming tiktoken is the default tokenizer - tokenizer_ = nullptr; - tokenizer_ = get_tiktoken_for_llama(); - ::tokenizers::Error err = tokenizer_->load(tokenizer_path_); - // Rely on tiktoken to throw error if the artifact is incompatible. Then we - // fallback to BPE tokenizer. - if (err != ::tokenizers::Error::Ok) { + // Create and load tokenizer + std::unique_ptr<::tokenizers::Tokenizer> tokenizer = get_tiktoken_for_llama(); + ::tokenizers::Error tk_err = tokenizer->load(tokenizer_path); + + // Fallback to BPE tokenizer if tiktoken fails + if (tk_err != ::tokenizers::Error::Ok) { ET_LOG( Info, "Failed to load %s as a Tiktoken artifact, trying BPE tokenizer", - tokenizer_path_.c_str()); - tokenizer_.reset(); - // @lint-ignore CLANGTIDY facebook-hte-Deprecated - tokenizer_ = std::make_unique<::tokenizers::Llama2cTokenizer>(); - err = tokenizer_->load(tokenizer_path_); - ET_CHECK_TK_OK_OR_RETURN_ERROR( - err, - "Failed to load %s as a llama2.c tokenizer artifact", - tokenizer_path_.c_str()); + tokenizer_path.c_str()); + tokenizer.reset(); + tokenizer = std::make_unique<::tokenizers::Llama2cTokenizer>(); + tk_err = tokenizer->load(tokenizer_path); + if (tk_err != ::tokenizers::Error::Ok) { + ET_LOG( + Error, + "Failed to load %s as a llama2.c tokenizer artifact", + tokenizer_path.c_str()); + return nullptr; + } } ET_LOG(Info, "Reading metadata from model"); - metadata_[kBosId] = tokenizer_->bos_tok(); + // Set tokenizer-related metadata + metadata[kBosId] = tokenizer->bos_tok(); auto eos_ids = std::make_unique>( - std::unordered_set{tokenizer_->eos_tok()}); - metadata_[kVocabSize] = tokenizer_->vocab_size(); - - const auto method_names = - ET_UNWRAP(module_->method_names(), "Failed reading method names"); + std::unordered_set{tokenizer->eos_tok()}); + metadata[kVocabSize] = tokenizer->vocab_size(); + + // Read metadata from the model + auto method_names_result = module->method_names(); + if (method_names_result.error() != Error::Ok) { + ET_LOG(Error, "Failed reading method names"); + return nullptr; + } + const auto method_names = method_names_result.get(); - for (auto& pair : metadata_) { + for (auto& pair : metadata) { const auto& method_name = pair.first; auto& value = pair.second; if (method_names.count(method_name)) { - value = ET_UNWRAP(module_->get(method_name)) - .toScalar() - .to(); + auto get_result = module->get(method_name); + value = get_result.get().toScalar().to(); } else { ET_LOG( Info, - "Methond %s not found, using the default value %" PRId64, + "Method %s not found, using the default value %" PRId64, method_name.c_str(), value); } ET_LOG(Info, "Metadata: %s = %" PRId64, method_name.c_str(), value); } + + // Get EOS IDs if available if (method_names.count(kEosIds)) { eos_ids->clear(); - for (const auto& eos_id : ET_UNWRAP(module_->execute(kEosIds))) { + auto execute_result = module->execute(kEosIds); + if (execute_result.error() != Error::Ok) { + ET_LOG(Error, "Failed to execute %s", kEosIds); + return nullptr; + } + for (const auto& eos_id : execute_result.get()) { auto value = eos_id.toScalar().to(); eos_ids->emplace(value); ET_LOG(Info, "eos_id = %" PRId64, value); } } - // @lint-ignore CLANGTIDY facebook-hte-Deprecated - text_decoder_runner_ = std::make_unique( - module_.get(), metadata_.at(kUseKVCache)); - text_prefiller_ = std::make_unique( - text_decoder_runner_.get(), - metadata_.at(kUseKVCache), - metadata_.at(kEnableDynamicShape), - metadata_.at(kMaxSeqLen)); - - text_token_generator_ = std::make_unique( - tokenizer_.get(), - text_decoder_runner_.get(), - metadata_.at(kUseKVCache), + + // Create text_decoder_runner. Use a shared_ptr so that it can be shared with + // TextPrefiller and TextTokenGenerator + auto text_decoder_runner = std::make_unique( + module.get(), metadata.at(kUseKVCache)); + + // Create text_prefiller + auto text_prefiller = std::make_unique( + text_decoder_runner.get(), + metadata.at(kUseKVCache), + metadata.at(kEnableDynamicShape), + metadata.at(kMaxSeqLen)); + + // Create text_token_generator with stats + auto stats = std::make_unique(); + auto text_token_generator = std::make_unique( + tokenizer.get(), + text_decoder_runner.get(), + metadata.at(kUseKVCache), std::move(eos_ids), - &stats_); + stats.get()); + + // Create and return the Runner instance + return std::make_unique( + std::move(metadata), + std::move(tokenizer), + std::move(text_prefiller), + std::move(text_token_generator), + std::move(stats), + temperature); +} +Runner::Runner( + std::unordered_map metadata, + std::unique_ptr<::tokenizers::Tokenizer> tokenizer, + std::unique_ptr<::executorch::extension::llm::TextPrefiller> text_prefiller, + std::unique_ptr<::executorch::extension::llm::TextTokenGenerator> + text_token_generator, + std::unique_ptr<::executorch::extension::llm::Stats> stats, + float temperature) + : tokenizer_(std::move(tokenizer)), + metadata_(std::move(metadata)), + text_prefiller_(std::move(text_prefiller)), + text_token_generator_(std::move(text_token_generator)), + stats_(std::move(stats)), + temperature_(temperature) { + // Note: This constructor assumes that text_prefiller and text_token_generator + // already have references to the Module and TextDecoderRunner they need +} + +bool Runner::is_loaded() const { + return text_prefiller_->is_loaded() && text_token_generator_->is_loaded(); +} + +Error Runner::load() { + if (is_loaded()) { + return Error::Ok; + } + ET_CHECK_OK_OR_RETURN_ERROR(text_prefiller_->load()); + ET_CHECK_OK_OR_RETURN_ERROR(text_token_generator_->load()); return Error::Ok; } @@ -179,9 +211,9 @@ Error Runner::generate( // Use ones-initialized inputs. ET_CHECK_MSG(!prompt.empty(), "Prompt cannot be null"); if (!is_loaded()) { - stats_.model_load_start_ms = llm::time_in_ms(); + stats_->model_load_start_ms = llm::time_in_ms(); ET_CHECK_OK_OR_RETURN_ERROR(load()); - stats_.model_load_end_ms = llm::time_in_ms(); + stats_->model_load_end_ms = llm::time_in_ms(); } if (config.warming) { @@ -207,7 +239,7 @@ Error Runner::generate( // First token time only measures the time it takes to encode the prompt and // return a response token. - stats_.inference_start_ms = llm::time_in_ms(); + stats_->inference_start_ms = llm::time_in_ms(); shouldStop_ = false; ::tokenizers::Result> encode_res = tokenizer_->encode( @@ -248,8 +280,8 @@ Error Runner::generate( auto prefill_res = text_prefiller_->prefill(prompt_tokens, pos); ET_CHECK_OK_OR_RETURN_ERROR(prefill_res.error()); uint64_t cur_token = prefill_res.get(); - stats_.first_token_ms = llm::time_in_ms(); - stats_.prompt_eval_end_ms = llm::time_in_ms(); + stats_->first_token_ms = llm::time_in_ms(); + stats_->prompt_eval_end_ms = llm::time_in_ms(); // print the first token from prefill. No prev_token so use cur_token for it. wrapped_callback( @@ -270,7 +302,7 @@ Error Runner::generate( temperature_ == -1.0f ? config.temperature : temperature_, wrapped_callback)); - stats_.inference_end_ms = llm::time_in_ms(); + stats_->inference_end_ms = llm::time_in_ms(); if (!config.warming) { printf("\n"); } @@ -283,17 +315,17 @@ Error Runner::generate( RUNNER_ET_LOG(config.warming, "Max new tokens %i reached!", max_new_tokens); } - stats_.num_prompt_tokens = num_prompt_tokens; - stats_.num_generated_tokens = num_generated_tokens; + stats_->num_prompt_tokens = num_prompt_tokens; + stats_->num_generated_tokens = num_generated_tokens; if (config.warming) { ET_LOG(Info, "Warmup run finished!"); } else { // Do not print report during warmup - ::executorch::llm::print_report(stats_); + ::executorch::llm::print_report(*stats_); } if (stats_callback) { - stats_callback(stats_); + stats_callback(*stats_); } return Error::Ok; diff --git a/examples/models/llama/runner/runner.h b/examples/models/llama/runner/runner.h index 97ffe4b98b7..4ac1e9148da 100644 --- a/examples/models/llama/runner/runner.h +++ b/examples/models/llama/runner/runner.h @@ -30,18 +30,23 @@ namespace example { class ET_EXPERIMENTAL Runner : public executorch::extension::llm::IRunner { public: - explicit Runner( + // Static factory method to create a Runner instance + static std::unique_ptr create( const std::string& model_path, const std::string& tokenizer_path, - std::optional data_path = std::nullopt); + std::optional data_path = std::nullopt, + float temperature = -1.0f); - [[deprecated( - "This constructor is deprecated. Use the constructor without temperature parameter instead.")]] + // Constructor with dependency injection explicit Runner( - const std::string& model_path, - const std::string& tokenizer_path, - const float temperature, - std::optional data_path = std::nullopt); + std::unordered_map metadata, + std::unique_ptr<::tokenizers::Tokenizer> tokenizer, + std::unique_ptr<::executorch::extension::llm::TextPrefiller> + text_prefiller, + std::unique_ptr<::executorch::extension::llm::TextTokenGenerator> + text_token_generator, + std::unique_ptr<::executorch::extension::llm::Stats> stats, + float temperature = -1.0f); bool is_loaded() const override; ::executorch::runtime::Error load() override; @@ -59,9 +64,7 @@ class ET_EXPERIMENTAL Runner : public executorch::extension::llm::IRunner { private: bool shouldStop_{false}; - // model - std::unique_ptr<::executorch::extension::Module> module_; - std::string tokenizer_path_; + // Components std::unique_ptr<::tokenizers::Tokenizer> tokenizer_; std::unordered_map metadata_; std::unique_ptr<::executorch::extension::llm::TextDecoderRunner> @@ -70,8 +73,8 @@ class ET_EXPERIMENTAL Runner : public executorch::extension::llm::IRunner { std::unique_ptr<::executorch::extension::llm::TextTokenGenerator> text_token_generator_; - // stats - ::executorch::extension::llm::Stats stats_; + // Stats + std::unique_ptr<::executorch::extension::llm::Stats> stats_; // temperature. // Deprecated, we should rely on the temperature in GenerationConfig instead. diff --git a/examples/models/llama/runner/test/CMakeLists.txt b/examples/models/llama/runner/test/CMakeLists.txt new file mode 100644 index 00000000000..39abbf86aab --- /dev/null +++ b/examples/models/llama/runner/test/CMakeLists.txt @@ -0,0 +1,28 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# This file should be formatted with +# ~~~ +# cmake-format -i CMakeLists.txt +# ~~~ +# It should also be cmake-lint clean. +# + +cmake_minimum_required(VERSION 3.19) + +set(EXECUTORCH_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/../../../..) + +include(${EXECUTORCH_ROOT}/tools/cmake/Test.cmake) + +set(_test_srcs runner_test.cpp) + +et_cxx_test( + runner_test + SOURCES + ${_test_srcs} + EXTRA_LIBS + executorch +) diff --git a/examples/models/llama/runner/test/TARGETS b/examples/models/llama/runner/test/TARGETS new file mode 100644 index 00000000000..97de7abe9b1 --- /dev/null +++ b/examples/models/llama/runner/test/TARGETS @@ -0,0 +1,14 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# Any targets that should be shared between fbcode and xplat must be defined in +# targets.bzl. This file can contain fbcode-only targets. + +load(":targets.bzl", "define_common_targets") + +oncall("executorch") + +define_common_targets() diff --git a/examples/models/llama/runner/test/runner_test.cpp b/examples/models/llama/runner/test/runner_test.cpp new file mode 100644 index 00000000000..c782c5f31f9 --- /dev/null +++ b/examples/models/llama/runner/test/runner_test.cpp @@ -0,0 +1,297 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include +#include +#include +#include + +using namespace ::testing; +using namespace example; +using executorch::extension::llm::GenerationConfig; +using executorch::extension::llm::Stats; +using executorch::extension::llm::TextDecoderRunner; +using executorch::extension::llm::TextPrefiller; +using executorch::extension::llm::TextTokenGenerator; +using executorch::runtime::Error; +using executorch::runtime::Result; +using executorch::runtime::testing::TensorFactory; +// Mock classes for dependencies +class MockTokenizer : public ::tokenizers::Tokenizer { + public: + MOCK_METHOD(::tokenizers::Error, load, (const std::string&), ()); + MOCK_METHOD(bool, is_initialized, (), (const)); + MOCK_METHOD( + ::tokenizers::Result>, + encode, + (const std::string&, int8_t, int8_t), + (const)); + MOCK_METHOD( + ::tokenizers::Result, + decode, + (uint64_t, uint64_t), + (const)); + MOCK_METHOD(uint64_t, bos_tok, (), (const)); + MOCK_METHOD(uint64_t, eos_tok, (), (const)); + MOCK_METHOD(uint64_t, vocab_size, (), (const)); +}; + +class MockTextDecoderRunner : public TextDecoderRunner { + public: + MockTextDecoderRunner() : TextDecoderRunner(nullptr, false) {} + MOCK_METHOD( + Result, + step, + (executorch::extension::TensorPtr&, executorch::extension::TensorPtr&), + ()); + MOCK_METHOD(bool, is_method_loaded, (), ()); + MOCK_METHOD(Result, prefill, (std::vector&, int64_t), ()); + MOCK_METHOD(::executorch::runtime::Error, load, (), ()); + // Implement logits_to_token directly since it's not virtual in the parent + // class + int32_t logits_to_token( + const executorch::aten::Tensor& logits_tensor, + const float temperature = 0.0f) { + return 42; // Return a fixed value for testing + } +}; + +class MockTextPrefiller : public TextPrefiller { + public: + MockTextPrefiller(TextDecoderRunner* text_decoder_runner) + : TextPrefiller(text_decoder_runner, false, false, 0) {} + MOCK_METHOD( + Result, + prefill, + (std::vector&, int64_t&), + ()); + MOCK_METHOD(::executorch::runtime::Error, load, (), ()); + MOCK_METHOD(bool, is_loaded, (), ()); +}; + +// Callback counter class for tests +class CallbackCounter { + public: + CallbackCounter() : count_(0) {} + + void callback(const std::string& token) { + count_++; + } + + int getCount() const { + return count_; + } + + private: + int count_; +}; + +// Test fixture for Runner tests - minimal setup +class RunnerTest : public Test { + protected: + // Helper functions to create and set up mock objects + std::unique_ptr createMockTokenizer() { + auto tokenizer = std::make_unique(); + + // Set up default behavior for the tokenizer + ON_CALL(*tokenizer, is_initialized).WillByDefault(Return(true)); + ON_CALL(*tokenizer, encode) + .WillByDefault([](const std::string&, int8_t, int8_t) { + return ::tokenizers::Result>( + std::vector{1, 2, 3}); + }); + + ON_CALL(*tokenizer, decode).WillByDefault([](uint64_t, uint64_t) { + return ::tokenizers::Result("token"); + }); + + ON_CALL(*tokenizer, bos_tok()).WillByDefault(Return(1)); + ON_CALL(*tokenizer, eos_tok()).WillByDefault(Return(2)); + ON_CALL(*tokenizer, vocab_size()).WillByDefault(Return(100)); + + return tokenizer; + } + + std::unique_ptr createMockTextDecoderRunner() { + auto text_decoder_runner = std::make_unique(); + ON_CALL(*text_decoder_runner, step) + .WillByDefault([&](executorch::extension::TensorPtr&, + executorch::extension::TensorPtr&) { + return Result(tensor); + }); + ON_CALL(*text_decoder_runner, is_method_loaded()) + .WillByDefault(Return(true)); + return text_decoder_runner; + } + + std::unique_ptr createMockTextPrefiller( + TextDecoderRunner* text_decoder_runner) { + auto text_prefiller = + std::make_unique(text_decoder_runner); + ON_CALL(*text_prefiller, is_loaded()).WillByDefault(Return(true)); + // Set up default behavior for the text prefiller + ON_CALL(*text_prefiller, prefill) + .WillByDefault([](const std::vector&, int64_t) { + return Result(4); + }); + + return text_prefiller; + } + + std::unique_ptr createTextTokenGenerator( + ::tokenizers::Tokenizer* tokenizer, + TextDecoderRunner* text_decoder_runner, + Stats* stats) { + auto eos_ids = std::make_unique>( + std::unordered_set{100}); + return std::make_unique( + tokenizer, + text_decoder_runner, + true, // use_kv_cache + std::move(eos_ids), + stats); + } + + std::unordered_map createDefaultMetadata() { + return { + {"enable_dynamic_shape", false}, + {"get_max_seq_len", 128}, + {"get_max_context_len", 128}, + {"use_kv_cache", true}, + }; + } + + protected: + Stats stats_; + std::vector return_logits_ = {0.1f, 0.2f, 0.3f, 0.4f}; + TensorFactory tf; + executorch::aten::Tensor tensor = tf.make({1, 4}, return_logits_); +}; + +// Test that generate() calls the token callback exactly max_new_tokens times +TEST_F(RunnerTest, GenerateCallsCallbackExactlyMaxNewTokensTimes) { + // Create mock instances using helper functions + auto tokenizer = createMockTokenizer(); + auto text_decoder_runner = createMockTextDecoderRunner(); + auto text_prefiller = createMockTextPrefiller(text_decoder_runner.get()); + + // Set up expectations for the tokenizer encode method + EXPECT_CALL(*tokenizer, encode(_, _, _)) + .WillOnce(Return(::tokenizers::Result>( + std::vector{1, 2, 3}))); + + // Set up expectations for the text prefiller + EXPECT_CALL(*text_prefiller, prefill(_, _)) + .WillOnce(Return(Result(4))); + + // Set up expectations for load methods + EXPECT_CALL(*text_prefiller, is_loaded()).WillRepeatedly(Return(true)); + + // Create a real TextTokenGenerator + auto text_token_generator = createTextTokenGenerator( + tokenizer.get(), text_decoder_runner.get(), &stats_); + + // Create a Runner with our mocked components + Runner runner( + createDefaultMetadata(), + std::unique_ptr<::tokenizers::Tokenizer>(tokenizer.release()), + std::unique_ptr<::executorch::extension::llm::TextPrefiller>( + text_prefiller.release()), + std::move(text_token_generator)); + + // Load + runner.load(); + + // Set up the generation config with a specific max_new_tokens value + GenerationConfig config; + config.max_new_tokens = 10; + config.echo = false; + + // Create a callback counter + CallbackCounter counter; + + // Call generate with our callback + Error err = runner.generate( + "test prompt", config, [&counter](const std::string& token) { + counter.callback(token); + }); + + // Verify the callback was called exactly max_new_tokens times + // The first token is generated by prefill, and the rest by the token + // generator + EXPECT_EQ(counter.getCount(), config.max_new_tokens); + EXPECT_EQ(err, Error::Ok); +} + +// Test that warmup() calls generate with the warming flag set +TEST_F(RunnerTest, WarmupCallsGenerateWithWarmingFlag) { + // Create mock instances using helper functions + auto tokenizer = createMockTokenizer(); + auto text_decoder_runner = createMockTextDecoderRunner(); + auto text_prefiller = createMockTextPrefiller(text_decoder_runner.get()); + + // Set up expectations for the tokenizer encode method + EXPECT_CALL(*tokenizer, encode(_, _, _)) + .WillOnce(Return(::tokenizers::Result>( + std::vector{1, 2, 3}))); + + // Set up expectations for the text prefiller + EXPECT_CALL(*text_prefiller, prefill(_, _)) + .WillOnce(Return(Result(4))); + + EXPECT_CALL(*text_prefiller, is_loaded()).WillRepeatedly(Return(true)); + + // Create a TextTokenGenerator + auto text_token_generator = createTextTokenGenerator( + tokenizer.get(), text_decoder_runner.get(), &stats_); + + // Create a Runner with our mocked components + Runner runner( + createDefaultMetadata(), + std::unique_ptr<::tokenizers::Tokenizer>(tokenizer.release()), + std::unique_ptr<::executorch::extension::llm::TextPrefiller>( + text_prefiller.release()), + std::move(text_token_generator)); + + // Load + runner.load(); + + // Call warmup + Error err = runner.warmup("test prompt", 5); + + // Verify the result + EXPECT_EQ(err, Error::Ok); +} + +// Test that is_loaded() returns true when components are initialized +TEST_F(RunnerTest, IsLoadedReturnsTrueWhenComponentsInitialized) { + // Create mock instances using helper functions + auto tokenizer = createMockTokenizer(); + auto text_decoder_runner = createMockTextDecoderRunner(); + auto text_prefiller = createMockTextPrefiller(text_decoder_runner.get()); + + // Create a real TextTokenGenerator + auto text_token_generator = createTextTokenGenerator( + tokenizer.get(), text_decoder_runner.get(), &stats_); + + // Create a Runner with our mocked components + Runner runner( + createDefaultMetadata(), + std::unique_ptr<::tokenizers::Tokenizer>(tokenizer.release()), + std::unique_ptr<::executorch::extension::llm::TextPrefiller>( + text_prefiller.release()), + std::move(text_token_generator)); + + // Load + runner.load(); + + // Verify is_loaded returns true + EXPECT_TRUE(runner.is_loaded()); +} diff --git a/examples/models/llama/runner/test/targets.bzl b/examples/models/llama/runner/test/targets.bzl new file mode 100644 index 00000000000..5fd995f1a83 --- /dev/null +++ b/examples/models/llama/runner/test/targets.bzl @@ -0,0 +1,25 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime") + +def define_common_targets(): + runtime.cxx_test( + name = "runner_test", + srcs = ["runner_test.cpp"], + deps = [ + "//executorch/examples/models/llama/runner:runner", + "//executorch/extension/llm/runner:irunner", + "//executorch/extension/llm/runner:stats", + "//executorch/extension/llm/runner:text_token_generator", + "//executorch/extension/llm/runner:text_decoder_runner", + "//executorch/extension/llm/runner:text_prefiller", + "//executorch/extension/module:module", + "//executorch/runtime/core:core", + "//executorch/runtime/platform:platform", + "//executorch/runtime/core/exec_aten/testing_util:tensor_util", + ], + ) diff --git a/extension/android/jni/jni_layer_llama.cpp b/extension/android/jni/jni_layer_llama.cpp index 5e730c559d1..54140755369 100644 --- a/extension/android/jni/jni_layer_llama.cpp +++ b/extension/android/jni/jni_layer_llama.cpp @@ -164,16 +164,12 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass { tokenizer_path->toStdString().c_str(), temperature); } else if (model_type_category == MODEL_TYPE_CATEGORY_LLM) { - if (data_path != nullptr) { - runner_ = std::make_unique( - model_path->toStdString().c_str(), - tokenizer_path->toStdString().c_str(), - data_path->toStdString().c_str()); - } else { - runner_ = std::make_unique( - model_path->toStdString().c_str(), - tokenizer_path->toStdString().c_str()); - } + std::optional data_path_str = + data_path == nullptr ? std::nullopt : data_path->toStdString(); + runner_ = example::Runner::create( + model_path->toStdString().c_str(), + tokenizer_path->toStdString().c_str(), + data_path_str); #if defined(EXECUTORCH_BUILD_MEDIATEK) } else if (model_type_category == MODEL_TYPE_MEDIATEK_LLAMA) { runner_ = std::make_unique( diff --git a/extension/llm/runner/text_prefiller.h b/extension/llm/runner/text_prefiller.h index 28632ad856a..49b2c867167 100644 --- a/extension/llm/runner/text_prefiller.h +++ b/extension/llm/runner/text_prefiller.h @@ -49,6 +49,23 @@ class ET_EXPERIMENTAL TextPrefiller { std::vector& prompt_tokens, int64_t& start_pos); + /** + * Load the necessary resources for the TextPrefiller. + * This method should be called before using the prefill methods. + */ + ::executorch::runtime::Error load() { + return text_decoder_runner_->load(); + } + + /** + * Check if the TextPrefiller has been successfully loaded. + * @return True if the resources are loaded, false otherwise. + */ + bool inline is_loaded() const { + // Implementation to check if resources are loaded + return text_decoder_runner_->is_method_loaded(); + } + private: /** * Note: TextPrefiller does not own the TextDecoderRunner instance. diff --git a/extension/llm/runner/text_token_generator.h b/extension/llm/runner/text_token_generator.h index 38873e25fc1..b5001495763 100644 --- a/extension/llm/runner/text_token_generator.h +++ b/extension/llm/runner/text_token_generator.h @@ -137,6 +137,24 @@ class ET_EXPERIMENTAL TextTokenGenerator { should_stop_ = true; } + /** + * Load the necessary resources for TextTokenGenerator. + * This method should be called before using the generate() method. + */ + ::executorch::runtime::Error load() { + return text_decoder_runner_->load(); + } + + /** + * Check if the TextTokenGenerator has been successfully loaded. + * @return True if the resources are loaded, false otherwise. + */ + bool inline is_loaded() const { + // Implementation to check if resources are loaded + return tokenizer_->is_initialized() && + text_decoder_runner_->is_method_loaded(); + } + private: /** * Note: TextTokenGenerator does not own the tokenizer_ and