From e9c621652c85628ccd6dc4f8f281561fdfa97164 Mon Sep 17 00:00:00 2001 From: Billy Moses Date: Wed, 3 Sep 2025 18:12:54 -0500 Subject: [PATCH 1/8] Try using select instead of add/mul --- .../jax/Passes/OptimizeCommunication.cpp | 93 ++++++++++++------- 1 file changed, 61 insertions(+), 32 deletions(-) diff --git a/src/enzyme_ad/jax/Passes/OptimizeCommunication.cpp b/src/enzyme_ad/jax/Passes/OptimizeCommunication.cpp index 060013828f..13bc9be6ff 100644 --- a/src/enzyme_ad/jax/Passes/OptimizeCommunication.cpp +++ b/src/enzyme_ad/jax/Passes/OptimizeCommunication.cpp @@ -667,6 +667,21 @@ bool isZero(Value v) { return false; } +SplatElementsAttr isSplat(ElementsAttr v) { + return dyn-cast(v); +} + +SplatElementsAttr isSplat(Value v) { + SplatElementsAttr elem; + if (matchPattern(v, m_Constant(&elem))) { + return elem; + } + if (auto sdyConstant = v.getDefiningOp()) { + return isZero(sdyConstant.getValue()); + } + return nullptr; +} + // TODO: we might need to update this to use the generalized version for the // generateShiftPairs function std::tuple @@ -2219,6 +2234,8 @@ struct DUSToPadComm : public OpRewritePattern { dus.getLoc(), rewriter.getZeroAttr(elementType)); auto one = rewriter.create( dus.getLoc(), rewriter.getOneAttr(elementType)); + auto oneI1 = rewriter.create( + dus.getLoc(), rewriter.getOneAttr(rewriter.getI1Type())); SmallVector padInner(ndims, 0); @@ -2230,51 +2247,63 @@ struct DUSToPadComm : public OpRewritePattern { operandShape[i] - updateShape[i] - constantStartIndices[i]; } Value updatePad = nullptr; - if (!isZero(update)) { - auto updatePadOp = rewriter.create( - dus.getLoc(), update, zero, updatePadLow, updatePadHigh, padInner); - sdy::setSharding(updatePadOp, sharding); - updatePad = updatePadOp; - } - Value maskedOperand = nullptr; - if (!isZero(operand)) { - auto updateType = cast(update.getType()); + auto splatUpdate = isSplat(update); + auto splatOperand = isSplat(update); + + + auto updateType = cast(update.getType()); + auto updateI1Type = RankedTensorType::Get(updateType.getShape(), rewriter.getI1Type()); + auto zeroAttr = + DenseElementsAttr::get(updateI1Type, rewriter.getZeroAttr(rewriter.getI1Type())); + auto zeroUpdateOp = rewriter.create( + dus.getLoc(), updateI1Type, zeroAttr); + sdy::setSharding(zeroUpdateOp, sharding); + + auto maskOp = rewriter.create( + dus.getLoc(), zeroUpdateOp, oneI1, updatePadLow, updatePadHigh, + padInner); + + Value resultV = nullptr; + if (splatOperand) { + auto padTy = RankedTensorType::get({}, operand.getType().getElementType()); + auto newOperand = rewriter.create(padTy, splatOperand.resizeSplat(padTy)); + auto maskedOperandOp = newOperand.create( + dus.getLoc(), update, newUpdate, updatePadLow, updatePadHigh, + padInner); + sdy::setSharding(maskedOperandOp, sharding); + } { + Value newUpdate; + if (splatUpdate) { + newUpdate = rewriter.create(operand.getType(), splatUpdate.resizeSplat(operand.getType())); + } else { + auto newUpdateOp = rewriter.create(operand.getType(), update, zero, updatePadLow, updatePadHigh, + padInner); + sdy::setSharding(newUpdateOp, sharding); + newUpdate = newUpdateOp; + } + + auto updateI1Type = RankedTensorType::get(updateType.getShape(), rewriter.getI1Type()); + auto zeroAttr = - DenseElementsAttr::get(updateType, rewriter.getZeroAttr(elementType)); + DenseElementsAttr::get(updateI1Type, rewriter.getZeroAttr(rewriter.getI1Type())); auto zeroUpdateOp = rewriter.create( - dus.getLoc(), updateType, zeroAttr); + dus.getLoc(), updateI1Type, zeroAttr); sdy::setSharding(zeroUpdateOp, sharding); auto maskOp = rewriter.create( - dus.getLoc(), zeroUpdateOp, one, updatePadLow, updatePadHigh, + dus.getLoc(), zeroUpdateOp, oneI1, updatePadLow, updatePadHigh, padInner); + sdy::setSharding(maskOp, sharding); auto maskedOperandOp = - rewriter.create(dus.getLoc(), operand, maskOp); + rewriter.create(dus.getLoc(), maskOp, operand, newUpdate); sdy::setSharding(maskedOperandOp, sharding); - maskedOperand = maskedOperandOp; - } - Value resultV = nullptr; - if (maskedOperand && updatePad) { - auto result = rewriter.create(dus.getLoc(), - maskedOperand, updatePad); - sdy::setSharding(result, sharding); - resultV = result; - } else if (maskedOperand) { - resultV = maskedOperand; - } else if (updatePad) { - resultV = updatePad; - } else { - auto cst = rewriter.create( - dus.getLoc(), dus.getType(), - cast(rewriter.getZeroAttr(dus.getType()))); - sdy::setSharding(cst, sharding); - resultV = cst; + resultV = maskedOperandOp; } - + rewriter.replaceOp(dus, resultV); return success(); } From 37f8f25623ea3efd52205aa7c1ff6892798b7a27 Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Wed, 3 Sep 2025 18:20:07 -0500 Subject: [PATCH 2/8] fix --- .../jax/Passes/OptimizeCommunication.cpp | 25 +++++++++++-------- 1 file changed, 15 insertions(+), 10 deletions(-) diff --git a/src/enzyme_ad/jax/Passes/OptimizeCommunication.cpp b/src/enzyme_ad/jax/Passes/OptimizeCommunication.cpp index 13bc9be6ff..39ff26a2ed 100644 --- a/src/enzyme_ad/jax/Passes/OptimizeCommunication.cpp +++ b/src/enzyme_ad/jax/Passes/OptimizeCommunication.cpp @@ -668,7 +668,7 @@ bool isZero(Value v) { } SplatElementsAttr isSplat(ElementsAttr v) { - return dyn-cast(v); + return dyn_cast(v); } SplatElementsAttr isSplat(Value v) { @@ -677,7 +677,7 @@ SplatElementsAttr isSplat(Value v) { return elem; } if (auto sdyConstant = v.getDefiningOp()) { - return isZero(sdyConstant.getValue()); + return isSplat(sdyConstant.getValue()); } return nullptr; } @@ -2253,7 +2253,8 @@ struct DUSToPadComm : public OpRewritePattern { auto updateType = cast(update.getType()); - auto updateI1Type = RankedTensorType::Get(updateType.getShape(), rewriter.getI1Type()); + auto updateI1Type = + RankedTensorType::get(updateType.getShape(), rewriter.getI1Type()); auto zeroAttr = DenseElementsAttr::get(updateI1Type, rewriter.getZeroAttr(rewriter.getI1Type())); auto zeroUpdateOp = rewriter.create( @@ -2267,18 +2268,22 @@ struct DUSToPadComm : public OpRewritePattern { Value resultV = nullptr; if (splatOperand) { auto padTy = RankedTensorType::get({}, operand.getType().getElementType()); - auto newOperand = rewriter.create(padTy, splatOperand.resizeSplat(padTy)); - auto maskedOperandOp = newOperand.create( - dus.getLoc(), update, newUpdate, updatePadLow, updatePadHigh, - padInner); + auto newOperand = rewriter.create( + dus.getLoc(), padTy, splatOperand.resizeSplat(padTy)); + auto maskedOperandOp = rewriter.create( + dus.getLoc(), update, newOperand, updatePadLow, updatePadHigh, + padInner); sdy::setSharding(maskedOperandOp, sharding); } { Value newUpdate; if (splatUpdate) { - newUpdate = rewriter.create(operand.getType(), splatUpdate.resizeSplat(operand.getType())); + newUpdate = rewriter.create( + dus.getLoc(), operand.getType(), + splatUpdate.resizeSplat(operand.getType())); } else { - auto newUpdateOp = rewriter.create(operand.getType(), update, zero, updatePadLow, updatePadHigh, - padInner); + auto newUpdateOp = rewriter.create( + dus.getLoc(), operand.getType(), update, zero, updatePadLow, + updatePadHigh, padInner); sdy::setSharding(newUpdateOp, sharding); newUpdate = newUpdateOp; } From 158a6d37a01b401938f1d3bf4555275ba624b104 Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Wed, 3 Sep 2025 18:21:35 -0500 Subject: [PATCH 3/8] oom test --- .github/workflows/test-gb-25.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/test-gb-25.yml b/.github/workflows/test-gb-25.yml index 70ce23af4c..978c10c676 100644 --- a/.github/workflows/test-gb-25.yml +++ b/.github/workflows/test-gb-25.yml @@ -51,6 +51,7 @@ jobs: # - 'b25f3cbed2bc88c8ffef85f6a5319e2cf7b0454c' gb25_commit: - 'main' + - 'mg/oom-reproducer' # - '0123456789abcdef0123456789abcdef01234567' reactant_commit: - 'main' From fc3a4112f7ea6dca0787b12ab8f7d3e672ed24d2 Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Wed, 3 Sep 2025 18:27:42 -0500 Subject: [PATCH 4/8] ftest --- test/lit_tests/communication/dus.mlir | 52 +++++++++++++-------------- 1 file changed, 26 insertions(+), 26 deletions(-) diff --git a/test/lit_tests/communication/dus.mlir b/test/lit_tests/communication/dus.mlir index e47a0e5460..23ee597c65 100644 --- a/test/lit_tests/communication/dus.mlir +++ b/test/lit_tests/communication/dus.mlir @@ -12,11 +12,12 @@ func.func @constantUpdate1D(%arg21: tensor<20x24x96xf64> {sdy.sharding = #sdy.sh } // PAD: func.func @constantUpdate1D(%arg0: tensor<20x24x96xf64> {sdy.sharding = #sdy.sharding<@mesh, [{"z"}, {"y"}, {"x"}]>}) -> (tensor<20x24x96xf64> {sdy.sharding = #sdy.sharding<@mesh, [{"z"}, {"y"}, {"x"}]>}) { -// PAD-NEXT: %[[cst1:.+]] = stablehlo.constant dense<1.000000e+00> : tensor -// PAD-NEXT: %[[cst0:.+]] = stablehlo.constant {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"z"}, {"y"}, {"x"}]>]>} dense<0.000000e+00> : tensor<20x24x80xf64> -// PAD-NEXT: %[[p0:.+]] = stablehlo.pad %[[cst0]], %[[cst1]], low = [0, 0, 8], high = [0, 0, 8], interior = [0, 0, 0] {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"z"}, {"y"}, {"x"}]>]>} : (tensor<20x24x80xf64>, tensor) -> tensor<20x24x96xf64> -// PAD-NEXT: %[[m0:.+]] = stablehlo.multiply %arg0, %[[p0]] {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"z"}, {"y"}, {"x"}]>]>} : tensor<20x24x96xf64> -// PAD-NEXT: return %[[m0]] : tensor<20x24x96xf64> +// PAD-NEXT: %c = stablehlo.constant dense : tensor +// PAD-NEXT: %cst = stablehlo.constant dense<0.000000e+00> : tensor<20x24x96xf64> +// PAD-NEXT: %c_0 = stablehlo.constant {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"z"}, {"y"}, {"x"}]>]>} dense : tensor<20x24x80xi1> +// PAD-NEXT: %0 = stablehlo.pad %c_0, %c, low = [0, 0, 8], high = [0, 0, 8], interior = [0, 0, 0] {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"z"}, {"y"}, {"x"}]>]>} : (tensor<20x24x80xi1>, tensor) -> tensor<20x24x96xi1> +// PAD-NEXT: %1 = stablehlo.select %0, %arg0, %cst {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"z"}, {"y"}, {"x"}]>]>} : tensor<20x24x96xi1>, tensor<20x24x96xf64> +// PAD-NEXT: return %1 : tensor<20x24x96xf64> // PAD-NEXT: } // CHECK: func.func @constantUpdate1D(%arg0: tensor<20x24x96xf64> {sdy.sharding = #sdy.sharding<@mesh, [{"z"}, {"y"}, {"x"}]>}) -> (tensor<20x24x96xf64> {sdy.sharding = #sdy.sharding<@mesh, [{"z"}, {"y"}, {"x"}]>}) { @@ -59,11 +60,12 @@ func.func @constantUpdate(%arg21: tensor<20x24x96xf64> {sdy.sharding = #sdy.shar } // PAD: func.func @constantUpdate(%arg0: tensor<20x24x96xf64> {sdy.sharding = #sdy.sharding<@mesh, [{"z"}, {"y"}, {"x"}]>}) -> (tensor<20x24x96xf64> {sdy.sharding = #sdy.sharding<@mesh, [{"z"}, {"y"}, {"x"}]>}) { -// PAD-NEXT: %[[cst1:.+]] = stablehlo.constant dense<1.000000e+00> : tensor -// PAD-NEXT: %[[cst0:.+]] = stablehlo.constant {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"z"}, {"y"}, {"x"}]>]>} dense<0.000000e+00> : tensor<4x1x80xf64> -// PAD-NEXT: %[[mask:.+]] = stablehlo.pad %[[cst0]], %[[cst1]], low = [8, 8, 8], high = [8, 15, 8], interior = [0, 0, 0] {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"z"}, {"y"}, {"x"}]>]>} : (tensor<4x1x80xf64>, tensor) -> tensor<20x24x96xf64> -// PAD-NEXT: %[[mul:.+]] = stablehlo.multiply %arg0, %[[mask]] {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"z"}, {"y"}, {"x"}]>]>} : tensor<20x24x96xf64> -// PAD-NEXT: return %[[mul]] : tensor<20x24x96xf64> +// PAD-NEXT: %c = stablehlo.constant dense : tensor +// PAD-NEXT: %cst = stablehlo.constant dense<0.000000e+00> : tensor<20x24x96xf64> +// PAD-NEXT: %c_0 = stablehlo.constant {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"z"}, {"y"}, {"x"}]>]>} dense : tensor<4x1x80xi1> +// PAD-NEXT: %0 = stablehlo.pad %c_0, %c, low = [8, 8, 8], high = [8, 15, 8], interior = [0, 0, 0] {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"z"}, {"y"}, {"x"}]>]>} : (tensor<4x1x80xi1>, tensor) -> tensor<20x24x96xi1> +// PAD-NEXT: %1 = stablehlo.select %0, %arg0, %cst {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"z"}, {"y"}, {"x"}]>]>} : tensor<20x24x96xi1>, tensor<20x24x96xf64> +// PAD-NEXT: return %1 : tensor<20x24x96xf64> // PAD-NEXT: } // CHECK: func.func @constantUpdate(%arg0: tensor<20x24x96xf64> {sdy.sharding = #sdy.sharding<@mesh, [{"z"}, {"y"}, {"x"}]>}) -> (tensor<20x24x96xf64> {sdy.sharding = #sdy.sharding<@mesh, [{"z"}, {"y"}, {"x"}]>}) { @@ -108,14 +110,13 @@ func.func @argUpdate1D(%arg21: tensor<20x24x96xf64> {sdy.sharding = #sdy.shardin } // PAD: func.func @argUpdate1D(%arg0: tensor<20x24x96xf64> {sdy.sharding = #sdy.sharding<@mesh, [{"z"}, {"y"}, {"x"}]>}, %arg1: tensor<20x24x80xf64> {sdy.sharding = #sdy.sharding<@mesh, [{"z"}, {"y"}, {"x"}]>}) -> (tensor<20x24x96xf64> {sdy.sharding = #sdy.sharding<@mesh, [{"z"}, {"y"}, {"x"}]>}) { -// PAD-NEXT: %cst = stablehlo.constant {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"z"}, {"y"}, {"x"}]>]>} dense<0.000000e+00> : tensor<20x24x80xf64> -// PAD-NEXT: %cst_0 = stablehlo.constant dense<0.000000e+00> : tensor -// PAD-NEXT: %cst_1 = stablehlo.constant dense<1.000000e+00> : tensor -// PAD-NEXT: %0 = stablehlo.pad %arg1, %cst_0, low = [0, 0, 8], high = [0, 0, 8], interior = [0, 0, 0] {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"z"}, {"y"}, {"x"}]>]>} : (tensor<20x24x80xf64>, tensor) -> tensor<20x24x96xf64> -// PAD-NEXT: %1 = stablehlo.pad %cst, %cst_1, low = [0, 0, 8], high = [0, 0, 8], interior = [0, 0, 0] {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"z"}, {"y"}, {"x"}]>]>} : (tensor<20x24x80xf64>, tensor) -> tensor<20x24x96xf64> -// PAD-NEXT: %2 = stablehlo.multiply %arg0, %1 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"z"}, {"y"}, {"x"}]>]>} : tensor<20x24x96xf64> -// PAD-NEXT: %3 = stablehlo.add %2, %0 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"z"}, {"y"}, {"x"}]>]>} : tensor<20x24x96xf64> -// PAD-NEXT: return %3 : tensor<20x24x96xf64> +// PAD-NEXT: %c = stablehlo.constant {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"z"}, {"y"}, {"x"}]>]>} dense : tensor<20x24x80xi1> +// PAD-NEXT: %cst = stablehlo.constant dense<0.000000e+00> : tensor +// PAD-NEXT: %c_0 = stablehlo.constant dense : tensor +// PAD-NEXT: %0 = stablehlo.pad %arg1, %cst, low = [0, 0, 8], high = [0, 0, 8], interior = [0, 0, 0] {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"z"}, {"y"}, {"x"}]>]>} : (tensor<20x24x80xf64>, tensor) -> tensor<20x24x96xf64> +// PAD-NEXT: %1 = stablehlo.pad %c, %c_0, low = [0, 0, 8], high = [0, 0, 8], interior = [0, 0, 0] {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"z"}, {"y"}, {"x"}]>]>} : (tensor<20x24x80xi1>, tensor) -> tensor<20x24x96xi1> +// PAD-NEXT: %2 = stablehlo.select %1, %arg0, %0 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"z"}, {"y"}, {"x"}]>]>} : tensor<20x24x96xi1>, tensor<20x24x96xf64> +// PAD-NEXT: return %2 : tensor<20x24x96xf64> // PAD-NEXT: } // CHECK: func.func @argUpdate1D(%arg0: tensor<20x24x96xf64> {sdy.sharding = #sdy.sharding<@mesh, [{"z"}, {"y"}, {"x"}]>}, %arg1: tensor<20x24x80xf64> {sdy.sharding = #sdy.sharding<@mesh, [{"z"}, {"y"}, {"x"}]>}) -> (tensor<20x24x96xf64> {sdy.sharding = #sdy.sharding<@mesh, [{"z"}, {"y"}, {"x"}]>}) { @@ -158,14 +159,13 @@ func.func @argUpdate(%arg21: tensor<20x24x96xf64> {sdy.sharding = #sdy.sharding< } // PAD: func.func @argUpdate(%arg0: tensor<20x24x96xf64> {sdy.sharding = #sdy.sharding<@mesh, [{"z"}, {"y"}, {"x"}]>}, %arg1: tensor<4x1x80xf64> {sdy.sharding = #sdy.sharding<@mesh, [{"z"}, {"y"}, {"x"}]>}) -> (tensor<20x24x96xf64> {sdy.sharding = #sdy.sharding<@mesh, [{"z"}, {"y"}, {"x"}]>}) { -// PAD-NEXT: %cst = stablehlo.constant {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"z"}, {"y"}, {"x"}]>]>} dense<0.000000e+00> : tensor<4x1x80xf64> -// PAD-NEXT: %cst_0 = stablehlo.constant dense<0.000000e+00> : tensor -// PAD-NEXT: %cst_1 = stablehlo.constant dense<1.000000e+00> : tensor -// PAD-NEXT: %0 = stablehlo.pad %arg1, %cst_0, low = [8, 8, 8], high = [8, 15, 8], interior = [0, 0, 0] {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"z"}, {"y"}, {"x"}]>]>} : (tensor<4x1x80xf64>, tensor) -> tensor<20x24x96xf64> -// PAD-NEXT: %1 = stablehlo.pad %cst, %cst_1, low = [8, 8, 8], high = [8, 15, 8], interior = [0, 0, 0] {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"z"}, {"y"}, {"x"}]>]>} : (tensor<4x1x80xf64>, tensor) -> tensor<20x24x96xf64> -// PAD-NEXT: %2 = stablehlo.multiply %arg0, %1 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"z"}, {"y"}, {"x"}]>]>} : tensor<20x24x96xf64> -// PAD-NEXT: %3 = stablehlo.add %2, %0 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"z"}, {"y"}, {"x"}]>]>} : tensor<20x24x96xf64> -// PAD-NEXT: return %3 : tensor<20x24x96xf64> +// PAD-NEXT: %c = stablehlo.constant {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"z"}, {"y"}, {"x"}]>]>} dense : tensor<4x1x80xi1> +// PAD-NEXT: %cst = stablehlo.constant dense<0.000000e+00> : tensor +// PAD-NEXT: %c_0 = stablehlo.constant dense : tensor +// PAD-NEXT: %0 = stablehlo.pad %arg1, %cst, low = [8, 8, 8], high = [8, 15, 8], interior = [0, 0, 0] {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"z"}, {"y"}, {"x"}]>]>} : (tensor<4x1x80xf64>, tensor) -> tensor<20x24x96xf64> +// PAD-NEXT: %1 = stablehlo.pad %c, %c_0, low = [8, 8, 8], high = [8, 15, 8], interior = [0, 0, 0] {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"z"}, {"y"}, {"x"}]>]>} : (tensor<4x1x80xi1>, tensor) -> tensor<20x24x96xi1> +// PAD-NEXT: %2 = stablehlo.select %1, %arg0, %0 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"z"}, {"y"}, {"x"}]>]>} : tensor<20x24x96xi1>, tensor<20x24x96xf64> +// PAD-NEXT: return %2 : tensor<20x24x96xf64> // PAD-NEXT: } // CHECK: func.func @argUpdate(%arg0: tensor<20x24x96xf64> {sdy.sharding = #sdy.sharding<@mesh, [{"z"}, {"y"}, {"x"}]>}, %arg1: tensor<4x1x80xf64> {sdy.sharding = #sdy.sharding<@mesh, [{"z"}, {"y"}, {"x"}]>}) -> (tensor<20x24x96xf64> {sdy.sharding = #sdy.sharding<@mesh, [{"z"}, {"y"}, {"x"}]>}) { From 858f79c673f7b3e0a285d4c1fcd46f5adeeeb720 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mos=C3=A8=20Giordano?= <765740+giordano@users.noreply.github.com> Date: Thu, 4 Sep 2025 16:07:38 +0200 Subject: [PATCH 5/8] Increase grid size in GB-25 simulation --- .github/workflows/test-gb-25.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/test-gb-25.yml b/.github/workflows/test-gb-25.yml index 978c10c676..419cae6313 100644 --- a/.github/workflows/test-gb-25.yml +++ b/.github/workflows/test-gb-25.yml @@ -188,7 +188,7 @@ jobs: timeout-minutes: 60 run: | export XLA_FLAGS='--xla_dump_to=${{ env.GB25_DIR }}/xla_dump' - timeout --signal=TERM --verbose 59m mpiexecjl -np 1 julia --color=yes --project -O0 --startup-file=no --threads=16 --compiled-modules=strict sharding/sharded_baroclinic_instability_simulation_run.jl + timeout --signal=TERM --verbose 59m mpiexecjl -np 1 julia --color=yes --project -O0 --startup-file=no --threads=16 --compiled-modules=strict sharding/sharded_baroclinic_instability_simulation_run.jl --grid-x=6144 --grid-y=1536 --grid-z=4 working-directory: ${{ env.GB25_DIR }} - name: Test correctness in GB-25 code timeout-minutes: 20 From a21b5a1cfe358941617ecff502f5ff1719b1fcad Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Thu, 4 Sep 2025 09:13:39 -0500 Subject: [PATCH 6/8] test fix --- test/lit_tests/communication/dus2D.mlir | 78 ++++++++++----------- test/lit_tests/communication/dusNonDiv.mlir | 52 +++++++------- 2 files changed, 65 insertions(+), 65 deletions(-) diff --git a/test/lit_tests/communication/dus2D.mlir b/test/lit_tests/communication/dus2D.mlir index 757f1c6756..8507013224 100644 --- a/test/lit_tests/communication/dus2D.mlir +++ b/test/lit_tests/communication/dus2D.mlir @@ -12,11 +12,12 @@ func.func @constantUpdate1D(%arg21: tensor<20x24x96xf64> {sdy.sharding = #sdy.sh } // PAD: func.func @constantUpdate1D(%arg0: tensor<20x24x96xf64> {sdy.sharding = #sdy.sharding<@mesh, [{"z"}, {"y"}, {"x"}]>}) -> (tensor<20x24x96xf64> {sdy.sharding = #sdy.sharding<@mesh, [{"z"}, {"y"}, {"x"}]>}) { -// PAD-NEXT: %[[cst1:.+]] = stablehlo.constant dense<1.000000e+00> : tensor -// PAD-NEXT: %[[cst0:.+]] = stablehlo.constant {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"z"}, {"y"}, {"x"}]>]>} dense<0.000000e+00> : tensor<20x24x80xf64> -// PAD-NEXT: %[[p0:.+]] = stablehlo.pad %[[cst0]], %[[cst1]], low = [0, 0, 8], high = [0, 0, 8], interior = [0, 0, 0] {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"z"}, {"y"}, {"x"}]>]>} : (tensor<20x24x80xf64>, tensor) -> tensor<20x24x96xf64> -// PAD-NEXT: %[[m0:.+]] = stablehlo.multiply %arg0, %[[p0]] {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"z"}, {"y"}, {"x"}]>]>} : tensor<20x24x96xf64> -// PAD-NEXT: return %[[m0]] : tensor<20x24x96xf64> +// PAD-NEXT: %c = stablehlo.constant dense : tensor +// PAD-NEXT: %cst = stablehlo.constant dense<0.000000e+00> : tensor<20x24x96xf64> +// PAD-NEXT: %c_0 = stablehlo.constant {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"z"}, {"y"}, {"x"}]>]>} dense : tensor<20x24x80xi1> +// PAD-NEXT: %0 = stablehlo.pad %c_0, %c, low = [0, 0, 8], high = [0, 0, 8], interior = [0, 0, 0] {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"z"}, {"y"}, {"x"}]>]>} : (tensor<20x24x80xi1>, tensor) -> tensor<20x24x96xi1> +// PAD-NEXT: %1 = stablehlo.select %0, %arg0, %cst {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"z"}, {"y"}, {"x"}]>]>} : tensor<20x24x96xi1>, tensor<20x24x96xf64> +// PAD-NEXT: return %1 : tensor<20x24x96xf64> // PAD-NEXT: } // CHECK: func.func @constantUpdate1D(%arg0: tensor<20x24x96xf64> {sdy.sharding = #sdy.sharding<@mesh, [{"z"}, {"y"}, {"x"}]>}) -> (tensor<20x24x96xf64> {sdy.sharding = #sdy.sharding<@mesh, [{"z"}, {"y"}, {"x"}]>}) { @@ -61,11 +62,12 @@ func.func @constantUpdateOver(%arg21: tensor<20x24x96xf64> {sdy.sharding = #sdy. } // PAD: func.func @constantUpdateOver(%arg0: tensor<20x24x96xf64> {sdy.sharding = #sdy.sharding<@mesh, [{"z"}, {"y"}, {"x"}]>}) -> (tensor<20x24x96xf64> {sdy.sharding = #sdy.sharding<@mesh, [{"z"}, {"y"}, {"x"}]>}) { -// PAD-DAG: %[[cst0:.+]] = stablehlo.constant {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"z"}, {"y"}, {"x"}]>]>} dense<0.000000e+00> : tensor<4x1x80xf64> -// PAD-DAG: %[[cst1:.+]] = stablehlo.constant dense<1.000000e+00> : tensor -// PAD-NEXT: %[[mask:.+]] = stablehlo.pad %[[cst0]], %[[cst1]], low = [8, 8, 8], high = [8, 15, 8], interior = [0, 0, 0] {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"z"}, {"y"}, {"x"}]>]>} : (tensor<4x1x80xf64>, tensor) -> tensor<20x24x96xf64> -// PAD-NEXT: %[[mul:.+]] = stablehlo.multiply %arg0, %[[mask]] {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"z"}, {"y"}, {"x"}]>]>} : tensor<20x24x96xf64> -// PAD-NEXT: return %[[mul]] : tensor<20x24x96xf64> +// PAD-NEXT: %c = stablehlo.constant dense : tensor +// PAD-NEXT: %cst = stablehlo.constant dense<0.000000e+00> : tensor<20x24x96xf64> +// PAD-NEXT: %c_0 = stablehlo.constant {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"z"}, {"y"}, {"x"}]>]>} dense : tensor<4x1x80xi1> +// PAD-NEXT: %0 = stablehlo.pad %c_0, %c, low = [8, 8, 8], high = [8, 15, 8], interior = [0, 0, 0] {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"z"}, {"y"}, {"x"}]>]>} : (tensor<4x1x80xi1>, tensor) -> tensor<20x24x96xi1> +// PAD-NEXT: %1 = stablehlo.select %0, %arg0, %cst {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"z"}, {"y"}, {"x"}]>]>} : tensor<20x24x96xi1>, tensor<20x24x96xf64> +// PAD-NEXT: return %1 : tensor<20x24x96xf64> // PAD-NEXT: } // CHECK: func.func @constantUpdateOver(%arg0: tensor<20x24x96xf64> {sdy.sharding = #sdy.sharding<@mesh, [{"z"}, {"y"}, {"x"}]>}) -> (tensor<20x24x96xf64> {sdy.sharding = #sdy.sharding<@mesh, [{"z"}, {"y"}, {"x"}]>}) { @@ -132,11 +134,12 @@ func.func @constantUpdate(%arg21: tensor<20x24x96xf64> {sdy.sharding = #sdy.shar } // PAD: func.func @constantUpdate(%arg0: tensor<20x24x96xf64> {sdy.sharding = #sdy.sharding<@mesh, [{"z"}, {"y"}, {"x"}]>}) -> (tensor<20x24x96xf64> {sdy.sharding = #sdy.sharding<@mesh, [{"z"}, {"y"}, {"x"}]>}) { -// PAD-DAG: %[[c0:.+]] = stablehlo.constant {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"z"}, {"y"}, {"x"}]>]>} dense<0.000000e+00> : tensor<4x8x80xf64> -// PAD-DAG: %[[cst1:.+]] = stablehlo.constant dense<1.000000e+00> : tensor -// PAD-NEXT: %[[mask:.+]] = stablehlo.pad %[[c0]], %[[cst1]], low = [8, 8, 8], high = [8, 8, 8], interior = [0, 0, 0] {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"z"}, {"y"}, {"x"}]>]>} : (tensor<4x8x80xf64>, tensor) -> tensor<20x24x96xf64> -// PAD-NEXT: %[[mul:.+]] = stablehlo.multiply %arg0, %[[mask]] {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"z"}, {"y"}, {"x"}]>]>} : tensor<20x24x96xf64> -// PAD-NEXT: return %[[mul]] : tensor<20x24x96xf64> +// PAD-NEXT: %c = stablehlo.constant dense : tensor +// PAD-NEXT: %cst = stablehlo.constant dense<0.000000e+00> : tensor<20x24x96xf64> +// PAD-NEXT: %c_0 = stablehlo.constant {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"z"}, {"y"}, {"x"}]>]>} dense : tensor<4x8x80xi1> +// PAD-NEXT: %0 = stablehlo.pad %c_0, %c, low = [8, 8, 8], high = [8, 8, 8], interior = [0, 0, 0] {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"z"}, {"y"}, {"x"}]>]>} : (tensor<4x8x80xi1>, tensor) -> tensor<20x24x96xi1> +// PAD-NEXT: %1 = stablehlo.select %0, %arg0, %cst {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"z"}, {"y"}, {"x"}]>]>} : tensor<20x24x96xi1>, tensor<20x24x96xf64> +// PAD-NEXT: return %1 : tensor<20x24x96xf64> // PAD-NEXT: } // CHECK: func.func @constantUpdate(%arg0: tensor<20x24x96xf64> {sdy.sharding = #sdy.sharding<@mesh, [{"z"}, {"y"}, {"x"}]>}) -> (tensor<20x24x96xf64> {sdy.sharding = #sdy.sharding<@mesh, [{"z"}, {"y"}, {"x"}]>}) { @@ -200,14 +203,13 @@ func.func @argUpdate1D(%arg21: tensor<20x24x96xf64> {sdy.sharding = #sdy.shardin } // PAD: func.func @argUpdate1D(%arg0: tensor<20x24x96xf64> {sdy.sharding = #sdy.sharding<@mesh, [{"z"}, {"y"}, {"x"}]>}, %arg1: tensor<20x24x80xf64> {sdy.sharding = #sdy.sharding<@mesh, [{"z"}, {"y"}, {"x"}]>}) -> (tensor<20x24x96xf64> {sdy.sharding = #sdy.sharding<@mesh, [{"z"}, {"y"}, {"x"}]>}) { -// PAD-NEXT: %cst = stablehlo.constant {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"z"}, {"y"}, {"x"}]>]>} dense<0.000000e+00> : tensor<20x24x80xf64> -// PAD-NEXT: %cst_0 = stablehlo.constant dense<0.000000e+00> : tensor -// PAD-NEXT: %cst_1 = stablehlo.constant dense<1.000000e+00> : tensor -// PAD-NEXT: %0 = stablehlo.pad %arg1, %cst_0, low = [0, 0, 8], high = [0, 0, 8], interior = [0, 0, 0] {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"z"}, {"y"}, {"x"}]>]>} : (tensor<20x24x80xf64>, tensor) -> tensor<20x24x96xf64> -// PAD-NEXT: %1 = stablehlo.pad %cst, %cst_1, low = [0, 0, 8], high = [0, 0, 8], interior = [0, 0, 0] {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"z"}, {"y"}, {"x"}]>]>} : (tensor<20x24x80xf64>, tensor) -> tensor<20x24x96xf64> -// PAD-NEXT: %2 = stablehlo.multiply %arg0, %1 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"z"}, {"y"}, {"x"}]>]>} : tensor<20x24x96xf64> -// PAD-NEXT: %3 = stablehlo.add %2, %0 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"z"}, {"y"}, {"x"}]>]>} : tensor<20x24x96xf64> -// PAD-NEXT: return %3 : tensor<20x24x96xf64> +// PAD-NEXT: %c = stablehlo.constant {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"z"}, {"y"}, {"x"}]>]>} dense : tensor<20x24x80xi1> +// PAD-NEXT: %cst = stablehlo.constant dense<0.000000e+00> : tensor +// PAD-NEXT: %c_0 = stablehlo.constant dense : tensor +// PAD-NEXT: %0 = stablehlo.pad %arg1, %cst, low = [0, 0, 8], high = [0, 0, 8], interior = [0, 0, 0] {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"z"}, {"y"}, {"x"}]>]>} : (tensor<20x24x80xf64>, tensor) -> tensor<20x24x96xf64> +// PAD-NEXT: %1 = stablehlo.pad %c, %c_0, low = [0, 0, 8], high = [0, 0, 8], interior = [0, 0, 0] {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"z"}, {"y"}, {"x"}]>]>} : (tensor<20x24x80xi1>, tensor) -> tensor<20x24x96xi1> +// PAD-NEXT: %2 = stablehlo.select %1, %arg0, %0 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"z"}, {"y"}, {"x"}]>]>} : tensor<20x24x96xi1>, tensor<20x24x96xf64> +// PAD-NEXT: return %2 : tensor<20x24x96xf64> // PAD-NEXT: } // CHECK: func.func @argUpdate1D(%arg0: tensor<20x24x96xf64> {sdy.sharding = #sdy.sharding<@mesh, [{"z"}, {"y"}, {"x"}]>}, %arg1: tensor<20x24x80xf64> {sdy.sharding = #sdy.sharding<@mesh, [{"z"}, {"y"}, {"x"}]>}) -> (tensor<20x24x96xf64> {sdy.sharding = #sdy.sharding<@mesh, [{"z"}, {"y"}, {"x"}]>}) { @@ -252,14 +254,13 @@ func.func @argUpdateOver(%arg21: tensor<20x24x96xf64> {sdy.sharding = #sdy.shard } // PAD: func.func @argUpdateOver(%arg0: tensor<20x24x96xf64> {sdy.sharding = #sdy.sharding<@mesh, [{"z"}, {"y"}, {"x"}]>}, %arg1: tensor<4x1x80xf64> {sdy.sharding = #sdy.sharding<@mesh, [{"z"}, {"y"}, {"x"}]>}) -> (tensor<20x24x96xf64> {sdy.sharding = #sdy.sharding<@mesh, [{"z"}, {"y"}, {"x"}]>}) { -// PAD-DAG: %[[c0:.+]] = stablehlo.constant {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"z"}, {"y"}, {"x"}]>]>} dense<0.000000e+00> : tensor<4x1x80xf64> -// PAD-DAG: %[[cst0:.+]] = stablehlo.constant dense<0.000000e+00> : tensor -// PAD-DAG: %[[cst1:.+]] = stablehlo.constant dense<1.000000e+00> : tensor -// PAD-NEXT: %0 = stablehlo.pad %arg1, %[[cst0]], low = [8, 8, 8], high = [8, 15, 8], interior = [0, 0, 0] {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"z"}, {"y"}, {"x"}]>]>} : (tensor<4x1x80xf64>, tensor) -> tensor<20x24x96xf64> -// PAD-NEXT: %1 = stablehlo.pad %[[c0]], %[[cst1]], low = [8, 8, 8], high = [8, 15, 8], interior = [0, 0, 0] {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"z"}, {"y"}, {"x"}]>]>} : (tensor<4x1x80xf64>, tensor) -> tensor<20x24x96xf64> -// PAD-NEXT: %2 = stablehlo.multiply %arg0, %1 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"z"}, {"y"}, {"x"}]>]>} : tensor<20x24x96xf64> -// PAD-NEXT: %3 = stablehlo.add %2, %0 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"z"}, {"y"}, {"x"}]>]>} : tensor<20x24x96xf64> -// PAD-NEXT: return %3 : tensor<20x24x96xf64> +// PAD-NEXT %c = stablehlo.constant {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"z"}, {"y"}, {"x"}]>]>} dense : tensor<4x1x80xi1> +// PAD-NEXT %cst = stablehlo.constant dense<0.000000e+00> : tensor +// PAD-NEXT %c_0 = stablehlo.constant dense : tensor +// PAD-NEXT %0 = stablehlo.pad %arg1, %cst, low = [8, 8, 8], high = [8, 15, 8], interior = [0, 0, 0] {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"z"}, {"y"}, {"x"}]>]>} : (tensor<4x1x80xf64>, tensor) -> tensor<20x24x96xf64> +// PAD-NEXT %1 = stablehlo.pad %c, %c_0, low = [8, 8, 8], high = [8, 15, 8], interior = [0, 0, 0] {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"z"}, {"y"}, {"x"}]>]>} : (tensor<4x1x80xi1>, tensor) -> tensor<20x24x96xi1> +// PAD-NEXT %2 = stablehlo.select %1, %arg0, %0 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"z"}, {"y"}, {"x"}]>]>} : tensor<20x24x96xi1>, tensor<20x24x96xf64> +// PAD-NEXT return %2 : tensor<20x24x96xf64> // PAD-NEXT: } // CHECK: func.func @argUpdateOver(%arg0: tensor<20x24x96xf64> {sdy.sharding = #sdy.sharding<@mesh, [{"z"}, {"y"}, {"x"}]>}, %arg1: tensor<4x1x80xf64> {sdy.sharding = #sdy.sharding<@mesh, [{"z"}, {"y"}, {"x"}]>}) -> (tensor<20x24x96xf64> {sdy.sharding = #sdy.sharding<@mesh, [{"z"}, {"y"}, {"x"}]>}) { @@ -326,14 +327,13 @@ func.func @argUpdate(%arg21: tensor<20x24x96xf64> {sdy.sharding = #sdy.sharding< } // PAD: func.func @argUpdate(%arg0: tensor<20x24x96xf64> {sdy.sharding = #sdy.sharding<@mesh, [{"z"}, {"y"}, {"x"}]>}, %arg1: tensor<4x8x80xf64> {sdy.sharding = #sdy.sharding<@mesh, [{"z"}, {"y"}, {"x"}]>}) -> (tensor<20x24x96xf64> {sdy.sharding = #sdy.sharding<@mesh, [{"z"}, {"y"}, {"x"}]>}) { -// PAD-NEXT: %cst = stablehlo.constant {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"z"}, {"y"}, {"x"}]>]>} dense<0.000000e+00> : tensor<4x8x80xf64> -// PAD-NEXT: %cst_0 = stablehlo.constant dense<0.000000e+00> : tensor -// PAD-NEXT: %cst_1 = stablehlo.constant dense<1.000000e+00> : tensor -// PAD-NEXT: %0 = stablehlo.pad %arg1, %cst_0, low = [8, 8, 8], high = [8, 8, 8], interior = [0, 0, 0] {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"z"}, {"y"}, {"x"}]>]>} : (tensor<4x8x80xf64>, tensor) -> tensor<20x24x96xf64> -// PAD-NEXT: %1 = stablehlo.pad %cst, %cst_1, low = [8, 8, 8], high = [8, 8, 8], interior = [0, 0, 0] {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"z"}, {"y"}, {"x"}]>]>} : (tensor<4x8x80xf64>, tensor) -> tensor<20x24x96xf64> -// PAD-NEXT: %2 = stablehlo.multiply %arg0, %1 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"z"}, {"y"}, {"x"}]>]>} : tensor<20x24x96xf64> -// PAD-NEXT: %3 = stablehlo.add %2, %0 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"z"}, {"y"}, {"x"}]>]>} : tensor<20x24x96xf64> -// PAD-NEXT: return %3 : tensor<20x24x96xf64> +// PAD-NEXT: %c = stablehlo.constant {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"z"}, {"y"}, {"x"}]>]>} dense : tensor<4x8x80xi1> +// PAD-NEXT: %cst = stablehlo.constant dense<0.000000e+00> : tensor +// PAD-NEXT: %c_0 = stablehlo.constant dense : tensor +// PAD-NEXT: %0 = stablehlo.pad %arg1, %cst, low = [8, 8, 8], high = [8, 8, 8], interior = [0, 0, 0] {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"z"}, {"y"}, {"x"}]>]>} : (tensor<4x8x80xf64>, tensor) -> tensor<20x24x96xf64> +// PAD-NEXT: %1 = stablehlo.pad %c, %c_0, low = [8, 8, 8], high = [8, 8, 8], interior = [0, 0, 0] {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"z"}, {"y"}, {"x"}]>]>} : (tensor<4x8x80xi1>, tensor) -> tensor<20x24x96xi1> +// PAD-NEXT: %2 = stablehlo.select %1, %arg0, %0 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"z"}, {"y"}, {"x"}]>]>} : tensor<20x24x96xi1>, tensor<20x24x96xf64> +// PAD-NEXT: return %2 : tensor<20x24x96xf64> // PAD-NEXT: } // CHECK: func.func @argUpdate(%arg0: tensor<20x24x96xf64> {sdy.sharding = #sdy.sharding<@mesh, [{"z"}, {"y"}, {"x"}]>}, %arg1: tensor<4x8x80xf64> {sdy.sharding = #sdy.sharding<@mesh, [{"z"}, {"y"}, {"x"}]>}) -> (tensor<20x24x96xf64> {sdy.sharding = #sdy.sharding<@mesh, [{"z"}, {"y"}, {"x"}]>}) { diff --git a/test/lit_tests/communication/dusNonDiv.mlir b/test/lit_tests/communication/dusNonDiv.mlir index 84623e3029..1bde0dcc32 100644 --- a/test/lit_tests/communication/dusNonDiv.mlir +++ b/test/lit_tests/communication/dusNonDiv.mlir @@ -12,11 +12,12 @@ func.func @constantUpdate1D(%arg21: tensor<20x24x97xf64> {sdy.sharding = #sdy.sh } // PAD: func.func @constantUpdate1D(%arg0: tensor<20x24x97xf64> {sdy.sharding = #sdy.sharding<@mesh, [{"z"}, {"y"}, {"x"}]>}) -> (tensor<20x24x97xf64> {sdy.sharding = #sdy.sharding<@mesh, [{"z"}, {"y"}, {"x"}]>}) { -// PAD-NEXT: %[[cst1:.+]] = stablehlo.constant dense<1.000000e+00> : tensor -// PAD-NEXT: %[[cst0:.+]] = stablehlo.constant {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"z"}, {"y"}, {"x"}]>]>} dense<0.000000e+00> : tensor<20x24x80xf64> -// PAD-NEXT: %[[mask:.+]] = stablehlo.pad %[[cst0]], %[[cst1]], low = [0, 0, 8], high = [0, 0, 9], interior = [0, 0, 0] {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"z"}, {"y"}, {"x"}]>]>} : (tensor<20x24x80xf64>, tensor) -> tensor<20x24x97xf64> -// PAD-NEXT: %[[res:.+]] = stablehlo.multiply %arg0, %[[mask]] {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"z"}, {"y"}, {"x"}]>]>} : tensor<20x24x97xf64> -// PAD-NEXT: return %[[res]] : tensor<20x24x97xf64> +// PAD-NEXT: %c = stablehlo.constant dense : tensor +// PAD-NEXT: %cst = stablehlo.constant dense<0.000000e+00> : tensor<20x24x97xf64> +// PAD-NEXT: %c_0 = stablehlo.constant {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"z"}, {"y"}, {"x"}]>]>} dense : tensor<20x24x80xi1> +// PAD-NEXT: %0 = stablehlo.pad %c_0, %c, low = [0, 0, 8], high = [0, 0, 9], interior = [0, 0, 0] {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"z"}, {"y"}, {"x"}]>]>} : (tensor<20x24x80xi1>, tensor) -> tensor<20x24x97xi1> +// PAD-NEXT: %1 = stablehlo.select %0, %arg0, %cst {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"z"}, {"y"}, {"x"}]>]>} : tensor<20x24x97xi1>, tensor<20x24x97xf64> +// PAD-NEXT: return %1 : tensor<20x24x97xf64> // PAD-NEXT: } // CHECK: func.func @constantUpdate1D(%arg0: tensor<20x24x97xf64> {sdy.sharding = #sdy.sharding<@mesh, [{"z"}, {"y"}, {"x"}]>}) -> (tensor<20x24x97xf64> {sdy.sharding = #sdy.sharding<@mesh, [{"z"}, {"y"}, {"x"}]>}) { @@ -63,11 +64,12 @@ func.func @constantUpdate(%arg21: tensor<20x24x97xf64> {sdy.sharding = #sdy.shar } // PAD: func.func @constantUpdate(%arg0: tensor<20x24x97xf64> {sdy.sharding = #sdy.sharding<@mesh, [{"z"}, {"y"}, {"x"}]>}) -> (tensor<20x24x97xf64> {sdy.sharding = #sdy.sharding<@mesh, [{"z"}, {"y"}, {"x"}]>}) { -// PAD-NEXT: %[[cst1:.+]] = stablehlo.constant dense<1.000000e+00> : tensor -// PAD-NEXT: %[[cst0:.+]] = stablehlo.constant {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"z"}, {"y"}, {"x"}]>]>} dense<0.000000e+00> : tensor<4x1x80xf64> -// PAD-NEXT: %[[mask:.+]] = stablehlo.pad %[[cst0]], %[[cst1]], low = [8, 8, 8], high = [8, 15, 9], interior = [0, 0, 0] {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"z"}, {"y"}, {"x"}]>]>} : (tensor<4x1x80xf64>, tensor) -> tensor<20x24x97xf64> -// PAD-NEXT: %[[res:.+]] = stablehlo.multiply %arg0, %[[mask]] {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"z"}, {"y"}, {"x"}]>]>} : tensor<20x24x97xf64> -// PAD-NEXT: return %[[res]] : tensor<20x24x97xf64> +// PAD-NEXT: %c = stablehlo.constant dense : tensor +// PAD-NEXT: %cst = stablehlo.constant dense<0.000000e+00> : tensor<20x24x97xf64> +// PAD-NEXT: %c_0 = stablehlo.constant {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"z"}, {"y"}, {"x"}]>]>} dense : tensor<4x1x80xi1> +// PAD-NEXT: %0 = stablehlo.pad %c_0, %c, low = [8, 8, 8], high = [8, 15, 9], interior = [0, 0, 0] {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"z"}, {"y"}, {"x"}]>]>} : (tensor<4x1x80xi1>, tensor) -> tensor<20x24x97xi1> +// PAD-NEXT: %1 = stablehlo.select %0, %arg0, %cst {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"z"}, {"y"}, {"x"}]>]>} : tensor<20x24x97xi1>, tensor<20x24x97xf64> +// PAD-NEXT: return %1 : tensor<20x24x97xf64> // PAD-NEXT: } // CHECK: func.func @constantUpdate(%arg0: tensor<20x24x97xf64> {sdy.sharding = #sdy.sharding<@mesh, [{"z"}, {"y"}, {"x"}]>}) -> (tensor<20x24x97xf64> {sdy.sharding = #sdy.sharding<@mesh, [{"z"}, {"y"}, {"x"}]>}) { @@ -115,14 +117,13 @@ func.func @argUpdate1D(%arg21: tensor<20x24x97xf64> {sdy.sharding = #sdy.shardin } // PAD: func.func @argUpdate1D(%arg0: tensor<20x24x97xf64> {sdy.sharding = #sdy.sharding<@mesh, [{"z"}, {"y"}, {"x"}]>}, %arg1: tensor<20x24x80xf64> {sdy.sharding = #sdy.sharding<@mesh, [{"z"}, {"y"}, {"x"}]>}) -> (tensor<20x24x97xf64> {sdy.sharding = #sdy.sharding<@mesh, [{"z"}, {"y"}, {"x"}]>}) { -// PAD-NEXT: %cst = stablehlo.constant {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"z"}, {"y"}, {"x"}]>]>} dense<0.000000e+00> : tensor<20x24x80xf64> -// PAD-NEXT: %cst_0 = stablehlo.constant dense<0.000000e+00> : tensor -// PAD-NEXT: %cst_1 = stablehlo.constant dense<1.000000e+00> : tensor -// PAD-NEXT: %0 = stablehlo.pad %arg1, %cst_0, low = [0, 0, 8], high = [0, 0, 9], interior = [0, 0, 0] {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"z"}, {"y"}, {"x"}]>]>} : (tensor<20x24x80xf64>, tensor) -> tensor<20x24x97xf64> -// PAD-NEXT: %1 = stablehlo.pad %cst, %cst_1, low = [0, 0, 8], high = [0, 0, 9], interior = [0, 0, 0] {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"z"}, {"y"}, {"x"}]>]>} : (tensor<20x24x80xf64>, tensor) -> tensor<20x24x97xf64> -// PAD-NEXT: %2 = stablehlo.multiply %arg0, %1 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"z"}, {"y"}, {"x"}]>]>} : tensor<20x24x97xf64> -// PAD-NEXT: %3 = stablehlo.add %2, %0 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"z"}, {"y"}, {"x"}]>]>} : tensor<20x24x97xf64> -// PAD-NEXT: return %3 : tensor<20x24x97xf64> +// PAD-NEXT: %c = stablehlo.constant {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"z"}, {"y"}, {"x"}]>]>} dense : tensor<20x24x80xi1> +// PAD-NEXT: %cst = stablehlo.constant dense<0.000000e+00> : tensor +// PAD-NEXT: %c_0 = stablehlo.constant dense : tensor +// PAD-NEXT: %0 = stablehlo.pad %arg1, %cst, low = [0, 0, 8], high = [0, 0, 9], interior = [0, 0, 0] {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"z"}, {"y"}, {"x"}]>]>} : (tensor<20x24x80xf64>, tensor) -> tensor<20x24x97xf64> +// PAD-NEXT: %1 = stablehlo.pad %c, %c_0, low = [0, 0, 8], high = [0, 0, 9], interior = [0, 0, 0] {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"z"}, {"y"}, {"x"}]>]>} : (tensor<20x24x80xi1>, tensor) -> tensor<20x24x97xi1> +// PAD-NEXT: %2 = stablehlo.select %1, %arg0, %0 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"z"}, {"y"}, {"x"}]>]>} : tensor<20x24x97xi1>, tensor<20x24x97xf64> +// PAD-NEXT: return %2 : tensor<20x24x97xf64> // PAD-NEXT: } // CHECK: func.func @argUpdate1D(%arg0: tensor<20x24x97xf64> {sdy.sharding = #sdy.sharding<@mesh, [{"z"}, {"y"}, {"x"}]>}, %arg1: tensor<20x24x80xf64> {sdy.sharding = #sdy.sharding<@mesh, [{"z"}, {"y"}, {"x"}]>}) -> (tensor<20x24x97xf64> {sdy.sharding = #sdy.sharding<@mesh, [{"z"}, {"y"}, {"x"}]>}) { @@ -168,14 +169,13 @@ func.func @argUpdate(%arg21: tensor<20x24x97xf64> {sdy.sharding = #sdy.sharding< } // PAD: func.func @argUpdate(%arg0: tensor<20x24x97xf64> {sdy.sharding = #sdy.sharding<@mesh, [{"z"}, {"y"}, {"x"}]>}, %arg1: tensor<4x1x80xf64> {sdy.sharding = #sdy.sharding<@mesh, [{"z"}, {"y"}, {"x"}]>}) -> (tensor<20x24x97xf64> {sdy.sharding = #sdy.sharding<@mesh, [{"z"}, {"y"}, {"x"}]>}) { -// PAD-NEXT: %cst = stablehlo.constant {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"z"}, {"y"}, {"x"}]>]>} dense<0.000000e+00> : tensor<4x1x80xf64> -// PAD-NEXT: %cst_0 = stablehlo.constant dense<0.000000e+00> : tensor -// PAD-NEXT: %cst_1 = stablehlo.constant dense<1.000000e+00> : tensor -// PAD-NEXT: %0 = stablehlo.pad %arg1, %cst_0, low = [8, 8, 8], high = [8, 15, 9], interior = [0, 0, 0] {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"z"}, {"y"}, {"x"}]>]>} : (tensor<4x1x80xf64>, tensor) -> tensor<20x24x97xf64> -// PAD-NEXT: %1 = stablehlo.pad %cst, %cst_1, low = [8, 8, 8], high = [8, 15, 9], interior = [0, 0, 0] {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"z"}, {"y"}, {"x"}]>]>} : (tensor<4x1x80xf64>, tensor) -> tensor<20x24x97xf64> -// PAD-NEXT: %2 = stablehlo.multiply %arg0, %1 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"z"}, {"y"}, {"x"}]>]>} : tensor<20x24x97xf64> -// PAD-NEXT: %3 = stablehlo.add %2, %0 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"z"}, {"y"}, {"x"}]>]>} : tensor<20x24x97xf64> -// PAD-NEXT: return %3 : tensor<20x24x97xf64> +// PAD-NEXT: %c = stablehlo.constant {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"z"}, {"y"}, {"x"}]>]>} dense : tensor<4x1x80xi1> +// PAD-NEXT: %cst = stablehlo.constant dense<0.000000e+00> : tensor +// PAD-NEXT: %c_0 = stablehlo.constant dense : tensor +// PAD-NEXT: %0 = stablehlo.pad %arg1, %cst, low = [8, 8, 8], high = [8, 15, 9], interior = [0, 0, 0] {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"z"}, {"y"}, {"x"}]>]>} : (tensor<4x1x80xf64>, tensor) -> tensor<20x24x97xf64> +// PAD-NEXT: %1 = stablehlo.pad %c, %c_0, low = [8, 8, 8], high = [8, 15, 9], interior = [0, 0, 0] {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"z"}, {"y"}, {"x"}]>]>} : (tensor<4x1x80xi1>, tensor) -> tensor<20x24x97xi1> +// PAD-NEXT: %2 = stablehlo.select %1, %arg0, %0 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"z"}, {"y"}, {"x"}]>]>} : tensor<20x24x97xi1>, tensor<20x24x97xf64> +// PAD-NEXT: return %2 : tensor<20x24x97xf64> // PAD-NEXT: } // CHECK: func.func @argUpdate(%arg0: tensor<20x24x97xf64> {sdy.sharding = #sdy.sharding<@mesh, [{"z"}, {"y"}, {"x"}]>}, %arg1: tensor<4x1x80xf64> {sdy.sharding = #sdy.sharding<@mesh, [{"z"}, {"y"}, {"x"}]>}) -> (tensor<20x24x97xf64> {sdy.sharding = #sdy.sharding<@mesh, [{"z"}, {"y"}, {"x"}]>}) { From 27dc9aa9e63dafc8d6c93ea554c49ea688e0e73e Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Thu, 4 Sep 2025 10:13:55 -0400 Subject: [PATCH 7/8] fmt --- .../jax/Passes/OptimizeCommunication.cpp | 30 ++++++++++--------- 1 file changed, 16 insertions(+), 14 deletions(-) diff --git a/src/enzyme_ad/jax/Passes/OptimizeCommunication.cpp b/src/enzyme_ad/jax/Passes/OptimizeCommunication.cpp index 39ff26a2ed..cefee78e19 100644 --- a/src/enzyme_ad/jax/Passes/OptimizeCommunication.cpp +++ b/src/enzyme_ad/jax/Passes/OptimizeCommunication.cpp @@ -2251,30 +2251,31 @@ struct DUSToPadComm : public OpRewritePattern { auto splatUpdate = isSplat(update); auto splatOperand = isSplat(update); - auto updateType = cast(update.getType()); auto updateI1Type = RankedTensorType::get(updateType.getShape(), rewriter.getI1Type()); - auto zeroAttr = - DenseElementsAttr::get(updateI1Type, rewriter.getZeroAttr(rewriter.getI1Type())); + auto zeroAttr = DenseElementsAttr::get( + updateI1Type, rewriter.getZeroAttr(rewriter.getI1Type())); auto zeroUpdateOp = rewriter.create( dus.getLoc(), updateI1Type, zeroAttr); sdy::setSharding(zeroUpdateOp, sharding); - auto maskOp = rewriter.create( - dus.getLoc(), zeroUpdateOp, oneI1, updatePadLow, updatePadHigh, - padInner); + auto maskOp = rewriter.create(dus.getLoc(), zeroUpdateOp, + oneI1, updatePadLow, + updatePadHigh, padInner); Value resultV = nullptr; if (splatOperand) { - auto padTy = RankedTensorType::get({}, operand.getType().getElementType()); + auto padTy = + RankedTensorType::get({}, operand.getType().getElementType()); auto newOperand = rewriter.create( dus.getLoc(), padTy, splatOperand.resizeSplat(padTy)); auto maskedOperandOp = rewriter.create( dus.getLoc(), update, newOperand, updatePadLow, updatePadHigh, padInner); sdy::setSharding(maskedOperandOp, sharding); - } { + } + { Value newUpdate; if (splatUpdate) { newUpdate = rewriter.create( @@ -2288,10 +2289,11 @@ struct DUSToPadComm : public OpRewritePattern { newUpdate = newUpdateOp; } - auto updateI1Type = RankedTensorType::get(updateType.getShape(), rewriter.getI1Type()); + auto updateI1Type = + RankedTensorType::get(updateType.getShape(), rewriter.getI1Type()); - auto zeroAttr = - DenseElementsAttr::get(updateI1Type, rewriter.getZeroAttr(rewriter.getI1Type())); + auto zeroAttr = DenseElementsAttr::get( + updateI1Type, rewriter.getZeroAttr(rewriter.getI1Type())); auto zeroUpdateOp = rewriter.create( dus.getLoc(), updateI1Type, zeroAttr); sdy::setSharding(zeroUpdateOp, sharding); @@ -2302,13 +2304,13 @@ struct DUSToPadComm : public OpRewritePattern { sdy::setSharding(maskOp, sharding); - auto maskedOperandOp = - rewriter.create(dus.getLoc(), maskOp, operand, newUpdate); + auto maskedOperandOp = rewriter.create( + dus.getLoc(), maskOp, operand, newUpdate); sdy::setSharding(maskedOperandOp, sharding); resultV = maskedOperandOp; } - + rewriter.replaceOp(dus, resultV); return success(); } From 99ab797b23c99291a54438e60f0181f87924c804 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mos=C3=A8=20Giordano?= <765740+giordano@users.noreply.github.com> Date: Thu, 4 Sep 2025 17:43:21 +0200 Subject: [PATCH 8/8] Revert GB-25 workflow configuration --- .github/workflows/test-gb-25.yml | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/.github/workflows/test-gb-25.yml b/.github/workflows/test-gb-25.yml index 419cae6313..70ce23af4c 100644 --- a/.github/workflows/test-gb-25.yml +++ b/.github/workflows/test-gb-25.yml @@ -51,7 +51,6 @@ jobs: # - 'b25f3cbed2bc88c8ffef85f6a5319e2cf7b0454c' gb25_commit: - 'main' - - 'mg/oom-reproducer' # - '0123456789abcdef0123456789abcdef01234567' reactant_commit: - 'main' @@ -188,7 +187,7 @@ jobs: timeout-minutes: 60 run: | export XLA_FLAGS='--xla_dump_to=${{ env.GB25_DIR }}/xla_dump' - timeout --signal=TERM --verbose 59m mpiexecjl -np 1 julia --color=yes --project -O0 --startup-file=no --threads=16 --compiled-modules=strict sharding/sharded_baroclinic_instability_simulation_run.jl --grid-x=6144 --grid-y=1536 --grid-z=4 + timeout --signal=TERM --verbose 59m mpiexecjl -np 1 julia --color=yes --project -O0 --startup-file=no --threads=16 --compiled-modules=strict sharding/sharded_baroclinic_instability_simulation_run.jl working-directory: ${{ env.GB25_DIR }} - name: Test correctness in GB-25 code timeout-minutes: 20