Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -78,4 +78,69 @@ def SetDescLayoutOp : Op<Transform_Dialect, "xegpu.set_desc_layout", [
}];
}

def SetOpLayoutAttrOp : Op<Transform_Dialect, "xegpu.set_op_layout_attr", [
AttrSizedOperandSegments,
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
TransformOpInterface
]> {

let summary = "Set xegpu.layout attribute of an op.";
let description = [{
Sets the `xegpu.layout` attribute of an op. If `result=true`, sets the
`layout_result_{index}`, otherwise `layout_operand_{index}` attribute. The
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a warning here: XeGPU is refactoring the layout setting API which user only need to set the anchor op's layout, and XeGPU will take care the propagation automatically. The new API that sets anchor ops' layout doesn't require "layout_result_" and "layout_operand_", so may impact this operation also.

target operand/result value is defined by the `index` argument. The layout
is defined by the `sg_layout`, `sg_data` and optional `inst_data` attributes.
}];

let arguments = (ins TransformHandleTypeInterface : $target,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: colon style

DefaultValuedOptionalAttr<I64Attr, "0"> : $index,
Variadic<TransformAnyParamTypeOrAnyHandle> : $sg_layout,
Variadic<TransformAnyParamTypeOrAnyHandle> : $sg_data,
Variadic<TransformAnyParamTypeOrAnyHandle> : $inst_data,
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_sg_layout,
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_sg_data,
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_inst_data,
DefaultValuedAttr<UnitAttr, "false">:$result
);

let results = (outs);
let builders = [
OpBuilder<(ins "Value":$target,
"int64_t":$index,
"ArrayRef<OpFoldResult>":$mixedSgLayout,
"ArrayRef<OpFoldResult>":$mixedSgData,
"ArrayRef<OpFoldResult>":$mixedInstData,
CArg<"bool", "false">:$result
)>,
];

let assemblyFormat = [{
$target (`result` $result^)? (`index` `=` $index^)?
`sg_layout` `=` custom<DynamicIndexList>($sg_layout, $static_sg_layout)
`sg_data` `=` custom<DynamicIndexList>($sg_data, $static_sg_data)
(`inst_data` `=` custom<DynamicIndexList>($inst_data, $static_inst_data)^)?
attr-dict `:` qualified(type(operands))
}];

let extraClassDeclaration = [{
::mlir::DiagnosedSilenceableFailure apply(
::mlir::transform::TransformRewriter &rewriter,
::mlir::transform::TransformResults &transformResults,
::mlir::transform::TransformState &state);

::llvm::SmallVector<::mlir::OpFoldResult> getMixedSgLayout() {
Builder b(getContext());
return getMixedValues(getStaticSgLayout(), getSgLayout(), b);
}
::llvm::SmallVector<::mlir::OpFoldResult> getMixedSgData() {
Builder b(getContext());
return getMixedValues(getStaticSgData(), getSgData(), b);
}
::llvm::SmallVector<::mlir::OpFoldResult> getMixedInstData() {
Builder b(getContext());
return getMixedValues(getStaticInstData(), getInstData(), b);
}
}];
}

#endif // XEGPU_TRANSFORM_OPS
125 changes: 106 additions & 19 deletions mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,38 @@ createLayoutAttr(MLIRContext *ctx, ArrayRef<int32_t> sgLayout,
/*order=*/nullptr);
}

/// Generate `xegpu::LayoutAttr` from op mixed layout values.
DiagnosedSilenceableFailure
getLayoutAttrFromOperands(transform::TransformRewriter &rewriter,
transform::TransformState &state,
Comment on lines +95 to +96
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about just passing the context here?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, I see the state is also used by the convertMixedValuesToInt utility. Is it because it does the mapping from a SSA-value to the interpreter's associated values?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, in convertMixedValuesToInt we use state.getPayloadOps and state.getParams to get handles to payload ops and transform params, respectively. I don't think we can get rid of that.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you need the rewriter, though?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, you're right. We just need the context. I'll fix it.

TransformOpInterface transformOp,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you're just constructing the attribute, it's not obvious why you need the op here. It is maybe to do error reporting? If so, maybe passing the location suffices?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm -- I see convertMixedValuesToInt is an existing function which takes the op. Might still be worth looking at that utility and see if passing the op is really needed.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, transformOp is only used for error handling in convertMixedValuesToInt. I'll see if errors could be propagated in another way.

ArrayRef<::mlir::OpFoldResult> mixedSgLayout,
ArrayRef<::mlir::OpFoldResult> mixedSgData,
ArrayRef<::mlir::OpFoldResult> mixedInstData,
xegpu::LayoutAttr &layoutAttr) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing order attribute. It defines the layout of subgroup ids and lane ids. It is optional, if not set, the layout of these ids are row major. It can be added later when you need to handle vector.transpose.

