Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MLIR][OpenMP] Add codegen for teams reductions #133310

Open
wants to merge 9 commits into
base: main
Choose a base branch
from

Conversation

jsjodin
Copy link
Contributor

@jsjodin jsjodin commented Mar 27, 2025

This patch adds the lowering of teams reductions from the omp dialect to LLVM-IR. Some minor cleanup was done in clang to remove an unused parameter.

@jsjodin jsjodin requested review from ergawy, skatrak, tblah and TIFitis March 27, 2025 20:35
@llvmbot llvmbot added clang Clang issues not falling into any other category clang:codegen IR generation bugs: mangling, exceptions, etc. mlir:llvm mlir mlir:openmp flang:openmp clang:openmp OpenMP related changes to Clang offload labels Mar 27, 2025
@llvmbot
Copy link
Member

llvmbot commented Mar 27, 2025

@llvm/pr-subscribers-offload
@llvm/pr-subscribers-flang-openmp
@llvm/pr-subscribers-clang-codegen

@llvm/pr-subscribers-mlir-llvm

Author: Jan Leyonberg (jsjodin)

Changes

This patch adds the lowering of teams reductions from the omp dialect to LLVM-IR. Some minor cleanup was done in clang to remove an unused parameter.


Patch is 39.18 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/133310.diff

11 Files Affected:

  • (modified) clang/lib/CodeGen/CGOpenMPRuntimeGPU.cpp (+1-2)
  • (modified) llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h (+4-4)
  • (modified) llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp (+135-56)
  • (modified) mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp (+185-23)
  • (modified) mlir/test/Target/LLVMIR/omptarget-parallel-wsloop.mlir (+1-1)
  • (modified) mlir/test/Target/LLVMIR/omptarget-wsloop-collapsed.mlir (+1-1)
  • (modified) mlir/test/Target/LLVMIR/omptarget-wsloop.mlir (+2-2)
  • (added) mlir/test/Target/LLVMIR/openmp-teams-reduction.mlir (+71)
  • (modified) mlir/test/Target/LLVMIR/openmp-todo.mlir (-28)
  • (added) offload/test/offloading/fortran/basic-target-parallel-reduction.f90 (+27)
  • (added) offload/test/offloading/fortran/basic-target-teams-parallel-reduction.f90 (+27)
