Skip to content

Commit 51d94b3

Browse files
authored
feat: sort autodiff rules (#1584)
* feat: sort forward mode AD * feat: reverse mode * refactor: move the common function * fix: derivative rule * test: update * fix: replace in cacheValues
1 parent f647a93 commit 51d94b3

File tree

3 files changed

+337
-14
lines changed

3 files changed

+337
-14
lines changed

src/enzyme_ad/jax/Implementations/StableHLOAutoDiffOpInterfaceImpl.cpp

Lines changed: 225 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2279,26 +2279,237 @@ class AutoDiffHLOReturn
22792279
}
22802280
};
22812281

2282-
class AutoDiffSort
2283-
: public AutoDiffOpInterface::ExternalModel<AutoDiffSort, SortOp> {
2282+
stablehlo::SortOp
2283+
constructSortOpWithExtraOperands(OpBuilder &builder, stablehlo::SortOp original,
2284+
SmallVectorImpl<Value> &newOperands) {
2285+
auto newSortOp = stablehlo::SortOp::create(
2286+
builder, original.getLoc(), newOperands, original.getDimensionAttr(),
2287+
original.getIsStableAttr());
2288+
2289+
IRMapping regionMapper;
2290+
auto &newComparator = newSortOp.getComparator();
2291+
auto *newBlock = new Block();
2292+
newComparator.push_back(newBlock);
2293+
2294+
{
2295+
SmallVector<Type> scalarArgTys;
2296+
for (auto arg : newOperands) {
2297+
auto elTy = RankedTensorType::get(
2298+
{}, cast<TensorType>(arg.getType()).getElementType());
2299+
scalarArgTys.push_back(elTy);
2300+
scalarArgTys.push_back(elTy);
2301+
}
2302+
newBlock->addArguments(
2303+
scalarArgTys,
2304+
SmallVector<Location>(scalarArgTys.size(), original.getLoc()));
2305+
}
2306+
2307+
auto &origComparator = original.getComparator();
2308+
auto &origBlock = origComparator.front();
2309+
2310+
IRMapping mapper;
2311+
for (int64_t i = 0; i < origBlock.getNumArguments(); i++)
2312+
mapper.map(origBlock.getArgument(i), newBlock->getArgument(i));
2313+
2314+
{
2315+
OpBuilder::InsertionGuard guard(builder);
2316+
builder.setInsertionPointToStart(newBlock);
2317+
for (Operation &origOpInside : origBlock) {
2318+
builder.clone(origOpInside, mapper);
2319+
}
2320+
}
2321+
2322+
return newSortOp;
2323+
}
2324+
2325+
class AutoDiffSortFwd
2326+
: public AutoDiffOpInterface::ExternalModel<AutoDiffSortFwd, SortOp> {
22842327
public:
22852328
LogicalResult createForwardModeTangent(Operation *op, OpBuilder &builder,
22862329
MGradientUtils *gutils) const {
2330+
if (gutils->width > 1) {
2331+
op->emitError(
2332+
"TODO: AutoDiffSortFwd does not support batched forward mode");
2333+
return failure();
2334+
}
22872335

2288-
// TODO: we may need to record, for every successor, which of its inputs
2289-
// need a shadow to recreate the body correctly.
2290-
llvm::SmallDenseSet<unsigned> operandPositionsToShadow;
2291-
llvm::SmallDenseSet<unsigned> resultPositionsToShadow;
2336+
auto sortOp = cast<stablehlo::SortOp>(op);
22922337

2293-
for (auto res : op->getResults())
2294-
if (!gutils->isConstantValue(res)) {
2295-
operandPositionsToShadow.insert(res.getResultNumber());
2296-
resultPositionsToShadow.insert(res.getResultNumber());
2338+
DenseMap<int32_t, int32_t> gradMapping;
2339+
2340+
SmallVector<Value> newOperands;
2341+
for (auto operand : sortOp.getInputs()) {
2342+
newOperands.push_back(gutils->getNewFromOriginal(operand));
2343+
}
2344+
for (auto [i, operand] : llvm::enumerate(sortOp.getInputs())) {
2345+
if (!gutils->isConstantValue(operand)) {
2346+
newOperands.push_back(gutils->invertPointerM(operand, builder));
2347+
gradMapping[i] = newOperands.size() - 1;
22972348
}
2349+
}
22982350

2299-
return mlir::enzyme::detail::controlFlowForwardHandler(
2300-
op, builder, gutils, operandPositionsToShadow, resultPositionsToShadow);
2351+
auto newSortOp =
2352+
constructSortOpWithExtraOperands(builder, sortOp, newOperands);
2353+
2354+
SmallVector<Value> replacementResults(sortOp.getNumResults());
2355+
for (int32_t i = 0; i < sortOp.getNumResults(); i++) {
2356+
replacementResults[i] = newSortOp.getResults()[i];
2357+
auto origRes = sortOp.getResults()[i];
2358+
if (!gutils->isConstantValue(origRes)) {
2359+
int32_t j = gradMapping[i];
2360+
gutils->setDiffe(origRes, newSortOp.getResults()[j], builder);
2361+
}
2362+
}
2363+
2364+
gutils->replaceOrigOpWith(op, replacementResults);
2365+
gutils->originalToNewFnOps[op] = newSortOp;
2366+
gutils->eraseIfUnused(op);
2367+
return success();
2368+
}
2369+
};
2370+
2371+
class AutoDiffSortRev
2372+
: public ReverseAutoDiffOpInterface::ExternalModel<AutoDiffSortRev,
2373+
stablehlo::SortOp> {
2374+
public:
2375+
LogicalResult createReverseModeAdjoint(Operation *orig, OpBuilder &builder,
2376+
MGradientUtilsReverse *gutils,
2377+
SmallVector<Value> caches) const {
2378+
auto sortOp = cast<stablehlo::SortOp>(orig);
2379+
2380+
if (gutils->width > 1) {
2381+
orig->emitError(
2382+
"TODO: AutoDiffSortRev does not support batched reverse mode");
2383+
return failure();
2384+
}
2385+
2386+
auto indices = gutils->popCache(caches[0], builder);
2387+
auto indicesTy = cast<RankedTensorType>(indices.getType());
2388+
2389+
SmallVector<int64_t> newIndicesShape(indicesTy.getShape().begin(),
2390+
indicesTy.getShape().end());
2391+
newIndicesShape.push_back(1);
2392+
2393+
indices = stablehlo::ReshapeOp::create(
2394+
builder, orig->getLoc(),
2395+
RankedTensorType::get(newIndicesShape, indicesTy.getElementType()),
2396+
indices);
2397+
2398+
auto inTy = cast<RankedTensorType>(orig->getOperand(0).getType());
2399+
auto inRank = inTy.getRank();
2400+
auto inShape = inTy.getShape();
2401+
2402+
SmallVector<int64_t> batchingDims;
2403+
for (int32_t d = 0; d < inRank; d++) {
2404+
if (d != sortOp.getDimension()) {
2405+
batchingDims.push_back(d);
2406+
}
2407+
}
2408+
2409+
auto scatterDims = stablehlo::ScatterDimensionNumbersAttr::get(
2410+
orig->getContext(), SmallVector<int64_t>(),
2411+
SmallVector<int64_t>{static_cast<int64_t>(sortOp.getDimension())},
2412+
batchingDims, batchingDims,
2413+
SmallVector<int64_t>{static_cast<int64_t>(sortOp.getDimension())},
2414+
indicesTy.getRank());
2415+
2416+
for (size_t i = 0; i < orig->getNumResults(); i++) {
2417+
if (gutils->isConstantValue(orig->getResult(i)) ||
2418+
gutils->isConstantValue(orig->getOperand(i)))
2419+
continue;
2420+
2421+
// we compute the gradients with scatter_add and then set the original
2422+
auto inDiffe = gutils->diffe(orig->getResult(i), builder);
2423+
auto inDiffeTy = cast<RankedTensorType>(inDiffe.getType());
2424+
gutils->zeroDiffe(orig->getResult(i), builder);
2425+
2426+
auto outDiffe = gutils->diffe(orig->getOperand(i), builder);
2427+
2428+
Region combiner;
2429+
{
2430+
Block *block = new Block();
2431+
combiner.push_back(block);
2432+
block->addArgument(
2433+
RankedTensorType::get({}, inDiffeTy.getElementType()),
2434+
orig->getLoc());
2435+
block->addArgument(
2436+
RankedTensorType::get({}, inDiffeTy.getElementType()),
2437+
orig->getLoc());
2438+
OpBuilder::InsertionGuard guard(builder);
2439+
builder.setInsertionPointToStart(block);
2440+
stablehlo::ReturnOp::create(
2441+
builder, orig->getLoc(),
2442+
ValueRange{stablehlo::AddOp::create(builder, orig->getLoc(),
2443+
block->getArgument(0),
2444+
block->getArgument(1))});
2445+
}
2446+
2447+
auto scatterOp = stablehlo::ScatterOp::create(
2448+
builder, orig->getLoc(), outDiffe, indices, inDiffe, scatterDims,
2449+
builder.getBoolAttr(false), builder.getBoolAttr(true));
2450+
scatterOp.getUpdateComputation().takeBody(combiner);
2451+
2452+
gutils->setDiffe(orig->getOperand(i), scatterOp.getResults()[0], builder);
2453+
}
2454+
2455+
return success();
2456+
}
2457+
2458+
SmallVector<Value> cacheValues(Operation *orig,
2459+
MGradientUtilsReverse *gutils) const {
2460+
auto sortOp = cast<stablehlo::SortOp>(orig);
2461+
2462+
if (gutils->width > 1)
2463+
return {};
2464+
2465+
bool allConstant = true;
2466+
for (auto input : sortOp.getInputs()) {
2467+
if (!gutils->isConstantValue(input)) {
2468+
allConstant = false;
2469+
break;
2470+
}
2471+
}
2472+
2473+
if (allConstant)
2474+
return {};
2475+
2476+
auto newOp = gutils->getNewFromOriginal(orig);
2477+
OpBuilder cacheBuilder(newOp);
2478+
2479+
SmallVector<Value> newOperands(sortOp.getInputs().size() + 1);
2480+
for (auto [i, operand] : llvm::enumerate(sortOp.getInputs())) {
2481+
newOperands[i] = gutils->getNewFromOriginal(operand);
2482+
}
2483+
auto OpTy = cast<TensorType>(newOperands[0].getType());
2484+
auto iotaOp = stablehlo::IotaOp::create(
2485+
cacheBuilder, orig->getLoc(),
2486+
RankedTensorType::get(OpTy.getShape(),
2487+
cacheBuilder.getIntegerType(32, false)),
2488+
sortOp.getDimensionAttr());
2489+
newOperands[newOperands.size() - 1] = iotaOp.getResult();
2490+
2491+
auto newSortOp =
2492+
constructSortOpWithExtraOperands(cacheBuilder, sortOp, newOperands);
2493+
auto newResults = newSortOp.getResults();
2494+
2495+
SmallVector<Value> caches;
2496+
caches.push_back(gutils->initAndPushCache(newResults[newResults.size() - 1],
2497+
cacheBuilder));
2498+
2499+
SmallVector<Value> replacements;
2500+
for (size_t i = 0; i < newResults.size() - 1; i++) {
2501+
replacements.push_back(newResults[i]);
2502+
}
2503+
2504+
gutils->replaceOrigOpWith(orig, replacements);
2505+
gutils->eraseIfUnused(orig);
2506+
gutils->originalToNewFnOps[orig] = newSortOp;
2507+
2508+
return caches;
23012509
}
2510+
2511+
void createShadowValues(Operation *op, OpBuilder &builder,
2512+
MGradientUtilsReverse *gutils) const {}
23022513
};
23032514

23042515
class AutoDiffBatchNormTrainingRev
@@ -3701,8 +3912,6 @@ void mlir::enzyme::registerStableHLODialectAutoDiffInterface(
37013912
stablehlo::StablehloDialect *) {
37023913
registerInterfaces(context);
37033914

3704-
// SortOp::attachInterface<AutoDiffSort>(*context);
3705-
37063915
WhileOp::attachInterface<WhileOpEnzymeOpsRemover>(*context);
37073916
IfOp::attachInterface<IfOpEnzymeOpsRemover>(*context);
37083917

@@ -3722,6 +3931,8 @@ void mlir::enzyme::registerStableHLODialectAutoDiffInterface(
37223931
IfOp::attachInterface<AutoDiffIfFwd>(*context);
37233932
IfOp::attachInterface<AutoDiffIfCF>(*context);
37243933

3934+
SortOp::attachInterface<AutoDiffSortFwd>(*context);
3935+
SortOp::attachInterface<AutoDiffSortRev>(*context);
37253936
WhileOp::attachInterface<AutoDiffWhileFwd>(*context);
37263937
WhileOp::attachInterface<AutoDiffWhileRev>(*context);
37273938
ReduceOp::attachInterface<AutoDiffReduceCF<ReduceOp>>(*context);
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
// RUN: enzymexlamlir-opt %s --enzyme-wrap="infn=main outfn= retTys=enzyme_dup argTys=enzyme_dup mode=ForwardMode" --canonicalize --remove-unnecessary-enzyme-ops --enzyme-simplify-math | FileCheck %s --check-prefix=FORWARD
2+
// RUN: enzymexlamlir-opt %s --enzyme-wrap="infn=main outfn= retTys=enzyme_active argTys=enzyme_active mode=ReverseModeCombined" --canonicalize --remove-unnecessary-enzyme-ops --enzyme-simplify-math | FileCheck %s --check-prefix=REVERSE
3+
4+
func.func @main(%arg0: tensor<8x4xf64>) -> (tensor<8x4xf64>) {
5+
%0 = "stablehlo.sort"(%arg0) <{dimension = 0 : i64, is_stable = false}> ({
6+
^bb0(%arg1: tensor<f64>, %arg2: tensor<f64>):
7+
%1 = stablehlo.compare LT, %arg1, %arg2 : (tensor<f64>, tensor<f64>) -> tensor<i1>
8+
stablehlo.return %1 : tensor<i1>
9+
}) : (tensor<8x4xf64>) -> tensor<8x4xf64>
10+
return %0 : tensor<8x4xf64>
11+
}
12+
13+
// FORWARD: func.func @main(%arg0: tensor<8x4xf64>, %arg1: tensor<8x4xf64>) -> (tensor<8x4xf64>, tensor<8x4xf64>) {
14+
// FORWARD-NEXT: %0:2 = "stablehlo.sort"(%arg0, %arg1) <{dimension = 0 : i64, is_stable = false}> ({
15+
// FORWARD-NEXT: ^bb0(%arg2: tensor<f64>, %arg3: tensor<f64>, %arg4: tensor<f64>, %arg5: tensor<f64>):
16+
// FORWARD-NEXT: %1 = stablehlo.compare LT, %arg2, %arg3 : (tensor<f64>, tensor<f64>) -> tensor<i1>
17+
// FORWARD-NEXT: stablehlo.return %1 : tensor<i1>
18+
// FORWARD-NEXT: }) : (tensor<8x4xf64>, tensor<8x4xf64>) -> (tensor<8x4xf64>, tensor<8x4xf64>)
19+
// FORWARD-NEXT: return %0#0, %0#1 : tensor<8x4xf64>, tensor<8x4xf64>
20+
// FORWARD-NEXT: }
21+
22+
// REVERSE: func.func @main(%arg0: tensor<8x4xf64>, %arg1: tensor<8x4xf64>) -> tensor<8x4xf64> {
23+
// REVERSE-NEXT: %cst = arith.constant dense<0.000000e+00> : tensor<8x4xf64>
24+
// REVERSE-NEXT: %0 = stablehlo.iota dim = 0 : tensor<8x4xui32>
25+
// REVERSE-NEXT: %1:2 = "stablehlo.sort"(%arg0, %0) <{dimension = 0 : i64, is_stable = false}> ({
26+
// REVERSE-NEXT: ^bb0(%arg2: tensor<f64>, %arg3: tensor<f64>, %arg4: tensor<ui32>, %arg5: tensor<ui32>):
27+
// REVERSE-NEXT: %4 = stablehlo.compare LT, %arg2, %arg3 : (tensor<f64>, tensor<f64>) -> tensor<i1>
28+
// REVERSE-NEXT: stablehlo.return %4 : tensor<i1>
29+
// REVERSE-NEXT: }) : (tensor<8x4xf64>, tensor<8x4xui32>) -> (tensor<8x4xf64>, tensor<8x4xui32>)
30+
// REVERSE-NEXT: %2 = stablehlo.reshape %1#1 : (tensor<8x4xui32>) -> tensor<8x4x1xui32>
31+
// REVERSE-NEXT: %3 = "stablehlo.scatter"(%cst, %2, %arg1) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter<inserted_window_dims = [0], input_batching_dims = [1], scatter_indices_batching_dims = [1], scatter_dims_to_operand_dims = [0], index_vector_dim = 2>, unique_indices = true}> ({
32+
// REVERSE-NEXT: ^bb0(%arg2: tensor<f64>, %arg3: tensor<f64>):
33+
// REVERSE-NEXT: %4 = stablehlo.add %arg2, %arg3 : tensor<f64>
34+
// REVERSE-NEXT: stablehlo.return %4 : tensor<f64>
35+
// REVERSE-NEXT: }) : (tensor<8x4xf64>, tensor<8x4x1xui32>, tensor<8x4xf64>) -> tensor<8x4xf64>
36+
// REVERSE-NEXT: return %3 : tensor<8x4xf64>
37+
// REVERSE-NEXT: }
Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
// RUN: enzymexlamlir-opt --enzyme --canonicalize --remove-unnecessary-enzyme-ops --enzyme-simplify-math --inline --enzyme-hlo-opt %s | FileCheck %s
2+
3+
module {
4+
func.func private @sort(%arg0: tensor<2xf64> {enzymexla.memory_effects = []}) -> (tensor<2xf64>, tensor<2xf64>) attributes {enzymexla.memory_effects = []} {
5+
%0 = "stablehlo.sort"(%arg0) <{dimension = 0 : i64, is_stable = false}> ({
6+
^bb0(%arg1: tensor<f64>, %arg2: tensor<f64>):
7+
%1 = stablehlo.compare LT, %arg1, %arg2 : (tensor<f64>, tensor<f64>) -> tensor<i1>
8+
stablehlo.return %1 : tensor<i1>
9+
}) : (tensor<2xf64>) -> tensor<2xf64>
10+
return %0, %arg0 : tensor<2xf64>, tensor<2xf64>
11+
}
12+
func.func @main(%arg0: tensor<2xf64>) -> (tensor<2xf64>, tensor<2xf64>) {
13+
%cst = stablehlo.constant dense<1.000000e+00> : tensor<2xf64>
14+
%cst_0 = stablehlo.constant dense<0.000000e+00> : tensor<2xf64>
15+
%0:2 = enzyme.autodiff @sort(%arg0, %cst, %cst_0) {activity = [#enzyme<activity enzyme_active>], ret_activity = [#enzyme<activity enzyme_activenoneed>, #enzyme<activity enzyme_active>]} : (tensor<2xf64>, tensor<2xf64>, tensor<2xf64>) -> (tensor<2xf64>, tensor<2xf64>)
16+
return %0#1, %0#0 : tensor<2xf64>, tensor<2xf64>
17+
}
18+
}
19+
20+
// CHECK: func.func @main(%arg0: tensor<2xf64>) -> (tensor<2xf64>, tensor<2xf64>) {
21+
// CHECK-NEXT: %cst = stablehlo.constant dense<1.000000e+00> : tensor<f64>
22+
// CHECK-NEXT: %cst_0 = stablehlo.constant dense<1.000000e+00> : tensor<2xf64>
23+
// CHECK-NEXT: %cst_1 = stablehlo.constant dense<0.000000e+00> : tensor<2xf64>
24+
// CHECK-NEXT: %c = stablehlo.constant dense<[0, 1]> : tensor<2xui32>
25+
// CHECK-NEXT: %0:2 = "stablehlo.sort"(%arg0, %c) <{dimension = 0 : i64, is_stable = false}> ({
26+
// CHECK-NEXT: ^bb0(%arg1: tensor<f64>, %arg2: tensor<f64>, %arg3: tensor<ui32>, %arg4: tensor<ui32>):
27+
// CHECK-NEXT: %3 = stablehlo.compare LT, %arg1, %arg2 : (tensor<f64>, tensor<f64>) -> tensor<i1>
28+
// CHECK-NEXT: stablehlo.return %3 : tensor<i1>
29+
// CHECK-NEXT: }) : (tensor<2xf64>, tensor<2xui32>) -> (tensor<2xf64>, tensor<2xui32>)
30+
// CHECK-NEXT: %1 = stablehlo.reshape %0#1 : (tensor<2xui32>) -> tensor<2x1xui32>
31+
// CHECK-NEXT: %2 = "stablehlo.scatter"(%cst_1, %1, %cst_0) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter<inserted_window_dims = [0], scatter_dims_to_operand_dims = [0], index_vector_dim = 1>, unique_indices = true}> ({
32+
// CHECK-NEXT: ^bb0(%arg1: tensor<f64>, %arg2: tensor<f64>):
33+
// CHECK-NEXT: stablehlo.return %cst : tensor<f64>
34+
// CHECK-NEXT: }) : (tensor<2xf64>, tensor<2x1xui32>, tensor<2xf64>) -> tensor<2xf64>
35+
// CHECK-NEXT: return %2, %arg0 : tensor<2xf64>, tensor<2xf64>
36+
// CHECK-NEXT: }
37+
38+
module {
39+
func.func private @sort(%arg0: tensor<5x4x3x2xf32>) -> (tensor<f32>, tensor<5x4x3x2xf32>) {
40+
%cst = stablehlo.constant dense<0.000000e+00> : tensor<f32>
41+
%0 = stablehlo.transpose %arg0, dims = [3, 2, 1, 0] : (tensor<5x4x3x2xf32>) -> tensor<2x3x4x5xf32>
42+
%1 = "stablehlo.sort"(%0) <{dimension = 2 : i64, is_stable = false}> ({
43+
^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>):
44+
%4 = stablehlo.compare LT, %arg1, %arg2 : (tensor<f32>, tensor<f32>) -> tensor<i1>
45+
stablehlo.return %4 : tensor<i1>
46+
}) : (tensor<2x3x4x5xf32>) -> tensor<2x3x4x5xf32>
47+
%2 = stablehlo.multiply %1, %1 : tensor<2x3x4x5xf32>
48+
%3 = stablehlo.reduce(%2 init: %cst) applies stablehlo.add across dimensions = [0, 1, 2, 3] : (tensor<2x3x4x5xf32>, tensor<f32>) -> tensor<f32>
49+
return %3, %arg0 : tensor<f32>, tensor<5x4x3x2xf32>
50+
}
51+
func.func @main(%arg0: tensor<5x4x3x2xf32>) -> (tensor<5x4x3x2xf32>, tensor<5x4x3x2xf32>) {
52+
%cst = stablehlo.constant dense<1.000000e+00> : tensor<f32>
53+
%0:2 = enzyme.autodiff @sort(%arg0, %cst) {activity = [#enzyme<activity enzyme_active>], ret_activity = [#enzyme<activity enzyme_activenoneed>, #enzyme<activity enzyme_const>]} : (tensor<5x4x3x2xf32>, tensor<f32>) -> (tensor<5x4x3x2xf32>, tensor<5x4x3x2xf32>)
54+
return %0#1, %0#0 : tensor<5x4x3x2xf32>, tensor<5x4x3x2xf32>
55+
}
56+
}
57+
58+
// CHECK: func.func @main(%arg0: tensor<5x4x3x2xf32>) -> (tensor<5x4x3x2xf32>, tensor<5x4x3x2xf32>) {
59+
// CHECK-NEXT: %c = stablehlo.constant dense<"0x000000000000000000000000000000000000000001000000010000000100000001000000010000000200000002000000020000000200000002000000030000000300000003000000030000000300000000000000000000000000000000000000000000000100000001000000010000000100000001000000020000000200000002000000020000000200000003000000030000000300000003000000030000000000000000000000000000000000000000000000010000000100000001000000010000000100000002000000020000000200000002000000020000000300000003000000030000000300000003000000000000000000000000000000000000000000000001000000010000000100000001000000010000000200000002000000020000000200000002000000030000000300000003000000030000000300000000000000000000000000000000000000000000000100000001000000010000000100000001000000020000000200000002000000020000000200000003000000030000000300000003000000030000000000000000000000000000000000000000000000010000000100000001000000010000000100000002000000020000000200000002000000020000000300000003000000030000000300000003000000"> : tensor<2x3x4x5xui32>
60+
// CHECK-NEXT: %cst = arith.constant dense<0.000000e+00> : tensor<2x3x4x5xf32>
61+
// CHECK-NEXT: %0 = stablehlo.transpose %arg0, dims = [3, 2, 1, 0] : (tensor<5x4x3x2xf32>) -> tensor<2x3x4x5xf32>
62+
// CHECK-NEXT: %1:2 = "stablehlo.sort"(%0, %c) <{dimension = 2 : i64, is_stable = false}> ({
63+
// CHECK-NEXT: ^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>, %arg3: tensor<ui32>, %arg4: tensor<ui32>):
64+
// CHECK-NEXT: %6 = stablehlo.compare LT, %arg1, %arg2 : (tensor<f32>, tensor<f32>) -> tensor<i1>
65+
// CHECK-NEXT: stablehlo.return %6 : tensor<i1>
66+
// CHECK-NEXT: }) : (tensor<2x3x4x5xf32>, tensor<2x3x4x5xui32>) -> (tensor<2x3x4x5xf32>, tensor<2x3x4x5xui32>)
67+
// CHECK-NEXT: %2 = arith.addf %1#0, %1#0 : tensor<2x3x4x5xf32>
68+
// CHECK-NEXT: %3 = stablehlo.reshape %1#1 : (tensor<2x3x4x5xui32>) -> tensor<2x3x4x5x1xui32>
69+
// CHECK-NEXT: %4 = "stablehlo.scatter"(%cst, %3, %2) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter<inserted_window_dims = [2], input_batching_dims = [0, 1, 3], scatter_indices_batching_dims = [0, 1, 3], scatter_dims_to_operand_dims = [2], index_vector_dim = 4>, unique_indices = true}> ({
70+
// CHECK-NEXT: ^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>):
71+
// CHECK-NEXT: stablehlo.return %arg2 : tensor<f32>
72+
// CHECK-NEXT: }) : (tensor<2x3x4x5xf32>, tensor<2x3x4x5x1xui32>, tensor<2x3x4x5xf32>) -> tensor<2x3x4x5xf32>
73+
// CHECK-NEXT: %5 = stablehlo.transpose %4, dims = [3, 2, 1, 0] : (tensor<2x3x4x5xf32>) -> tensor<5x4x3x2xf32>
74+
// CHECK-NEXT: return %5, %arg0 : tensor<5x4x3x2xf32>, tensor<5x4x3x2xf32>
75+
// CHECK-NEXT: }

0 commit comments

Comments
 (0)