SmallVector<int32_t> sgLayout, sgData, instData;
auto status =
convertMixedValuesToInt(state, transformOp, sgLayout, mixedSgLayout);
if (!status.succeeded())
return status;

status = convertMixedValuesToInt(state, transformOp, sgData, mixedSgData);
if (!status.succeeded())
return status;

status = convertMixedValuesToInt(state, transformOp, instData, mixedInstData);
if (!status.succeeded())
return status;
auto maybeInstData = instData.empty()
? std::nullopt
: std::optional<ArrayRef<int32_t>>(instData);

layoutAttr =
createLayoutAttr(rewriter.getContext(), sgLayout, sgData, maybeInstData);

return DiagnosedSilenceableFailure::success();
}

/// Replace xegpu.create_nd_desc op with a new one with the given layout.
static xegpu::CreateNdDescOp
setDescLayout(transform::TransformRewriter &rewriter,
Expand Down Expand Up @@ -142,26 +174,13 @@ transform::SetDescLayoutOp::apply(transform::TransformRewriter &rewriter,
}
Operation *target = *targetOps.begin();

SmallVector<int32_t> sgLayout;
DiagnosedSilenceableFailure status =
convertMixedValuesToInt(state, (*this), sgLayout, getMixedSgLayout());
xegpu::LayoutAttr layoutAttr = nullptr;
auto status = getLayoutAttrFromOperands(rewriter, state, (*this),
getMixedSgLayout(), getMixedSgData(),
getMixedInstData(), layoutAttr);
if (!status.succeeded())
return status;

SmallVector<int32_t> sgData;
status = convertMixedValuesToInt(state, (*this), sgData, getMixedSgData());
if (!status.succeeded())
return status;

SmallVector<int32_t> instData;
status =
convertMixedValuesToInt(state, (*this), instData, getMixedInstData());
if (!status.succeeded())
return status;
auto maybeInstData = instData.empty()
? std::nullopt
: std::optional<ArrayRef<int32_t>>(instData);

// For now only create_nd_desc op is supported.
auto descOp = dyn_cast<xegpu::CreateNdDescOp>(target);
if (!descOp) {
Expand All @@ -173,8 +192,6 @@ transform::SetDescLayoutOp::apply(transform::TransformRewriter &rewriter,
}

// Set layout attr in desc op's return type. Replaces old desc op.
auto layoutAttr =
createLayoutAttr(rewriter.getContext(), sgLayout, sgData, maybeInstData);
auto newdescOp = setDescLayout(rewriter, descOp, layoutAttr);

// Map result handles.
Expand All @@ -193,6 +210,76 @@ void transform::SetDescLayoutOp::getEffects(
modifiesPayload(effects);
}

void transform::SetOpLayoutAttrOp::build(
OpBuilder &builder, OperationState &ostate, Value target, int64_t index,
ArrayRef<OpFoldResult> mixedSgLayout, ArrayRef<OpFoldResult> mixedSgData,
ArrayRef<OpFoldResult> mixedInstData, bool result) {
SmallVector<int64_t> staticSgLayout, staticSgData, staticInstData;
SmallVector<Value> dynamicSgLayout, dynamicSgData, dynamicInstData;
dispatchIndexOpFoldResults(mixedSgLayout, dynamicSgLayout, staticSgLayout);
dispatchIndexOpFoldResults(mixedSgData, dynamicSgData, staticSgData);
dispatchIndexOpFoldResults(mixedInstData, dynamicInstData, staticInstData);
build(builder, ostate, target.getType(),
/*target=*/target,
/*index=*/index,
/*sg_layout=*/dynamicSgLayout,
/*sg_data=*/dynamicSgData,
/*inst_data=*/dynamicInstData,
/*static_sg_layout=*/staticSgLayout,
/*static_sg_data=*/staticSgData,
/*static_inst_data=*/staticInstData,
/*result=*/result);
}

DiagnosedSilenceableFailure
transform::SetOpLayoutAttrOp::apply(transform::TransformRewriter &rewriter,
transform::TransformResults &results,
transform::TransformState &state) {

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: whitespace

auto targetOps = state.getPayloadOps(getTarget());
if (!llvm::hasSingleElement(targetOps)) {
return emitDefiniteFailure() << "Requires exactly one targetOp handle (got "
<< llvm::range_size(targetOps) << ")";
}
Operation *target = *targetOps.begin();

bool resultTarget = getResult();

int64_t index = getIndex();
if (resultTarget && index >= target->getNumResults()) {
return emitSilenceableFailure(getLoc())
<< "Index exceeds the number of op results";
}
if (!resultTarget && index >= target->getNumOperands()) {
return emitSilenceableFailure(getLoc())
<< "Index exceeds the number of op operands";
}

xegpu::LayoutAttr layoutAttr = nullptr;
auto status = getLayoutAttrFromOperands(rewriter, state, (*this),
getMixedSgLayout(), getMixedSgData(),
getMixedInstData(), layoutAttr);
if (!status.succeeded())
return status;

// Set layout attribute for the op result or operand
if (resultTarget) {
xegpu::setDistributeLayoutAttr(target->getResult(index), layoutAttr);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: braces can be skipped

} else {
xegpu::setDistributeLayoutAttr(target->getOpOperand(index), layoutAttr);
}
return DiagnosedSilenceableFailure::success();
}

void transform::SetOpLayoutAttrOp::getEffects(
::llvm::SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
onlyReadsHandle(getTargetMutable(), effects);
onlyReadsHandle(getSgLayoutMutable(), effects);
onlyReadsHandle(getSgDataMutable(), effects);
onlyReadsHandle(getInstDataMutable(), effects);
modifiesPayload(effects);
}

namespace {
class XeGPUTransformDialectExtension
: public transform::TransformDialectExtension<
Expand Down
47 changes: 47 additions & 0 deletions mlir/python/mlir/dialects/transform/xegpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,3 +64,50 @@ def __init__(
loc=loc,
ip=ip,
)


@_ods_cext.register_operation(_Dialect, replace=True)
class SetOpLayoutAttrOp(SetOpLayoutAttrOp):
"""Specialization for SetOpLayoutAttrOp class."""

def __init__(
self,
target: Union[Operation, Value],
sg_layout: MixedValues,
sg_data: MixedValues,
*,
inst_data: MixedValues = None,
index: Union[int, Attribute] = None,
result: Union[bool, Attribute] = None,
Comment on lines +79 to +81
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pedantically speaking, I believe these types should be wrapped in Optional[...]

loc=None,
ip=None,
):
inst_data = [] if inst_data is None else inst_data
(
dynamic_sg_layout,
static_sg_layout,
_,
) = _dispatch_dynamic_index_list(sg_layout)
(
dynamic_sg_data,
static_sg_data,
_,
) = _dispatch_dynamic_index_list(sg_data)
(
dynamic_inst_data,
static_inst_data,
_,
) = _dispatch_dynamic_index_list(inst_data)
super().__init__(
_get_op_result_or_value(target),
dynamic_sg_layout,
dynamic_sg_data,
dynamic_inst_data,
static_sg_layout=static_sg_layout,
static_sg_data=static_sg_data,
static_inst_data=static_inst_data,
index=index,
result=result,
loc=loc,
ip=ip,
)
58 changes: 58 additions & 0 deletions mlir/test/Dialect/XeGPU/transform-ops-invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,61 @@ module attributes {transform.with_named_sequence} {
transform.yield
}
}

// -----

// CHECK-LABEL: @set_op_layout_attr_bad_result_index
func.func @set_op_layout_attr_bad_result_index(%arg0: memref<4096x4096xf16>) {
%0 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16>
%1 = xegpu.load_nd %0[0, 0] : !xegpu.tensor_desc<256x32xf16> -> vector<256x32xf16>
%2 = arith.extf %1 : vector<256x32xf16> to vector<256x32xf32>
return
}

module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["arith.extf"]} in %arg1 : (!transform.any_op) -> !transform.any_op
// expected-error@below {{Index exceeds the number of op results}}
transform.xegpu.set_op_layout_attr %0 result index = 1 sg_layout = [8, 4] sg_data = [32, 64] : !transform.any_op
transform.yield
}
}

