Skip to content

Commit 7b2b3aa

Browse files
committed
Use merged data map in module
Pull Request resolved: #14767 ghstack-source-id: 315148398 Differential Revision: [D83799869](https://our.internmc.facebook.com/intern/diff/D83799869/)
1 parent 2eb8994 commit 7b2b3aa

File tree

10 files changed

+63
-26
lines changed

10 files changed

+63
-26
lines changed

extension/module/CMakeLists.txt

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ else()
2929
endif()
3030
target_link_libraries(
3131
extension_module PRIVATE executorch_core extension_data_loader
32-
extension_flat_tensor
32+
extension_flat_tensor extension_named_data_map
3333
)
3434
target_include_directories(
3535
extension_module PUBLIC ${_common_include_directories}
@@ -42,8 +42,9 @@ target_compile_options(
4242
# after cleaning up CMake targets.
4343
add_library(extension_module_static STATIC ${_extension_module__srcs})
4444
target_link_libraries(
45-
extension_module_static PRIVATE executorch_core extension_data_loader
46-
extension_flat_tensor
45+
extension_module_static
46+
PRIVATE executorch_core extension_data_loader extension_flat_tensor
47+
extension_named_data_map
4748
)
4849
target_include_directories(
4950
extension_module_static PUBLIC ${_common_include_directories}

extension/module/module.cpp

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
#include <executorch/extension/data_loader/mmap_data_loader.h>
1313
#include <executorch/extension/flat_tensor/flat_tensor_data_map.h>
1414
#include <executorch/extension/memory_allocator/malloc_memory_allocator.h>
15+
#include <executorch/extension/named_data_map/merged_data_map.h>
1516
#include <executorch/runtime/platform/runtime.h>
1617

1718
/**
@@ -38,6 +39,7 @@ namespace executorch {
3839
namespace extension {
3940
namespace ET_MODULE_NAMESPACE {
4041

42+
using ET_MERGED_DATA_MAP_NAMESPACE::MergedDataMap;
4143
using ET_RUNTIME_NAMESPACE::MethodMeta;
4244
using ET_RUNTIME_NAMESPACE::Program;
4345

@@ -155,24 +157,27 @@ runtime::Error Module::load(const Program::Verification verification) {
155157
data_loader_ = ET_UNWRAP(make_data_loader(file_path_, load_mode_));
156158
}
157159
if (data_files_.size() > 0) {
158-
ET_CHECK_OR_RETURN_ERROR(
159-
data_files_.size() == 1,
160-
NotImplemented,
161-
"Multiple named data map paths are not supported yet.");
162160
for (const auto& data_file : data_files_) {
163161
data_map_loaders_.push_back(
164162
ET_UNWRAP(make_data_loader(data_file, load_mode_)));
165163
}
166164
}
167165

168166
if (data_map_loaders_.size() > 0) {
169-
ET_CHECK_OR_RETURN_ERROR(
170-
data_map_loaders_.size() == 1 && merged_data_map_ == nullptr,
171-
NotImplemented,
172-
"Multiple named data map loaders are not supported yet.");
173-
// TODO(lfq): support multiple named data map loaders.
174-
merged_data_map_ =
175-
ET_UNWRAP_UNIQUE(FlatTensorDataMap::load(data_map_loaders_[0].get()));
167+
for (auto i = 0; i < data_map_loaders_.size(); ++i) {
168+
named_data_maps_.push_back(ET_UNWRAP_UNIQUE(
169+
FlatTensorDataMap::load(data_map_loaders_[i].get())));
170+
}
171+
172+
// Extract raw pointers from unique_ptrs to pass to MergedDataMap::load()
173+
std::vector<const NamedDataMap*> raw_data_maps;
174+
raw_data_maps.reserve(named_data_maps_.size());
175+
for (const auto& data_map : named_data_maps_) {
176+
raw_data_maps.push_back(data_map.get());
177+
}
178+
merged_data_map_ = ET_UNWRAP_UNIQUE(
179+
MergedDataMap::load(runtime::Span<const NamedDataMap*>(
180+
raw_data_maps.data(), raw_data_maps.size())));
176181
}
177182

178183
auto program =

extension/module/targets.bzl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ def define_common_targets():
2626
"//executorch/extension/data_loader:file_data_loader",
2727
"//executorch/extension/data_loader:mmap_data_loader",
2828
"//executorch/extension/flat_tensor:flat_tensor_data_map" + aten_suffix,
29+
"//executorch/extension/named_data_map:merged_data_map" + aten_suffix,
2930
],
3031
exported_deps = [
3132
"//executorch/runtime/executor:program_no_prim_ops" + aten_suffix,

extension/module/test/CMakeLists.txt

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,11 +23,14 @@ add_custom_command(
2323
OUTPUT "${CMAKE_CURRENT_BINARY_DIR}/ModuleAdd.pte"
2424
"${CMAKE_CURRENT_BINARY_DIR}/ModuleAddMulProgram.pte"
2525
"${CMAKE_CURRENT_BINARY_DIR}/ModuleAddMulProgram.ptd"
26+
"${CMAKE_CURRENT_BINARY_DIR}/ModuleLinearProgram.pte"
27+
"${CMAKE_CURRENT_BINARY_DIR}/ModuleLinearProgram.ptd"
2628
COMMAND ${PYTHON_EXECUTABLE} -m test.models.export_program --modules
2729
"ModuleAdd" --outdir "${CMAKE_CURRENT_BINARY_DIR}"
2830
COMMAND
29-
${PYTHON_EXECUTABLE} -m test.models.export_program --modules "ModuleAddMul"
30-
--external-constants --outdir "${CMAKE_CURRENT_BINARY_DIR}"
31+
${PYTHON_EXECUTABLE} -m test.models.export_program --modules
32+
"ModuleAddMul,ModuleLinear" --external-constants --outdir
33+
"${CMAKE_CURRENT_BINARY_DIR}"
3134
WORKING_DIRECTORY ${EXECUTORCH_ROOT}
3235
)
3336

@@ -36,12 +39,16 @@ add_custom_target(
3639
DEPENDS "${CMAKE_CURRENT_BINARY_DIR}/ModuleAdd.pte"
3740
"${CMAKE_CURRENT_BINARY_DIR}/ModuleAddMulProgram.pte"
3841
"${CMAKE_CURRENT_BINARY_DIR}/ModuleAddMulProgram.ptd"
42+
"${CMAKE_CURRENT_BINARY_DIR}/ModuleLinearProgram.pte"
43+
"${CMAKE_CURRENT_BINARY_DIR}/ModuleLinearProgram.ptd"
3944
)
4045

4146
set(test_env
4247
"ET_MODULE_ADD_PATH=${CMAKE_CURRENT_BINARY_DIR}/ModuleAdd.pte"
4348
"ET_MODULE_ADD_MUL_PROGRAM_PATH=${CMAKE_CURRENT_BINARY_DIR}/ModuleAddMulProgram.pte"
4449
"ET_MODULE_ADD_MUL_DATA_PATH=${CMAKE_CURRENT_BINARY_DIR}/ModuleAddMulProgram.ptd"
50+
"ET_MODULE_LINEAR_PROGRAM_PATH=${CMAKE_CURRENT_BINARY_DIR}/ModuleLinearProgram.pte"
51+
"ET_MODULE_LINEAR_DATA_PATH=${CMAKE_CURRENT_BINARY_DIR}/ModuleLinearProgram.ptd"
4552
)
4653

4754
et_cxx_test(

extension/module/test/module_test.cpp

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -26,11 +26,15 @@ class ModuleTest : public ::testing::Test {
2626
model_path_ = std::getenv("ET_MODULE_ADD_PATH");
2727
add_mul_path_ = std::getenv("ET_MODULE_ADD_MUL_PROGRAM_PATH");
2828
add_mul_data_path_ = std::getenv("ET_MODULE_ADD_MUL_DATA_PATH");
29+
linear_path_ = std::getenv("ET_MODULE_LINEAR_PROGRAM_PATH");
30+
linear_data_path_ = std::getenv("ET_MODULE_LINEAR_DATA_PATH");
2931
}
3032

3133
static inline std::string model_path_;
3234
static inline std::string add_mul_path_;
3335
static inline std::string add_mul_data_path_;
36+
static inline std::string linear_path_;
37+
static inline std::string linear_data_path_;
3438
};
3539

3640
TEST_F(ModuleTest, TestLoad) {
@@ -532,16 +536,21 @@ TEST_F(ModuleTest, TestPTD) {
532536
}
533537

534538
TEST_F(ModuleTest, TestPTD_Multiple) {
535-
std::vector<std::string> data_files = {add_mul_data_path_};
536-
Module module(add_mul_path_, data_files);
537-
538-
ASSERT_EQ(module.load_method("forward"), Error::Ok);
539+
std::vector<std::string> data_files = {add_mul_data_path_, linear_data_path_};
539540

541+
// Create module with add mul.
542+
Module module_add_mul(add_mul_path_, data_files);
543+
ASSERT_EQ(module_add_mul.load_method("forward"), Error::Ok);
540544
auto tensor = make_tensor_ptr({2, 2}, {2.f, 3.f, 4.f, 2.f});
541-
ASSERT_EQ(module.forward(tensor).error(), Error::Ok);
545+
ASSERT_EQ(module_add_mul.forward(tensor).error(), Error::Ok);
542546

543547
// Confirm that the data_file is not std::move'd away.
544548
ASSERT_EQ(std::strcmp(data_files[0].c_str(), add_mul_data_path_.c_str()), 0);
549+
ASSERT_EQ(std::strcmp(data_files[1].c_str(), linear_data_path_.c_str()), 0);
545550

546-
// TODO(lfq): add test when merge capability is supported.
551+
// Create module with linear.
552+
Module module_linear(linear_path_, data_files);
553+
ASSERT_EQ(module_linear.load_method("forward"), Error::Ok);
554+
auto tensor2 = make_tensor_ptr({3}, {2.f, 3.f, 4.f});
555+
ASSERT_EQ(module_linear.forward(tensor2).error(), Error::Ok);
547556
}

extension/module/test/targets.bzl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@ def define_common_targets(is_fbcode=False):
1919
"ET_MODULE_ADD_PATH": "$(location fbcode//executorch/test/models:exported_programs[ModuleAdd.pte])",
2020
"ET_MODULE_ADD_MUL_PROGRAM_PATH": "$(location fbcode//executorch/test/models:exported_program_and_data[ModuleAddMul.pte])",
2121
"ET_MODULE_ADD_MUL_DATA_PATH": "$(location fbcode//executorch/test/models:exported_program_and_data[ModuleAddMul.ptd])",
22+
"ET_MODULE_LINEAR_PROGRAM_PATH": "$(location fbcode//executorch/test/models:exported_program_and_data[ModuleLinear.pte])",
23+
"ET_MODULE_LINEAR_DATA_PATH": "$(location fbcode//executorch/test/models:exported_program_and_data[ModuleLinear.ptd])",
2224
"ET_MODULE_SHARED_STATE": "$(location fbcode//executorch/test/models:exported_programs[ModuleSharedState.pte])",
2325
}
2426

extension/named_data_map/merged_data_map.cpp

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ using executorch::runtime::Result;
2121
using executorch::runtime::Span;
2222

2323
namespace executorch::extension {
24-
24+
namespace ET_MERGED_DATA_MAP_NAMESPACE {
2525
/*static*/ Result<MergedDataMap> MergedDataMap::load(
2626
Span<const NamedDataMap*> named_data_maps) {
2727
std::vector<const NamedDataMap*> valid_data_maps;
@@ -38,7 +38,7 @@ namespace executorch::extension {
3838

3939
// Check for duplicate keys.
4040
std::unordered_map<std::string, uint32_t> key_to_map_index;
41-
for (auto i : c10::irange(valid_data_maps.size())) {
41+
for (const uint32_t i : c10::irange(valid_data_maps.size())) {
4242
const auto cur_map = valid_data_maps[i];
4343
uint32_t num_keys = cur_map->get_num_keys().get();
4444
for (auto j : c10::irange(num_keys)) {
@@ -47,7 +47,7 @@ namespace executorch::extension {
4747
ET_CHECK_OR_RETURN_ERROR(
4848
inserted,
4949
InvalidArgument,
50-
"Duplicate key %s in named data maps at index %u and %lu",
50+
"Duplicate key %s in named data maps at index %u and %" PRIu32,
5151
cur_key,
5252
it->second,
5353
i);
@@ -114,4 +114,6 @@ ET_NODISCARD Result<const char*> MergedDataMap::get_key(uint32_t index) const {
114114
// Shouldn't reach here.
115115
return Error::Internal;
116116
}
117+
118+
} // namespace ET_MERGED_DATA_MAP_NAMESPACE
117119
} // namespace executorch::extension

extension/named_data_map/merged_data_map.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,15 @@
1313
#include <unordered_map>
1414
#include <vector>
1515

16+
#ifdef USE_ATEN_LIB
17+
#define ET_MERGED_DATA_MAP_NAMESPACE merged_data_map::aten
18+
#else // !USE_ATEN_LIB
19+
#define ET_MERGED_DATA_MAP_NAMESPACE merged_data_map
20+
#endif // USE_ATEN_LIB
21+
1622
namespace executorch::extension {
23+
24+
namespace ET_MERGED_DATA_MAP_NAMESPACE {
1725
/**
1826
* A NamedDataMap implementation that wraps other NamedDataMaps.
1927
*/
@@ -103,4 +111,5 @@ class MergedDataMap final
103111
std::unordered_map<std::string, uint32_t> key_to_map_index_;
104112
};
105113

114+
} // namespace ET_MERGED_DATA_MAP_NAMESPACE
106115
} // namespace executorch::extension

extension/named_data_map/test/merged_data_map_test.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
using namespace ::testing;
2424
using executorch::extension::FileDataLoader;
2525
using executorch::extension::FlatTensorDataMap;
26-
using executorch::extension::MergedDataMap;
26+
using executorch::extension::merged_data_map::MergedDataMap;
2727
using executorch::runtime::DataLoader;
2828
using executorch::runtime::Error;
2929
using executorch::runtime::NamedDataMap;

scripts/build_apple_frameworks.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ libextension_apple.a,\
3131
libextension_data_loader.a,\
3232
libextension_flat_tensor.a,\
3333
libextension_module.a,\
34+
libextension_named_data_map.a,\
3435
libextension_tensor.a,\
3536
:${FRAMEWORK_EXECUTORCH_HEADERS_DIR}:${FRAMEWORK_EXECUTORCH_MODULE_NAME}"
3637

0 commit comments

Comments
 (0)