@@ -304,20 +304,22 @@ def sdpa_mask_recent_torch(
304304
305305 # Similar to `kv_arange = mint.arange(start=kv_offset, end=kv_offset + kv_length, device=cache_position.device)`
306306 # but without data-dependent slicing (i.e. torch.compile friendly)
307- kv_arange = mint .arange (kv_length , device = cache_position . device )
307+ kv_arange = mint .arange (kv_length )
308308 kv_arange += kv_offset
309309
310310 # Potentially add the padding 2D mask
311311 if padding_mask is not None :
312312 mask_function = and_masks (mask_function , padding_mask_function (padding_mask ))
313313
314- batch_arange = mint .arange (batch_size , device = cache_position . device )
315- head_arange = mint .arange (1 , device = cache_position . device )
314+ batch_arange = mint .arange (batch_size )
315+ head_arange = mint .arange (1 )
316316 # This creates the 4D mask easily. Note that we need this context manager as vmap cannot handle slicing a tensor from
317317 # scalar tensor (it internally calls `.item()` which vmap does not allow, but this context works around it
318318 # We don't need to add an offset to the mask_function either, as we vmap directly the correct indices for k and kv indices
319319 # with TransformGetItemToIndex():
320- causal_mask = _vmap_for_bhqkv (mask_function )(batch_arange , head_arange , cache_position , kv_arange )
320+ # TODO There is a compile problem if using 'mindspore.vmap', we abandon this operator and generate 2D mask --> 4D mask step by step
321+ causal_mask = mask_function ()(batch_arange , head_arange , cache_position , kv_arange )
322+ causal_mask = causal_mask [None , None , :, :].broadcast_to ((batch_size , - 1 , - 1 , - 1 ))
321323
322324 return causal_mask
323325
@@ -383,7 +385,8 @@ def sdpa_mask_older_torch(
383385 # as vmap cannot handle slicing a tensor from scalar tensor (it internally calls `.item()` which vmap does not allow
384386 # However, in more recent version of Pytorch, a trick was introduced to handle it - which is the reason we have
385387 # `sdpa_mask_recent_torch`, as it allows more general `mask_function`
386- causal_mask = _vmap_for_bhqkv (mask_function , bh_indices = False )(None , None , cache_position , kv_arange )
388+ # TODO There is a compile problem if using 'mindspore.vmap', we abandon this operator and generate 2D mask --> 4D mask step by step
389+ causal_mask = mask_function ()(None , None , cache_position , kv_arange )
387390 causal_mask = causal_mask [None , None , :, :].broadcast_to ((batch_size , - 1 , - 1 , - 1 ))
388391 if padding_mask is not None :
389392 causal_mask = causal_mask * padding_mask [:, None , None , :]
@@ -436,7 +439,8 @@ def _ignore_causal_mask_sdpa(
436439
437440# We use the version with newer torch whenever possible, as it is more general and can handle arbitrary mask functions
438441# (especially mask_function indexing a tensor, such as the padding mask function)
439- sdpa_mask = sdpa_mask_older_torch # TODO: use sdpa_mask_recent_torch orsdpa_mask_older_torch?
442+ # TODO we do not go through older sdpa func like transformers setting, default setting is set to `sdpa_mask_recent_torch`
443+ sdpa_mask = sdpa_mask_recent_torch
440444
441445
442446def eager_mask (
@@ -669,6 +673,8 @@ def create_causal_mask(
669673 cache_position : ms .Tensor ,
670674 past_key_values : Optional [Cache ],
671675 position_ids : Optional [ms .Tensor ] = None ,
676+ or_mask_function : Optional [Callable ] = None ,
677+ and_mask_function : Optional [Callable ] = None ,
672678) -> Optional [Union [ms .Tensor , BlockMask ]]:
673679 """
674680 Create a standard causal mask based on the attention implementation used (stored in the config). If `past_key_values`
@@ -717,18 +723,17 @@ def create_causal_mask(
717723 # TODO: cyril -> probably revisit and remove this, but a lot of tests rely on it
718724 allow_is_causal_skip = not past_key_values .is_compileable if past_key_values is not None else True
719725
720- # TODO there is a compile problem during and_masks/or_masks func used as mask_factory_function, Comment this part firstly
721- # # If we detected packing format
722- # if packed_sequence_mask is not None:
723- # mask_factory_function = and_masks(mask_factory_function, packed_sequence_mask_function(packed_sequence_mask))
724- # allow_is_causal_skip = False
725- # # Allow slight deviations from causal mask
726- # if or_mask_function is not None:
727- # mask_factory_function = or_masks(mask_factory_function, or_mask_function)
728- # allow_is_causal_skip = False
729- # if and_mask_function is not None:
730- # mask_factory_function = and_masks(mask_factory_function, and_mask_function)
731- # allow_is_causal_skip = False
726+ # If we detected packing format
727+ if packed_sequence_mask is not None :
728+ mask_factory_function = and_masks (mask_factory_function , packed_sequence_mask_function (packed_sequence_mask ))
729+ allow_is_causal_skip = False
730+ # Allow slight deviations from causal mask
731+ if or_mask_function is not None :
732+ mask_factory_function = or_masks (mask_factory_function , or_mask_function )
733+ allow_is_causal_skip = False
734+ if and_mask_function is not None :
735+ mask_factory_function = and_masks (mask_factory_function , and_mask_function )
736+ allow_is_causal_skip = False
732737
733738 # We now create the mask
734739 causal_mask = mask_interface (
@@ -806,18 +811,17 @@ def create_sliding_window_causal_mask(
806811 # Do not allow skip if we are compiling (this is to match BC)
807812 # TODO: cyril -> probably revisit and remove this, but a lot of tests rely on it
808813 allow_is_causal_skip = not past_key_values .is_compileable if past_key_values is not None else True
809- # # TODO there is a compile problem during and_masks/or_masks func used as mask_factory_function, Comment this part firstly
810- # # If we detected packing format
811- # if packed_sequence_mask is not None:
812- # mask_factory_function = and_masks(mask_factory_function, packed_sequence_mask_function(packed_sequence_mask))
813- # allow_is_causal_skip = False
814- # # Allow slight deviations from sliding causal mask
815- # if or_mask_function is not None:
816- # mask_factory_function = or_masks(mask_factory_function, or_mask_function)
817- # allow_is_causal_skip = False
818- # if and_mask_function is not None:
819- # mask_factory_function = and_masks(mask_factory_function, and_mask_function)
820- # allow_is_causal_skip = False
814+ # If we detected packing format
815+ if packed_sequence_mask is not None :
816+ mask_factory_function = and_masks (mask_factory_function , packed_sequence_mask_function (packed_sequence_mask ))
817+ allow_is_causal_skip = False
818+ # Allow slight deviations from sliding causal mask
819+ if or_mask_function is not None :
820+ mask_factory_function = or_masks (mask_factory_function , or_mask_function )
821+ allow_is_causal_skip = False
822+ if and_mask_function is not None :
823+ mask_factory_function = and_masks (mask_factory_function , and_mask_function )
824+ allow_is_causal_skip = False
821825
822826 # We now create the mask
823827 causal_mask = mask_interface (
0 commit comments