From 8d17e5b831de56237324c34aefabf608b6d639b2 Mon Sep 17 00:00:00 2001 From: PragmaTwice Date: Tue, 9 Sep 2025 12:13:16 +0800 Subject: [PATCH 1/5] [MLIR][Python] Add the ability to signal pass failures in python-defined passes --- mlir/lib/Bindings/Python/Pass.cpp | 18 ++++++++++++++++-- mlir/test/python/python_pass.py | 17 +++++++++++++++++ 2 files changed, 33 insertions(+), 2 deletions(-) diff --git a/mlir/lib/Bindings/Python/Pass.cpp b/mlir/lib/Bindings/Python/Pass.cpp index 6ee85e8a31492..fb7dc2705b3ce 100644 --- a/mlir/lib/Bindings/Python/Pass.cpp +++ b/mlir/lib/Bindings/Python/Pass.cpp @@ -56,6 +56,8 @@ class PyPassManager { /// Create the `mlir.passmanager` here. void mlir::python::populatePassManagerSubmodule(nb::module_ &m) { + constexpr const char *mlirExternalPassAttr = "__mlir_external_pass__"; + //---------------------------------------------------------------------------- // Mapping of the top-level PassManager //---------------------------------------------------------------------------- @@ -182,10 +184,22 @@ void mlir::python::populatePassManagerSubmodule(nb::module_ &m) { callbacks.clone = [](void *) -> void * { throw std::runtime_error("Cloning Python passes not supported"); }; - callbacks.run = [](MlirOperation op, MlirExternalPass, + callbacks.run = [](MlirOperation op, MlirExternalPass pass, void *userData) { - nb::borrow(static_cast(userData))(op); + auto callable = + nb::borrow(static_cast(userData)); + nb::setattr(callable, mlirExternalPassAttr, + nb::capsule(pass.ptr)); + callable(op); + // delete it to avoid that it is used after + // the external pass is freed by the pass manager + nb::delattr(callable, mlirExternalPassAttr); }; + nb::setattr(run, "signal_pass_failure", nb::cpp_function([run]() { + nb::capsule cap = run.attr(mlirExternalPassAttr); + mlirExternalPassSignalFailure( + MlirExternalPass{cap.data()}); + })); auto externalPass = mlirCreateExternalPass( passID, mlirStringRefCreate(name->data(), name->length()), mlirStringRefCreate(argument.data(), argument.length()), diff --git a/mlir/test/python/python_pass.py b/mlir/test/python/python_pass.py index c94f96e20966f..7734d76fcba94 100644 --- a/mlir/test/python/python_pass.py +++ b/mlir/test/python/python_pass.py @@ -86,3 +86,20 @@ def __call__(self, m): # CHECK: llvm.mul pm.add("convert-arith-to-llvm") pm.run(module) + + # test signal_pass_failure + class CustomPassThatFails: + def __call__(self, m): + print("hello from pass that fails") + self.signal_pass_failure() + + custom_pass_that_fails = CustomPassThatFails() + + pm = PassManager("any") + pm.add(custom_pass_that_fails, "CustomPassThatFails") + # CHECK: hello from pass that fails + # CHECK: caught exception: Failure while executing pass pipeline + try: + pm.run(module) + except Exception as e: + print(f"caught exception: {e}") From 2241e278856ed01ab912da5eb5567b13f02358d8 Mon Sep 17 00:00:00 2001 From: PragmaTwice Date: Tue, 9 Sep 2025 13:03:13 +0800 Subject: [PATCH 2/5] refine the bad path --- mlir/lib/Bindings/Python/Pass.cpp | 9 ++++++++- mlir/test/python/python_pass.py | 6 ++++++ 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/mlir/lib/Bindings/Python/Pass.cpp b/mlir/lib/Bindings/Python/Pass.cpp index fb7dc2705b3ce..c5fe7bda4a680 100644 --- a/mlir/lib/Bindings/Python/Pass.cpp +++ b/mlir/lib/Bindings/Python/Pass.cpp @@ -196,7 +196,14 @@ void mlir::python::populatePassManagerSubmodule(nb::module_ &m) { nb::delattr(callable, mlirExternalPassAttr); }; nb::setattr(run, "signal_pass_failure", nb::cpp_function([run]() { - nb::capsule cap = run.attr(mlirExternalPassAttr); + nb::capsule cap; + try { + cap = run.attr(mlirExternalPassAttr); + } catch (nb::python_error &e) { + throw std::runtime_error( + "signal_pass_failure() should always be called " + "from the __call__ method"); + } mlirExternalPassSignalFailure( MlirExternalPass{cap.data()}); })); diff --git a/mlir/test/python/python_pass.py b/mlir/test/python/python_pass.py index 7734d76fcba94..4784e073fef0a 100644 --- a/mlir/test/python/python_pass.py +++ b/mlir/test/python/python_pass.py @@ -103,3 +103,9 @@ def __call__(self, m): pm.run(module) except Exception as e: print(f"caught exception: {e}") + + # CHECK: caught exception: signal_pass_failure() should always be called from the __call__ method + try: + custom_pass_that_fails.signal_pass_failure() + except Exception as e: + print(f"caught exception: {e}") From 28638ab495cd0cde1487d5540819aea64f57d014 Mon Sep 17 00:00:00 2001 From: PragmaTwice Date: Tue, 9 Sep 2025 21:19:11 +0800 Subject: [PATCH 3/5] drop the setattr design --- mlir/lib/Bindings/Python/Pass.cpp | 28 +++++++--------------------- mlir/test/python/python_pass.py | 16 +++++----------- 2 files changed, 12 insertions(+), 32 deletions(-) diff --git a/mlir/lib/Bindings/Python/Pass.cpp b/mlir/lib/Bindings/Python/Pass.cpp index c5fe7bda4a680..ef606431fbd5e 100644 --- a/mlir/lib/Bindings/Python/Pass.cpp +++ b/mlir/lib/Bindings/Python/Pass.cpp @@ -56,7 +56,12 @@ class PyPassManager { /// Create the `mlir.passmanager` here. void mlir::python::populatePassManagerSubmodule(nb::module_ &m) { - constexpr const char *mlirExternalPassAttr = "__mlir_external_pass__"; + //---------------------------------------------------------------------------- + // Mapping of MlirExternalPass + //---------------------------------------------------------------------------- + nb::class_(m, "ExternalPass") + .def("signal_failure", + [](MlirExternalPass pass) { mlirExternalPassSignalFailure(pass); }); //---------------------------------------------------------------------------- // Mapping of the top-level PassManager @@ -186,27 +191,8 @@ void mlir::python::populatePassManagerSubmodule(nb::module_ &m) { }; callbacks.run = [](MlirOperation op, MlirExternalPass pass, void *userData) { - auto callable = - nb::borrow(static_cast(userData)); - nb::setattr(callable, mlirExternalPassAttr, - nb::capsule(pass.ptr)); - callable(op); - // delete it to avoid that it is used after - // the external pass is freed by the pass manager - nb::delattr(callable, mlirExternalPassAttr); + nb::handle(static_cast(userData))(op, pass); }; - nb::setattr(run, "signal_pass_failure", nb::cpp_function([run]() { - nb::capsule cap; - try { - cap = run.attr(mlirExternalPassAttr); - } catch (nb::python_error &e) { - throw std::runtime_error( - "signal_pass_failure() should always be called " - "from the __call__ method"); - } - mlirExternalPassSignalFailure( - MlirExternalPass{cap.data()}); - })); auto externalPass = mlirCreateExternalPass( passID, mlirStringRefCreate(name->data(), name->length()), mlirStringRefCreate(argument.data(), argument.length()), diff --git a/mlir/test/python/python_pass.py b/mlir/test/python/python_pass.py index 4784e073fef0a..10b449f9b1ef8 100644 --- a/mlir/test/python/python_pass.py +++ b/mlir/test/python/python_pass.py @@ -64,12 +64,12 @@ def testCustomPass(): """ ) - def custom_pass_1(op): + def custom_pass_1(op, pass_): print("hello from pass 1!!!", file=sys.stderr) class CustomPass2: - def __call__(self, m): - apply_patterns_and_fold_greedily(m, frozen) + def __call__(self, op, pass_): + apply_patterns_and_fold_greedily(op, frozen) custom_pass_2 = CustomPass2() @@ -89,9 +89,9 @@ def __call__(self, m): # test signal_pass_failure class CustomPassThatFails: - def __call__(self, m): + def __call__(self, op, pass_): print("hello from pass that fails") - self.signal_pass_failure() + pass_.signal_failure() custom_pass_that_fails = CustomPassThatFails() @@ -103,9 +103,3 @@ def __call__(self, m): pm.run(module) except Exception as e: print(f"caught exception: {e}") - - # CHECK: caught exception: signal_pass_failure() should always be called from the __call__ method - try: - custom_pass_that_fails.signal_pass_failure() - except Exception as e: - print(f"caught exception: {e}") From 77c2f6cd4a60dc60f1d266d53177b908e7c97abd Mon Sep 17 00:00:00 2001 From: PragmaTwice Date: Tue, 9 Sep 2025 21:23:43 +0800 Subject: [PATCH 4/5] fix style --- mlir/test/python/python_pass.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/mlir/test/python/python_pass.py b/mlir/test/python/python_pass.py index 10b449f9b1ef8..72af1d93cd5db 100644 --- a/mlir/test/python/python_pass.py +++ b/mlir/test/python/python_pass.py @@ -88,12 +88,9 @@ def __call__(self, op, pass_): pm.run(module) # test signal_pass_failure - class CustomPassThatFails: - def __call__(self, op, pass_): - print("hello from pass that fails") - pass_.signal_failure() - - custom_pass_that_fails = CustomPassThatFails() + def custom_pass_that_fails(op, pass_): + print("hello from pass that fails") + pass_.signal_failure() pm = PassManager("any") pm.add(custom_pass_that_fails, "CustomPassThatFails") From 2a99c8c3d3d2b2eb52bf457be398ade2d5493d7f Mon Sep 17 00:00:00 2001 From: PragmaTwice Date: Tue, 9 Sep 2025 21:40:43 +0800 Subject: [PATCH 5/5] fix name --- mlir/lib/Bindings/Python/Pass.cpp | 2 +- mlir/test/python/python_pass.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/mlir/lib/Bindings/Python/Pass.cpp b/mlir/lib/Bindings/Python/Pass.cpp index ef606431fbd5e..47ef5d8e9dd3b 100644 --- a/mlir/lib/Bindings/Python/Pass.cpp +++ b/mlir/lib/Bindings/Python/Pass.cpp @@ -60,7 +60,7 @@ void mlir::python::populatePassManagerSubmodule(nb::module_ &m) { // Mapping of MlirExternalPass //---------------------------------------------------------------------------- nb::class_(m, "ExternalPass") - .def("signal_failure", + .def("signal_pass_failure", [](MlirExternalPass pass) { mlirExternalPassSignalFailure(pass); }); //---------------------------------------------------------------------------- diff --git a/mlir/test/python/python_pass.py b/mlir/test/python/python_pass.py index 72af1d93cd5db..50c42102f66d3 100644 --- a/mlir/test/python/python_pass.py +++ b/mlir/test/python/python_pass.py @@ -90,7 +90,7 @@ def __call__(self, op, pass_): # test signal_pass_failure def custom_pass_that_fails(op, pass_): print("hello from pass that fails") - pass_.signal_failure() + pass_.signal_pass_failure() pm = PassManager("any") pm.add(custom_pass_that_fails, "CustomPassThatFails")