Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 70 additions & 34 deletions src/enzyme_ad/jax/Passes/OptimizeCommunication.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -667,6 +667,21 @@ bool isZero(Value v) {
return false;
}

SplatElementsAttr isSplat(ElementsAttr v) {
return dyn_cast<SplatElementsAttr>(v);
}

SplatElementsAttr isSplat(Value v) {
SplatElementsAttr elem;
if (matchPattern(v, m_Constant(&elem))) {
return elem;
}
if (auto sdyConstant = v.getDefiningOp<sdy::ConstantOp>()) {
return isSplat(sdyConstant.getValue());
}
return nullptr;
}

// TODO: we might need to update this to use the generalized version for the
// generateShiftPairs function
std::tuple<Value, Value, Value, Value, Value, Value>
Expand Down Expand Up @@ -2219,6 +2234,8 @@ struct DUSToPadComm : public OpRewritePattern<stablehlo::DynamicUpdateSliceOp> {
dus.getLoc(), rewriter.getZeroAttr(elementType));
auto one = rewriter.create<stablehlo::ConstantOp>(
dus.getLoc(), rewriter.getOneAttr(elementType));
auto oneI1 = rewriter.create<stablehlo::ConstantOp>(
dus.getLoc(), rewriter.getOneAttr(rewriter.getI1Type()));

SmallVector<int64_t> padInner(ndims, 0);

Expand All @@ -2230,49 +2247,68 @@ struct DUSToPadComm : public OpRewritePattern<stablehlo::DynamicUpdateSliceOp> {
operandShape[i] - updateShape[i] - constantStartIndices[i];
}
Value updatePad = nullptr;
if (!isZero(update)) {
auto updatePadOp = rewriter.create<stablehlo::PadOp>(
dus.getLoc(), update, zero, updatePadLow, updatePadHigh, padInner);
sdy::setSharding(updatePadOp, sharding);
updatePad = updatePadOp;
}

Value maskedOperand = nullptr;
if (!isZero(operand)) {
auto updateType = cast<RankedTensorType>(update.getType());
auto zeroAttr =
DenseElementsAttr::get(updateType, rewriter.getZeroAttr(elementType));

auto splatUpdate = isSplat(update);
auto splatOperand = isSplat(update);

auto updateType = cast<RankedTensorType>(update.getType());
auto updateI1Type =
RankedTensorType::get(updateType.getShape(), rewriter.getI1Type());
auto zeroAttr = DenseElementsAttr::get(
updateI1Type, rewriter.getZeroAttr(rewriter.getI1Type()));
auto zeroUpdateOp = rewriter.create<stablehlo::ConstantOp>(
dus.getLoc(), updateI1Type, zeroAttr);
sdy::setSharding(zeroUpdateOp, sharding);

auto maskOp = rewriter.create<stablehlo::PadOp>(dus.getLoc(), zeroUpdateOp,
oneI1, updatePadLow,
updatePadHigh, padInner);

Value resultV = nullptr;
if (splatOperand) {
auto padTy =
RankedTensorType::get({}, operand.getType().getElementType());
auto newOperand = rewriter.create<stablehlo::ConstantOp>(
dus.getLoc(), padTy, splatOperand.resizeSplat(padTy));
auto maskedOperandOp = rewriter.create<stablehlo::PadOp>(
dus.getLoc(), update, newOperand, updatePadLow, updatePadHigh,
padInner);
sdy::setSharding(maskedOperandOp, sharding);
}
{
Value newUpdate;
if (splatUpdate) {
newUpdate = rewriter.create<stablehlo::ConstantOp>(
dus.getLoc(), operand.getType(),
splatUpdate.resizeSplat(operand.getType()));
} else {
auto newUpdateOp = rewriter.create<stablehlo::PadOp>(
dus.getLoc(), operand.getType(), update, zero, updatePadLow,
updatePadHigh, padInner);
sdy::setSharding(newUpdateOp, sharding);
newUpdate = newUpdateOp;
}

auto updateI1Type =
RankedTensorType::get(updateType.getShape(), rewriter.getI1Type());

auto zeroAttr = DenseElementsAttr::get(
updateI1Type, rewriter.getZeroAttr(rewriter.getI1Type()));
auto zeroUpdateOp = rewriter.create<stablehlo::ConstantOp>(
dus.getLoc(), updateType, zeroAttr);
dus.getLoc(), updateI1Type, zeroAttr);
sdy::setSharding(zeroUpdateOp, sharding);

auto maskOp = rewriter.create<stablehlo::PadOp>(
dus.getLoc(), zeroUpdateOp, one, updatePadLow, updatePadHigh,
dus.getLoc(), zeroUpdateOp, oneI1, updatePadLow, updatePadHigh,
padInner);

sdy::setSharding(maskOp, sharding);

auto maskedOperandOp =
rewriter.create<stablehlo::MulOp>(dus.getLoc(), operand, maskOp);
auto maskedOperandOp = rewriter.create<stablehlo::SelectOp>(
dus.getLoc(), maskOp, operand, newUpdate);
sdy::setSharding(maskedOperandOp, sharding);
maskedOperand = maskedOperandOp;
}

