diff --git a/lib/Transform/Arith/MulToAddPdll.cpp b/lib/Transform/Arith/MulToAddPdll.cpp index 36cc48a..f76c834 100644 --- a/lib/Transform/Arith/MulToAddPdll.cpp +++ b/lib/Transform/Arith/MulToAddPdll.cpp @@ -1,5 +1,6 @@ #include "lib/Transform/Arith/MulToAddPdll.h" #include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/PDLInterp/IR/PDLInterp.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "mlir/include/mlir/Pass/Pass.h" @@ -10,15 +11,18 @@ namespace tutorial { #define GEN_PASS_DEF_MULTOADDPDLL #include "lib/Transform/Arith/Passes.h.inc" -Attribute halveImpl(PatternRewriter &rewriter, Attribute attr) { - IntegerAttr cAttr = ::llvm::cast<::mlir::IntegerAttr>(attr); + +LogicalResult halveImpl(PatternRewriter &rewriter, PDLResultList &results, + ArrayRef args) { + Attribute attr = args[0].cast(); + IntegerAttr cAttr = cast(attr); int64_t value = cAttr.getValue().getSExtValue(); - return rewriter.getIntegerAttr(cAttr.getType(), value / 2); + results.push_back(rewriter.getIntegerAttr(cAttr.getType(), value / 2)); + return success(); } void registerNativeConstraints(RewritePatternSet &patterns) { - patterns.getPDLPatterns().registerConstraintFunction( - "Halve", halveImpl); + patterns.getPDLPatterns().registerConstraintFunction("Halve", halveImpl); } struct MulToAddPdll : impl::MulToAddPdllBase {