@@ -76,9 +76,9 @@ def modified_forward(
7676
7777 # TODO: Keep it fixed or make it dynamic?
7878 if args .top_k <= 1 :
79- args .top_k = int (args .top_k * key_states .shape [- 2 ])
79+ self .top_k = int (args .top_k * key_states .shape [- 2 ])
8080 else :
81- args .top_k = int (args .top_k )
81+ self .top_k = int (args .top_k )
8282
8383 key_states = torch .matmul (key_states , self .pca_components )
8484 query_states = torch .matmul (query_states , self .pca_components )
@@ -96,7 +96,7 @@ def modified_forward(
9696 # We do not need a causal mask here since this is the generation step
9797 attn_weights = torch .matmul (query_states [:,:,:,:args .top_r ], key_states .transpose (2 , 3 )[:,:,:args .top_r ,:]) / math .sqrt (self .head_dim )
9898
99- key_states_topk_indices = torch .topk (attn_weights , args .top_k , dim = - 1 ).indices .to ("cuda" )
99+ key_states_topk_indices = torch .topk (attn_weights , self .top_k , dim = - 1 ).indices .to ("cuda" )
100100 key_states_topk_indices , _ = torch .sort (key_states_topk_indices , dim = - 1 )
101101 key_states_topk_indices = key_states_topk_indices .reshape (- 1 , key_states_topk_indices .shape [- 1 ])
102102
@@ -109,7 +109,8 @@ def modified_forward(
109109 key_states_topk_indices ,
110110 chunk = 256 # Varying this changes performance
111111 #chunk=min(k2, 65536 // Q.shape[-1]),
112- )
112+ ) / math .sqrt (self .head_dim )
113+
113114 attn_weights = nn .functional .softmax (attn_weights , dim = - 1 , dtype = torch .float32 ).to (query_states .dtype )
114115 attn_weights = nn .functional .dropout (attn_weights , p = self .attention_dropout , training = self .training )
115116
0 commit comments