Skip to content

Commit 3454f84

Browse files
pggPLpre-commit-ci[bot]cyanguwa
authored
[common] Remove kvpacked and qkvpacked attention functions for every kernel type. (NVIDIA#2287)
* code drop Signed-off-by: Pawel Gadzinski <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by: Pawel Gadzinski <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by: Pawel Gadzinski <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by: Pawel Gadzinski <[email protected]> * depracted compile time warning + \warning -> \deprecated Signed-off-by: Pawel Gadzinski <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Pawel Gadzinski <[email protected]> Signed-off-by: Charlene Yang <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Charlene Yang <[email protected]>
1 parent d20311b commit 3454f84

File tree

8 files changed

+302
-1388
lines changed

8 files changed

+302
-1388
lines changed

transformer_engine/common/fused_attn/fused_attn.cpp

Lines changed: 280 additions & 53 deletions
Large diffs are not rendered by default.

transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu

Lines changed: 2 additions & 528 deletions
Large diffs are not rendered by default.

transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.h

Lines changed: 0 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -18,53 +18,6 @@
1818

1919
namespace transformer_engine {
2020
#if (CUDNN_VERSION >= 8900)
21-
void fused_attn_arbitrary_seqlen_fwd_qkvpacked(
22-
size_t batch, size_t num_attn_heads, size_t max_seqlen, size_t head_dim, size_t num_tokens,
23-
bool is_training, bool return_max_logit, float attn_scale, float p_dropout,
24-
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
25-
NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right,
26-
const Tensor *input_QKV, const Tensor *input_Bias, const Tensor *input_SoftmaxOffset,
27-
Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens,
28-
const Tensor *cu_seqlens_padded, const Tensor *rng_state, Tensor *workspace,
29-
cudaStream_t stream, cudnnHandle_t handle);
30-
31-
void fused_attn_arbitrary_seqlen_bwd_qkvpacked(
32-
size_t batch, size_t num_attn_heads, size_t max_seqlen, size_t head_dim, size_t num_tokens,
33-
float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
34-
NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left,
35-
int64_t window_size_right, bool deterministic, const Tensor *input_QKV, const Tensor *input_O,
36-
const Tensor *input_dO, const Tensor *input_Bias, const Tensor *input_SoftmaxOffset,
37-
Tensor *output_S, Tensor *output_dQKV, Tensor *output_dBias, Tensor *output_dSoftmaxOffset,
38-
const Tensor *cu_seqlens, const Tensor *cu_seqlens_padded, const Tensor *rng_state,
39-
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle);
40-
41-
void fused_attn_arbitrary_seqlen_fwd_kvpacked(
42-
size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q,
43-
size_t max_seqlen_kv, size_t head_dim, size_t num_tokens_q, size_t num_tokens_kv,
44-
size_t num_pages_k, size_t num_pages_v, size_t page_size_k, size_t page_size_v,
45-
size_t max_pages_per_seq_k, size_t max_pages_per_seq_v, bool is_training, bool return_max_logit,
46-
float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
47-
NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left,
48-
int64_t window_size_right, const Tensor *input_Q, const Tensor *input_KV,
49-
const Tensor *input_Bias, const Tensor *input_SoftmaxOffset, Tensor *output_O,
50-
NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv,
51-
const Tensor *cu_seqlens_q_padded, const Tensor *cu_seqlens_kv_padded,
52-
const Tensor *page_table_k, const Tensor *page_table_v, const Tensor *rng_state,
53-
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle);
54-
55-
void fused_attn_arbitrary_seqlen_bwd_kvpacked(
56-
size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q,
57-
size_t max_seqlen_kv, size_t head_dim, size_t num_tokens_q, size_t num_tokens_kv,
58-
float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
59-
NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left,
60-
int64_t window_size_right, bool deterministic, const Tensor *input_Q, const Tensor *input_KV,
61-
const Tensor *input_O, const Tensor *input_dO, const Tensor *input_Bias,
62-
const Tensor *input_SoftmaxOffset, Tensor *output_S, Tensor *output_dQ, Tensor *output_dKV,
63-
Tensor *output_dBias, Tensor *output_dSoftmaxOffset, const Tensor *cu_seqlens_q,
64-
const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded,
65-
const Tensor *cu_seqlens_kv_padded, const Tensor *rng_state, Tensor *workspace,
66-
cudaStream_t stream, cudnnHandle_t handle);
67-
6821
void fused_attn_arbitrary_seqlen_fwd(
6922
size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q,
7023
size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, size_t num_tokens_q,

transformer_engine/common/fused_attn/fused_attn_f16_max512_seqlen.cu

Lines changed: 0 additions & 264 deletions
Original file line numberDiff line numberDiff line change
@@ -1215,150 +1215,6 @@ void fused_attn_max_512_bwd_impl(int64_t b, int64_t h, int64_t s_q, int64_t s_kv
12151215
} // namespace fused_attn
12161216

12171217
using namespace transformer_engine::fused_attn;
1218-
void fused_attn_max_512_fwd_qkvpacked(
1219-
size_t batch, size_t num_head, size_t max_seqlen, size_t head_dim, bool is_training,
1220-
float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
1221-
NVTE_Mask_Type mask_type, const Tensor *input_QKV, const Tensor *input_Bias, Tensor *output_O,
1222-
NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens, const Tensor *rng_state,
1223-
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) {
1224-
using namespace transformer_engine;
1225-
1226-
// QKV shape is [b, s, 3, h, d]
1227-
void *devPtrQKV = input_QKV->data.dptr;
1228-
const auto stride = 2 * num_head * head_dim;
1229-
1230-
void *devPtrQ = static_cast<void *>(devPtrQKV);
1231-
void *devPtrK = static_cast<void *>(static_cast<int8_t *>(devPtrQKV) + stride);
1232-
void *devPtrV = static_cast<void *>(static_cast<int8_t *>(devPtrQKV) + 2 * stride);
1233-
1234-
void *devPtrBias = static_cast<void *>(input_Bias->data.dptr);
1235-
1236-
void *devPtrO = output_O->data.dptr;
1237-
1238-
void *devPtrS = nullptr;
1239-
1240-
if (Aux_CTX_Tensors->size == 0) {
1241-
Aux_CTX_Tensors->size = 1;
1242-
Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]);
1243-
output_S->data.dptr = nullptr;
1244-
output_S->data.shape = {batch, num_head, max_seqlen, max_seqlen};
1245-
output_S->data.dtype = input_QKV->data.dtype;
1246-
} else if (Aux_CTX_Tensors->size == 1) {
1247-
Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]);
1248-
devPtrS = output_S->data.dptr;
1249-
} else {
1250-
NVTE_ERROR("Unexpected Aux_CTX_Tensors->size.");
1251-
}
1252-
1253-
void *devPtrCuSeqlen = cu_seqlens->data.dptr;
1254-
1255-
const DType rng_state_type = rng_state->data.dtype;
1256-
NVTE_CHECK(rng_state_type == DType::kInt64);
1257-
void *devPtrDropoutSeed = rng_state->data.dptr;
1258-
void *devPtrDropoutOffset =
1259-
static_cast<void *>(static_cast<uint64_t *>(rng_state->data.dptr) + 1);
1260-
1261-
const DType QKV_type = input_QKV->data.dtype;
1262-
size_t workspace_size = 0;
1263-
1264-
fused_attn_max_512_fwd_impl(
1265-
batch, num_head, max_seqlen, max_seqlen, head_dim, is_training, attn_scale, p_dropout,
1266-
qkv_layout, bias_type, mask_type, devPtrQ, devPtrK, devPtrV, devPtrS, devPtrO, devPtrBias,
1267-
devPtrCuSeqlen, devPtrCuSeqlen, devPtrDropoutSeed, devPtrDropoutOffset, workspace->data.dptr,
1268-
&workspace_size, get_cudnn_dtype(QKV_type), stream, handle);
1269-
1270-
if (workspace_size > 0) {
1271-
if (workspace->data.dptr == nullptr) {
1272-
workspace->data.shape = {workspace_size};
1273-
workspace->data.dtype = DType::kByte;
1274-
return;
1275-
}
1276-
} else if (workspace_size == 0) {
1277-
workspace->data.shape = {1};
1278-
workspace->data.dtype = DType::kByte;
1279-
return;
1280-
} else {
1281-
NVTE_ERROR("Unexpected workspace_size.");
1282-
}
1283-
}
1284-
1285-
void fused_attn_max_512_fwd_kvpacked(size_t batch, size_t num_head, size_t q_max_seqlen,
1286-
size_t kv_max_seqlen, size_t head_dim, bool is_training,
1287-
float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout,
1288-
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
1289-
const Tensor *input_Q, const Tensor *input_KV,
1290-
const Tensor *input_Bias, Tensor *output_O,
1291-
NVTETensorPack *Aux_CTX_Tensors, const Tensor *q_cu_seqlens,
1292-
const Tensor *kv_cu_seqlens, const Tensor *rng_state,
1293-
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) {
1294-
using namespace transformer_engine;
1295-
1296-
NVTE_CHECK(bias_type == NVTE_Bias_Type::NVTE_NO_BIAS ||
1297-
bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS,
1298-
"NVTE_PRE_SCALE_BIAS is not implemented in fused_attn_max_512.");
1299-
1300-
// Q shape is [b, s, h, d]
1301-
void *devPtrQ = input_Q->data.dptr;
1302-
1303-
// KV shape is [b, s, 2, h, d]
1304-
const auto stride = 2 * num_head * head_dim;
1305-
void *devPtrK = input_KV->data.dptr;
1306-
void *devPtrV = static_cast<void *>(static_cast<int8_t *>(devPtrK) + stride);
1307-
1308-
void *devPtrBias = input_Bias->data.dptr;
1309-
1310-
void *devPtrO = output_O->data.dptr;
1311-
1312-
void *devPtrS = nullptr;
1313-
1314-
const DType q_type = input_Q->data.dtype;
1315-
const DType kv_type = input_KV->data.dtype;
1316-
NVTE_CHECK(q_type == kv_type, "data type of Q must be equal to data type of KV.");
1317-
1318-
if (Aux_CTX_Tensors->size == 0) {
1319-
Aux_CTX_Tensors->size = 1;
1320-
Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]);
1321-
output_S->data.dptr = nullptr;
1322-
output_S->data.shape = {batch, num_head, q_max_seqlen, kv_max_seqlen};
1323-
output_S->data.dtype = q_type;
1324-
} else if (Aux_CTX_Tensors->size == 1) {
1325-
Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]);
1326-
devPtrS = output_S->data.dptr;
1327-
} else {
1328-
NVTE_ERROR("Unexpected Aux_CTX_Tensors->size.");
1329-
}
1330-
1331-
void *devQCuSeqlen = q_cu_seqlens->data.dptr;
1332-
void *devKVCuSeqlen = kv_cu_seqlens->data.dptr;
1333-
1334-
const DType rng_state_type = rng_state->data.dtype;
1335-
NVTE_CHECK(rng_state_type == DType::kInt64);
1336-
void *devPtrDropoutSeed = rng_state->data.dptr;
1337-
void *devPtrDropoutOffset =
1338-
static_cast<void *>(static_cast<uint64_t *>(rng_state->data.dptr) + 1);
1339-
1340-
size_t workspace_size = 0;
1341-
1342-
fused_attn_max_512_fwd_impl(
1343-
batch, num_head, q_max_seqlen, kv_max_seqlen, head_dim, is_training, attn_scale, p_dropout,
1344-
qkv_layout, bias_type, mask_type, devPtrQ, devPtrK, devPtrV, devPtrS, devPtrO, devPtrBias,
1345-
devQCuSeqlen, devKVCuSeqlen, devPtrDropoutSeed, devPtrDropoutOffset, workspace->data.dptr,
1346-
&workspace_size, get_cudnn_dtype(q_type), stream, handle);
1347-
1348-
if (workspace_size > 0) {
1349-
if (workspace->data.dptr == nullptr) {
1350-
workspace->data.shape = {workspace_size};
1351-
workspace->data.dtype = DType::kByte;
1352-
return;
1353-
}
1354-
} else if (workspace_size == 0) {
1355-
workspace->data.shape = {1};
1356-
workspace->data.dtype = DType::kByte;
1357-
return;
1358-
} else {
1359-
NVTE_ERROR("Unexpected workspace_size.");
1360-
}
1361-
}
13621218
void fused_attn_max_512_fwd(size_t batch, size_t num_head, size_t q_max_seqlen,
13631219
size_t kv_max_seqlen, size_t head_dim, bool is_training,
13641220
float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout,
@@ -1429,126 +1285,6 @@ void fused_attn_max_512_fwd(size_t batch, size_t num_head, size_t q_max_seqlen,
14291285
}
14301286
}
14311287

