Skip to content

Commit 7d04e37

Browse files
PragmaTwicecnb.bsD2OPwAgEAmakslevental
authored
[MLIR][Python] Support Python-defined passes in MLIR (#156000)
It closes #155996. This PR added a method `add(callable, ..)` to `mlir.passmanager.PassManager` to accept a callable object for defining passes in the Python side. This is a simple example of a Python-defined pass. ```python from mlir.passmanager import PassManager def demo_pass_1(op): # do something with op pass class DemoPass: def __init__(self, ...): pass def __call__(op): # do something pass demo_pass_2 = DemoPass(..) pm = PassManager('any', ctx) pm.add(demo_pass_1) pm.add(demo_pass_2) pm.add("registered-passes") pm.run(..) ``` --------- Co-authored-by: cnb.bsD2OPwAgEA <[email protected]> Co-authored-by: Maksim Levental <[email protected]>
1 parent 82ef4ee commit 7d04e37

File tree

4 files changed

+137
-4
lines changed

4 files changed

+137
-4
lines changed

mlir/lib/Bindings/Python/MainModule.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,7 @@ NB_MODULE(_mlir, m) {
136136
populateRewriteSubmodule(rewriteModule);
137137

138138
// Define and populate PassManager submodule.
139-
auto passModule =
139+
auto passManagerModule =
140140
m.def_submodule("passmanager", "MLIR Pass Management Bindings");
141-
populatePassManagerSubmodule(passModule);
141+
populatePassManagerSubmodule(passManagerModule);
142142
}

mlir/lib/Bindings/Python/Pass.cpp

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,10 @@
1010

1111
#include "IRModule.h"
1212
#include "mlir-c/Pass.h"
13+
// clang-format off
1314
#include "mlir/Bindings/Python/Nanobind.h"
1415
#include "mlir-c/Bindings/Python/Interop.h" // This is expected after nanobind.
16+
// clang-format on
1517

1618
namespace nb = nanobind;
1719
using namespace nb::literals;
@@ -157,6 +159,45 @@ void mlir::python::populatePassManagerSubmodule(nb::module_ &m) {
157159
"pipeline"_a,
158160
"Add textual pipeline elements to the pass manager. Throws a "
159161
"ValueError if the pipeline can't be parsed.")
162+
.def(
163+
"add",
164+
[](PyPassManager &passManager, const nb::callable &run,
165+
std::optional<std::string> &name, const std::string &argument,
166+
const std::string &description, const std::string &opName) {
167+
if (!name.has_value()) {
168+
name = nb::cast<std::string>(
169+
nb::borrow<nb::str>(run.attr("__name__")));
170+
}
171+
MlirTypeIDAllocator typeIDAllocator = mlirTypeIDAllocatorCreate();
172+
MlirTypeID passID =
173+
mlirTypeIDAllocatorAllocateTypeID(typeIDAllocator);
174+
MlirExternalPassCallbacks callbacks;
175+
callbacks.construct = [](void *obj) {
176+
(void)nb::handle(static_cast<PyObject *>(obj)).inc_ref();
177+
};
178+
callbacks.destruct = [](void *obj) {
179+
(void)nb::handle(static_cast<PyObject *>(obj)).dec_ref();
180+
};
181+
callbacks.initialize = nullptr;
182+
callbacks.clone = [](void *) -> void * {
183+
throw std::runtime_error("Cloning Python passes not supported");
184+
};
185+
callbacks.run = [](MlirOperation op, MlirExternalPass,
186+
void *userData) {
187+
nb::borrow<nb::callable>(static_cast<PyObject *>(userData))(op);
188+
};
189+
auto externalPass = mlirCreateExternalPass(
190+
passID, mlirStringRefCreate(name->data(), name->length()),
191+
mlirStringRefCreate(argument.data(), argument.length()),
192+
mlirStringRefCreate(description.data(), description.length()),
193+
mlirStringRefCreate(opName.data(), opName.size()),
194+
/*nDependentDialects*/ 0, /*dependentDialects*/ nullptr,
195+
callbacks, /*userData*/ run.ptr());
196+
mlirPassManagerAddOwnedPass(passManager.get(), externalPass);
197+
},
198+
"run"_a, "name"_a.none() = nb::none(), "argument"_a.none() = "",
199+
"description"_a.none() = "", "op_name"_a.none() = "",
200+
"Add a python-defined pass to the pass manager.")
160201
.def(
161202
"run",
162203
[](PyPassManager &passManager, PyOperationBase &op) {

mlir/lib/CAPI/IR/Pass.cpp

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -145,10 +145,14 @@ class ExternalPass : public Pass {
145145
: Pass(passID, opName), id(passID), name(name), argument(argument),
146146
description(description), dependentDialects(dependentDialects),
147147
callbacks(callbacks), userData(userData) {
148-
callbacks.construct(userData);
148+
if (callbacks.construct)
149+
callbacks.construct(userData);
149150
}
150151

151-
~ExternalPass() override { callbacks.destruct(userData); }
152+
~ExternalPass() override {
153+
if (callbacks.destruct)
154+
callbacks.destruct(userData);
155+
}
152156

153157
StringRef getName() const override { return name; }
154158
StringRef getArgument() const override { return argument; }

mlir/test/python/python_pass.py

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
# RUN: %PYTHON %s 2>&1 | FileCheck %s
2+
3+
import gc, sys
4+
from mlir.ir import *
5+
from mlir.passmanager import *
6+
from mlir.dialects.builtin import ModuleOp
7+
from mlir.dialects import pdl
8+
from mlir.rewrite import *
9+
10+
11+
def log(*args):
12+
print(*args, file=sys.stderr)
13+
sys.stderr.flush()
14+
15+
16+
def run(f):
17+
log("\nTEST:", f.__name__)
18+
f()
19+
gc.collect()
20+
assert Context._get_live_count() == 0
21+
22+
23+
def make_pdl_module():
24+
with Location.unknown():
25+
pdl_module = Module.create()
26+
with InsertionPoint(pdl_module.body):
27+
# Change all arith.addi with index types to arith.muli.
28+
@pdl.pattern(benefit=1, sym_name="addi_to_mul")
29+
def pat():
30+
# Match arith.addi with index types.
31+
i64_type = pdl.TypeOp(IntegerType.get_signless(64))
32+
operand0 = pdl.OperandOp(i64_type)
33+
operand1 = pdl.OperandOp(i64_type)
34+
op0 = pdl.OperationOp(
35+
name="arith.addi", args=[operand0, operand1], types=[i64_type]
36+
)
37+
38+
# Replace the matched op with arith.muli.
39+
@pdl.rewrite()
40+
def rew():
41+
newOp = pdl.OperationOp(
42+
name="arith.muli", args=[operand0, operand1], types=[i64_type]
43+
)
44+
pdl.ReplaceOp(op0, with_op=newOp)
45+
46+
return pdl_module
47+
48+
49+
# CHECK-LABEL: TEST: testCustomPass
50+
@run
51+
def testCustomPass():
52+
with Context():
53+
pdl_module = make_pdl_module()
54+
frozen = PDLModule(pdl_module).freeze()
55+
56+
module = ModuleOp.parse(
57+
r"""
58+
module {
59+
func.func @add(%a: i64, %b: i64) -> i64 {
60+
%sum = arith.addi %a, %b : i64
61+
return %sum : i64
62+
}
63+
}
64+
"""
65+
)
66+
67+
def custom_pass_1(op):
68+
print("hello from pass 1!!!", file=sys.stderr)
69+
70+
class CustomPass2:
71+
def __call__(self, m):
72+
apply_patterns_and_fold_greedily(m, frozen)
73+
74+
custom_pass_2 = CustomPass2()
75+
76+
pm = PassManager("any")
77+
pm.enable_ir_printing()
78+
79+
# CHECK: hello from pass 1!!!
80+
# CHECK-LABEL: Dump After custom_pass_1
81+
pm.add(custom_pass_1)
82+
# CHECK-LABEL: Dump After CustomPass2
83+
# CHECK: arith.muli
84+
pm.add(custom_pass_2, "CustomPass2")
85+
# CHECK-LABEL: Dump After ArithToLLVMConversionPass
86+
# CHECK: llvm.mul
87+
pm.add("convert-arith-to-llvm")
88+
pm.run(module)

0 commit comments

Comments
 (0)