// -----

// CHECK-LABEL: @set_op_layout_attr_bad_operand_index
func.func @set_op_layout_attr_bad_operand_index(%arg0: memref<4096x4096xf16>) {
%0 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16>
%1 = xegpu.load_nd %0[0, 0] : !xegpu.tensor_desc<256x32xf16> -> vector<256x32xf16>
%2 = arith.extf %1 : vector<256x32xf16> to vector<256x32xf32>
return
}

module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["arith.extf"]} in %arg1 : (!transform.any_op) -> !transform.any_op
// expected-error@below {{Index exceeds the number of op operands}}
transform.xegpu.set_op_layout_attr %0 index = 1 sg_layout = [8, 4] sg_data = [32, 64] : !transform.any_op
transform.yield
}
}

// -----

// CHECK-LABEL: @set_op_layout_attr_multiple
func.func @set_op_layout_attr_multiple(%arg0: memref<4096x4096xf16>) {
%0 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16>
%1 = xegpu.load_nd %0[0, 0] : !xegpu.tensor_desc<256x32xf16> -> vector<256x32xf16>
%2 = arith.extf %1 : vector<256x32xf16> to vector<256x32xf32>
%3 = arith.extf %2 : vector<256x32xf32> to vector<256x32xf64>
return
}

module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["arith.extf"]} in %arg1 : (!transform.any_op) -> !transform.any_op
// expected-error@below {{Requires exactly one targetOp handle (got 2)}}
transform.xegpu.set_op_layout_attr %0 sg_layout = [8, 4] sg_data = [32, 64] : !transform.any_op
transform.yield
}
}
Loading