Value resultV = nullptr;
if (maskedOperand && updatePad) {
auto result = rewriter.create<stablehlo::AddOp>(dus.getLoc(),
maskedOperand, updatePad);
sdy::setSharding(result, sharding);
resultV = result;
} else if (maskedOperand) {
resultV = maskedOperand;
} else if (updatePad) {
resultV = updatePad;
} else {
auto cst = rewriter.create<stablehlo::ConstantOp>(
dus.getLoc(), dus.getType(),
cast<ElementsAttr>(rewriter.getZeroAttr(dus.getType())));
sdy::setSharding(cst, sharding);
resultV = cst;
resultV = maskedOperandOp;
}

rewriter.replaceOp(dus, resultV);
Expand Down
52 changes: 26 additions & 26 deletions test/lit_tests/communication/dus.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,12 @@ func.func @constantUpdate1D(%arg21: tensor<20x24x96xf64> {sdy.sharding = #sdy.sh
}

// PAD: func.func @constantUpdate1D(%arg0: tensor<20x24x96xf64> {sdy.sharding = #sdy.sharding<@mesh, [{"z"}, {"y"}, {"x"}]>}) -> (tensor<20x24x96xf64> {sdy.sharding = #sdy.sharding<@mesh, [{"z"}, {"y"}, {"x"}]>}) {
// PAD-NEXT: %[[cst1:.+]] = stablehlo.constant dense<1.000000e+00> : tensor<f64>
// PAD-NEXT: %[[cst0:.+]] = stablehlo.constant {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"z"}, {"y"}, {"x"}]>]>} dense<0.000000e+00> : tensor<20x24x80xf64>
// PAD-NEXT: %[[p0:.+]] = stablehlo.pad %[[cst0]], %[[cst1]], low = [0, 0, 8], high = [0, 0, 8], interior = [0, 0, 0] {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"z"}, {"y"}, {"x"}]>]>} : (tensor<20x24x80xf64>, tensor<f64>) -> tensor<20x24x96xf64>
// PAD-NEXT: %[[m0:.+]] = stablehlo.multiply %arg0, %[[p0]] {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"z"}, {"y"}, {"x"}]>]>} : tensor<20x24x96xf64>
// PAD-NEXT: return %[[m0]] : tensor<20x24x96xf64>
// PAD-NEXT: %c = stablehlo.constant dense<true> : tensor<i1>
// PAD-NEXT: %cst = stablehlo.constant dense<0.000000e+00> : tensor<20x24x96xf64>
// PAD-NEXT: %c_0 = stablehlo.constant {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"z"}, {"y"}, {"x"}]>]>} dense<false> : tensor<20x24x80xi1>
// PAD-NEXT: %0 = stablehlo.pad %c_0, %c, low = [0, 0, 8], high = [0, 0, 8], interior = [0, 0, 0] {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"z"}, {"y"}, {"x"}]>]>} : (tensor<20x24x80xi1>, tensor<i1>) -> tensor<20x24x96xi1>
// PAD-NEXT: %1 = stablehlo.select %0, %arg0, %cst {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"z"}, {"y"}, {"x"}]>]>} : tensor<20x24x96xi1>, tensor<20x24x96xf64>
// PAD-NEXT: return %1 : tensor<20x24x96xf64>
// PAD-NEXT: }

// CHECK: func.func @constantUpdate1D(%arg0: tensor<20x24x96xf64> {sdy.sharding = #sdy.sharding<@mesh, [{"z"}, {"y"}, {"x"}]>}) -> (tensor<20x24x96xf64> {sdy.sharding = #sdy.sharding<@mesh, [{"z"}, {"y"}, {"x"}]>}) {
Expand Down Expand Up @@ -59,11 +60,12 @@ func.func @constantUpdate(%arg21: tensor<20x24x96xf64> {sdy.sharding = #sdy.shar
}

// PAD: func.func @constantUpdate(%arg0: tensor<20x24x96xf64> {sdy.sharding = #sdy.sharding<@mesh, [{"z"}, {"y"}, {"x"}]>}) -> (tensor<20x24x96xf64> {sdy.sharding = #sdy.sharding<@mesh, [{"z"}, {"y"}, {"x"}]>}) {
// PAD-NEXT: %[[cst1:.+]] = stablehlo.constant dense<1.000000e+00> : tensor<f64>
// PAD-NEXT: %[[cst0:.+]] = stablehlo.constant {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"z"}, {"y"}, {"x"}]>]>} dense<0.000000e+00> : tensor<4x1x80xf64>
// PAD-NEXT: %[[mask:.+]] = stablehlo.pad %[[cst0]], %[[cst1]], low = [8, 8, 8], high = [8, 15, 8], interior = [0, 0, 0] {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"z"}, {"y"}, {"x"}]>]>} : (tensor<4x1x80xf64>, tensor<f64>) -> tensor<20x24x96xf64>
// PAD-NEXT: %[[mul:.+]] = stablehlo.multiply %arg0, %[[mask]] {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"z"}, {"y"}, {"x"}]>]>} : tensor<20x24x96xf64>
// PAD-NEXT: return %[[mul]] : tensor<20x24x96xf64>
// PAD-NEXT: %c = stablehlo.constant dense<true> : tensor<i1>
// PAD-NEXT: %cst = stablehlo.constant dense<0.000000e+00> : tensor<20x24x96xf64>
// PAD-NEXT: %c_0 = stablehlo.constant {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"z"}, {"y"}, {"x"}]>]>} dense<false> : tensor<4x1x80xi1>
// PAD-NEXT: %0 = stablehlo.pad %c_0, %c, low = [8, 8, 8], high = [8, 15, 8], interior = [0, 0, 0] {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"z"}, {"y"}, {"x"}]>]>} : (tensor<4x1x80xi1>, tensor<i1>) -> tensor<20x24x96xi1>
// PAD-NEXT: %1 = stablehlo.select %0, %arg0, %cst {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"z"}, {"y"}, {"x"}]>]>} : tensor<20x24x96xi1>, tensor<20x24x96xf64>
// PAD-NEXT: return %1 : tensor<20x24x96xf64>
// PAD-NEXT: }

// CHECK: func.func @constantUpdate(%arg0: tensor<20x24x96xf64> {sdy.sharding = #sdy.sharding<@mesh, [{"z"}, {"y"}, {"x"}]>}) -> (tensor<20x24x96xf64> {sdy.sharding = #sdy.sharding<@mesh, [{"z"}, {"y"}, {"x"}]>}) {
Expand Down Expand Up @@ -108,14 +110,13 @@ func.func @argUpdate1D(%arg21: tensor<20x24x96xf64> {sdy.sharding = #sdy.shardin
}

// PAD: func.func @argUpdate1D(%arg0: tensor<20x24x96xf64> {sdy.sharding = #sdy.sharding<@mesh, [{"z"}, {"y"}, {"x"}]>}, %arg1: tensor<20x24x80xf64> {sdy.sharding = #sdy.sharding<@mesh, [{"z"}, {"y"}, {"x"}]>}) -> (tensor<20x24x96xf64> {sdy.sharding = #sdy.sharding<@mesh, [{"z"}, {"y"}, {"x"}]>}) {
// PAD-NEXT: %cst = stablehlo.constant {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"z"}, {"y"}, {"x"}]>]>} dense<0.000000e+00> : tensor<20x24x80xf64>
// PAD-NEXT: %cst_0 = stablehlo.constant dense<0.000000e+00> : tensor<f64>
// PAD-NEXT: %cst_1 = stablehlo.constant dense<1.000000e+00> : tensor<f64>
// PAD-NEXT: %0 = stablehlo.pad %arg1, %cst_0, low = [0, 0, 8], high = [0, 0, 8], interior = [0, 0, 0] {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"z"}, {"y"}, {"x"}]>]>} : (tensor<20x24x80xf64>, tensor<f64>) -> tensor<20x24x96xf64>
// PAD-NEXT: %1 = stablehlo.pad %cst, %cst_1, low = [0, 0, 8], high = [0, 0, 8], interior = [0, 0, 0] {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"z"}, {"y"}, {"x"}]>]>} : (tensor<20x24x80xf64>, tensor<f64>) -> tensor<20x24x96xf64>
// PAD-NEXT: %2 = stablehlo.multiply %arg0, %1 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"z"}, {"y"}, {"x"}]>]>} : tensor<20x24x96xf64>
// PAD-NEXT: %3 = stablehlo.add %2, %0 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"z"}, {"y"}, {"x"}]>]>} : tensor<20x24x96xf64>
// PAD-NEXT: return %3 : tensor<20x24x96xf64>
// PAD-NEXT: %c = stablehlo.constant {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"z"}, {"y"}, {"x"}]>]>} dense<false> : tensor<20x24x80xi1>
// PAD-NEXT: %cst = stablehlo.constant dense<0.000000e+00> : tensor<f64>
// PAD-NEXT: %c_0 = stablehlo.constant dense<true> : tensor<i1>
// PAD-NEXT: %0 = stablehlo.pad %arg1, %cst, low = [0, 0, 8], high = [0, 0, 8], interior = [0, 0, 0] {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"z"}, {"y"}, {"x"}]>]>} : (tensor<20x24x80xf64>, tensor<f64>) -> tensor<20x24x96xf64>
// PAD-NEXT: %1 = stablehlo.pad %c, %c_0, low = [0, 0, 8], high = [0, 0, 8], interior = [0, 0, 0] {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"z"}, {"y"}, {"x"}]>]>} : (tensor<20x24x80xi1>, tensor<i1>) -> tensor<20x24x96xi1>
// PAD-NEXT: %2 = stablehlo.select %1, %arg0, %0 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"z"}, {"y"}, {"x"}]>]>} : tensor<20x24x96xi1>, tensor<20x24x96xf64>
// PAD-NEXT: return %2 : tensor<20x24x96xf64>
// PAD-NEXT: }

// CHECK: func.func @argUpdate1D(%arg0: tensor<20x24x96xf64> {sdy.sharding = #sdy.sharding<@mesh, [{"z"}, {"y"}, {"x"}]>}, %arg1: tensor<20x24x80xf64> {sdy.sharding = #sdy.sharding<@mesh, [{"z"}, {"y"}, {"x"}]>}) -> (tensor<20x24x96xf64> {sdy.sharding = #sdy.sharding<@mesh, [{"z"}, {"y"}, {"x"}]>}) {
Expand Down Expand Up @@ -158,14 +159,13 @@ func.func @argUpdate(%arg21: tensor<20x24x96xf64> {sdy.sharding = #sdy.sharding<
}

// PAD: func.func @argUpdate(%arg0: tensor<20x24x96xf64> {sdy.sharding = #sdy.sharding<@mesh, [{"z"}, {"y"}, {"x"}]>}, %arg1: tensor<4x1x80xf64> {sdy.sharding = #sdy.sharding<@mesh, [{"z"}, {"y"}, {"x"}]>}) -> (tensor<20x24x96xf64> {sdy.sharding = #sdy.sharding<@mesh, [{"z"}, {"y"}, {"x"}]>}) {
// PAD-NEXT: %cst = stablehlo.constant {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"z"}, {"y"}, {"x"}]>]>} dense<0.000000e+00> : tensor<4x1x80xf64>
// PAD-NEXT: %cst_0 = stablehlo.constant dense<0.000000e+00> : tensor<f64>
// PAD-NEXT: %cst_1 = stablehlo.constant dense<1.000000e+00> : tensor<f64>
// PAD-NEXT: %0 = stablehlo.pad %arg1, %cst_0, low = [8, 8, 8], high = [8, 15, 8], interior = [0, 0, 0] {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"z"}, {"y"}, {"x"}]>]>} : (tensor<4x1x80xf64>, tensor<f64>) -> tensor<20x24x96xf64>
// PAD-NEXT: %1 = stablehlo.pad %cst, %cst_1, low = [8, 8, 8], high = [8, 15, 8], interior = [0, 0, 0] {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"z"}, {"y"}, {"x"}]>]>} : (tensor<4x1x80xf64>, tensor<f64>) -> tensor<20x24x96xf64>
// PAD-NEXT: %2 = stablehlo.multiply %arg0, %1 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"z"}, {"y"}, {"x"}]>]>} : tensor<20x24x96xf64>
// PAD-NEXT: %3 = stablehlo.add %2, %0 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"z"}, {"y"}, {"x"}]>]>} : tensor<20x24x96xf64>
// PAD-NEXT: return %3 : tensor<20x24x96xf64>
// PAD-NEXT: %c = stablehlo.constant {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"z"}, {"y"}, {"x"}]>]>} dense<false> : tensor<4x1x80xi1>
// PAD-NEXT: %cst = stablehlo.constant dense<0.000000e+00> : tensor<f64>
// PAD-NEXT: %c_0 = stablehlo.constant dense<true> : tensor<i1>
// PAD-NEXT: %0 = stablehlo.pad %arg1, %cst, low = [8, 8, 8], high = [8, 15, 8], interior = [0, 0, 0] {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"z"}, {"y"}, {"x"}]>]>} : (tensor<4x1x80xf64>, tensor<f64>) -> tensor<20x24x96xf64>
// PAD-NEXT: %1 = stablehlo.pad %c, %c_0, low = [8, 8, 8], high = [8, 15, 8], interior = [0, 0, 0] {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"z"}, {"y"}, {"x"}]>]>} : (tensor<4x1x80xi1>, tensor<i1>) -> tensor<20x24x96xi1>
// PAD-NEXT: %2 = stablehlo.select %1, %arg0, %0 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"z"}, {"y"}, {"x"}]>]>} : tensor<20x24x96xi1>, tensor<20x24x96xf64>
// PAD-NEXT: return %2 : tensor<20x24x96xf64>
// PAD-NEXT: }

// CHECK: func.func @argUpdate(%arg0: tensor<20x24x96xf64> {sdy.sharding = #sdy.sharding<@mesh, [{"z"}, {"y"}, {"x"}]>}, %arg1: tensor<4x1x80xf64> {sdy.sharding = #sdy.sharding<@mesh, [{"z"}, {"y"}, {"x"}]>}) -> (tensor<20x24x96xf64> {sdy.sharding = #sdy.sharding<@mesh, [{"z"}, {"y"}, {"x"}]>}) {
Expand Down
Loading
Loading