Skip to content

Commit 98f256c

Browse files
committed
[MLIR][OpenMP] Skip host omp ops when compiling for the target device
This patch separates the lowering dispatch for host and target devices. For the target device, if the current operation is not a top-level operation (e.g. omp.target) or is inside a target device code region it will be ignored, since it belongs to the host code.
1 parent 1bce411 commit 98f256c

8 files changed

+312
-118
lines changed

mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp

+176-111
Original file line numberDiff line numberDiff line change
@@ -3116,6 +3116,172 @@ convertDeclareTargetAttr(Operation *op, mlir::omp::DeclareTargetAttr attribute,
31163116
return success();
31173117
}
31183118

3119+
static bool isTargetDeviceOp(Operation *op) {
3120+
// Assumes no reverse offloading
3121+
if (op->getParentOfType<omp::TargetOp>())
3122+
return true;
3123+
3124+
if (auto parentFn = op->getParentOfType<LLVM::LLVMFuncOp>())
3125+
if (auto declareTargetIface =
3126+
llvm::dyn_cast<mlir::omp::DeclareTargetInterface>(
3127+
parentFn.getOperation()))
3128+
if (declareTargetIface.isDeclareTarget() &&
3129+
declareTargetIface.getDeclareTargetDeviceType() !=
3130+
mlir::omp::DeclareTargetDeviceType::host)
3131+
return true;
3132+
3133+
return false;
3134+
}
3135+
3136+
/// Given an OpenMP MLIR operation, create the corresponding LLVM IR
3137+
/// (including OpenMP runtime calls).
3138+
static LogicalResult
3139+
convertHostOrTargetOperation(Operation *op, llvm::IRBuilderBase &builder,
3140+
LLVM::ModuleTranslation &moduleTranslation) {
3141+
3142+
llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
3143+
3144+
return llvm::TypeSwitch<Operation *, LogicalResult>(op)
3145+
.Case([&](omp::BarrierOp) {
3146+
ompBuilder->createBarrier(builder.saveIP(), llvm::omp::OMPD_barrier);
3147+
return success();
3148+
})
3149+
.Case([&](omp::TaskwaitOp) {
3150+
ompBuilder->createTaskwait(builder.saveIP());
3151+
return success();
3152+
})
3153+
.Case([&](omp::TaskyieldOp) {
3154+
ompBuilder->createTaskyield(builder.saveIP());
3155+
return success();
3156+
})
3157+
.Case([&](omp::FlushOp) {
3158+
// No support in Openmp runtime function (__kmpc_flush) to accept
3159+
// the argument list.
3160+
// OpenMP standard states the following:
3161+
// "An implementation may implement a flush with a list by ignoring
3162+
// the list, and treating it the same as a flush without a list."
3163+
//
3164+
// The argument list is discarded so that, flush with a list is treated
3165+
// same as a flush without a list.
3166+
ompBuilder->createFlush(builder.saveIP());
3167+
return success();
3168+
})
3169+
.Case([&](omp::ParallelOp op) {
3170+
return convertOmpParallel(op, builder, moduleTranslation);
3171+
})
3172+
.Case([&](omp::ReductionOp reductionOp) {
3173+
return convertOmpReductionOp(reductionOp, builder, moduleTranslation);
3174+
})
3175+
.Case([&](omp::MasterOp) {
3176+
return convertOmpMaster(*op, builder, moduleTranslation);
3177+
})
3178+
.Case([&](omp::CriticalOp) {
3179+
return convertOmpCritical(*op, builder, moduleTranslation);
3180+
})
3181+
.Case([&](omp::OrderedRegionOp) {
3182+
return convertOmpOrderedRegion(*op, builder, moduleTranslation);
3183+
})
3184+
.Case([&](omp::OrderedOp) {
3185+
return convertOmpOrdered(*op, builder, moduleTranslation);
3186+
})
3187+
.Case([&](omp::WsloopOp) {
3188+
return convertOmpWsloop(*op, builder, moduleTranslation);
3189+
})
3190+
.Case([&](omp::SimdLoopOp) {
3191+
return convertOmpSimdLoop(*op, builder, moduleTranslation);
3192+
})
3193+
.Case([&](omp::AtomicReadOp) {
3194+
return convertOmpAtomicRead(*op, builder, moduleTranslation);
3195+
})
3196+
.Case([&](omp::AtomicWriteOp) {
3197+
return convertOmpAtomicWrite(*op, builder, moduleTranslation);
3198+
})
3199+
.Case([&](omp::AtomicUpdateOp op) {
3200+
return convertOmpAtomicUpdate(op, builder, moduleTranslation);
3201+
})
3202+
.Case([&](omp::AtomicCaptureOp op) {
3203+
return convertOmpAtomicCapture(op, builder, moduleTranslation);
3204+
})
3205+
.Case([&](omp::SectionsOp) {
3206+
return convertOmpSections(*op, builder, moduleTranslation);
3207+
})
3208+
.Case([&](omp::SingleOp op) {
3209+
return convertOmpSingle(op, builder, moduleTranslation);
3210+
})
3211+
.Case([&](omp::TeamsOp op) {
3212+
return convertOmpTeams(op, builder, moduleTranslation);
3213+
})
3214+
.Case([&](omp::TaskOp op) {
3215+
return convertOmpTaskOp(op, builder, moduleTranslation);
3216+
})
3217+
.Case([&](omp::TaskgroupOp op) {
3218+
return convertOmpTaskgroupOp(op, builder, moduleTranslation);
3219+
})
3220+
.Case<omp::YieldOp, omp::TerminatorOp, omp::DeclareReductionOp,
3221+
omp::CriticalDeclareOp>([](auto op) {
3222+
// `yield` and `terminator` can be just omitted. The block structure
3223+
// was created in the region that handles their parent operation.
3224+
// `declare_reduction` will be used by reductions and is not
3225+
// converted directly, skip it.
3226+
// `critical.declare` is only used to declare names of critical
3227+
// sections which will be used by `critical` ops and hence can be
3228+
// ignored for lowering. The OpenMP IRBuilder will create unique
3229+
// name for critical section names.
3230+
return success();
3231+
})
3232+
.Case([&](omp::ThreadprivateOp) {
3233+
return convertOmpThreadprivate(*op, builder, moduleTranslation);
3234+
})
3235+
.Case<omp::TargetDataOp, omp::TargetEnterDataOp, omp::TargetExitDataOp,
3236+
omp::TargetUpdateOp>([&](auto op) {
3237+
return convertOmpTargetData(op, builder, moduleTranslation);
3238+
})
3239+
.Case([&](omp::TargetOp) {
3240+
return convertOmpTarget(*op, builder, moduleTranslation);
3241+
})
3242+
.Case<omp::MapInfoOp, omp::MapBoundsOp, omp::PrivateClauseOp>(
3243+
[&](auto op) {
3244+
// No-op, should be handled by relevant owning operations e.g.
3245+
// TargetOp, TargetEnterDataOp, TargetExitDataOp, TargetDataOp etc.
3246+
// and then discarded
3247+
return success();
3248+
})
3249+
.Default([&](Operation *inst) {
3250+
return inst->emitError("unsupported OpenMP operation: ")
3251+
<< inst->getName();
3252+
});
3253+
}
3254+
3255+
static LogicalResult
3256+
convertTargetDeviceOp(Operation *op, llvm::IRBuilderBase &builder,
3257+
LLVM::ModuleTranslation &moduleTranslation) {
3258+
return convertHostOrTargetOperation(op, builder, moduleTranslation);
3259+
}
3260+
3261+
static LogicalResult
3262+
convertTargetOpsInNest(Operation *op, llvm::IRBuilderBase &builder,
3263+
LLVM::ModuleTranslation &moduleTranslation) {
3264+
if (isa<omp::TargetOp>(op))
3265+
return convertOmpTarget(*op, builder, moduleTranslation);
3266+
if (isa<omp::TargetDataOp>(op))
3267+
return convertOmpTargetData(op, builder, moduleTranslation);
3268+
bool interrupted =
3269+
op->walk<WalkOrder::PreOrder>([&](Operation *oper) {
3270+
if (isa<omp::TargetOp>(oper)) {
3271+
if (failed(convertOmpTarget(*oper, builder, moduleTranslation)))
3272+
return WalkResult::interrupt();
3273+
return WalkResult::skip();
3274+
}
3275+
if (isa<omp::TargetDataOp>(oper)) {
3276+
if (failed(convertOmpTargetData(oper, builder, moduleTranslation)))
3277+
return WalkResult::interrupt();
3278+
return WalkResult::skip();
3279+
}
3280+
return WalkResult::advance();
3281+
}).wasInterrupted();
3282+
return failure(interrupted);
3283+
}
3284+
31193285
namespace {
31203286

31213287
/// Implementation of the dialect interface that converts operations belonging
@@ -3131,8 +3297,8 @@ class OpenMPDialectLLVMIRTranslationInterface
31313297
convertOperation(Operation *op, llvm::IRBuilderBase &builder,
31323298
LLVM::ModuleTranslation &moduleTranslation) const final;
31333299

3134-
/// Given an OpenMP MLIR attribute, create the corresponding LLVM-IR, runtime
3135-
/// calls, or operation amendments
3300+
/// Given an OpenMP MLIR attribute, create the corresponding LLVM-IR,
3301+
/// runtime calls, or operation amendments
31363302
LogicalResult
31373303
amendOperation(Operation *op, ArrayRef<llvm::Instruction *> instructions,
31383304
NamedAttribute attribute,
@@ -3237,116 +3403,15 @@ LogicalResult OpenMPDialectLLVMIRTranslationInterface::convertOperation(
32373403
LLVM::ModuleTranslation &moduleTranslation) const {
32383404

32393405
llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
3406+
if (ompBuilder->Config.isTargetDevice()) {
3407+
if (isTargetDeviceOp(op)) {
3408+
return convertTargetDeviceOp(op, builder, moduleTranslation);
3409+
} else {
3410+
return convertTargetOpsInNest(op, builder, moduleTranslation);
3411+
}
3412+
}
32403413

3241-
return llvm::TypeSwitch<Operation *, LogicalResult>(op)
3242-
.Case([&](omp::BarrierOp) {
3243-
ompBuilder->createBarrier(builder.saveIP(), llvm::omp::OMPD_barrier);
3244-
return success();
3245-
})
3246-
.Case([&](omp::TaskwaitOp) {
3247-
ompBuilder->createTaskwait(builder.saveIP());
3248-
return success();
3249-
})
3250-
.Case([&](omp::TaskyieldOp) {
3251-
ompBuilder->createTaskyield(builder.saveIP());
3252-
return success();
3253-
})
3254-
.Case([&](omp::FlushOp) {
3255-
// No support in Openmp runtime function (__kmpc_flush) to accept
3256-
// the argument list.
3257-
// OpenMP standard states the following:
3258-
// "An implementation may implement a flush with a list by ignoring
3259-
// the list, and treating it the same as a flush without a list."
3260-
//
3261-
// The argument list is discarded so that, flush with a list is treated
3262-
// same as a flush without a list.
3263-
ompBuilder->createFlush(builder.saveIP());
3264-
return success();
3265-
})
3266-
.Case([&](omp::ParallelOp op) {
3267-
return convertOmpParallel(op, builder, moduleTranslation);
3268-
})
3269-
.Case([&](omp::ReductionOp reductionOp) {
3270-
return convertOmpReductionOp(reductionOp, builder, moduleTranslation);
3271-
})
3272-
.Case([&](omp::MasterOp) {
3273-
return convertOmpMaster(*op, builder, moduleTranslation);
3274-
})
3275-
.Case([&](omp::CriticalOp) {
3276-
return convertOmpCritical(*op, builder, moduleTranslation);
3277-
})
3278-
.Case([&](omp::OrderedRegionOp) {
3279-
return convertOmpOrderedRegion(*op, builder, moduleTranslation);
3280-
})
3281-
.Case([&](omp::OrderedOp) {
3282-
return convertOmpOrdered(*op, builder, moduleTranslation);
3283-
})
3284-
.Case([&](omp::WsloopOp) {
3285-
return convertOmpWsloop(*op, builder, moduleTranslation);
3286-
})
3287-
.Case([&](omp::SimdLoopOp) {
3288-
return convertOmpSimdLoop(*op, builder, moduleTranslation);
3289-
})
3290-
.Case([&](omp::AtomicReadOp) {
3291-
return convertOmpAtomicRead(*op, builder, moduleTranslation);
3292-
})
3293-
.Case([&](omp::AtomicWriteOp) {
3294-
return convertOmpAtomicWrite(*op, builder, moduleTranslation);
3295-
})
3296-
.Case([&](omp::AtomicUpdateOp op) {
3297-
return convertOmpAtomicUpdate(op, builder, moduleTranslation);
3298-
})
3299-
.Case([&](omp::AtomicCaptureOp op) {
3300-
return convertOmpAtomicCapture(op, builder, moduleTranslation);
3301-
})
3302-
.Case([&](omp::SectionsOp) {
3303-
return convertOmpSections(*op, builder, moduleTranslation);
3304-
})
3305-
.Case([&](omp::SingleOp op) {
3306-
return convertOmpSingle(op, builder, moduleTranslation);
3307-
})
3308-
.Case([&](omp::TeamsOp op) {
3309-
return convertOmpTeams(op, builder, moduleTranslation);
3310-
})
3311-
.Case([&](omp::TaskOp op) {
3312-
return convertOmpTaskOp(op, builder, moduleTranslation);
3313-
})
3314-
.Case([&](omp::TaskgroupOp op) {
3315-
return convertOmpTaskgroupOp(op, builder, moduleTranslation);
3316-
})
3317-
.Case<omp::YieldOp, omp::TerminatorOp, omp::DeclareReductionOp,
3318-
omp::CriticalDeclareOp>([](auto op) {
3319-
// `yield` and `terminator` can be just omitted. The block structure
3320-
// was created in the region that handles their parent operation.
3321-
// `declare_reduction` will be used by reductions and is not
3322-
// converted directly, skip it.
3323-
// `critical.declare` is only used to declare names of critical
3324-
// sections which will be used by `critical` ops and hence can be
3325-
// ignored for lowering. The OpenMP IRBuilder will create unique
3326-
// name for critical section names.
3327-
return success();
3328-
})
3329-
.Case([&](omp::ThreadprivateOp) {
3330-
return convertOmpThreadprivate(*op, builder, moduleTranslation);
3331-
})
3332-
.Case<omp::TargetDataOp, omp::TargetEnterDataOp, omp::TargetExitDataOp,
3333-
omp::TargetUpdateOp>([&](auto op) {
3334-
return convertOmpTargetData(op, builder, moduleTranslation);
3335-
})
3336-
.Case([&](omp::TargetOp) {
3337-
return convertOmpTarget(*op, builder, moduleTranslation);
3338-
})
3339-
.Case<omp::MapInfoOp, omp::MapBoundsOp, omp::PrivateClauseOp>(
3340-
[&](auto op) {
3341-
// No-op, should be handled by relevant owning operations e.g.
3342-
// TargetOp, TargetEnterDataOp, TargetExitDataOp, TargetDataOp etc.
3343-
// and then discarded
3344-
return success();
3345-
})
3346-
.Default([&](Operation *inst) {
3347-
return inst->emitError("unsupported OpenMP operation: ")
3348-
<< inst->getName();
3349-
});
3414+
return convertHostOrTargetOperation(op, builder, moduleTranslation);
33503415
}
33513416

33523417
void mlir::registerOpenMPDialectTranslation(DialectRegistry &registry) {

mlir/test/Target/LLVMIR/omptarget-parallel-wsloop.mlir

+3-3
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,10 @@
44
// for nested omp do loop inside omp target region
55

66
module attributes {dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry<"dlti.alloca_memory_space", 5 : ui32>>, llvm.data_layout = "e-p:64:64-p1:64:64-p2:32:32-p3:32:32-p4:64:64-p5:32:32-p6:32:32-p7:160:256:256:32-p8:128:128-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-v2048:2048-n32:64-S32-A5-G1-ni:7:8", llvm.target_triple = "amdgcn-amd-amdhsa", omp.is_gpu = true, omp.is_target_device = true } {
7-
llvm.func @target_parallel_wsloop(%arg0: !llvm.ptr) attributes {
7+
llvm.func @target_parallel_wsloop(%arg0: !llvm.ptr) attributes {omp.declare_target = #omp.declaretarget<device_type = (any), capture_clause = (to)>,
88
target_cpu = "gfx90a",
9-
target_features = #llvm.target_features<["+gfx9-insts", "+wavefrontsize64"]>
10-
} {
9+
target_features = #llvm.target_features<["+gfx9-insts", "+wavefrontsize64"]>}
10+
{
1111
omp.parallel {
1212
%loop_ub = llvm.mlir.constant(9 : i32) : i32
1313
%loop_lb = llvm.mlir.constant(0 : i32) : i32
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s
2+
3+
module attributes {omp.is_target_device = true, omp.is_gpu = true} {
4+
llvm.func @omp_target_region_() {
5+
%0 = llvm.mlir.constant(20 : i32) : i32
6+
%1 = llvm.mlir.constant(10 : i32) : i32
7+
%2 = llvm.mlir.constant(1 : i64) : i64
8+
%3 = llvm.alloca %2 x i32 {bindc_name = "a", in_type = i32, operandSegmentSizes = array<i32: 0, 0>, uniq_name = "_QFomp_target_regionEa"} : (i64) -> !llvm.ptr
9+
%4 = llvm.mlir.constant(1 : i64) : i64
10+
%5 = llvm.alloca %4 x i32 {bindc_name = "b", in_type = i32, operandSegmentSizes = array<i32: 0, 0>, uniq_name = "_QFomp_target_regionEb"} : (i64) -> !llvm.ptr
11+
%6 = llvm.mlir.constant(1 : i64) : i64
12+
%7 = llvm.alloca %6 x i32 {bindc_name = "c", in_type = i32, operandSegmentSizes = array<i32: 0, 0>, uniq_name = "_QFomp_target_regionEc"} : (i64) -> !llvm.ptr
13+
llvm.store %1, %3 : i32, !llvm.ptr
14+
llvm.store %0, %5 : i32, !llvm.ptr
15+
omp.task {
16+
%map1 = omp.map.info var_ptr(%3 : !llvm.ptr, i32) map_clauses(tofrom) capture(ByRef) -> !llvm.ptr {name = ""}
17+
%map2 = omp.map.info var_ptr(%5 : !llvm.ptr, i32) map_clauses(tofrom) capture(ByRef) -> !llvm.ptr {name = ""}
18+
%map3 = omp.map.info var_ptr(%7 : !llvm.ptr, i32) map_clauses(tofrom) capture(ByRef) -> !llvm.ptr {name = ""}
19+
omp.target map_entries(%map1 -> %arg0, %map2 -> %arg1, %map3 -> %arg2 : !llvm.ptr, !llvm.ptr, !llvm.ptr) {
20+
^bb0(%arg0: !llvm.ptr, %arg1: !llvm.ptr, %arg2: !llvm.ptr):
21+
%8 = llvm.load %arg0 : !llvm.ptr -> i32
22+
%9 = llvm.load %arg1 : !llvm.ptr -> i32
23+
%10 = llvm.add %8, %9 : i32
24+
llvm.store %10, %arg2 : i32, !llvm.ptr
25+
omp.terminator
26+
}
27+
omp.terminator
28+
}
29+
llvm.return
30+
}
31+
32+
llvm.func @omp_target_no_map() {
33+
omp.target {
34+
omp.terminator
35+
}
36+
llvm.return
37+
}
38+
}
39+
40+
// CHECK: define weak_odr protected void @__omp_offloading_{{.*}}_{{.*}}_omp_target_region__l19
41+
// CHECK: ret void

mlir/test/Target/LLVMIR/omptarget-teams-llvm.mlir

+1-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
module attributes {omp.is_target_device = true} {
77
llvm.func @foo(i32)
8-
llvm.func @omp_target_teams_shared_simple(%arg0 : i32) {
8+
llvm.func @omp_target_teams_shared_simple(%arg0 : i32) attributes {omp.declare_target = #omp.declaretarget<device_type = (any), capture_clause = (to)>} {
99
omp.teams {
1010
llvm.call @foo(%arg0) : (i32) -> ()
1111
omp.terminator

mlir/test/Target/LLVMIR/omptarget-wsloop-collapsed.mlir

+1-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
// for nested omp do loop with collapse clause inside omp target region
55

66
module attributes {dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry<"dlti.alloca_memory_space", 5 : ui32>>, llvm.data_layout = "e-p:64:64-p1:64:64-p2:32:32-p3:32:32-p4:64:64-p5:32:32-p6:32:32-p7:160:256:256:32-p8:128:128-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-v2048:2048-n32:64-S32-A5-G1-ni:7:8", llvm.target_triple = "amdgcn-amd-amdhsa", omp.is_gpu = true, omp.is_target_device = true } {
7-
llvm.func @target_collapsed_wsloop(%arg0: !llvm.ptr) {
7+
llvm.func @target_collapsed_wsloop(%arg0: !llvm.ptr) attributes {omp.declare_target = #omp.declaretarget<device_type = (any), capture_clause = (to)>} {
88
%loop_ub = llvm.mlir.constant(99 : i32) : i32
99
%loop_lb = llvm.mlir.constant(0 : i32) : i32
1010
%loop_step = llvm.mlir.constant(1 : index) : i32

0 commit comments

Comments
 (0)