@@ -2279,26 +2279,237 @@ class AutoDiffHLOReturn
22792279 }
22802280};
22812281
2282- class AutoDiffSort
2283- : public AutoDiffOpInterface::ExternalModel<AutoDiffSort, SortOp> {
2282+ stablehlo::SortOp
2283+ constructSortOpWithExtraOperands (OpBuilder &builder, stablehlo::SortOp original,
2284+ SmallVectorImpl<Value> &newOperands) {
2285+ auto newSortOp = stablehlo::SortOp::create (
2286+ builder, original.getLoc (), newOperands, original.getDimensionAttr (),
2287+ original.getIsStableAttr ());
2288+
2289+ IRMapping regionMapper;
2290+ auto &newComparator = newSortOp.getComparator ();
2291+ auto *newBlock = new Block ();
2292+ newComparator.push_back (newBlock);
2293+
2294+ {
2295+ SmallVector<Type> scalarArgTys;
2296+ for (auto arg : newOperands) {
2297+ auto elTy = RankedTensorType::get (
2298+ {}, cast<TensorType>(arg.getType ()).getElementType ());
2299+ scalarArgTys.push_back (elTy);
2300+ scalarArgTys.push_back (elTy);
2301+ }
2302+ newBlock->addArguments (
2303+ scalarArgTys,
2304+ SmallVector<Location>(scalarArgTys.size (), original.getLoc ()));
2305+ }
2306+
2307+ auto &origComparator = original.getComparator ();
2308+ auto &origBlock = origComparator.front ();
2309+
2310+ IRMapping mapper;
2311+ for (int64_t i = 0 ; i < origBlock.getNumArguments (); i++)
2312+ mapper.map (origBlock.getArgument (i), newBlock->getArgument (i));
2313+
2314+ {
2315+ OpBuilder::InsertionGuard guard (builder);
2316+ builder.setInsertionPointToStart (newBlock);
2317+ for (Operation &origOpInside : origBlock) {
2318+ builder.clone (origOpInside, mapper);
2319+ }
2320+ }
2321+
2322+ return newSortOp;
2323+ }
2324+
2325+ class AutoDiffSortFwd
2326+ : public AutoDiffOpInterface::ExternalModel<AutoDiffSortFwd, SortOp> {
22842327public:
22852328 LogicalResult createForwardModeTangent (Operation *op, OpBuilder &builder,
22862329 MGradientUtils *gutils) const {
2330+ if (gutils->width > 1 ) {
2331+ op->emitError (
2332+ " TODO: AutoDiffSortFwd does not support batched forward mode" );
2333+ return failure ();
2334+ }
22872335
2288- // TODO: we may need to record, for every successor, which of its inputs
2289- // need a shadow to recreate the body correctly.
2290- llvm::SmallDenseSet<unsigned > operandPositionsToShadow;
2291- llvm::SmallDenseSet<unsigned > resultPositionsToShadow;
2336+ auto sortOp = cast<stablehlo::SortOp>(op);
22922337
2293- for (auto res : op->getResults ())
2294- if (!gutils->isConstantValue (res)) {
2295- operandPositionsToShadow.insert (res.getResultNumber ());
2296- resultPositionsToShadow.insert (res.getResultNumber ());
2338+ DenseMap<int32_t , int32_t > gradMapping;
2339+
2340+ SmallVector<Value> newOperands;
2341+ for (auto operand : sortOp.getInputs ()) {
2342+ newOperands.push_back (gutils->getNewFromOriginal (operand));
2343+ }
2344+ for (auto [i, operand] : llvm::enumerate (sortOp.getInputs ())) {
2345+ if (!gutils->isConstantValue (operand)) {
2346+ newOperands.push_back (gutils->invertPointerM (operand, builder));
2347+ gradMapping[i] = newOperands.size () - 1 ;
22972348 }
2349+ }
22982350
2299- return mlir::enzyme::detail::controlFlowForwardHandler (
2300- op, builder, gutils, operandPositionsToShadow, resultPositionsToShadow);
2351+ auto newSortOp =
2352+ constructSortOpWithExtraOperands (builder, sortOp, newOperands);
2353+
2354+ SmallVector<Value> replacementResults (sortOp.getNumResults ());
2355+ for (int32_t i = 0 ; i < sortOp.getNumResults (); i++) {
2356+ replacementResults[i] = newSortOp.getResults ()[i];
2357+ auto origRes = sortOp.getResults ()[i];
2358+ if (!gutils->isConstantValue (origRes)) {
2359+ int32_t j = gradMapping[i];
2360+ gutils->setDiffe (origRes, newSortOp.getResults ()[j], builder);
2361+ }
2362+ }
2363+
2364+ gutils->replaceOrigOpWith (op, replacementResults);
2365+ gutils->originalToNewFnOps [op] = newSortOp;
2366+ gutils->eraseIfUnused (op);
2367+ return success ();
2368+ }
2369+ };
2370+
2371+ class AutoDiffSortRev
2372+ : public ReverseAutoDiffOpInterface::ExternalModel<AutoDiffSortRev,
2373+ stablehlo::SortOp> {
2374+ public:
2375+ LogicalResult createReverseModeAdjoint (Operation *orig, OpBuilder &builder,
2376+ MGradientUtilsReverse *gutils,
2377+ SmallVector<Value> caches) const {
2378+ auto sortOp = cast<stablehlo::SortOp>(orig);
2379+
2380+ if (gutils->width > 1 ) {
2381+ orig->emitError (
2382+ " TODO: AutoDiffSortRev does not support batched reverse mode" );
2383+ return failure ();
2384+ }
2385+
2386+ auto indices = gutils->popCache (caches[0 ], builder);
2387+ auto indicesTy = cast<RankedTensorType>(indices.getType ());
2388+
2389+ SmallVector<int64_t > newIndicesShape (indicesTy.getShape ().begin (),
2390+ indicesTy.getShape ().end ());
2391+ newIndicesShape.push_back (1 );
2392+
2393+ indices = stablehlo::ReshapeOp::create (
2394+ builder, orig->getLoc (),
2395+ RankedTensorType::get (newIndicesShape, indicesTy.getElementType ()),
2396+ indices);
2397+
2398+ auto inTy = cast<RankedTensorType>(orig->getOperand (0 ).getType ());
2399+ auto inRank = inTy.getRank ();
2400+ auto inShape = inTy.getShape ();
2401+
2402+ SmallVector<int64_t > batchingDims;
2403+ for (int32_t d = 0 ; d < inRank; d++) {
2404+ if (d != sortOp.getDimension ()) {
2405+ batchingDims.push_back (d);
2406+ }
2407+ }
2408+
2409+ auto scatterDims = stablehlo::ScatterDimensionNumbersAttr::get (
2410+ orig->getContext (), SmallVector<int64_t >(),
2411+ SmallVector<int64_t >{static_cast <int64_t >(sortOp.getDimension ())},
2412+ batchingDims, batchingDims,
2413+ SmallVector<int64_t >{static_cast <int64_t >(sortOp.getDimension ())},
2414+ indicesTy.getRank ());
2415+
2416+ for (size_t i = 0 ; i < orig->getNumResults (); i++) {
2417+ if (gutils->isConstantValue (orig->getResult (i)) ||
2418+ gutils->isConstantValue (orig->getOperand (i)))
2419+ continue ;
2420+
2421+ // we compute the gradients with scatter_add and then set the original
2422+ auto inDiffe = gutils->diffe (orig->getResult (i), builder);
2423+ auto inDiffeTy = cast<RankedTensorType>(inDiffe.getType ());
2424+ gutils->zeroDiffe (orig->getResult (i), builder);
2425+
2426+ auto outDiffe = gutils->diffe (orig->getOperand (i), builder);
2427+
2428+ Region combiner;
2429+ {
2430+ Block *block = new Block ();
2431+ combiner.push_back (block);
2432+ block->addArgument (
2433+ RankedTensorType::get ({}, inDiffeTy.getElementType ()),
2434+ orig->getLoc ());
2435+ block->addArgument (
2436+ RankedTensorType::get ({}, inDiffeTy.getElementType ()),
2437+ orig->getLoc ());
2438+ OpBuilder::InsertionGuard guard (builder);
2439+ builder.setInsertionPointToStart (block);
2440+ stablehlo::ReturnOp::create (
2441+ builder, orig->getLoc (),
2442+ ValueRange{stablehlo::AddOp::create (builder, orig->getLoc (),
2443+ block->getArgument (0 ),
2444+ block->getArgument (1 ))});
2445+ }
2446+
2447+ auto scatterOp = stablehlo::ScatterOp::create (
2448+ builder, orig->getLoc (), outDiffe, indices, inDiffe, scatterDims,
2449+ builder.getBoolAttr (false ), builder.getBoolAttr (true ));
2450+ scatterOp.getUpdateComputation ().takeBody (combiner);
2451+
2452+ gutils->setDiffe (orig->getOperand (i), scatterOp.getResults ()[0 ], builder);
2453+ }
2454+
2455+ return success ();
2456+ }
2457+
2458+ SmallVector<Value> cacheValues (Operation *orig,
2459+ MGradientUtilsReverse *gutils) const {
2460+ auto sortOp = cast<stablehlo::SortOp>(orig);
2461+
2462+ if (gutils->width > 1 )
2463+ return {};
2464+
2465+ bool allConstant = true ;
2466+ for (auto input : sortOp.getInputs ()) {
2467+ if (!gutils->isConstantValue (input)) {
2468+ allConstant = false ;
2469+ break ;
2470+ }
2471+ }
2472+
2473+ if (allConstant)
2474+ return {};
2475+
2476+ auto newOp = gutils->getNewFromOriginal (orig);
2477+ OpBuilder cacheBuilder (newOp);
2478+
2479+ SmallVector<Value> newOperands (sortOp.getInputs ().size () + 1 );
2480+ for (auto [i, operand] : llvm::enumerate (sortOp.getInputs ())) {
2481+ newOperands[i] = gutils->getNewFromOriginal (operand);
2482+ }
2483+ auto OpTy = cast<TensorType>(newOperands[0 ].getType ());
2484+ auto iotaOp = stablehlo::IotaOp::create (
2485+ cacheBuilder, orig->getLoc (),
2486+ RankedTensorType::get (OpTy.getShape (),
2487+ cacheBuilder.getIntegerType (32 , false )),
2488+ sortOp.getDimensionAttr ());
2489+ newOperands[newOperands.size () - 1 ] = iotaOp.getResult ();
2490+
2491+ auto newSortOp =
2492+ constructSortOpWithExtraOperands (cacheBuilder, sortOp, newOperands);
2493+ auto newResults = newSortOp.getResults ();
2494+
2495+ SmallVector<Value> caches;
2496+ caches.push_back (gutils->initAndPushCache (newResults[newResults.size () - 1 ],
2497+ cacheBuilder));
2498+
2499+ SmallVector<Value> replacements;
2500+ for (size_t i = 0 ; i < newResults.size () - 1 ; i++) {
2501+ replacements.push_back (newResults[i]);
2502+ }
2503+
2504+ gutils->replaceOrigOpWith (orig, replacements);
2505+ gutils->eraseIfUnused (orig);
2506+ gutils->originalToNewFnOps [orig] = newSortOp;
2507+
2508+ return caches;
23012509 }
2510+
2511+ void createShadowValues (Operation *op, OpBuilder &builder,
2512+ MGradientUtilsReverse *gutils) const {}
23022513};
23032514
23042515class AutoDiffBatchNormTrainingRev
@@ -3701,8 +3912,6 @@ void mlir::enzyme::registerStableHLODialectAutoDiffInterface(
37013912 stablehlo::StablehloDialect *) {
37023913 registerInterfaces (context);
37033914
3704- // SortOp::attachInterface<AutoDiffSort>(*context);
3705-
37063915 WhileOp::attachInterface<WhileOpEnzymeOpsRemover>(*context);
37073916 IfOp::attachInterface<IfOpEnzymeOpsRemover>(*context);
37083917
@@ -3722,6 +3931,8 @@ void mlir::enzyme::registerStableHLODialectAutoDiffInterface(
37223931 IfOp::attachInterface<AutoDiffIfFwd>(*context);
37233932 IfOp::attachInterface<AutoDiffIfCF>(*context);
37243933
3934+ SortOp::attachInterface<AutoDiffSortFwd>(*context);
3935+ SortOp::attachInterface<AutoDiffSortRev>(*context);
37253936 WhileOp::attachInterface<AutoDiffWhileFwd>(*context);
37263937 WhileOp::attachInterface<AutoDiffWhileRev>(*context);
37273938 ReduceOp::attachInterface<AutoDiffReduceCF<ReduceOp>>(*context);
0 commit comments