|
10 | 10 |
|
11 | 11 | from s3torchconnector import S3ReaderConstructor |
12 | 12 | from s3torchconnector.dcp import S3StorageWriter, S3StorageReader |
13 | | -from s3torchconnector.s3reader import SequentialS3Reader, DCPOptimizedS3Reader |
| 13 | +from s3torchconnector._s3client import S3Client |
14 | 14 |
|
15 | 15 | from ...conftest import READER_TYPE_STRING_TO_CLASS |
16 | 16 |
|
@@ -93,3 +93,89 @@ def track_reads(self, size=None): |
93 | 93 | assert loaded_state_dict.keys() == state_dict.keys() |
94 | 94 | for key in state_dict: |
95 | 95 | assert torch.equal(loaded_state_dict[key], state_dict[key]) |
| 96 | + |
| 97 | + |
| 98 | +@pytest.mark.parametrize("model", [SIMPLE_MODEL, LARGER_MODEL]) |
| 99 | +@pytest.mark.parametrize( |
| 100 | + "max_gap_size,load_filter,filter_name,expected_streams", |
| 101 | + [ |
| 102 | + # Full load - all tensors are consecutive, so always 1 stream |
| 103 | + (0, lambda k: True, "Full", 1), |
| 104 | + (float("inf"), lambda k: True, "Full", 1), |
| 105 | + # Weights only - scattered by biases, so stream count depends on max_gap_size |
| 106 | + (0, lambda k: k.endswith(".weight"), "Weights", 3), |
| 107 | + (float("inf"), lambda k: k.endswith(".weight"), "Weights", 1), |
| 108 | + # Layer 2 only - their bias+weight tensors are consecutive, so always 1 stream |
| 109 | + (0, lambda k: "2." in k, "Layer 2", 1), |
| 110 | + (float("inf"), lambda k: "2." in k, "Layer 2", 1), |
| 111 | + ], |
| 112 | +) |
| 113 | +def test_dcp_optimized_loading_patterns( |
| 114 | + checkpoint_directory, |
| 115 | + model, |
| 116 | + max_gap_size, |
| 117 | + load_filter, |
| 118 | + filter_name, |
| 119 | + expected_streams, |
| 120 | +): |
| 121 | + """Test DCPOptimized reader with full and partial loading patterns and different max_gap_size. |
| 122 | +
|
| 123 | + Validates that full loads use 1 stream, and partial load stream usage depends |
| 124 | + on max_gap_size and whether tensors are consecutive / neighbours. |
| 125 | + |
| 126 | + SIMPLE_MODEL tensors: ['0.bias', '0.weight', '1.bias', '1.weight', '2.bias', '2.weight'] |
| 127 | + LARGER_MODEL tensors: ['linear_relu_stack.0.bias', 'linear_relu_stack.0.weight', 'linear_relu_stack.2.bias', |
| 128 | + 'linear_relu_stack.2.weight', 'linear_relu_stack.4.bias', 'linear_relu_stack.4.weight'] |
| 129 | + """ |
| 130 | + region = checkpoint_directory.region |
| 131 | + s3_uri = checkpoint_directory.s3_uri |
| 132 | + |
| 133 | + state_dict = model.state_dict() |
| 134 | + dcp.save(state_dict, storage_writer=S3StorageWriter(region, s3_uri, overwrite=True)) |
| 135 | + |
| 136 | + # Print model structure (once per model) |
| 137 | + all_keys = list(state_dict.keys()) |
| 138 | + if max_gap_size == 0 and filter_name == "Full": |
| 139 | + print(f"\nTensors: {sorted(all_keys)}") |
| 140 | + |
| 141 | + # Apply filter for partial load |
| 142 | + filtered_keys = [k for k in all_keys if load_filter(k)] |
| 143 | + excluded_keys = [k for k in all_keys if not load_filter(k)] |
| 144 | + assert filtered_keys, f"No keys match {filter_name} filter for this model" |
| 145 | + filtered_dict = {k: torch.empty_like(state_dict[k]) for k in filtered_keys} |
| 146 | + |
| 147 | + # Load full / partial checkpoint with stream call tracker |
| 148 | + stream_calls = [] |
| 149 | + original_get_object_stream = S3Client._get_object_stream |
| 150 | + def track_get_object_stream(self, bucket, key, start=None, end=None): |
| 151 | + if not key.endswith(".metadata"): |
| 152 | + stream_calls.append((start, end)) |
| 153 | + return original_get_object_stream(self, bucket, key, start=start, end=end) |
| 154 | + |
| 155 | + with patch.object(S3Client, "_get_object_stream", track_get_object_stream): |
| 156 | + reader_constructor = S3ReaderConstructor.dcp_optimized(max_gap_size) |
| 157 | + reader = S3StorageReader(region, s3_uri, reader_constructor=reader_constructor) |
| 158 | + dcp.load(filtered_dict, storage_reader=reader) |
| 159 | + |
| 160 | + # Verify correctness |
| 161 | + assert len(filtered_dict) == len(filtered_keys) |
| 162 | + for k, v in filtered_dict.items(): |
| 163 | + assert torch.equal(v, state_dict[k]) |
| 164 | + assert load_filter(k) |
| 165 | + |
| 166 | + # Verify excluded keys are not loaded |
| 167 | + for k in excluded_keys: |
| 168 | + assert k not in filtered_dict, f"Key {k} should not be in {filter_name} load" |
| 169 | + |
| 170 | + # Verify expected stream count |
| 171 | + assert len(stream_calls) == expected_streams |
| 172 | + if len(stream_calls) > 1: |
| 173 | + for i in range(1, len(stream_calls)): |
| 174 | + assert stream_calls[i][0] >= stream_calls[i - 1][1] |
| 175 | + assert stream_calls[i][0] - stream_calls[i - 1][1] >= max_gap_size |
| 176 | + |
| 177 | + # Print number of stream calls |
| 178 | + coalesce = "no coalesce" if max_gap_size == 0 else "full coalesce" |
| 179 | + print( |
| 180 | + f"{filter_name} load, {coalesce}: {len(stream_calls)} streams, {len(filtered_keys)} tensors" |
| 181 | + ) |
0 commit comments