77import urllib .parse
88from contextlib import contextmanager
99from pathlib import Path
10- from typing import Generator , Union , Optional
10+ from typing import Generator , Union , Optional , Callable , Any
1111from typing import List
1212
1313from s3torchconnectorclient ._mountpoint_s3_client import S3Exception
@@ -252,7 +252,7 @@ def _escape_path(string):
252252 return "/" .join (parts )
253253
254254
255- from torch .distributed .checkpoint .planner import SavePlan
255+ from torch .distributed .checkpoint .planner import SavePlan , LoadPlan
256256import dataclasses
257257from dataclasses import dataclass
258258
@@ -264,13 +264,29 @@ class StorageMetadata:
264264 prefix : str
265265
266266
267+ from torch .distributed .checkpoint .filesystem import (
268+ _split_by_size_and_type as original_split ,
269+ )
270+
271+
272+ def _ordered_split_by_size_and_type (
273+ bins : int , items : List [Any ], sort_key : Optional [Callable ] = None
274+ ) -> List [List [Any ]]:
275+ buckets = original_split (bins , items )
276+ if sort_key :
277+ for bucket in buckets :
278+ bucket .sort (key = sort_key )
279+ return buckets
280+
281+
267282class S3StorageWriter (FileSystemWriter ):
268283 def __init__ (
269284 self ,
270285 region : str ,
271286 path : str ,
272287 s3client_config : Optional [S3ClientConfig ] = None ,
273288 prefix_strategy : Optional [S3PrefixStrategyBase ] = None ,
289+ sort_key : Optional [Callable ] = None ,
274290 ** kwargs ,
275291 ) -> None :
276292 """
@@ -292,6 +308,16 @@ def __init__(
292308 self .fs = S3FileSystem (region , s3client_config = s3client_config ) # type: ignore
293309 self .path = self .fs .init_path (path )
294310 self .prefix_strategy = prefix_strategy or DefaultPrefixStrategy ()
311+ self .sort_key = sort_key
312+
313+ if self .sort_key :
314+ # Replace the original split function with ours that will sort tensors/weights
315+ import torch .distributed .checkpoint .filesystem as fs_module
316+ from functools import partial
317+
318+ fs_module ._split_by_size_and_type = partial (
319+ _ordered_split_by_size_and_type , sort_key = sort_key
320+ )
295321
296322 def prepare_global_plan (self , plans : List [SavePlan ]) -> List [SavePlan ]:
297323 """
@@ -342,6 +368,11 @@ def __init__(
342368 def validate_checkpoint_id (cls , checkpoint_id : Union [str , os .PathLike ]) -> bool :
343369 return S3FileSystem .validate_checkpoint_id (checkpoint_id )
344370
371+ def prepare_local_plan (self , plan : LoadPlan ) -> LoadPlan :
372+ # Sort items in plan based on their offset in checkpoints shards
373+ plan .items .sort (key = lambda item : self .storage_data [item .storage_index ].offset )
374+ return plan
375+
345376
346377def _path_or_str_to_str (path : Union [str , os .PathLike ]) -> str :
347378 return path if isinstance (path , str ) else str (path )
0 commit comments