Skip to content

Commit 8c10ef5

Browse files
author
Ilya Isaev
committed
Sort tensors/weight based on their offset in object when loading checkpoints
1 parent 706ac44 commit 8c10ef5

File tree

1 file changed

+6
-1
lines changed

1 file changed

+6
-1
lines changed

s3torchconnector/src/s3torchconnector/dcp/s3_file_system.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -252,7 +252,7 @@ def _escape_path(string):
252252
return "/".join(parts)
253253

254254

255-
from torch.distributed.checkpoint.planner import SavePlan
255+
from torch.distributed.checkpoint.planner import SavePlan, LoadPlan
256256
import dataclasses
257257
from dataclasses import dataclass
258258

@@ -359,6 +359,11 @@ def __init__(
359359
def validate_checkpoint_id(cls, checkpoint_id: Union[str, os.PathLike]) -> bool:
360360
return S3FileSystem.validate_checkpoint_id(checkpoint_id)
361361

362+
def prepare_local_plan(self, plan: LoadPlan) -> LoadPlan:
363+
# Sort items in plan based on their offset in checkpoints shards
364+
plan.items.sort(key=lambda item: self.storage_data[item.storage_index].offset)
365+
return plan
366+
362367

363368
def _path_or_str_to_str(path: Union[str, os.PathLike]) -> str:
364369
return path if isinstance(path, str) else str(path)

0 commit comments

Comments
 (0)