Skip to content

Commit f77bb5e

Browse files
authored
Fix delegate node metadata
Differential Revision: D78350040 Pull Request resolved: #12504
1 parent 0a038a7 commit f77bb5e

File tree

3 files changed

+61
-5
lines changed

3 files changed

+61
-5
lines changed

backends/arm/test/tester/arm_tester.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -726,7 +726,7 @@ def _get_dtype_distribution(
726726
if node.op == "placeholder":
727727
placeholder_dtypes.append(str(node.meta["val"].dtype))
728728
if node.op == "call_function":
729-
if "val" in node.meta:
729+
if "val" in node.meta and isinstance(node.meta["val"], torch.Tensor):
730730
dtype, _, _ = extract_tensor_meta(node.meta, tosa_spec)
731731
call_function_dtypes.append(ts.DTypeNames[dtype])
732732
return Counter(placeholder_dtypes), Counter(call_function_dtypes)

exir/backend/backend_api.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -235,7 +235,9 @@ def generate_debug_handle(ep: ExportedProgram) -> int:
235235
call_submodule_node.kwargs,
236236
)
237237
call_delegate_node.meta["debug_handle"] = generate_debug_handle(owning_program)
238-
call_delegate_node.meta["val"] = submodule_output_node.meta["val"]
238+
call_delegate_node.meta["val"] = [
239+
out_arg.meta["val"] for out_arg in submodule_output_node.args[0]
240+
]
239241
call_submodule_node.replace_all_uses_with(call_delegate_node)
240242
owning_graph_module.graph.erase_node(call_submodule_node)
241243
if is_submodule:
@@ -472,11 +474,9 @@ def _create_partitions_in_graph_module(
472474
tagged_graph_module, node_list, tag
473475
)
474476

475-
tagged_graph_module_output_node = tagged_graph_module.graph.output_node()
476477
submodule_output_node = submodule.graph.output_node()
477478
# Copy the output node meta from the original output node, because
478479
# create_submodule_from_nodes doesn't cover the meta field
479-
submodule_output_node.meta = tagged_graph_module_output_node.meta
480480
logging.debug(f"Partitioned graph module: {tagged_graph_module}")
481481
(
482482
submodule_program,

exir/backend/test/demos/test_xnnpack_qnnpack.py

Lines changed: 57 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,11 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
import tempfile
78
import unittest
89

10+
from typing import Tuple
11+
912
import executorch.exir as exir
1013

1114
import torch
@@ -20,7 +23,13 @@
2023
# import the xnnpack backend implementation
2124
from executorch.backends.xnnpack.xnnpack_preprocess import XnnpackBackend
2225

23-
from executorch.exir import CaptureConfig
26+
from executorch.exir import (
27+
CaptureConfig,
28+
EdgeCompileConfig,
29+
EdgeProgramManager,
30+
to_edge_transform_and_lower,
31+
)
32+
2433
from executorch.exir.backend.backend_api import to_backend, validation_disabled
2534
from executorch.exir.passes.spec_prop_pass import SpecPropPass
2635

@@ -132,3 +141,50 @@ def forward(self, x, y):
132141
self.assertTrue(
133142
torch.allclose(model_output[0], ref_output, atol=1e-03, rtol=1e-03)
134143
)
144+
145+
def test_serde(self):
146+
# The module with blank_logprobs() function
147+
class BlankLogProbsModule(torch.nn.Module):
148+
def __init__(self) -> None:
149+
super().__init__()
150+
self.linear = torch.nn.Linear(768, 1)
151+
self.log_sigmoid = torch.nn.LogSigmoid()
152+
153+
def forward(self, joint_encodings: torch.Tensor) -> torch.Tensor:
154+
tanh_out = torch.tanh(joint_encodings)
155+
linear_out = self.linear(tanh_out)
156+
blank_output = self.log_sigmoid(linear_out)
157+
return blank_output
158+
159+
def get_blank_logprobs_inputs_fn() -> Tuple[torch.Tensor, ...]:
160+
"""
161+
Get the input to the blank_logprobs() and nonblank_logprobs() functions.
162+
"""
163+
return (torch.randn(1, 1, 1, 768),)
164+
165+
model = BlankLogProbsModule()
166+
# Get the inputs for the logprobs function
167+
logprobs_fake_inputs = get_blank_logprobs_inputs_fn()
168+
169+
# Export and partition
170+
aten_prog = torch.export.export(model, logprobs_fake_inputs, strict=True)
171+
partitioned_prog: EdgeProgramManager = to_edge_transform_and_lower(
172+
aten_prog,
173+
partitioner=[XnnpackFloatingPointPartitioner()],
174+
compile_config=EdgeCompileConfig(
175+
_check_ir_validity=False,
176+
_use_edge_ops=True,
177+
),
178+
)
179+
180+
with tempfile.NamedTemporaryFile(suffix=".pt2") as f:
181+
exir.save(partitioned_prog.exported_program(), f.name)
182+
f.seek(0)
183+
loaded_model = exir.load(f.name)
184+
185+
self.assertTrue(
186+
torch.allclose(
187+
model(*logprobs_fake_inputs),
188+
loaded_model.module()(*logprobs_fake_inputs),
189+
)
190+
)

0 commit comments

Comments
 (0)