-
Notifications
You must be signed in to change notification settings - Fork 15.1k
[MLIR][XeGPU][TransformOps] Add set_op_layout_attr op #166854
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -78,4 +78,69 @@ def SetDescLayoutOp : Op<Transform_Dialect, "xegpu.set_desc_layout", [ | |
| }]; | ||
| } | ||
|
|
||
| def SetOpLayoutAttrOp : Op<Transform_Dialect, "xegpu.set_op_layout_attr", [ | ||
| AttrSizedOperandSegments, | ||
| DeclareOpInterfaceMethods<MemoryEffectsOpInterface>, | ||
| TransformOpInterface | ||
| ]> { | ||
|
|
||
| let summary = "Set xegpu.layout attribute of an op."; | ||
| let description = [{ | ||
| Sets the `xegpu.layout` attribute of an op. If `result=true`, sets the | ||
| `layout_result_{index}`, otherwise `layout_operand_{index}` attribute. The | ||
| target operand/result value is defined by the `index` argument. The layout | ||
| is defined by the `sg_layout`, `sg_data` and optional `inst_data` attributes. | ||
| }]; | ||
|
|
||
| let arguments = (ins TransformHandleTypeInterface : $target, | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: colon style |
||
| DefaultValuedOptionalAttr<I64Attr, "0"> : $index, | ||
| Variadic<TransformAnyParamTypeOrAnyHandle> : $sg_layout, | ||
| Variadic<TransformAnyParamTypeOrAnyHandle> : $sg_data, | ||
| Variadic<TransformAnyParamTypeOrAnyHandle> : $inst_data, | ||
| DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_sg_layout, | ||
| DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_sg_data, | ||
| DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_inst_data, | ||
| DefaultValuedAttr<UnitAttr, "false">:$result | ||
| ); | ||
|
|
||
| let results = (outs); | ||
| let builders = [ | ||
| OpBuilder<(ins "Value":$target, | ||
| "int64_t":$index, | ||
| "ArrayRef<OpFoldResult>":$mixedSgLayout, | ||
| "ArrayRef<OpFoldResult>":$mixedSgData, | ||
| "ArrayRef<OpFoldResult>":$mixedInstData, | ||
| CArg<"bool", "false">:$result | ||
| )>, | ||
| ]; | ||
|
|
||
| let assemblyFormat = [{ | ||
| $target (`result` $result^)? (`index` `=` $index^)? | ||
| `sg_layout` `=` custom<DynamicIndexList>($sg_layout, $static_sg_layout) | ||
| `sg_data` `=` custom<DynamicIndexList>($sg_data, $static_sg_data) | ||
| (`inst_data` `=` custom<DynamicIndexList>($inst_data, $static_inst_data)^)? | ||
| attr-dict `:` qualified(type(operands)) | ||
| }]; | ||
|
|
||
| let extraClassDeclaration = [{ | ||
| ::mlir::DiagnosedSilenceableFailure apply( | ||
| ::mlir::transform::TransformRewriter &rewriter, | ||
| ::mlir::transform::TransformResults &transformResults, | ||
| ::mlir::transform::TransformState &state); | ||
|
|
||
| ::llvm::SmallVector<::mlir::OpFoldResult> getMixedSgLayout() { | ||
| Builder b(getContext()); | ||
| return getMixedValues(getStaticSgLayout(), getSgLayout(), b); | ||
| } | ||
| ::llvm::SmallVector<::mlir::OpFoldResult> getMixedSgData() { | ||
| Builder b(getContext()); | ||
| return getMixedValues(getStaticSgData(), getSgData(), b); | ||
| } | ||
| ::llvm::SmallVector<::mlir::OpFoldResult> getMixedInstData() { | ||
| Builder b(getContext()); | ||
| return getMixedValues(getStaticInstData(), getInstData(), b); | ||
| } | ||
| }]; | ||
| } | ||
|
|
||
| #endif // XEGPU_TRANSFORM_OPS | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -90,6 +90,38 @@ createLayoutAttr(MLIRContext *ctx, ArrayRef<int32_t> sgLayout, | |
| /*order=*/nullptr); | ||
| } | ||
|
|
||
| /// Generate `xegpu::LayoutAttr` from op mixed layout values. | ||
| DiagnosedSilenceableFailure | ||
| getLayoutAttrFromOperands(transform::TransformRewriter &rewriter, | ||
| transform::TransformState &state, | ||
|
Comment on lines
+95
to
+96
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. How about just passing the context here?
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah, I see the state is also used by the
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, in
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do you need the rewriter, though?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No, you're right. We just need the context. I'll fix it. |
||
| TransformOpInterface transformOp, | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If you're just constructing the attribute, it's not obvious why you need the op here. It is maybe to do error reporting? If so, maybe passing the location suffices?
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hmm -- I see
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, |
||
| ArrayRef<::mlir::OpFoldResult> mixedSgLayout, | ||
| ArrayRef<::mlir::OpFoldResult> mixedSgData, | ||
| ArrayRef<::mlir::OpFoldResult> mixedInstData, | ||
| xegpu::LayoutAttr &layoutAttr) { | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Missing order attribute. It defines the layout of subgroup ids and lane ids. It is optional, if not set, the layout of these ids are row major. It can be added later when you need to handle vector.transpose. |
||
| SmallVector<int32_t> sgLayout, sgData, instData; | ||
| auto status = | ||
| convertMixedValuesToInt(state, transformOp, sgLayout, mixedSgLayout); | ||
| if (!status.succeeded()) | ||
| return status; | ||
|
|
||
| status = convertMixedValuesToInt(state, transformOp, sgData, mixedSgData); | ||
| if (!status.succeeded()) | ||
| return status; | ||
|
|
||
| status = convertMixedValuesToInt(state, transformOp, instData, mixedInstData); | ||
| if (!status.succeeded()) | ||
| return status; | ||
| auto maybeInstData = instData.empty() | ||
| ? std::nullopt | ||
| : std::optional<ArrayRef<int32_t>>(instData); | ||
|
|
||
| layoutAttr = | ||
| createLayoutAttr(rewriter.getContext(), sgLayout, sgData, maybeInstData); | ||
|
|
||
| return DiagnosedSilenceableFailure::success(); | ||
| } | ||
|
|
||
| /// Replace xegpu.create_nd_desc op with a new one with the given layout. | ||
| static xegpu::CreateNdDescOp | ||
| setDescLayout(transform::TransformRewriter &rewriter, | ||
|
|
@@ -142,26 +174,13 @@ transform::SetDescLayoutOp::apply(transform::TransformRewriter &rewriter, | |
| } | ||
| Operation *target = *targetOps.begin(); | ||
|
|
||
| SmallVector<int32_t> sgLayout; | ||
| DiagnosedSilenceableFailure status = | ||
| convertMixedValuesToInt(state, (*this), sgLayout, getMixedSgLayout()); | ||
| xegpu::LayoutAttr layoutAttr = nullptr; | ||
| auto status = getLayoutAttrFromOperands(rewriter, state, (*this), | ||
| getMixedSgLayout(), getMixedSgData(), | ||
| getMixedInstData(), layoutAttr); | ||
| if (!status.succeeded()) | ||
| return status; | ||
|
|
||
| SmallVector<int32_t> sgData; | ||
| status = convertMixedValuesToInt(state, (*this), sgData, getMixedSgData()); | ||
| if (!status.succeeded()) | ||
| return status; | ||
|
|
||
| SmallVector<int32_t> instData; | ||
| status = | ||
| convertMixedValuesToInt(state, (*this), instData, getMixedInstData()); | ||
| if (!status.succeeded()) | ||
| return status; | ||
| auto maybeInstData = instData.empty() | ||
| ? std::nullopt | ||
| : std::optional<ArrayRef<int32_t>>(instData); | ||
|
|
||
| // For now only create_nd_desc op is supported. | ||
| auto descOp = dyn_cast<xegpu::CreateNdDescOp>(target); | ||
| if (!descOp) { | ||
|
|
@@ -173,8 +192,6 @@ transform::SetDescLayoutOp::apply(transform::TransformRewriter &rewriter, | |
| } | ||
|
|
||
| // Set layout attr in desc op's return type. Replaces old desc op. | ||
| auto layoutAttr = | ||
| createLayoutAttr(rewriter.getContext(), sgLayout, sgData, maybeInstData); | ||
| auto newdescOp = setDescLayout(rewriter, descOp, layoutAttr); | ||
|
|
||
| // Map result handles. | ||
|
|
@@ -193,6 +210,76 @@ void transform::SetDescLayoutOp::getEffects( | |
| modifiesPayload(effects); | ||
| } | ||
|
|
||
| void transform::SetOpLayoutAttrOp::build( | ||
| OpBuilder &builder, OperationState &ostate, Value target, int64_t index, | ||
| ArrayRef<OpFoldResult> mixedSgLayout, ArrayRef<OpFoldResult> mixedSgData, | ||
| ArrayRef<OpFoldResult> mixedInstData, bool result) { | ||
| SmallVector<int64_t> staticSgLayout, staticSgData, staticInstData; | ||
| SmallVector<Value> dynamicSgLayout, dynamicSgData, dynamicInstData; | ||
| dispatchIndexOpFoldResults(mixedSgLayout, dynamicSgLayout, staticSgLayout); | ||
| dispatchIndexOpFoldResults(mixedSgData, dynamicSgData, staticSgData); | ||
| dispatchIndexOpFoldResults(mixedInstData, dynamicInstData, staticInstData); | ||
| build(builder, ostate, target.getType(), | ||
| /*target=*/target, | ||
| /*index=*/index, | ||
| /*sg_layout=*/dynamicSgLayout, | ||
| /*sg_data=*/dynamicSgData, | ||
| /*inst_data=*/dynamicInstData, | ||
| /*static_sg_layout=*/staticSgLayout, | ||
| /*static_sg_data=*/staticSgData, | ||
| /*static_inst_data=*/staticInstData, | ||
| /*result=*/result); | ||
| } | ||
|
|
||
| DiagnosedSilenceableFailure | ||
| transform::SetOpLayoutAttrOp::apply(transform::TransformRewriter &rewriter, | ||
| transform::TransformResults &results, | ||
| transform::TransformState &state) { | ||
|
|
||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: whitespace |
||
| auto targetOps = state.getPayloadOps(getTarget()); | ||
| if (!llvm::hasSingleElement(targetOps)) { | ||
| return emitDefiniteFailure() << "Requires exactly one targetOp handle (got " | ||
| << llvm::range_size(targetOps) << ")"; | ||
| } | ||
| Operation *target = *targetOps.begin(); | ||
|
|
||
| bool resultTarget = getResult(); | ||
|
|
||
| int64_t index = getIndex(); | ||
| if (resultTarget && index >= target->getNumResults()) { | ||
| return emitSilenceableFailure(getLoc()) | ||
| << "Index exceeds the number of op results"; | ||
| } | ||
| if (!resultTarget && index >= target->getNumOperands()) { | ||
| return emitSilenceableFailure(getLoc()) | ||
| << "Index exceeds the number of op operands"; | ||
| } | ||
|
|
||
| xegpu::LayoutAttr layoutAttr = nullptr; | ||
| auto status = getLayoutAttrFromOperands(rewriter, state, (*this), | ||
| getMixedSgLayout(), getMixedSgData(), | ||
| getMixedInstData(), layoutAttr); | ||
| if (!status.succeeded()) | ||
| return status; | ||
|
|
||
| // Set layout attribute for the op result or operand | ||
| if (resultTarget) { | ||
| xegpu::setDistributeLayoutAttr(target->getResult(index), layoutAttr); | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: braces can be skipped |
||
| } else { | ||
| xegpu::setDistributeLayoutAttr(target->getOpOperand(index), layoutAttr); | ||
| } | ||
| return DiagnosedSilenceableFailure::success(); | ||
| } | ||
|
|
||
| void transform::SetOpLayoutAttrOp::getEffects( | ||
| ::llvm::SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { | ||
| onlyReadsHandle(getTargetMutable(), effects); | ||
| onlyReadsHandle(getSgLayoutMutable(), effects); | ||
| onlyReadsHandle(getSgDataMutable(), effects); | ||
| onlyReadsHandle(getInstDataMutable(), effects); | ||
| modifiesPayload(effects); | ||
| } | ||
|
|
||
| namespace { | ||
| class XeGPUTransformDialectExtension | ||
| : public transform::TransformDialectExtension< | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -64,3 +64,50 @@ def __init__( | |
| loc=loc, | ||
| ip=ip, | ||
| ) | ||
|
|
||
|
|
||
| @_ods_cext.register_operation(_Dialect, replace=True) | ||
| class SetOpLayoutAttrOp(SetOpLayoutAttrOp): | ||
| """Specialization for SetOpLayoutAttrOp class.""" | ||
|
|
||
| def __init__( | ||
| self, | ||
| target: Union[Operation, Value], | ||
| sg_layout: MixedValues, | ||
| sg_data: MixedValues, | ||
| *, | ||
| inst_data: MixedValues = None, | ||
| index: Union[int, Attribute] = None, | ||
| result: Union[bool, Attribute] = None, | ||
|
Comment on lines
+79
to
+81
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Pedantically speaking, I believe these types should be wrapped in |
||
| loc=None, | ||
| ip=None, | ||
| ): | ||
| inst_data = [] if inst_data is None else inst_data | ||
| ( | ||
| dynamic_sg_layout, | ||
| static_sg_layout, | ||
| _, | ||
| ) = _dispatch_dynamic_index_list(sg_layout) | ||
| ( | ||
| dynamic_sg_data, | ||
| static_sg_data, | ||
| _, | ||
| ) = _dispatch_dynamic_index_list(sg_data) | ||
| ( | ||
| dynamic_inst_data, | ||
| static_inst_data, | ||
| _, | ||
| ) = _dispatch_dynamic_index_list(inst_data) | ||
| super().__init__( | ||
| _get_op_result_or_value(target), | ||
| dynamic_sg_layout, | ||
| dynamic_sg_data, | ||
| dynamic_inst_data, | ||
| static_sg_layout=static_sg_layout, | ||
| static_sg_data=static_sg_data, | ||
| static_inst_data=static_inst_data, | ||
| index=index, | ||
| result=result, | ||
| loc=loc, | ||
| ip=ip, | ||
| ) | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just a warning here: XeGPU is refactoring the layout setting API which user only need to set the anchor op's layout, and XeGPU will take care the propagation automatically. The new API that sets anchor ops' layout doesn't require "layout_result_" and "layout_operand_", so may impact this operation also.