Skip to content

Conversation

@tkarna
Copy link
Contributor

@tkarna tkarna commented Nov 6, 2025

Adds transform.xegpu.set_op_layout_attr transform op that attaches xegpu.layout attribute to the target op.

Also adds getLayoutAttrFromOperands utility function.

For reference, the rationale behind xegpu transform ops is outlined in this RFC document.

@llvmbot
Copy link
Member

llvmbot commented Nov 6, 2025

@llvm/pr-subscribers-mlir-gpu

@llvm/pr-subscribers-mlir

Author: Tuomas Kärnä (tkarna)

Changes

Adds transform.xegpu.set_op_layout_attr transform op that attaches xegpu.layout attribute to the target op.

Also adds getLayoutAttrFromOperands utility function.

For reference, the rationale behind xegpu transform ops is outlined in this RFC document.


Patch is 23.42 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/166854.diff

6 Files Affected:

  • (modified) mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td (+65)
  • (modified) mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp (+106-19)
  • (modified) mlir/python/mlir/dialects/transform/xegpu.py (+47)
  • (modified) mlir/test/Dialect/XeGPU/transform-ops-invalid.mlir (+58)
  • (modified) mlir/test/Dialect/XeGPU/transform-ops.mlir (+134)
  • (modified) mlir/test/python/dialects/transform_xegpu_ext.py (+49)
diff --git a/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td b/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td
index b985d5450be0e..4e0eae1007c8f 100644
--- a/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td
+++ b/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td
@@ -78,4 +78,69 @@ def SetDescLayoutOp : Op<Transform_Dialect, "xegpu.set_desc_layout", [
   }];
 }
 
+def SetOpLayoutAttrOp : Op<Transform_Dialect, "xegpu.set_op_layout_attr", [
+  AttrSizedOperandSegments,
+  DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
+  TransformOpInterface
+]> {
+
+  let summary = "Set xegpu.layout attribute of an op.";
+  let description = [{
+    Sets the `xegpu.layout` attribute of an op. If `result=true`, sets the
+    `layout_result_{index}`, otherwise `layout_operand_{index}` attribute. The
+    target operand/result value is defined by the `index` argument. The layout
+    is defined by the `sg_layout`, `sg_data` and optional `inst_data` attributes.
+  }];
+
+  let arguments = (ins TransformHandleTypeInterface : $target,
+                   DefaultValuedOptionalAttr<I64Attr, "0"> : $index,
+                   Variadic<TransformAnyParamTypeOrAnyHandle> : $sg_layout,
+                   Variadic<TransformAnyParamTypeOrAnyHandle> : $sg_data,
+                   Variadic<TransformAnyParamTypeOrAnyHandle> : $inst_data,
+                   DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_sg_layout,
+                   DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_sg_data,
+                   DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_inst_data,
+                   DefaultValuedAttr<UnitAttr, "false">:$result
+                   );
+
+  let results = (outs);
+  let builders = [
+    OpBuilder<(ins "Value":$target,
+                   "int64_t":$index,
+                   "ArrayRef<OpFoldResult>":$mixedSgLayout,
+                   "ArrayRef<OpFoldResult>":$mixedSgData,
+                   "ArrayRef<OpFoldResult>":$mixedInstData,
+                   CArg<"bool", "false">:$result
+                   )>,
+  ];
+
+  let assemblyFormat = [{
+    $target (`result` $result^)? (`index` `=` $index^)?
+    `sg_layout` `=` custom<DynamicIndexList>($sg_layout, $static_sg_layout)
+    `sg_data` `=` custom<DynamicIndexList>($sg_data, $static_sg_data)
+    (`inst_data` `=` custom<DynamicIndexList>($inst_data, $static_inst_data)^)?
+    attr-dict `:` qualified(type(operands))
+  }];
+
+  let extraClassDeclaration = [{
+    ::mlir::DiagnosedSilenceableFailure apply(
+        ::mlir::transform::TransformRewriter &rewriter,
+        ::mlir::transform::TransformResults &transformResults,
+        ::mlir::transform::TransformState &state);
+
+    ::llvm::SmallVector<::mlir::OpFoldResult> getMixedSgLayout() {
+      Builder b(getContext());
+      return getMixedValues(getStaticSgLayout(), getSgLayout(), b);
+    }
+    ::llvm::SmallVector<::mlir::OpFoldResult> getMixedSgData() {
+      Builder b(getContext());
+      return getMixedValues(getStaticSgData(), getSgData(), b);
+    }
+    ::llvm::SmallVector<::mlir::OpFoldResult> getMixedInstData() {
+      Builder b(getContext());
+      return getMixedValues(getStaticInstData(), getInstData(), b);
+    }
+  }];
+}
+
 #endif // XEGPU_TRANSFORM_OPS
