@@ -1001,10 +1001,10 @@ llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid() const {
10011001 {
10021002 GGML_ASSERT (hparams.swa_type == LLAMA_SWA_TYPE_NONE && " Hybrid recurrent is not supported with SWA attention layers" );
10031003
1004- const auto n_kv = inp->mctx ->get_attn ()->get_n_kv ();
1004+ const auto n_kv = inp->mctx ->get_attn ()->get_n_kv ();
1005+ const auto n_seqs = cparams.n_seq_virt > 1 ? ubatch.n_seqs_unq : 1 ;
10051006
1006- inp->self_kq_mask = ggml_new_tensor_2d (ctx0, GGML_TYPE_F32, n_kv, GGML_PAD (n_tokens, GGML_KQ_MASK_PAD));
1007- // cb(inp->self_kq_mask, "KQ_mask", -1);
1007+ inp->self_kq_mask = ggml_new_tensor_3d (ctx0, GGML_TYPE_F32, n_kv, GGML_PAD (n_tokens/n_seqs, GGML_KQ_MASK_PAD), n_seqs);
10081008 ggml_set_input (inp->self_kq_mask );
10091009
10101010 inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast (ctx0, inp->self_kq_mask , GGML_TYPE_F16) : inp->self_kq_mask ;
@@ -1206,14 +1206,13 @@ llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified()
12061206 {
12071207 GGML_ASSERT (hparams.swa_type == LLAMA_SWA_TYPE_NONE && " Use llama_kv_cache_unified_iswa for SWA" );
12081208
1209- const auto n_kv = mctx_cur->get_n_kv ();
1209+ const auto n_kv = mctx_cur->get_n_kv ();
12101210 const auto n_seqs = cparams.n_seq_virt > 1 ? ubatch.n_seqs_unq : 1 ;
12111211
12121212 inp->self_kv_idxs = ggml_new_tensor_1d (ctx0, GGML_TYPE_I64, n_tokens);
12131213 ggml_set_input (inp->self_kv_idxs );
12141214
12151215 inp->self_kq_mask = ggml_new_tensor_3d (ctx0, GGML_TYPE_F32, n_kv, GGML_PAD (n_tokens/n_seqs, GGML_KQ_MASK_PAD), n_seqs);
1216- // cb(inp->self_kq_mask, "KQ_mask", -1);
12171216 ggml_set_input (inp->self_kq_mask );
12181217
12191218 inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast (ctx0, inp->self_kq_mask , GGML_TYPE_F16) : inp->self_kq_mask ;
@@ -1440,14 +1439,15 @@ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unif
14401439
14411440 auto inp = std::make_unique<llm_graph_input_attn_kv_unified_iswa>(hparams, cparams, mctx_cur);
14421441
1442+ const auto n_seqs = cparams.n_seq_virt > 1 ? ubatch.n_seqs_unq : 1 ;
1443+
14431444 {
14441445 const auto n_kv = mctx_cur->get_base ()->get_n_kv ();
14451446
14461447 inp->self_kv_idxs = ggml_new_tensor_1d (ctx0, GGML_TYPE_I64, n_tokens);
14471448 ggml_set_input (inp->self_kv_idxs );
14481449
1449- inp->self_kq_mask = ggml_new_tensor_2d (ctx0, GGML_TYPE_F32, n_kv, GGML_PAD (n_tokens, GGML_KQ_MASK_PAD));
1450- // cb(inp->self_kq_mask, "KQ_mask", -1);
1450+ inp->self_kq_mask = ggml_new_tensor_3d (ctx0, GGML_TYPE_F32, n_kv, GGML_PAD (n_tokens/n_seqs, GGML_KQ_MASK_PAD), n_seqs);
14511451 ggml_set_input (inp->self_kq_mask );
14521452
14531453 inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast (ctx0, inp->self_kq_mask , GGML_TYPE_F16) : inp->self_kq_mask ;
@@ -1461,8 +1461,7 @@ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unif
14611461 inp->self_kv_idxs_swa = ggml_new_tensor_1d (ctx0, GGML_TYPE_I64, n_tokens);
14621462 ggml_set_input (inp->self_kv_idxs_swa );
14631463
1464- inp->self_kq_mask_swa = ggml_new_tensor_2d (ctx0, GGML_TYPE_F32, n_kv, GGML_PAD (n_tokens, GGML_KQ_MASK_PAD));
1465- // cb(inp->self_kq_mask_swa, "KQ_mask_swa", -1);
1464+ inp->self_kq_mask_swa = ggml_new_tensor_3d (ctx0, GGML_TYPE_F32, n_kv, GGML_PAD (n_tokens/n_seqs, GGML_KQ_MASK_PAD), n_seqs);
14661465 ggml_set_input (inp->self_kq_mask_swa );
14671466
14681467 inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast (ctx0, inp->self_kq_mask_swa , GGML_TYPE_F16) : inp->self_kq_mask_swa ;
0 commit comments