Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
## TBD

### Bug fixes
* Add seekable() method in S3Reader to eliminate tensor copies during DCP loading (#359)
* Override S3Writer closed property and block writes after close (#360)
* Fix SequentialS3Reader seek beyond EOF to clamp position to object size (#362)

### Other changes
* Add seekable() method in S3Reader to eliminate tensor copies during DCP loading (#359)
* Add load ordering optimization to S3StorageReader for sequential access patterns (#372)
* Add benchmark to run DCP Loading Workloads (#357)
* Add thread_count parameter to S3StorageWriter (#370)

Expand Down
16 changes: 15 additions & 1 deletion s3torchconnector/src/s3torchconnector/dcp/s3_file_system.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,7 @@ def _escape_path(string):
return "/".join(parts)


from torch.distributed.checkpoint.planner import SavePlan
from torch.distributed.checkpoint.planner import SavePlan, LoadPlan
import dataclasses
from dataclasses import dataclass

Expand Down Expand Up @@ -345,6 +345,20 @@ def __init__(
def validate_checkpoint_id(cls, checkpoint_id: Union[str, os.PathLike]) -> bool:
return S3FileSystem.validate_checkpoint_id(checkpoint_id)

def prepare_local_plan(self, plan: LoadPlan) -> LoadPlan:
"""
Sort load items by storage offset for sequential access optimization.

Args:
plan (LoadPlan): The load plan from PyTorch DCP.

Returns:
LoadPlan: The same plan with items sorted by storage offset.
"""
# Sort items in plan based on their offset in checkpoints shards
plan.items.sort(key=lambda item: self.storage_data[item.storage_index].offset)
return plan


def _path_or_str_to_str(path: Union[str, os.PathLike]) -> str:
return path if isinstance(path, str) else str(path)
91 changes: 91 additions & 0 deletions s3torchconnector/tst/e2e/dcp/test_e2e_s3_storage_reader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# // SPDX-License-Identifier: BSD

import pytest
from unittest.mock import patch

import torch
import torch.nn as nn
import torch.distributed.checkpoint as dcp

from s3torchconnector import S3ReaderConstructor
from s3torchconnector.dcp import S3StorageWriter, S3StorageReader
from s3torchconnector.s3reader.sequential import SequentialS3Reader


SIMPLE_MODEL = torch.nn.Sequential(
nn.Linear(5, 5),
nn.Linear(20, 20),
nn.Linear(10, 10),
)


class NeuralNetwork(nn.Module):
"""NeuralNetwork from PyTorch quickstart tutorial."""

def __init__(self):
super().__init__()
self.flatten = nn.Flatten()
self.linear_relu_stack = nn.Sequential(
nn.Linear(28 * 28, 512),
nn.ReLU(),
nn.Linear(512, 512),
nn.ReLU(),
nn.Linear(512, 10),
)


LARGER_MODEL = NeuralNetwork()


@pytest.mark.parametrize("model", [SIMPLE_MODEL, LARGER_MODEL])
def test_dcp_load_reads_tensors_in_sequential_order(checkpoint_directory, model):
"""
Test that prepare_local_plan allows dcp.load() to read items in offset order.

This does not prevent backwards seek, since torch.load() would still call
backwards seek operations.

pytorch/torch/serialization.py load() function will call _is_zipfile(), which
includes this read() call: f.read(len(local_header_magic_number)). This is
followed by readinto() calls on the actual tensor.

Hence we can track read() call positions to determine if load ordering is
being applied correctly.
"""
region = checkpoint_directory.region
s3_uri = checkpoint_directory.s3_uri

state_dict = model.state_dict()
storage_writer = S3StorageWriter(region=region, path=s3_uri, overwrite=True)
dcp.save(state_dict, storage_writer=storage_writer)

read_positions = []

original_read = SequentialS3Reader.read

def track_reads(self, size=None):
if not self.key.endswith(".metadata"):
read_positions.append(self._position)
return original_read(self, size)

# Load with position tracking on read() (called at the start of each torch.load())
with patch.object(SequentialS3Reader, "read", track_reads):
loaded_state_dict = {k: torch.empty_like(v) for k, v in state_dict.items()}
storage_reader = S3StorageReader(
region=region,
path=s3_uri,
reader_constructor=S3ReaderConstructor.sequential(),
)
dcp.load(loaded_state_dict, storage_reader=storage_reader)

print(f"Read positions: {read_positions}")

# Assert load ordering works (read() calls should be in sorted order)
assert read_positions == sorted(read_positions)

# Assert all tensors are correctly loaded
assert len(loaded_state_dict) == len(state_dict)
assert loaded_state_dict.keys() == state_dict.keys()
for key in state_dict:
assert torch.equal(loaded_state_dict[key], state_dict[key])
64 changes: 64 additions & 0 deletions s3torchconnector/tst/unit/dcp/test_s3_storage_reader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# // SPDX-License-Identifier: BSD

from unittest.mock import Mock
from hypothesis import given
from hypothesis.strategies import composite, integers, lists

from torch.distributed.checkpoint.planner import LoadPlan, ReadItem

from s3torchconnector.dcp import S3StorageReader

TEST_REGION = "eu-east-1"
TEST_PATH = "s3://test-bucket/test-checkpoint/"


@composite
def load_plan_with_offsets(draw):
"""Generate LoadPlan with random offsets."""
offsets = draw(lists(integers(0, 10_000_000), min_size=1, max_size=10_000))

storage_data = {}
items = []

for i, offset in enumerate(offsets):
storage_index = f"item{i}"
storage_data[storage_index] = Mock(offset=offset)
items.append(Mock(spec=ReadItem, storage_index=storage_index))

return LoadPlan(items), storage_data


def test_s3storage_reader_prepare_local_plan_empty():
"""Test prepare_local_plan handles empty plans."""
s3_storage_reader = S3StorageReader(TEST_REGION, TEST_PATH)

sorted_plan = s3_storage_reader.prepare_local_plan(LoadPlan([]))
# Output: LoadPlan(items=[], storage_data=None, planner_data=None)

assert isinstance(sorted_plan, LoadPlan)
assert len(sorted_plan.items) == 0


@given(load_plan_with_offsets())
def test_s3storage_reader_prepare_local_plan(loadplan_and_storagedata):
"""Test prepare local plan sorts items by storage_data offset."""
load_plan, storage_data = loadplan_and_storagedata

s3_storage_reader = S3StorageReader(TEST_REGION, TEST_PATH)
s3_storage_reader.storage_data = storage_data

sorted_plan = s3_storage_reader.prepare_local_plan(load_plan)
sorted_offsets = [
storage_data[item.storage_index].offset for item in sorted_plan.items
]

# Verify return type
assert isinstance(sorted_plan, LoadPlan)

# Verify Load Ordering sorts offsets
assert sorted_offsets == sorted(sorted_offsets)

# Verify Load Ordering keeps items the same
assert len(sorted_plan.items) == len(load_plan.items)
assert set(sorted_plan.items) == set(load_plan.items)
4 changes: 2 additions & 2 deletions s3torchconnectorclient/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -142,8 +142,8 @@ inherit.test-command = "append"
test-command = [
"python -m pip install -e '{package}/../s3torchconnector[dcp-test]'",
"pytest {package}/../s3torchconnector/tst/unit/dcp",
"CI_STORAGE_CLASS='' CI_REGION=${S3_REGION} CI_BUCKET=${S3_BUCKET} CI_PREFIX=${S3_PREFIX} CI_CUSTOM_ENDPOINT_URL=${S3_CUSTOM_ENDPOINT_URL} pytest -s {package}/../s3torchconnector/tst/e2e/dcp/test_e2e_s3_file_system.py",
"AWS_DEFAULT_REGION=${S3_EXPRESS_REGION} CI_STORAGE_CLASS=EXPRESS_ONEZONE CI_REGION=${S3_EXPRESS_REGION} CI_BUCKET=${S3_EXPRESS_BUCKET} CI_PREFIX=${S3_PREFIX} CI_CUSTOM_ENDPOINT_URL='' pytest -s {package}/../s3torchconnector/tst/e2e/dcp/test_e2e_s3_file_system.py",
"CI_STORAGE_CLASS='' CI_REGION=${S3_REGION} CI_BUCKET=${S3_BUCKET} CI_PREFIX=${S3_PREFIX} CI_CUSTOM_ENDPOINT_URL=${S3_CUSTOM_ENDPOINT_URL} pytest -s {package}/../s3torchconnector/tst/e2e/dcp",
"AWS_DEFAULT_REGION=${S3_EXPRESS_REGION} CI_STORAGE_CLASS=EXPRESS_ONEZONE CI_REGION=${S3_EXPRESS_REGION} CI_BUCKET=${S3_EXPRESS_BUCKET} CI_PREFIX=${S3_PREFIX} CI_CUSTOM_ENDPOINT_URL='' pytest -s {package}/../s3torchconnector/tst/e2e/dcp",
]

[[tool.cibuildwheel.overrides]]
Expand Down