Skip to content

Commit cdbbe0a

Browse files
committed
Bug Fix: Optimised PCA-TopK modify_llama code
1 parent df8b247 commit cdbbe0a

File tree

3 files changed

+9
-7
lines changed

3 files changed

+9
-7
lines changed

methods/pca_topk/cache_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,7 @@ def micro_benchmark_pca_topk(cache, prompt_keys, top_r, top_k, num_gen_steps=200
178178
#.squeeze(0).squeeze(-1),
179179
chunk=256
180180
#chunk=min(k2, 65536 // Q.shape[-1]),
181-
)
181+
) / math.sqrt(head_dim)
182182
attn_weights = torch.softmax(attn_weights, dim=-1)
183183

184184
attn_output = G.gather_inner_matrix_only_bmv(
@@ -276,6 +276,6 @@ def benchmark_attention(batch_size=1,
276276
if __name__ == "__main__":
277277
#test_pcatopk_cache()
278278
with torch.no_grad():
279-
benchmark_attention(prompt_length=512, num_gen_steps=16, batch_size=128, topk=128)
279+
benchmark_attention(prompt_length=4096, num_gen_steps=2000, batch_size=16, topk=1024)
280280

281281

methods/pca_topk/modify_llama_optimized.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -76,9 +76,9 @@ def modified_forward(
7676

7777
# TODO: Keep it fixed or make it dynamic?
7878
if args.top_k <= 1:
79-
args.top_k = int(args.top_k * key_states.shape[-2])
79+
self.top_k = int(args.top_k * key_states.shape[-2])
8080
else:
81-
args.top_k = int(args.top_k)
81+
self.top_k = int(args.top_k)
8282

8383
key_states = torch.matmul(key_states, self.pca_components)
8484
query_states = torch.matmul(query_states, self.pca_components)
@@ -96,7 +96,7 @@ def modified_forward(
9696
# We do not need a causal mask here since this is the generation step
9797
attn_weights = torch.matmul(query_states[:,:,:,:args.top_r], key_states.transpose(2, 3)[:,:,:args.top_r,:]) / math.sqrt(self.head_dim)
9898

99-
key_states_topk_indices = torch.topk(attn_weights, args.top_k, dim=-1).indices.to("cuda")
99+
key_states_topk_indices = torch.topk(attn_weights, self.top_k, dim=-1).indices.to("cuda")
100100
key_states_topk_indices , _ = torch.sort(key_states_topk_indices, dim=-1)
101101
key_states_topk_indices = key_states_topk_indices.reshape(-1, key_states_topk_indices.shape[-1])
102102

@@ -109,7 +109,8 @@ def modified_forward(
109109
key_states_topk_indices,
110110
chunk=256 # Varying this changes performance
111111
#chunk=min(k2, 65536 // Q.shape[-1]),
112-
)
112+
) / math.sqrt(self.head_dim)
113+
113114
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
114115
attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
115116

methods/pca_topk/utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,8 @@ def mask_attn_pca_topk(args, layer_idx, attn_weights, attention_mask, query_stat
104104

105105
# Compute attention with the query_states and key_states_sparse
106106
attn_weights_s_hat = torch.matmul(query_states_sparse, key_states_sparse.transpose(-1, -2)) / math.sqrt(head_dim)
107-
methods.LOGGER.update_config({"scaling_factor": "fixed"})
107+
if methods.LOGGER is not None:
108+
methods.LOGGER.update_config({"scaling_factor": "fixed"})
108109
if attention_mask is not None: # no matter the length, we just slice it
109110
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
110111
attn_weights_s_hat = attn_weights_s_hat + causal_mask

0 commit comments

Comments
 (0)