diff --git a/mlir/lib/Bindings/Python/MainModule.cpp b/mlir/lib/Bindings/Python/MainModule.cpp index 278847e7ac7f5..d7282b3d6f713 100644 --- a/mlir/lib/Bindings/Python/MainModule.cpp +++ b/mlir/lib/Bindings/Python/MainModule.cpp @@ -136,7 +136,7 @@ NB_MODULE(_mlir, m) { populateRewriteSubmodule(rewriteModule); // Define and populate PassManager submodule. - auto passModule = + auto passManagerModule = m.def_submodule("passmanager", "MLIR Pass Management Bindings"); - populatePassManagerSubmodule(passModule); + populatePassManagerSubmodule(passManagerModule); } diff --git a/mlir/lib/Bindings/Python/Pass.cpp b/mlir/lib/Bindings/Python/Pass.cpp index 88e28dca76bb9..2f9030239cd87 100644 --- a/mlir/lib/Bindings/Python/Pass.cpp +++ b/mlir/lib/Bindings/Python/Pass.cpp @@ -10,8 +10,10 @@ #include "IRModule.h" #include "mlir-c/Pass.h" +// clang-format off #include "mlir/Bindings/Python/Nanobind.h" #include "mlir-c/Bindings/Python/Interop.h" // This is expected after nanobind. +// clang-format on namespace nb = nanobind; using namespace nb::literals; @@ -157,6 +159,45 @@ void mlir::python::populatePassManagerSubmodule(nb::module_ &m) { "pipeline"_a, "Add textual pipeline elements to the pass manager. Throws a " "ValueError if the pipeline can't be parsed.") + .def( + "add", + [](PyPassManager &passManager, const nb::callable &run, + std::optional &name, const std::string &argument, + const std::string &description, const std::string &opName) { + if (!name.has_value()) { + name = nb::cast( + nb::borrow(run.attr("__name__"))); + } + MlirTypeIDAllocator typeIDAllocator = mlirTypeIDAllocatorCreate(); + MlirTypeID passID = + mlirTypeIDAllocatorAllocateTypeID(typeIDAllocator); + MlirExternalPassCallbacks callbacks; + callbacks.construct = [](void *obj) { + (void)nb::handle(static_cast(obj)).inc_ref(); + }; + callbacks.destruct = [](void *obj) { + (void)nb::handle(static_cast(obj)).dec_ref(); + }; + callbacks.initialize = nullptr; + callbacks.clone = [](void *) -> void * { + throw std::runtime_error("Cloning Python passes not supported"); + }; + callbacks.run = [](MlirOperation op, MlirExternalPass, + void *userData) { + nb::borrow(static_cast(userData))(op); + }; + auto externalPass = mlirCreateExternalPass( + passID, mlirStringRefCreate(name->data(), name->length()), + mlirStringRefCreate(argument.data(), argument.length()), + mlirStringRefCreate(description.data(), description.length()), + mlirStringRefCreate(opName.data(), opName.size()), + /*nDependentDialects*/ 0, /*dependentDialects*/ nullptr, + callbacks, /*userData*/ run.ptr()); + mlirPassManagerAddOwnedPass(passManager.get(), externalPass); + }, + "run"_a, "name"_a.none() = nb::none(), "argument"_a.none() = "", + "description"_a.none() = "", "op_name"_a.none() = "", + "Add a python-defined pass to the pass manager.") .def( "run", [](PyPassManager &passManager, PyOperationBase &op) { diff --git a/mlir/lib/CAPI/IR/Pass.cpp b/mlir/lib/CAPI/IR/Pass.cpp index 3c499c3e4974d..b0a6ec1ace3cc 100644 --- a/mlir/lib/CAPI/IR/Pass.cpp +++ b/mlir/lib/CAPI/IR/Pass.cpp @@ -145,10 +145,14 @@ class ExternalPass : public Pass { : Pass(passID, opName), id(passID), name(name), argument(argument), description(description), dependentDialects(dependentDialects), callbacks(callbacks), userData(userData) { - callbacks.construct(userData); + if (callbacks.construct) + callbacks.construct(userData); } - ~ExternalPass() override { callbacks.destruct(userData); } + ~ExternalPass() override { + if (callbacks.destruct) + callbacks.destruct(userData); + } StringRef getName() const override { return name; } StringRef getArgument() const override { return argument; } diff --git a/mlir/test/python/python_pass.py b/mlir/test/python/python_pass.py new file mode 100644 index 0000000000000..c94f96e20966f --- /dev/null +++ b/mlir/test/python/python_pass.py @@ -0,0 +1,88 @@ +# RUN: %PYTHON %s 2>&1 | FileCheck %s + +import gc, sys +from mlir.ir import * +from mlir.passmanager import * +from mlir.dialects.builtin import ModuleOp +from mlir.dialects import pdl +from mlir.rewrite import * + + +def log(*args): + print(*args, file=sys.stderr) + sys.stderr.flush() + + +def run(f): + log("\nTEST:", f.__name__) + f() + gc.collect() + assert Context._get_live_count() == 0 + + +def make_pdl_module(): + with Location.unknown(): + pdl_module = Module.create() + with InsertionPoint(pdl_module.body): + # Change all arith.addi with index types to arith.muli. + @pdl.pattern(benefit=1, sym_name="addi_to_mul") + def pat(): + # Match arith.addi with index types. + i64_type = pdl.TypeOp(IntegerType.get_signless(64)) + operand0 = pdl.OperandOp(i64_type) + operand1 = pdl.OperandOp(i64_type) + op0 = pdl.OperationOp( + name="arith.addi", args=[operand0, operand1], types=[i64_type] + ) + + # Replace the matched op with arith.muli. + @pdl.rewrite() + def rew(): + newOp = pdl.OperationOp( + name="arith.muli", args=[operand0, operand1], types=[i64_type] + ) + pdl.ReplaceOp(op0, with_op=newOp) + + return pdl_module + + +# CHECK-LABEL: TEST: testCustomPass +@run +def testCustomPass(): + with Context(): + pdl_module = make_pdl_module() + frozen = PDLModule(pdl_module).freeze() + + module = ModuleOp.parse( + r""" + module { + func.func @add(%a: i64, %b: i64) -> i64 { + %sum = arith.addi %a, %b : i64 + return %sum : i64 + } + } + """ + ) + + def custom_pass_1(op): + print("hello from pass 1!!!", file=sys.stderr) + + class CustomPass2: + def __call__(self, m): + apply_patterns_and_fold_greedily(m, frozen) + + custom_pass_2 = CustomPass2() + + pm = PassManager("any") + pm.enable_ir_printing() + + # CHECK: hello from pass 1!!! + # CHECK-LABEL: Dump After custom_pass_1 + pm.add(custom_pass_1) + # CHECK-LABEL: Dump After CustomPass2 + # CHECK: arith.muli + pm.add(custom_pass_2, "CustomPass2") + # CHECK-LABEL: Dump After ArithToLLVMConversionPass + # CHECK: llvm.mul + pm.add("convert-arith-to-llvm") + pm.run(module)