-
Notifications
You must be signed in to change notification settings - Fork 13.3k
[MLIR][NVVM] Add support for f6x2 conversion #136537
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1066,6 +1066,60 @@ def NVVM_CvtFloatToTF32Op : NVVM_Op<"cvt.float.to.tf32"> { | |
}]; | ||
} | ||
|
||
def CVTFP6E2M3 : I32EnumAttrCase<"E2M3", 0, "e2m3">; | ||
def CVTFP6E3M2 : I32EnumAttrCase<"E3M2", 1, "e3m2">; | ||
|
||
def CVTFP6Type : I32EnumAttr<"CVTFP6Type", "NVVM CVTFP6Type kind", | ||
[CVTFP6E2M3, CVTFP6E3M2]> { | ||
let genSpecializedAttr = 0; | ||
let cppNamespace = "::mlir::NVVM"; | ||
} | ||
def CVTFP6TypeAttr : EnumAttr<NVVM_Dialect, CVTFP6Type, "cvt_fp6_type"> { | ||
let assemblyFormat = "`<` $value `>`"; | ||
} | ||
|
||
def NVVM_CvtToF6x2Op : NVVM_Op<"cvt.to.f6x2"> { | ||
let summary = "Convert a pair of float inputs to f6x2"; | ||
let description = [{ | ||
This Op converts each of the given float inputs to the specified fp6 type. | ||
The result `dst` is represented either as an i16 type or as a vector | ||
of two i8 types. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please also include the details on the nature of the result.. (from spec)
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Updated the description in the latest revision, thanks! |
||
If `dst` is returned as an i16 type, the converted values are packed such | ||
that the value converted from `a` is stored in the upper 8 bits of `dst` | ||
with 2 MSB bits padded with zeros and the value converted from `b` is | ||
stored in the lower 8 bits of `dst` with 2 MSB bits padded with zeros. | ||
If `dst` is returned as a vector type, each converted value is stored as an | ||
i8 element in the vector. | ||
The `relu` attribute, when set, lowers to the '.relu' variant of | ||
the cvt instruction. | ||
|
||
[For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cvt) | ||
}]; | ||
let results = (outs AnyTypeOf<[I16, VectorOfLengthAndType<[2], [I8]>]>:$dst); | ||
let arguments = (ins | ||
CVTFP6TypeAttr:$type, | ||
F32:$a, | ||
F32:$b, | ||
DefaultValuedAttr<BoolAttr, "false">:$relu); | ||
let assemblyFormat = "$type $a `,` $b attr-dict `:` type($dst)"; | ||
|
||
let extraClassDeclaration = [{ | ||
static llvm::Intrinsic::ID getIntrinsicID(NVVM::CVTFP6Type, | ||
bool hasRelu); | ||
bool isReturnVectorType(); | ||
}]; | ||
|
||
string llvmBuilder = [{ | ||
auto intId = NVVM::CvtToF6x2Op::getIntrinsicID($type, $relu); | ||
llvm::Value *packedI16 = createIntrinsicCall(builder, intId, {$a, $b}); | ||
if(op.isReturnVectorType()) | ||
$dst = builder.CreateBitCast(packedI16, | ||
llvm::FixedVectorType::get(llvm::Type::getInt8Ty(builder.getContext()), 2)); | ||
else | ||
$dst = packedI16; | ||
}]; | ||
} | ||
|
||
//===----------------------------------------------------------------------===// | ||
// NVVM MMA Ops | ||
//===----------------------------------------------------------------------===// | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s | ||
|
||
// CHECK-LABEL: @convert_float_to_fp6x2_packed | ||
llvm.func @convert_float_to_fp6x2_packed(%srcA : f32, %srcB : f32) { | ||
//CHECK: %{{.*}} = call i16 @llvm.nvvm.ff.to.e2m3x2.rn.satfinite(float %{{.*}}, float %{{.*}}) | ||
%res1 = nvvm.cvt.to.f6x2 <e2m3> %srcA, %srcB : i16 | ||
//CHECK: %{{.*}} = call i16 @llvm.nvvm.ff.to.e3m2x2.rn.satfinite(float %{{.*}}, float %{{.*}}) | ||
%res2 = nvvm.cvt.to.f6x2 <e3m2> %srcA, %srcB : i16 | ||
llvm.return | ||
} | ||
|
||
// CHECK-LABEL: @convert_float_to_fp6x2_vector | ||
llvm.func @convert_float_to_fp6x2_vector(%srcA : f32, %srcB : f32) { | ||
//CHECK: %[[res0:.*]] = call i16 @llvm.nvvm.ff.to.e2m3x2.rn.satfinite(float %{{.*}}, float %{{.*}}) | ||
//CHECK-NEXT: %{{.*}} = bitcast i16 %[[res0]] to <2 x i8> | ||
%res1 = nvvm.cvt.to.f6x2 <e2m3> %srcA, %srcB : vector<2xi8> | ||
//CHECK: %[[res1:.*]] = call i16 @llvm.nvvm.ff.to.e3m2x2.rn.satfinite(float %{{.*}}, float %{{.*}}) | ||
//CHECK-NEXT: %{{.*}} = bitcast i16 %[[res1]] to <2 x i8> | ||
%res2 = nvvm.cvt.to.f6x2 <e3m2> %srcA, %srcB : vector<2xi8> | ||
llvm.return | ||
} | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The FP6 types themselves can be used in other places/context too.
So, we do not need to attach "CVT" context to these types and keep them generic as "FP6E2M3" etc.
(May be even FP6_E2M3 for better readability..)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I did this here because I see that for the
wgmma.mma_async
Op, we have types likeWGMMATypeF8E4M3
to represent the operand types. If we have a more general enumeration, should we change the usage inwgmma.mma_async
too (since we will be having cvt Ops for the FP8 types as well)?