1616import torch
1717from torch import nn
1818from transformers import LlamaConfig
19- from transformers .cache_utils import Cache , StaticCache
20- from transformers .modeling_attn_mask_utils import AttentionMaskConverter
19+ from transformers .cache_utils import Cache
2120from transformers .modeling_outputs import CausalLMOutputWithPast
2221from transformers .modeling_utils import PreTrainedModel
23- from transformers .models .llama .modeling_llama import LlamaMLP , LlamaRMSNorm , logger , repeat_kv
22+ from transformers .models .llama .modeling_llama import LlamaMLP , LlamaRMSNorm , repeat_kv
2423
2524from QEfficient .transformers .cache_utils import QEffDynamicCache
2625from QEfficient .transformers .modeling_attn_mask_utils import _create_causal_mask
@@ -211,80 +210,6 @@ def _run_swiftkv_layers(
211210 hidden_states = self .norm (hidden_states )
212211 return hidden_states , past_key_values
213212
214- def _update_causal_mask (
215- self ,
216- attention_mask : torch .Tensor ,
217- input_tensor : torch .Tensor ,
218- cache_position : torch .Tensor ,
219- position_ids : torch .Tensor ,
220- past_key_values : Cache ,
221- output_attentions : bool ,
222- ):
223- self .config ._attn_implementation = "eager"
224- if self .config ._attn_implementation == "flash_attention_2" :
225- if attention_mask is not None and 0.0 in attention_mask :
226- return attention_mask
227- return None
228-
229- # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
230- # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
231- # to infer the attention mask.
232- past_seen_tokens = past_key_values .get_seq_length () if past_key_values is not None else 0
233- using_static_cache = isinstance (past_key_values , StaticCache )
234-
235- # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
236- if self .config ._attn_implementation == "sdpa" and not using_static_cache and not output_attentions :
237- if AttentionMaskConverter ._ignore_causal_mask_sdpa (
238- attention_mask ,
239- inputs_embeds = input_tensor ,
240- past_key_values_length = past_seen_tokens ,
241- is_training = self .training ,
242- ):
243- return None
244-
245- dtype , device = input_tensor .dtype , input_tensor .device
246- min_dtype = torch .finfo (dtype ).min
247- sequence_length = input_tensor .shape [1 ]
248- if using_static_cache :
249- target_length = past_key_values .get_max_length ()
250- else :
251- target_length = attention_mask .shape [- 1 ] if isinstance (attention_mask , torch .Tensor ) else past_seen_tokens
252-
253- if attention_mask is not None and attention_mask .dim () == 4 :
254- # in this case we assume that the mask comes already in inverted form and requires no inversion or slicing
255- if attention_mask .max () != 0 :
256- raise ValueError ("Custom 4D attention mask should be passed in inverted form with max==0`" )
257- causal_mask = attention_mask
258- else :
259- causal_mask = torch .full ((sequence_length , target_length ), fill_value = min_dtype , dtype = dtype , device = device )
260- if sequence_length != 1 :
261- causal_mask = torch .triu (causal_mask , diagonal = 1 )
262- causal_mask *= torch .arange (target_length , device = device ) > cache_position .reshape (- 1 , 1 )
263- causal_mask = causal_mask [None , None , :, :].expand (input_tensor .shape [0 ], 1 , - 1 , - 1 )
264- if attention_mask is not None :
265- causal_mask = causal_mask .clone () # copy to contiguous memory for in-place edit
266- mask_length = attention_mask .shape [- 1 ]
267- padding_mask = causal_mask [:, :, :, :mask_length ] + attention_mask [:, None , None , :]
268- padding_mask = padding_mask == 0
269- causal_mask [:, :, :, :mask_length ] = causal_mask [:, :, :, :mask_length ].masked_fill (
270- padding_mask , min_dtype
271- )
272- else :
273- causal_mask = _create_causal_mask (position_ids = position_ids , target_length = target_length )
274-
275- if (
276- self .config ._attn_implementation == "sdpa"
277- and attention_mask is not None
278- and attention_mask .device .type == "cuda"
279- and not output_attentions
280- ):
281- # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
282- # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
283- # Details: https://github.com/pytorch/pytorch/issues/110213
284- causal_mask = AttentionMaskConverter ._unmask_unattended (causal_mask , min_dtype )
285-
286- return causal_mask
287-
288213 def forward (
289214 self ,
290215 input_ids : Optional [torch .Tensor ],
@@ -298,15 +223,7 @@ def forward(
298223 use_cache = True
299224
300225 if use_cache and not isinstance (past_key_values , Cache ):
301- if past_key_values is None :
302- past_key_values = QEffDynamicCache ()
303- else :
304- past_key_values = QEffDynamicCache .from_legacy_cache (past_key_values )
305- logger .warning_once (
306- "We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and "
307- "will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class "
308- "(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)"
309- )
226+ past_key_values = QEffDynamicCache .from_legacy_cache (past_key_values )
310227
311228 past_seen_tokens = past_key_values .get_seq_length () if past_key_values is not None else 0
312229 cache_position = torch .arange (
@@ -315,9 +232,7 @@ def forward(
315232 if position_ids is None :
316233 position_ids = cache_position .unsqueeze (0 )
317234
318- causal_mask = self ._update_causal_mask (
319- None , inputs_embeds , cache_position , position_ids , past_key_values , False
320- )
235+ causal_mask = _create_causal_mask (position_ids = position_ids , target_length = past_seen_tokens )
321236 hidden_states = inputs_embeds
322237
323238 next_decoder_cache = None
0 commit comments