Skip to content

Commit efeb194

Browse files
committed
feat(dcp): use constructor pattern for ListOfRanges optimization
Add DCPListOfRangesConstructor and dcp_list_of_ranges() factory method to enable DCP range optimization through reader_constructor parameter. Includes better range injection logic and support for both direct ListOfRanges usage and DCP optimization. Users can now opt-in via: reader_constructor=S3ReaderConstructor.dcp_list_of_ranges()
1 parent 653ce81 commit efeb194

File tree

3 files changed

+56
-32
lines changed

3 files changed

+56
-32
lines changed

s3torchconnector/src/s3torchconnector/dcp/s3_file_system.py

Lines changed: 21 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,11 @@
2828

2929
from s3torchconnector._s3client import S3Client
3030
from s3torchconnector._s3dataset_common import parse_s3_uri
31-
from ..s3reader import S3ReaderConstructor, S3ReaderConstructorProtocol
31+
from ..s3reader import (
32+
S3ReaderConstructor,
33+
DCPListOfRangesConstructor,
34+
S3ReaderConstructorProtocol,
35+
)
3236
from .. import S3ClientConfig
3337
from .s3_prefix_strategy import S3PrefixStrategyBase, DefaultPrefixStrategy
3438
from .._user_agent import UserAgent
@@ -79,7 +83,6 @@ def create_stream(
7983
self,
8084
path: Union[str, os.PathLike],
8185
mode: str,
82-
reader_constructor: Optional[S3ReaderConstructorProtocol] = None,
8386
) -> Generator[io.IOBase, None, None]:
8487
"""
8588
Create a stream for reading or writing to S3.
@@ -102,18 +105,8 @@ def create_stream(
102105
with self._client.put_object(bucket, key) as stream:
103106
yield stream
104107
elif mode == "rb": # read mode
105-
logger.debug("create_stream readable for %s", path_str)
106-
relative_path = os.path.relpath(path, self._path)
107-
if self.file_ranges and relative_path in self.file_ranges:
108-
ranges = self.file_ranges[relative_path]
109-
# ! Force use list_of_ranges reader for now
110-
# TODO: (Important) improve this by passing in ranges parameter properly
111-
reader_constructor = S3ReaderConstructor.list_of_ranges(ranges)
112-
113-
# Use provided reader_constructor or fall back to default
114-
constructor = reader_constructor or self._reader_constructor
115108
with self._client.get_object(
116-
bucket, key, reader_constructor=constructor
109+
bucket, key, reader_constructor=self._reader_constructor
117110
) as stream:
118111
yield stream
119112
else:
@@ -380,20 +373,21 @@ def prepare_local_plan(self, plan: LoadPlan) -> LoadPlan:
380373
LoadPlan: The same plan with items sorted by storage offset.
381374
"""
382375

383-
# Calculate ranges per file
384-
per_file_ranges = {}
385-
for read_item in plan.items:
386-
item_md = self.storage_data[read_item.storage_index]
387-
path = item_md.relative_path
388-
if path not in per_file_ranges:
389-
per_file_ranges[path] = []
390-
per_file_ranges[path].append(
391-
RangeRequest(start=item_md.offset, end=item_md.offset + item_md.length)
392-
)
393-
394-
# Store ranges in filesystem
395-
# TODO find a better place to handle this information
396-
self.fs.file_ranges = per_file_ranges
376+
# Inject ranges if using DCP list-of-ranges reader constructor
377+
if isinstance(self.fs._reader_constructor, DCPListOfRangesConstructor):
378+
# Calculate ranges per file
379+
per_file_ranges = {}
380+
for read_item in plan.items:
381+
item_md = self.storage_data[read_item.storage_index]
382+
path = item_md.relative_path
383+
if path not in per_file_ranges:
384+
per_file_ranges[path] = []
385+
per_file_ranges[path].append(
386+
RangeRequest(
387+
start=item_md.offset, end=item_md.offset + item_md.length
388+
)
389+
)
390+
self.fs._reader_constructor.set_file_ranges(per_file_ranges)
397391

