Skip to content

Commit 982a52f

Browse files
committed
[MLIR][Python] Support Python-defined passes in MLIR
based heavily on #156000
1 parent 1d848cf commit 982a52f

File tree

7 files changed

+168
-15
lines changed

7 files changed

+168
-15
lines changed

mlir/include/mlir-c/Rewrite.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -301,6 +301,10 @@ mlirFreezeRewritePattern(MlirRewritePatternSet op);
301301
MLIR_CAPI_EXPORTED void
302302
mlirFrozenRewritePatternSetDestroy(MlirFrozenRewritePatternSet op);
303303

304+
MLIR_CAPI_EXPORTED MlirLogicalResult mlirApplyPatternsAndFoldGreedilyWithOp(
305+
MlirOperation op, MlirFrozenRewritePatternSet patterns,
306+
MlirGreedyRewriteDriverConfig);
307+
304308
MLIR_CAPI_EXPORTED MlirLogicalResult mlirApplyPatternsAndFoldGreedily(
305309
MlirModule op, MlirFrozenRewritePatternSet patterns,
306310
MlirGreedyRewriteDriverConfig);

mlir/lib/Bindings/Python/MainModule.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,7 @@ NB_MODULE(_mlir, m) {
136136
populateRewriteSubmodule(rewriteModule);
137137

138138
// Define and populate PassManager submodule.
139-
auto passModule =
139+
auto passManagerModule =
140140
m.def_submodule("passmanager", "MLIR Pass Management Bindings");
141-
populatePassManagerSubmodule(passModule);
141+
populatePassManagerSubmodule(passManagerModule);
142142
}

mlir/lib/Bindings/Python/Pass.cpp

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,10 @@
1010

1111
#include "IRModule.h"
1212
#include "mlir-c/Pass.h"
13+
// clang-format off
1314
#include "mlir/Bindings/Python/Nanobind.h"
1415
#include "mlir-c/Bindings/Python/Interop.h" // This is expected after nanobind.
16+
// clang-format on
1517

