Skip to content

Commit

Permalink
Fix function signature for constraints with return values
Browse files Browse the repository at this point in the history
  • Loading branch information
j2kun committed Jul 22, 2024
1 parent f33c843 commit b8051b5
Showing 1 changed file with 9 additions and 5 deletions.
14 changes: 9 additions & 5 deletions lib/Transform/Arith/MulToAddPdll.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#include "lib/Transform/Arith/MulToAddPdll.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/PDLInterp/IR/PDLInterp.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "mlir/include/mlir/Pass/Pass.h"
Expand All @@ -10,15 +11,18 @@ namespace tutorial {
#define GEN_PASS_DEF_MULTOADDPDLL
#include "lib/Transform/Arith/Passes.h.inc"

Attribute halveImpl(PatternRewriter &rewriter, Attribute attr) {
IntegerAttr cAttr = ::llvm::cast<::mlir::IntegerAttr>(attr);

LogicalResult halveImpl(PatternRewriter &rewriter, PDLResultList &results,
ArrayRef<PDLValue> args) {
Attribute attr = args[0].cast<Attribute>();
IntegerAttr cAttr = cast<IntegerAttr>(attr);
int64_t value = cAttr.getValue().getSExtValue();
return rewriter.getIntegerAttr(cAttr.getType(), value / 2);
results.push_back(rewriter.getIntegerAttr(cAttr.getType(), value / 2));
return success();
}

void registerNativeConstraints(RewritePatternSet &patterns) {
patterns.getPDLPatterns().registerConstraintFunction(
"Halve", halveImpl);
patterns.getPDLPatterns().registerConstraintFunction("Halve", halveImpl);
}

struct MulToAddPdll : impl::MulToAddPdllBase<MulToAddPdll> {
Expand Down

0 comments on commit b8051b5

Please sign in to comment.