Skip to content

Commit 5289c12

Browse files
committed
[MLIR][Python] Support Python-defined passes in MLIR
based heavily on #156000
1 parent 64b9896 commit 5289c12

File tree

7 files changed

+160
-16
lines changed

7 files changed

+160
-16
lines changed

mlir/include/mlir-c/Rewrite.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -301,6 +301,10 @@ mlirFreezeRewritePattern(MlirRewritePatternSet op);
301301
MLIR_CAPI_EXPORTED void
302302
mlirFrozenRewritePatternSetDestroy(MlirFrozenRewritePatternSet op);
303303

304+
MLIR_CAPI_EXPORTED MlirLogicalResult mlirApplyPatternsAndFoldGreedilyWithOp(
305+
MlirOperation op, MlirFrozenRewritePatternSet patterns,
306+
MlirGreedyRewriteDriverConfig);
307+
304308
MLIR_CAPI_EXPORTED MlirLogicalResult mlirApplyPatternsAndFoldGreedily(
305309
MlirModule op, MlirFrozenRewritePatternSet patterns,
306310
MlirGreedyRewriteDriverConfig);

mlir/lib/Bindings/Python/MainModule.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,7 @@ NB_MODULE(_mlir, m) {
136136
populateRewriteSubmodule(rewriteModule);
137137

138138
// Define and populate PassManager submodule.
139-
auto passModule =
139+
auto passManagerModule =
140140
m.def_submodule("passmanager", "MLIR Pass Management Bindings");
141-
populatePassManagerSubmodule(passModule);
141+
populatePassManagerSubmodule(passManagerModule);
142142
}

mlir/lib/Bindings/Python/Pass.cpp

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,10 @@
1010

1111
#include "IRModule.h"
1212
#include "mlir-c/Pass.h"
13+
// clang-format off
1314
#include "mlir/Bindings/Python/Nanobind.h"
1415
#include "mlir-c/Bindings/Python/Interop.h" // This is expected after nanobind.
16+
// clang-format on
1517

