Skip to content

Commit f88c0da

Browse files
committed
refactor: revert READER_TYPE_STRING and sequential dcp e2e test
Reverts non-DCP optimized reader changes to make the PR changes clearer: - Revert fix(tests): resolve e2e test import errors after adding __init__ files - Revert test: place READER_TYPE_STRING_TO_CLASS in conftest - Revert a minor test escape sequence fix.
1 parent 4e9cb1f commit f88c0da

13 files changed

+26
-23
lines changed

s3torchconnector/tst/__init__.py

Lines changed: 0 additions & 2 deletions
This file was deleted.

s3torchconnector/tst/e2e/__init__.py

Lines changed: 0 additions & 2 deletions
This file was deleted.

s3torchconnector/tst/e2e/dcp/test_e2e_s3_file_system.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
from typing import Optional
3030

3131
from s3torchconnector.dcp.s3_prefix_strategy import RoundRobinPrefixStrategy
32-
from ..test_common import _list_folders_in_bucket
32+
from test_common import _list_folders_in_bucket
3333

3434

3535
def generate_random_port():

s3torchconnector/tst/e2e/dcp/test_e2e_s3_storage_reader.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,9 @@
1010

1111
from s3torchconnector import S3ReaderConstructor
1212
from s3torchconnector.dcp import S3StorageWriter, S3StorageReader
13+
from s3torchconnector.s3reader import SequentialS3Reader, DCPOptimizedS3Reader
1314
from s3torchconnector._s3client import S3Client
1415

15-
from ...conftest import READER_TYPE_STRING_TO_CLASS
1616

1717
SIMPLE_MODEL = torch.nn.Sequential(
1818
nn.Linear(5, 5),
@@ -40,8 +40,15 @@ def __init__(self):
4040

4141

4242
@pytest.mark.parametrize("model", [SIMPLE_MODEL, LARGER_MODEL])
43+
@pytest.mark.parametrize(
44+
"reader_class,reader_constructor",
45+
[
46+
(SequentialS3Reader, S3ReaderConstructor.sequential()),
47+
(DCPOptimizedS3Reader, S3ReaderConstructor.dcp_optimized()),
48+
],
49+
)
4350
def test_dcp_load_reads_tensors_in_sequential_order(
44-
dcp_reader_constructor, checkpoint_directory, model
51+
checkpoint_directory, model, reader_class, reader_constructor
4552
):
4653
"""
4754
Test that prepare_local_plan allows dcp.load() to read items in offset order.
@@ -63,11 +70,6 @@ def test_dcp_load_reads_tensors_in_sequential_order(
6370
storage_writer = S3StorageWriter(region=region, path=s3_uri, overwrite=True)
6471
dcp.save(state_dict, storage_writer=storage_writer)
6572

66-
reader_type_string = S3ReaderConstructor.get_reader_type_string(
67-
dcp_reader_constructor
68-
)
69-
reader_class = READER_TYPE_STRING_TO_CLASS[reader_type_string]
70-
7173
read_positions = []
7274
original_read = reader_class.read
7375

@@ -76,10 +78,11 @@ def track_reads(self, size=None):
7678
read_positions.append(self._position)
7779
return original_read(self, size)
7880

81+
# Load with position tracking on read() (called at the start of each torch.load())
7982
with patch.object(reader_class, "read", track_reads):
8083
loaded_state_dict = {k: torch.empty_like(v) for k, v in state_dict.items()}
8184
storage_reader = S3StorageReader(
82-
region=region, path=s3_uri, reader_constructor=dcp_reader_constructor
85+
region=region, path=s3_uri, reader_constructor=reader_constructor
8386
)
8487
dcp.load(loaded_state_dict, storage_reader=storage_reader)
8588

s3torchconnector/tst/e2e/test_distributed_training.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from .conftest import BucketPrefixFixture, BucketPrefixData
2121

2222

23-
from .test_common import _get_fork_methods, _read_data, _set_start_method
23+
from test_common import _get_fork_methods, _read_data, _set_start_method
2424

2525

2626
start_methods = _get_fork_methods()

s3torchconnector/tst/e2e/test_e2e_s3_lightning_checkpoint.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,8 @@
1919
from s3torchconnector.lightning import S3LightningCheckpoint
2020
from s3torchconnectorclient import S3Exception, __version__
2121

22-
from .models.net import Net
23-
from .models.lightning_transformer import LightningTransformer, L
22+
from models.net import Net
23+
from models.lightning_transformer import LightningTransformer, L
2424

2525

2626
LIGHTNING_ACCELERATOR = "cpu"

s3torchconnector/tst/e2e/test_e2e_s3checkpoint.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import pytest
66

77
from s3torchconnector import S3Checkpoint
8-
from .models.net import Net
8+
from models.net import Net
99

1010

1111
@pytest.mark.parametrize(

s3torchconnector/tst/e2e/test_mountpoint_client_parallel_access.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@
55
from s3torchconnector._s3client import S3Client
66
from s3torchconnectorclient._mountpoint_s3_client import MountpointS3Client
77

8-
from .test_common import _get_fork_methods
9-
from .conftest import getenv
8+
from test_common import _get_fork_methods
9+
from conftest import getenv
1010

1111

1212
NATIVE_S3_CLIENT = None

s3torchconnector/tst/e2e/test_multiprocess_dataloading.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
if TYPE_CHECKING:
2323
from .conftest import BucketPrefixFixture
2424

25-
from .test_common import _get_fork_methods, _read_data, _set_start_method
25+
from test_common import _get_fork_methods, _read_data, _set_start_method
2626

2727

2828
start_methods = _get_fork_methods()

s3torchconnector/tst/unit/test_s3dataset_common.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,10 @@
2828
TEST_REGION = "us-east-1"
2929
S3_PREFIX = f"s3://{TEST_BUCKET}"
3030
TEST_ENDPOINT = "https://s3.us-east-1.amazonaws.com"
31+
READER_TYPE_STRING_TO_CLASS = {
32+
"sequential": SequentialS3Reader,
33+
"range_based": RangedS3Reader,
34+
}
3135

3236

3337
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)