Skip to content

Commit 567b077

Browse files
committed
docs(dcp): document DCP-optimized S3 reader in README and docstrings
- Add documentation to README, constructor, and DCPOptimizedS3Reader class - Include class docstrings for S3FileSystem, S3StorageWriter, and S3StorageReader - Update reader configurations in README with examples - Use sphinx-friendly formatting for docstrings - Remove some unplanned TODOs and update some comments
1 parent f88c0da commit 567b077

File tree

5 files changed

+109
-32
lines changed

5 files changed

+109
-32
lines changed

README.md

Lines changed: 29 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,9 @@ Amazon S3 Connector for PyTorch provides robust support for PyTorch distributed
128128

129129
- `S3StorageWriter`: Implementation of PyTorch's StorageWriter interface.
130130

131-
- `S3StorageReader`: Implementation of PyTorch's StorageReader interface. Supports configurable reading strategies via the `reader_constructor` parameter (see [Reader Configurations](#reader-configurations)).
131+
- `S3StorageReader`: Implementation of PyTorch's StorageReader interface.
132+
- Supports configurable reading strategies via the `reader_constructor` parameter (see [Reader Configurations](#reader-configurations)).
133+
- `S3ReaderConstructor.dcp_optimized()` is recommended for up to 2x faster loading with partial checkpoint optimizations.
132134
- `S3FileSystem`: An implementation of PyTorch's FileSystemBase.
133135

134136
These tools enable seamless integration of Amazon S3 with
@@ -151,6 +153,7 @@ can be found in the [examples/dcp](https://github.com/awslabs/s3-connector-for-p
151153

152154
```py
153155
from s3torchconnector.dcp import S3StorageWriter, S3StorageReader
156+
from s3torchconnector import S3ReaderConstructor
154157

155158
import torchvision
156159
import torch.distributed.checkpoint as DCP
@@ -175,7 +178,13 @@ DCP.save(
175178
# Load distributed checkpoint from S3
176179
model = torchvision.models.resnet18()
177180
model_state_dict = model.state_dict()
178-
s3_storage_reader = S3StorageReader(region=REGION, path=CHECKPOINT_URI)
181+
# Use DCP-optimized reader for faster loading
182+
reader_constructor = S3ReaderConstructor.dcp_optimized()
183+
s3_storage_reader = S3StorageReader(
184+
region=REGION,
185+
path=CHECKPOINT_URI,
186+
reader_constructor=reader_constructor, # optional; constructor for S3Reader types
187+
)
179188
DCP.load(
180189
state_dict=model_state_dict,
181190
storage_reader=s3_storage_reader,
@@ -409,7 +418,7 @@ data = s3reader.read()
409418

410419
## Reader Configurations
411420

412-
Amazon S3 Connector for PyTorch supports two types of readers, configurable through `S3ReaderConstructor`.
421+
Amazon S3 Connector for PyTorch supports three types of readers, configurable through `S3ReaderConstructor`.
413422

414423
### Reader Types
415424

@@ -420,21 +429,32 @@ Amazon S3 Connector for PyTorch supports two types of readers, configurable thro
420429

421430
#### 2. Range-based Reader
422431

423-
- Performs byte-range requests to read specific portions of S3 objects without downloading the entire file.
424-
- Prioritizes memory efficiency, with performance gains only for sparse partial reads.
432+
- Performs byte-range requests to read specific portions of S3 objects without downloading the entire object.
433+
- Prioritizes memory efficiency, with performance gains only for sparse partial reads in large objects.
425434
- Features adaptive buffering with forward overlap handling:
426435
- **Small reads** (< `buffer_size`): Use internal buffer to reduce S3 API calls.
427436
- **Large reads** (≥ `buffer_size`): Bypass buffer for direct transfer.
428437

438+
#### 3. DCP-Optimized Reader (DCP only)
439+
440+
- Specialized usage for PyTorch Distributed Checkpoint (DCP) loading.
441+
- Provides up to 2x performance improvement through zero-copy buffers and sequential access patterns.
442+
- Enables efficient partial checkpoint loading (e.g. model-only) through range-based streams and range coalescing.
443+
- Automatically handles range metadata injection from DCP load plan.
444+
- Requires sequential access patterns (automatically enforced in `S3StorageReader.prepare_local_plan()`)
445+
429446
### When to Use Each Reader
430447

431-
- **Sequential Reader**: For processing entire files, and when repeated access to the data is required. Best for most general use cases.
448+
- **Sequential Reader**: For processing entire objects, and when repeated access to the data is required. Best for most general use cases.
432449
- **Range-based Reader**: For larger objects (100MB+) that require sparse partial reads, and in memory-constrained environments.
450+
- **DCP-Optimized Reader**: For PyTorch Distributed Checkpoint loading scenarios.
433451

434452
**Note**: S3Reader instances are not thread-safe and should not be shared across threads. For multiprocessing with DataLoader, each worker process creates its own S3Reader instance automatically.
435453

436454
### Examples
437455

456+
For `S3ReaderConstructor` usage details, please refer to the [`S3ReaderConstructor` documentation](https://awslabs.github.io/s3-connector-for-pytorch/autoapi/s3torchconnector/s3reader/constructor/index.html). Below are some examples for `S3ReaderConstructor` usage.
457+
438458
Direct method - `S3Client` usage with range-based reader without buffer:
439459
```py
440460
# Direct S3Client usage for zero-copy partial reads into pre-allocated buffers, for memory efficiency and fast data transfer
@@ -456,15 +476,13 @@ s3reader.seek(100 * 1024 * 1024) # Skip to 100MB offset
456476
bytes_read = s3reader.readinto(buffer) # Direct read into buffer
457477
```
458478

459-
DCP interface - `S3StorageReader` usage with range-based reader with buffer:
479+
DCP interface - `S3StorageReader` usage with dcp-optimized reader:
460480
```py
461-
# Load distributed checkpoint with range-based reader to optimize memory usage for large checkpoint files
481+
# Load checkpoint with dcp-optimized reader for better performance
462482
from s3torchconnector.dcp import S3StorageReader
463483
from s3torchconnector import S3ReaderConstructor
464484

465-
reader_constructor = S3ReaderConstructor.range_based(
466-
buffer_size=16*1024*1024 # 16MB buffer
467-
)
485+
reader_constructor = S3ReaderConstructor.dcp_optimized()
468486
s3_storage_reader = S3StorageReader(
469487
region=REGION,
470488
path=CHECKPOINT_URI,
@@ -492,7 +510,6 @@ for item in dataset:
492510
...
493511
```
494512

495-
For `S3ReaderConstructor` usage details, please refer to the [`S3ReaderConstructor` documentation](https://awslabs.github.io/s3-connector-for-pytorch/autoapi/s3torchconnector/s3reader/constructor/index.html).
496513

497514
## Contributing
498515

s3torchconnector/pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ test = [
3535
"flake8",
3636
"black",
3737
"mypy",
38-
"importlib_metadata; python_version == '3.9'", # PyTorch 2.7.0+ DCP w/ Python 3.9 requires this module
38+
"importlib_metadata; python_version == '3.9'", # PyTorch 2.7.0+ DCP w/ Python 3.9 requires this module; for dcp_optimized reader unit tests
3939
]
4040

4141
e2e = [

s3torchconnector/src/s3torchconnector/dcp/s3_file_system.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444

4545

4646
class S3FileSystem(FileSystemBase):
47+
"""S3-based implementation of PyTorch's FileSystemBase for distributed checkpointing."""
4748
def __init__(
4849
self,
4950
region: str,
@@ -267,6 +268,7 @@ class StorageMetadata:
267268

268269

269270
class S3StorageWriter(FileSystemWriter):
271+
"""S3 implementation of PyTorch's FileSystemWriter for distributed checkpoints."""
270272
def __init__(
271273
self,
272274
region: str,
@@ -321,6 +323,7 @@ def validate_checkpoint_id(cls, checkpoint_id: Union[str, os.PathLike]) -> bool:
321323

322324

323325
class S3StorageReader(FileSystemReader):
326+
"""S3 implementation of PyTorch's FileSystemReader with configurable reader strategies."""
324327
def __init__(
325328
self,
326329
region: str,
@@ -356,13 +359,21 @@ def validate_checkpoint_id(cls, checkpoint_id: Union[str, os.PathLike]) -> bool:
356359

357360
def prepare_local_plan(self, plan: LoadPlan) -> LoadPlan:
358361
"""
359-
Sort load items by storage offset for sequential access optimization.
362+
Performs two key optimizations:
363+
364+
1. **Load Ordering**: Sorts load items by storage offset to enable sequential access
365+
366+
2. **Range Injection**: Provides byte range metadata to DCP reader constructors to enable
367+
usage of DCPOptimizedS3Reader for range-based streams and range coalescing
360368
361369
Args:
362370
plan (LoadPlan): The load plan from PyTorch DCP.
363371
364372
Returns:
365373
LoadPlan: The same plan with items sorted by storage offset.
374+
375+
Note:
376+
Both optimizations are required for DCPOptimizedS3Reader.
366377
"""
367378
# Sort items in plan based on their offset in checkpoints shards
368379
plan.items.sort(key=lambda item: self.storage_data[item.storage_index].offset)

s3torchconnector/src/s3torchconnector/s3reader/constructor.py

Lines changed: 34 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,6 @@ def set_item_ranges_by_file(
3838
storage_data: "Dict[MetadataIndex, _StorageInfo]",
3939
) -> None:
4040

41-
# TODO: Check if we want to return DCPOptimizedConstructor for immutability here instead
4241
if not plan_items:
4342
return # Allow lack of plan_items, for SequentialS3Reader fallbacks
4443

@@ -142,15 +141,43 @@ def range_based(buffer_size: Optional[int] = None) -> S3ReaderConstructorProtoco
142141
def dcp_optimized(
143142
max_gap_size: Union[int, float] = DEFAULT_MAX_GAP_SIZE,
144143
) -> DCPS3ReaderConstructorProtocol:
145-
"""
146-
Creates a DCPOptimizedConstructor that uses DCPOptimizedS3Reader when ranges are available
144+
"""Creates a constructor for DCP-optimized readers for faster checkpoint loading.
145+
146+
The DCP-optimized reader provides up to 2x performance improvement over the default sequential reader through:
147+
148+
- Zero-copy buffer management by storing data as memoryview segments
149+
- Sequential access optimization to reduce buffer sizes from file-level to item-level
150+
- Range-based fetching that downloads only required byte ranges and coalesces nearby ranges to reduce S3 request latency
147151
148152
Args:
149-
max_gap_size: Maximum gap size in bytes to coalesce ranges into multiple ranged-streams.
150-
Use float("inf") to coalesce all ranges regardless of gaps.
151-
Use 0 to disable coalescing.
153+
max_gap_size: Maximum gap size in bytes between ranges to coalesce into the same S3 read stream.
154+
Most users should use the default value.
155+
156+
- Default: 32MB (``32 * 1024 * 1024``)
157+
- Use ``float("inf")`` to coalesce all ranges regardless of gaps
158+
- Use 0 to disable coalescing, which creates a new range-based stream for each gap
159+
160+
Returns:
161+
DCPOptimizedConstructorProtocol:
162+
Constructor that creates DCPOptimizedS3Reader when ranges are available, falling back to
163+
SequentialS3Reader otherwise.
164+
165+
Requirements:
166+
Should be used with S3StorageReader, in which ``prepare_local_plan()`` automatically handles:
167+
168+
- Load ordering: Sorts items by storage offset for sequential access
169+
- Range injection: Provides byte ranges from DCP load plan to the reader
170+
171+
Advanced users implementing custom readers must include these optimizations
172+
in their ``prepare_local_plan()``/``read_data()`` implementation to use the DCP-optimized reader.
173+
174+
Example::
175+
176+
reader_constructor = S3ReaderConstructor.dcp_optimized()
177+
storage_reader = S3StorageReader(region, path, reader_constructor=reader_constructor)
178+
DCP.load(state_dict, storage_reader=storage_reader)
179+
152180
"""
153-
# TODO update docstring with guide and requirements to use this reader for DCP
154181
return DCPOptimizedConstructor(max_gap_size=max_gap_size)
155182

156183
@staticmethod

s3torchconnector/src/s3torchconnector/s3reader/dcp_optimized.py

Lines changed: 33 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -164,18 +164,38 @@ def readinto(self, buf) -> int:
164164

165165

166166
class DCPOptimizedS3Reader(S3Reader):
167-
"""
168-
This reader optimizes PyTorch Distributed Checkpoint (DCP) partial loading by
169-
1. exploiting sequential access patterns to avoid BytesIO buffer copy, and
170-
2. only fetching required byte ranges instead of entire objects.
167+
"""S3 reader implementation optimized for PyTorch Distributed Checkpoint (DCP) loading.
168+
169+
Provides up to 2x performance improvement over default sequential reader through:
170+
171+
1. **Zero-Copy Buffer**: Custom ``_ItemViewBuffer`` storing data as memoryview
172+
segments to eliminate BytesIO allocation and copy overhead.
173+
174+
2. **Sequential Access Optimization**: Exploits sequential access patterns over tensor
175+
enforced by ``S3StorageReader.prepare_local_plan()`` to reduce buffer sizes from file-level to
176+
item-level.
177+
178+
3. **Range-based fetching**: For partial checkpoint loading, uses load plan item ranges information
179+
to group nearby byte ranges within ``max_gap_size`` to minimize S3 first byte latency (compared to
180+
range-based reader), while only fetching required byte ranges instead of entire files
181+
(compared to sequential reader).
182+
183+
**Requirements**:
171184
172-
REQUIRES:
173-
- DCP Loading - reader is only designed for usage via dcp_optimized reader_constructor for dcp.load()
174-
- Load Ordering, applied automatically prepare_local_plan, to ensure sequential access patterns.
175-
- item_ranges provided (List[ItemRange]) must be pre-sorted - also applied in prepare_local_plan.
176-
- Only supports sequentially reading exact item_ranges provided - otherwise would result in errors.
177-
Non-sequential access will result in errors.
185+
- DCP Loading - reader is only designed for usage via dcp_optimized reader_constructor for ``dcp.load()``
186+
- Pre-sorted list of item_ranges, injected automatically in ``prepare_local_plan``.
187+
- Sequential Access over exact item_ranges provided, also applied automatically by ``prepare_local_plan``
178188
189+
**Usage**:
190+
Typically created automatically by ``DCPOptimizedConstructor`` when used with ``S3StorageReader`` and
191+
``S3ReaderConstructor.dcp_optimized()``:
192+
193+
reader_constructor = S3ReaderConstructor.dcp_optimized(max_gap_size=32*1024*1024)
194+
storage_reader = S3StorageReader(region, path, reader_constructor=reader_constructor)
195+
DCP.load(state_dict, storage_reader=storage_reader)
196+
197+
**Error Handling**:
198+
Non-sequential access attempts raise ValueError with descriptive messages.
179199
"""
180200

181201
def __init__(
@@ -392,7 +412,6 @@ def _get_item_buffer(self, item: ItemRange) -> _ItemViewBuffer:
392412

393413
chunk_len = len(chunk)
394414

395-
# TODO: separate skip part and take part for clearer logic
396415
# Skip past unwanted data (due to coalescing)
397416
if pos < item.start:
398417
skip_bytes = min(item.start - pos, chunk_len)
@@ -532,6 +551,9 @@ def tell(self) -> int:
532551
return self._position
533552

534553
def close(self) -> None:
554+
"""
555+
Close the stream and release resources.
556+
"""
535557
if not self._closed:
536558
self._closed = True
537559
self._stream = None

0 commit comments

Comments
 (0)