diff --git a/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp b/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp
index 8943ba09d9c34..456cfb9ddd2bc 100644
--- a/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp
@@ -90,6 +90,38 @@ createLayoutAttr(MLIRContext *ctx, ArrayRef<int32_t> sgLayout,
       /*order=*/nullptr);
 }
 
+/// Generate `xegpu::LayoutAttr` from op mixed layout values.
+DiagnosedSilenceableFailure
+getLayoutAttrFromOperands(transform::TransformRewriter &rewriter,
+                          transform::TransformState &state,
+                          TransformOpInterface transformOp,
+                          ArrayRef<::mlir::OpFoldResult> mixedSgLayout,
+                          ArrayRef<::mlir::OpFoldResult> mixedSgData,
+                          ArrayRef<::mlir::OpFoldResult> mixedInstData,
+                          xegpu::LayoutAttr &layoutAttr) {
+  SmallVector<int32_t> sgLayout, sgData, instData;
+  auto status =
+      convertMixedValuesToInt(state, transformOp, sgLayout, mixedSgLayout);
+  if (!status.succeeded())
+    return status;
+
+  status = convertMixedValuesToInt(state, transformOp, sgData, mixedSgData);
+  if (!status.succeeded())
+    return status;
+
+  status = convertMixedValuesToInt(state, transformOp, instData, mixedInstData);
+  if (!status.succeeded())
+    return status;
+  auto maybeInstData = instData.empty()
+                           ? std::nullopt
+                           : std::optional<ArrayRef<int32_t>>(instData);
+
+  layoutAttr =
+      createLayoutAttr(rewriter.getContext(), sgLayout, sgData, maybeInstData);
+
+  return DiagnosedSilenceableFailure::success();
+}
+
 /// Replace xegpu.create_nd_desc op with a new one with the given layout.
 static xegpu::CreateNdDescOp
 setDescLayout(transform::TransformRewriter &rewriter,
@@ -142,26 +174,13 @@ transform::SetDescLayoutOp::apply(transform::TransformRewriter &rewriter,
   }
   Operation *target = *targetOps.begin();
 
-  SmallVector<int32_t> sgLayout;
-  DiagnosedSilenceableFailure status =
-      convertMixedValuesToInt(state, (*this), sgLayout, getMixedSgLayout());
+  xegpu::LayoutAttr layoutAttr = nullptr;
+  auto status = getLayoutAttrFromOperands(rewriter, state, (*this),
+                                          getMixedSgLayout(), getMixedSgData(),
+                                          getMixedInstData(), layoutAttr);
   if (!status.succeeded())
     return status;
 
-  SmallVector<int32_t> sgData;
-  status = convertMixedValuesToInt(state, (*this), sgData, getMixedSgData());
-  if (!status.succeeded())
-    return status;
-
-  SmallVector<int32_t> instData;
-  status =
-      convertMixedValuesToInt(state, (*this), instData, getMixedInstData());
-  if (!status.succeeded())
-    return status;
-  auto maybeInstData = instData.empty()
-                           ? std::nullopt
-                           : std::optional<ArrayRef<int32_t>>(instData);
-
   // For now only create_nd_desc op is supported.
   auto descOp = dyn_cast<xegpu::CreateNdDescOp>(target);
   if (!descOp) {
@@ -173,8 +192,6 @@ transform::SetDescLayoutOp::apply(transform::TransformRewriter &rewriter,
   }
 
   // Set layout attr in desc op's return type. Replaces old desc op.
-  auto layoutAttr =
-      createLayoutAttr(rewriter.getContext(), sgLayout, sgData, maybeInstData);
   auto newdescOp = setDescLayout(rewriter, descOp, layoutAttr);
 
   // Map result handles.
@@ -193,6 +210,76 @@ void transform::SetDescLayoutOp::getEffects(
   modifiesPayload(effects);
 }
 
+void transform::SetOpLayoutAttrOp::build(
+    OpBuilder &builder, OperationState &ostate, Value target, int64_t index,
+    ArrayRef<OpFoldResult> mixedSgLayout, ArrayRef<OpFoldResult> mixedSgData,
+    ArrayRef<OpFoldResult> mixedInstData, bool result) {
+  SmallVector<int64_t> staticSgLayout, staticSgData, staticInstData;
+  SmallVector<Value> dynamicSgLayout, dynamicSgData, dynamicInstData;
+  dispatchIndexOpFoldResults(mixedSgLayout, dynamicSgLayout, staticSgLayout);
+  dispatchIndexOpFoldResults(mixedSgData, dynamicSgData, staticSgData);
+  dispatchIndexOpFoldResults(mixedInstData, dynamicInstData, staticInstData);
+  build(builder, ostate, target.getType(),
+        /*target=*/target,
+        /*index=*/index,
+        /*sg_layout=*/dynamicSgLayout,
+        /*sg_data=*/dynamicSgData,
+        /*inst_data=*/dynamicInstData,
+        /*static_sg_layout=*/staticSgLayout,
+        /*static_sg_data=*/staticSgData,
+        /*static_inst_data=*/staticInstData,
+        /*result=*/result);
+}
+
+DiagnosedSilenceableFailure
+transform::SetOpLayoutAttrOp::apply(transform::TransformRewriter &rewriter,
+                                    transform::TransformResults &results,
+                                    transform::TransformState &state) {
+
+  auto targetOps = state.getPayloadOps(getTarget());
+  if (!llvm::hasSingleElement(targetOps)) {
+    return emitDefiniteFailure() << "Requires exactly one targetOp handle (got "
+                                 << llvm::range_size(targetOps) << ")";
+  }
+  Operation *target = *targetOps.begin();
+
+  bool resultTarget = getResult();
+
+  int64_t index = getIndex();
+  if (resultTarget && index >= target->getNumResults()) {
+    return emitSilenceableFailure(getLoc())
+           << "Index exceeds the number of op results";
+  }
+  if (!resultTarget && index >= target->getNumOperands()) {
+    return emitSilenceableFailure(getLoc())
+           << "Index exceeds the number of op operands";
+  }
+
+  xegpu::LayoutAttr layoutAttr = nullptr;
+  auto status = getLayoutAttrFromOperands(rewriter, state, (*this),
+                                          getMixedSgLayout(), getMixedSgData(),
+                                          getMixedInstData(), layoutAttr);
+  if (!status.succeeded())
+    return status;
+
+  // Set layout attribute for the op result or operand
+  if (resultTarget) {
+    xegpu::setDistributeLayoutAttr(target->getResult(index), layoutAttr);
+  } else {
+    xegpu::setDistributeLayoutAttr(target->getOpOperand(index), layoutAttr);
+  }
+  return DiagnosedSilenceableFailure::success();
+}
+
+void transform::SetOpLayoutAttrOp::getEffects(
+    ::llvm::SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+  onlyReadsHandle(getTargetMutable(), effects);
+  onlyReadsHandle(getSgLayoutMutable(), effects);
+  onlyReadsHandle(getSgDataMutable(), effects);
+  onlyReadsHandle(getInstDataMutable(), effects);
+  modifiesPayload(effects);
+}
+
 namespace {
 class XeGPUTransformDialectExtension
     : public transform::TransformDialectExtension<
diff --git a/mlir/python/mlir/dialects/transform/xegpu.py b/mlir/python/mlir/dialects/transform/xegpu.py
index 2918bf592880a..46a1f032630d1 100644
--- a/mlir/python/mlir/dialects/transform/xegpu.py
+++ b/mlir/python/mlir/dialects/transform/xegpu.py
@@ -64,3 +64,50 @@ def __init__(
             loc=loc,
             ip=ip,
         )
+
+
+@_ods_cext.register_operation(_Dialect, replace=True)
+class SetOpLayoutAttrOp(SetOpLayoutAttrOp):
+    """Specialization for SetOpLayoutAttrOp class."""
+
+    def __init__(
+        self,
+        target: Union[Operation, Value],
+        sg_layout: MixedValues,
+        sg_data: MixedValues,
+        *,
+        inst_data: MixedValues = None,
+        index: Union[int, Attribute] = None,
+        result: Union[bool, Attribute] = None,
+        loc=None,
+        ip=None,
+    ):
+        inst_data = [] if inst_data is None else inst_data
+        (
+            dynamic_sg_layout,
+            static_sg_layout,
+            _,
+        ) = _dispatch_dynamic_index_list(sg_layout)
+        (
+            dynamic_sg_data,
+            static_sg_data,
+            _,
+        ) = _dispatch_dynamic_index_list(sg_data)
+        (
+            dynamic_inst_data,
+            static_inst_data,
+            _,
+        ) = _dispatch_dynamic_index_list(inst_data)
+        super().__init__(
+            _get_op_result_or_value(target),
+            dynamic_sg_layout,
+            dynamic_sg_data,
+            dynamic_inst_data,
+            static_sg_layout=static_sg_layout,
+            static_sg_data=static_sg_data,
+            static_inst_data=static_inst_data,
+            index=index,
+            result=result,
+            loc=loc,
+            ip=ip,
+        )
diff --git a/mlir/test/Dialect/XeGPU/transform-ops-invalid.mlir b/mlir/test/Dialect/XeGPU/transform-ops-invalid.mlir
index 303584518f9f4..726b6748452ae 100644
--- a/mlir/test/Dialect/XeGPU/transform-ops-invalid.mlir
+++ b/mlir/test/Dialect/XeGPU/transform-ops-invalid.mlir
@@ -13,3 +13,61 @@ module attributes {transform.with_named_sequence} {
     transform.yield
   }
 }
+
+// -----
+
+// CHECK-LABEL: @set_op_layout_attr_bad_result_index
+func.func @set_op_layout_attr_bad_result_index(%arg0: memref<4096x4096xf16>) {
+  %0 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16>
+  %1 = xegpu.load_nd %0[0, 0]  : !xegpu.tensor_desc<256x32xf16> -> vector<256x32xf16>
+  %2 = arith.extf %1 : vector<256x32xf16> to vector<256x32xf32>
+  return
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["arith.extf"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    // expected-error@below {{Index exceeds the number of op results}}
+    transform.xegpu.set_op_layout_attr %0 result index = 1 sg_layout = [8, 4] sg_data = [32, 64] : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+// CHECK-LABEL: @set_op_layout_attr_bad_operand_index
+func.func @set_op_layout_attr_bad_operand_index(%arg0: memref<4096x4096xf16>) {
+  %0 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16>
+  %1 = xegpu.load_nd %0[0, 0]  : !xegpu.tensor_desc<256x32xf16> -> vector<256x32xf16>
+  %2 = arith.extf %1 : vector<256x32xf16> to vector<256x32xf32>
+  return
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["arith.extf"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    // expected-error@below {{Index exceeds the number of op operands}}
+    transform.xegpu.set_op_layout_attr %0 index = 1 sg_layout = [8, 4] sg_data = [32, 64] : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+// CHECK-LABEL: @set_op_layout_attr_multiple
+func.func @set_op_layout_attr_multiple(%arg0: memref<4096x4096xf16>) {
+  %0 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16>
+  %1 = xegpu.load_nd %0[0, 0]  : !xegpu.tensor_desc<256x32xf16> -> vector<256x32xf16>
+  %2 = arith.extf %1 : vector<256x32xf16> to vector<256x32xf32>
+  %3 = arith.extf %2 : vector<256x32xf32> to vector<256x32xf64>
+  return
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["arith.extf"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    // expected-error@below {{Requires exactly one targetOp handle (got 2)}}
+    transform.xegpu.set_op_layout_attr %0 sg_layout = [8, 4] sg_data = [32, 64] : !transform.any_op
+    transform.yield
+  }
+}
diff --git a/mlir/test/Dialect/XeGPU/transform-ops.mlir b/mlir/test/Dialect/XeGPU/transform-ops.mlir
index 23e1cd946b4cd..089a8fb4fd9b6 100644
--- a/mlir/test/Dialect/XeGPU/transform-ops.mlir
+++ b/mlir/test/Dialect/XeGPU/transform-ops.mlir
@@ -56,3 +56,137 @@ module attributes {transform.with_named_sequence} {
     transform.yield
   }
 }
+
+// -----
+
+// CHECK-LABEL: @set_op_layout_attr_result_default_index
+func.func @set_op_layout_attr_result_default_index(%arg0: memref<4096x4096xf16>, %arg1: memref<4096x4096xf16>, %arg2: memref<4096x4096xf16>) {
+  %0 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16>
+  %1 = xegpu.load_nd %0[0, 0] : !xegpu.tensor_desc<256x32xf16> -> vector<256x32xf16>
+  %2 = xegpu.create_nd_tdesc %arg1 : memref<4096x4096xf16> -> !xegpu.tensor_desc<32x256xf16>
+  %3 = xegpu.load_nd %2[0, 0]  : !xegpu.tensor_desc<32x256xf16> -> vector<32x256xf16>
+  %4 = xegpu.create_nd_tdesc %arg2 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x256xf16>
+  %5 = xegpu.load_nd %4[0, 0]  : !xegpu.tensor_desc<256x256xf16> -> vector<256x256xf16>
+  // CHECK: = xegpu.dpas
+  // CHECK-SAME: {layout_result_0 = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 64], inst_data = [8, 16]>}
+  %6 = xegpu.dpas %1, %3, %5 : vector<256x32xf16>, vector<32x256xf16>, vector<256x256xf16> -> vector<256x256xf16>
+  return
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["xegpu.dpas"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    // CHECK: transform.xegpu.set_op_layout_attr %{{.*}}
+    transform.xegpu.set_op_layout_attr %0 result sg_layout = [8, 4] sg_data = [32, 64] inst_data = [8, 16] : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+// CHECK-LABEL: @set_op_layout_attr_result_sg_param
+func.func @set_op_layout_attr_result_sg_param(%arg0: memref<4096x4096xf16>) {
+  %0 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16>
+  %1 = xegpu.load_nd %0[0, 0]  : !xegpu.tensor_desc<256x32xf16> -> vector<256x32xf16>
+  // CHECK: = arith.extf %1
+  // CHECK-SAME: {layout_result_0 = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 64], inst_data = [8, 16]>}
+  %2 = arith.extf %1 : vector<256x32xf16> to vector<256x32xf32>
+  return
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["arith.extf"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    // CHECK: transform.xegpu.set_op_layout_attr %{{.*}}
+    %layout0 = transform.param.constant 8 : i64 -> !transform.param<i64>
+    transform.xegpu.set_op_layout_attr %0 result sg_layout = [%layout0, 4] sg_data = [32, 64] inst_data = [8, 16] : !transform.any_op, !transform.param<i64>
+    transform.yield
+  }
+}
+
+// -----
+
+// CHECK-LABEL: @set_op_layout_attr_result_sg_param2
+func.func @set_op_layout_attr_result_sg_param2(%arg0: memref<4096x4096xf16>) {
+  %0 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16>
+  %1 = xegpu.load_nd %0[0, 0]  : !xegpu.tensor_desc<256x32xf16> -> vector<256x32xf16>
+  // CHECK: = arith.extf %1
+  // CHECK-SAME: {layout_result_0 = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 64], inst_data = [8, 16]>}
+  %2 = arith.extf %1 : vector<256x32xf16> to vector<256x32xf32>
+  return
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["arith.extf"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    // CHECK: transform.xegpu.set_op_layout_attr %{{.*}}
+    %layout0 = transform.param.constant 8 : i64 -> !transform.param<i64>
+    %layout1 = transform.param.constant 4 : i64 -> !transform.param<i64>
+    transform.xegpu.set_op_layout_attr %0 result sg_layout = [%layout0, %layout1] sg_data = [32, 64] inst_data = [8, 16] : !transform.any_op, !transform.param<i64>, !transform.param<i64>
+    transform.yield
+  }
+}
+
+// -----
+
+// CHECK-LABEL: @set_op_layout_attr_result0
+func.func @set_op_layout_attr_result0(%arg0: memref<4096x4096xf16>) {
+  %0 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16>
+  %1 = xegpu.load_nd %0[0, 0]  : !xegpu.tensor_desc<256x32xf16> -> vector<256x32xf16>
+  // CHECK: = arith.extf %1
+  // CHECK-SAME: {layout_result_0 = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 64], inst_data = [8, 16]>}
+  %2 = arith.extf %1 : vector<256x32xf16> to vector<256x32xf32>
+  return
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["arith.extf"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    // CHECK: transform.xegpu.set_op_layout_attr %{{.*}}
+    transform.xegpu.set_op_layout_attr %0 result index = 0 sg_layout = [8, 4] sg_data = [32, 64] inst_data = [8, 16] : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+// CHECK-LABEL: @set_op_layout_attr_operand_minimal
+func.func @set_op_layout_attr_operand_minimal(%arg0: memref<4096x4096xf16>) {
+  %0 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16>
+  %1 = xegpu.load_nd %0[0, 0]  : !xegpu.tensor_desc<256x32xf16> -> vector<256x32xf16>
+  // CHECK: = arith.extf %1
+  // CHECK-SAME: {layout_operand_0 = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 64]>}
+  %2 = arith.extf %1 : vector<256x32...
[truncated]

@tkarna
Copy link
Contributor Author

tkarna commented Nov 6, 2025

Copy link
Contributor

@rolfmorel rolfmorel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Generally seems good to me.

Some nits and a question if a helper function needs the interface that it has. Otherwise, looks fine to me to go in.

is defined by the `sg_layout`, `sg_data` and optional `inst_data` attributes.
}];

let arguments = (ins TransformHandleTypeInterface : $target,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: colon style

Comment on lines +95 to +96
getLayoutAttrFromOperands(transform::TransformRewriter &rewriter,
transform::TransformState &state,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about just passing the context here?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, I see the state is also used by the convertMixedValuesToInt utility. Is it because it does the mapping from a SSA-value to the interpreter's associated values?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, in convertMixedValuesToInt we use state.getPayloadOps and state.getParams to get handles to payload ops and transform params, respectively. I don't think we can get rid of that.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you need the rewriter, though?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, you're right. We just need the context. I'll fix it.

DiagnosedSilenceableFailure
getLayoutAttrFromOperands(transform::TransformRewriter &rewriter,
transform::TransformState &state,
TransformOpInterface transformOp,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you're just constructing the attribute, it's not obvious why you need the op here. It is maybe to do error reporting? If so, maybe passing the location suffices?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm -- I see convertMixedValuesToInt is an existing function which takes the op. Might still be worth looking at that utility and see if passing the op is really needed.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, transformOp is only used for error handling in convertMixedValuesToInt. I'll see if errors could be propagated in another way.

transform::SetOpLayoutAttrOp::apply(transform::TransformRewriter &rewriter,
transform::TransformResults &results,
transform::TransformState &state) {

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: whitespace

Comment on lines +79 to +81
inst_data: MixedValues = None,
index: Union[int, Attribute] = None,
result: Union[bool, Attribute] = None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pedantically speaking, I believe these types should be wrapped in Optional[...]

%0 = transform.structured.match ops{["arith.extf"]} in %arg1 : (!transform.any_op) -> !transform.any_op
// CHECK: transform.xegpu.set_op_layout_attr %{{.*}}
%layout0 = transform.param.constant 8 : i64 -> !transform.param<i64>
transform.xegpu.set_op_layout_attr %0 result sg_layout = [%layout0, 4] sg_data = [32, 64] inst_data = [8, 16] : !transform.any_op, !transform.param<i64>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had a feeling transform.annotate might have been enough, though this demonstrates a case that cannot be handled this way. Maybe some day ...


// Set layout attribute for the op result or operand
if (resultTarget) {
xegpu::setDistributeLayoutAttr(target->getResult(index), layoutAttr);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: braces can be skipped

Copy link
Contributor

@Jianhui-Li Jianhui-Li left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. A warning that a near future refactoring may impact this operation's interface.

ArrayRef<::mlir::OpFoldResult> mixedSgLayout,
ArrayRef<::mlir::OpFoldResult> mixedSgData,
ArrayRef<::mlir::OpFoldResult> mixedInstData,
xegpu::LayoutAttr &layoutAttr) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing order attribute. It defines the layout of subgroup ids and lane ids. It is optional, if not set, the layout of these ids are row major. It can be added later when you need to handle vector.transpose.

let summary = "Set xegpu.layout attribute of an op.";
let description = [{
Sets the `xegpu.layout` attribute of an op. If `result=true`, sets the
`layout_result_{index}`, otherwise `layout_operand_{index}` attribute. The
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a warning here: XeGPU is refactoring the layout setting API which user only need to set the anchor op's layout, and XeGPU will take care the propagation automatically. The new API that sets anchor ops' layout doesn't require "layout_result_" and "layout_operand_", so may impact this operation also.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants