Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 33 additions & 2 deletions s3torchconnector/src/s3torchconnector/dcp/s3_file_system.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import urllib.parse
from contextlib import contextmanager
from pathlib import Path
from typing import Generator, Union, Optional
from typing import Generator, Union, Optional, Callable, Any
from typing import List

from s3torchconnectorclient._mountpoint_s3_client import S3Exception
Expand Down Expand Up @@ -252,7 +252,7 @@ def _escape_path(string):
return "/".join(parts)


from torch.distributed.checkpoint.planner import SavePlan
from torch.distributed.checkpoint.planner import SavePlan, LoadPlan
import dataclasses
from dataclasses import dataclass

Expand All @@ -264,13 +264,29 @@ class StorageMetadata:
prefix: str


from torch.distributed.checkpoint.filesystem import (
_split_by_size_and_type as original_split,
)


def _ordered_split_by_size_and_type(
bins: int, items: List[Any], sort_key: Optional[Callable] = None
) -> List[List[Any]]:
buckets = original_split(bins, items)
if sort_key:
for bucket in buckets:
bucket.sort(key=sort_key)
return buckets


class S3StorageWriter(FileSystemWriter):
def __init__(
self,
region: str,
path: str,
s3client_config: Optional[S3ClientConfig] = None,
prefix_strategy: Optional[S3PrefixStrategyBase] = None,
sort_key: Optional[Callable] = None,
**kwargs,
) -> None:
"""
Expand All @@ -292,6 +308,16 @@ def __init__(
self.fs = S3FileSystem(region, s3client_config=s3client_config) # type: ignore
self.path = self.fs.init_path(path)
self.prefix_strategy = prefix_strategy or DefaultPrefixStrategy()
self.sort_key = sort_key

if self.sort_key:
# Replace the original split function with ours that will sort tensors/weights
import torch.distributed.checkpoint.filesystem as fs_module
from functools import partial

fs_module._split_by_size_and_type = partial(
_ordered_split_by_size_and_type, sort_key=sort_key
)

def prepare_global_plan(self, plans: List[SavePlan]) -> List[SavePlan]:
"""
Expand Down Expand Up @@ -342,6 +368,11 @@ def __init__(
def validate_checkpoint_id(cls, checkpoint_id: Union[str, os.PathLike]) -> bool:
return S3FileSystem.validate_checkpoint_id(checkpoint_id)

def prepare_local_plan(self, plan: LoadPlan) -> LoadPlan:
# Sort items in plan based on their offset in checkpoints shards
plan.items.sort(key=lambda item: self.storage_data[item.storage_index].offset)
return plan


def _path_or_str_to_str(path: Union[str, os.PathLike]) -> str:
return path if isinstance(path, str) else str(path)