@@ -132,6 +132,36 @@ createLayoutAttr(MLIRContext *ctx, ArrayRef<int32_t> sgLayout,
132132 /* order=*/ nullptr );
133133}
134134
135+ // / Generate `xegpu::LayoutAttr` from op mixed layout values.
136+ DiagnosedSilenceableFailure
137+ getLayoutAttrFromOperands (MLIRContext *ctx, transform::TransformState &state,
138+ TransformOpInterface transformOp,
139+ ArrayRef<::mlir::OpFoldResult> mixedSgLayout,
140+ ArrayRef<::mlir::OpFoldResult> mixedSgData,
141+ ArrayRef<::mlir::OpFoldResult> mixedInstData,
142+ xegpu::LayoutAttr &layoutAttr) {
143+ SmallVector<int32_t > sgLayout, sgData, instData;
144+ auto status =
145+ convertMixedValuesToInt (state, transformOp, sgLayout, mixedSgLayout);
146+ if (!status.succeeded ())
147+ return status;
148+
149+ status = convertMixedValuesToInt (state, transformOp, sgData, mixedSgData);
150+ if (!status.succeeded ())
151+ return status;
152+
153+ status = convertMixedValuesToInt (state, transformOp, instData, mixedInstData);
154+ if (!status.succeeded ())
155+ return status;
156+ auto maybeInstData = instData.empty ()
157+ ? std::nullopt
158+ : std::optional<ArrayRef<int32_t >>(instData);
159+
160+ layoutAttr = createLayoutAttr (ctx, sgLayout, sgData, maybeInstData);
161+
162+ return DiagnosedSilenceableFailure::success ();
163+ }
164+
135165// / Replace xegpu.create_nd_desc op with a new one with the given layout.
136166static xegpu::CreateNdDescOp
137167setDescLayout (transform::TransformRewriter &rewriter,
@@ -207,26 +237,13 @@ transform::SetDescLayoutOp::apply(transform::TransformRewriter &rewriter,
207237 }
208238 Operation *target = *targetOps.begin ();
209239
210- SmallVector<int32_t > sgLayout;
211- DiagnosedSilenceableFailure status =
212- convertMixedValuesToInt (state, (*this ), sgLayout, getMixedSgLayout ());
240+ xegpu::LayoutAttr layoutAttr = nullptr ;
241+ auto status = getLayoutAttrFromOperands (getContext (), state, (*this ),
242+ getMixedSgLayout (), getMixedSgData (),
243+ getMixedInstData (), layoutAttr);
213244 if (!status.succeeded ())
214245 return status;
215246
216- SmallVector<int32_t > sgData;
217- status = convertMixedValuesToInt (state, (*this ), sgData, getMixedSgData ());
218- if (!status.succeeded ())
219- return status;
220-
221- SmallVector<int32_t > instData;
222- status =
223- convertMixedValuesToInt (state, (*this ), instData, getMixedInstData ());
224- if (!status.succeeded ())
225- return status;
226- auto maybeInstData = instData.empty ()
227- ? std::nullopt
228- : std::optional<ArrayRef<int32_t >>(instData);
229-
230247 // For now only create_nd_desc op is supported.
231248 auto descOp = dyn_cast<xegpu::CreateNdDescOp>(target);
232249 if (!descOp) {
@@ -238,8 +255,6 @@ transform::SetDescLayoutOp::apply(transform::TransformRewriter &rewriter,
238255 }
239256
240257 // Set layout attr in desc op's return type. Replaces old desc op.
241- auto layoutAttr =
242- createLayoutAttr (rewriter.getContext (), sgLayout, sgData, maybeInstData);
243258 auto newdescOp = setDescLayout (rewriter, descOp, layoutAttr);
244259
245260 // Map result handles.
@@ -258,6 +273,74 @@ void transform::SetDescLayoutOp::getEffects(
258273 modifiesPayload (effects);
259274}
260275
276+ void transform::SetOpLayoutAttrOp::build (
277+ OpBuilder &builder, OperationState &ostate, Value target, int64_t index,
278+ ArrayRef<OpFoldResult> mixedSgLayout, ArrayRef<OpFoldResult> mixedSgData,
279+ ArrayRef<OpFoldResult> mixedInstData, bool result) {
280+ SmallVector<int64_t > staticSgLayout, staticSgData, staticInstData;
281+ SmallVector<Value> dynamicSgLayout, dynamicSgData, dynamicInstData;
282+ dispatchIndexOpFoldResults (mixedSgLayout, dynamicSgLayout, staticSgLayout);
283+ dispatchIndexOpFoldResults (mixedSgData, dynamicSgData, staticSgData);
284+ dispatchIndexOpFoldResults (mixedInstData, dynamicInstData, staticInstData);
285+ build (builder, ostate, target.getType (),
286+ /* target=*/ target,
287+ /* index=*/ index,
288+ /* sg_layout=*/ dynamicSgLayout,
289+ /* sg_data=*/ dynamicSgData,
290+ /* inst_data=*/ dynamicInstData,
291+ /* static_sg_layout=*/ staticSgLayout,
292+ /* static_sg_data=*/ staticSgData,
293+ /* static_inst_data=*/ staticInstData,
294+ /* result=*/ result);
295+ }
296+
297+ DiagnosedSilenceableFailure
298+ transform::SetOpLayoutAttrOp::apply (transform::TransformRewriter &rewriter,
299+ transform::TransformResults &results,
300+ transform::TransformState &state) {
301+ auto targetOps = state.getPayloadOps (getTarget ());
302+ if (!llvm::hasSingleElement (targetOps)) {
303+ return emitDefiniteFailure () << " Requires exactly one targetOp handle (got "
304+ << llvm::range_size (targetOps) << " )" ;
305+ }
306+ Operation *target = *targetOps.begin ();
307+
308+ bool resultTarget = getResult ();
309+
310+ int64_t index = getIndex ();
311+ if (resultTarget && index >= target->getNumResults ()) {
312+ return emitSilenceableFailure (getLoc ())
313+ << " Index exceeds the number of op results" ;
314+ }
315+ if (!resultTarget && index >= target->getNumOperands ()) {
316+ return emitSilenceableFailure (getLoc ())
317+ << " Index exceeds the number of op operands" ;
318+ }
319+
320+ xegpu::LayoutAttr layoutAttr = nullptr ;
321+ auto status = getLayoutAttrFromOperands (getContext (), state, (*this ),
322+ getMixedSgLayout (), getMixedSgData (),
323+ getMixedInstData (), layoutAttr);
324+ if (!status.succeeded ())
325+ return status;
326+
327+ // Set layout attribute for the op result or operand
328+ if (resultTarget)
329+ xegpu::setDistributeLayoutAttr (target->getResult (index), layoutAttr);
330+ else
331+ xegpu::setDistributeLayoutAttr (target->getOpOperand (index), layoutAttr);
332+ return DiagnosedSilenceableFailure::success ();
333+ }
334+
335+ void transform::SetOpLayoutAttrOp::getEffects (
336+ ::llvm::SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
337+ onlyReadsHandle (getTargetMutable (), effects);
338+ onlyReadsHandle (getSgLayoutMutable (), effects);
339+ onlyReadsHandle (getSgDataMutable (), effects);
340+ onlyReadsHandle (getInstDataMutable (), effects);
341+ modifiesPayload (effects);
342+ }
343+
261344namespace {
262345class XeGPUTransformDialectExtension
263346 : public transform::TransformDialectExtension<
0 commit comments