diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td index 0a6e66919f021..0e1481ea374db 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td @@ -1066,6 +1066,59 @@ def NVVM_CvtFloatToTF32Op : NVVM_Op<"cvt.float.to.tf32"> { }]; } +def CVTFP6E2M3 : I32EnumAttrCase<"E2M3", 0, "e2m3">; +def CVTFP6E3M2 : I32EnumAttrCase<"E3M2", 1, "e3m2">; + +def CVTFP6Type : I32EnumAttr<"CVTFP6Type", "NVVM CVTFP6Type kind", + [CVTFP6E2M3, CVTFP6E3M2]> { + let genSpecializedAttr = 0; + let cppNamespace = "::mlir::NVVM"; +} +def CVTFP6TypeAttr : EnumAttr { + let assemblyFormat = "`<` $value `>`"; +} + +def NVVM_CvtToF6x2Op : NVVM_Op<"cvt.to.f6x2"> { + let summary = "Convert a pair of float inputs to f6x2"; + let description = [{ + This Op converts each of the given float inputs to the specified fp6 type. + The result `dst` is represented either as an i16 type or as a vector + of two i8 types. + If `dst` is returned as an i16 type, the converted values are packed such + that the value converted from `a` is stored in the upper 8 bits of `dst` + with 2 MSB bits padded with zeros and the value converted from `b` is + stored in the lower 8 bits of `dst` with 2 MSB bits padded with zeros. + If `dst` is returned as a vector type, each converted value is stored as an + i8 element in the vector. + The `relu` attribute, when set, lowers to the '.relu' variant of + the cvt instruction. + + [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cvt) + }]; + let results = (outs AnyTypeOf<[I16, VectorOfLengthAndType<[2], [I8]>]>:$dst); + let arguments = (ins + CVTFP6TypeAttr:$type, + F32:$a, + F32:$b, + DefaultValuedAttr:$relu); + let assemblyFormat = "$type $a `,` $b attr-dict `:` type($dst)"; + + let extraClassDeclaration = [{ + static llvm::Intrinsic::ID getIntrinsicID(NVVM::CVTFP6Type, + bool hasRelu); + }]; + + string llvmBuilder = [{ + auto intId = NVVM::CvtToF6x2Op::getIntrinsicID($type, $relu); + llvm::Value *packedI16 = createIntrinsicCall(builder, intId, {$a, $b}); + if(op.getDst().getType().isInteger(16)) + $dst = packedI16; + else + $dst = builder.CreateBitCast(packedI16, + llvm::FixedVectorType::get(llvm::Type::getInt8Ty(builder.getContext()), 2)); + }]; +} + //===----------------------------------------------------------------------===// // NVVM MMA Ops //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp index e3d496c983e59..44040401b0406 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp @@ -1290,6 +1290,22 @@ llvm::Intrinsic::ID CvtFloatToTF32Op::getIntrinsicID(NVVM::FPRoundingMode rnd, } } +#define CVT_TO_F6X2_ID_IMPL(type, has_relu) \ + has_relu ? llvm::Intrinsic::nvvm_ff_to_##type##_rn_relu_satfinite \ + : llvm::Intrinsic::nvvm_ff_to_##type##_rn_satfinite + +llvm::Intrinsic::ID CvtToF6x2Op::getIntrinsicID(NVVM::CVTFP6Type type, + bool hasRelu) { + switch (type) { + case NVVM::CVTFP6Type::E2M3: + return CVT_TO_F6X2_ID_IMPL(e2m3x2, hasRelu); + case NVVM::CVTFP6Type::E3M2: + return CVT_TO_F6X2_ID_IMPL(e3m2x2, hasRelu); + default: + llvm_unreachable("Invalid CVTFP6Type for CvtToF6x2Op"); + } +} + llvm::Intrinsic::ID Tcgen05AllocOp::getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt, diff --git a/mlir/test/Target/LLVMIR/nvvm/cvt_fp6x2.mlir b/mlir/test/Target/LLVMIR/nvvm/cvt_fp6x2.mlir new file mode 100644 index 0000000000000..2237e6faad52d --- /dev/null +++ b/mlir/test/Target/LLVMIR/nvvm/cvt_fp6x2.mlir @@ -0,0 +1,22 @@ +// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s + +// CHECK-LABEL: @convert_float_to_fp6x2_packed +llvm.func @convert_float_to_fp6x2_packed(%srcA : f32, %srcB : f32) { + //CHECK: %{{.*}} = call i16 @llvm.nvvm.ff.to.e2m3x2.rn.satfinite(float %{{.*}}, float %{{.*}}) + %res1 = nvvm.cvt.to.f6x2 %srcA, %srcB : i16 + //CHECK: %{{.*}} = call i16 @llvm.nvvm.ff.to.e3m2x2.rn.satfinite(float %{{.*}}, float %{{.*}}) + %res2 = nvvm.cvt.to.f6x2 %srcA, %srcB : i16 + llvm.return +} + +// CHECK-LABEL: @convert_float_to_fp6x2_vector +llvm.func @convert_float_to_fp6x2_vector(%srcA : f32, %srcB : f32) { + //CHECK: %[[res0:.*]] = call i16 @llvm.nvvm.ff.to.e2m3x2.rn.satfinite(float %{{.*}}, float %{{.*}}) + //CHECK-NEXT: %{{.*}} = bitcast i16 %[[res0]] to <2 x i8> + %res1 = nvvm.cvt.to.f6x2 %srcA, %srcB : vector<2xi8> + //CHECK: %[[res1:.*]] = call i16 @llvm.nvvm.ff.to.e3m2x2.rn.satfinite(float %{{.*}}, float %{{.*}}) + //CHECK-NEXT: %{{.*}} = bitcast i16 %[[res1]] to <2 x i8> + %res2 = nvvm.cvt.to.f6x2 %srcA, %srcB : vector<2xi8> + llvm.return +} +