Skip to content

Commit 0a8f48c

Browse files
committed
test(dcp): add dcp optimized reader e2e test for coalescing behaviour
Add e2e integration test for DCPOptimizedS3Reader range coalescing behaviour with full and partial loading patterns and different max_gap_sizes.
1 parent 3b3c17d commit 0a8f48c

File tree

1 file changed

+87
-1
lines changed

1 file changed

+87
-1
lines changed

s3torchconnector/tst/e2e/dcp/test_e2e_s3_storage_reader.py

Lines changed: 87 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
from s3torchconnector import S3ReaderConstructor
1212
from s3torchconnector.dcp import S3StorageWriter, S3StorageReader
13-
from s3torchconnector.s3reader import SequentialS3Reader, DCPOptimizedS3Reader
13+
from s3torchconnector._s3client import S3Client
1414

1515
from ...conftest import READER_TYPE_STRING_TO_CLASS
1616

@@ -93,3 +93,89 @@ def track_reads(self, size=None):
9393
assert loaded_state_dict.keys() == state_dict.keys()
9494
for key in state_dict:
9595
assert torch.equal(loaded_state_dict[key], state_dict[key])
96+
97+
98+
@pytest.mark.parametrize("model", [SIMPLE_MODEL, LARGER_MODEL])
99+
@pytest.mark.parametrize(
100+
"max_gap_size,load_filter,filter_name,expected_streams",
101+
[
102+
# Full load - all tensors are consecutive, so always 1 stream
103+
(0, lambda k: True, "Full", 1),
104+
(float("inf"), lambda k: True, "Full", 1),
105+
# Weights only - scattered by biases, so stream count depends on max_gap_size
106+
(0, lambda k: k.endswith(".weight"), "Weights", 3),
107+
(float("inf"), lambda k: k.endswith(".weight"), "Weights", 1),
108+
# Layer 2 only - their bias+weight tensors are consecutive, so always 1 stream
109+
(0, lambda k: "2." in k, "Layer 2", 1),
110+
(float("inf"), lambda k: "2." in k, "Layer 2", 1),
111+
],
112+
)
113+
def test_dcp_optimized_loading_patterns(
114+
checkpoint_directory,
115+
model,
116+
max_gap_size,
117+
load_filter,
118+
filter_name,
119+
expected_streams,
120+
):
121+
"""Test DCPOptimized reader with full and partial loading patterns and different max_gap_size.
122+
123+
Validates that full loads use 1 stream, and partial load stream usage depends
124+
on max_gap_size and whether tensors are consecutive / neighbours.
125+
126+
SIMPLE_MODEL tensors: ['0.bias', '0.weight', '1.bias', '1.weight', '2.bias', '2.weight']
127+
LARGER_MODEL tensors: ['linear_relu_stack.0.bias', 'linear_relu_stack.0.weight', 'linear_relu_stack.2.bias',
128+
'linear_relu_stack.2.weight', 'linear_relu_stack.4.bias', 'linear_relu_stack.4.weight']
129+
"""
130+
region = checkpoint_directory.region
131+
s3_uri = checkpoint_directory.s3_uri
132+
133+
state_dict = model.state_dict()
134+
dcp.save(state_dict, storage_writer=S3StorageWriter(region, s3_uri, overwrite=True))
135+
136+
# Print model structure (once per model)
137+
all_keys = list(state_dict.keys())
138+
if max_gap_size == 0 and filter_name == "Full":
139+
print(f"\nTensors: {sorted(all_keys)}")
140+
141+
# Apply filter for partial load
142+
filtered_keys = [k for k in all_keys if load_filter(k)]
143+
excluded_keys = [k for k in all_keys if not load_filter(k)]
144+
assert filtered_keys, f"No keys match {filter_name} filter for this model"
145+
filtered_dict = {k: torch.empty_like(state_dict[k]) for k in filtered_keys}
146+
147+
# Load full / partial checkpoint with stream call tracker
148+
stream_calls = []
149+
original_get_object_stream = S3Client._get_object_stream
150+
def track_get_object_stream(self, bucket, key, start=None, end=None):
151+
if not key.endswith(".metadata"):
152+
stream_calls.append((start, end))
153+
return original_get_object_stream(self, bucket, key, start=start, end=end)
154+
155+
with patch.object(S3Client, "_get_object_stream", track_get_object_stream):
156+
reader_constructor = S3ReaderConstructor.dcp_optimized(max_gap_size)
157+
reader = S3StorageReader(region, s3_uri, reader_constructor=reader_constructor)
158+
dcp.load(filtered_dict, storage_reader=reader)
159+
160+
# Verify correctness
161+
assert len(filtered_dict) == len(filtered_keys)
162+
for k, v in filtered_dict.items():
163+
assert torch.equal(v, state_dict[k])
164+
assert load_filter(k)
165+
166+
# Verify excluded keys are not loaded
167+
for k in excluded_keys:
168+
assert k not in filtered_dict, f"Key {k} should not be in {filter_name} load"
169+
170+
# Verify expected stream count
171+
assert len(stream_calls) == expected_streams
172+
if len(stream_calls) > 1:
173+
for i in range(1, len(stream_calls)):
174+
assert stream_calls[i][0] >= stream_calls[i - 1][1]
175+
assert stream_calls[i][0] - stream_calls[i - 1][1] >= max_gap_size
176+
177+
# Print number of stream calls
178+
coalesce = "no coalesce" if max_gap_size == 0 else "full coalesce"
179+
print(
180+
f"{filter_name} load, {coalesce}: {len(stream_calls)} streams, {len(filtered_keys)} tensors"
181+
)

0 commit comments

Comments
 (0)