1432-
void fused_attn_max_512_bwd_qkvpacked(size_t batch, size_t num_head, size_t max_seqlen,
1433-
size_t head_dim, float attn_scale, float p_dropout,
1434-
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
1435-
NVTE_Mask_Type mask_type, const Tensor *input_QKV,
1436-
const Tensor *input_dO, Tensor *output_S, Tensor *output_dQKV,
1437-
Tensor *output_dBias, const Tensor *cu_seqlens,
1438-
Tensor *workspace, cudaStream_t stream,
1439-
cudnnHandle_t handle) {
1440-
using namespace transformer_engine;
1441-
1442-
// QKV shape is [b, s, 3, h, d]
1443-
void *devPtrQKV = input_QKV->data.dptr;
1444-
1445-
auto stride = 2 * num_head * head_dim;
1446-
void *devPtrQ = devPtrQKV;
1447-
void *devPtrK = static_cast<void *>(static_cast<int8_t *>(devPtrQKV) + stride);
1448-
void *devPtrV = static_cast<void *>(static_cast<int8_t *>(devPtrQKV) + 2 * stride);
1449-
1450-
void *devPtrdO = input_dO->data.dptr;
1451-
1452-
// dQKV shape is [b, s, 3, h, d]
1453-
void *devPtrdQKV = output_dQKV->data.dptr;
1454-
void *devPtrdQ = devPtrdQKV;
1455-
void *devPtrdK = static_cast<void *>(static_cast<int8_t *>(devPtrdQKV) + stride);
1456-
void *devPtrdV = static_cast<void *>(static_cast<int8_t *>(devPtrdQKV) + 2 * stride);
1457-
1458-
void *devPtrdBias = output_dBias->data.dptr;
1459-
1460-
void *devPtrS = output_S->data.dptr;
1461-
1462-
// devPtrdS reuses the memory of devPtrS
1463-
void *devPtrdS = devPtrS;
1464-
1465-
void *devPtrCuSeqlens = cu_seqlens->data.dptr;
1466-
1467-
const auto qkv_type = input_QKV->data.dtype;
1468-
size_t workspace_size = 0;
1469-
1470-
fused_attn_max_512_bwd_impl(batch, num_head, max_seqlen, max_seqlen, head_dim, attn_scale,
1471-
p_dropout, qkv_layout, mask_type, bias_type, devPtrQ, devPtrK,
1472-
devPtrV, devPtrS, devPtrdQ, devPtrdK, devPtrdV, devPtrdO, devPtrdS,
1473-
devPtrdBias, devPtrCuSeqlens, devPtrCuSeqlens, workspace->data.dptr,
1474-
&workspace_size, get_cudnn_dtype(qkv_type), stream, handle);
1475-
1476-
if (workspace_size > 0) {
1477-
if (workspace->data.dptr == nullptr) {
1478-
workspace->data.shape = {workspace_size};
1479-
workspace->data.dtype = DType::kByte;
1480-
return;
1481-
}
1482-
} else if (workspace_size == 0) {
1483-
workspace->data.shape = {1};
1484-
workspace->data.dtype = DType::kByte;
1485-
return;
1486-
} else {
1487-
NVTE_ERROR("Unexpected workspace_size.");
1488-
}
1489-
}
1490-
1491-
void fused_attn_max_512_bwd_kvpacked(size_t batch, size_t num_head, size_t q_max_seqlen,
1492-
size_t kv_max_seqlen, size_t head_dim, float attn_scale,
1493-
float p_dropout, NVTE_QKV_Layout qkv_layout,
1494-
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
1495-
const Tensor *input_Q, const Tensor *input_KV,
1496-
const Tensor *input_dO, Tensor *output_S, Tensor *output_dQ,
1497-
Tensor *output_dKV, Tensor *output_dBias,
1498-
const Tensor *q_cu_seqlens, const Tensor *kv_cu_seqlens,
1499-
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) {
1500-
using namespace transformer_engine;
1501-
1502-
// Q shape is [b, s, h, d]
1503-
// KV shape is [b, s, 2, h, d]
1504-
auto stride = 2 * num_head * head_dim;
1505-
void *devPtrQ = input_Q->data.dptr;
1506-
void *devPtrK = input_KV->data.dptr;
1507-
void *devPtrV = static_cast<void *>(static_cast<int8_t *>(devPtrK) + stride);
1508-
1509-
void *devPtrdO = input_dO->data.dptr;
1510-
1511-
// dQ shape is [b, s, h, d]
1512-
// dKV shape is [b, s, 2, h, d]
1513-
void *devPtrdQ = output_dQ->data.dptr;
1514-
void *devPtrdK = output_dKV->data.dptr;
1515-
void *devPtrdV = static_cast<void *>(static_cast<int8_t *>(devPtrdK) + stride);
1516-
1517-
void *devPtrdBias = output_dBias->data.dptr;
1518-
1519-
void *devPtrS = output_S->data.dptr;
1520-
1521-
// devPtrdS reuses the memory of devPtrS
1522-
void *devPtrdS = devPtrS;
1523-
1524-
void *devPtrQCuSeqlens = q_cu_seqlens->data.dptr;
1525-
void *devPtrKVCuSeqlens = kv_cu_seqlens->data.dptr;
1526-
1527-
const auto q_type = input_Q->data.dtype;
1528-
const auto kv_type = input_KV->data.dtype;
1529-
NVTE_CHECK(q_type == kv_type, "data type of Q must be equal to data type of KV.");
1530-
size_t workspace_size = 0;
1531-
1532-
fused_attn_max_512_bwd_impl(
1533-
batch, num_head, q_max_seqlen, kv_max_seqlen, head_dim, attn_scale, p_dropout, qkv_layout,
1534-
mask_type, bias_type, devPtrQ, devPtrK, devPtrV, devPtrS, devPtrdQ, devPtrdK, devPtrdV,
1535-
devPtrdO, devPtrdS, devPtrdBias, devPtrQCuSeqlens, devPtrKVCuSeqlens, workspace->data.dptr,
1536-
&workspace_size, get_cudnn_dtype(q_type), stream, handle);
1537-
1538-
if (workspace_size > 0) {
1539-
if (workspace->data.dptr == nullptr) {
1540-
workspace->data.shape = {workspace_size};
1541-
workspace->data.dtype = DType::kByte;
1542-
return;
1543-
}
1544-
} else if (workspace_size == 0) {
1545-
workspace->data.shape = {1};
1546-
workspace->data.dtype = DType::kByte;
1547-
return;
1548-
} else {
1549-
NVTE_ERROR("Unexpected workspace_size.");
1550-
}
1551-
}
15521288
void fused_attn_max_512_bwd(size_t batch, size_t num_head, size_t q_max_seqlen,
15531289
size_t kv_max_seqlen, size_t head_dim, float attn_scale,
15541290
float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,

transformer_engine/common/fused_attn/fused_attn_f16_max512_seqlen.h

Lines changed: 0 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -18,25 +18,6 @@
1818

1919
namespace transformer_engine {
2020
#if (CUDNN_VERSION >= 8901)
21-
void fused_attn_max_512_fwd_qkvpacked(size_t batch, size_t num_head, size_t max_seqlen,
22-
size_t head_size, bool is_training, float attn_scale,
23-
float p_dropout, NVTE_QKV_Layout qkv_layout,
24-
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
25-
const Tensor *input_QKV, const Tensor *input_Bias,
26-
Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors,
27-
const Tensor *cu_seqlens, const Tensor *rng_state,
28-
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle);
29-
30-
void fused_attn_max_512_fwd_kvpacked(size_t batch, size_t num_head, size_t q_max_seqlen,
31-
size_t kv_max_seqlen, size_t head_dim, bool is_training,
32-
float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout,
33-
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
34-
const Tensor *input_Q, const Tensor *input_KV,
35-
const Tensor *input_Bias, Tensor *output_O,
36-
NVTETensorPack *Aux_CTX_Tensors, const Tensor *q_cu_seqlens,
37-
const Tensor *kv_cu_seqlens, const Tensor *rng_state,
38-
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle);
39-
4021
void fused_attn_max_512_fwd(size_t batch, size_t num_head, size_t q_max_seqlen,
4122
size_t kv_max_seqlen, size_t head_dim, bool is_training,
4223
float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout,
@@ -47,24 +28,6 @@ void fused_attn_max_512_fwd(size_t batch, size_t num_head, size_t q_max_seqlen,
4728
const Tensor *kv_cu_seqlens, const Tensor *rng_state, Tensor *workspace,
4829
cudaStream_t stream, cudnnHandle_t handle);
4930

50-
void fused_attn_max_512_bwd_qkvpacked(size_t batch, size_t num_head, size_t max_seqlen,
51-
size_t head_dim, float attn_scale, float p_dropout,
52-
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
53-
NVTE_Mask_Type mask_type, const Tensor *input_QKV,
54-
const Tensor *input_dO, Tensor *output_S, Tensor *output_dQKV,
55-
Tensor *output_dBias, const Tensor *cu_seqlens,
56-
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle);
57-
58-
void fused_attn_max_512_bwd_kvpacked(size_t batch, size_t num_head, size_t q_max_seqlen,
59-
size_t kv_max_seqlen, size_t head_dim, float attn_scale,
60-
float p_dropout, NVTE_QKV_Layout qkv_layout,
61-
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
62-
const Tensor *input_Q, const Tensor *input_KV,
63-
const Tensor *input_dO, Tensor *output_S, Tensor *output_dQ,
64-
Tensor *output_dKV, Tensor *output_dBias,
65-
const Tensor *q_cu_seqlens, const Tensor *kv_cu_seqlens,
66-
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle);
67-
6831
void fused_attn_max_512_bwd(size_t batch, size_t num_head, size_t q_max_seqlen,
6932
size_t kv_max_seqlen, size_t head_dim, float attn_scale,
7033
float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,

0 commit comments

Comments
 (0)