Skip to content

Commit e031649

Browse files
Added lit tests
1 parent e4591a1 commit e031649

File tree

6 files changed

+139
-4
lines changed

6 files changed

+139
-4
lines changed

src/enzyme_ad/jax/Dialect/Tessera/Ops.td

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,10 @@ def CallOp : TesseraOp<"call",
119119

120120
let results = (outs Variadic<AnyType>);
121121

122+
let assemblyFormat = [{
123+
$callee `(` $operands `)` attr-dict `:` functional-type($operands, results)
124+
}];
125+
122126
let builders = [
123127
OpBuilder<(ins "DefineOp":$callee, CArg<"ValueRange", "{}">:$operands), [{
124128
$_state.addOperands(operands);

src/enzyme_ad/jax/Passes/Tessera/FuncToTessera.cpp

Lines changed: 31 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
#include "src/enzyme_ad/jax/Dialect/Tessera/Dialect.h"
1919
#include "src/enzyme_ad/jax/Passes/Tessera/Passes.h"
2020
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
21+
#include "mlir/IR/BuiltinDialect.h"
2122

2223
using namespace mlir;
2324
using namespace mlir::enzyme;
@@ -59,6 +60,20 @@ class FuncOpRewrite final : public OpRewritePattern<func::FuncOp> {
5960
funcOp.getBody().cloneInto(&tesseraDefineOp.getBody(),
6061
tesseraDefineOp.getBody().end(),
6162
mapper);
63+
64+
// Now walk through the cloned operations and convert func.return to tessera.return
65+
tesseraDefineOp.walk([&](func::ReturnOp returnOp) {
66+
rewriter.setInsertionPoint(returnOp);
67+
rewriter.replaceOpWithNewOp<tessera::ReturnOp>(returnOp, returnOp.getOperands());
68+
});
69+
70+
// Convert func.call to tessera.call
71+
tesseraDefineOp.walk([&](func::CallOp callOp) {
72+
rewriter.setInsertionPoint(callOp);
73+
rewriter.replaceOpWithNewOp<tessera::CallOp>(callOp, callOp.getResultTypes(),
74+
callOp.getOperands(),
75+
callOp->getAttrs());
76+
});
6277
}
6378

6479
rewriter.eraseOp(funcOp);
@@ -81,7 +96,7 @@ class CallOpRewrite final : public OpRewritePattern<func::CallOp> {
8196
Operation *calleeOp = SymbolTable::lookupSymbolIn(moduleOp, calleeAttr);
8297

8398
// Only convert if the callee is a Tessera DefineOp
84-
if (isa<tessera::DefineOp>(calleeOp))
99+
if (!isa<tessera::DefineOp>(calleeOp))
85100
return rewriter.notifyMatchFailure(callOp, "Callee is not a Tessera DefineOp");
86101

87102
rewriter.replaceOpWithNewOp<tessera::CallOp>(callOp, callOp.getResultTypes(),
@@ -122,15 +137,28 @@ namespace mlir::enzyme::tessera {
122137
struct FuncToTesseraPass
123138
: public PassWrapper<FuncToTesseraPass, OperationPass<ModuleOp>> {
124139

140+
StringRef getArgument() const final { return "func-to-tessera"; }
141+
StringRef getDescription() const final { return "Convert func dialect to tessera dialect."; }
142+
143+
void getDependentDialects(DialectRegistry &registry) const override {
144+
registry.insert<tessera::TesseraDialect>();
145+
}
146+
125147
void runOnOperation() override {
126148
MLIRContext *ctx = &getContext();
149+
150+
ConversionTarget target(*ctx);
151+
target.addLegalDialect<tessera::TesseraDialect>();
152+
target.addLegalDialect<BuiltinDialect>();
153+
target.addIllegalDialect<func::FuncDialect>();
154+
127155
RewritePatternSet patterns(ctx);
128156

129157
patterns.add<FuncOpRewrite, CallOpRewrite, ReturnOpRewrite>(ctx);
130158

131-
if (failed(applyPatternsAndFoldGreedily(getOperation(),
132-
std::move(patterns))))
159+
if (failed(applyFullConversion(getOperation(), target, std::move(patterns))))
133160
signalPassFailure();
161+
134162
}
135163
};
136164

src/enzyme_ad/jax/Passes/Tessera/TesseraToFunc.cpp

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ class DefineOpRewrite final : public OpRewritePattern<tessera::DefineOp> {
4444

4545

4646
// Create the `func.func` op
47-
auto funcOp = rewriter.create<tessera::DefineOp>(
47+
auto funcOp = rewriter.create<func::FuncOp>(
4848
defineOp.getLoc(), defineOp.getName(), fnType);
4949

5050

@@ -111,6 +111,13 @@ namespace mlir::enzyme::tessera {
111111
struct TesseraToFuncPass
112112
: public PassWrapper<TesseraToFuncPass, OperationPass<ModuleOp>> {
113113

114+
StringRef getArgument() const final { return "tessera-to-func"; }
115+
StringRef getDescription() const final { return "Convert tessera dialect to func dialect."; }
116+
117+
void getDependentDialects(DialectRegistry &registry) const override {
118+
registry.insert<func::FuncDialect>();
119+
}
120+
114121
void runOnOperation() override {
115122
MLIRContext *ctx = &getContext();
116123
RewritePatternSet patterns(ctx);
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
// RUN: enzymexlamlir-opt %s -func-to-tessera | FileCheck %s
2+
3+
// CHECK-LABEL: tessera.define @simple_func
4+
func.func @simple_func() {
5+
// CHECK: tessera.return
6+
func.return
7+
}
8+
9+
// -----
10+
11+
// CHECK-LABEL: tessera.define @func_with_args
12+
func.func @func_with_args(%arg0: i32, %arg1: f32) -> i32 {
13+
// CHECK: tessera.return %arg0 : i32
14+
func.return %arg0 : i32
15+
}
16+
17+
// -----
18+
19+
// CHECK-LABEL: tessera.define @helper
20+
func.func @helper() {
21+
func.return
22+
}
23+
24+
// CHECK-LABEL: tessera.define @func_with_call
25+
func.func @func_with_call() {
26+
// CHECK: tessera.call @helper() : () -> ()
27+
func.call @helper() : () -> ()
28+
// CHECK: tessera.return
29+
func.return
30+
}
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
// RUN: enzymexlamlir-opt %s | FileCheck %s
2+
3+
// CHECK-LABEL: tessera.define @foo
4+
tessera.define @foo() {
5+
// CHECK: tessera.return
6+
tessera.return
7+
}
8+
9+
// -----
10+
11+
// CHECK-LABEL: tessera.define @bar
12+
tessera.define @bar() -> i32 {
13+
%c42_i32 = arith.constant 42 : i32
14+
// CHECK: tessera.return %{{.*}} : i32
15+
tessera.return %c42_i32 : i32
16+
}
17+
18+
// -----
19+
20+
// CHECK-LABEL: tessera.define @caller
21+
tessera.define @caller() {
22+
// CHECK: tessera.call @foo() : () -> ()
23+
tessera.call @foo() : () -> ()
24+
// CHECK: tessera.return
25+
tessera.return
26+
}
27+
28+
// -----
29+
30+
// CHECK-LABEL: tessera.define @with_args
31+
tessera.define @with_args(%arg0: i32, %arg1: f32) -> i32 {
32+
// CHECK: %[[V0:.*]] = tessera.call @bar() : () -> i32
33+
%0 = tessera.call @bar() : () -> i32
34+
// CHECK: tessera.return %[[V0]] : i32
35+
tessera.return %0 : i32
36+
}
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
// RUN: enzymexlamlir-opt %s -tessera-to-func | FileCheck %s
2+
3+
// CHECK-LABEL: func.func @simple_func
4+
tessera.define @simple_func() {
5+
// CHECK: func.return
6+
tessera.return
7+
}
8+
9+
// -----
10+
11+
// CHECK-LABEL: func.func @func_with_args
12+
tessera.define @func_with_args(%arg0: i32, %arg1: f32) -> i32 {
13+
// CHECK: func.return %arg0 : i32
14+
tessera.return %arg0 : i32
15+
}
16+
17+
// -----
18+
19+
// CHECK-LABEL: func.func @helper
20+
tessera.define @helper() {
21+
tessera.return
22+
}
23+
24+
// CHECK-LABEL: func.func @func_with_call
25+
tessera.define @func_with_call() {
26+
// CHECK: func.call @helper() : () -> ()
27+
tessera.call @helper() : () -> ()
28+
// CHECK: func.return
29+
tessera.return
30+
}

0 commit comments

Comments
 (0)