Skip to content

Commit 24debc8

Browse files
committed
revert masking_utils modification
1 parent f2258df commit 24debc8

File tree

1 file changed

+34
-30
lines changed

1 file changed

+34
-30
lines changed

mindone/transformers/masking_utils.py

Lines changed: 34 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -304,20 +304,22 @@ def sdpa_mask_recent_torch(
304304

305305
# Similar to `kv_arange = mint.arange(start=kv_offset, end=kv_offset + kv_length, device=cache_position.device)`
306306
# but without data-dependent slicing (i.e. torch.compile friendly)
307-
kv_arange = mint.arange(kv_length, device=cache_position.device)
307+
kv_arange = mint.arange(kv_length)
308308
kv_arange += kv_offset
309309

310310
# Potentially add the padding 2D mask
311311
if padding_mask is not None:
312312
mask_function = and_masks(mask_function, padding_mask_function(padding_mask))
313313

314-
batch_arange = mint.arange(batch_size, device=cache_position.device)
315-
head_arange = mint.arange(1, device=cache_position.device)
314+
batch_arange = mint.arange(batch_size)
315+
head_arange = mint.arange(1)
316316
# This creates the 4D mask easily. Note that we need this context manager as vmap cannot handle slicing a tensor from
317317
# scalar tensor (it internally calls `.item()` which vmap does not allow, but this context works around it
318318
# We don't need to add an offset to the mask_function either, as we vmap directly the correct indices for k and kv indices
319319
# with TransformGetItemToIndex():
320-
causal_mask = _vmap_for_bhqkv(mask_function)(batch_arange, head_arange, cache_position, kv_arange)
320+
# TODO There is a compile problem if using 'mindspore.vmap', we abandon this operator and generate 2D mask --> 4D mask step by step
321+
causal_mask = mask_function()(batch_arange, head_arange, cache_position, kv_arange)
322+
causal_mask = causal_mask[None, None, :, :].broadcast_to((batch_size, -1, -1, -1))
321323

322324
return causal_mask
323325

@@ -383,7 +385,8 @@ def sdpa_mask_older_torch(
383385
# as vmap cannot handle slicing a tensor from scalar tensor (it internally calls `.item()` which vmap does not allow
384386
# However, in more recent version of Pytorch, a trick was introduced to handle it - which is the reason we have
385387
# `sdpa_mask_recent_torch`, as it allows more general `mask_function`
386-
causal_mask = _vmap_for_bhqkv(mask_function, bh_indices=False)(None, None, cache_position, kv_arange)
388+
# TODO There is a compile problem if using 'mindspore.vmap', we abandon this operator and generate 2D mask --> 4D mask step by step
389+
causal_mask = mask_function()(None, None, cache_position, kv_arange)
387390
causal_mask = causal_mask[None, None, :, :].broadcast_to((batch_size, -1, -1, -1))
388391
if padding_mask is not None:
389392
causal_mask = causal_mask * padding_mask[:, None, None, :]
@@ -436,7 +439,8 @@ def _ignore_causal_mask_sdpa(
436439

437440
# We use the version with newer torch whenever possible, as it is more general and can handle arbitrary mask functions
438441
# (especially mask_function indexing a tensor, such as the padding mask function)
439-
sdpa_mask = sdpa_mask_older_torch # TODO: use sdpa_mask_recent_torch orsdpa_mask_older_torch?
442+
# TODO we do not go through older sdpa func like transformers setting, default setting is set to `sdpa_mask_recent_torch`
443+
sdpa_mask = sdpa_mask_recent_torch
440444

441445

442446
def eager_mask(
@@ -669,6 +673,8 @@ def create_causal_mask(
669673
cache_position: ms.Tensor,
670674
past_key_values: Optional[Cache],
671675
position_ids: Optional[ms.Tensor] = None,
676+
or_mask_function: Optional[Callable] = None,
677+
and_mask_function: Optional[Callable] = None,
672678
) -> Optional[Union[ms.Tensor, BlockMask]]:
673679
"""
674680
Create a standard causal mask based on the attention implementation used (stored in the config). If `past_key_values`
@@ -717,18 +723,17 @@ def create_causal_mask(
717723
# TODO: cyril -> probably revisit and remove this, but a lot of tests rely on it
718724
allow_is_causal_skip = not past_key_values.is_compileable if past_key_values is not None else True
719725

720-
# TODO there is a compile problem during and_masks/or_masks func used as mask_factory_function, Comment this part firstly
721-
# # If we detected packing format
722-
# if packed_sequence_mask is not None:
723-
# mask_factory_function = and_masks(mask_factory_function, packed_sequence_mask_function(packed_sequence_mask))
724-
# allow_is_causal_skip = False
725-
# # Allow slight deviations from causal mask
726-
# if or_mask_function is not None:
727-
# mask_factory_function = or_masks(mask_factory_function, or_mask_function)
728-
# allow_is_causal_skip = False
729-
# if and_mask_function is not None:
730-
# mask_factory_function = and_masks(mask_factory_function, and_mask_function)
731-
# allow_is_causal_skip = False
726+
# If we detected packing format
727+
if packed_sequence_mask is not None:
728+
mask_factory_function = and_masks(mask_factory_function, packed_sequence_mask_function(packed_sequence_mask))
729+
allow_is_causal_skip = False
730+
# Allow slight deviations from causal mask
731+
if or_mask_function is not None:
732+
mask_factory_function = or_masks(mask_factory_function, or_mask_function)
733+
allow_is_causal_skip = False
734+
if and_mask_function is not None:
735+
mask_factory_function = and_masks(mask_factory_function, and_mask_function)
736+
allow_is_causal_skip = False
732737

733738
# We now create the mask
734739
causal_mask = mask_interface(
@@ -806,18 +811,17 @@ def create_sliding_window_causal_mask(
806811
# Do not allow skip if we are compiling (this is to match BC)
807812
# TODO: cyril -> probably revisit and remove this, but a lot of tests rely on it
808813
allow_is_causal_skip = not past_key_values.is_compileable if past_key_values is not None else True
809-
# # TODO there is a compile problem during and_masks/or_masks func used as mask_factory_function, Comment this part firstly
810-
# # If we detected packing format
811-
# if packed_sequence_mask is not None:
812-
# mask_factory_function = and_masks(mask_factory_function, packed_sequence_mask_function(packed_sequence_mask))
813-
# allow_is_causal_skip = False
814-
# # Allow slight deviations from sliding causal mask
815-
# if or_mask_function is not None:
816-
# mask_factory_function = or_masks(mask_factory_function, or_mask_function)
817-
# allow_is_causal_skip = False
818-
# if and_mask_function is not None:
819-
# mask_factory_function = and_masks(mask_factory_function, and_mask_function)
820-
# allow_is_causal_skip = False
814+
# If we detected packing format
815+
if packed_sequence_mask is not None:
816+
mask_factory_function = and_masks(mask_factory_function, packed_sequence_mask_function(packed_sequence_mask))
817+
allow_is_causal_skip = False
818+
# Allow slight deviations from sliding causal mask
819+
if or_mask_function is not None:
820+
mask_factory_function = or_masks(mask_factory_function, or_mask_function)
821+
allow_is_causal_skip = False
822+
if and_mask_function is not None:
823+
mask_factory_function = and_masks(mask_factory_function, and_mask_function)
824+
allow_is_causal_skip = False
821825

822826
# We now create the mask
823827
causal_mask = mask_interface(

0 commit comments

Comments
 (0)