398392
# Sort items in plan based on their offset in checkpoints shards
399393
plan.items.sort(key=lambda item: self.storage_data[item.storage_index].offset)

s3torchconnector/src/s3torchconnector/s3reader/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,16 @@
22
# // SPDX-License-Identifier: BSD
33

44
from .s3reader import S3Reader
5-
from .constructor import S3ReaderConstructor
5+
from .constructor import S3ReaderConstructor, DCPListOfRangesConstructor
66
from .sequential import SequentialS3Reader
77
from .ranged import RangedS3Reader
8+
from .list_of_ranges import ListOfRangesS3Reader
89
from .protocol import GetStreamCallable, S3ReaderConstructorProtocol
910

1011
__all__ = [
1112
"S3Reader",
1213
"S3ReaderConstructor",
1314
"SequentialS3Reader",
1415
"RangedS3Reader",
16+
"ListOfRangesS3Reader",
1517
]

s3torchconnector/src/s3torchconnector/s3reader/constructor.py

Lines changed: 32 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,35 @@
22
# // SPDX-License-Identifier: BSD
33

44
from functools import partial
5-
from typing import Optional, List
5+
from typing import Optional, List, Dict
66

7+
from .s3reader import S3Reader
78
from .protocol import S3ReaderConstructorProtocol
89
from .sequential import SequentialS3Reader
910
from .ranged import RangedS3Reader
1011
from .list_of_ranges import ListOfRangesS3Reader, RangeRequest
1112

1213

14+
class DCPListOfRangesConstructor:
15+
def __init__(self) -> None:
16+
self._file_ranges: Dict[str, List[RangeRequest]] = {}
17+
18+
def set_file_ranges(self, file_ranges: Dict[str, List[RangeRequest]]) -> None:
19+
self._file_ranges = file_ranges
20+
21+
def __call__(
22+
self, bucket: str, key: str, get_object_info, get_stream, **kwargs
23+
) -> S3Reader:
24+
if self._file_ranges:
25+
# TODO: Check if using filename only works with prefix strategies
26+
filename = key.split("/")[-1]
27+
return ListOfRangesS3Reader(
28+
bucket, key, self._file_ranges[filename], get_object_info, get_stream
29+
)
30+
# Fallback to SequentialS3Reader if no file_ranges yet (e.g. when reading .metadata)
31+
return SequentialS3Reader(bucket, key, get_object_info, get_stream, **kwargs)
32+
33+
1334
class S3ReaderConstructor:
1435
"""Constructor for creating ``partial(S3Reader)`` instances.
1536
@@ -93,6 +114,12 @@ def list_of_ranges(ranges: List[RangeRequest]) -> S3ReaderConstructorProtocol:
93114
# TODO update docstring
94115
return partial(ListOfRangesS3Reader, ranges=ranges)
95116

117+
@staticmethod
118+
def dcp_list_of_ranges() -> S3ReaderConstructorProtocol:
119+
"""Creates a DCP-optimized constructor that uses ListOfRanges when ranges are available"""
120+
# TODO update docstring
121+
return DCPListOfRangesConstructor()
122+
96123
@staticmethod
97124
def default() -> S3ReaderConstructorProtocol:
98125
"""Creates default reader constructor (sequential)
@@ -112,10 +139,11 @@ def get_reader_type_string(
112139
S3ReaderConstructor.default()
113140
)
114141

115-
if not isinstance(constructor, partial):
142+
if isinstance(constructor, DCPListOfRangesConstructor):
143+
return "dcp_list_of_ranges"
144+
elif not isinstance(constructor, partial):
116145
return "unknown"
117-
118-
if constructor.func == RangedS3Reader:
146+
elif constructor.func == RangedS3Reader:
119147
return "range_based"
120148
elif constructor.func == SequentialS3Reader:
121149
return "sequential"

0 commit comments

Comments
 (0)