1010
1111from s3torchconnector import S3ReaderConstructor
1212from s3torchconnector .dcp import S3StorageWriter , S3StorageReader
13- from s3torchconnector .s3reader . sequential import SequentialS3Reader
13+ from s3torchconnector .s3reader import SequentialS3Reader , DCPOptimizedS3Reader
1414
1515
1616SIMPLE_MODEL = torch .nn .Sequential (
@@ -39,19 +39,34 @@ def __init__(self):
3939
4040
4141@pytest .mark .parametrize ("model" , [SIMPLE_MODEL , LARGER_MODEL ])
42- def test_dcp_load_reads_tensors_in_sequential_order (checkpoint_directory , model ):
42+ @pytest .mark .parametrize (
43+ "reader_class,reader_constructor" ,
44+ [
45+ (SequentialS3Reader , S3ReaderConstructor .sequential ()),
46+ (DCPOptimizedS3Reader , S3ReaderConstructor .dcp_optimized ()),
47+ ],
48+ )
49+ def test_dcp_load_reads_tensors_in_sequential_order (
50+ checkpoint_directory , model , reader_class , reader_constructor
51+ ):
4352 """
4453 Test that prepare_local_plan allows dcp.load() to read items in offset order.
4554
4655 This does not prevent backwards seek, since torch.load() would still call
4756 backwards seek operations.
4857
58+ SequentialS3Reader:
4959 pytorch/torch/serialization.py load() function will call _is_zipfile(), which
5060 includes this read() call: f.read(len(local_header_magic_number)). This is
5161 followed by readinto() calls on the actual tensor.
5262
63+ DCPOptimizedS3Reader:
64+ DCPOptimizedS3Reader.seekable() returns false, hence PyTorch would use read()
65+ calls and make it seekable with `seekable = io.BytesIO(transform_from.read(-1))` in
66+ pytorch/torch/distributed/checkpoint/filesystem.py read_data() method.
67+
5368 Hence we can track read() call positions to determine if load ordering is
54- being applied correctly.
69+ being applied correctly for both cases .
5570 """
5671 region = checkpoint_directory .region
5772 s3_uri = checkpoint_directory .s3_uri
@@ -61,21 +76,17 @@ def test_dcp_load_reads_tensors_in_sequential_order(checkpoint_directory, model)
6176 dcp .save (state_dict , storage_writer = storage_writer )
6277
6378 read_positions = []
64-
65- original_read = SequentialS3Reader .read
79+ original_read = reader_class .read
6680
6781 def track_reads (self , size = None ):
6882 if not self .key .endswith (".metadata" ):
6983 read_positions .append (self ._position )
7084 return original_read (self , size )
7185
72- # Load with position tracking on read() (called at the start of each torch.load())
73- with patch .object (SequentialS3Reader , "read" , track_reads ):
86+ with patch .object (reader_class , "read" , track_reads ):
7487 loaded_state_dict = {k : torch .empty_like (v ) for k , v in state_dict .items ()}
7588 storage_reader = S3StorageReader (
76- region = region ,
77- path = s3_uri ,
78- reader_constructor = S3ReaderConstructor .sequential (),
89+ region = region , path = s3_uri , reader_constructor = reader_constructor
7990 )
8091 dcp .load (loaded_state_dict , storage_reader = storage_reader )
8192
0 commit comments