-
Notifications
You must be signed in to change notification settings - Fork 15.1k
[MLIR][XeGPU][TransformOps] Add set_op_layout_attr op #166854
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
@llvm/pr-subscribers-mlir-gpu @llvm/pr-subscribers-mlir Author: Tuomas Kärnä (tkarna) ChangesAdds Also adds For reference, the rationale behind xegpu transform ops is outlined in this RFC document. Patch is 23.42 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/166854.diff 6 Files Affected:
diff --git a/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td b/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td
index b985d5450be0e..4e0eae1007c8f 100644
--- a/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td
+++ b/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td
@@ -78,4 +78,69 @@ def SetDescLayoutOp : Op<Transform_Dialect, "xegpu.set_desc_layout", [
}];
}
+def SetOpLayoutAttrOp : Op<Transform_Dialect, "xegpu.set_op_layout_attr", [
+ AttrSizedOperandSegments,
+ DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
+ TransformOpInterface
+]> {
+
+ let summary = "Set xegpu.layout attribute of an op.";
+ let description = [{
+ Sets the `xegpu.layout` attribute of an op. If `result=true`, sets the
+ `layout_result_{index}`, otherwise `layout_operand_{index}` attribute. The
+ target operand/result value is defined by the `index` argument. The layout
+ is defined by the `sg_layout`, `sg_data` and optional `inst_data` attributes.
+ }];
+
+ let arguments = (ins TransformHandleTypeInterface : $target,
+ DefaultValuedOptionalAttr<I64Attr, "0"> : $index,
+ Variadic<TransformAnyParamTypeOrAnyHandle> : $sg_layout,
+ Variadic<TransformAnyParamTypeOrAnyHandle> : $sg_data,
+ Variadic<TransformAnyParamTypeOrAnyHandle> : $inst_data,
+ DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_sg_layout,
+ DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_sg_data,
+ DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_inst_data,
+ DefaultValuedAttr<UnitAttr, "false">:$result
+ );
+
+ let results = (outs);
+ let builders = [
+ OpBuilder<(ins "Value":$target,
+ "int64_t":$index,
+ "ArrayRef<OpFoldResult>":$mixedSgLayout,
+ "ArrayRef<OpFoldResult>":$mixedSgData,
+ "ArrayRef<OpFoldResult>":$mixedInstData,
+ CArg<"bool", "false">:$result
+ )>,
+ ];
+
+ let assemblyFormat = [{
+ $target (`result` $result^)? (`index` `=` $index^)?
+ `sg_layout` `=` custom<DynamicIndexList>($sg_layout, $static_sg_layout)
+ `sg_data` `=` custom<DynamicIndexList>($sg_data, $static_sg_data)
+ (`inst_data` `=` custom<DynamicIndexList>($inst_data, $static_inst_data)^)?
+ attr-dict `:` qualified(type(operands))
+ }];
+
+ let extraClassDeclaration = [{
+ ::mlir::DiagnosedSilenceableFailure apply(
+ ::mlir::transform::TransformRewriter &rewriter,
+ ::mlir::transform::TransformResults &transformResults,
+ ::mlir::transform::TransformState &state);
+
+ ::llvm::SmallVector<::mlir::OpFoldResult> getMixedSgLayout() {
+ Builder b(getContext());
+ return getMixedValues(getStaticSgLayout(), getSgLayout(), b);
+ }
+ ::llvm::SmallVector<::mlir::OpFoldResult> getMixedSgData() {
+ Builder b(getContext());
+ return getMixedValues(getStaticSgData(), getSgData(), b);
+ }
+ ::llvm::SmallVector<::mlir::OpFoldResult> getMixedInstData() {
+ Builder b(getContext());
+ return getMixedValues(getStaticInstData(), getInstData(), b);
+ }
+ }];
+}
+
#endif // XEGPU_TRANSFORM_OPS
diff --git a/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp b/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp
index 8943ba09d9c34..456cfb9ddd2bc 100644
--- a/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp
@@ -90,6 +90,38 @@ createLayoutAttr(MLIRContext *ctx, ArrayRef<int32_t> sgLayout,
/*order=*/nullptr);
}
+/// Generate `xegpu::LayoutAttr` from op mixed layout values.
+DiagnosedSilenceableFailure
+getLayoutAttrFromOperands(transform::TransformRewriter &rewriter,
+ transform::TransformState &state,
+ TransformOpInterface transformOp,
+ ArrayRef<::mlir::OpFoldResult> mixedSgLayout,
+ ArrayRef<::mlir::OpFoldResult> mixedSgData,
+ ArrayRef<::mlir::OpFoldResult> mixedInstData,
+ xegpu::LayoutAttr &layoutAttr) {
+ SmallVector<int32_t> sgLayout, sgData, instData;
+ auto status =
+ convertMixedValuesToInt(state, transformOp, sgLayout, mixedSgLayout);
+ if (!status.succeeded())
+ return status;
+
+ status = convertMixedValuesToInt(state, transformOp, sgData, mixedSgData);
+ if (!status.succeeded())
+ return status;
+
+ status = convertMixedValuesToInt(state, transformOp, instData, mixedInstData);
+ if (!status.succeeded())
+ return status;
+ auto maybeInstData = instData.empty()
+ ? std::nullopt
+ : std::optional<ArrayRef<int32_t>>(instData);
+
+ layoutAttr =
+ createLayoutAttr(rewriter.getContext(), sgLayout, sgData, maybeInstData);
+
+ return DiagnosedSilenceableFailure::success();
+}
+
/// Replace xegpu.create_nd_desc op with a new one with the given layout.
static xegpu::CreateNdDescOp
setDescLayout(transform::TransformRewriter &rewriter,
@@ -142,26 +174,13 @@ transform::SetDescLayoutOp::apply(transform::TransformRewriter &rewriter,
}
Operation *target = *targetOps.begin();
- SmallVector<int32_t> sgLayout;
- DiagnosedSilenceableFailure status =
- convertMixedValuesToInt(state, (*this), sgLayout, getMixedSgLayout());
+ xegpu::LayoutAttr layoutAttr = nullptr;
+ auto status = getLayoutAttrFromOperands(rewriter, state, (*this),
+ getMixedSgLayout(), getMixedSgData(),
+ getMixedInstData(), layoutAttr);
if (!status.succeeded())
return status;
- SmallVector<int32_t> sgData;
- status = convertMixedValuesToInt(state, (*this), sgData, getMixedSgData());
- if (!status.succeeded())
- return status;
-
- SmallVector<int32_t> instData;
- status =
- convertMixedValuesToInt(state, (*this), instData, getMixedInstData());
- if (!status.succeeded())
- return status;
- auto maybeInstData = instData.empty()
- ? std::nullopt
- : std::optional<ArrayRef<int32_t>>(instData);
-
// For now only create_nd_desc op is supported.
auto descOp = dyn_cast<xegpu::CreateNdDescOp>(target);
if (!descOp) {
@@ -173,8 +192,6 @@ transform::SetDescLayoutOp::apply(transform::TransformRewriter &rewriter,
}
// Set layout attr in desc op's return type. Replaces old desc op.
- auto layoutAttr =
- createLayoutAttr(rewriter.getContext(), sgLayout, sgData, maybeInstData);
auto newdescOp = setDescLayout(rewriter, descOp, layoutAttr);
// Map result handles.
@@ -193,6 +210,76 @@ void transform::SetDescLayoutOp::getEffects(
modifiesPayload(effects);
}
+void transform::SetOpLayoutAttrOp::build(
+ OpBuilder &builder, OperationState &ostate, Value target, int64_t index,
+ ArrayRef<OpFoldResult> mixedSgLayout, ArrayRef<OpFoldResult> mixedSgData,
+ ArrayRef<OpFoldResult> mixedInstData, bool result) {
+ SmallVector<int64_t> staticSgLayout, staticSgData, staticInstData;
+ SmallVector<Value> dynamicSgLayout, dynamicSgData, dynamicInstData;
+ dispatchIndexOpFoldResults(mixedSgLayout, dynamicSgLayout, staticSgLayout);
+ dispatchIndexOpFoldResults(mixedSgData, dynamicSgData, staticSgData);
+ dispatchIndexOpFoldResults(mixedInstData, dynamicInstData, staticInstData);
+ build(builder, ostate, target.getType(),
+ /*target=*/target,
+ /*index=*/index,
+ /*sg_layout=*/dynamicSgLayout,
+ /*sg_data=*/dynamicSgData,
+ /*inst_data=*/dynamicInstData,
+ /*static_sg_layout=*/staticSgLayout,
+ /*static_sg_data=*/staticSgData,
+ /*static_inst_data=*/staticInstData,
+ /*result=*/result);
+}
+
+DiagnosedSilenceableFailure
+transform::SetOpLayoutAttrOp::apply(transform::TransformRewriter &rewriter,
+ transform::TransformResults &results,
+ transform::TransformState &state) {
+
+ auto targetOps = state.getPayloadOps(getTarget());
+ if (!llvm::hasSingleElement(targetOps)) {
+ return emitDefiniteFailure() << "Requires exactly one targetOp handle (got "
+ << llvm::range_size(targetOps) << ")";
+ }
+ Operation *target = *targetOps.begin();
+
+ bool resultTarget = getResult();
+
+ int64_t index = getIndex();
+ if (resultTarget && index >= target->getNumResults()) {
+ return emitSilenceableFailure(getLoc())
+ << "Index exceeds the number of op results";
+ }
+ if (!resultTarget && index >= target->getNumOperands()) {
+ return emitSilenceableFailure(getLoc())
+ << "Index exceeds the number of op operands";
+ }
+
+ xegpu::LayoutAttr layoutAttr = nullptr;
+ auto status = getLayoutAttrFromOperands(rewriter, state, (*this),
+ getMixedSgLayout(), getMixedSgData(),
+ getMixedInstData(), layoutAttr);
+ if (!status.succeeded())
+ return status;
+
+ // Set layout attribute for the op result or operand
+ if (resultTarget) {
+ xegpu::setDistributeLayoutAttr(target->getResult(index), layoutAttr);
+ } else {
+ xegpu::setDistributeLayoutAttr(target->getOpOperand(index), layoutAttr);
+ }
+ return DiagnosedSilenceableFailure::success();
+}
+
+void transform::SetOpLayoutAttrOp::getEffects(
+ ::llvm::SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+ onlyReadsHandle(getTargetMutable(), effects);
+ onlyReadsHandle(getSgLayoutMutable(), effects);
+ onlyReadsHandle(getSgDataMutable(), effects);
+ onlyReadsHandle(getInstDataMutable(), effects);
+ modifiesPayload(effects);
+}
+
namespace {
class XeGPUTransformDialectExtension
: public transform::TransformDialectExtension<
diff --git a/mlir/python/mlir/dialects/transform/xegpu.py b/mlir/python/mlir/dialects/transform/xegpu.py
index 2918bf592880a..46a1f032630d1 100644
--- a/mlir/python/mlir/dialects/transform/xegpu.py
+++ b/mlir/python/mlir/dialects/transform/xegpu.py
@@ -64,3 +64,50 @@ def __init__(
loc=loc,
ip=ip,
)
+
+
+@_ods_cext.register_operation(_Dialect, replace=True)
+class SetOpLayoutAttrOp(SetOpLayoutAttrOp):
+ """Specialization for SetOpLayoutAttrOp class."""
+
+ def __init__(
+ self,
+ target: Union[Operation, Value],
+ sg_layout: MixedValues,
+ sg_data: MixedValues,
+ *,
+ inst_data: MixedValues = None,
+ index: Union[int, Attribute] = None,
+ result: Union[bool, Attribute] = None,
+ loc=None,
+ ip=None,
+ ):
+ inst_data = [] if inst_data is None else inst_data
+ (
+ dynamic_sg_layout,
+ static_sg_layout,
+ _,
+ ) = _dispatch_dynamic_index_list(sg_layout)
+ (
+ dynamic_sg_data,
+ static_sg_data,
+ _,
+ ) = _dispatch_dynamic_index_list(sg_data)
+ (
+ dynamic_inst_data,
+ static_inst_data,
+ _,
+ ) = _dispatch_dynamic_index_list(inst_data)
+ super().__init__(
+ _get_op_result_or_value(target),
+ dynamic_sg_layout,
+ dynamic_sg_data,
+ dynamic_inst_data,
+ static_sg_layout=static_sg_layout,
+ static_sg_data=static_sg_data,
+ static_inst_data=static_inst_data,
+ index=index,
+ result=result,
+ loc=loc,
+ ip=ip,
+ )
diff --git a/mlir/test/Dialect/XeGPU/transform-ops-invalid.mlir b/mlir/test/Dialect/XeGPU/transform-ops-invalid.mlir
index 303584518f9f4..726b6748452ae 100644
--- a/mlir/test/Dialect/XeGPU/transform-ops-invalid.mlir
+++ b/mlir/test/Dialect/XeGPU/transform-ops-invalid.mlir
@@ -13,3 +13,61 @@ module attributes {transform.with_named_sequence} {
transform.yield
}
}
+
+// -----
+
+// CHECK-LABEL: @set_op_layout_attr_bad_result_index
+func.func @set_op_layout_attr_bad_result_index(%arg0: memref<4096x4096xf16>) {
+ %0 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16>
+ %1 = xegpu.load_nd %0[0, 0] : !xegpu.tensor_desc<256x32xf16> -> vector<256x32xf16>
+ %2 = arith.extf %1 : vector<256x32xf16> to vector<256x32xf32>
+ return
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["arith.extf"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ // expected-error@below {{Index exceeds the number of op results}}
+ transform.xegpu.set_op_layout_attr %0 result index = 1 sg_layout = [8, 4] sg_data = [32, 64] : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+// CHECK-LABEL: @set_op_layout_attr_bad_operand_index
+func.func @set_op_layout_attr_bad_operand_index(%arg0: memref<4096x4096xf16>) {
+ %0 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16>
+ %1 = xegpu.load_nd %0[0, 0] : !xegpu.tensor_desc<256x32xf16> -> vector<256x32xf16>
+ %2 = arith.extf %1 : vector<256x32xf16> to vector<256x32xf32>
+ return
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["arith.extf"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ // expected-error@below {{Index exceeds the number of op operands}}
+ transform.xegpu.set_op_layout_attr %0 index = 1 sg_layout = [8, 4] sg_data = [32, 64] : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+// CHECK-LABEL: @set_op_layout_attr_multiple
+func.func @set_op_layout_attr_multiple(%arg0: memref<4096x4096xf16>) {
+ %0 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16>
+ %1 = xegpu.load_nd %0[0, 0] : !xegpu.tensor_desc<256x32xf16> -> vector<256x32xf16>
+ %2 = arith.extf %1 : vector<256x32xf16> to vector<256x32xf32>
+ %3 = arith.extf %2 : vector<256x32xf32> to vector<256x32xf64>
+ return
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["arith.extf"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ // expected-error@below {{Requires exactly one targetOp handle (got 2)}}
+ transform.xegpu.set_op_layout_attr %0 sg_layout = [8, 4] sg_data = [32, 64] : !transform.any_op
+ transform.yield
+ }
+}
diff --git a/mlir/test/Dialect/XeGPU/transform-ops.mlir b/mlir/test/Dialect/XeGPU/transform-ops.mlir
index 23e1cd946b4cd..089a8fb4fd9b6 100644
--- a/mlir/test/Dialect/XeGPU/transform-ops.mlir
+++ b/mlir/test/Dialect/XeGPU/transform-ops.mlir
@@ -56,3 +56,137 @@ module attributes {transform.with_named_sequence} {
transform.yield
}
}
+
+// -----
+
+// CHECK-LABEL: @set_op_layout_attr_result_default_index
+func.func @set_op_layout_attr_result_default_index(%arg0: memref<4096x4096xf16>, %arg1: memref<4096x4096xf16>, %arg2: memref<4096x4096xf16>) {
+ %0 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16>
+ %1 = xegpu.load_nd %0[0, 0] : !xegpu.tensor_desc<256x32xf16> -> vector<256x32xf16>
+ %2 = xegpu.create_nd_tdesc %arg1 : memref<4096x4096xf16> -> !xegpu.tensor_desc<32x256xf16>
+ %3 = xegpu.load_nd %2[0, 0] : !xegpu.tensor_desc<32x256xf16> -> vector<32x256xf16>
+ %4 = xegpu.create_nd_tdesc %arg2 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x256xf16>
+ %5 = xegpu.load_nd %4[0, 0] : !xegpu.tensor_desc<256x256xf16> -> vector<256x256xf16>
+ // CHECK: = xegpu.dpas
+ // CHECK-SAME: {layout_result_0 = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 64], inst_data = [8, 16]>}
+ %6 = xegpu.dpas %1, %3, %5 : vector<256x32xf16>, vector<32x256xf16>, vector<256x256xf16> -> vector<256x256xf16>
+ return
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["xegpu.dpas"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ // CHECK: transform.xegpu.set_op_layout_attr %{{.*}}
+ transform.xegpu.set_op_layout_attr %0 result sg_layout = [8, 4] sg_data = [32, 64] inst_data = [8, 16] : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+// CHECK-LABEL: @set_op_layout_attr_result_sg_param
+func.func @set_op_layout_attr_result_sg_param(%arg0: memref<4096x4096xf16>) {
+ %0 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16>
+ %1 = xegpu.load_nd %0[0, 0] : !xegpu.tensor_desc<256x32xf16> -> vector<256x32xf16>
+ // CHECK: = arith.extf %1
+ // CHECK-SAME: {layout_result_0 = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 64], inst_data = [8, 16]>}
+ %2 = arith.extf %1 : vector<256x32xf16> to vector<256x32xf32>
+ return
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["arith.extf"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ // CHECK: transform.xegpu.set_op_layout_attr %{{.*}}
+ %layout0 = transform.param.constant 8 : i64 -> !transform.param<i64>
+ transform.xegpu.set_op_layout_attr %0 result sg_layout = [%layout0, 4] sg_data = [32, 64] inst_data = [8, 16] : !transform.any_op, !transform.param<i64>
+ transform.yield
+ }
+}
+
+// -----
+
+// CHECK-LABEL: @set_op_layout_attr_result_sg_param2
+func.func @set_op_layout_attr_result_sg_param2(%arg0: memref<4096x4096xf16>) {
+ %0 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16>
+ %1 = xegpu.load_nd %0[0, 0] : !xegpu.tensor_desc<256x32xf16> -> vector<256x32xf16>
+ // CHECK: = arith.extf %1
+ // CHECK-SAME: {layout_result_0 = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 64], inst_data = [8, 16]>}
+ %2 = arith.extf %1 : vector<256x32xf16> to vector<256x32xf32>
+ return
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["arith.extf"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ // CHECK: transform.xegpu.set_op_layout_attr %{{.*}}
+ %layout0 = transform.param.constant 8 : i64 -> !transform.param<i64>
+ %layout1 = transform.param.constant 4 : i64 -> !transform.param<i64>
+ transform.xegpu.set_op_layout_attr %0 result sg_layout = [%layout0, %layout1] sg_data = [32, 64] inst_data = [8, 16] : !transform.any_op, !transform.param<i64>, !transform.param<i64>
+ transform.yield
+ }
+}
+
+// -----
+
+// CHECK-LABEL: @set_op_layout_attr_result0
+func.func @set_op_layout_attr_result0(%arg0: memref<4096x4096xf16>) {
+ %0 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16>
+ %1 = xegpu.load_nd %0[0, 0] : !xegpu.tensor_desc<256x32xf16> -> vector<256x32xf16>
+ // CHECK: = arith.extf %1
+ // CHECK-SAME: {layout_result_0 = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 64], inst_data = [8, 16]>}
+ %2 = arith.extf %1 : vector<256x32xf16> to vector<256x32xf32>
+ return
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["arith.extf"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ // CHECK: transform.xegpu.set_op_layout_attr %{{.*}}
+ transform.xegpu.set_op_layout_attr %0 result index = 0 sg_layout = [8, 4] sg_data = [32, 64] inst_data = [8, 16] : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+// CHECK-LABEL: @set_op_layout_attr_operand_minimal
+func.func @set_op_layout_attr_operand_minimal(%arg0: memref<4096x4096xf16>) {
+ %0 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16>
+ %1 = xegpu.load_nd %0[0, 0] : !xegpu.tensor_desc<256x32xf16> -> vector<256x32xf16>
+ // CHECK: = arith.extf %1
+ // CHECK-SAME: {layout_operand_0 = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 64]>}
+ %2 = arith.extf %1 : vector<256x32...
[truncated]
|
rolfmorel
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Generally seems good to me.
Some nits and a question if a helper function needs the interface that it has. Otherwise, looks fine to me to go in.
| is defined by the `sg_layout`, `sg_data` and optional `inst_data` attributes. | ||
| }]; | ||
|
|
||
| let arguments = (ins TransformHandleTypeInterface : $target, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: colon style
| getLayoutAttrFromOperands(transform::TransformRewriter &rewriter, | ||
| transform::TransformState &state, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How about just passing the context here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, I see the state is also used by the convertMixedValuesToInt utility. Is it because it does the mapping from a SSA-value to the interpreter's associated values?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, in convertMixedValuesToInt we use state.getPayloadOps and state.getParams to get handles to payload ops and transform params, respectively. I don't think we can get rid of that.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you need the rewriter, though?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, you're right. We just need the context. I'll fix it.
| DiagnosedSilenceableFailure | ||
| getLayoutAttrFromOperands(transform::TransformRewriter &rewriter, | ||
| transform::TransformState &state, | ||
| TransformOpInterface transformOp, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If you're just constructing the attribute, it's not obvious why you need the op here. It is maybe to do error reporting? If so, maybe passing the location suffices?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm -- I see convertMixedValuesToInt is an existing function which takes the op. Might still be worth looking at that utility and see if passing the op is really needed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, transformOp is only used for error handling in convertMixedValuesToInt. I'll see if errors could be propagated in another way.
| transform::SetOpLayoutAttrOp::apply(transform::TransformRewriter &rewriter, | ||
| transform::TransformResults &results, | ||
| transform::TransformState &state) { | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: whitespace
| inst_data: MixedValues = None, | ||
| index: Union[int, Attribute] = None, | ||
| result: Union[bool, Attribute] = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pedantically speaking, I believe these types should be wrapped in Optional[...]
| %0 = transform.structured.match ops{["arith.extf"]} in %arg1 : (!transform.any_op) -> !transform.any_op | ||
| // CHECK: transform.xegpu.set_op_layout_attr %{{.*}} | ||
| %layout0 = transform.param.constant 8 : i64 -> !transform.param<i64> | ||
| transform.xegpu.set_op_layout_attr %0 result sg_layout = [%layout0, 4] sg_data = [32, 64] inst_data = [8, 16] : !transform.any_op, !transform.param<i64> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I had a feeling transform.annotate might have been enough, though this demonstrates a case that cannot be handled this way. Maybe some day ...
|
|
||
| // Set layout attribute for the op result or operand | ||
| if (resultTarget) { | ||
| xegpu::setDistributeLayoutAttr(target->getResult(index), layoutAttr); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: braces can be skipped
Jianhui-Li
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. A warning that a near future refactoring may impact this operation's interface.
| ArrayRef<::mlir::OpFoldResult> mixedSgLayout, | ||
| ArrayRef<::mlir::OpFoldResult> mixedSgData, | ||
| ArrayRef<::mlir::OpFoldResult> mixedInstData, | ||
| xegpu::LayoutAttr &layoutAttr) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Missing order attribute. It defines the layout of subgroup ids and lane ids. It is optional, if not set, the layout of these ids are row major. It can be added later when you need to handle vector.transpose.
| let summary = "Set xegpu.layout attribute of an op."; | ||
| let description = [{ | ||
| Sets the `xegpu.layout` attribute of an op. If `result=true`, sets the | ||
| `layout_result_{index}`, otherwise `layout_operand_{index}` attribute. The |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just a warning here: XeGPU is refactoring the layout setting API which user only need to set the anchor op's layout, and XeGPU will take care the propagation automatically. The new API that sets anchor ops' layout doesn't require "layout_result_" and "layout_operand_", so may impact this operation also.
Adds
transform.xegpu.set_op_layout_attrtransform op that attachesxegpu.layoutattribute to the target op.Also adds
getLayoutAttrFromOperandsutility function.For reference, the rationale behind xegpu transform ops is outlined in this RFC document.