Skip to content

Commit 68165e6

Browse files
committed
test(dcp): update dcp e2e tests with DCPOptimizedS3Reader
- Add dcp_reader_constructor fixture for DCP tests - Update test_e2e_s3_file_system.py to use dcp_reader_constructor fixture - Update test_e2e_s3_storage_reader.py load ordering test to also cover dcop-optimized s3 reader
1 parent a515b4a commit 68165e6

File tree

3 files changed

+43
-17
lines changed

3 files changed

+43
-17
lines changed

s3torchconnector/tst/conftest.py

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,17 +11,32 @@
1111
# Shared reader constructors for parametrized tests
1212
# TODO: use this variable in test_distributed_training.py and test_multiprocess_dataloading.py
1313
READER_CONSTRUCTORS = [
14-
S3ReaderConstructor.sequential(), # Sequential Reader
15-
S3ReaderConstructor.range_based(), # Default range-based reader, with buffer
16-
S3ReaderConstructor.range_based(buffer_size=0), # range-based reader, no buffer
14+
("sequential", S3ReaderConstructor.sequential()),
15+
("range_based_with_buffer", S3ReaderConstructor.range_based()),
16+
("range_based_no_buffer", S3ReaderConstructor.range_based(buffer_size=0)),
17+
]
18+
19+
# Include dcp_optimized for DCP tests
20+
DCP_READER_CONSTRUCTORS = READER_CONSTRUCTORS + [
21+
("dcp_optimized", S3ReaderConstructor.dcp_optimized()),
1722
]
1823

1924

2025
@pytest.fixture(
21-
params=READER_CONSTRUCTORS,
22-
ids=["sequential", "range_based_with_buffer", "range_based_no_buffer"],
26+
params=[constructor for _, constructor in READER_CONSTRUCTORS],
27+
ids=[name for name, _ in READER_CONSTRUCTORS],
2328
scope="module",
2429
)
2530
def reader_constructor(request) -> S3ReaderConstructorProtocol:
2631
"""Provide reader constructor (partial(S3Reader)) instances for all supported reader types."""
2732
return request.param
33+
34+
35+
@pytest.fixture(
36+
params=[constructor for _, constructor in DCP_READER_CONSTRUCTORS],
37+
ids=[name for name, _ in DCP_READER_CONSTRUCTORS],
38+
scope="module",
39+
)
40+
def dcp_reader_constructor(request) -> S3ReaderConstructorProtocol:
41+
"""Provide reader constructor instances for DCP tests including dcp_optimized."""
42+
return request.param

s3torchconnector/tst/e2e/dcp/test_e2e_s3_file_system.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -212,7 +212,7 @@ def test_dcp_when_multi_process(
212212
tensor_dimensions,
213213
thread_count,
214214
port_offset,
215-
reader_constructor,
215+
dcp_reader_constructor,
216216
):
217217
multi_process_dcp_save_load(
218218
world_size=3,
@@ -221,7 +221,7 @@ def test_dcp_when_multi_process(
221221
tensor_dimensions=tensor_dimensions,
222222
port_offset=port_offset,
223223
prefix_strategy=None,
224-
reader_constructor=reader_constructor,
224+
reader_constructor=dcp_reader_constructor,
225225
)
226226

227227

s3torchconnector/tst/e2e/dcp/test_e2e_s3_storage_reader.py

Lines changed: 21 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
from s3torchconnector import S3ReaderConstructor
1212
from s3torchconnector.dcp import S3StorageWriter, S3StorageReader
13-
from s3torchconnector.s3reader.sequential import SequentialS3Reader
13+
from s3torchconnector.s3reader import SequentialS3Reader, DCPOptimizedS3Reader
1414

1515

1616
SIMPLE_MODEL = torch.nn.Sequential(
@@ -39,19 +39,34 @@ def __init__(self):
3939

4040

4141
@pytest.mark.parametrize("model", [SIMPLE_MODEL, LARGER_MODEL])
42-
def test_dcp_load_reads_tensors_in_sequential_order(checkpoint_directory, model):
42+
@pytest.mark.parametrize(
43+
"reader_class,reader_constructor",
44+
[
45+
(SequentialS3Reader, S3ReaderConstructor.sequential()),
46+
(DCPOptimizedS3Reader, S3ReaderConstructor.dcp_optimized()),
47+
],
48+
)
49+
def test_dcp_load_reads_tensors_in_sequential_order(
50+
checkpoint_directory, model, reader_class, reader_constructor
51+
):
4352
"""
4453
Test that prepare_local_plan allows dcp.load() to read items in offset order.
4554
4655
This does not prevent backwards seek, since torch.load() would still call
4756
backwards seek operations.
4857
58+
SequentialS3Reader:
4959
pytorch/torch/serialization.py load() function will call _is_zipfile(), which
5060
includes this read() call: f.read(len(local_header_magic_number)). This is
5161
followed by readinto() calls on the actual tensor.
5262
63+
DCPOptimizedS3Reader:
64+
DCPOptimizedS3Reader.seekable() returns false, hence PyTorch would use read()
65+
calls and make it seekable with `seekable = io.BytesIO(transform_from.read(-1))` in
66+
pytorch/torch/distributed/checkpoint/filesystem.py read_data() method.
67+
5368
Hence we can track read() call positions to determine if load ordering is
54-
being applied correctly.
69+
being applied correctly for both cases.
5570
"""
5671
region = checkpoint_directory.region
5772
s3_uri = checkpoint_directory.s3_uri
@@ -61,21 +76,17 @@ def test_dcp_load_reads_tensors_in_sequential_order(checkpoint_directory, model)
6176
dcp.save(state_dict, storage_writer=storage_writer)
6277

6378
read_positions = []
64-
65-
original_read = SequentialS3Reader.read
79+
original_read = reader_class.read
6680

6781
def track_reads(self, size=None):
6882
if not self.key.endswith(".metadata"):
6983
read_positions.append(self._position)
7084
return original_read(self, size)
7185

72-
# Load with position tracking on read() (called at the start of each torch.load())
73-
with patch.object(SequentialS3Reader, "read", track_reads):
86+
with patch.object(reader_class, "read", track_reads):
7487
loaded_state_dict = {k: torch.empty_like(v) for k, v in state_dict.items()}
7588
storage_reader = S3StorageReader(
76-
region=region,
77-
path=s3_uri,
78-
reader_constructor=S3ReaderConstructor.sequential(),
89+
region=region, path=s3_uri, reader_constructor=reader_constructor
7990
)
8091
dcp.load(loaded_state_dict, storage_reader=storage_reader)
8192

0 commit comments

Comments
 (0)