diff --git a/clang/lib/CodeGen/CGOpenMPRuntimeGPU.cpp b/clang/lib/CodeGen/CGOpenMPRuntimeGPU.cpp
index feb2448297542..d30bef9e7f0ba 100644
--- a/clang/lib/CodeGen/CGOpenMPRuntimeGPU.cpp
+++ b/clang/lib/CodeGen/CGOpenMPRuntimeGPU.cpp
@@ -1659,7 +1659,6 @@ void CGOpenMPRuntimeGPU::emitReduction(
     return;
 
   bool ParallelReduction = isOpenMPParallelDirective(Options.ReductionKind);
-  bool DistributeReduction = isOpenMPDistributeDirective(Options.ReductionKind);
   bool TeamsReduction = isOpenMPTeamsDirective(Options.ReductionKind);
 
   ASTContext &C = CGM.getContext();
@@ -1756,7 +1755,7 @@ void CGOpenMPRuntimeGPU::emitReduction(
   llvm::OpenMPIRBuilder::InsertPointTy AfterIP =
       cantFail(OMPBuilder.createReductionsGPU(
           OmpLoc, AllocaIP, CodeGenIP, ReductionInfos, false, TeamsReduction,
-          DistributeReduction, llvm::OpenMPIRBuilder::ReductionGenCBKind::Clang,
+          llvm::OpenMPIRBuilder::ReductionGenCBKind::Clang,
           CGF.getTarget().getGridValue(),
           C.getLangOpts().OpenMPCUDAReductionBufNum, RTLoc));
   CGF.Builder.restoreIP(AfterIP);
diff --git a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
index 28909cef4748d..a3a266e3f0a98 100644
--- a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
+++ b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
@@ -1905,8 +1905,6 @@ class OpenMPIRBuilder {
   ///                           nowait.
   /// \param IsTeamsReduction   Optional flag set if it is a teams
   ///                           reduction.
-  /// \param HasDistribute      Optional flag set if it is a
-  ///                           distribute reduction.
   /// \param GridValue          Optional GPU grid value.
   /// \param ReductionBufNum    Optional OpenMPCUDAReductionBufNumValue to be
   /// used for teams reduction.
@@ -1915,7 +1913,6 @@ class OpenMPIRBuilder {
       const LocationDescription &Loc, InsertPointTy AllocaIP,
       InsertPointTy CodeGenIP, ArrayRef<ReductionInfo> ReductionInfos,
       bool IsNoWait = false, bool IsTeamsReduction = false,
-      bool HasDistribute = false,
       ReductionGenCBKind ReductionGenCBKind = ReductionGenCBKind::MLIR,
       std::optional<omp::GV> GridValue = {}, unsigned ReductionBufNum = 1024,
       Value *SrcLocInfo = nullptr);
@@ -1987,7 +1984,8 @@ class OpenMPIRBuilder {
                                         InsertPointTy AllocaIP,
                                         ArrayRef<ReductionInfo> ReductionInfos,
                                         ArrayRef<bool> IsByRef,
-                                        bool IsNoWait = false);
+                                        bool IsNoWait = false,
+                                        bool IsTeamsReduction = false);
 
   ///}
 
@@ -2271,6 +2269,8 @@ class OpenMPIRBuilder {
     int32_t MinTeams = 1;
     SmallVector<int32_t, 3> MaxThreads = {-1};
     int32_t MinThreads = 1;
+    int32_t ReductionDataSize = 0;
+    int32_t ReductionBufferLength = 0;
   };
 
   /// Container to pass LLVM IR runtime values or constants related to the
diff --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
index 2e5ce5308eea5..b5e55dbccf464 100644
--- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
+++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
@@ -3495,9 +3495,9 @@ checkReductionInfos(ArrayRef<OpenMPIRBuilder::ReductionInfo> ReductionInfos,
 OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createReductionsGPU(
     const LocationDescription &Loc, InsertPointTy AllocaIP,
     InsertPointTy CodeGenIP, ArrayRef<ReductionInfo> ReductionInfos,
-    bool IsNoWait, bool IsTeamsReduction, bool HasDistribute,
-    ReductionGenCBKind ReductionGenCBKind, std::optional<omp::GV> GridValue,
-    unsigned ReductionBufNum, Value *SrcLocInfo) {
+    bool IsNoWait, bool IsTeamsReduction, ReductionGenCBKind ReductionGenCBKind,
+    std::optional<omp::GV> GridValue, unsigned ReductionBufNum,
+    Value *SrcLocInfo) {
   if (!updateToLocation(Loc))
     return InsertPointTy();
   Builder.restoreIP(CodeGenIP);
@@ -3514,6 +3514,16 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createReductionsGPU(
   if (ReductionInfos.size() == 0)
     return Builder.saveIP();
 
+  BasicBlock *ContinuationBlock = nullptr;
+  if (ReductionGenCBKind != ReductionGenCBKind::Clang) {
+    // Copied code from createReductions
+    BasicBlock *InsertBlock = Loc.IP.getBlock();
+    ContinuationBlock =
+        InsertBlock->splitBasicBlock(Loc.IP.getPoint(), "reduce.finalize");
+    InsertBlock->getTerminator()->eraseFromParent();
+    Builder.SetInsertPoint(InsertBlock, InsertBlock->end());
+  }
+
   Function *CurFunc = Builder.GetInsertBlock()->getParent();
   AttributeList FuncAttrs;
   AttrBuilder AttrBldr(Ctx);
@@ -3669,11 +3679,21 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createReductionsGPU(
                ReductionFunc;
       });
     } else {
-      assert(false && "Unhandled ReductionGenCBKind");
+      Value *LHSValue = Builder.CreateLoad(RI.ElementType, LHS, "final.lhs");
+      Value *RHSValue = Builder.CreateLoad(RI.ElementType, RHS, "final.rhs");
+      Value *Reduced;
+      InsertPointOrErrorTy AfterIP =
+          RI.ReductionGen(Builder.saveIP(), RHSValue, LHSValue, Reduced);
+      if (!AfterIP)
+        return AfterIP.takeError();
+      Builder.CreateStore(Reduced, LHS, false);
     }
   }
   emitBlock(ExitBB, CurFunc);
-
+  if (ContinuationBlock) {
+    Builder.CreateBr(ContinuationBlock);
+    Builder.SetInsertPoint(ContinuationBlock);
+  }
   Config.setEmitLLVMUsed();
 
   return Builder.saveIP();
@@ -3688,27 +3708,95 @@ static Function *getFreshReductionFunc(Module &M) {
                           ".omp.reduction.func", &M);
 }
 
-OpenMPIRBuilder::InsertPointOrErrorTy
-OpenMPIRBuilder::createReductions(const LocationDescription &Loc,
-                                  InsertPointTy AllocaIP,
-                                  ArrayRef<ReductionInfo> ReductionInfos,
-                                  ArrayRef<bool> IsByRef, bool IsNoWait) {
-  assert(ReductionInfos.size() == IsByRef.size());
-  for (const ReductionInfo &RI : ReductionInfos) {
-    (void)RI;
-    assert(RI.Variable && "expected non-null variable");
-    assert(RI.PrivateVariable && "expected non-null private variable");
-    assert(RI.ReductionGen && "expected non-null reduction generator callback");
-    assert(RI.Variable->getType() == RI.PrivateVariable->getType() &&
-           "expected variables and their private equivalents to have the same "
-           "type");
-    assert(RI.Variable->getType()->isPointerTy() &&
-           "expected variables to be pointers");
+static Error populateReductionFunction(
+    Function *ReductionFunc,
+    ArrayRef<OpenMPIRBuilder::ReductionInfo> ReductionInfos,
+    IRBuilder<> &Builder, ArrayRef<bool> IsByRef, bool IsGPU) {
+  Module *Module = ReductionFunc->getParent();
+  BasicBlock *ReductionFuncBlock =
+      BasicBlock::Create(Module->getContext(), "", ReductionFunc);
+  Builder.SetInsertPoint(ReductionFuncBlock);
+  Value *LHSArrayPtr = nullptr;
+  Value *RHSArrayPtr = nullptr;
+  if (IsGPU) {
+    // Need to alloca memory here and deal with the pointers before getting
+    // LHS/RHS pointers out
+    //
+    Argument *Arg0 = ReductionFunc->getArg(0);
+    Argument *Arg1 = ReductionFunc->getArg(1);
+    Type *Arg0Type = Arg0->getType();
+    Type *Arg1Type = Arg1->getType();
+
+    Value *LHSAlloca =
+        Builder.CreateAlloca(Arg0Type, nullptr, Arg0->getName() + ".addr");
+    Value *RHSAlloca =
+        Builder.CreateAlloca(Arg1Type, nullptr, Arg1->getName() + ".addr");
+    Value *LHSAddrCast =
+        Builder.CreatePointerBitCastOrAddrSpaceCast(LHSAlloca, Arg0Type);
+    Value *RHSAddrCast =
+        Builder.CreatePointerBitCastOrAddrSpaceCast(RHSAlloca, Arg1Type);
+    Builder.CreateStore(Arg0, LHSAddrCast);
+    Builder.CreateStore(Arg1, RHSAddrCast);
+    LHSArrayPtr = Builder.CreateLoad(Arg0Type, LHSAddrCast);
+    RHSArrayPtr = Builder.CreateLoad(Arg1Type, RHSAddrCast);
+  } else {
+    LHSArrayPtr = ReductionFunc->getArg(0);
+    RHSArrayPtr = ReductionFunc->getArg(1);
   }
 
+  unsigned NumReductions = ReductionInfos.size();
+  Type *RedArrayTy = ArrayType::get(Builder.getPtrTy(), NumReductions);
+
+  for (auto En : enumerate(ReductionInfos)) {
+    const OpenMPIRBuilder::ReductionInfo &RI = En.value();
+    Value *LHSI8PtrPtr = Builder.CreateConstInBoundsGEP2_64(
+        RedArrayTy, LHSArrayPtr, 0, En.index());
+    Value *LHSI8Ptr = Builder.CreateLoad(Builder.getPtrTy(), LHSI8PtrPtr);
+    Value *LHSPtr = Builder.CreatePointerBitCastOrAddrSpaceCast(
+        LHSI8Ptr, RI.Variable->getType());
+    Value *LHS = Builder.CreateLoad(RI.ElementType, LHSPtr);
+    Value *RHSI8PtrPtr = Builder.CreateConstInBoundsGEP2_64(
+        RedArrayTy, RHSArrayPtr, 0, En.index());
+    Value *RHSI8Ptr = Builder.CreateLoad(Builder.getPtrTy(), RHSI8PtrPtr);
+    Value *RHSPtr = Builder.CreatePointerBitCastOrAddrSpaceCast(
+        RHSI8Ptr, RI.PrivateVariable->getType());
+    Value *RHS = Builder.CreateLoad(RI.ElementType, RHSPtr);
+    Value *Reduced;
+    OpenMPIRBuilder::InsertPointOrErrorTy AfterIP =
+        RI.ReductionGen(Builder.saveIP(), LHS, RHS, Reduced);
+    if (!AfterIP)
+      return AfterIP.takeError();
+
+    Builder.restoreIP(*AfterIP);
+    // TODO: Consider flagging an error.
+    if (!Builder.GetInsertBlock())
+      return Error::success();
+
+    // store is inside of the reduction region when using by-ref
+    if (!IsByRef[En.index()])
+      Builder.CreateStore(Reduced, LHSPtr);
+  }
+  Builder.CreateRetVoid();
+  return Error::success();
+}
+
+OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createReductions(
+    const LocationDescription &Loc, InsertPointTy AllocaIP,
+    ArrayRef<ReductionInfo> ReductionInfos, ArrayRef<bool> IsByRef,
+    bool IsNoWait, bool IsTeamsReduction) {
+  assert(ReductionInfos.size() == IsByRef.size());
+  if (Config.isGPU())
+    return createReductionsGPU(Loc, AllocaIP, Builder.saveIP(), ReductionInfos,
+                               IsNoWait, IsTeamsReduction);
+
+  checkReductionInfos(ReductionInfos, /*IsGPU*/ false);
+
   if (!updateToLocation(Loc))
     return InsertPointTy();
 
+  if (ReductionInfos.size() == 0)
+    return Builder.saveIP();
+
   BasicBlock *InsertBlock = Loc.IP.getBlock();
   BasicBlock *ContinuationBlock =
       InsertBlock->splitBasicBlock(Loc.IP.getPoint(), "reduce.finalize");
@@ -3832,38 +3920,13 @@ OpenMPIRBuilder::createReductions(const LocationDescription &Loc,
   // Populate the outlined reduction function using the elementwise reduction
   // function. Partial values are extracted from the type-erased array of
   // pointers to private variables.
-  BasicBlock *ReductionFuncBlock =
-      BasicBlock::Create(Module->getContext(), "", ReductionFunc);
-  Builder.SetInsertPoint(ReductionFuncBlock);
-  Value *LHSArrayPtr = ReductionFunc->getArg(0);
-  Value *RHSArrayPtr = ReductionFunc->getArg(1);
+  Error Err = populateReductionFunction(ReductionFunc, ReductionInfos, Builder,
+                                        IsByRef, false);
+  if (Err)
+    return Err;
 
-  for (auto En : enumerate(ReductionInfos)) {
-    const ReductionInfo &RI = En.value();
-    Value *LHSI8PtrPtr = Builder.CreateConstInBoundsGEP2_64(
-        RedArrayTy, LHSArrayPtr, 0, En.index());
-    Value *LHSI8Ptr = Builder.CreateLoad(Builder.getPtrTy(), LHSI8PtrPtr);
-    Value *LHSPtr = Builder.CreateBitCast(LHSI8Ptr, RI.Variable->getType());
-    Value *LHS = Builder.CreateLoad(RI.ElementType, LHSPtr);
-    Value *RHSI8PtrPtr = Builder.CreateConstInBoundsGEP2_64(
-        RedArrayTy, RHSArrayPtr, 0, En.index());
-    Value *RHSI8Ptr = Builder.CreateLoad(Builder.getPtrTy(), RHSI8PtrPtr);
-    Value *RHSPtr =
-        Builder.CreateBitCast(RHSI8Ptr, RI.PrivateVariable->getType());
-    Value *RHS = Builder.CreateLoad(RI.ElementType, RHSPtr);
-    Value *Reduced;
-    InsertPointOrErrorTy AfterIP =
-        RI.ReductionGen(Builder.saveIP(), LHS, RHS, Reduced);
-    if (!AfterIP)
-      return AfterIP.takeError();
-    Builder.restoreIP(*AfterIP);
-    if (!Builder.GetInsertBlock())
-      return InsertPointTy();
-    // store is inside of the reduction region when using by-ref
-    if (!IsByRef[En.index()])
-      Builder.CreateStore(Reduced, LHSPtr);
-  }
-  Builder.CreateRetVoid();
+  if (!Builder.GetInsertBlock())
+    return InsertPointTy();
 
   Builder.SetInsertPoint(ContinuationBlock);
   return Builder.saveIP();
@@ -4434,10 +4497,24 @@ getKmpcForStaticLoopForType(Type *Ty, OpenMPIRBuilder *OMPBuilder,
 static void createTargetLoopWorkshareCall(
     OpenMPIRBuilder *OMPBuilder, WorksharingLoopType LoopType,
     BasicBlock *InsertBlock, Value *Ident, Value *LoopBodyArg,
-    Type *ParallelTaskPtr, Value *TripCount, Function &LoopBodyFn) {
-  Type *TripCountTy = TripCount->getType();
+    Type *ParallelTaskPtr, Value *TripCountOrig, Function &LoopBodyFn) {
   Module &M = OMPBuilder->M;
   IRBuilder<> &Builder = OMPBuilder->Builder;
+  Value *TripCount = TripCountOrig;
+  // The trip count is 1 larger than it should be for GPU, this is because
+  // of how the deviceRTL functions work with clang. TODO: make the trip
+  // count consistent between both so we don't have to subtract one here.
+  if (OMPBuilder->Config.isGPU()) {
+    Builder.restoreIP({InsertBlock, std::prev(InsertBlock->end())});
+    LLVMContext &Ctx = M.getContext();
+    Type *IVTy = TripCountOrig->getType();
+    Type *InternalIVTy = IVTy->getIntegerBitWidth() <= 32
+                             ? Type::getInt32Ty(Ctx)
+                             : Type::getInt64Ty(Ctx);
+    Constant *One = ConstantInt::get(InternalIVTy, 1);
+    TripCount = Builder.CreateSub(TripCountOrig, One, "modified_trip_count");
+  }
+  Type *TripCountTy = TripCount->getType();
   FunctionCallee RTLFn =
       getKmpcForStaticLoopForType(TripCountTy, OMPBuilder, LoopType);
   SmallVector<Value *, 8> RealArgs;
@@ -6239,8 +6316,10 @@ OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createTargetInit(
   Constant *MaxThreads = ConstantInt::getSigned(Int32, MaxThreadsVal);
   Constant *MinTeams = ConstantInt::getSigned(Int32, Attrs.MinTeams);
   Constant *MaxTeams = ConstantInt::getSigned(Int32, Attrs.MaxTeams.front());
-  Constant *ReductionDataSize = ConstantInt::getSigned(Int32, 0);
-  Constant *ReductionBufferLength = ConstantInt::getSigned(Int32, 0);
+  Constant *ReductionDataSize =
+      ConstantInt::getSigned(Int32, Attrs.ReductionDataSize);
+  Constant *ReductionBufferLength =
+      ConstantInt::getSigned(Int32, Attrs.ReductionBufferLength);
 
   Function *Fn = getOrCreateRuntimeFunctionPtr(
       omp::RuntimeFunction::OMPRTL___kmpc_target_init);
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index d41489921bd13..155ea3f920617 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -265,7 +265,6 @@ static LogicalResult checkImplementationStatus(Operation &op) {
       .Case([&](omp::TeamsOp op) {
         checkAllocate(op, result);
         checkPrivate(op, result);
-        checkReduction(op, result);
       })
       .Case([&](omp::TaskOp op) {
         checkAllocate(op, result);
@@ -1018,19 +1017,37 @@ allocReductionVars(T loop, ArrayRef<BlockArgument> reductionArgs,
       // variable allocated in the inlined region)
       llvm::Value *var = builder.CreateAlloca(
           moduleTranslation.convertType(reductionDecls[i].getType()));
-      deferredStores.emplace_back(phis[0], var);
-
-      privateReductionVariables[i] = var;
-      moduleTranslation.mapValue(reductionArgs[i], phis[0]);
-      reductionVariableMap.try_emplace(loop.getReductionVars()[i], phis[0]);
+      //      var->setName("private_redvar");
+
+      llvm::Type *ptrTy = llvm::PointerType::getUnqual(builder.getContext());
+      llvm::Value *castVar =
+          builder.CreatePointerBitCastOrAddrSpaceCast(var, ptrTy);
+      // TODO: I (Sergio) just guessed casting phis[0] like it's done for var is
+      // what's supposed to happen with this code coming from a merge from main,
+      // but I don't actually know. Someone more familiar with it needs to check
+      // this.
+      llvm::Value *castPhi =
+          builder.CreatePointerBitCastOrAddrSpaceCast(phis[0], ptrTy);
+
+      deferredStores.emplace_back(castPhi, castVar);
+
+      privateReductionVariables[i] = castVar;
+      moduleTranslation.mapValue(reductionArgs[i], castPhi);
+      reductionVariableMap.try_emplace(loop.getReductionVars()[i], castPhi);
     } else {
       assert(allocRegion.empty() &&
              "allocaction is implicit for by-val reduction");
       llvm::Value *var = builder.CreateAlloca(
           moduleTranslation.convertType(reductionDecls[i].getType()));
-      moduleTranslation.mapValue(reductionArgs[i], var);
-      privateReductionVariables[i] = var;
-      reductionVariableMap.try_emplace(loop.getReductionVars()[i], var);
+      //      var->setName("private_redvar");
+
+      llvm::Type *ptrTy = llvm::PointerType::getUnqual(builder.getContext());
+      llvm::Value *castVar =
+          builder.CreatePointerBitCastOrAddrSpaceCast(var, ptrTy);
+
+      moduleTranslation.mapValue(reductionArgs[i], castVar);
+      privateReductionVariables[i] = castVar;
+      reductionVariableMap.try_emplace(loop.getReductionVars()[i], castVar);
     }
   }
 
@@ -1250,18 +1267,20 @@ static LogicalResult createReductionsAndCleanup(
     LLVM::ModuleTranslation &moduleTranslation,
     llvm::OpenMPIRBuilder::InsertPointTy &allocaIP,
     SmallVectorImpl<omp::DeclareReductionOp> &reductionDecls,
-    ArrayRef<llvm::Value *> privateReductionVariables, ArrayRef<bool> isByRef) {
+    ArrayRef<llvm::Value *> privateReductionVariables, ArrayRef<bool> isByRef,
+    bool isNowait = false, bool isTeamsReduction = false) {
   // Process the reductions if required.
   if (op.getNumReductionVars() == 0)
     return success();
 
+  SmallVector<OwningReductionGen> owningReductionGens;
+  SmallVector<OwningAtomicReductionGen> owningAtomicReductionGens;
+  SmallVector<llvm::OpenMPIRBuilder::ReductionInfo> reductionInfos;
+
   llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
 
   // Create the reduction generators. We need to own them here because
   // ReductionInfo only accepts references to the generators.
-  SmallVector<OwningReductionGen> owningReductionGens;
-  SmallVector<OwningAtomicReductionGen> owningAtomicReductionGens;
-  SmallVector<llvm::OpenMPIRBuilder::ReductionInfo> reductionInfos;
   collectReductionInfo(op, builder, moduleTranslation, reductionDecls,
                        owningReductionGens, owningAtomicReductionGens,
                        privateReductionVariables, reductionInfos);
@@ -1273,7 +1292,7 @@ static LogicalResult createReductionsAndCleanup(
   builder.SetInsertPoint(tempTerminator);
   llvm::OpenMPIRBuilder::InsertPointOrErrorTy contInsertPoint =
       ompBuilder->createReductions(builder.saveIP(), allocaIP, reductionInfos,
-                                   isByRef, op.getNowait());
+                                   isByRef, isNowait, isTeamsReduction);
 
   if (failed(handleError(contInsertPoint, *op)))
     return failure();
@@ -1666,9 +1685,9 @@ convertOmpSections(Operation &opInst, llvm::IRBuilderBase &builder,
   builder.restoreIP(*afterIP);
 
   // Process the reductions if required.
-  return createReductionsAndCleanup(sectionsOp, builder, moduleTranslation,
-                                    allocaIP, reductionDecls,
-                                    privateReductionVariables, isByRef);
+  return createReductionsAndCleanup(
+      sectionsOp, builder, moduleTranslation, allocaIP, reductionDecls,
+      privateReductionVariables, isByRef, sectionsOp.getNowait());
 }
 
 /// Converts an OpenMP single construct into LLVM IR using OpenMPIRBuilder.
@@ -1714,6 +1733,43 @@ convertOmpSingle(omp::SingleOp &singleOp, llvm::IRBuilderBase &builder,
   return success();
 }
 
+static bool teamsReductionContainedInDistribute(omp::TeamsOp teamsOp) {
+  auto iface =
+      llvm::cast<mlir::omp::BlockArgOpenMPOpInterface>(teamsOp.getOperation());
+  // Check that all uses of the reduction block arg has the same distri...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented Mar 27, 2025

@llvm/pr-subscribers-clang

Author: Jan Leyonberg (jsjodin)

Changes

This patch adds the lowering of teams reductions from the omp dialect to LLVM-IR. Some minor cleanup was done in clang to remove an unused parameter.


Patch is 39.18 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/133310.diff

11 Files Affected:

  • (modified) clang/lib/CodeGen/CGOpenMPRuntimeGPU.cpp (+1-2)
  • (modified) llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h (+4-4)
  • (modified) llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp (+135-56)
  • (modified) mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp (+185-23)
  • (modified) mlir/test/Target/LLVMIR/omptarget-parallel-wsloop.mlir (+1-1)
  • (modified) mlir/test/Target/LLVMIR/omptarget-wsloop-collapsed.mlir (+1-1)
  • (modified) mlir/test/Target/LLVMIR/omptarget-wsloop.mlir (+2-2)
  • (added) mlir/test/Target/LLVMIR/openmp-teams-reduction.mlir (+71)
  • (modified) mlir/test/Target/LLVMIR/openmp-todo.mlir (-28)
  • (added) offload/test/offloading/fortran/basic-target-parallel-reduction.f90 (+27)
  • (added) offload/test/offloading/fortran/basic-target-teams-parallel-reduction.f90 (+27)
diff --git a/clang/lib/CodeGen/CGOpenMPRuntimeGPU.cpp b/clang/lib/CodeGen/CGOpenMPRuntimeGPU.cpp
index feb2448297542..d30bef9e7f0ba 100644
--- a/clang/lib/CodeGen/CGOpenMPRuntimeGPU.cpp
+++ b/clang/lib/CodeGen/CGOpenMPRuntimeGPU.cpp
@@ -1659,7 +1659,6 @@ void CGOpenMPRuntimeGPU::emitReduction(
     return;
 
   bool ParallelReduction = isOpenMPParallelDirective(Options.ReductionKind);
-  bool DistributeReduction = isOpenMPDistributeDirective(Options.ReductionKind);
   bool TeamsReduction = isOpenMPTeamsDirective(Options.ReductionKind);
 
   ASTContext &C = CGM.getContext();
@@ -1756,7 +1755,7 @@ void CGOpenMPRuntimeGPU::emitReduction(
   llvm::OpenMPIRBuilder::InsertPointTy AfterIP =
       cantFail(OMPBuilder.createReductionsGPU(
           OmpLoc, AllocaIP, CodeGenIP, ReductionInfos, false, TeamsReduction,
-          DistributeReduction, llvm::OpenMPIRBuilder::ReductionGenCBKind::Clang,
+          llvm::OpenMPIRBuilder::ReductionGenCBKind::Clang,
           CGF.getTarget().getGridValue(),
           C.getLangOpts().OpenMPCUDAReductionBufNum, RTLoc));
   CGF.Builder.restoreIP(AfterIP);
diff --git a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
index 28909cef4748d..a3a266e3f0a98 100644
--- a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
+++ b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
@@ -1905,8 +1905,6 @@ class OpenMPIRBuilder {
   ///                           nowait.
   /// \param IsTeamsReduction   Optional flag set if it is a teams
   ///                           reduction.
-  /// \param HasDistribute      Optional flag set if it is a
-  ///                           distribute reduction.
   /// \param GridValue          Optional GPU grid value.
   /// \param ReductionBufNum    Optional OpenMPCUDAReductionBufNumValue to be
   /// used for teams reduction.
@@ -1915,7 +1913,6 @@ class OpenMPIRBuilder {
       const LocationDescription &Loc, InsertPointTy AllocaIP,
       InsertPointTy CodeGenIP, ArrayRef<ReductionInfo> ReductionInfos,
       bool IsNoWait = false, bool IsTeamsReduction = false,
-      bool HasDistribute = false,
       ReductionGenCBKind ReductionGenCBKind = ReductionGenCBKind::MLIR,
       std::optional<omp::GV> GridValue = {}, unsigned ReductionBufNum = 1024,
       Value *SrcLocInfo = nullptr);
@@ -1987,7 +1984,8 @@ class OpenMPIRBuilder {
                                         InsertPointTy AllocaIP,
                                         ArrayRef<ReductionInfo> ReductionInfos,
                                         ArrayRef<bool> IsByRef,
-                                        bool IsNoWait = false);
+                                        bool IsNoWait = false,
+                                        bool IsTeamsReduction = false);
 
   ///}
 
@@ -2271,6 +2269,8 @@ class OpenMPIRBuilder {
     int32_t MinTeams = 1;
     SmallVector<int32_t, 3> MaxThreads = {-1};
     int32_t MinThreads = 1;
+    int32_t ReductionDataSize = 0;
+    int32_t ReductionBufferLength = 0;
   };
 
   /// Container to pass LLVM IR runtime values or constants related to the
diff --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
index 2e5ce5308eea5..b5e55dbccf464 100644
--- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
+++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
@@ -3495,9 +3495,9 @@ checkReductionInfos(ArrayRef<OpenMPIRBuilder::ReductionInfo> ReductionInfos,
 OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createReductionsGPU(
     const LocationDescription &Loc, InsertPointTy AllocaIP,
     InsertPointTy CodeGenIP, ArrayRef<ReductionInfo> ReductionInfos,
-    bool IsNoWait, bool IsTeamsReduction, bool HasDistribute,
-    ReductionGenCBKind ReductionGenCBKind, std::optional<omp::GV> GridValue,
-    unsigned ReductionBufNum, Value *SrcLocInfo) {
+    bool IsNoWait, bool IsTeamsReduction, ReductionGenCBKind ReductionGenCBKind,
+    std::optional<omp::GV> GridValue, unsigned ReductionBufNum,
+    Value *SrcLocInfo) {
   if (!updateToLocation(Loc))
     return InsertPointTy();
   Builder.restoreIP(CodeGenIP);
@@ -3514,6 +3514,16 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createReductionsGPU(
   if (ReductionInfos.size() == 0)
     return Builder.saveIP();
 
+  BasicBlock *ContinuationBlock = nullptr;
+  if (ReductionGenCBKind != ReductionGenCBKind::Clang) {
+    // Copied code from createReductions
+    BasicBlock *InsertBlock = Loc.IP.getBlock();
+    ContinuationBlock =
+        InsertBlock->splitBasicBlock(Loc.IP.getPoint(), "reduce.finalize");
+    InsertBlock->getTerminator()->eraseFromParent();
+    Builder.SetInsertPoint(InsertBlock, InsertBlock->end());
+  }
+
   Function *CurFunc = Builder.GetInsertBlock()->getParent();
   AttributeList FuncAttrs;
   AttrBuilder AttrBldr(Ctx);
@@ -3669,11 +3679,21 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createReductionsGPU(
                ReductionFunc;
       });
     } else {
-      assert(false && "Unhandled ReductionGenCBKind");
+      Value *LHSValue = Builder.CreateLoad(RI.ElementType, LHS, "final.lhs");
+      Value *RHSValue = Builder.CreateLoad(RI.ElementType, RHS, "final.rhs");
+      Value *Reduced;
+      InsertPointOrErrorTy AfterIP =
+          RI.ReductionGen(Builder.saveIP(), RHSValue, LHSValue, Reduced);
+      if (!AfterIP)
+        return AfterIP.takeError();
+      Builder.CreateStore(Reduced, LHS, false);
     }
   }
   emitBlock(ExitBB, CurFunc);
-
+  if (ContinuationBlock) {
+    Builder.CreateBr(ContinuationBlock);
+    Builder.SetInsertPoint(ContinuationBlock);
+  }
   Config.setEmitLLVMUsed();
 
   return Builder.saveIP();
@@ -3688,27 +3708,95 @@ static Function *getFreshReductionFunc(Module &M) {
                           ".omp.reduction.func", &M);
 }
 
-OpenMPIRBuilder::InsertPointOrErrorTy
-OpenMPIRBuilder::createReductions(const LocationDescription &Loc,
-                                  InsertPointTy AllocaIP,
-                                  ArrayRef<ReductionInfo> ReductionInfos,
-                                  ArrayRef<bool> IsByRef, bool IsNoWait) {
-  assert(ReductionInfos.size() == IsByRef.size());
-  for (const ReductionInfo &RI : ReductionInfos) {
-    (void)RI;
-    assert(RI.Variable && "expected non-null variable");
-    assert(RI.PrivateVariable && "expected non-null private variable");
-    assert(RI.ReductionGen && "expected non-null reduction generator callback");
-    assert(RI.Variable->getType() == RI.PrivateVariable->getType() &&
-           "expected variables and their private equivalents to have the same "
-           "type");
-    assert(RI.Variable->getType()->isPointerTy() &&
-           "expected variables to be pointers");
+static Error populateReductionFunction(
+    Function *ReductionFunc,
+    ArrayRef<OpenMPIRBuilder::ReductionInfo> ReductionInfos,
+    IRBuilder<> &Builder, ArrayRef<bool> IsByRef, bool IsGPU) {
+  Module *Module = ReductionFunc->getParent();
+  BasicBlock *ReductionFuncBlock =
+      BasicBlock::Create(Module->getContext(), "", ReductionFunc);
+  Builder.SetInsertPoint(ReductionFuncBlock);
+  Value *LHSArrayPtr = nullptr;
+  Value *RHSArrayPtr = nullptr;
+  if (IsGPU) {
+    // Need to alloca memory here and deal with the pointers before getting
+    // LHS/RHS pointers out
+    //
+    Argument *Arg0 = ReductionFunc->getArg(0);
+    Argument *Arg1 = ReductionFunc->getArg(1);
+    Type *Arg0Type = Arg0->getType();
+    Type *Arg1Type = Arg1->getType();
+
+    Value *LHSAlloca =
+        Builder.CreateAlloca(Arg0Type, nullptr, Arg0->getName() + ".addr");
+    Value *RHSAlloca =
+        Builder.CreateAlloca(Arg1Type, nullptr, Arg1->getName() + ".addr");
+    Value *LHSAddrCast =
+        Builder.CreatePointerBitCastOrAddrSpaceCast(LHSAlloca, Arg0Type);
+    Value *RHSAddrCast =
+        Builder.CreatePointerBitCastOrAddrSpaceCast(RHSAlloca, Arg1Type);
+    Builder.CreateStore(Arg0, LHSAddrCast);
+    Builder.CreateStore(Arg1, RHSAddrCast);
+    LHSArrayPtr = Builder.CreateLoad(Arg0Type, LHSAddrCast);
+    RHSArrayPtr = Builder.CreateLoad(Arg1Type, RHSAddrCast);
+  } else {
+    LHSArrayPtr = ReductionFunc->getArg(0);
+    RHSArrayPtr = ReductionFunc->getArg(1);
   }
 
+  unsigned NumReductions = ReductionInfos.size();
+  Type *RedArrayTy = ArrayType::get(Builder.getPtrTy(), NumReductions);
+
+  for (auto En : enumerate(ReductionInfos)) {
+    const OpenMPIRBuilder::ReductionInfo &RI = En.value();
+    Value *LHSI8PtrPtr = Builder.CreateConstInBoundsGEP2_64(
+        RedArrayTy, LHSArrayPtr, 0, En.index());
+    Value *LHSI8Ptr = Builder.CreateLoad(Builder.getPtrTy(), LHSI8PtrPtr);
+    Value *LHSPtr = Builder.CreatePointerBitCastOrAddrSpaceCast(
+        LHSI8Ptr, RI.Variable->getType());
+    Value *LHS = Builder.CreateLoad(RI.ElementType, LHSPtr);
+    Value *RHSI8PtrPtr = Builder.CreateConstInBoundsGEP2_64(
+        RedArrayTy, RHSArrayPtr, 0, En.index());
+    Value *RHSI8Ptr = Builder.CreateLoad(Builder.getPtrTy(), RHSI8PtrPtr);
+    Value *RHSPtr = Builder.CreatePointerBitCastOrAddrSpaceCast(
+        RHSI8Ptr, RI.PrivateVariable->getType());
+    Value *RHS = Builder.CreateLoad(RI.ElementType, RHSPtr);
+    Value *Reduced;
+    OpenMPIRBuilder::InsertPointOrErrorTy AfterIP =
+        RI.ReductionGen(Builder.saveIP(), LHS, RHS, Reduced);
+    if (!AfterIP)
+      return AfterIP.takeError();
+
+    Builder.restoreIP(*AfterIP);
+    // TODO: Consider flagging an error.
+    if (!Builder.GetInsertBlock())
+      return Error::success();
+
+    // store is inside of the reduction region when using by-ref
+    if (!IsByRef[En.index()])
+      Builder.CreateStore(Reduced, LHSPtr);
+  }
+  Builder.CreateRetVoid();
+  return Error::success();
+}
+
+OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createReductions(
+    const LocationDescription &Loc, InsertPointTy AllocaIP,
+    ArrayRef<ReductionInfo> ReductionInfos, ArrayRef<bool> IsByRef,
+    bool IsNoWait, bool IsTeamsReduction) {
+  assert(ReductionInfos.size() == IsByRef.size());
+  if (Config.isGPU())
+    return createReductionsGPU(Loc, AllocaIP, Builder.saveIP(), ReductionInfos,
+                               IsNoWait, IsTeamsReduction);
+
+  checkReductionInfos(ReductionInfos, /*IsGPU*/ false);
+
   if (!updateToLocation(Loc))
     return InsertPointTy();
 
+  if (ReductionInfos.size() == 0)
+    return Builder.saveIP();
+
   BasicBlock *InsertBlock = Loc.IP.getBlock();
   BasicBlock *ContinuationBlock =
       InsertBlock->splitBasicBlock(Loc.IP.getPoint(), "reduce.finalize");
@@ -3832,38 +3920,13 @@ OpenMPIRBuilder::createReductions(const LocationDescription &Loc,
   // Populate the outlined reduction function using the elementwise reduction
   // function. Partial values are extracted from the type-erased array of
   // pointers to private variables.
-  BasicBlock *ReductionFuncBlock =
-      BasicBlock::Create(Module->getContext(), "", ReductionFunc);
-  Builder.SetInsertPoint(ReductionFuncBlock);
-  Value *LHSArrayPtr = ReductionFunc->getArg(0);
-  Value *RHSArrayPtr = ReductionFunc->getArg(1);
+  Error Err = populateReductionFunction(ReductionFunc, ReductionInfos, Builder,
+                                        IsByRef, false);
+  if (Err)
+    return Err;
 
-  for (auto En : enumerate(ReductionInfos)) {
-    const ReductionInfo &RI = En.value();
-    Value *LHSI8PtrPtr = Builder.CreateConstInBoundsGEP2_64(
-        RedArrayTy, LHSArrayPtr, 0, En.index());
-    Value *LHSI8Ptr = Builder.CreateLoad(Builder.getPtrTy(), LHSI8PtrPtr);
-    Value *LHSPtr = Builder.CreateBitCast(LHSI8Ptr, RI.Variable->getType());
-    Value *LHS = Builder.CreateLoad(RI.ElementType, LHSPtr);
-    Value *RHSI8PtrPtr = Builder.CreateConstInBoundsGEP2_64(
-        RedArrayTy, RHSArrayPtr, 0, En.index());
-    Value *RHSI8Ptr = Builder.CreateLoad(Builder.getPtrTy(), RHSI8PtrPtr);
-    Value *RHSPtr =
-        Builder.CreateBitCast(RHSI8Ptr, RI.PrivateVariable->getType());
-    Value *RHS = Builder.CreateLoad(RI.ElementType, RHSPtr);
-    Value *Reduced;
-    InsertPointOrErrorTy AfterIP =
-        RI.ReductionGen(Builder.saveIP(), LHS, RHS, Reduced);
-    if (!AfterIP)
-      return AfterIP.takeError();
-    Builder.restoreIP(*AfterIP);
-    if (!Builder.GetInsertBlock())
-      return InsertPointTy();
-    // store is inside of the reduction region when using by-ref
-    if (!IsByRef[En.index()])
-      Builder.CreateStore(Reduced, LHSPtr);
-  }
-  Builder.CreateRetVoid();
+  if (!Builder.GetInsertBlock())
+    return InsertPointTy();
 
   Builder.SetInsertPoint(ContinuationBlock);
   return Builder.saveIP();
@@ -4434,10 +4497,24 @@ getKmpcForStaticLoopForType(Type *Ty, OpenMPIRBuilder *OMPBuilder,
 static void createTargetLoopWorkshareCall(
     OpenMPIRBuilder *OMPBuilder, WorksharingLoopType LoopType,
     BasicBlock *InsertBlock, Value *Ident, Value *LoopBodyArg,
-    Type *ParallelTaskPtr, Value *TripCount, Function &LoopBodyFn) {
-  Type *TripCountTy = TripCount->getType();
+    Type *ParallelTaskPtr, Value *TripCountOrig, Function &LoopBodyFn) {
   Module &M = OMPBuilder->M;
   IRBuilder<> &Builder = OMPBuilder->Builder;
+  Value *TripCount = TripCountOrig;
+  // The trip count is 1 larger than it should be for GPU, this is because
+  // of how the deviceRTL functions work with clang. TODO: make the trip
+  // count consistent between both so we don't have to subtract one here.
+  if (OMPBuilder->Config.isGPU()) {
+    Builder.restoreIP({InsertBlock, std::prev(InsertBlock->end())});
+    LLVMContext &Ctx = M.getContext();
+    Type *IVTy = TripCountOrig->getType();
+    Type *InternalIVTy = IVTy->getIntegerBitWidth() <= 32
+                             ? Type::getInt32Ty(Ctx)
+                             : Type::getInt64Ty(Ctx);
+    Constant *One = ConstantInt::get(InternalIVTy, 1);
+    TripCount = Builder.CreateSub(TripCountOrig, One, "modified_trip_count");
+  }
+  Type *TripCountTy = TripCount->getType();
   FunctionCallee RTLFn =
       getKmpcForStaticLoopForType(TripCountTy, OMPBuilder, LoopType);
   SmallVector<Value *, 8> RealArgs;
@@ -6239,8 +6316,10 @@ OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createTargetInit(
   Constant *MaxThreads = ConstantInt::getSigned(Int32, MaxThreadsVal);
   Constant *MinTeams = ConstantInt::getSigned(Int32, Attrs.MinTeams);
   Constant *MaxTeams = ConstantInt::getSigned(Int32, Attrs.MaxTeams.front());
-  Constant *ReductionDataSize = ConstantInt::getSigned(Int32, 0);
-  Constant *ReductionBufferLength = ConstantInt::getSigned(Int32, 0);
+  Constant *ReductionDataSize =
+      ConstantInt::getSigned(Int32, Attrs.ReductionDataSize);
+  Constant *ReductionBufferLength =
+      ConstantInt::getSigned(Int32, Attrs.ReductionBufferLength);
 
   Function *Fn = getOrCreateRuntimeFunctionPtr(
       omp::RuntimeFunction::OMPRTL___kmpc_target_init);
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index d41489921bd13..155ea3f920617 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -265,7 +265,6 @@ static LogicalResult checkImplementationStatus(Operation &op) {
       .Case([&](omp::TeamsOp op) {
         checkAllocate(op, result);
         checkPrivate(op, result);
-        checkReduction(op, result);
       })
       .Case([&](omp::TaskOp op) {
         checkAllocate(op, result);
@@ -1018,19 +1017,37 @@ allocReductionVars(T loop, ArrayRef<BlockArgument> reductionArgs,
       // variable allocated in the inlined region)
       llvm::Value *var = builder.CreateAlloca(
           moduleTranslation.convertType(reductionDecls[i].getType()));
-      deferredStores.emplace_back(phis[0], var);
-
-      privateReductionVariables[i] = var;
-      moduleTranslation.mapValue(reductionArgs[i], phis[0]);
-      reductionVariableMap.try_emplace(loop.getReductionVars()[i], phis[0]);
+      //      var->setName("private_redvar");
+
+      llvm::Type *ptrTy = llvm::PointerType::getUnqual(builder.getContext());
+      llvm::Value *castVar =
+          builder.CreatePointerBitCastOrAddrSpaceCast(var, ptrTy);
+      // TODO: I (Sergio) just guessed casting phis[0] like it's done for var is
+      // what's supposed to happen with this code coming from a merge from main,
+      // but I don't actually know. Someone more familiar with it needs to check
+      // this.
+      llvm::Value *castPhi =
+          builder.CreatePointerBitCastOrAddrSpaceCast(phis[0], ptrTy);
+
+      deferredStores.emplace_back(castPhi, castVar);
+
+      privateReductionVariables[i] = castVar;
+      moduleTranslation.mapValue(reductionArgs[i], castPhi);
+      reductionVariableMap.try_emplace(loop.getReductionVars()[i], castPhi);
     } else {
       assert(allocRegion.empty() &&
              "allocaction is implicit for by-val reduction");
       llvm::Value *var = builder.CreateAlloca(
           moduleTranslation.convertType(reductionDecls[i].getType()));
-      moduleTranslation.mapValue(reductionArgs[i], var);
-      privateReductionVariables[i] = var;
-      reductionVariableMap.try_emplace(loop.getReductionVars()[i], var);
+      //      var->setName("private_redvar");
+
+      llvm::Type *ptrTy = llvm::PointerType::getUnqual(builder.getContext());
+      llvm::Value *castVar =
+          builder.CreatePointerBitCastOrAddrSpaceCast(var, ptrTy);
+
+      moduleTranslation.mapValue(reductionArgs[i], castVar);
+      privateReductionVariables[i] = castVar;
+      reductionVariableMap.try_emplace(loop.getReductionVars()[i], castVar);
     }
   }
 
@@ -1250,18 +1267,20 @@ static LogicalResult createReductionsAndCleanup(
     LLVM::ModuleTranslation &moduleTranslation,
     llvm::OpenMPIRBuilder::InsertPointTy &allocaIP,
     SmallVectorImpl<omp::DeclareReductionOp> &reductionDecls,
-    ArrayRef<llvm::Value *> privateReductionVariables, ArrayRef<bool> isByRef) {
+    ArrayRef<llvm::Value *> privateReductionVariables, ArrayRef<bool> isByRef,
+    bool isNowait = false, bool isTeamsReduction = false) {
   // Process the reductions if required.
   if (op.getNumReductionVars() == 0)
     return success();
 
+  SmallVector<OwningReductionGen> owningReductionGens;
+  SmallVector<OwningAtomicReductionGen> owningAtomicReductionGens;
+  SmallVector<llvm::OpenMPIRBuilder::ReductionInfo> reductionInfos;
+
   llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
 
   // Create the reduction generators. We need to own them here because
   // ReductionInfo only accepts references to the generators.
-  SmallVector<OwningReductionGen> owningReductionGens;
-  SmallVector<OwningAtomicReductionGen> owningAtomicReductionGens;
-  SmallVector<llvm::OpenMPIRBuilder::ReductionInfo> reductionInfos;
   collectReductionInfo(op, builder, moduleTranslation, reductionDecls,
                        owningReductionGens, owningAtomicReductionGens,
                        privateReductionVariables, reductionInfos);
@@ -1273,7 +1292,7 @@ static LogicalResult createReductionsAndCleanup(
   builder.SetInsertPoint(tempTerminator);
   llvm::OpenMPIRBuilder::InsertPointOrErrorTy contInsertPoint =
       ompBuilder->createReductions(builder.saveIP(), allocaIP, reductionInfos,
-                                   isByRef, op.getNowait());
+                                   isByRef, isNowait, isTeamsReduction);
 
   if (failed(handleError(contInsertPoint, *op)))
     return failure();
@@ -1666,9 +1685,9 @@ convertOmpSections(Operation &opInst, llvm::IRBuilderBase &builder,
   builder.restoreIP(*afterIP);
 
   // Process the reductions if required.
-  return createReductionsAndCleanup(sectionsOp, builder, moduleTranslation,
-                                    allocaIP, reductionDecls,
-                                    privateReductionVariables, isByRef);
+  return createReductionsAndCleanup(
+      sectionsOp, builder, moduleTranslation, allocaIP, reductionDecls,
+      privateReductionVariables, isByRef, sectionsOp.getNowait());
 }
 
 /// Converts an OpenMP single construct into LLVM IR using OpenMPIRBuilder.
@@ -1714,6 +1733,43 @@ convertOmpSingle(omp::SingleOp &singleOp, llvm::IRBuilderBase &builder,
   return success();
 }
 
+static bool teamsReductionContainedInDistribute(omp::TeamsOp teamsOp) {
+  auto iface =
+      llvm::cast<mlir::omp::BlockArgOpenMPOpInterface>(teamsOp.getOperation());
+  // Check that all uses of the reduction block arg has the same distri...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented Mar 27, 2025

@llvm/pr-subscribers-mlir-openmp

Author: Jan Leyonberg (jsjodin)

Changes

This patch adds the lowering of teams reductions from the omp dialect to LLVM-IR. Some minor cleanup was done in clang to remove an unused parameter.


Patch is 39.18 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/133310.diff

11 Files Affected:

  • (modified) clang/lib/CodeGen/CGOpenMPRuntimeGPU.cpp (+1-2)
  • (modified) llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h (+4-4)
  • (modified) llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp (+135-56)
  • (modified) mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp (+185-23)
  • (modified) mlir/test/Target/LLVMIR/omptarget-parallel-wsloop.mlir (+1-1)
  • (modified) mlir/test/Target/LLVMIR/omptarget-wsloop-collapsed.mlir (+1-1)
  • (modified) mlir/test/Target/LLVMIR/omptarget-wsloop.mlir (+2-2)
  • (added) mlir/test/Target/LLVMIR/openmp-teams-reduction.mlir (+71)
  • (modified) mlir/test/Target/LLVMIR/openmp-todo.mlir (-28)
  • (added) offload/test/offloading/fortran/basic-target-parallel-reduction.f90 (+27)
  • (added) offload/test/offloading/fortran/basic-target-teams-parallel-reduction.f90 (+27)
diff --git a/clang/lib/CodeGen/CGOpenMPRuntimeGPU.cpp b/clang/lib/CodeGen/CGOpenMPRuntimeGPU.cpp
index feb2448297542..d30bef9e7f0ba 100644
--- a/clang/lib/CodeGen/CGOpenMPRuntimeGPU.cpp
+++ b/clang/lib/CodeGen/CGOpenMPRuntimeGPU.cpp
@@ -1659,7 +1659,6 @@ void CGOpenMPRuntimeGPU::emitReduction(
     return;
 
   bool ParallelReduction = isOpenMPParallelDirective(Options.ReductionKind);
-  bool DistributeReduction = isOpenMPDistributeDirective(Options.ReductionKind);
   bool TeamsReduction = isOpenMPTeamsDirective(Options.ReductionKind);
 
   ASTContext &C = CGM.getContext();
@@ -1756,7 +1755,7 @@ void CGOpenMPRuntimeGPU::emitReduction(
   llvm::OpenMPIRBuilder::InsertPointTy AfterIP =
       cantFail(OMPBuilder.createReductionsGPU(
           OmpLoc, AllocaIP, CodeGenIP, ReductionInfos, false, TeamsReduction,
-          DistributeReduction, llvm::OpenMPIRBuilder::ReductionGenCBKind::Clang,
+          llvm::OpenMPIRBuilder::ReductionGenCBKind::Clang,
           CGF.getTarget().getGridValue(),
           C.getLangOpts().OpenMPCUDAReductionBufNum, RTLoc));
   CGF.Builder.restoreIP(AfterIP);
diff --git a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
index 28909cef4748d..a3a266e3f0a98 100644
--- a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
+++ b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
@@ -1905,8 +1905,6 @@ class OpenMPIRBuilder {
   ///                           nowait.
   /// \param IsTeamsReduction   Optional flag set if it is a teams
   ///                           reduction.
-  /// \param HasDistribute      Optional flag set if it is a
-  ///                           distribute reduction.
   /// \param GridValue          Optional GPU grid value.
   /// \param ReductionBufNum    Optional OpenMPCUDAReductionBufNumValue to be
   /// used for teams reduction.
@@ -1915,7 +1913,6 @@ class OpenMPIRBuilder {
       const LocationDescription &Loc, InsertPointTy AllocaIP,
       InsertPointTy CodeGenIP, ArrayRef<ReductionInfo> ReductionInfos,
       bool IsNoWait = false, bool IsTeamsReduction = false,
-      bool HasDistribute = false,
       ReductionGenCBKind ReductionGenCBKind = ReductionGenCBKind::MLIR,
       std::optional<omp::GV> GridValue = {}, unsigned ReductionBufNum = 1024,
       Value *SrcLocInfo = nullptr);
@@ -1987,7 +1984,8 @@ class OpenMPIRBuilder {
                                         InsertPointTy AllocaIP,
                                         ArrayRef<ReductionInfo> ReductionInfos,
                                         ArrayRef<bool> IsByRef,
-                                        bool IsNoWait = false);
+                                        bool IsNoWait = false,
+                                        bool IsTeamsReduction = false);
 
   ///}
 
@@ -2271,6 +2269,8 @@ class OpenMPIRBuilder {
     int32_t MinTeams = 1;
     SmallVector<int32_t, 3> MaxThreads = {-1};
     int32_t MinThreads = 1;
+    int32_t ReductionDataSize = 0;
+    int32_t ReductionBufferLength = 0;
   };
 
   /// Container to pass LLVM IR runtime values or constants related to the
diff --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
index 2e5ce5308eea5..b5e55dbccf464 100644
--- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
+++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
@@ -3495,9 +3495,9 @@ checkReductionInfos(ArrayRef<OpenMPIRBuilder::ReductionInfo> ReductionInfos,
 OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createReductionsGPU(
     const LocationDescription &Loc, InsertPointTy AllocaIP,
     InsertPointTy CodeGenIP, ArrayRef<ReductionInfo> ReductionInfos,
-    bool IsNoWait, bool IsTeamsReduction, bool HasDistribute,
-    ReductionGenCBKind ReductionGenCBKind, std::optional<omp::GV> GridValue,
-    unsigned ReductionBufNum, Value *SrcLocInfo) {
+    bool IsNoWait, bool IsTeamsReduction, ReductionGenCBKind ReductionGenCBKind,
+    std::optional<omp::GV> GridValue, unsigned ReductionBufNum,
+    Value *SrcLocInfo) {
   if (!updateToLocation(Loc))
     return InsertPointTy();
   Builder.restoreIP(CodeGenIP);
@@ -3514,6 +3514,16 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createReductionsGPU(
   if (ReductionInfos.size() == 0)
     return Builder.saveIP();
 
+  BasicBlock *ContinuationBlock = nullptr;
+  if (ReductionGenCBKind != ReductionGenCBKind::Clang) {
+    // Copied code from createReductions
+    BasicBlock *InsertBlock = Loc.IP.getBlock();
+    ContinuationBlock =
+        InsertBlock->splitBasicBlock(Loc.IP.getPoint(), "reduce.finalize");
+    InsertBlock->getTerminator()->eraseFromParent();
+    Builder.SetInsertPoint(InsertBlock, InsertBlock->end());
+  }
+
   Function *CurFunc = Builder.GetInsertBlock()->getParent();
   AttributeList FuncAttrs;
   AttrBuilder AttrBldr(Ctx);
@@ -3669,11 +3679,21 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createReductionsGPU(
                ReductionFunc;
       });
     } else {
-      assert(false && "Unhandled ReductionGenCBKind");
+      Value *LHSValue = Builder.CreateLoad(RI.ElementType, LHS, "final.lhs");
+      Value *RHSValue = Builder.CreateLoad(RI.ElementType, RHS, "final.rhs");
+      Value *Reduced;
+      InsertPointOrErrorTy AfterIP =
+          RI.ReductionGen(Builder.saveIP(), RHSValue, LHSValue, Reduced);
+      if (!AfterIP)
+        return AfterIP.takeError();
+      Builder.CreateStore(Reduced, LHS, false);
     }
   }
   emitBlock(ExitBB, CurFunc);
-
+  if (ContinuationBlock) {
+    Builder.CreateBr(ContinuationBlock);
+    Builder.SetInsertPoint(ContinuationBlock);
+  }
   Config.setEmitLLVMUsed();
 
   return Builder.saveIP();
@@ -3688,27 +3708,95 @@ static Function *getFreshReductionFunc(Module &M) {
                           ".omp.reduction.func", &M);
 }
 
-OpenMPIRBuilder::InsertPointOrErrorTy
-OpenMPIRBuilder::createReductions(const LocationDescription &Loc,
-                                  InsertPointTy AllocaIP,
-                                  ArrayRef<ReductionInfo> ReductionInfos,
-                                  ArrayRef<bool> IsByRef, bool IsNoWait) {
-  assert(ReductionInfos.size() == IsByRef.size());
-  for (const ReductionInfo &RI : ReductionInfos) {
-    (void)RI;
-    assert(RI.Variable && "expected non-null variable");
-    assert(RI.PrivateVariable && "expected non-null private variable");
-    assert(RI.ReductionGen && "expected non-null reduction generator callback");
-    assert(RI.Variable->getType() == RI.PrivateVariable->getType() &&
-           "expected variables and their private equivalents to have the same "
-           "type");
-    assert(RI.Variable->getType()->isPointerTy() &&
-           "expected variables to be pointers");
+static Error populateReductionFunction(
+    Function *ReductionFunc,
+    ArrayRef<OpenMPIRBuilder::ReductionInfo> ReductionInfos,
+    IRBuilder<> &Builder, ArrayRef<bool> IsByRef, bool IsGPU) {
+  Module *Module = ReductionFunc->getParent();
+  BasicBlock *ReductionFuncBlock =
+      BasicBlock::Create(Module->getContext(), "", ReductionFunc);
+  Builder.SetInsertPoint(ReductionFuncBlock);
+  Value *LHSArrayPtr = nullptr;
+  Value *RHSArrayPtr = nullptr;
+  if (IsGPU) {
+    // Need to alloca memory here and deal with the pointers before getting
+    // LHS/RHS pointers out
+    //
+    Argument *Arg0 = ReductionFunc->getArg(0);
+    Argument *Arg1 = ReductionFunc->getArg(1);
+    Type *Arg0Type = Arg0->getType();
+    Type *Arg1Type = Arg1->getType();
+
+    Value *LHSAlloca =
+        Builder.CreateAlloca(Arg0Type, nullptr, Arg0->getName() + ".addr");
+    Value *RHSAlloca =
+        Builder.CreateAlloca(Arg1Type, nullptr, Arg1->getName() + ".addr");
+    Value *LHSAddrCast =
+        Builder.CreatePointerBitCastOrAddrSpaceCast(LHSAlloca, Arg0Type);
+    Value *RHSAddrCast =
+        Builder.CreatePointerBitCastOrAddrSpaceCast(RHSAlloca, Arg1Type);
+    Builder.CreateStore(Arg0, LHSAddrCast);
+    Builder.CreateStore(Arg1, RHSAddrCast);
+    LHSArrayPtr = Builder.CreateLoad(Arg0Type, LHSAddrCast);
+    RHSArrayPtr = Builder.CreateLoad(Arg1Type, RHSAddrCast);
+  } else {
+    LHSArrayPtr = ReductionFunc->getArg(0);
+    RHSArrayPtr = ReductionFunc->getArg(1);
   }
 
+  unsigned NumReductions = ReductionInfos.size();
+  Type *RedArrayTy = ArrayType::get(Builder.getPtrTy(), NumReductions);
+
+  for (auto En : enumerate(ReductionInfos)) {
+    const OpenMPIRBuilder::ReductionInfo &RI = En.value();
+    Value *LHSI8PtrPtr = Builder.CreateConstInBoundsGEP2_64(
+        RedArrayTy, LHSArrayPtr, 0, En.index());
+    Value *LHSI8Ptr = Builder.CreateLoad(Builder.getPtrTy(), LHSI8PtrPtr);
+    Value *LHSPtr = Builder.CreatePointerBitCastOrAddrSpaceCast(
+        LHSI8Ptr, RI.Variable->getType());
+    Value *LHS = Builder.CreateLoad(RI.ElementType, LHSPtr);
+    Value *RHSI8PtrPtr = Builder.CreateConstInBoundsGEP2_64(
+        RedArrayTy, RHSArrayPtr, 0, En.index());
+    Value *RHSI8Ptr = Builder.CreateLoad(Builder.getPtrTy(), RHSI8PtrPtr);
+    Value *RHSPtr = Builder.CreatePointerBitCastOrAddrSpaceCast(
+        RHSI8Ptr, RI.PrivateVariable->getType());
+    Value *RHS = Builder.CreateLoad(RI.ElementType, RHSPtr);
+    Value *Reduced;
+    OpenMPIRBuilder::InsertPointOrErrorTy AfterIP =
+        RI.ReductionGen(Builder.saveIP(), LHS, RHS, Reduced);
+    if (!AfterIP)
+      return AfterIP.takeError();
+
+    Builder.restoreIP(*AfterIP);
+    // TODO: Consider flagging an error.
+    if (!Builder.GetInsertBlock())
+      return Error::success();
+
+    // store is inside of the reduction region when using by-ref
+    if (!IsByRef[En.index()])
+      Builder.CreateStore(Reduced, LHSPtr);
+  }
+  Builder.CreateRetVoid();
+  return Error::success();
+}
+
+OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createReductions(
+    const LocationDescription &Loc, InsertPointTy AllocaIP,
+    ArrayRef<ReductionInfo> ReductionInfos, ArrayRef<bool> IsByRef,
+    bool IsNoWait, bool IsTeamsReduction) {
+  assert(ReductionInfos.size() == IsByRef.size());
+  if (Config.isGPU())
+    return createReductionsGPU(Loc, AllocaIP, Builder.saveIP(), ReductionInfos,
+                               IsNoWait, IsTeamsReduction);
+
+  checkReductionInfos(ReductionInfos, /*IsGPU*/ false);
+
   if (!updateToLocation(Loc))
     return InsertPointTy();
 
+  if (ReductionInfos.size() == 0)
+    return Builder.saveIP();
+
   BasicBlock *InsertBlock = Loc.IP.getBlock();
   BasicBlock *ContinuationBlock =
       InsertBlock->splitBasicBlock(Loc.IP.getPoint(), "reduce.finalize");
@@ -3832,38 +3920,13 @@ OpenMPIRBuilder::createReductions(const LocationDescription &Loc,
   // Populate the outlined reduction function using the elementwise reduction
   // function. Partial values are extracted from the type-erased array of
   // pointers to private variables.
-  BasicBlock *ReductionFuncBlock =
-      BasicBlock::Create(Module->getContext(), "", ReductionFunc);
-  Builder.SetInsertPoint(ReductionFuncBlock);
-  Value *LHSArrayPtr = ReductionFunc->getArg(0);
-  Value *RHSArrayPtr = ReductionFunc->getArg(1);
+  Error Err = populateReductionFunction(ReductionFunc, ReductionInfos, Builder,
+                                        IsByRef, false);
+  if (Err)
+    return Err;
 
-  for (auto En : enumerate(ReductionInfos)) {
-    const ReductionInfo &RI = En.value();
-    Value *LHSI8PtrPtr = Builder.CreateConstInBoundsGEP2_64(
-        RedArrayTy, LHSArrayPtr, 0, En.index());
-    Value *LHSI8Ptr = Builder.CreateLoad(Builder.getPtrTy(), LHSI8PtrPtr);
-    Value *LHSPtr = Builder.CreateBitCast(LHSI8Ptr, RI.Variable->getType());
-    Value *LHS = Builder.CreateLoad(RI.ElementType, LHSPtr);
-    Value *RHSI8PtrPtr = Builder.CreateConstInBoundsGEP2_64(
-        RedArrayTy, RHSArrayPtr, 0, En.index());
-    Value *RHSI8Ptr = Builder.CreateLoad(Builder.getPtrTy(), RHSI8PtrPtr);
-    Value *RHSPtr =
-        Builder.CreateBitCast(RHSI8Ptr, RI.PrivateVariable->getType());
-    Value *RHS = Builder.CreateLoad(RI.ElementType, RHSPtr);
-    Value *Reduced;
-    InsertPointOrErrorTy AfterIP =
-        RI.ReductionGen(Builder.saveIP(), LHS, RHS, Reduced);
-    if (!AfterIP)
-      return AfterIP.takeError();
-    Builder.restoreIP(*AfterIP);
-    if (!Builder.GetInsertBlock())
-      return InsertPointTy();
-    // store is inside of the reduction region when using by-ref
-    if (!IsByRef[En.index()])
-      Builder.CreateStore(Reduced, LHSPtr);
-  }
-  Builder.CreateRetVoid();
+  if (!Builder.GetInsertBlock())
+    return InsertPointTy();
 
   Builder.SetInsertPoint(ContinuationBlock);
   return Builder.saveIP();
@@ -4434,10 +4497,24 @@ getKmpcForStaticLoopForType(Type *Ty, OpenMPIRBuilder *OMPBuilder,
 static void createTargetLoopWorkshareCall(
     OpenMPIRBuilder *OMPBuilder, WorksharingLoopType LoopType,
     BasicBlock *InsertBlock, Value *Ident, Value *LoopBodyArg,
-    Type *ParallelTaskPtr, Value *TripCount, Function &LoopBodyFn) {
-  Type *TripCountTy = TripCount->getType();
+    Type *ParallelTaskPtr, Value *TripCountOrig, Function &LoopBodyFn) {
   Module &M = OMPBuilder->M;
   IRBuilder<> &Builder = OMPBuilder->Builder;
+  Value *TripCount = TripCountOrig;
+  // The trip count is 1 larger than it should be for GPU, this is because
+  // of how the deviceRTL functions work with clang. TODO: make the trip
+  // count consistent between both so we don't have to subtract one here.
+  if (OMPBuilder->Config.isGPU()) {
+    Builder.restoreIP({InsertBlock, std::prev(InsertBlock->end())});
+    LLVMContext &Ctx = M.getContext();
+    Type *IVTy = TripCountOrig->getType();
+    Type *InternalIVTy = IVTy->getIntegerBitWidth() <= 32
+                             ? Type::getInt32Ty(Ctx)
+                             : Type::getInt64Ty(Ctx);
+    Constant *One = ConstantInt::get(InternalIVTy, 1);
+    TripCount = Builder.CreateSub(TripCountOrig, One, "modified_trip_count");
+  }
+  Type *TripCountTy = TripCount->getType();
   FunctionCallee RTLFn =
       getKmpcForStaticLoopForType(TripCountTy, OMPBuilder, LoopType);
   SmallVector<Value *, 8> RealArgs;
@@ -6239,8 +6316,10 @@ OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createTargetInit(
   Constant *MaxThreads = ConstantInt::getSigned(Int32, MaxThreadsVal);
   Constant *MinTeams = ConstantInt::getSigned(Int32, Attrs.MinTeams);
   Constant *MaxTeams = ConstantInt::getSigned(Int32, Attrs.MaxTeams.front());
-  Constant *ReductionDataSize = ConstantInt::getSigned(Int32, 0);
-  Constant *ReductionBufferLength = ConstantInt::getSigned(Int32, 0);
+  Constant *ReductionDataSize =
+      ConstantInt::getSigned(Int32, Attrs.ReductionDataSize);
+  Constant *ReductionBufferLength =
+      ConstantInt::getSigned(Int32, Attrs.ReductionBufferLength);
 
   Function *Fn = getOrCreateRuntimeFunctionPtr(
       omp::RuntimeFunction::OMPRTL___kmpc_target_init);
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index d41489921bd13..155ea3f920617 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -265,7 +265,6 @@ static LogicalResult checkImplementationStatus(Operation &op) {
       .Case([&](omp::TeamsOp op) {
         checkAllocate(op, result);
         checkPrivate(op, result);
-        checkReduction(op, result);
       })
       .Case([&](omp::TaskOp op) {
         checkAllocate(op, result);
@@ -1018,19 +1017,37 @@ allocReductionVars(T loop, ArrayRef<BlockArgument> reductionArgs,
       // variable allocated in the inlined region)
       llvm::Value *var = builder.CreateAlloca(
           moduleTranslation.convertType(reductionDecls[i].getType()));
-      deferredStores.emplace_back(phis[0], var);
-
-      privateReductionVariables[i] = var;
-      moduleTranslation.mapValue(reductionArgs[i], phis[0]);
-      reductionVariableMap.try_emplace(loop.getReductionVars()[i], phis[0]);
+      //      var->setName("private_redvar");
+
+      llvm::Type *ptrTy = llvm::PointerType::getUnqual(builder.getContext());
+      llvm::Value *castVar =
+          builder.CreatePointerBitCastOrAddrSpaceCast(var, ptrTy);
+      // TODO: I (Sergio) just guessed casting phis[0] like it's done for var is
+      // what's supposed to happen with this code coming from a merge from main,
+      // but I don't actually know. Someone more familiar with it needs to check
+      // this.
+      llvm::Value *castPhi =
+          builder.CreatePointerBitCastOrAddrSpaceCast(phis[0], ptrTy);
+
+      deferredStores.emplace_back(castPhi, castVar);
+
+      privateReductionVariables[i] = castVar;
+      moduleTranslation.mapValue(reductionArgs[i], castPhi);
+      reductionVariableMap.try_emplace(loop.getReductionVars()[i], castPhi);
     } else {
       assert(allocRegion.empty() &&
              "allocaction is implicit for by-val reduction");
       llvm::Value *var = builder.CreateAlloca(
           moduleTranslation.convertType(reductionDecls[i].getType()));
-      moduleTranslation.mapValue(reductionArgs[i], var);
-      privateReductionVariables[i] = var;
-      reductionVariableMap.try_emplace(loop.getReductionVars()[i], var);
+      //      var->setName("private_redvar");
+
+      llvm::Type *ptrTy = llvm::PointerType::getUnqual(builder.getContext());
+      llvm::Value *castVar =
+          builder.CreatePointerBitCastOrAddrSpaceCast(var, ptrTy);
+
+      moduleTranslation.mapValue(reductionArgs[i], castVar);
+      privateReductionVariables[i] = castVar;
+      reductionVariableMap.try_emplace(loop.getReductionVars()[i], castVar);
     }
   }
 
@@ -1250,18 +1267,20 @@ static LogicalResult createReductionsAndCleanup(
     LLVM::ModuleTranslation &moduleTranslation,
     llvm::OpenMPIRBuilder::InsertPointTy &allocaIP,
     SmallVectorImpl<omp::DeclareReductionOp> &reductionDecls,
-    ArrayRef<llvm::Value *> privateReductionVariables, ArrayRef<bool> isByRef) {
+    ArrayRef<llvm::Value *> privateReductionVariables, ArrayRef<bool> isByRef,
+    bool isNowait = false, bool isTeamsReduction = false) {
   // Process the reductions if required.
   if (op.getNumReductionVars() == 0)
     return success();
 
+  SmallVector<OwningReductionGen> owningReductionGens;
+  SmallVector<OwningAtomicReductionGen> owningAtomicReductionGens;
+  SmallVector<llvm::OpenMPIRBuilder::ReductionInfo> reductionInfos;
+
   llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
 
   // Create the reduction generators. We need to own them here because
   // ReductionInfo only accepts references to the generators.
-  SmallVector<OwningReductionGen> owningReductionGens;
-  SmallVector<OwningAtomicReductionGen> owningAtomicReductionGens;
-  SmallVector<llvm::OpenMPIRBuilder::ReductionInfo> reductionInfos;
   collectReductionInfo(op, builder, moduleTranslation, reductionDecls,
                        owningReductionGens, owningAtomicReductionGens,
                        privateReductionVariables, reductionInfos);
@@ -1273,7 +1292,7 @@ static LogicalResult createReductionsAndCleanup(
   builder.SetInsertPoint(tempTerminator);
   llvm::OpenMPIRBuilder::InsertPointOrErrorTy contInsertPoint =
       ompBuilder->createReductions(builder.saveIP(), allocaIP, reductionInfos,
-                                   isByRef, op.getNowait());
+                                   isByRef, isNowait, isTeamsReduction);
 
   if (failed(handleError(contInsertPoint, *op)))
     return failure();
@@ -1666,9 +1685,9 @@ convertOmpSections(Operation &opInst, llvm::IRBuilderBase &builder,
   builder.restoreIP(*afterIP);
 
   // Process the reductions if required.
-  return createReductionsAndCleanup(sectionsOp, builder, moduleTranslation,
-                                    allocaIP, reductionDecls,
-                                    privateReductionVariables, isByRef);
+  return createReductionsAndCleanup(
+      sectionsOp, builder, moduleTranslation, allocaIP, reductionDecls,
+      privateReductionVariables, isByRef, sectionsOp.getNowait());
 }
 
 /// Converts an OpenMP single construct into LLVM IR using OpenMPIRBuilder.
@@ -1714,6 +1733,43 @@ convertOmpSingle(omp::SingleOp &singleOp, llvm::IRBuilderBase &builder,
   return success();
 }
 
+static bool teamsReductionContainedInDistribute(omp::TeamsOp teamsOp) {
+  auto iface =
+      llvm::cast<mlir::omp::BlockArgOpenMPOpInterface>(teamsOp.getOperation());
+  // Check that all uses of the reduction block arg has the same distri...
[truncated]

Copy link
Member

@skatrak skatrak left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you Jan, I have a few superficial comments, but the approach LGTM.

bool doDistributeReduction =
teamsOp ? teamsReductionContainedInDistribute(teamsOp) : false;

DenseMap<Value, llvm::Value *> reductionVariableMap;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: This looks like it could be moved into the if.

It actually looks like it would make sense to make it local to allocAndInitializeReductionVars, since I can't see anywhere it being used outside of it, but that can be looked into later on.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can make a separate cleanup PR for this to make this one smaller.

// Update kernel bounds structure for the `OpenMPIRBuilder` to use.
attrs.ExecFlags = targetOp.getKernelExecFlags(capturedOp);
attrs.MinTeams = minTeamsVal;
attrs.MaxTeams.front() = maxTeamsVal;
attrs.MinThreads = 1;
attrs.MaxThreads.front() = combinedMaxThreadsVal;
attrs.ReductionDataSize = reductionDataSize;
if (attrs.ReductionDataSize != 0)
attrs.ReductionBufferLength = 1024;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: Could you add some comment to document the reasoning for this size, or add a TODO to actually calculate this based on the actual reductions?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I looked into this a bit more and it is basically a fixed value in clang that can be modified with the fopenmp-cuda-teams-reduction-recs-num flang. I added a TODO for now.

@jsjodin jsjodin force-pushed the jleyonberg/reductions branch from d1ecf5f to bce0b41 Compare April 1, 2025 14:25
@jsjodin jsjodin requested a review from skatrak April 3, 2025 16:30
Copy link
Member

@skatrak skatrak left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, LGTM!

Comment on lines +1758 to +1762
// If we are going to use distribute reduction then remove any debug uses of
// the reduction parameters in teamsOp. Otherwise they will be left without
// any mapped value in moduleTranslation and will eventually error out.
for (auto use : debugUses)
use->erase();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@abidh, can you take a look at whether this is safe/the right thing to do?

It seems to me, though I'm not familiar with debug information handling, that another way to deal with this would be to map these omp.teams entry block arguments with the corresponding LLVM values defined while handling the reduction inside of the nested omp.distribute. Then, we would just have to sink these debug ops into omp.loop_nest rather than deleting them and we would be able to keep that information.

I don't actually know if we'd be losing any information by deleting these ops, this is just an idea in case we wanted to keep them.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The omp.distribute has its own copy of the hlfir.declare for the reduction variable so we should not be losing this information even if we drop it from here.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
clang:codegen IR generation bugs: mangling, exceptions, etc. clang:openmp OpenMP related changes to Clang clang Clang issues not falling into any other category flang:openmp mlir:llvm mlir:openmp mlir offload
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants