Skip to content

Commit 236c07d

Browse files
committed
Fix BatchNorm tests by passing relevant input buffers.
Signed-off-by: zjgarvey <[email protected]>
1 parent 95d0eb0 commit 236c07d

File tree

1 file changed

+4
-1
lines changed

1 file changed

+4
-1
lines changed

projects/e2e/torch_mlir_e2e_test/configs/fx_importer_backend.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -149,9 +149,12 @@ def _export_run(self, artifact: torch.nn.Module, trace: Trace) -> Trace:
149149
)
150150
module = self._backend.compile(module)
151151
backend_module = self._backend.load(module)
152+
input_buffers = prog.graph_signature.inputs_to_buffers.values()
152153
params = {
153154
# **dict(artifact.named_parameters(remove_duplicate=False)),
154-
**dict(artifact.named_buffers(remove_duplicate=False)),
155+
name: value
156+
for (name, value) in artifact.named_buffers(remove_duplicate=False)
157+
if name in input_buffers
155158
}
156159
params_flat, params_spec = pytree.tree_flatten(params)
157160
params_flat = list(params_flat)

0 commit comments

Comments
 (0)