Skip to content

Commit 94a7006

Browse files
authored
[MLIR][XeGPU][TransformOps] Add set_op_layout_attr op (#166854)
Adds `transform.xegpu.set_op_layout_attr` transform op that attaches `xegpu.layout` attribute to the target op.
1 parent 71cdd40 commit 94a7006

File tree

6 files changed

+455
-19
lines changed

6 files changed

+455
-19
lines changed

mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,4 +96,69 @@ def SetDescLayoutOp : Op<Transform_Dialect, "xegpu.set_desc_layout", [
9696
}];
9797
}
9898

99+
def SetOpLayoutAttrOp : Op<Transform_Dialect, "xegpu.set_op_layout_attr", [
100+
AttrSizedOperandSegments,
101+
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
102+
TransformOpInterface
103+
]> {
104+
105+
let summary = "Set xegpu.layout attribute of an op.";
106+
let description = [{
107+
Sets the `xegpu.layout` attribute of an op. If `result=true`, sets the
108+
`layout_result_{index}`, otherwise `layout_operand_{index}` attribute. The
109+
target operand/result value is defined by the `index` argument. The layout
110+
is defined by the `sg_layout`, `sg_data` and optional `inst_data` attributes.
111+
}];
112+
113+
let arguments = (ins TransformHandleTypeInterface:$target,
114+
DefaultValuedOptionalAttr<I64Attr, "0">:$index,
115+
Variadic<TransformAnyParamTypeOrAnyHandle>:$sg_layout,
116+
Variadic<TransformAnyParamTypeOrAnyHandle>:$sg_data,
117+
Variadic<TransformAnyParamTypeOrAnyHandle>:$inst_data,
118+
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_sg_layout,
119+
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_sg_data,
120+
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_inst_data,
121+
DefaultValuedAttr<UnitAttr, "false">:$result
122+
);
123+
124+
let results = (outs);
125+
let builders = [
126+
OpBuilder<(ins "Value":$target,
127+
"int64_t":$index,
128+
"ArrayRef<OpFoldResult>":$mixedSgLayout,
129+
"ArrayRef<OpFoldResult>":$mixedSgData,
130+
"ArrayRef<OpFoldResult>":$mixedInstData,
131+
CArg<"bool", "false">:$result
132+
)>,
133+
];
134+
135+
let assemblyFormat = [{
136+
$target (`result` $result^)? (`index` `=` $index^)?
137+
`sg_layout` `=` custom<DynamicIndexList>($sg_layout, $static_sg_layout)
138+
`sg_data` `=` custom<DynamicIndexList>($sg_data, $static_sg_data)
139+
(`inst_data` `=` custom<DynamicIndexList>($inst_data, $static_inst_data)^)?
140+
attr-dict `:` qualified(type(operands))
141+
}];
142+
143+
let extraClassDeclaration = [{
144+
::mlir::DiagnosedSilenceableFailure apply(
145+
::mlir::transform::TransformRewriter &rewriter,
146+
::mlir::transform::TransformResults &transformResults,
147+
::mlir::transform::TransformState &state);
148+
149+
::llvm::SmallVector<::mlir::OpFoldResult> getMixedSgLayout() {
150+
Builder b(getContext());
151+
return getMixedValues(getStaticSgLayout(), getSgLayout(), b);
152+
}
153+
::llvm::SmallVector<::mlir::OpFoldResult> getMixedSgData() {
154+
Builder b(getContext());
155+
return getMixedValues(getStaticSgData(), getSgData(), b);
156+
}
157+
::llvm::SmallVector<::mlir::OpFoldResult> getMixedInstData() {
158+
Builder b(getContext());
159+
return getMixedValues(getStaticInstData(), getInstData(), b);
160+
}
161+
}];
162+
}
163+
99164
#endif // XEGPU_TRANSFORM_OPS

mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp

Lines changed: 102 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,36 @@ createLayoutAttr(MLIRContext *ctx, ArrayRef<int32_t> sgLayout,
132132
/*order=*/nullptr);
133133
}
134134

