Skip to content

Commit ee27ea9

Browse files
committed
pre-commit
1 parent acb4f23 commit ee27ea9

File tree

2 files changed

+7
-6
lines changed

2 files changed

+7
-6
lines changed

mindone/transformers/masking_utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -304,15 +304,15 @@ def sdpa_mask_recent_torch(
304304

305305
# Similar to `kv_arange = mint.arange(start=kv_offset, end=kv_offset + kv_length, device=cache_position.device)`
306306
# but without data-dependent slicing (i.e. torch.compile friendly)
307-
kv_arange = mint.arange(kv_length, device=cache_position.device)
307+
kv_arange = mint.arange(kv_length)
308308
kv_arange += kv_offset
309309

310310
# Potentially add the padding 2D mask
311311
if padding_mask is not None:
312312
mask_function = and_masks(mask_function, padding_mask_function(padding_mask))
313313

314-
batch_arange = mint.arange(batch_size, device=cache_position.device)
315-
head_arange = mint.arange(1, device=cache_position.device)
314+
batch_arange = mint.arange(batch_size)
315+
head_arange = mint.arange(1)
316316
# This creates the 4D mask easily. Note that we need this context manager as vmap cannot handle slicing a tensor from
317317
# scalar tensor (it internally calls `.item()` which vmap does not allow, but this context works around it
318318
# We don't need to add an offset to the mask_function either, as we vmap directly the correct indices for k and kv indices

mindone/transformers/modeling_utils.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2458,7 +2458,7 @@ def from_pretrained(
24582458

24592459
if transformers_explicit_filename is not None:
24602460
if not transformers_explicit_filename.endswith(
2461-
".safetensors"
2461+
".safetensors"
24622462
) and not transformers_explicit_filename.endswith(".safetensors.index.json"):
24632463
raise ValueError(
24642464
"The transformers file in the config seems to be incorrect: it is neither a safetensors file "
@@ -2483,8 +2483,9 @@ def from_pretrained(
24832483
if is_local:
24842484
if transformers_explicit_filename is not None:
24852485
# If the filename is explicitly defined, load this by default.
2486-
archive_file = os.path.join(pretrained_model_name_or_path, subfolder,
2487-
transformers_explicit_filename)
2486+
archive_file = os.path.join(
2487+
pretrained_model_name_or_path, subfolder, transformers_explicit_filename
2488+
)
24882489
is_sharded = transformers_explicit_filename.endswith(".safetensors.index.json")
24892490
elif from_tf and os.path.isfile(
24902491
os.path.join(pretrained_model_name_or_path, subfolder, TF_WEIGHTS_NAME + ".index")

0 commit comments

Comments
 (0)