diff --git a/CHANGELOG.md b/CHANGELOG.md index 4532fefd..90dfc9aa 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,11 +1,12 @@ ## TBD ### Bug fixes -* Add seekable() method in S3Reader to eliminate tensor copies during DCP loading (#359) * Override S3Writer closed property and block writes after close (#360) * Fix SequentialS3Reader seek beyond EOF to clamp position to object size (#362) ### Other changes +* Add seekable() method in S3Reader to eliminate tensor copies during DCP loading (#359) +* Add load ordering optimization to S3StorageReader for sequential access patterns (#372) * Add benchmark to run DCP Loading Workloads (#357) * Add thread_count parameter to S3StorageWriter (#370) diff --git a/s3torchconnector/src/s3torchconnector/dcp/s3_file_system.py b/s3torchconnector/src/s3torchconnector/dcp/s3_file_system.py index f7ab995b..3759e5a5 100644 --- a/s3torchconnector/src/s3torchconnector/dcp/s3_file_system.py +++ b/s3torchconnector/src/s3torchconnector/dcp/s3_file_system.py @@ -252,7 +252,7 @@ def _escape_path(string): return "/".join(parts) -from torch.distributed.checkpoint.planner import SavePlan +from torch.distributed.checkpoint.planner import SavePlan, LoadPlan import dataclasses from dataclasses import dataclass @@ -345,6 +345,20 @@ def __init__( def validate_checkpoint_id(cls, checkpoint_id: Union[str, os.PathLike]) -> bool: return S3FileSystem.validate_checkpoint_id(checkpoint_id) + def prepare_local_plan(self, plan: LoadPlan) -> LoadPlan: + """ + Sort load items by storage offset for sequential access optimization. + + Args: + plan (LoadPlan): The load plan from PyTorch DCP. + + Returns: + LoadPlan: The same plan with items sorted by storage offset. + """ + # Sort items in plan based on their offset in checkpoints shards + plan.items.sort(key=lambda item: self.storage_data[item.storage_index].offset) + return plan + def _path_or_str_to_str(path: Union[str, os.PathLike]) -> str: return path if isinstance(path, str) else str(path) diff --git a/s3torchconnector/tst/e2e/dcp/test_e2e_s3_storage_reader.py b/s3torchconnector/tst/e2e/dcp/test_e2e_s3_storage_reader.py new file mode 100644 index 00000000..d5e62af6 --- /dev/null +++ b/s3torchconnector/tst/e2e/dcp/test_e2e_s3_storage_reader.py @@ -0,0 +1,91 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# // SPDX-License-Identifier: BSD + +import pytest +from unittest.mock import patch + +import torch +import torch.nn as nn +import torch.distributed.checkpoint as dcp + +from s3torchconnector import S3ReaderConstructor +from s3torchconnector.dcp import S3StorageWriter, S3StorageReader +from s3torchconnector.s3reader.sequential import SequentialS3Reader + + +SIMPLE_MODEL = torch.nn.Sequential( + nn.Linear(5, 5), + nn.Linear(20, 20), + nn.Linear(10, 10), +) + + +class NeuralNetwork(nn.Module): + """NeuralNetwork from PyTorch quickstart tutorial.""" + + def __init__(self): + super().__init__() + self.flatten = nn.Flatten() + self.linear_relu_stack = nn.Sequential( + nn.Linear(28 * 28, 512), + nn.ReLU(), + nn.Linear(512, 512), + nn.ReLU(), + nn.Linear(512, 10), + ) + + +LARGER_MODEL = NeuralNetwork() + + +@pytest.mark.parametrize("model", [SIMPLE_MODEL, LARGER_MODEL]) +def test_dcp_load_reads_tensors_in_sequential_order(checkpoint_directory, model): + """ + Test that prepare_local_plan allows dcp.load() to read items in offset order. + + This does not prevent backwards seek, since torch.load() would still call + backwards seek operations. + + pytorch/torch/serialization.py load() function will call _is_zipfile(), which + includes this read() call: f.read(len(local_header_magic_number)). This is + followed by readinto() calls on the actual tensor. + + Hence we can track read() call positions to determine if load ordering is + being applied correctly. + """ + region = checkpoint_directory.region + s3_uri = checkpoint_directory.s3_uri + + state_dict = model.state_dict() + storage_writer = S3StorageWriter(region=region, path=s3_uri, overwrite=True) + dcp.save(state_dict, storage_writer=storage_writer) + + read_positions = [] + + original_read = SequentialS3Reader.read + + def track_reads(self, size=None): + if not self.key.endswith(".metadata"): + read_positions.append(self._position) + return original_read(self, size) + + # Load with position tracking on read() (called at the start of each torch.load()) + with patch.object(SequentialS3Reader, "read", track_reads): + loaded_state_dict = {k: torch.empty_like(v) for k, v in state_dict.items()} + storage_reader = S3StorageReader( + region=region, + path=s3_uri, + reader_constructor=S3ReaderConstructor.sequential(), + ) + dcp.load(loaded_state_dict, storage_reader=storage_reader) + + print(f"Read positions: {read_positions}") + + # Assert load ordering works (read() calls should be in sorted order) + assert read_positions == sorted(read_positions) + + # Assert all tensors are correctly loaded + assert len(loaded_state_dict) == len(state_dict) + assert loaded_state_dict.keys() == state_dict.keys() + for key in state_dict: + assert torch.equal(loaded_state_dict[key], state_dict[key]) diff --git a/s3torchconnector/tst/unit/dcp/test_s3_storage_reader.py b/s3torchconnector/tst/unit/dcp/test_s3_storage_reader.py new file mode 100644 index 00000000..f46dd56f --- /dev/null +++ b/s3torchconnector/tst/unit/dcp/test_s3_storage_reader.py @@ -0,0 +1,64 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# // SPDX-License-Identifier: BSD + +from unittest.mock import Mock +from hypothesis import given +from hypothesis.strategies import composite, integers, lists + +from torch.distributed.checkpoint.planner import LoadPlan, ReadItem + +from s3torchconnector.dcp import S3StorageReader + +TEST_REGION = "eu-east-1" +TEST_PATH = "s3://test-bucket/test-checkpoint/" + + +@composite +def load_plan_with_offsets(draw): + """Generate LoadPlan with random offsets.""" + offsets = draw(lists(integers(0, 10_000_000), min_size=1, max_size=10_000)) + + storage_data = {} + items = [] + + for i, offset in enumerate(offsets): + storage_index = f"item{i}" + storage_data[storage_index] = Mock(offset=offset) + items.append(Mock(spec=ReadItem, storage_index=storage_index)) + + return LoadPlan(items), storage_data + + +def test_s3storage_reader_prepare_local_plan_empty(): + """Test prepare_local_plan handles empty plans.""" + s3_storage_reader = S3StorageReader(TEST_REGION, TEST_PATH) + + sorted_plan = s3_storage_reader.prepare_local_plan(LoadPlan([])) + # Output: LoadPlan(items=[], storage_data=None, planner_data=None) + + assert isinstance(sorted_plan, LoadPlan) + assert len(sorted_plan.items) == 0 + + +@given(load_plan_with_offsets()) +def test_s3storage_reader_prepare_local_plan(loadplan_and_storagedata): + """Test prepare local plan sorts items by storage_data offset.""" + load_plan, storage_data = loadplan_and_storagedata + + s3_storage_reader = S3StorageReader(TEST_REGION, TEST_PATH) + s3_storage_reader.storage_data = storage_data + + sorted_plan = s3_storage_reader.prepare_local_plan(load_plan) + sorted_offsets = [ + storage_data[item.storage_index].offset for item in sorted_plan.items + ] + + # Verify return type + assert isinstance(sorted_plan, LoadPlan) + + # Verify Load Ordering sorts offsets + assert sorted_offsets == sorted(sorted_offsets) + + # Verify Load Ordering keeps items the same + assert len(sorted_plan.items) == len(load_plan.items) + assert set(sorted_plan.items) == set(load_plan.items) diff --git a/s3torchconnectorclient/pyproject.toml b/s3torchconnectorclient/pyproject.toml index 80d641d2..ee912318 100644 --- a/s3torchconnectorclient/pyproject.toml +++ b/s3torchconnectorclient/pyproject.toml @@ -142,8 +142,8 @@ inherit.test-command = "append" test-command = [ "python -m pip install -e '{package}/../s3torchconnector[dcp-test]'", "pytest {package}/../s3torchconnector/tst/unit/dcp", - "CI_STORAGE_CLASS='' CI_REGION=${S3_REGION} CI_BUCKET=${S3_BUCKET} CI_PREFIX=${S3_PREFIX} CI_CUSTOM_ENDPOINT_URL=${S3_CUSTOM_ENDPOINT_URL} pytest -s {package}/../s3torchconnector/tst/e2e/dcp/test_e2e_s3_file_system.py", - "AWS_DEFAULT_REGION=${S3_EXPRESS_REGION} CI_STORAGE_CLASS=EXPRESS_ONEZONE CI_REGION=${S3_EXPRESS_REGION} CI_BUCKET=${S3_EXPRESS_BUCKET} CI_PREFIX=${S3_PREFIX} CI_CUSTOM_ENDPOINT_URL='' pytest -s {package}/../s3torchconnector/tst/e2e/dcp/test_e2e_s3_file_system.py", + "CI_STORAGE_CLASS='' CI_REGION=${S3_REGION} CI_BUCKET=${S3_BUCKET} CI_PREFIX=${S3_PREFIX} CI_CUSTOM_ENDPOINT_URL=${S3_CUSTOM_ENDPOINT_URL} pytest -s {package}/../s3torchconnector/tst/e2e/dcp", + "AWS_DEFAULT_REGION=${S3_EXPRESS_REGION} CI_STORAGE_CLASS=EXPRESS_ONEZONE CI_REGION=${S3_EXPRESS_REGION} CI_BUCKET=${S3_EXPRESS_BUCKET} CI_PREFIX=${S3_PREFIX} CI_CUSTOM_ENDPOINT_URL='' pytest -s {package}/../s3torchconnector/tst/e2e/dcp", ] [[tool.cibuildwheel.overrides]]