Skip to content

Commit f040873

Browse files
committed
[MLIR][NVVM] Add support for f6x2 conversion
This patch adds the `cvt.to.fp6x2` NVVM dialect Op for conversion into f6x2 types. For more information, see PTX ISA: https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-shfl-sync
1 parent 8435de0 commit f040873

File tree

4 files changed

+133
-0
lines changed

4 files changed

+133
-0
lines changed

mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td

+50
Original file line numberDiff line numberDiff line change
@@ -1066,6 +1066,56 @@ def NVVM_CvtFloatToTF32Op : NVVM_Op<"cvt.float.to.tf32"> {
10661066
}];
10671067
}
10681068

1069+
def FP6E2M3 : I32EnumAttrCase<"E2M3", 0, "e2m3">;
1070+
def FP6E3M2 : I32EnumAttrCase<"E3M2", 1, "e3m2">;
1071+
1072+
def FP6Type : I32EnumAttr<"FP6Type", "NVVM FP6Type kind",
1073+
[FP6E2M3, FP6E3M2]> {
1074+
let genSpecializedAttr = 0;
1075+
let cppNamespace = "::mlir::NVVM";
1076+
}
1077+
def FP6TypeAttr : EnumAttr<NVVM_Dialect, FP6Type, "fp6_type"> {
1078+
let assemblyFormat = "`<` $value `>`";
1079+
}
1080+
1081+
def NVVM_CvtToF6x2Op : NVVM_Op<"cvt.to.f6x2"> {
1082+
let summary = "Convert the given float input to f6x2";
1083+
let description = [{
1084+
This Op converts the given float input to f6x2.
1085+
The result `dst` is represented either as an i16 type or a vector
1086+
of two i8 types.
1087+
The `relu` attribute, when set, lowers to the '.relu' variant of
1088+
the cvt instruction. The `rnd` and `sat` attributes specify the
1089+
the rounding and saturation modes respectively.
1090+
1091+
[For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cvt)
1092+
}];
1093+
let results = (outs AnyTypeOf<[I16, VectorOfLengthAndType<[2], [I8]>]>:$dst);
1094+
let arguments = (ins
1095+
FP6TypeAttr:$type,
1096+
F32:$a,
1097+
F32:$b,
1098+
DefaultValuedAttr<FPRoundingModeAttr, "FPRoundingMode::RN">:$rnd,
1099+
DefaultValuedAttr<SaturationModeAttr, "SaturationMode::SATFINITE">:$sat,
1100+
DefaultValuedAttr<BoolAttr, "false">:$relu);
1101+
let assemblyFormat = "$type $a `,` $b attr-dict `:` type($dst)";
1102+
1103+
let extraClassDeclaration = [{
1104+
static llvm::Intrinsic::ID getIntrinsicID(NVVM::FP6Type,
1105+
bool hasRelu);
1106+
bool isPacked();
1107+
llvm::Value* getCastedResult(llvm::Value* packedI16, llvm::IRBuilderBase &builder);
1108+
}];
1109+
1110+
string llvmBuilder = [{
1111+
auto intId = NVVM::CvtToF6x2Op::getIntrinsicID($type, $relu);
1112+
llvm::Value *packedI16 = createIntrinsicCall(builder, intId, {$a, $b});
1113+
$dst = op.getCastedResult(packedI16, builder);
1114+
}];
1115+
1116+
let hasVerifier = 1;
1117+
}
1118+
10691119
//===----------------------------------------------------------------------===//
10701120
// NVVM MMA Ops
10711121
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp

+45
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
#include "llvm/AsmParser/Parser.h"
3434
#include "llvm/IR/Attributes.h"
3535
#include "llvm/IR/Function.h"
36+
#include "llvm/IR/IRBuilder.h"
3637
#include "llvm/IR/IntrinsicsNVPTX.h"
3738
#include "llvm/IR/Type.h"
3839
#include "llvm/Support/Casting.h"
@@ -133,6 +134,33 @@ LogicalResult CvtFloatToTF32Op::verify() {
133134
return success();
134135
}
135136

137+
bool CvtToF6x2Op::isPacked() {
138+
if (getDst().getType().isInteger(16)) {
139+
return true;
140+
}
141+
return false;
142+
}
143+
144+
llvm::Value *CvtToF6x2Op::getCastedResult(llvm::Value *packedI16,
145+
llvm::IRBuilderBase &builder) {
146+
if (isPacked()) {
147+
return packedI16;
148+
}
149+
return builder.CreateBitCast(
150+
packedI16, llvm::FixedVectorType::get(
151+
llvm::Type::getInt8Ty(builder.getContext()), 2));
152+
}
153+
154+
LogicalResult CvtToF6x2Op::verify() {
155+
if (getRnd() != NVVM::FPRoundingMode::RN) {
156+
return emitOpError("RN rounding mode required for CvtToF6x2Op.");
157+
}
158+
if (getSat() != NVVM::SaturationMode::SATFINITE) {
159+
return emitOpError("SATFINITE saturation mode required for CvtToF6x2Op.");
160+
}
161+
return success();
162+
}
163+
136164
LogicalResult BulkStoreOp::verify() {
137165
if (getInitVal() != 0)
138166
return emitOpError("only 0 is supported for initVal, got ") << getInitVal();
@@ -1290,6 +1318,23 @@ llvm::Intrinsic::ID CvtFloatToTF32Op::getIntrinsicID(NVVM::FPRoundingMode rnd,
12901318
}
12911319
}
12921320