1618
namespace nb = nanobind;
1719
using namespace nb::literals;
@@ -157,6 +159,38 @@ void mlir::python::populatePassManagerSubmodule(nb::module_ &m) {
157159
"pipeline"_a,
158160
"Add textual pipeline elements to the pass manager. Throws a "
159161
"ValueError if the pipeline can't be parsed.")
162+
.def(
163+
"add_python_pass",
164+
[](PyPassManager &passManager, const std::string &name,
165+
const std::string &argument, const std::string &description,
166+
const std::string &opName, const nb::callable &run) {
167+
MlirTypeIDAllocator typeIDAllocator = mlirTypeIDAllocatorCreate();
168+
MlirTypeID passID =
169+
mlirTypeIDAllocatorAllocateTypeID(typeIDAllocator);
170+
MlirExternalPassCallbacks callbacks;
171+
callbacks.construct = [](void *obj) {
172+
(void)nb::handle(static_cast<PyObject *>(obj)).inc_ref();
173+
};
174+
callbacks.destruct = [](void *obj) {
175+
(void)nb::handle(static_cast<PyObject *>(obj)).dec_ref();
176+
};
177+
callbacks.run = [](MlirOperation op, MlirExternalPass,
178+
void *userData) {
179+
nb::steal<nb::callable>(static_cast<PyObject *>(userData))(op);
180+
};
181+
callbacks.clone = nullptr;
182+
callbacks.initialize = nullptr;
183+
auto externalPass = mlirCreateExternalPass(
184+
passID, mlirStringRefCreate(name.data(), name.length()),
185+
mlirStringRefCreate(argument.data(), argument.length()),
186+
mlirStringRefCreate(description.data(), description.length()),
187+
mlirStringRefCreate(opName.data(), opName.size()),
188+
/*nDependentDialects*/ 0, /*dependentDialects*/ nullptr,
189+
callbacks, /*userData*/ run.ptr());
190+
mlirPassManagerAddOwnedPass(passManager.get(), externalPass);
191+
},
192+
"name"_a, "argument"_a, "description"_a, "op_name"_a, "run"_a,
193+
"Add a python-defined pass to the pass manager.")
160194
.def(
161195
"run",
162196
[](PyPassManager &passManager, PyOperationBase &op,

mlir/lib/Bindings/Python/Rewrite.cpp

Lines changed: 23 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,9 @@
99
#include "Rewrite.h"
1010

1111
#include "IRModule.h"
12+
#include "mlir-c/Bindings/Python/Interop.h" // This is expected after nanobind.
1213
#include "mlir-c/Rewrite.h"
1314
#include "mlir/Bindings/Python/Nanobind.h"
14-
#include "mlir-c/Bindings/Python/Interop.h" // This is expected after nanobind.
1515
#include "mlir/Config/mlir-config.h"
1616

1717
namespace nb = nanobind;
@@ -99,14 +99,26 @@ void mlir::python::populateRewriteSubmodule(nb::module_ &m) {
9999
.def(MLIR_PYTHON_CAPI_FACTORY_ATTR,
100100
&PyFrozenRewritePatternSet::createFromCapsule);
101101
m.def(
102-
"apply_patterns_and_fold_greedily",
103-
[](MlirModule module, MlirFrozenRewritePatternSet set) {
104-
auto status = mlirApplyPatternsAndFoldGreedily(module, set, {});
105-
if (mlirLogicalResultIsFailure(status))
106-
// FIXME: Not sure this is the right error to throw here.
107-
throw nb::value_error("pattern application failed to converge");
108-
},
109-
"module"_a, "set"_a,
110-
"Applys the given patterns to the given module greedily while folding "
111-
"results.");
102+
"apply_patterns_and_fold_greedily",
103+
[](MlirModule module, MlirFrozenRewritePatternSet set) {
104+
auto status = mlirApplyPatternsAndFoldGreedily(module, set, {});
105+
if (mlirLogicalResultIsFailure(status))
106+
// FIXME: Not sure this is the right error to throw here.
107+
throw nb::value_error("pattern application failed to converge");
108+
},
109+
"module"_a, "set"_a,
110+
"Applys the given patterns to the given module greedily while folding "
111+
"results.")
112+
.def(
113+
"apply_patterns_and_fold_greedily_with_op",
114+
[](MlirOperation op, MlirFrozenRewritePatternSet set) {
115+
auto status = mlirApplyPatternsAndFoldGreedilyWithOp(op, set, {});
116+
if (mlirLogicalResultIsFailure(status)) {
117+
throw std::runtime_error(
118+
"pattern application failed to converge");
119+
}
120+
},
121+
"op"_a, "set"_a,
122+
"Applys the given patterns to the given op greedily while folding "
123+
"results.");
112124
}

mlir/lib/CAPI/IR/Pass.cpp

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -145,10 +145,14 @@ class ExternalPass : public Pass {
145145
: Pass(passID, opName), id(passID), name(name), argument(argument),
146146
description(description), dependentDialects(dependentDialects),
147147
callbacks(callbacks), userData(userData) {
148-
callbacks.construct(userData);
148+
if (callbacks.construct)
149+
callbacks.construct(userData);
149150
}
150151

151-
~ExternalPass() override { callbacks.destruct(userData); }
152+
~ExternalPass() override {
153+
if (callbacks.destruct)
154+
callbacks.destruct(userData);
155+
}
152156

153157
StringRef getName() const override { return name; }
154158
StringRef getArgument() const override { return argument; }
@@ -180,7 +184,9 @@ class ExternalPass : public Pass {
180184
}
181185

182186
std::unique_ptr<Pass> clonePass() const override {
183-
void *clonedUserData = callbacks.clone(userData);
187+
void *clonedUserData;
188+
if (callbacks.clone)
189+
clonedUserData = callbacks.clone(userData);
184190
return std::make_unique<ExternalPass>(id, name, argument, description,
185191
getOpName(), dependentDialects,
186192
callbacks, clonedUserData);

mlir/lib/CAPI/Transforms/Rewrite.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -294,6 +294,13 @@ mlirApplyPatternsAndFoldGreedily(MlirModule op,
294294
return wrap(mlir::applyPatternsGreedily(unwrap(op), *unwrap(patterns)));
295295
}
296296

297+
MlirLogicalResult
298+
mlirApplyPatternsAndFoldGreedilyWithOp(MlirOperation op,
299+
MlirFrozenRewritePatternSet patterns,
300+
MlirGreedyRewriteDriverConfig) {
301+
return wrap(mlir::applyPatternsGreedily(unwrap(op), *unwrap(patterns)));
302+
}
303+
297304
//===----------------------------------------------------------------------===//
298305
/// PDLPatternModule API
299306
//===----------------------------------------------------------------------===//

mlir/test/python/pass.py

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
# RUN: %PYTHON %s 2>&1 | FileCheck %s
2+
3+
import gc, sys
4+
from mlir.ir import *
5+
from mlir.passmanager import *
6+
from mlir.dialects.builtin import ModuleOp
7+
from mlir.dialects import pdl
8+
from mlir.rewrite import *
9+
10+
11+
def run(f):
12+
# Note, everything in this file is dumped to stderr because that's where
13+
# `IR Dump After` dumps too (so we can't cross the "streams")
14+
print("\nTEST:", f.__name__, file=sys.stderr)
15+
f()
16+
gc.collect()
17+
assert Context._get_live_count() == 0
18+
19+
20+
def make_pdl_module():
21+
with Location.unknown():
22+
pdl_module = Module.create()
23+
with InsertionPoint(pdl_module.body):
24+
# Change all arith.addi with index types to arith.muli.
25+
@pdl.pattern(benefit=1, sym_name="addi_to_mul")
26+
def pat():
27+
# Match arith.addi with index types.
28+
i64_type = pdl.TypeOp(IntegerType.get_signless(64))
29+
operand0 = pdl.OperandOp(i64_type)
30+
operand1 = pdl.OperandOp(i64_type)
31+
op0 = pdl.OperationOp(
32+
name="arith.addi", args=[operand0, operand1], types=[i64_type]
33+
)
34+
35+
# Replace the matched op with arith.muli.
36+
@pdl.rewrite()
37+
def rew():
38+
newOp = pdl.OperationOp(
39+
name="arith.muli", args=[operand0, operand1], types=[i64_type]
40+
)
41+
pdl.ReplaceOp(op0, with_op=newOp)
42+
43+
return pdl_module
44+
45+
46+
# CHECK-LABEL: TEST: testCustomPass
47+
@run
48+
def testCustomPass():
49+
with Context():
50+
pdl_module = make_pdl_module()
51+
frozen = PDLModule(pdl_module).freeze()
52+
53+
module = ModuleOp.parse(
54+
r"""
55+
module {
56+
func.func @add(%a: i64, %b: i64) -> i64 {
57+
%sum = arith.addi %a, %b : i64
58+
return %sum : i64
59+
}
60+
}
61+
"""
62+
)
63+
64+
def run1(op):
65+
print("hello from pass 1!!!", file=sys.stderr)
66+
67+
def run2(op):
68+
apply_patterns_and_fold_greedily_with_op(op, frozen)
69+
70+
pm = PassManager("any")
71+
pm.enable_ir_printing()
72+
73+
# CHECK: hello from pass 1!!!
74+
# CHECK-LABEL: Dump After CustomPass
75+
# CHECK: arith.muli
76+
pm.add_python_pass("CustomPass1", "", "", "", run1)
77+
pm.add_python_pass("CustomPass2", "", "", "", run2)
78+
# # CHECK-LABEL: Dump After ArithToLLVMConversionPass
79+
# # CHECK: llvm.mul
80+
pm.add("convert-arith-to-llvm")
81+
pm.run(module)

0 commit comments

Comments
 (0)