1618
namespace nb = nanobind;
1719
using namespace nb::literals;
@@ -157,6 +159,45 @@ void mlir::python::populatePassManagerSubmodule(nb::module_ &m) {
157159
"pipeline"_a,
158160
"Add textual pipeline elements to the pass manager. Throws a "
159161
"ValueError if the pipeline can't be parsed.")
162+
.def(
163+
"add_python_pass",
164+
[](PyPassManager &passManager, const nb::callable &run,
165+
std::optional<std::string> &name, const std::string &argument,
166+
const std::string &description, const std::string &opName) {
167+
if (!name.has_value()) {
168+
name = nb::cast<std::string>(
169+
nb::borrow<nb::str>(run.attr("__name__")));
170+
}
171+
MlirTypeIDAllocator typeIDAllocator = mlirTypeIDAllocatorCreate();
172+
MlirTypeID passID =
173+
mlirTypeIDAllocatorAllocateTypeID(typeIDAllocator);
174+
MlirExternalPassCallbacks callbacks;
175+
callbacks.construct = [](void *obj) {
176+
(void)nb::handle(static_cast<PyObject *>(obj)).inc_ref();
177+
};
178+
callbacks.destruct = [](void *obj) {
179+
(void)nb::handle(static_cast<PyObject *>(obj)).dec_ref();
180+
};
181+
callbacks.initialize = nullptr;
182+
callbacks.clone = [](void *) -> void * {
183+
throw std::runtime_error("Cloning Python passes not supported");
184+
};
185+
callbacks.run = [](MlirOperation op, MlirExternalPass,
186+
void *userData) {
187+
nb::borrow<nb::callable>(static_cast<PyObject *>(userData))(op);
188+
};
189+
auto externalPass = mlirCreateExternalPass(
190+
passID, mlirStringRefCreate(name->data(), name->length()),
191+
mlirStringRefCreate(argument.data(), argument.length()),
192+
mlirStringRefCreate(description.data(), description.length()),
193+
mlirStringRefCreate(opName.data(), opName.size()),
194+
/*nDependentDialects*/ 0, /*dependentDialects*/ nullptr,
195+
callbacks, /*userData*/ run.ptr());
196+
mlirPassManagerAddOwnedPass(passManager.get(), externalPass);
197+
},
198+
"run"_a, "name"_a.none() = nb::none(), "argument"_a.none() = "",
199+
"description"_a.none() = "", "op_name"_a.none() = "",
200+
"Add a python-defined pass to the pass manager.")
160201
.def(
161202
"run",
162203
[](PyPassManager &passManager, PyOperationBase &op) {

mlir/lib/Bindings/Python/Rewrite.cpp

Lines changed: 24 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,10 @@
1010

1111
#include "IRModule.h"
1212
#include "mlir-c/Rewrite.h"
13+
// clang-format off
1314
#include "mlir/Bindings/Python/Nanobind.h"
1415
#include "mlir-c/Bindings/Python/Interop.h" // This is expected after nanobind.
16+
// clang-format on
1517
#include "mlir/Config/mlir-config.h"
1618

1719
namespace nb = nanobind;
@@ -99,14 +101,26 @@ void mlir::python::populateRewriteSubmodule(nb::module_ &m) {
99101
.def(MLIR_PYTHON_CAPI_FACTORY_ATTR,
100102
&PyFrozenRewritePatternSet::createFromCapsule);
101103
m.def(
102-
"apply_patterns_and_fold_greedily",
103-
[](MlirModule module, MlirFrozenRewritePatternSet set) {
104-
auto status = mlirApplyPatternsAndFoldGreedily(module, set, {});
105-
if (mlirLogicalResultIsFailure(status))
106-
// FIXME: Not sure this is the right error to throw here.
107-
throw nb::value_error("pattern application failed to converge");
108-
},
109-
"module"_a, "set"_a,
110-
"Applys the given patterns to the given module greedily while folding "
111-
"results.");
104+
"apply_patterns_and_fold_greedily",
105+
[](MlirModule module, MlirFrozenRewritePatternSet set) {
106+
auto status = mlirApplyPatternsAndFoldGreedily(module, set, {});
107+
if (mlirLogicalResultIsFailure(status))
108+
// FIXME: Not sure this is the right error to throw here.
109+
throw nb::value_error("pattern application failed to converge");
110+
},
111+
"module"_a, "set"_a,
112+
"Applys the given patterns to the given module greedily while folding "
113+
"results.")
114+
.def(
115+
"apply_patterns_and_fold_greedily_with_op",
116+
[](MlirOperation op, MlirFrozenRewritePatternSet set) {
117+
auto status = mlirApplyPatternsAndFoldGreedilyWithOp(op, set, {});
118+
if (mlirLogicalResultIsFailure(status)) {
119+
throw std::runtime_error(
120+
"pattern application failed to converge");
121+
}
122+
},
123+
"op"_a, "set"_a,
124+
"Applys the given patterns to the given op greedily while folding "
125+
"results.");
112126
}

mlir/lib/CAPI/IR/Pass.cpp

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -145,10 +145,14 @@ class ExternalPass : public Pass {
145145
: Pass(passID, opName), id(passID), name(name), argument(argument),
146146
description(description), dependentDialects(dependentDialects),
147147
callbacks(callbacks), userData(userData) {
148-
callbacks.construct(userData);
148+
if (callbacks.construct)
149+
callbacks.construct(userData);
149150
}
150151

151-
~ExternalPass() override { callbacks.destruct(userData); }
152+
~ExternalPass() override {
153+
if (callbacks.destruct)
154+
callbacks.destruct(userData);
155+
}
152156

153157
StringRef getName() const override { return name; }
154158
StringRef getArgument() const override { return argument; }
@@ -180,7 +184,9 @@ class ExternalPass : public Pass {
180184
}
181185

182186
std::unique_ptr<Pass> clonePass() const override {
183-
void *clonedUserData = callbacks.clone(userData);
187+
void *clonedUserData;
188+
if (callbacks.clone)
189+
clonedUserData = callbacks.clone(userData);
184190
return std::make_unique<ExternalPass>(id, name, argument, description,
185191
getOpName(), dependentDialects,
186192
callbacks, clonedUserData);

mlir/lib/CAPI/Transforms/Rewrite.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -294,6 +294,13 @@ mlirApplyPatternsAndFoldGreedily(MlirModule op,
294294
return wrap(mlir::applyPatternsGreedily(unwrap(op), *unwrap(patterns)));
295295
}
296296

297+
MlirLogicalResult
298+
mlirApplyPatternsAndFoldGreedilyWithOp(MlirOperation op,
299+
MlirFrozenRewritePatternSet patterns,
300+
MlirGreedyRewriteDriverConfig) {
301+
return wrap(mlir::applyPatternsGreedily(unwrap(op), *unwrap(patterns)));
302+
}
303+
297304
//===----------------------------------------------------------------------===//
298305
/// PDLPatternModule API
299306
//===----------------------------------------------------------------------===//

mlir/test/python/pass.py

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
# RUN: %PYTHON %s 2>&1 | FileCheck %s
2+
3+
import gc, sys
4+
from mlir.ir import *
5+
from mlir.passmanager import *
6+
from mlir.dialects.builtin import ModuleOp
7+
from mlir.dialects import pdl
8+
from mlir.rewrite import *
9+
10+
11+
def run(f):
12+
# Note, everything in this file is dumped to stderr because that's where
13+
# `IR Dump After` dumps too (so we can't cross the "streams")
14+
print("\nTEST:", f.__name__, file=sys.stderr)
15+
f()
16+
gc.collect()
17+
18+
19+
def make_pdl_module():
20+
with Location.unknown():
21+
pdl_module = Module.create()
22+
with InsertionPoint(pdl_module.body):
23+
# Change all arith.addi with index types to arith.muli.
24+
@pdl.pattern(benefit=1, sym_name="addi_to_mul")
25+
def pat():
26+
# Match arith.addi with index types.
27+
i64_type = pdl.TypeOp(IntegerType.get_signless(64))
28+
operand0 = pdl.OperandOp(i64_type)
29+
operand1 = pdl.OperandOp(i64_type)
30+
op0 = pdl.OperationOp(
31+
name="arith.addi", args=[operand0, operand1], types=[i64_type]
32+
)
33+
34+
# Replace the matched op with arith.muli.
35+
@pdl.rewrite()
36+
def rew():
37+
newOp = pdl.OperationOp(
38+
name="arith.muli", args=[operand0, operand1], types=[i64_type]
39+
)
40+
pdl.ReplaceOp(op0, with_op=newOp)
41+
42+
return pdl_module
43+
44+
45+
# CHECK-LABEL: TEST: testCustomPass
46+
@run
47+
def testCustomPass():
48+
with Context():
49+
pdl_module = make_pdl_module()
50+
frozen = PDLModule(pdl_module).freeze()
51+
52+
module = ModuleOp.parse(
53+
r"""
54+
module {
55+
func.func @add(%a: i64, %b: i64) -> i64 {
56+
%sum = arith.addi %a, %b : i64
57+
return %sum : i64
58+
}
59+
}
60+
"""
61+
)
62+
63+
def custom_pass_1(op):
64+
print("hello from pass 1!!!", file=sys.stderr)
65+
66+
def custom_pass_2(op):
67+
apply_patterns_and_fold_greedily_with_op(op, frozen)
68+
69+
pm = PassManager("any")
70+
pm.enable_ir_printing()
71+
72+
# CHECK: hello from pass 1!!!
73+
# CHECK-LABEL: Dump After custom_pass_1
74+
# CHECK-LABEL: Dump After CustomPass2
75+
# CHECK: arith.muli
76+
pm.add_python_pass(custom_pass_1)
77+
pm.add_python_pass(custom_pass_2, "CustomPass2")
78+
# # CHECK-LABEL: Dump After ArithToLLVMConversionPass
79+
# # CHECK: llvm.mul
80+
pm.add("convert-arith-to-llvm")
81+
pm.run(module)

0 commit comments

Comments
 (0)