@@ -1215,150 +1215,6 @@ void fused_attn_max_512_bwd_impl(int64_t b, int64_t h, int64_t s_q, int64_t s_kv
12151215} // namespace fused_attn
12161216
12171217using namespace transformer_engine ::fused_attn;
1218- void fused_attn_max_512_fwd_qkvpacked (
1219- size_t batch, size_t num_head, size_t max_seqlen, size_t head_dim, bool is_training,
1220- float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
1221- NVTE_Mask_Type mask_type, const Tensor *input_QKV, const Tensor *input_Bias, Tensor *output_O,
1222- NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens, const Tensor *rng_state,
1223- Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) {
1224- using namespace transformer_engine ;
1225-
1226- // QKV shape is [b, s, 3, h, d]
1227- void *devPtrQKV = input_QKV->data .dptr ;
1228- const auto stride = 2 * num_head * head_dim;
1229-
1230- void *devPtrQ = static_cast <void *>(devPtrQKV);
1231- void *devPtrK = static_cast <void *>(static_cast <int8_t *>(devPtrQKV) + stride);
1232- void *devPtrV = static_cast <void *>(static_cast <int8_t *>(devPtrQKV) + 2 * stride);
1233-
1234- void *devPtrBias = static_cast <void *>(input_Bias->data .dptr );
1235-
1236- void *devPtrO = output_O->data .dptr ;
1237-
1238- void *devPtrS = nullptr ;
1239-
1240- if (Aux_CTX_Tensors->size == 0 ) {
1241- Aux_CTX_Tensors->size = 1 ;
1242- Tensor *output_S = convertNVTETensorCheck (Aux_CTX_Tensors->tensors [0 ]);
1243- output_S->data .dptr = nullptr ;
1244- output_S->data .shape = {batch, num_head, max_seqlen, max_seqlen};
1245- output_S->data .dtype = input_QKV->data .dtype ;
1246- } else if (Aux_CTX_Tensors->size == 1 ) {
1247- Tensor *output_S = convertNVTETensorCheck (Aux_CTX_Tensors->tensors [0 ]);
1248- devPtrS = output_S->data .dptr ;
1249- } else {
1250- NVTE_ERROR (" Unexpected Aux_CTX_Tensors->size." );
1251- }
1252-
1253- void *devPtrCuSeqlen = cu_seqlens->data .dptr ;
1254-
1255- const DType rng_state_type = rng_state->data .dtype ;
1256- NVTE_CHECK (rng_state_type == DType::kInt64 );
1257- void *devPtrDropoutSeed = rng_state->data .dptr ;
1258- void *devPtrDropoutOffset =
1259- static_cast <void *>(static_cast <uint64_t *>(rng_state->data .dptr ) + 1 );
1260-
1261- const DType QKV_type = input_QKV->data .dtype ;
1262- size_t workspace_size = 0 ;
1263-
1264- fused_attn_max_512_fwd_impl (
1265- batch, num_head, max_seqlen, max_seqlen, head_dim, is_training, attn_scale, p_dropout,
1266- qkv_layout, bias_type, mask_type, devPtrQ, devPtrK, devPtrV, devPtrS, devPtrO, devPtrBias,
1267- devPtrCuSeqlen, devPtrCuSeqlen, devPtrDropoutSeed, devPtrDropoutOffset, workspace->data .dptr ,
1268- &workspace_size, get_cudnn_dtype (QKV_type), stream, handle);
1269-
1270- if (workspace_size > 0 ) {
1271- if (workspace->data .dptr == nullptr ) {
1272- workspace->data .shape = {workspace_size};
1273- workspace->data .dtype = DType::kByte ;
1274- return ;
1275- }
1276- } else if (workspace_size == 0 ) {
1277- workspace->data .shape = {1 };
1278- workspace->data .dtype = DType::kByte ;
1279- return ;
1280- } else {
1281- NVTE_ERROR (" Unexpected workspace_size." );
1282- }
1283- }
1284-
1285- void fused_attn_max_512_fwd_kvpacked (size_t batch, size_t num_head, size_t q_max_seqlen,
1286- size_t kv_max_seqlen, size_t head_dim, bool is_training,
1287- float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout,
1288- NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
1289- const Tensor *input_Q, const Tensor *input_KV,
1290- const Tensor *input_Bias, Tensor *output_O,
1291- NVTETensorPack *Aux_CTX_Tensors, const Tensor *q_cu_seqlens,
1292- const Tensor *kv_cu_seqlens, const Tensor *rng_state,
1293- Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) {
1294- using namespace transformer_engine ;
1295-
1296- NVTE_CHECK (bias_type == NVTE_Bias_Type::NVTE_NO_BIAS ||
1297- bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS,
1298- " NVTE_PRE_SCALE_BIAS is not implemented in fused_attn_max_512." );
1299-
1300- // Q shape is [b, s, h, d]
1301- void *devPtrQ = input_Q->data .dptr ;
1302-
1303- // KV shape is [b, s, 2, h, d]
1304- const auto stride = 2 * num_head * head_dim;
1305- void *devPtrK = input_KV->data .dptr ;
1306- void *devPtrV = static_cast <void *>(static_cast <int8_t *>(devPtrK) + stride);
1307-
1308- void *devPtrBias = input_Bias->data .dptr ;
1309-
1310- void *devPtrO = output_O->data .dptr ;
1311-
1312- void *devPtrS = nullptr ;
1313-
1314- const DType q_type = input_Q->data .dtype ;
1315- const DType kv_type = input_KV->data .dtype ;
1316- NVTE_CHECK (q_type == kv_type, " data type of Q must be equal to data type of KV." );
1317-
1318- if (Aux_CTX_Tensors->size == 0 ) {
1319- Aux_CTX_Tensors->size = 1 ;
1320- Tensor *output_S = convertNVTETensorCheck (Aux_CTX_Tensors->tensors [0 ]);
1321- output_S->data .dptr = nullptr ;
1322- output_S->data .shape = {batch, num_head, q_max_seqlen, kv_max_seqlen};
1323- output_S->data .dtype = q_type;
1324- } else if (Aux_CTX_Tensors->size == 1 ) {
1325- Tensor *output_S = convertNVTETensorCheck (Aux_CTX_Tensors->tensors [0 ]);
1326- devPtrS = output_S->data .dptr ;
1327- } else {
1328- NVTE_ERROR (" Unexpected Aux_CTX_Tensors->size." );
1329- }
1330-
1331- void *devQCuSeqlen = q_cu_seqlens->data .dptr ;
1332- void *devKVCuSeqlen = kv_cu_seqlens->data .dptr ;
1333-
1334- const DType rng_state_type = rng_state->data .dtype ;
1335- NVTE_CHECK (rng_state_type == DType::kInt64 );
1336- void *devPtrDropoutSeed = rng_state->data .dptr ;
1337- void *devPtrDropoutOffset =
1338- static_cast <void *>(static_cast <uint64_t *>(rng_state->data .dptr ) + 1 );
1339-
1340- size_t workspace_size = 0 ;
1341-
1342- fused_attn_max_512_fwd_impl (
1343- batch, num_head, q_max_seqlen, kv_max_seqlen, head_dim, is_training, attn_scale, p_dropout,
1344- qkv_layout, bias_type, mask_type, devPtrQ, devPtrK, devPtrV, devPtrS, devPtrO, devPtrBias,
1345- devQCuSeqlen, devKVCuSeqlen, devPtrDropoutSeed, devPtrDropoutOffset, workspace->data .dptr ,
1346- &workspace_size, get_cudnn_dtype (q_type), stream, handle);
1347-
1348- if (workspace_size > 0 ) {
1349- if (workspace->data .dptr == nullptr ) {
1350- workspace->data .shape = {workspace_size};
1351- workspace->data .dtype = DType::kByte ;
1352- return ;
1353- }
1354- } else if (workspace_size == 0 ) {
1355- workspace->data .shape = {1 };
1356- workspace->data .dtype = DType::kByte ;
1357- return ;
1358- } else {
1359- NVTE_ERROR (" Unexpected workspace_size." );
1360- }
1361- }
13621218void fused_attn_max_512_fwd (size_t batch, size_t num_head, size_t q_max_seqlen,
13631219 size_t kv_max_seqlen, size_t head_dim, bool is_training,
13641220 float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout,
@@ -1429,126 +1285,6 @@ void fused_attn_max_512_fwd(size_t batch, size_t num_head, size_t q_max_seqlen,
14291285 }
14301286}
14311287
1432- void fused_attn_max_512_bwd_qkvpacked (size_t batch, size_t num_head, size_t max_seqlen,
1433- size_t head_dim, float attn_scale, float p_dropout,
1434- NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
1435- NVTE_Mask_Type mask_type, const Tensor *input_QKV,
1436- const Tensor *input_dO, Tensor *output_S, Tensor *output_dQKV,
1437- Tensor *output_dBias, const Tensor *cu_seqlens,
1438- Tensor *workspace, cudaStream_t stream,
1439- cudnnHandle_t handle) {
1440- using namespace transformer_engine ;
1441-
1442- // QKV shape is [b, s, 3, h, d]
1443- void *devPtrQKV = input_QKV->data .dptr ;
1444-
1445- auto stride = 2 * num_head * head_dim;
1446- void *devPtrQ = devPtrQKV;
1447- void *devPtrK = static_cast <void *>(static_cast <int8_t *>(devPtrQKV) + stride);
1448- void *devPtrV = static_cast <void *>(static_cast <int8_t *>(devPtrQKV) + 2 * stride);
1449-
1450- void *devPtrdO = input_dO->data .dptr ;
1451-
1452- // dQKV shape is [b, s, 3, h, d]
1453- void *devPtrdQKV = output_dQKV->data .dptr ;
1454- void *devPtrdQ = devPtrdQKV;
1455- void *devPtrdK = static_cast <void *>(static_cast <int8_t *>(devPtrdQKV) + stride);
1456- void *devPtrdV = static_cast <void *>(static_cast <int8_t *>(devPtrdQKV) + 2 * stride);
1457-
1458- void *devPtrdBias = output_dBias->data .dptr ;
1459-
1460- void *devPtrS = output_S->data .dptr ;
1461-
1462- // devPtrdS reuses the memory of devPtrS
1463- void *devPtrdS = devPtrS;
1464-
1465- void *devPtrCuSeqlens = cu_seqlens->data .dptr ;
1466-
1467- const auto qkv_type = input_QKV->data .dtype ;
1468- size_t workspace_size = 0 ;
1469-
1470- fused_attn_max_512_bwd_impl (batch, num_head, max_seqlen, max_seqlen, head_dim, attn_scale,
1471- p_dropout, qkv_layout, mask_type, bias_type, devPtrQ, devPtrK,
1472- devPtrV, devPtrS, devPtrdQ, devPtrdK, devPtrdV, devPtrdO, devPtrdS,
1473- devPtrdBias, devPtrCuSeqlens, devPtrCuSeqlens, workspace->data .dptr ,
1474- &workspace_size, get_cudnn_dtype (qkv_type), stream, handle);
1475-
1476- if (workspace_size > 0 ) {
1477- if (workspace->data .dptr == nullptr ) {
1478- workspace->data .shape = {workspace_size};
1479- workspace->data .dtype = DType::kByte ;
1480- return ;
1481- }
1482- } else if (workspace_size == 0 ) {
1483- workspace->data .shape = {1 };
1484- workspace->data .dtype = DType::kByte ;
1485- return ;
1486- } else {
1487- NVTE_ERROR (" Unexpected workspace_size." );
1488- }
1489- }
1490-
1491- void fused_attn_max_512_bwd_kvpacked (size_t batch, size_t num_head, size_t q_max_seqlen,
1492- size_t kv_max_seqlen, size_t head_dim, float attn_scale,
1493- float p_dropout, NVTE_QKV_Layout qkv_layout,
1494- NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
1495- const Tensor *input_Q, const Tensor *input_KV,
1496- const Tensor *input_dO, Tensor *output_S, Tensor *output_dQ,
1497- Tensor *output_dKV, Tensor *output_dBias,
1498- const Tensor *q_cu_seqlens, const Tensor *kv_cu_seqlens,
1499- Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) {
1500- using namespace transformer_engine ;
1501-
1502- // Q shape is [b, s, h, d]
1503- // KV shape is [b, s, 2, h, d]
1504- auto stride = 2 * num_head * head_dim;
1505- void *devPtrQ = input_Q->data .dptr ;
1506- void *devPtrK = input_KV->data .dptr ;
1507- void *devPtrV = static_cast <void *>(static_cast <int8_t *>(devPtrK) + stride);
1508-
1509- void *devPtrdO = input_dO->data .dptr ;
1510-
1511- // dQ shape is [b, s, h, d]
1512- // dKV shape is [b, s, 2, h, d]
1513- void *devPtrdQ = output_dQ->data .dptr ;
1514- void *devPtrdK = output_dKV->data .dptr ;
1515- void *devPtrdV = static_cast <void *>(static_cast <int8_t *>(devPtrdK) + stride);
1516-
1517- void *devPtrdBias = output_dBias->data .dptr ;
1518-
1519- void *devPtrS = output_S->data .dptr ;
1520-
1521- // devPtrdS reuses the memory of devPtrS
1522- void *devPtrdS = devPtrS;
1523-
1524- void *devPtrQCuSeqlens = q_cu_seqlens->data .dptr ;
1525- void *devPtrKVCuSeqlens = kv_cu_seqlens->data .dptr ;
1526-
1527- const auto q_type = input_Q->data .dtype ;
1528- const auto kv_type = input_KV->data .dtype ;
1529- NVTE_CHECK (q_type == kv_type, " data type of Q must be equal to data type of KV." );
1530- size_t workspace_size = 0 ;
1531-
1532- fused_attn_max_512_bwd_impl (
1533- batch, num_head, q_max_seqlen, kv_max_seqlen, head_dim, attn_scale, p_dropout, qkv_layout,
1534- mask_type, bias_type, devPtrQ, devPtrK, devPtrV, devPtrS, devPtrdQ, devPtrdK, devPtrdV,
1535- devPtrdO, devPtrdS, devPtrdBias, devPtrQCuSeqlens, devPtrKVCuSeqlens, workspace->data .dptr ,
1536- &workspace_size, get_cudnn_dtype (q_type), stream, handle);
1537-
1538- if (workspace_size > 0 ) {
1539- if (workspace->data .dptr == nullptr ) {
1540- workspace->data .shape = {workspace_size};
1541- workspace->data .dtype = DType::kByte ;
1542- return ;
1543- }
1544- } else if (workspace_size == 0 ) {
1545- workspace->data .shape = {1 };
1546- workspace->data .dtype = DType::kByte ;
1547- return ;
1548- } else {
1549- NVTE_ERROR (" Unexpected workspace_size." );
1550- }
1551- }
15521288void fused_attn_max_512_bwd (size_t batch, size_t num_head, size_t q_max_seqlen,
15531289 size_t kv_max_seqlen, size_t head_dim, float attn_scale,
15541290 float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
0 commit comments