File tree Expand file tree Collapse file tree 2 files changed +7
-6
lines changed Expand file tree Collapse file tree 2 files changed +7
-6
lines changed Original file line number Diff line number Diff line change @@ -304,15 +304,15 @@ def sdpa_mask_recent_torch(
304304
305305 # Similar to `kv_arange = mint.arange(start=kv_offset, end=kv_offset + kv_length, device=cache_position.device)`
306306 # but without data-dependent slicing (i.e. torch.compile friendly)
307- kv_arange = mint .arange (kv_length , device = cache_position . device )
307+ kv_arange = mint .arange (kv_length )
308308 kv_arange += kv_offset
309309
310310 # Potentially add the padding 2D mask
311311 if padding_mask is not None :
312312 mask_function = and_masks (mask_function , padding_mask_function (padding_mask ))
313313
314- batch_arange = mint .arange (batch_size , device = cache_position . device )
315- head_arange = mint .arange (1 , device = cache_position . device )
314+ batch_arange = mint .arange (batch_size )
315+ head_arange = mint .arange (1 )
316316 # This creates the 4D mask easily. Note that we need this context manager as vmap cannot handle slicing a tensor from
317317 # scalar tensor (it internally calls `.item()` which vmap does not allow, but this context works around it
318318 # We don't need to add an offset to the mask_function either, as we vmap directly the correct indices for k and kv indices
Original file line number Diff line number Diff line change @@ -2458,7 +2458,7 @@ def from_pretrained(
24582458
24592459 if transformers_explicit_filename is not None :
24602460 if not transformers_explicit_filename .endswith (
2461- ".safetensors"
2461+ ".safetensors"
24622462 ) and not transformers_explicit_filename .endswith (".safetensors.index.json" ):
24632463 raise ValueError (
24642464 "The transformers file in the config seems to be incorrect: it is neither a safetensors file "
@@ -2483,8 +2483,9 @@ def from_pretrained(
24832483 if is_local :
24842484 if transformers_explicit_filename is not None :
24852485 # If the filename is explicitly defined, load this by default.
2486- archive_file = os .path .join (pretrained_model_name_or_path , subfolder ,
2487- transformers_explicit_filename )
2486+ archive_file = os .path .join (
2487+ pretrained_model_name_or_path , subfolder , transformers_explicit_filename
2488+ )
24882489 is_sharded = transformers_explicit_filename .endswith (".safetensors.index.json" )
24892490 elif from_tf and os .path .isfile (
24902491 os .path .join (pretrained_model_name_or_path , subfolder , TF_WEIGHTS_NAME + ".index" )
You can’t perform that action at this time.
0 commit comments