|
10 | 10 |
|
11 | 11 | #include "IRModule.h"
|
12 | 12 | #include "mlir-c/Pass.h"
|
| 13 | +// clang-format off |
13 | 14 | #include "mlir/Bindings/Python/Nanobind.h"
|
14 | 15 | #include "mlir-c/Bindings/Python/Interop.h" // This is expected after nanobind.
|
| 16 | +// clang-format on |
15 | 17 |
|
16 | 18 | namespace nb = nanobind;
|
17 | 19 | using namespace nb::literals;
|
@@ -157,6 +159,45 @@ void mlir::python::populatePassManagerSubmodule(nb::module_ &m) {
|
157 | 159 | "pipeline"_a,
|
158 | 160 | "Add textual pipeline elements to the pass manager. Throws a "
|
159 | 161 | "ValueError if the pipeline can't be parsed.")
|
| 162 | + .def( |
| 163 | + "add_python_pass", |
| 164 | + [](PyPassManager &passManager, const nb::callable &run, |
| 165 | + std::optional<std::string> &name, const std::string &argument, |
| 166 | + const std::string &description, const std::string &opName) { |
| 167 | + if (!name.has_value()) { |
| 168 | + name = nb::cast<std::string>( |
| 169 | + nb::borrow<nb::str>(run.attr("__name__"))); |
| 170 | + } |
| 171 | + MlirTypeIDAllocator typeIDAllocator = mlirTypeIDAllocatorCreate(); |
| 172 | + MlirTypeID passID = |
| 173 | + mlirTypeIDAllocatorAllocateTypeID(typeIDAllocator); |
| 174 | + MlirExternalPassCallbacks callbacks; |
| 175 | + callbacks.construct = [](void *obj) { |
| 176 | + (void)nb::handle(static_cast<PyObject *>(obj)).inc_ref(); |
| 177 | + }; |
| 178 | + callbacks.destruct = [](void *obj) { |
| 179 | + (void)nb::handle(static_cast<PyObject *>(obj)).dec_ref(); |
| 180 | + }; |
| 181 | + callbacks.initialize = nullptr; |
| 182 | + callbacks.clone = [](void *) -> void * { |
| 183 | + throw std::runtime_error("Cloning Python passes not supported"); |
| 184 | + }; |
| 185 | + callbacks.run = [](MlirOperation op, MlirExternalPass, |
| 186 | + void *userData) { |
| 187 | + nb::borrow<nb::callable>(static_cast<PyObject *>(userData))(op); |
| 188 | + }; |
| 189 | + auto externalPass = mlirCreateExternalPass( |
| 190 | + passID, mlirStringRefCreate(name->data(), name->length()), |
| 191 | + mlirStringRefCreate(argument.data(), argument.length()), |
| 192 | + mlirStringRefCreate(description.data(), description.length()), |
| 193 | + mlirStringRefCreate(opName.data(), opName.size()), |
| 194 | + /*nDependentDialects*/ 0, /*dependentDialects*/ nullptr, |
| 195 | + callbacks, /*userData*/ run.ptr()); |
| 196 | + mlirPassManagerAddOwnedPass(passManager.get(), externalPass); |
| 197 | + }, |
| 198 | + "run"_a, "name"_a.none() = nb::none(), "argument"_a.none() = "", |
| 199 | + "description"_a.none() = "", "op_name"_a.none() = "", |
| 200 | + "Add a python-defined pass to the pass manager.") |
160 | 201 | .def(
|
161 | 202 | "run",
|
162 | 203 | [](PyPassManager &passManager, PyOperationBase &op) {
|
|
0 commit comments