Skip to content

Commit 9e85202

Browse files
author
Ilya Isaev
committed
Sort tensors/weight based on their offset in object when loading checkpoints. Enable custom sorting for tensor/weights when creating checkpoints.
1 parent 5ec1d89 commit 9e85202

File tree

1 file changed

+33
-2
lines changed

1 file changed

+33
-2
lines changed

s3torchconnector/src/s3torchconnector/dcp/s3_file_system.py

Lines changed: 33 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import urllib.parse
88
from contextlib import contextmanager
99
from pathlib import Path
10-
from typing import Generator, Union, Optional
10+
from typing import Generator, Union, Optional, Callable, Any
1111
from typing import List
1212

1313
from s3torchconnectorclient._mountpoint_s3_client import S3Exception
@@ -252,7 +252,7 @@ def _escape_path(string):
252252
return "/".join(parts)
253253

254254

255-
from torch.distributed.checkpoint.planner import SavePlan
255+
from torch.distributed.checkpoint.planner import SavePlan, LoadPlan
256256
import dataclasses
257257
from dataclasses import dataclass
258258

@@ -264,13 +264,29 @@ class StorageMetadata:
264264
prefix: str
265265

266266

267+
from torch.distributed.checkpoint.filesystem import (
268+
_split_by_size_and_type as original_split,
269+
)
270+
271+
272+
def _ordered_split_by_size_and_type(
273+
bins: int, items: List[Any], sort_key: Optional[Callable] = None
274+
) -> List[List[Any]]:
275+
buckets = original_split(bins, items)
276+
if sort_key:
277+
for bucket in buckets:
278+
bucket.sort(key=sort_key)
279+
return buckets
280+
281+
267282
class S3StorageWriter(FileSystemWriter):
268283
def __init__(
269284
self,
270285
region: str,
271286
path: str,
272287
s3client_config: Optional[S3ClientConfig] = None,
273288
prefix_strategy: Optional[S3PrefixStrategyBase] = None,
289+
sort_key: Optional[Callable] = None,
274290
**kwargs,
275291
) -> None:
276292
"""
@@ -292,6 +308,16 @@ def __init__(
292308
self.fs = S3FileSystem(region, s3client_config=s3client_config) # type: ignore
293309
self.path = self.fs.init_path(path)
294310
self.prefix_strategy = prefix_strategy or DefaultPrefixStrategy()
311+
self.sort_key = sort_key
312+
313+
if self.sort_key:
314+
# Replace the original split function with ours that will sort tensors/weights
315+
import torch.distributed.checkpoint.filesystem as fs_module
316+
from functools import partial
317+
318+
fs_module._split_by_size_and_type = partial(
319+
_ordered_split_by_size_and_type, sort_key=sort_key
320+
)
295321

296322
def prepare_global_plan(self, plans: List[SavePlan]) -> List[SavePlan]:
297323
"""
@@ -342,6 +368,11 @@ def __init__(
342368
def validate_checkpoint_id(cls, checkpoint_id: Union[str, os.PathLike]) -> bool:
343369
return S3FileSystem.validate_checkpoint_id(checkpoint_id)
344370

371+
def prepare_local_plan(self, plan: LoadPlan) -> LoadPlan:
372+
# Sort items in plan based on their offset in checkpoints shards
373+
plan.items.sort(key=lambda item: self.storage_data[item.storage_index].offset)
374+
return plan
375+
345376

346377
def _path_or_str_to_str(path: Union[str, os.PathLike]) -> str:
347378
return path if isinstance(path, str) else str(path)

0 commit comments

Comments
 (0)