Skip to content

Commit abc1c02

Browse files
lucylqfacebook-github-bot
authored andcommitted
Extend data file support for other pybindings
Summary: This diff extends the support for external data_file to other executorch runtime pybindings. Reviewed By: LeeOHzzZ Differential Revision: D82046648
1 parent a1ed4ed commit abc1c02

File tree

3 files changed

+113
-13
lines changed

3 files changed

+113
-13
lines changed

extension/pybindings/pybindings.cpp

Lines changed: 47 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -161,10 +161,24 @@ void setup_output_storage(
161161
inline std::unique_ptr<Module> load_module_from_buffer(
162162
const void* ptr,
163163
size_t ptr_len,
164+
std::optional<const void*> data_map_ptr,
165+
std::optional<size_t> data_map_len,
164166
std::unique_ptr<runtime::EventTracer> event_tracer,
165167
Program::Verification program_verification) {
166168
EXECUTORCH_SCOPE_PROF("load_module_from_buffer");
167169
auto loader = std::make_unique<BufferDataLoader>(ptr, ptr_len);
170+
171+
if (data_map_ptr.has_value() && data_map_len.has_value()) {
172+
auto data_map_loader = std::make_unique<BufferDataLoader>(
173+
data_map_ptr.value(), data_map_len.value());
174+
return std::make_unique<Module>(
175+
std::move(loader),
176+
nullptr, // memory_allocator
177+
nullptr, // temp_allocator
178+
std::move(event_tracer), // event_tracer
179+
std::move(data_map_loader)); // data_map_loader
180+
}
181+
168182
return std::make_unique<Module>(
169183
std::move(loader),
170184
nullptr, // memory_allocator
@@ -504,6 +518,7 @@ struct PyMethodMeta final {
504518
struct PyModule final {
505519
explicit PyModule(
506520
const py::bytes& buffer,
521+
std::optional<const py::bytes> data_map_buffer,
507522
bool enable_etdump,
508523
size_t debug_buffer_size = 0,
509524
Program::Verification program_verification =
@@ -512,12 +527,21 @@ struct PyModule final {
512527
module_(load_module_from_buffer(
513528
buffer.cast<std::string_view>().data(),
514529
py::len(buffer),
530+
data_map_buffer.has_value()
531+
? std::optional<const void*>(
532+
data_map_buffer.value().cast<std::string_view>().data())
533+
: std::nullopt,
534+
data_map_buffer.has_value()
535+
? std::optional<size_t>(py::len(data_map_buffer.value()))
536+
: std::nullopt,
515537
setup_event_tracer(enable_etdump, debug_buffer_size),
516538
program_verification)) {}
517539

518540
explicit PyModule(
519541
const void* ptr,
520542
size_t ptr_len,
543+
std::optional<const void*> data_map_ptr,
544+
std::optional<size_t> data_map_ptr_len,
521545
bool enable_etdump,
522546
size_t debug_buffer_size = 0,
523547
Program::Verification program_verification =
@@ -526,6 +550,8 @@ struct PyModule final {
526550
module_(load_module_from_buffer(
527551
ptr,
528552
ptr_len,
553+
data_map_ptr,
554+
data_map_ptr_len,
529555
setup_event_tracer(enable_etdump, debug_buffer_size),
530556
program_verification)) {}
531557

@@ -551,12 +577,17 @@ struct PyModule final {
551577
// Module is only valid as long as the python buffer is alive.
552578
static std::unique_ptr<PyModule> load_from_buffer(
553579
const py::bytes& buffer,
580+
std::optional<const py::bytes> data_map_buffer,
554581
bool enable_etdump,
555582
size_t debug_buffer_size = 0,
556583
Program::Verification program_verification =
557584
Program::Verification::InternalConsistency) {
558585
return std::make_unique<PyModule>(
559-
buffer, enable_etdump, debug_buffer_size, program_verification);
586+
buffer,
587+
data_map_buffer,
588+
enable_etdump,
589+
debug_buffer_size,
590+
program_verification);
560591
}
561592

562593
static std::unique_ptr<PyModule> load_from_file(
@@ -576,13 +607,25 @@ struct PyModule final {
576607

577608
static std::unique_ptr<PyModule> load_from_bundled_program(
578609
PyBundledModule& m,
610+
std::optional<const py::bytes> data_map_buffer,
579611
bool enable_etdump,
580612
size_t debug_buffer_size = 0) {
613+
std::optional<const void*> data_map_ptr = std::nullopt;
614+
std::optional<size_t> data_map_len = std::nullopt;
615+
616+
if (data_map_buffer.has_value()) {
617+
data_map_ptr = data_map_buffer.value().cast<std::string_view>().data();
618+
data_map_len = py::len(data_map_buffer.value());
619+
}
620+
581621
return std::make_unique<PyModule>(
582622
m.get_program_ptr(),
583623
m.get_program_len(),
624+
data_map_ptr,
625+
data_map_len,
584626
enable_etdump,
585-
debug_buffer_size);
627+
debug_buffer_size,
628+
Program::Verification::InternalConsistency);
586629
}
587630

588631
py::list run_method(
@@ -1423,6 +1466,7 @@ PYBIND11_MODULE(EXECUTORCH_PYTHON_MODULE_NAME, m) {
14231466
"_load_for_executorch_from_buffer",
14241467
&PyModule::load_from_buffer,
14251468
py::arg("buffer"),
1469+
py::arg("data_buffer") = std::nullopt,
14261470
py::arg("enable_etdump") = false,
14271471
py::arg("debug_buffer_size") = 0,
14281472
py::arg("program_verification") =
@@ -1432,6 +1476,7 @@ PYBIND11_MODULE(EXECUTORCH_PYTHON_MODULE_NAME, m) {
14321476
"_load_for_executorch_from_bundled_program",
14331477
&PyModule::load_from_bundled_program,
14341478
py::arg("ptr"),
1479+
py::arg("data_map_ptr") = std::nullopt,
14351480
py::arg("enable_etdump") = false,
14361481
py::arg("debug_buffer_size") = 0,
14371482
call_guard);

extension/pybindings/test/TARGETS

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,9 @@ runtime.python_library(
1717
deps = [
1818
"//caffe2:torch",
1919
"//caffe2:torch_fx",
20+
"//executorch/devtools/bundled_program:config",
21+
"//executorch/devtools/bundled_program:core",
22+
"//executorch/devtools/bundled_program/serialize:lib",
2023
"//executorch/exir:lib",
2124
"//executorch/exir:pass_manager",
2225
"//executorch/exir:scalar_type",

extension/pybindings/test/test_pybindings.py

Lines changed: 63 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -635,24 +635,76 @@ def test_program_data_separation(self) -> None:
635635
external_constants=True,
636636
)
637637
)
638+
program_buffer = exec_program.buffer
639+
assert(len(exec_program._tensor_data) == 1)
640+
data_buffer = bytes(exec_program._tensor_data.pop("_default_external_constant"))
638641

639642
import os
640643
import tempfile
641-
642644
with tempfile.TemporaryDirectory() as tmpdir:
643645
pte_file = os.path.join(tmpdir, "linear.pte")
644646
with open(pte_file, "wb") as f:
645-
f.write(exec_program.buffer)
646-
647+
f.write(program_buffer)
647648
ptd_file = os.path.join(tmpdir, "linear.ptd")
648649
with open(ptd_file, "wb") as ptd:
649-
tensor_data = bytes(
650-
exec_program._tensor_data.pop("_default_external_constant")
651-
)
652-
ptd.write(tensor_data)
650+
ptd.write(data_buffer)
651+
expected = eager_module(inputs[0])
652+
# Test 1: File-based loading with external data file
653+
executorch_module_file = self.runtime._load_for_executorch(pte_file, ptd_file)
654+
executorch_output_file = executorch_module_file.forward(inputs)[0]
655+
self.assertTrue(torch.allclose(expected, executorch_output_file))
656+
657+
# Test 2: Buffer-based loading with external data buffer
658+
executorch_module_buffer = self.load_fn(program_buffer, data_buffer)
659+
executorch_output_buffer = executorch_module_buffer.forward(inputs)[0]
660+
self.assertTrue(torch.allclose(expected, executorch_output_buffer))
661+
662+
# Test 3: Buffer-based loading without external data file (should fail or work differently)
663+
# This should fail because the program expects external data
664+
with self.assertRaises(RuntimeError):
665+
executorch_module_no_data = self.load_fn(program_buffer)
666+
executorch_module_no_data.forward(inputs)
653667

654-
executorch_program = self.runtime._load_for_executorch(pte_file, ptd_file)
668+
# Test 4: Test with invalid data buffer (should fail)
669+
invalid_bytes = b"invalid bytes"
670+
with self.assertRaises(RuntimeError):
671+
executorch_module_invalid_data = self.load_fn(program_buffer, invalid_bytes)
672+
executorch_module_invalid_data.forward(inputs)
673+
674+
# Test 5: Test bundled program loading with external data
675+
# First create a bundled program with external constants
676+
from executorch.devtools.bundled_program.core import BundledProgram
677+
from executorch.devtools.bundled_program.config import MethodTestCase, MethodTestSuite
678+
from executorch.devtools.bundled_program.serialize import (
679+
serialize_from_bundled_program_to_flatbuffer,
680+
)
655681

656-
expected = eager_module(inputs[0])
657-
executorch_output = executorch_program.forward(inputs)[0]
658-
self.assertTrue(torch.allclose(expected, executorch_output))
682+
method_test_suites = [
683+
MethodTestSuite(
684+
method_name="forward",
685+
test_cases=[
686+
MethodTestCase(
687+
inputs=input,
688+
expected_outputs=expected,
689+
)
690+
for input in inputs
691+
],
692+
),
693+
]
694+
bundled_program = BundledProgram(exec_program, method_test_suites)
695+
bundled_buffer = serialize_from_bundled_program_to_flatbuffer(bundled_program)
696+
bundled_module = self.runtime._load_bundled_program_from_buffer(bundled_buffer)
697+
698+
# Load module from bundled program with external data
699+
executorch_module_bundled = self.runtime._load_for_executorch_from_bundled_program(
700+
bundled_module, data_buffer
701+
)
702+
executorch_output_bundled = executorch_module_bundled.forward(inputs)[0]
703+
self.assertTrue(torch.allclose(expected, executorch_output_bundled))
704+
705+
# Test 6: Bundled program without external data should fail
706+
with self.assertRaises(RuntimeError):
707+
executorch_module_bundled_no_data = self.runtime._load_for_executorch_from_bundled_program(
708+
bundled_module
709+
)
710+
executorch_module_bundled_no_data.forward(inputs)

0 commit comments

Comments
 (0)