Skip to content

Commit a08c429

Browse files
authored
[Xe] 4-bit unit stride -> VNNI reorders (#593)
Adds support for unit -> VNNI reorders for 4-bit types, to enable row-major int4 B. Note that this is an expensive reorder sequence (9 cycles/register), and it significantly impacts performance for current mainloops. We'll need to enable an SLM-based mainloop to reduce the cost of this reorder.
1 parent 80524d7 commit a08c429

File tree

2 files changed

+44
-0
lines changed

2 files changed

+44
-0
lines changed

examples/cute/tutorial/xe_gemm.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -414,6 +414,7 @@ int main(int argc, char** argv)
414414
test_case<int4_t, uint8_t, int32_t, 'R', 'C'>(Q, m, n, k);
415415

416416
test_case<uint4_t, uint4_t, uint32_t, 'R', 'C'>(Q, m, n, k);
417+
test_case<uint4_t, uint4_t, uint32_t, 'R', 'R'>(Q, m, n, k);
417418

418419
// Upconversion cases
419420
test_case<half_t, float_e5m2_t, float, 'R', 'R'>(Q, m, n, k);

include/cute/arch/reorder_xe.hpp

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1264,5 +1264,48 @@ struct Xe_Reorder<ReorderKind::UU, float_ue8m0_t, float>
12641264
}
12651265
};
12661266

1267+
template <>
1268+
struct Xe_Reorder<ReorderKind::UV, uint4_t, uint4_t>
1269+
{
1270+
using SRegisters = intel::uchar4[1];
1271+
using DRegisters = intel::uchar4[1];
1272+
1273+
CUTE_HOST_DEVICE static void
1274+
reorder(intel::uchar4 const& src0, intel::uchar4& dst0)
1275+
{
1276+
#if defined(CUTE_ARCH_COPY_XE_ENABLED)
1277+
const uint32_t lshifts = 0x00000004;
1278+
const uint32_t rshifts = 0x00040000;
1279+
asm ( /* 9 cycles/output register */
1280+
"{\n"
1281+
".decl IN_UB v_type=G type=UB num_elts=64 alias=<%1,0>\n"
1282+
".decl OUT_UB v_type=G type=UB num_elts=64 alias=<%0,0>\n"
1283+
".decl OUT_UW v_type=G type=UW num_elts=32 alias=<%0,0>\n"
1284+
".decl LSHIFTS v_type=G type=UW num_elts=2 alias=<%2,0>\n"
1285+
".decl RSHIFTS v_type=G type=UW num_elts=2 alias=<%3,0>\n"
1286+
".decl TMP_UB v_type=G type=UB num_elts=64 align=64\n"
1287+
".decl TMP_UW v_type=G type=UW num_elts=32 alias=<TMP_UB,0>\n"
1288+
"shr (M1_NM, 16) OUT_UB(0,0)<4> IN_UB(0, 0)<1;2,0> RSHIFTS(0,0)<0;2,1>\n"
1289+
"shr (M1_NM, 16) OUT_UB(0,1)<4> IN_UB(0,16)<1;2,0> RSHIFTS(0,0)<0;2,1>\n"
1290+
"shr (M1_NM, 16) OUT_UB(0,2)<4> IN_UB(0,32)<1;2,0> RSHIFTS(0,0)<0;2,1>\n"
1291+
"shr (M1_NM, 16) OUT_UB(0,3)<4> IN_UB(0,48)<1;2,0> RSHIFTS(0,0)<0;2,1>\n"
1292+
"shl (M1_NM, 16) TMP_UB(0,0)<4> IN_UB(0, 8)<1;2,0> LSHIFTS(0,0)<0;2,1>\n"
1293+
"shl (M1_NM, 16) TMP_UB(0,1)<4> IN_UB(0,24)<1;2,0> LSHIFTS(0,0)<0;2,1>\n"
1294+
"shl (M1_NM, 16) TMP_UB(0,2)<4> IN_UB(0,40)<1;2,0> LSHIFTS(0,0)<0;2,1>\n"
1295+
"shl (M1_NM, 16) TMP_UB(0,3)<4> IN_UB(0,56)<1;2,0> LSHIFTS(0,0)<0;2,1>\n"
1296+
"bfn.xCA (M1_NM, 32) OUT_UW(0,0)<1> OUT_UW(0,0)<1;1,0> TMP_UW(0,0)<1;1,0> 0xF0F0:uw\n"
1297+
"}\n"
1298+
: "=rw"(dst0)
1299+
: "rw"(src0), "rw.u"(lshifts), "rw.u"(rshifts)
1300+
);
1301+
#else
1302+
CUTE_INVALID_CONTROL_PATH("Not Xe");
1303+
#endif
1304+
}
1305+
};
1306+
1307+
template <> struct Xe_Reorder<ReorderKind::UV, int4_t, int4_t> : Xe_Reorder<ReorderKind::UV, uint4_t, uint4_t> {};
1308+
template <> struct Xe_Reorder<ReorderKind::UV, float_e2m1_t, float_e2m1_t> : Xe_Reorder<ReorderKind::UV, uint4_t, uint4_t> {};
1309+
12671310

12681311
} // end namespace cute

0 commit comments

Comments
 (0)