Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions mlir/lib/Bindings/Python/Pass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,13 @@ class PyPassManager {

/// Create the `mlir.passmanager` here.
void mlir::python::populatePassManagerSubmodule(nb::module_ &m) {
//----------------------------------------------------------------------------
// Mapping of MlirExternalPass
//----------------------------------------------------------------------------
nb::class_<MlirExternalPass>(m, "ExternalPass")
.def("signal_pass_failure",
[](MlirExternalPass pass) { mlirExternalPassSignalFailure(pass); });

//----------------------------------------------------------------------------
// Mapping of the top-level PassManager
//----------------------------------------------------------------------------
Expand Down Expand Up @@ -182,9 +189,9 @@ void mlir::python::populatePassManagerSubmodule(nb::module_ &m) {
callbacks.clone = [](void *) -> void * {
throw std::runtime_error("Cloning Python passes not supported");
};
callbacks.run = [](MlirOperation op, MlirExternalPass,
callbacks.run = [](MlirOperation op, MlirExternalPass pass,
void *userData) {
nb::borrow<nb::callable>(static_cast<PyObject *>(userData))(op);
nb::handle(static_cast<PyObject *>(userData))(op, pass);
};
auto externalPass = mlirCreateExternalPass(
passID, mlirStringRefCreate(name->data(), name->length()),
Expand Down
20 changes: 17 additions & 3 deletions mlir/test/python/python_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,12 +64,12 @@ def testCustomPass():
"""
)

def custom_pass_1(op):
def custom_pass_1(op, pass_):
print("hello from pass 1!!!", file=sys.stderr)

class CustomPass2:
def __call__(self, m):
apply_patterns_and_fold_greedily(m, frozen)
def __call__(self, op, pass_):
apply_patterns_and_fold_greedily(op, frozen)

custom_pass_2 = CustomPass2()

Expand All @@ -86,3 +86,17 @@ def __call__(self, m):
# CHECK: llvm.mul
pm.add("convert-arith-to-llvm")
pm.run(module)

# test signal_pass_failure
def custom_pass_that_fails(op, pass_):
print("hello from pass that fails")
pass_.signal_pass_failure()

pm = PassManager("any")
pm.add(custom_pass_that_fails, "CustomPassThatFails")
# CHECK: hello from pass that fails
# CHECK: caught exception: Failure while executing pass pipeline
try:
pm.run(module)
except Exception as e:
print(f"caught exception: {e}")