Skip to content

Commit 2be5fb0

Browse files
lucylqfacebook-github-bot
authored andcommitted
Extend data file support for other pybindings (#14413)
Summary: This diff extends the support for external data_file to load from buffer pybindings Reviewed By: JacobSzwejbka, LeeOHzzZ Differential Revision: D82046648
1 parent c9f46e2 commit 2be5fb0

File tree

3 files changed

+120
-12
lines changed

3 files changed

+120
-12
lines changed

extension/pybindings/pybindings.cpp

Lines changed: 47 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -161,10 +161,24 @@ void setup_output_storage(
161161
inline std::unique_ptr<Module> load_module_from_buffer(
162162
const void* ptr,
163163
size_t ptr_len,
164+
std::optional<const void*> data_map_ptr,
165+
std::optional<size_t> data_map_len,
164166
std::unique_ptr<runtime::EventTracer> event_tracer,
165167
Program::Verification program_verification) {
166168
EXECUTORCH_SCOPE_PROF("load_module_from_buffer");
167169
auto loader = std::make_unique<BufferDataLoader>(ptr, ptr_len);
170+
171+
if (data_map_ptr.has_value() && data_map_len.has_value()) {
172+
auto data_map_loader = std::make_unique<BufferDataLoader>(
173+
data_map_ptr.value(), data_map_len.value());
174+
return std::make_unique<Module>(
175+
std::move(loader),
176+
nullptr, // memory_allocator
177+
nullptr, // temp_allocator
178+
std::move(event_tracer), // event_tracer
179+
std::move(data_map_loader)); // data_map_loader
180+
}
181+
168182
return std::make_unique<Module>(
169183
std::move(loader),
170184
nullptr, // memory_allocator
@@ -504,6 +518,7 @@ struct PyMethodMeta final {
504518
struct PyModule final {
505519
explicit PyModule(
506520
const py::bytes& buffer,
521+
std::optional<const py::bytes> data_map_buffer,
507522
bool enable_etdump,
508523
size_t debug_buffer_size = 0,
509524
Program::Verification program_verification =
@@ -512,12 +527,21 @@ struct PyModule final {
512527
module_(load_module_from_buffer(
513528
buffer.cast<std::string_view>().data(),
514529
py::len(buffer),
530+
data_map_buffer.has_value()
531+
? std::optional<const void*>(
532+
data_map_buffer.value().cast<std::string_view>().data())
533+
: std::nullopt,
534+
data_map_buffer.has_value()
535+
? std::optional<size_t>(py::len(data_map_buffer.value()))
536+
: std::nullopt,
515537
setup_event_tracer(enable_etdump, debug_buffer_size),
516538
program_verification)) {}
517539

518540
explicit PyModule(
519541
const void* ptr,
520542
size_t ptr_len,
543+
std::optional<const void*> data_map_ptr,
544+
std::optional<size_t> data_map_ptr_len,
521545
bool enable_etdump,
522546
size_t debug_buffer_size = 0,
523547
Program::Verification program_verification =
@@ -526,6 +550,8 @@ struct PyModule final {
526550
module_(load_module_from_buffer(
527551
ptr,
528552
ptr_len,
553+
data_map_ptr,
554+
data_map_ptr_len,
529555
setup_event_tracer(enable_etdump, debug_buffer_size),
530556
program_verification)) {}
531557

@@ -551,12 +577,17 @@ struct PyModule final {
551577
// Module is only valid as long as the python buffer is alive.
552578
static std::unique_ptr<PyModule> load_from_buffer(
553579
const py::bytes& buffer,
580+
std::optional<const py::bytes> data_map_buffer,
554581
bool enable_etdump,
555582
size_t debug_buffer_size = 0,
556583
Program::Verification program_verification =
557584
Program::Verification::InternalConsistency) {
558585
return std::make_unique<PyModule>(
559-
buffer, enable_etdump, debug_buffer_size, program_verification);
586+
buffer,
587+
data_map_buffer,
588+
enable_etdump,
589+
debug_buffer_size,
590+
program_verification);
560591
}
561592

562593
static std::unique_ptr<PyModule> load_from_file(
@@ -576,13 +607,25 @@ struct PyModule final {
576607

577608
static std::unique_ptr<PyModule> load_from_bundled_program(
578609
PyBundledModule& m,
610+
std::optional<const py::bytes> data_map_buffer,
579611
bool enable_etdump,
580612
size_t debug_buffer_size = 0) {
613+
std::optional<const void*> data_map_ptr = std::nullopt;
614+
std::optional<size_t> data_map_len = std::nullopt;
615+
616+
if (data_map_buffer.has_value()) {
617+
data_map_ptr = data_map_buffer.value().cast<std::string_view>().data();
618+
data_map_len = py::len(data_map_buffer.value());
619+
}
620+
581621
return std::make_unique<PyModule>(
582622
m.get_program_ptr(),
583623
m.get_program_len(),
624+
data_map_ptr,
625+
data_map_len,
584626
enable_etdump,
585-
debug_buffer_size);
627+
debug_buffer_size,
628+
Program::Verification::InternalConsistency);
586629
}
587630

588631
py::list run_method(
@@ -1423,6 +1466,7 @@ PYBIND11_MODULE(EXECUTORCH_PYTHON_MODULE_NAME, m) {
14231466
"_load_for_executorch_from_buffer",
14241467
&PyModule::load_from_buffer,
14251468
py::arg("buffer"),
1469+
py::arg("data_buffer") = std::nullopt,
14261470
py::arg("enable_etdump") = false,
14271471
py::arg("debug_buffer_size") = 0,
14281472
py::arg("program_verification") =
@@ -1432,6 +1476,7 @@ PYBIND11_MODULE(EXECUTORCH_PYTHON_MODULE_NAME, m) {
14321476
"_load_for_executorch_from_bundled_program",
14331477
&PyModule::load_from_bundled_program,
14341478
py::arg("ptr"),
1479+
py::arg("data_map_ptr") = std::nullopt,
14351480
py::arg("enable_etdump") = false,
14361481
py::arg("debug_buffer_size") = 0,
14371482
call_guard);

extension/pybindings/test/TARGETS

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,9 @@ runtime.python_library(
1717
deps = [
1818
"//caffe2:torch",
1919
"//caffe2:torch_fx",
20+
"//executorch/devtools/bundled_program:config",
21+
"//executorch/devtools/bundled_program:core",
22+
"//executorch/devtools/bundled_program/serialize:lib",
2023
"//executorch/exir:lib",
2124
"//executorch/exir:pass_manager",
2225
"//executorch/exir:scalar_type",

extension/pybindings/test/test_pybindings.py

Lines changed: 70 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -635,24 +635,84 @@ def test_program_data_separation(self) -> None:
635635
external_constants=True,
636636
)
637637
)
638+
program_buffer = exec_program.buffer
639+
assert len(exec_program._tensor_data) == 1
640+
data_buffer = bytes(exec_program._tensor_data.pop("_default_external_constant"))
638641

639642
import os
640643
import tempfile
641644

642645
with tempfile.TemporaryDirectory() as tmpdir:
643646
pte_file = os.path.join(tmpdir, "linear.pte")
644647
with open(pte_file, "wb") as f:
645-
f.write(exec_program.buffer)
646-
648+
f.write(program_buffer)
647649
ptd_file = os.path.join(tmpdir, "linear.ptd")
648650
with open(ptd_file, "wb") as ptd:
649-
tensor_data = bytes(
650-
exec_program._tensor_data.pop("_default_external_constant")
651-
)
652-
ptd.write(tensor_data)
651+
ptd.write(data_buffer)
652+
expected = eager_module(inputs[0])
653+
# Test 1: File-based loading with external data file
654+
executorch_module_file = self.runtime._load_for_executorch(
655+
pte_file, ptd_file
656+
)
657+
executorch_output_file = executorch_module_file.forward(inputs)[0]
658+
self.assertTrue(torch.allclose(expected, executorch_output_file))
653659

654-
executorch_program = self.runtime._load_for_executorch(pte_file, ptd_file)
660+
# Test 2: Buffer-based loading with external data buffer
661+
executorch_module_buffer = self.load_fn(program_buffer, data_buffer)
662+
executorch_output_buffer = executorch_module_buffer.forward(inputs)[0]
663+
self.assertTrue(torch.allclose(expected, executorch_output_buffer))
655664

656-
expected = eager_module(inputs[0])
657-
executorch_output = executorch_program.forward(inputs)[0]
658-
self.assertTrue(torch.allclose(expected, executorch_output))
665+
# Test 3: Buffer-based loading without external data file (should fail or work differently)
666+
# This should fail because the program expects external data
667+
executorch_module_no_data = self.load_fn(program_buffer)
668+
with self.assertRaises(RuntimeError):
669+
executorch_module_no_data.forward(inputs)
670+
671+
# Test 4: Test with invalid data buffer (should fail)
672+
invalid_bytes = b"invalid bytes"
673+
executorch_module_invalid_data = self.load_fn(program_buffer, invalid_bytes)
674+
with self.assertRaises(RuntimeError):
675+
executorch_module_invalid_data.forward(inputs)
676+
677+
# Test 5: Test bundled program loading with external data
678+
# First create a bundled program with external constants
679+
from executorch.devtools.bundled_program.config import (
680+
MethodTestCase,
681+
MethodTestSuite,
682+
)
683+
from executorch.devtools.bundled_program.core import BundledProgram
684+
from executorch.devtools.bundled_program.serialize import (
685+
serialize_from_bundled_program_to_flatbuffer,
686+
)
687+
688+
method_test_suites = [
689+
MethodTestSuite(
690+
method_name="forward",
691+
test_cases=[
692+
MethodTestCase(
693+
inputs=input,
694+
expected_outputs=expected,
695+
)
696+
for input in inputs
697+
],
698+
),
699+
]
700+
bundled_program = BundledProgram(exec_program, method_test_suites)
701+
bundled_buffer = serialize_from_bundled_program_to_flatbuffer(bundled_program)
702+
bundled_module = self.runtime._load_bundled_program_from_buffer(bundled_buffer)
703+
704+
# Load module from bundled program with external data
705+
executorch_module_bundled = (
706+
self.runtime._load_for_executorch_from_bundled_program(
707+
bundled_module, data_buffer
708+
)
709+
)
710+
executorch_output_bundled = executorch_module_bundled.forward(inputs)[0]
711+
self.assertTrue(torch.allclose(expected, executorch_output_bundled))
712+
713+
# Test 6: Bundled program without external data should fail
714+
executorch_module_bundled_no_data = (
715+
self.runtime._load_for_executorch_from_bundled_program(bundled_module)
716+
)
717+
with self.assertRaises(RuntimeError):
718+
executorch_module_bundled_no_data.forward(inputs)

0 commit comments

Comments
 (0)