135+
/// Generate `xegpu::LayoutAttr` from op mixed layout values.
136+
DiagnosedSilenceableFailure
137+
getLayoutAttrFromOperands(MLIRContext *ctx, transform::TransformState &state,
138+
TransformOpInterface transformOp,
139+
ArrayRef<::mlir::OpFoldResult> mixedSgLayout,
140+
ArrayRef<::mlir::OpFoldResult> mixedSgData,
141+
ArrayRef<::mlir::OpFoldResult> mixedInstData,
142+
xegpu::LayoutAttr &layoutAttr) {
143+
SmallVector<int32_t> sgLayout, sgData, instData;
144+
auto status =
145+
convertMixedValuesToInt(state, transformOp, sgLayout, mixedSgLayout);
146+
if (!status.succeeded())
147+
return status;
148+
149+
status = convertMixedValuesToInt(state, transformOp, sgData, mixedSgData);
150+
if (!status.succeeded())
151+
return status;
152+
153+
status = convertMixedValuesToInt(state, transformOp, instData, mixedInstData);
154+
if (!status.succeeded())
155+
return status;
156+
auto maybeInstData = instData.empty()
157+
? std::nullopt
158+
: std::optional<ArrayRef<int32_t>>(instData);
159+
160+
layoutAttr = createLayoutAttr(ctx, sgLayout, sgData, maybeInstData);
161+
162+
return DiagnosedSilenceableFailure::success();
163+
}
164+
135165
/// Replace xegpu.create_nd_desc op with a new one with the given layout.
136166
static xegpu::CreateNdDescOp
137167
setDescLayout(transform::TransformRewriter &rewriter,
@@ -207,26 +237,13 @@ transform::SetDescLayoutOp::apply(transform::TransformRewriter &rewriter,
207237
}
208238
Operation *target = *targetOps.begin();
209239

210-
SmallVector<int32_t> sgLayout;
211-
DiagnosedSilenceableFailure status =
212-
convertMixedValuesToInt(state, (*this), sgLayout, getMixedSgLayout());
240+
xegpu::LayoutAttr layoutAttr = nullptr;
241+
auto status = getLayoutAttrFromOperands(getContext(), state, (*this),
242+
getMixedSgLayout(), getMixedSgData(),
243+
getMixedInstData(), layoutAttr);
213244
if (!status.succeeded())
214245
return status;
215246

216-
SmallVector<int32_t> sgData;
217-
status = convertMixedValuesToInt(state, (*this), sgData, getMixedSgData());
218-
if (!status.succeeded())
219-
return status;
220-
221-
SmallVector<int32_t> instData;
222-
status =
223-
convertMixedValuesToInt(state, (*this), instData, getMixedInstData());
224-
if (!status.succeeded())
225-
return status;
226-
auto maybeInstData = instData.empty()
227-
? std::nullopt
228-
: std::optional<ArrayRef<int32_t>>(instData);
229-
230247
// For now only create_nd_desc op is supported.
231248
auto descOp = dyn_cast<xegpu::CreateNdDescOp>(target);
232249
if (!descOp) {
@@ -238,8 +255,6 @@ transform::SetDescLayoutOp::apply(transform::TransformRewriter &rewriter,
238255
}
239256

240257
// Set layout attr in desc op's return type. Replaces old desc op.
241-
auto layoutAttr =
242-
createLayoutAttr(rewriter.getContext(), sgLayout, sgData, maybeInstData);
243258
auto newdescOp = setDescLayout(rewriter, descOp, layoutAttr);
244259

245260
// Map result handles.
@@ -258,6 +273,74 @@ void transform::SetDescLayoutOp::getEffects(
258273
modifiesPayload(effects);
259274
}
260275

276+
void transform::SetOpLayoutAttrOp::build(
277+
OpBuilder &builder, OperationState &ostate, Value target, int64_t index,
278+
ArrayRef<OpFoldResult> mixedSgLayout, ArrayRef<OpFoldResult> mixedSgData,
279+
ArrayRef<OpFoldResult> mixedInstData, bool result) {
280+
SmallVector<int64_t> staticSgLayout, staticSgData, staticInstData;
281+
SmallVector<Value> dynamicSgLayout, dynamicSgData, dynamicInstData;
282+
dispatchIndexOpFoldResults(mixedSgLayout, dynamicSgLayout, staticSgLayout);
283+
dispatchIndexOpFoldResults(mixedSgData, dynamicSgData, staticSgData);
284+
dispatchIndexOpFoldResults(mixedInstData, dynamicInstData, staticInstData);
285+
build(builder, ostate, target.getType(),
286+
/*target=*/target,
287+
/*index=*/index,
288+
/*sg_layout=*/dynamicSgLayout,
289+
/*sg_data=*/dynamicSgData,
290+
/*inst_data=*/dynamicInstData,
291+
/*static_sg_layout=*/staticSgLayout,
292+
/*static_sg_data=*/staticSgData,
293+
/*static_inst_data=*/staticInstData,
294+
/*result=*/result);
295+
}
296+
297+
DiagnosedSilenceableFailure
298+
transform::SetOpLayoutAttrOp::apply(transform::TransformRewriter &rewriter,
299+
transform::TransformResults &results,
300+
transform::TransformState &state) {
301+
auto targetOps = state.getPayloadOps(getTarget());
302+
if (!llvm::hasSingleElement(targetOps)) {
303+
return emitDefiniteFailure() << "Requires exactly one targetOp handle (got "
304+
<< llvm::range_size(targetOps) << ")";
305+
}
306+
Operation *target = *targetOps.begin();
307+
308+
bool resultTarget = getResult();
309+
310+
int64_t index = getIndex();
311+
if (resultTarget && index >= target->getNumResults()) {
312+
return emitSilenceableFailure(getLoc())
313+
<< "Index exceeds the number of op results";
314+
}
315+
if (!resultTarget && index >= target->getNumOperands()) {
316+
return emitSilenceableFailure(getLoc())
317+
<< "Index exceeds the number of op operands";
318+
}
319+
320+
xegpu::LayoutAttr layoutAttr = nullptr;
321+
auto status = getLayoutAttrFromOperands(getContext(), state, (*this),
322+
getMixedSgLayout(), getMixedSgData(),
323+
getMixedInstData(), layoutAttr);
324+
if (!status.succeeded())
325+
return status;
326+
327+
// Set layout attribute for the op result or operand
328+
if (resultTarget)
329+
xegpu::setDistributeLayoutAttr(target->getResult(index), layoutAttr);
330+
else
331+
xegpu::setDistributeLayoutAttr(target->getOpOperand(index), layoutAttr);
332+
return DiagnosedSilenceableFailure::success();
333+
}
334+
335+
void transform::SetOpLayoutAttrOp::getEffects(
336+
::llvm::SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
337+
onlyReadsHandle(getTargetMutable(), effects);
338+
onlyReadsHandle(getSgLayoutMutable(), effects);
339+
onlyReadsHandle(getSgDataMutable(), effects);
340+
onlyReadsHandle(getInstDataMutable(), effects);
341+
modifiesPayload(effects);
342+
}
343+
261344
namespace {
262345
class XeGPUTransformDialectExtension
263346
: public transform::TransformDialectExtension<

mlir/python/mlir/dialects/transform/xegpu.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,3 +85,50 @@ def __init__(
8585
loc=loc,
8686
ip=ip,
8787
)
88+
89+
90+
@_ods_cext.register_operation(_Dialect, replace=True)
91+
class SetOpLayoutAttrOp(SetOpLayoutAttrOp):
92+
"""Specialization for SetOpLayoutAttrOp class."""
93+
94+
def __init__(
95+
self,
96+
target: Union[Operation, Value],
97+
sg_layout: MixedValues,
98+
sg_data: MixedValues,
99+
*,
100+
inst_data: Optional[MixedValues] = None,
101+
index: Optional[Union[int, Attribute]] = None,
102+
result: Optional[Union[bool, Attribute]] = None,
103+
loc=None,
104+
ip=None,
105+
):
106+
inst_data = [] if inst_data is None else inst_data
107+
(
108+
dynamic_sg_layout,
109+
static_sg_layout,
110+
_,
111+
) = _dispatch_dynamic_index_list(sg_layout)
112+
(
113+
dynamic_sg_data,
114+
static_sg_data,
115+
_,
116+
) = _dispatch_dynamic_index_list(sg_data)
117+
(
118+
dynamic_inst_data,
119+
static_inst_data,
120+
_,
121+
) = _dispatch_dynamic_index_list(inst_data)
122+
super().__init__(
123+
_get_op_result_or_value(target),
124+
dynamic_sg_layout,
125+
dynamic_sg_data,
126+
dynamic_inst_data,
127+
static_sg_layout=static_sg_layout,
128+
static_sg_data=static_sg_data,
129+
static_inst_data=static_inst_data,
130+
index=index,
131+
result=result,
132+
loc=loc,
133+
ip=ip,
134+
)

mlir/test/Dialect/XeGPU/transform-ops-invalid.mlir

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,3 +13,61 @@ module attributes {transform.with_named_sequence} {
1313
transform.yield
1414
}
1515
}
16+
17+
// -----
18+
19+
// CHECK-LABEL: @set_op_layout_attr_bad_result_index
20+
func.func @set_op_layout_attr_bad_result_index(%arg0: memref<4096x4096xf16>) {
21+
%0 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16>
22+
%1 = xegpu.load_nd %0[0, 0] : !xegpu.tensor_desc<256x32xf16> -> vector<256x32xf16>
23+
%2 = arith.extf %1 : vector<256x32xf16> to vector<256x32xf32>
24+
return
25+
}
26+
27+
module attributes {transform.with_named_sequence} {
28+
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
29+
%0 = transform.structured.match ops{["arith.extf"]} in %arg1 : (!transform.any_op) -> !transform.any_op
30+
// expected-error@below {{Index exceeds the number of op results}}
31+
transform.xegpu.set_op_layout_attr %0 result index = 1 sg_layout = [8, 4] sg_data = [32, 64] : !transform.any_op
32+
transform.yield
33+
}
34+
}
35+
36+
// -----
37+
38+
// CHECK-LABEL: @set_op_layout_attr_bad_operand_index
39+
func.func @set_op_layout_attr_bad_operand_index(%arg0: memref<4096x4096xf16>) {
40+
%0 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16>
41+
%1 = xegpu.load_nd %0[0, 0] : !xegpu.tensor_desc<256x32xf16> -> vector<256x32xf16>
42+
%2 = arith.extf %1 : vector<256x32xf16> to vector<256x32xf32>
43+
return
44+
}
45+
46+
module attributes {transform.with_named_sequence} {
47+
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
48+
%0 = transform.structured.match ops{["arith.extf"]} in %arg1 : (!transform.any_op) -> !transform.any_op
49+
// expected-error@below {{Index exceeds the number of op operands}}
50+
transform.xegpu.set_op_layout_attr %0 index = 1 sg_layout = [8, 4] sg_data = [32, 64] : !transform.any_op
51+
transform.yield
52+
}
53+
}
54+
55+
// -----
56+
57+
// CHECK-LABEL: @set_op_layout_attr_multiple
58+
func.func @set_op_layout_attr_multiple(%arg0: memref<4096x4096xf16>) {
59+
%0 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16>
60+
%1 = xegpu.load_nd %0[0, 0] : !xegpu.tensor_desc<256x32xf16> -> vector<256x32xf16>
61+
%2 = arith.extf %1 : vector<256x32xf16> to vector<256x32xf32>
62+
%3 = arith.extf %2 : vector<256x32xf32> to vector<256x32xf64>
63+
return
64+
}
65+
66+
module attributes {transform.with_named_sequence} {
67+
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
68+
%0 = transform.structured.match ops{["arith.extf"]} in %arg1 : (!transform.any_op) -> !transform.any_op
69+
// expected-error@below {{Requires exactly one targetOp handle (got 2)}}
70+
transform.xegpu.set_op_layout_attr %0 sg_layout = [8, 4] sg_data = [32, 64] : !transform.any_op
71+
transform.yield
72+
}
73+
}

0 commit comments

Comments
 (0)