Skip to content

Commit 656194e

Browse files
Ilya Isaevjet-tong
authored andcommitted
Sort tensors/weight based on their offset in object when loading checkpoints
Cherry-picked prepare_local_plan method from upstream PR awslabs#352. Sequentially loads items based on their actual offset in checkpoint shards, ensuring sequential access patterns and improving I/O efficiency.
1 parent 6b1c431 commit 656194e

File tree

1 file changed

+6
-1
lines changed

1 file changed

+6
-1
lines changed

s3torchconnector/src/s3torchconnector/dcp/s3_file_system.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -252,7 +252,7 @@ def _escape_path(string):
252252
return "/".join(parts)
253253

254254

255-
from torch.distributed.checkpoint.planner import SavePlan
255+
from torch.distributed.checkpoint.planner import SavePlan, LoadPlan
256256
import dataclasses
257257
from dataclasses import dataclass
258258

@@ -345,6 +345,11 @@ def __init__(
345345
def validate_checkpoint_id(cls, checkpoint_id: Union[str, os.PathLike]) -> bool:
346346
return S3FileSystem.validate_checkpoint_id(checkpoint_id)
347347

348+
def prepare_local_plan(self, plan: LoadPlan) -> LoadPlan:
349+
# Sort items in plan based on their offset in checkpoints shards
350+
plan.items.sort(key=lambda item: self.storage_data[item.storage_index].offset)
351+
return plan
352+
348353

349354
def _path_or_str_to_str(path: Union[str, os.PathLike]) -> str:
350355
return path if isinstance(path, str) else str(path)

0 commit comments

Comments
 (0)