From 8d4bf043a1f9b93ae0ba07f03bbc6c5649f6201a Mon Sep 17 00:00:00 2001 From: Ilya Isaev Date: Fri, 26 Sep 2025 10:52:09 +0100 Subject: [PATCH 01/13] Sort tensors/weight based on their offset in object when loading checkpoints Cherry-picked prepare_local_plan method from upstream PR #352. Sequentially loads items based on their actual offset in checkpoint shards, ensuring sequential access patterns and improving I/O efficiency. --- .../src/s3torchconnector/dcp/s3_file_system.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/s3torchconnector/src/s3torchconnector/dcp/s3_file_system.py b/s3torchconnector/src/s3torchconnector/dcp/s3_file_system.py index f7ab995b..4bcc6824 100644 --- a/s3torchconnector/src/s3torchconnector/dcp/s3_file_system.py +++ b/s3torchconnector/src/s3torchconnector/dcp/s3_file_system.py @@ -252,7 +252,7 @@ def _escape_path(string): return "/".join(parts) -from torch.distributed.checkpoint.planner import SavePlan +from torch.distributed.checkpoint.planner import SavePlan, LoadPlan import dataclasses from dataclasses import dataclass @@ -345,6 +345,11 @@ def __init__( def validate_checkpoint_id(cls, checkpoint_id: Union[str, os.PathLike]) -> bool: return S3FileSystem.validate_checkpoint_id(checkpoint_id) + def prepare_local_plan(self, plan: LoadPlan) -> LoadPlan: + # Sort items in plan based on their offset in checkpoints shards + plan.items.sort(key=lambda item: self.storage_data[item.storage_index].offset) + return plan + def _path_or_str_to_str(path: Union[str, os.PathLike]) -> str: return path if isinstance(path, str) else str(path) From 78f814d30e2d365891a0c4787ad94d94fe81cfb8 Mon Sep 17 00:00:00 2001 From: Jensen Tong Date: Mon, 29 Sep 2025 12:31:56 +0100 Subject: [PATCH 02/13] test(dcp): add unit tests for S3StorageReader load ordering - Hypothesis composite to generate LoadPlan with random offsets - Test prepare_local_plan method sorts items by storage offset - Test DCP automatically applies sorting via prepare_local_plan --- .../tst/unit/dcp/test_s3_storage_reader.py | 120 ++++++++++++++++++ 1 file changed, 120 insertions(+) create mode 100644 s3torchconnector/tst/unit/dcp/test_s3_storage_reader.py diff --git a/s3torchconnector/tst/unit/dcp/test_s3_storage_reader.py b/s3torchconnector/tst/unit/dcp/test_s3_storage_reader.py new file mode 100644 index 00000000..aa8b7ba8 --- /dev/null +++ b/s3torchconnector/tst/unit/dcp/test_s3_storage_reader.py @@ -0,0 +1,120 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# // SPDX-License-Identifier: BSD + +from typing import Dict, Any +from unittest.mock import Mock +from hypothesis import given, assume +from hypothesis.strategies import composite, integers, lists + +import torch +import torch.distributed.checkpoint as dcp +from torch.distributed.checkpoint.planner import LoadPlan, ReadItem, LoadItemType +from torch.distributed.checkpoint.metadata import ( + Metadata, + MetadataIndex, + TensorStorageMetadata, + ChunkStorageMetadata, +) + +from s3torchconnector.dcp import S3StorageReader + +TEST_REGION = "eu-east-1" +TEST_PATH = "s3://test-bucket/test-checkpoint/" + + +@composite +def load_plan_with_offsets(draw): + """Generate LoadPlan with random offsets.""" + offsets = draw(lists(integers(0, 10_000_000), min_size=0, max_size=10_000)) + + storage_data = {} + items = [] + + for i, offset in enumerate(offsets): + metadata_index = MetadataIndex(fqn=f"item{i}", offset=torch.Size([0]), index=0) + + # Mock storage info + storage_data[metadata_index] = Mock( + offset=offset, + length=draw(integers(1000, 50000)), # DCP requires length - use random integers + relative_path=f"__{draw(integers(0, 7))}_0.distcp", + ) + + items.append( + Mock(spec=ReadItem, storage_index=metadata_index, type=LoadItemType.TENSOR) + ) + + return LoadPlan(items), storage_data # type: ignore + + +@given(load_plan_with_offsets()) +def test_s3storage_reader_prepare_local_plan(loadplan_and_storagedata): + """Test prepare local plan sorts items by storage_data offset.""" + load_plan, storage_data = loadplan_and_storagedata + + s3_storage_reader = S3StorageReader(TEST_REGION, TEST_PATH) + s3_storage_reader.storage_data = storage_data + + sorted_plan = s3_storage_reader.prepare_local_plan(load_plan) + offsets = [storage_data[item.storage_index].offset for item in sorted_plan.items] + + # Verify Load Ordering sorts offsets + assert offsets == sorted(offsets) + + # Verify Load Ordering keeps items the same + assert len(sorted_plan.items) == len(load_plan.items) + assert {item.storage_index for item in sorted_plan.items} == { + item.storage_index for item in load_plan.items + } + + +@given(load_plan_with_offsets()) +def test_s3storage_reader_dcp_load_uses_load_ordering(loadplan_and_storagedata): + """Test that DCP automatically calls our load ordering optimization via prepare_local_plan.""" + load_plan, storage_data = loadplan_and_storagedata + + # Skip test cases where input is already sorted + original_offsets = [storage_data[item.storage_index].offset for item in load_plan.items] + assume(original_offsets != sorted(original_offsets)) + assume(len(original_offsets) > 0) + + # Minimal tensor metadata to satisfy DCP's validation requirements + state_dict_metadata: Dict[str, Any] = { + f"item{i}": TensorStorageMetadata( + properties=Mock(dtype=torch.float32), # tensor type validation + size=torch.Size([10]), # memory allocation + chunks=[ # chunk info for distributed loading + ChunkStorageMetadata(offsets=torch.Size([0]), sizes=torch.Size([10])) + ], + ) + for i in range(len(load_plan.items)) + } + + # Create S3StorageReader with mock read_metadata (iterable) and read_data + s3_storage_reader = S3StorageReader(TEST_REGION, TEST_PATH) + s3_storage_reader.read_metadata = Mock( + return_value=Metadata( + state_dict_metadata=state_dict_metadata, # Real dict for DCP iteration + storage_data=storage_data, # Our test data with random offsets + ) + ) + s3_storage_reader.read_data = Mock() + + # Create state_dict matching the metadata structure + state_dict = {f"item{i}": torch.zeros(10) for i in range(len(load_plan.items))} + + # 1. In torch/distributed/checkpoint/state_dict_loader.py: dcp.load() calls _load_state_dict; + # 2. According to torch/distributed/checkpoint/storage.py StorageWriter docstring, _load_state_dict() calls: + # read_metadata() > set_up_storage_reader() > prepare_local_plan() > prepare_global_plan() > read_data() + dcp.load(state_dict, storage_reader=s3_storage_reader) + + # When read_data is called, verify prepare_local_plan was called and sorted the items + sorted_plan = s3_storage_reader.read_data.call_args[0][0] # First arg is the plan + sorted_offsets = [storage_data[item.storage_index].offset for item in sorted_plan.items] + assert sorted_offsets == sorted(sorted_offsets) + + # Verify Load Ordering keeps items the same + assert len(sorted_plan.items) == len(load_plan.items) + assert {item.storage_index for item in sorted_plan.items} == { + item.storage_index for item in load_plan.items + } \ No newline at end of file From 85d18ca6d490360753844d34e74cc89a8500a368 Mon Sep 17 00:00:00 2001 From: Jensen Tong Date: Mon, 29 Sep 2025 12:44:45 +0100 Subject: [PATCH 03/13] docs(dcp): add docstrings for S3StorageReader load ordering optimization - Add docstring to prepare_local_plan method - Update CHANGELOG --- CHANGELOG.md | 3 ++- .../src/s3torchconnector/dcp/s3_file_system.py | 9 +++++++++ 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 4532fefd..90dfc9aa 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,11 +1,12 @@ ## TBD ### Bug fixes -* Add seekable() method in S3Reader to eliminate tensor copies during DCP loading (#359) * Override S3Writer closed property and block writes after close (#360) * Fix SequentialS3Reader seek beyond EOF to clamp position to object size (#362) ### Other changes +* Add seekable() method in S3Reader to eliminate tensor copies during DCP loading (#359) +* Add load ordering optimization to S3StorageReader for sequential access patterns (#372) * Add benchmark to run DCP Loading Workloads (#357) * Add thread_count parameter to S3StorageWriter (#370) diff --git a/s3torchconnector/src/s3torchconnector/dcp/s3_file_system.py b/s3torchconnector/src/s3torchconnector/dcp/s3_file_system.py index 4bcc6824..427e05ea 100644 --- a/s3torchconnector/src/s3torchconnector/dcp/s3_file_system.py +++ b/s3torchconnector/src/s3torchconnector/dcp/s3_file_system.py @@ -346,6 +346,15 @@ def validate_checkpoint_id(cls, checkpoint_id: Union[str, os.PathLike]) -> bool: return S3FileSystem.validate_checkpoint_id(checkpoint_id) def prepare_local_plan(self, plan: LoadPlan) -> LoadPlan: + """ + Sort load items by storage offset for sequential access optimization. + + Args: + plan (LoadPlan): The load plan from PyTorch DCP. + + Returns: + LoadPlan: The same plan with items sorted by storage offset. + """ # Sort items in plan based on their offset in checkpoints shards plan.items.sort(key=lambda item: self.storage_data[item.storage_index].offset) return plan From 825504685b4e42bb6ac4ef2f8cac84af6a0831cc Mon Sep 17 00:00:00 2001 From: Jensen Tong Date: Mon, 29 Sep 2025 14:35:31 +0100 Subject: [PATCH 04/13] style: apply black formatting --- .../s3torchconnector/dcp/s3_file_system.py | 4 +-- .../tst/unit/dcp/test_s3_storage_reader.py | 28 +++++++++++-------- 2 files changed, 19 insertions(+), 13 deletions(-) diff --git a/s3torchconnector/src/s3torchconnector/dcp/s3_file_system.py b/s3torchconnector/src/s3torchconnector/dcp/s3_file_system.py index 427e05ea..3759e5a5 100644 --- a/s3torchconnector/src/s3torchconnector/dcp/s3_file_system.py +++ b/s3torchconnector/src/s3torchconnector/dcp/s3_file_system.py @@ -348,10 +348,10 @@ def validate_checkpoint_id(cls, checkpoint_id: Union[str, os.PathLike]) -> bool: def prepare_local_plan(self, plan: LoadPlan) -> LoadPlan: """ Sort load items by storage offset for sequential access optimization. - + Args: plan (LoadPlan): The load plan from PyTorch DCP. - + Returns: LoadPlan: The same plan with items sorted by storage offset. """ diff --git a/s3torchconnector/tst/unit/dcp/test_s3_storage_reader.py b/s3torchconnector/tst/unit/dcp/test_s3_storage_reader.py index aa8b7ba8..79c7e4ba 100644 --- a/s3torchconnector/tst/unit/dcp/test_s3_storage_reader.py +++ b/s3torchconnector/tst/unit/dcp/test_s3_storage_reader.py @@ -36,7 +36,9 @@ def load_plan_with_offsets(draw): # Mock storage info storage_data[metadata_index] = Mock( offset=offset, - length=draw(integers(1000, 50000)), # DCP requires length - use random integers + length=draw( + integers(1000, 50000) + ), # DCP requires length - use random integers relative_path=f"__{draw(integers(0, 7))}_0.distcp", ) @@ -74,16 +76,18 @@ def test_s3storage_reader_dcp_load_uses_load_ordering(loadplan_and_storagedata): load_plan, storage_data = loadplan_and_storagedata # Skip test cases where input is already sorted - original_offsets = [storage_data[item.storage_index].offset for item in load_plan.items] + original_offsets = [ + storage_data[item.storage_index].offset for item in load_plan.items + ] assume(original_offsets != sorted(original_offsets)) assume(len(original_offsets) > 0) # Minimal tensor metadata to satisfy DCP's validation requirements state_dict_metadata: Dict[str, Any] = { f"item{i}": TensorStorageMetadata( - properties=Mock(dtype=torch.float32), # tensor type validation - size=torch.Size([10]), # memory allocation - chunks=[ # chunk info for distributed loading + properties=Mock(dtype=torch.float32), # tensor type validation + size=torch.Size([10]), # memory allocation + chunks=[ # chunk info for distributed loading ChunkStorageMetadata(offsets=torch.Size([0]), sizes=torch.Size([10])) ], ) @@ -94,8 +98,8 @@ def test_s3storage_reader_dcp_load_uses_load_ordering(loadplan_and_storagedata): s3_storage_reader = S3StorageReader(TEST_REGION, TEST_PATH) s3_storage_reader.read_metadata = Mock( return_value=Metadata( - state_dict_metadata=state_dict_metadata, # Real dict for DCP iteration - storage_data=storage_data, # Our test data with random offsets + state_dict_metadata=state_dict_metadata, # Real dict for DCP iteration + storage_data=storage_data, # Our test data with random offsets ) ) s3_storage_reader.read_data = Mock() @@ -103,18 +107,20 @@ def test_s3storage_reader_dcp_load_uses_load_ordering(loadplan_and_storagedata): # Create state_dict matching the metadata structure state_dict = {f"item{i}": torch.zeros(10) for i in range(len(load_plan.items))} - # 1. In torch/distributed/checkpoint/state_dict_loader.py: dcp.load() calls _load_state_dict; - # 2. According to torch/distributed/checkpoint/storage.py StorageWriter docstring, _load_state_dict() calls: + # 1. In torch/distributed/checkpoint/state_dict_loader.py: dcp.load() calls _load_state_dict; + # 2. According to torch/distributed/checkpoint/storage.py StorageWriter docstring, _load_state_dict() calls: # read_metadata() > set_up_storage_reader() > prepare_local_plan() > prepare_global_plan() > read_data() dcp.load(state_dict, storage_reader=s3_storage_reader) # When read_data is called, verify prepare_local_plan was called and sorted the items sorted_plan = s3_storage_reader.read_data.call_args[0][0] # First arg is the plan - sorted_offsets = [storage_data[item.storage_index].offset for item in sorted_plan.items] + sorted_offsets = [ + storage_data[item.storage_index].offset for item in sorted_plan.items + ] assert sorted_offsets == sorted(sorted_offsets) # Verify Load Ordering keeps items the same assert len(sorted_plan.items) == len(load_plan.items) assert {item.storage_index for item in sorted_plan.items} == { item.storage_index for item in load_plan.items - } \ No newline at end of file + } From af210e454d31c8b9412bfa0639d63ee134943b56 Mon Sep 17 00:00:00 2001 From: Jensen Tong Date: Tue, 30 Sep 2025 13:20:17 +0100 Subject: [PATCH 05/13] fix(test): address review comments - Verify return type (LoadPlan) - Remove redundant assume() calls - Converted to real ReadItem so we can check sorted_plan items directly --- .../tst/unit/dcp/test_s3_storage_reader.py | 38 +++++++++++-------- 1 file changed, 22 insertions(+), 16 deletions(-) diff --git a/s3torchconnector/tst/unit/dcp/test_s3_storage_reader.py b/s3torchconnector/tst/unit/dcp/test_s3_storage_reader.py index 79c7e4ba..8726536e 100644 --- a/s3torchconnector/tst/unit/dcp/test_s3_storage_reader.py +++ b/s3torchconnector/tst/unit/dcp/test_s3_storage_reader.py @@ -43,7 +43,14 @@ def load_plan_with_offsets(draw): ) items.append( - Mock(spec=ReadItem, storage_index=metadata_index, type=LoadItemType.TENSOR) + ReadItem( + storage_index=metadata_index, + type=LoadItemType.TENSOR, + dest_index=metadata_index, + dest_offsets=torch.Size([0]), + storage_offsets=torch.Size([0]), + lengths=torch.Size([10]), + ) ) return LoadPlan(items), storage_data # type: ignore @@ -58,16 +65,19 @@ def test_s3storage_reader_prepare_local_plan(loadplan_and_storagedata): s3_storage_reader.storage_data = storage_data sorted_plan = s3_storage_reader.prepare_local_plan(load_plan) - offsets = [storage_data[item.storage_index].offset for item in sorted_plan.items] + sorted_offsets = [ + storage_data[item.storage_index].offset for item in sorted_plan.items + ] + + # Verify return type + assert isinstance(sorted_plan, LoadPlan) # Verify Load Ordering sorts offsets - assert offsets == sorted(offsets) + assert sorted_offsets == sorted(sorted_offsets) # Verify Load Ordering keeps items the same assert len(sorted_plan.items) == len(load_plan.items) - assert {item.storage_index for item in sorted_plan.items} == { - item.storage_index for item in load_plan.items - } + assert set(sorted_plan.items) == set(load_plan.items) @given(load_plan_with_offsets()) @@ -75,13 +85,6 @@ def test_s3storage_reader_dcp_load_uses_load_ordering(loadplan_and_storagedata): """Test that DCP automatically calls our load ordering optimization via prepare_local_plan.""" load_plan, storage_data = loadplan_and_storagedata - # Skip test cases where input is already sorted - original_offsets = [ - storage_data[item.storage_index].offset for item in load_plan.items - ] - assume(original_offsets != sorted(original_offsets)) - assume(len(original_offsets) > 0) - # Minimal tensor metadata to satisfy DCP's validation requirements state_dict_metadata: Dict[str, Any] = { f"item{i}": TensorStorageMetadata( @@ -117,10 +120,13 @@ def test_s3storage_reader_dcp_load_uses_load_ordering(loadplan_and_storagedata): sorted_offsets = [ storage_data[item.storage_index].offset for item in sorted_plan.items ] + + # Verify return type + assert isinstance(sorted_plan, LoadPlan) + + # Verify Load Ordering sorts offsets assert sorted_offsets == sorted(sorted_offsets) # Verify Load Ordering keeps items the same assert len(sorted_plan.items) == len(load_plan.items) - assert {item.storage_index for item in sorted_plan.items} == { - item.storage_index for item in load_plan.items - } + assert set(sorted_plan.items) == set(load_plan.items) From ef4a7205701d03e20bdf4e649453f6318491f824 Mon Sep 17 00:00:00 2001 From: Jensen Tong Date: Tue, 30 Sep 2025 14:09:01 +0100 Subject: [PATCH 06/13] fix(test): add empty plan test and remove dcp load test - Added empty plan test to separate sorting test from 0-length test - Removed dcp 'integration' test with mock items, since it only tests for whether prepare_local_plan is called. Improving the dcp 'integration' test by checking read_data reads will require too many patches, and I'm considering moving that into integration tests. --- .../tst/unit/dcp/test_s3_storage_reader.py | 102 +++--------------- 1 file changed, 17 insertions(+), 85 deletions(-) diff --git a/s3torchconnector/tst/unit/dcp/test_s3_storage_reader.py b/s3torchconnector/tst/unit/dcp/test_s3_storage_reader.py index 8726536e..f46dd56f 100644 --- a/s3torchconnector/tst/unit/dcp/test_s3_storage_reader.py +++ b/s3torchconnector/tst/unit/dcp/test_s3_storage_reader.py @@ -1,20 +1,11 @@ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # // SPDX-License-Identifier: BSD -from typing import Dict, Any from unittest.mock import Mock -from hypothesis import given, assume +from hypothesis import given from hypothesis.strategies import composite, integers, lists -import torch -import torch.distributed.checkpoint as dcp -from torch.distributed.checkpoint.planner import LoadPlan, ReadItem, LoadItemType -from torch.distributed.checkpoint.metadata import ( - Metadata, - MetadataIndex, - TensorStorageMetadata, - ChunkStorageMetadata, -) +from torch.distributed.checkpoint.planner import LoadPlan, ReadItem from s3torchconnector.dcp import S3StorageReader @@ -25,98 +16,39 @@ @composite def load_plan_with_offsets(draw): """Generate LoadPlan with random offsets.""" - offsets = draw(lists(integers(0, 10_000_000), min_size=0, max_size=10_000)) + offsets = draw(lists(integers(0, 10_000_000), min_size=1, max_size=10_000)) storage_data = {} items = [] for i, offset in enumerate(offsets): - metadata_index = MetadataIndex(fqn=f"item{i}", offset=torch.Size([0]), index=0) - - # Mock storage info - storage_data[metadata_index] = Mock( - offset=offset, - length=draw( - integers(1000, 50000) - ), # DCP requires length - use random integers - relative_path=f"__{draw(integers(0, 7))}_0.distcp", - ) - - items.append( - ReadItem( - storage_index=metadata_index, - type=LoadItemType.TENSOR, - dest_index=metadata_index, - dest_offsets=torch.Size([0]), - storage_offsets=torch.Size([0]), - lengths=torch.Size([10]), - ) - ) - - return LoadPlan(items), storage_data # type: ignore + storage_index = f"item{i}" + storage_data[storage_index] = Mock(offset=offset) + items.append(Mock(spec=ReadItem, storage_index=storage_index)) + return LoadPlan(items), storage_data -@given(load_plan_with_offsets()) -def test_s3storage_reader_prepare_local_plan(loadplan_and_storagedata): - """Test prepare local plan sorts items by storage_data offset.""" - load_plan, storage_data = loadplan_and_storagedata +def test_s3storage_reader_prepare_local_plan_empty(): + """Test prepare_local_plan handles empty plans.""" s3_storage_reader = S3StorageReader(TEST_REGION, TEST_PATH) - s3_storage_reader.storage_data = storage_data - sorted_plan = s3_storage_reader.prepare_local_plan(load_plan) - sorted_offsets = [ - storage_data[item.storage_index].offset for item in sorted_plan.items - ] + sorted_plan = s3_storage_reader.prepare_local_plan(LoadPlan([])) + # Output: LoadPlan(items=[], storage_data=None, planner_data=None) - # Verify return type assert isinstance(sorted_plan, LoadPlan) - - # Verify Load Ordering sorts offsets - assert sorted_offsets == sorted(sorted_offsets) - - # Verify Load Ordering keeps items the same - assert len(sorted_plan.items) == len(load_plan.items) - assert set(sorted_plan.items) == set(load_plan.items) + assert len(sorted_plan.items) == 0 @given(load_plan_with_offsets()) -def test_s3storage_reader_dcp_load_uses_load_ordering(loadplan_and_storagedata): - """Test that DCP automatically calls our load ordering optimization via prepare_local_plan.""" +def test_s3storage_reader_prepare_local_plan(loadplan_and_storagedata): + """Test prepare local plan sorts items by storage_data offset.""" load_plan, storage_data = loadplan_and_storagedata - # Minimal tensor metadata to satisfy DCP's validation requirements - state_dict_metadata: Dict[str, Any] = { - f"item{i}": TensorStorageMetadata( - properties=Mock(dtype=torch.float32), # tensor type validation - size=torch.Size([10]), # memory allocation - chunks=[ # chunk info for distributed loading - ChunkStorageMetadata(offsets=torch.Size([0]), sizes=torch.Size([10])) - ], - ) - for i in range(len(load_plan.items)) - } - - # Create S3StorageReader with mock read_metadata (iterable) and read_data s3_storage_reader = S3StorageReader(TEST_REGION, TEST_PATH) - s3_storage_reader.read_metadata = Mock( - return_value=Metadata( - state_dict_metadata=state_dict_metadata, # Real dict for DCP iteration - storage_data=storage_data, # Our test data with random offsets - ) - ) - s3_storage_reader.read_data = Mock() - - # Create state_dict matching the metadata structure - state_dict = {f"item{i}": torch.zeros(10) for i in range(len(load_plan.items))} - - # 1. In torch/distributed/checkpoint/state_dict_loader.py: dcp.load() calls _load_state_dict; - # 2. According to torch/distributed/checkpoint/storage.py StorageWriter docstring, _load_state_dict() calls: - # read_metadata() > set_up_storage_reader() > prepare_local_plan() > prepare_global_plan() > read_data() - dcp.load(state_dict, storage_reader=s3_storage_reader) - - # When read_data is called, verify prepare_local_plan was called and sorted the items - sorted_plan = s3_storage_reader.read_data.call_args[0][0] # First arg is the plan + s3_storage_reader.storage_data = storage_data + + sorted_plan = s3_storage_reader.prepare_local_plan(load_plan) sorted_offsets = [ storage_data[item.storage_index].offset for item in sorted_plan.items ] From 5271ed9e8091afe697a4db1bebb0e61352fc2925 Mon Sep 17 00:00:00 2001 From: Jensen Tong Date: Thu, 2 Oct 2025 15:48:05 +0100 Subject: [PATCH 07/13] test(dcp): add e2e load ordering test - Test load ordering in e2e by tracking read() calls - Use parametrized models (Sequential + ResNet) --- .../tst/e2e/dcp/test_e2e_s3_storage_reader.py | 78 +++++++++++++++++++ 1 file changed, 78 insertions(+) create mode 100644 s3torchconnector/tst/e2e/dcp/test_e2e_s3_storage_reader.py diff --git a/s3torchconnector/tst/e2e/dcp/test_e2e_s3_storage_reader.py b/s3torchconnector/tst/e2e/dcp/test_e2e_s3_storage_reader.py new file mode 100644 index 00000000..60adf24d --- /dev/null +++ b/s3torchconnector/tst/e2e/dcp/test_e2e_s3_storage_reader.py @@ -0,0 +1,78 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# // SPDX-License-Identifier: BSD + +import pytest +from unittest.mock import patch + +import torch +import torch.distributed.checkpoint as dcp +import torchvision.models as models + +from s3torchconnector import S3ReaderConstructor +from s3torchconnector.dcp import S3StorageWriter, S3StorageReader +from s3torchconnector.s3reader.sequential import SequentialS3Reader + + +@pytest.mark.parametrize( + "model", + [ + torch.nn.Sequential( + torch.nn.Linear(5, 5), + torch.nn.Linear(20, 20), + torch.nn.Linear(10, 10), + ), + models.resnet18(pretrained=False), + ], +) +def test_prepare_local_plan_sorts_by_storage_offset(checkpoint_directory, model): + """ + Test that prepare_local_plan allows dcp.load() to read items in offset order. + + This does not prevent backwards seek, since torch.load() would still call + backwards seek operations. + + pytorch/torch/serialization.py load() function will call _is_zipfile(), which + includes this read() call: f.read(len(local_header_magic_number)). This is + followed by readinto() calls on the actual tensor. + + Hence we can track read() call positions to determine if load ordering is + being applied correctly. + """ + region = checkpoint_directory.region + s3_uri = checkpoint_directory.s3_uri + + state_dict = model.state_dict() + storage_writer = S3StorageWriter(region=region, path=s3_uri, overwrite=True) + dcp.save(state_dict, storage_writer=storage_writer) + + read_positions = [] + + original_read = SequentialS3Reader.read + + def track_reads(self, size=None): + if not self.key.endswith(".metadata"): + read_positions.append(self._position) + return original_read(self, size) + + # Load with position tracking on read() (called at the start of each torch.load()) + with patch.object(SequentialS3Reader, "read", track_reads): + loaded_state_dict = {k: torch.empty_like(v) for k, v in state_dict.items()} + storage_reader = S3StorageReader( + region=region, + path=s3_uri, + reader_constructor=S3ReaderConstructor.sequential(), + ) + dcp.load(loaded_state_dict, storage_reader=storage_reader) + + print(f"Read positions: {read_positions}") + + # Assert load ordering works (read() calls should be in sorted order) + assert read_positions == sorted(read_positions) + + # Assert all tensors are correctly loaded + assert len(loaded_state_dict) == len(state_dict) + assert loaded_state_dict.keys() == state_dict.keys() + for key in state_dict: + assert torch.equal(loaded_state_dict[key], state_dict[key]), f"Tensor mismatch for {key}" + + From 45cfcfad8e4f617db7f2243f5789a9ac3d0e1e4a Mon Sep 17 00:00:00 2001 From: Jensen Tong Date: Thu, 2 Oct 2025 15:49:26 +0100 Subject: [PATCH 08/13] ci: add torchvision dependency and test all dcp files - Add torchvision to test with ResNet model - Fix to run all test files under dcp/ directory --- s3torchconnectorclient/pyproject.toml | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/s3torchconnectorclient/pyproject.toml b/s3torchconnectorclient/pyproject.toml index 80d641d2..3e33a8d0 100644 --- a/s3torchconnectorclient/pyproject.toml +++ b/s3torchconnectorclient/pyproject.toml @@ -33,6 +33,7 @@ test = [ "flake8", "black", "mypy", + "torchvision", "Pillow<=11.2.1" # installation of the newer versions fails in manylinux2014 images ] @@ -142,8 +143,8 @@ inherit.test-command = "append" test-command = [ "python -m pip install -e '{package}/../s3torchconnector[dcp-test]'", "pytest {package}/../s3torchconnector/tst/unit/dcp", - "CI_STORAGE_CLASS='' CI_REGION=${S3_REGION} CI_BUCKET=${S3_BUCKET} CI_PREFIX=${S3_PREFIX} CI_CUSTOM_ENDPOINT_URL=${S3_CUSTOM_ENDPOINT_URL} pytest -s {package}/../s3torchconnector/tst/e2e/dcp/test_e2e_s3_file_system.py", - "AWS_DEFAULT_REGION=${S3_EXPRESS_REGION} CI_STORAGE_CLASS=EXPRESS_ONEZONE CI_REGION=${S3_EXPRESS_REGION} CI_BUCKET=${S3_EXPRESS_BUCKET} CI_PREFIX=${S3_PREFIX} CI_CUSTOM_ENDPOINT_URL='' pytest -s {package}/../s3torchconnector/tst/e2e/dcp/test_e2e_s3_file_system.py", + "CI_STORAGE_CLASS='' CI_REGION=${S3_REGION} CI_BUCKET=${S3_BUCKET} CI_PREFIX=${S3_PREFIX} CI_CUSTOM_ENDPOINT_URL=${S3_CUSTOM_ENDPOINT_URL} pytest -s {package}/../s3torchconnector/tst/e2e/dcp", + "AWS_DEFAULT_REGION=${S3_EXPRESS_REGION} CI_STORAGE_CLASS=EXPRESS_ONEZONE CI_REGION=${S3_EXPRESS_REGION} CI_BUCKET=${S3_EXPRESS_BUCKET} CI_PREFIX=${S3_PREFIX} CI_CUSTOM_ENDPOINT_URL='' pytest -s {package}/../s3torchconnector/tst/e2e/dcp", ] [[tool.cibuildwheel.overrides]] From e08e76f8db1212e892c7f26c2dcb8ad15b0bab8d Mon Sep 17 00:00:00 2001 From: Jensen Tong Date: Thu, 2 Oct 2025 16:04:22 +0100 Subject: [PATCH 09/13] style: apply black formatting --- s3torchconnector/tst/e2e/dcp/test_e2e_s3_storage_reader.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/s3torchconnector/tst/e2e/dcp/test_e2e_s3_storage_reader.py b/s3torchconnector/tst/e2e/dcp/test_e2e_s3_storage_reader.py index 60adf24d..8cdfadf7 100644 --- a/s3torchconnector/tst/e2e/dcp/test_e2e_s3_storage_reader.py +++ b/s3torchconnector/tst/e2e/dcp/test_e2e_s3_storage_reader.py @@ -73,6 +73,4 @@ def track_reads(self, size=None): assert len(loaded_state_dict) == len(state_dict) assert loaded_state_dict.keys() == state_dict.keys() for key in state_dict: - assert torch.equal(loaded_state_dict[key], state_dict[key]), f"Tensor mismatch for {key}" - - + assert torch.equal(loaded_state_dict[key], state_dict[key]) From b95ae2ff64ee59433a85d227d88686432267a201 Mon Sep 17 00:00:00 2001 From: Jensen Tong Date: Thu, 2 Oct 2025 16:15:21 +0100 Subject: [PATCH 10/13] ci: move torchvision dependency to dcp-test Since pytorch lightning tests run into error: RuntimeError: operator torchvision::nms does not exist --- s3torchconnector/pyproject.toml | 1 + s3torchconnectorclient/pyproject.toml | 1 - 2 files changed, 1 insertion(+), 1 deletion(-) diff --git a/s3torchconnector/pyproject.toml b/s3torchconnector/pyproject.toml index 82a8d98e..0ac0065f 100644 --- a/s3torchconnector/pyproject.toml +++ b/s3torchconnector/pyproject.toml @@ -65,6 +65,7 @@ dcp-test = [ "s3torchconnector[dcp]", "pytest", "importlib_metadata; python_version == '3.9'", + "torchvision", ] [tool.setuptools.packages] diff --git a/s3torchconnectorclient/pyproject.toml b/s3torchconnectorclient/pyproject.toml index 3e33a8d0..ee912318 100644 --- a/s3torchconnectorclient/pyproject.toml +++ b/s3torchconnectorclient/pyproject.toml @@ -33,7 +33,6 @@ test = [ "flake8", "black", "mypy", - "torchvision", "Pillow<=11.2.1" # installation of the newer versions fails in manylinux2014 images ] From b5d0c4d75074503888f5036e892ea36ca86506a9 Mon Sep 17 00:00:00 2001 From: Jensen Tong Date: Thu, 2 Oct 2025 17:04:44 +0100 Subject: [PATCH 11/13] ci: move torchvision dependency to pyproject.toml So torchvision dynamically adapts to torch version. --- .github/workflows/python-integration.yml | 2 +- s3torchconnector/pyproject.toml | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/.github/workflows/python-integration.yml b/.github/workflows/python-integration.yml index 6ef25d24..21c96b73 100644 --- a/.github/workflows/python-integration.yml +++ b/.github/workflows/python-integration.yml @@ -105,7 +105,7 @@ jobs: run: | python -m pip install --upgrade pip # Manually install CPU-only version of torch so we're not carrying around giant GPU drivers/kernels - python -m pip install torch --extra-index-url https://download.pytorch.org/whl/cpu + python -m pip install torch torchvision --extra-index-url https://download.pytorch.org/whl/cpu python -m pip install -e "s3torchconnectorclient[test,e2e]" python -m pip install -e "s3torchconnector[test,e2e]" diff --git a/s3torchconnector/pyproject.toml b/s3torchconnector/pyproject.toml index 0ac0065f..82a8d98e 100644 --- a/s3torchconnector/pyproject.toml +++ b/s3torchconnector/pyproject.toml @@ -65,7 +65,6 @@ dcp-test = [ "s3torchconnector[dcp]", "pytest", "importlib_metadata; python_version == '3.9'", - "torchvision", ] [tool.setuptools.packages] From 6ea4a642ae37ec2af28ab89b726b320513b076d9 Mon Sep 17 00:00:00 2001 From: Jensen Tong Date: Thu, 2 Oct 2025 17:34:59 +0100 Subject: [PATCH 12/13] ci: move torchvision dependency to DCP dependencies line Reinstall torch/torchvision after s3torchconnector[dcp-test]. pip install './s3torchconnector[dcp-test]' would reinstall torch without torchvision otherwise. --- .github/workflows/python-integration.yml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/.github/workflows/python-integration.yml b/.github/workflows/python-integration.yml index 21c96b73..eeb761cc 100644 --- a/.github/workflows/python-integration.yml +++ b/.github/workflows/python-integration.yml @@ -105,7 +105,7 @@ jobs: run: | python -m pip install --upgrade pip # Manually install CPU-only version of torch so we're not carrying around giant GPU drivers/kernels - python -m pip install torch torchvision --extra-index-url https://download.pytorch.org/whl/cpu + python -m pip install torch --extra-index-url https://download.pytorch.org/whl/cpu python -m pip install -e "s3torchconnectorclient[test,e2e]" python -m pip install -e "s3torchconnector[test,e2e]" @@ -139,6 +139,7 @@ jobs: if: matrix.runner != 'macos-13' run: | python -m pip install './s3torchconnector[dcp-test]' + python -m pip install torch torchvision --extra-index-url https://download.pytorch.org/whl/cpu - name: Run s3torchconnector DCP e2e tests if: matrix.runner != 'macos-13' run: | From af0070e0843673a252af189aec423637f41d2587 Mon Sep 17 00:00:00 2001 From: Jensen Tong Date: Thu, 2 Oct 2025 18:21:33 +0100 Subject: [PATCH 13/13] fix: remove torchvision and resnet, replace with larger model - Remove torchvision dependency and stop using resnet model - Add neural network from PyTorch quickstart tutorial for e2e test --- .github/workflows/python-integration.yml | 1 - .../tst/e2e/dcp/test_e2e_s3_storage_reader.py | 39 +++++++++++++------ 2 files changed, 27 insertions(+), 13 deletions(-) diff --git a/.github/workflows/python-integration.yml b/.github/workflows/python-integration.yml index eeb761cc..6ef25d24 100644 --- a/.github/workflows/python-integration.yml +++ b/.github/workflows/python-integration.yml @@ -139,7 +139,6 @@ jobs: if: matrix.runner != 'macos-13' run: | python -m pip install './s3torchconnector[dcp-test]' - python -m pip install torch torchvision --extra-index-url https://download.pytorch.org/whl/cpu - name: Run s3torchconnector DCP e2e tests if: matrix.runner != 'macos-13' run: | diff --git a/s3torchconnector/tst/e2e/dcp/test_e2e_s3_storage_reader.py b/s3torchconnector/tst/e2e/dcp/test_e2e_s3_storage_reader.py index 8cdfadf7..d5e62af6 100644 --- a/s3torchconnector/tst/e2e/dcp/test_e2e_s3_storage_reader.py +++ b/s3torchconnector/tst/e2e/dcp/test_e2e_s3_storage_reader.py @@ -5,26 +5,41 @@ from unittest.mock import patch import torch +import torch.nn as nn import torch.distributed.checkpoint as dcp -import torchvision.models as models from s3torchconnector import S3ReaderConstructor from s3torchconnector.dcp import S3StorageWriter, S3StorageReader from s3torchconnector.s3reader.sequential import SequentialS3Reader -@pytest.mark.parametrize( - "model", - [ - torch.nn.Sequential( - torch.nn.Linear(5, 5), - torch.nn.Linear(20, 20), - torch.nn.Linear(10, 10), - ), - models.resnet18(pretrained=False), - ], +SIMPLE_MODEL = torch.nn.Sequential( + nn.Linear(5, 5), + nn.Linear(20, 20), + nn.Linear(10, 10), ) -def test_prepare_local_plan_sorts_by_storage_offset(checkpoint_directory, model): + + +class NeuralNetwork(nn.Module): + """NeuralNetwork from PyTorch quickstart tutorial.""" + + def __init__(self): + super().__init__() + self.flatten = nn.Flatten() + self.linear_relu_stack = nn.Sequential( + nn.Linear(28 * 28, 512), + nn.ReLU(), + nn.Linear(512, 512), + nn.ReLU(), + nn.Linear(512, 10), + ) + + +LARGER_MODEL = NeuralNetwork() + + +@pytest.mark.parametrize("model", [SIMPLE_MODEL, LARGER_MODEL]) +def test_dcp_load_reads_tensors_in_sequential_order(checkpoint_directory, model): """ Test that prepare_local_plan allows dcp.load() to read items in offset order.