Skip to content

Commit 44e24a6

Browse files
committed
Code cleaning
Signed-off-by: Amit Raj <[email protected]>
1 parent bbaaf61 commit 44e24a6

File tree

1 file changed

+4
-89
lines changed

1 file changed

+4
-89
lines changed

QEfficient/transformers/models/llama_swiftkv/modeling_llama_swiftkv.py

Lines changed: 4 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,10 @@
1616
import torch
1717
from torch import nn
1818
from transformers import LlamaConfig
19-
from transformers.cache_utils import Cache, StaticCache
20-
from transformers.modeling_attn_mask_utils import AttentionMaskConverter
19+
from transformers.cache_utils import Cache
2120
from transformers.modeling_outputs import CausalLMOutputWithPast
2221
from transformers.modeling_utils import PreTrainedModel
23-
from transformers.models.llama.modeling_llama import LlamaMLP, LlamaRMSNorm, logger, repeat_kv
22+
from transformers.models.llama.modeling_llama import LlamaMLP, LlamaRMSNorm, repeat_kv
2423

2524
from QEfficient.transformers.cache_utils import QEffDynamicCache
2625
from QEfficient.transformers.modeling_attn_mask_utils import _create_causal_mask
@@ -211,80 +210,6 @@ def _run_swiftkv_layers(
211210
hidden_states = self.norm(hidden_states)
212211
return hidden_states, past_key_values
213212

214-
def _update_causal_mask(
215-
self,
216-
attention_mask: torch.Tensor,
217-
input_tensor: torch.Tensor,
218-
cache_position: torch.Tensor,
219-
position_ids: torch.Tensor,
220-
past_key_values: Cache,
221-
output_attentions: bool,
222-
):
223-
self.config._attn_implementation = "eager"
224-
if self.config._attn_implementation == "flash_attention_2":
225-
if attention_mask is not None and 0.0 in attention_mask:
226-
return attention_mask
227-
return None
228-
229-
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
230-
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
231-
# to infer the attention mask.
232-
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
233-
using_static_cache = isinstance(past_key_values, StaticCache)
234-
235-
# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
236-
if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions:
237-
if AttentionMaskConverter._ignore_causal_mask_sdpa(
238-
attention_mask,
239-
inputs_embeds=input_tensor,
240-
past_key_values_length=past_seen_tokens,
241-
is_training=self.training,
242-
):
243-
return None
244-
245-
dtype, device = input_tensor.dtype, input_tensor.device
246-
min_dtype = torch.finfo(dtype).min
247-
sequence_length = input_tensor.shape[1]
248-
if using_static_cache:
249-
target_length = past_key_values.get_max_length()
250-
else:
251-
target_length = attention_mask.shape[-1] if isinstance(attention_mask, torch.Tensor) else past_seen_tokens
252-
253-
if attention_mask is not None and attention_mask.dim() == 4:
254-
# in this case we assume that the mask comes already in inverted form and requires no inversion or slicing
255-
if attention_mask.max() != 0:
256-
raise ValueError("Custom 4D attention mask should be passed in inverted form with max==0`")
257-
causal_mask = attention_mask
258-
else:
259-
causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device)
260-
if sequence_length != 1:
261-
causal_mask = torch.triu(causal_mask, diagonal=1)
262-
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
263-
causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1)
264-
if attention_mask is not None:
265-
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
266-
mask_length = attention_mask.shape[-1]
267-
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
268-
padding_mask = padding_mask == 0
269-
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
270-
padding_mask, min_dtype
271-
)
272-
else:
273-
causal_mask = _create_causal_mask(position_ids=position_ids, target_length=target_length)
274-
275-
if (
276-
self.config._attn_implementation == "sdpa"
277-
and attention_mask is not None
278-
and attention_mask.device.type == "cuda"
279-
and not output_attentions
280-
):
281-
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
282-
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
283-
# Details: https://github.com/pytorch/pytorch/issues/110213
284-
causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
285-
286-
return causal_mask
287-
288213
def forward(
289214
self,
290215
input_ids: Optional[torch.Tensor],
@@ -298,15 +223,7 @@ def forward(
298223
use_cache = True
299224

300225
if use_cache and not isinstance(past_key_values, Cache):
301-
if past_key_values is None:
302-
past_key_values = QEffDynamicCache()
303-
else:
304-
past_key_values = QEffDynamicCache.from_legacy_cache(past_key_values)
305-
logger.warning_once(
306-
"We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and "
307-
"will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class "
308-
"(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)"
309-
)
226+
past_key_values = QEffDynamicCache.from_legacy_cache(past_key_values)
310227

311228
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
312229
cache_position = torch.arange(
@@ -315,9 +232,7 @@ def forward(
315232
if position_ids is None:
316233
position_ids = cache_position.unsqueeze(0)
317234

318-
causal_mask = self._update_causal_mask(
319-
None, inputs_embeds, cache_position, position_ids, past_key_values, False
320-
)
235+
causal_mask = _create_causal_mask(position_ids=position_ids, target_length=past_seen_tokens)
321236
hidden_states = inputs_embeds
322237

323238
next_decoder_cache = None

0 commit comments

Comments
 (0)