Skip to content

Commit 7123463

Browse files
authored
[MLIR][Python] Add the ability to signal pass failures in python-defined passes (#157613)
This is a follow-up PR for #156000. In this PR we add the ability to signal pass failures (`signal_pass_failure()`) in python-defined passes. To achieve this, we expose `MlirExternalPass` via `nb::class_` with a method `signal_pass_failure()`, and the callable passed to `pm.add(..)` now accepts two arguments (`op: MlirOperation, pass_: MlirExternalPass`). For example: ```python def custom_pass_that_fails(op, pass_): if some_condition: pass_.signal_pass_failure() # do something ```
1 parent 406d6bd commit 7123463

File tree

2 files changed

+26
-5
lines changed

2 files changed

+26
-5
lines changed

mlir/lib/Bindings/Python/Pass.cpp

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,13 @@ class PyPassManager {
5656

5757
/// Create the `mlir.passmanager` here.
5858
void mlir::python::populatePassManagerSubmodule(nb::module_ &m) {
59+
//----------------------------------------------------------------------------
60+
// Mapping of MlirExternalPass
61+
//----------------------------------------------------------------------------
62+
nb::class_<MlirExternalPass>(m, "ExternalPass")
63+
.def("signal_pass_failure",
64+
[](MlirExternalPass pass) { mlirExternalPassSignalFailure(pass); });
65+
5966
//----------------------------------------------------------------------------
6067
// Mapping of the top-level PassManager
6168
//----------------------------------------------------------------------------
@@ -182,9 +189,9 @@ void mlir::python::populatePassManagerSubmodule(nb::module_ &m) {
182189
callbacks.clone = [](void *) -> void * {
183190
throw std::runtime_error("Cloning Python passes not supported");
184191
};
185-
callbacks.run = [](MlirOperation op, MlirExternalPass,
192+
callbacks.run = [](MlirOperation op, MlirExternalPass pass,
186193
void *userData) {
187-
nb::borrow<nb::callable>(static_cast<PyObject *>(userData))(op);
194+
nb::handle(static_cast<PyObject *>(userData))(op, pass);
188195
};
189196
auto externalPass = mlirCreateExternalPass(
190197
passID, mlirStringRefCreate(name->data(), name->length()),

mlir/test/python/python_pass.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -64,12 +64,12 @@ def testCustomPass():
6464
"""
6565
)
6666

67-
def custom_pass_1(op):
67+
def custom_pass_1(op, pass_):
6868
print("hello from pass 1!!!", file=sys.stderr)
6969

7070
class CustomPass2:
71-
def __call__(self, m):
72-
apply_patterns_and_fold_greedily(m, frozen)
71+
def __call__(self, op, pass_):
72+
apply_patterns_and_fold_greedily(op, frozen)
7373

7474
custom_pass_2 = CustomPass2()
7575

@@ -86,3 +86,17 @@ def __call__(self, m):
8686
# CHECK: llvm.mul
8787
pm.add("convert-arith-to-llvm")
8888
pm.run(module)
89+
90+
# test signal_pass_failure
91+
def custom_pass_that_fails(op, pass_):
92+
print("hello from pass that fails")
93+
pass_.signal_pass_failure()
94+
95+
pm = PassManager("any")
96+
pm.add(custom_pass_that_fails, "CustomPassThatFails")
97+
# CHECK: hello from pass that fails
98+
# CHECK: caught exception: Failure while executing pass pipeline
99+
try:
100+
pm.run(module)
101+
except Exception as e:
102+
print(f"caught exception: {e}")

0 commit comments

Comments
 (0)