Skip to content

Commit 98f3f35

Browse files
authored
feat: partial support for broadcast elementwise trait (#1602)
1 parent 47d57c1 commit 98f3f35

File tree

4 files changed

+67
-10
lines changed

4 files changed

+67
-10
lines changed

src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3086,7 +3086,7 @@ struct SliceElementwise final
30863086
auto elem = op.getOperand().getDefiningOp();
30873087
if (!elem)
30883088
return failure();
3089-
if (!elem->hasTrait<mlir::OpTrait::Elementwise>())
3089+
if (!stablehlo::hasTraitElementwise(elem))
30903090
return failure();
30913091
if (llvm::hasSingleElement(elem->getUsers())) {
30923092
SmallVector<Value> ops;
@@ -3772,7 +3772,7 @@ struct FullReduceReshapeOrTranspose final
37723772
reshapeOrTransposes.push_back(rs);
37733773
continue;
37743774
}
3775-
if (!curOp->hasTrait<mlir::OpTrait::Elementwise>())
3775+
if (!hasTraitElementwise(curOp))
37763776
return failure();
37773777
if (!isMemoryEffectFree(curOp))
37783778
return failure();
@@ -10469,7 +10469,7 @@ struct SliceReshapeElementwise final
1046910469
auto elem = reshape.getOperand().getDefiningOp();
1047010470
if (!elem)
1047110471
return failure();
10472-
if (!elem->hasTrait<mlir::OpTrait::Elementwise>())
10472+
if (!stablehlo::hasTraitElementwise(elem))
1047310473
return failure();
1047410474
if (!llvm::hasSingleElement(elem->getUsers()))
1047510475
return failure();
@@ -10515,7 +10515,7 @@ struct TransposeElementwise final
1051510515
if (!elem)
1051610516
return failure();
1051710517

10518-
if (!elem->hasTrait<mlir::OpTrait::Elementwise>())
10518+
if (!stablehlo::hasTraitElementwise(elem))
1051910519
return failure();
1052010520

1052110521
bool singleUser = llvm::hasSingleElement(elem->getUsers());
@@ -10913,7 +10913,7 @@ struct ReshapeElementwise final
1091310913
if (onlySingleUser && !llvm::hasSingleElement(elem->getUsers()))
1091410914
return failure();
1091510915

10916-
if (!elem->hasTrait<mlir::OpTrait::Elementwise>())
10916+
if (!stablehlo::hasTraitElementwise(elem))
1091710917
return failure();
1091810918

1091910919
SmallVector<Value> ops;
@@ -18958,8 +18958,8 @@ LogicalResult commUnaryOpElementwise(bool onlySingleUser, EnzymeOp op,
1895818958

1895918959
bool anyModified = false;
1896018960
for (auto elem : llvm::make_early_inc_range(op->getUsers())) {
18961-
if (!elem->template hasTrait<mlir::OpTrait::Elementwise>() ||
18962-
elem->getNumResults() != 1 || elem->getNumOperands() != 1)
18961+
if (!hasTraitElementwise(elem) || elem->getNumResults() != 1 ||
18962+
elem->getNumOperands() != 1)
1896318963
continue;
1896418964

1896518965
auto newOp = rewriter.create(
@@ -20270,7 +20270,7 @@ struct ConcatElementwise final
2027020270
if (isa<stablehlo::ConvertOp>(vdefOp)) // Conflicts with ConvertConcat
2027120271
return failure();
2027220272

20273-
if (vdefOp->hasTrait<mlir::OpTrait::Elementwise>()) {
20273+
if (hasTraitElementwise(vdefOp)) {
2027420274
if (concatOpOperands.size() != 0) {
2027520275
if (!OperationEquivalence::isEquivalentTo(
2027620276
concatOpOperands[0], vdefOp,
@@ -20671,7 +20671,7 @@ struct ConcatReshapeElementwise final
2067120671
if (!vdefOp)
2067220672
return failure();
2067320673

20674-
if (vdefOp->hasTrait<mlir::OpTrait::Elementwise>()) {
20674+
if (hasTraitElementwise(vdefOp)) {
2067520675
if (concatOpOperands.size() != 0) {
2067620676
if (!OperationEquivalence::isEquivalentTo(
2067720677
concatOpOperands[0], vdefOp,
@@ -21259,7 +21259,7 @@ struct GatherElementwise
2125921259
PatternRewriter &rewriter) const {
2126021260
auto gatherInput = op.getOperand();
2126121261
auto defOp = gatherInput.getDefiningOp();
21262-
if (!defOp || !defOp->hasTrait<mlir::OpTrait::Elementwise>())
21262+
if (!defOp || !hasTraitElementwise(defOp))
2126321263
return rewriter.notifyMatchFailure(op,
2126421264
"GatherOp with non-elementwise input");
2126521265

src/enzyme_ad/jax/Utils.cpp

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1380,6 +1380,38 @@ Value reshapeAxisOutOf(OpBuilder &builder, Value input,
13801380
builder.getDenseI64ArrayAttr(permutation));
13811381
}
13821382

1383+
bool hasTraitElementwise(Operation *op) {
1384+
if (op->hasTrait<OpTrait::Elementwise>())
1385+
return true;
1386+
1387+
if (op->hasTrait<hlo::OpTrait::BroadcastingElementwise>()) {
1388+
// Check sizes (shapes) match across operands, not the exact types.
1389+
auto refShapedTy = dyn_cast<ShapedType>(op->getOperand(0).getType());
1390+
if (!refShapedTy)
1391+
return false;
1392+
1393+
for (auto operand : op->getOperands()) {
1394+
auto curShapedTy = dyn_cast<ShapedType>(operand.getType());
1395+
if (!curShapedTy)
1396+
return false;
1397+
1398+
if (curShapedTy.getRank() != refShapedTy.getRank())
1399+
return false;
1400+
1401+
for (int64_t i = 0; i < curShapedTy.getRank(); ++i) {
1402+
int64_t a = curShapedTy.getDimSize(i);
1403+
int64_t b = refShapedTy.getDimSize(i);
1404+
// If both are static and different, sizes don't match.
1405+
if (a != ShapedType::kDynamic && b != ShapedType::kDynamic && a != b)
1406+
return false;
1407+
}
1408+
}
1409+
return true;
1410+
}
1411+
1412+
return false;
1413+
}
1414+
13831415
} // namespace stablehlo
13841416

13851417
} // namespace mlir

src/enzyme_ad/jax/Utils.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -798,6 +798,11 @@ mlir::Value reshapeAxisInto(OpBuilder &builder, Value input,
798798
mlir::Value reshapeAxisOutOf(OpBuilder &builder, Value input,
799799
ArrayRef<int64_t> &batchSizes, int64_t dim);
800800

801+
// matches for hasTrait<OpTrait::Elementwise>. Additionally matches for
802+
// hasTrait<OpTrait::HLOBroadcastingElementwise> if all of the operands are
803+
// of the same shape.
804+
bool hasTraitElementwise(Operation *op);
805+
801806
} // namespace stablehlo
802807

803808
} // namespace mlir
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
// RUN: enzymexlamlir-opt --enzyme-hlo-generate-td="patterns=transpose_elementwise(0);transpose_simplify;transpose_transpose" --transform-interpreter --enzyme-hlo-remove-transform %s | FileCheck %s
2+
3+
module {
4+
func.func @main(%arg0: tensor<4x3xf32>) -> tensor<4x3xf32> {
5+
%cst = stablehlo.constant dense<6.000000e+00> : tensor<3x4xf32>
6+
%cst_0 = stablehlo.constant dense<0.000000e+00> : tensor<3x4xf32>
7+
%0 = stablehlo.transpose %arg0, dims = [1, 0] : (tensor<4x3xf32>) -> tensor<3x4xf32>
8+
%1 = stablehlo.clamp %cst_0, %0, %cst : tensor<3x4xf32>
9+
%2 = stablehlo.transpose %1, dims = [1, 0] : (tensor<3x4xf32>) -> tensor<4x3xf32>
10+
return %2 : tensor<4x3xf32>
11+
}
12+
}
13+
14+
// CHECK: func.func @main(%arg0: tensor<4x3xf32>) -> tensor<4x3xf32> {
15+
// CHECK-NEXT: %cst = stablehlo.constant dense<6.000000e+00> : tensor<4x3xf32>
16+
// CHECK-NEXT: %cst_0 = stablehlo.constant dense<0.000000e+00> : tensor<4x3xf32>
17+
// CHECK-NEXT: %0 = stablehlo.transpose %arg0, dims = [0, 1] : (tensor<4x3xf32>) -> tensor<4x3xf32>
18+
// CHECK-NEXT: %1 = stablehlo.clamp %cst_0, %0, %cst : tensor<4x3xf32>
19+
// CHECK-NEXT: return %1 : tensor<4x3xf32>
20+
// CHECK-NEXT: }

0 commit comments

Comments
 (0)