1321+
#define CVT_TO_F6X2_ID_IMPL(type, relu) \
1322+
hasRelu ? llvm::Intrinsic::nvvm_ff_to_##type##_rn##relu##_satfinite \
1323+
: llvm::Intrinsic::nvvm_ff_to_##type##_rn_satfinite
1324+
1325+
llvm::Intrinsic::ID CvtToF6x2Op::getIntrinsicID(NVVM::FP6Type type,
1326+
bool hasRelu) {
1327+
switch (type) {
1328+
case NVVM::FP6Type::E2M3:
1329+
return CVT_TO_F6X2_ID_IMPL(e2m3x2, _relu);
1330+
case NVVM::FP6Type::E3M2:
1331+
return CVT_TO_F6X2_ID_IMPL(e3m2x2, _relu);
1332+
default:
1333+
break;
1334+
}
1335+
llvm_unreachable("Invalid FP6Type for CvtToF6x2Op");
1336+
}
1337+
12931338
llvm::Intrinsic::ID
12941339
Tcgen05AllocOp::getIntrinsicIDAndArgs(Operation &op,
12951340
LLVM::ModuleTranslation &mt,
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s
2+
3+
// CHECK-LABEL: @convert_float_to_fp6x2_packed
4+
llvm.func @convert_float_to_fp6x2_packed(%srcA : f32, %srcB : f32) {
5+
//CHECK: %{{.*}} = call i16 @llvm.nvvm.ff.to.e2m3x2.rn.satfinite(float %{{.*}}, float %{{.*}})
6+
%res1 = nvvm.cvt.to.f6x2 <e2m3> %srcA, %srcB : i16
7+
//CHECK: %{{.*}} = call i16 @llvm.nvvm.ff.to.e3m2x2.rn.satfinite(float %{{.*}}, float %{{.*}})
8+
%res2 = nvvm.cvt.to.f6x2 <e3m2> %srcA, %srcB : i16
9+
llvm.return
10+
}
11+
12+
// CHECK-LABEL: @convert_float_to_fp6x2_vector
13+
llvm.func @convert_float_to_fp6x2_vector(%srcA : f32, %srcB : f32) {
14+
//CHECK: %[[res0:.*]] = call i16 @llvm.nvvm.ff.to.e2m3x2.rn.satfinite(float %{{.*}}, float %{{.*}})
15+
//CHECK-NEXT: %{{.*}} = bitcast i16 %[[res0]] to <2 x i8>
16+
%res1 = nvvm.cvt.to.f6x2 <e2m3> %srcA, %srcB : vector<2xi8>
17+
//CHECK: %[[res1:.*]] = call i16 @llvm.nvvm.ff.to.e3m2x2.rn.satfinite(float %{{.*}}, float %{{.*}})
18+
//CHECK-NEXT: %{{.*}} = bitcast i16 %[[res1]] to <2 x i8>
19+
%res2 = nvvm.cvt.to.f6x2 <e3m2> %srcA, %srcB : vector<2xi8>
20+
llvm.return
21+
}
22+

mlir/test/Target/LLVMIR/nvvmir-invalid.mlir

+16
Original file line numberDiff line numberDiff line change
@@ -176,3 +176,19 @@ llvm.func @nvvm_match_sync_any(%val32: i32, %thread_mask: i32) {
176176
%0 = nvvm.match.sync any %thread_mask, %val32 : i32 -> !llvm.struct<(i32, i1)>
177177
llvm.return
178178
}
179+
180+
// -----
181+
182+
llvm.func @nvvm_cvt_to_f6x2(%a : f32, %b : f32) {
183+
// expected-error @below {{RN rounding mode required for CvtToF6x2Op.}}
184+
%res = nvvm.cvt.to.f6x2 <e2m3> %a, %b {rnd = #nvvm.fp_rnd_mode<rna>} : i16
185+
llvm.return
186+
}
187+
188+
// -----
189+
190+
llvm.func @nvvm_cvt_to_f6x2_packed(%a : f32, %b : f32) {
191+
// expected-error @below {{SATFINITE saturation mode required for CvtToF6x2Op.}}
192+
%res = nvvm.cvt.to.f6x2 <e3m2> %a, %b {sat = #nvvm.sat_mode<none>} : i16
193+
llvm.return
194+
}

0 commit comments

Comments
 (0)