Skip to content

Commit 08a815a

Browse files
committed
perf(s3reader): optimize for sequential DCP workloads
This commit improves performance of ListOfRangesS3Reader by up to 30% for DCP load: - Remove dependency on SequentialS3Reader for self-managed streams - Implement direct stream management with per-group buffering - Optimize read() method with no BytesIO buffer assuming sequential reading - We now enforce non-seekable behaviour to force sequential reading patterns This implementation is now significantly faster for distributed checkpoint loading patterns while maintaining correctness for sequential access. This relies on load ordering optimisation which enforces sequential reading with read() operations, but will not work with readinto() operations since those still have backward seek patterns.
1 parent 59aae7b commit 08a815a

File tree

3 files changed

+135
-151
lines changed

3 files changed

+135
-151
lines changed

s3torchconnector/src/s3torchconnector/s3reader/constructor.py

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -46,10 +46,7 @@ class S3ReaderConstructor:
4646
"""
4747

4848
@staticmethod
49-
def sequential(
50-
start_offset: Optional[int] = None,
51-
end_offset: Optional[int] = None,
52-
) -> S3ReaderConstructorProtocol:
49+
def sequential() -> S3ReaderConstructorProtocol:
5350
"""Creates a constructor for sequential readers
5451
5552
Returns:
@@ -60,12 +57,7 @@ def sequential(
6057
reader_constructor = S3ReaderConstructor.sequential()
6158
6259
"""
63-
# TODO update docstrings (after implementation fixed)
64-
return partial(
65-
SequentialS3Reader,
66-
start_offset=start_offset,
67-
end_offset=end_offset,
68-
)
60+
return partial(SequentialS3Reader)
6961

7062
@staticmethod
7163
def range_based(buffer_size: Optional[int] = None) -> S3ReaderConstructorProtocol:
@@ -111,13 +103,13 @@ def range_based(buffer_size: Optional[int] = None) -> S3ReaderConstructorProtoco
111103
@staticmethod
112104
def list_of_ranges(ranges: List[RangeRequest]) -> S3ReaderConstructorProtocol:
113105
"""Creates a constructor for ListOfRangesS3Reader with specific ranges"""
114-
# TODO update docstring
106+
# TODO update docstring, and name
115107
return partial(ListOfRangesS3Reader, ranges=ranges)
116108

117109
@staticmethod
118110
def dcp_list_of_ranges() -> S3ReaderConstructorProtocol:
119111
"""Creates a DCP-optimized constructor that uses ListOfRanges when ranges are available"""
120-
# TODO update docstring
112+
# TODO update docstring, and name
121113
return DCPListOfRangesConstructor()
122114

123115
@staticmethod

s3torchconnector/src/s3torchconnector/s3reader/list_of_ranges.py

Lines changed: 121 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -3,22 +3,23 @@
33

44
import logging
55
from dataclasses import dataclass
6-
from typing import List, Optional, Callable, Union, Dict
7-
from io import SEEK_SET
6+
from typing import List, Optional, Callable, Union, Dict, Iterator
7+
from io import SEEK_SET, SEEK_CUR
88

99
from s3torchconnectorclient._mountpoint_s3_client import (
1010
ObjectInfo,
1111
GetObjectStream,
1212
HeadObjectResult,
1313
)
1414
from .s3reader import S3Reader
15-
from .sequential import SequentialS3Reader
1615

1716
log = logging.getLogger(__name__)
1817

1918

2019
@dataclass
2120
class RangeRequest:
21+
"""Singular range request; Inclusive start, exclusive end"""
22+
2223
start: int
2324
end: int
2425
request_id: Optional[str] = None
@@ -31,6 +32,8 @@ class RangeGroup:
3132
requests: List[RangeRequest]
3233

3334

35+
# TODO: Update name, since it now requires sequential reading and is optimised for DCP
36+
# TODO: Update docstring to emphasise this requires Load Ordering in prepare_local_plan
3437
class ListOfRangesS3Reader(S3Reader):
3538
"""Optimized reader with pre-calculated request mapping and batch prefetch."""
3639

@@ -42,30 +45,20 @@ def __init__(
4245
get_object_info: Callable[[], Union[ObjectInfo, HeadObjectResult]],
4346
get_stream: Callable[[Optional[int], Optional[int]], GetObjectStream],
4447
max_gap_size: int = 200 * 1024 * 1024,
45-
**kwargs,
4648
):
4749
self._bucket = bucket
4850
self._key = key
4951
self._get_object_info = get_object_info
5052
self._get_stream = get_stream
5153

5254
# Calculate range groups using coalescing logic
53-
self._range_groups = self._calculate_range_groups(ranges, max_gap_size)
54-
55-
# Pre-create all readers
56-
self._group_readers: Dict[int, SequentialS3Reader] = {}
57-
for i, group in enumerate(self._range_groups):
58-
reader = SequentialS3Reader(
59-
bucket=bucket,
60-
key=key,
61-
get_object_info=get_object_info,
62-
get_stream=get_stream,
63-
start_offset=group.start,
64-
end_offset=group.end,
65-
)
66-
# TODO - judge if this is beneficial or not.
67-
reader.prefetch() # Batch prefetch all ranges
68-
self._group_readers[i] = reader
55+
self._range_groups = self._coalesce_ranges(ranges, max_gap_size)
56+
self._current_group_idx: int = 0
57+
58+
# Per-group stream cache
59+
self._streams: Dict[int, Iterator[bytes]] = {}
60+
self._stream_positions: Dict[int, int] = {}
61+
self._stream_buffers: Dict[int, bytes] = {}
6962

7063
self._position: int = 0
7164

@@ -77,76 +70,131 @@ def bucket(self) -> str:
7770
def key(self) -> str:
7871
return self._key
7972

80-
def _calculate_range_groups(
73+
def seekable(self) -> bool:
74+
"""Not seekable — torch/distributed/checkpoint/filesystem.py will use read() instead of readinto()."""
75+
return False
76+
77+
def _coalesce_ranges(
8178
self, ranges: List[RangeRequest], max_gap_size: int
8279
) -> List[RangeGroup]:
83-
"""Coalescing logic - group ranges within max_gap_size."""
84-
# TODO: optimise this logic
80+
"""Coalescing nearby byte ranges within max_gap_size."""
8581
if not ranges:
8682
return []
8783

88-
# TODO: could be pre-sorted in prepare_local_plan for dcp.load
89-
sorted_ranges = sorted(ranges, key=lambda r: r.start)
90-
groups = []
91-
current_group = [sorted_ranges[0]]
92-
93-
for i in range(1, len(sorted_ranges)):
94-
prev_end = current_group[-1].end
95-
curr_start = sorted_ranges[i].start
84+
# TODO: could be pre-sorted in prepare_local_plan (small optimisation)
85+
ranges = sorted(ranges, key=lambda r: r.start)
86+
groups: List[RangeGroup] = []
87+
current = [ranges[0]]
9688

97-
if curr_start - prev_end <= max_gap_size:
98-
current_group.append(sorted_ranges[i])
89+
for r in ranges[1:]:
90+
if r.start - current[-1].end <= max_gap_size:
91+
current.append(r)
9992
else:
100-
groups.append(self._create_range_group(current_group))
101-
current_group = [sorted_ranges[i]]
93+
groups.append(RangeGroup(current[0].start, current[-1].end, current))
94+
current = [r]
10295

103-
groups.append(self._create_range_group(current_group))
96+
groups.append(RangeGroup(current[0].start, current[-1].end, current))
10497
return groups
10598

106-
def _create_range_group(self, ranges: List[RangeRequest]) -> RangeGroup:
107-
"""Create range group - always succeeds since we only use gap size."""
108-
# TODO remove min/max code by tracking incrementally in _calculate_range_groups
109-
# * (was kept since it's easier to understand and test)
110-
group_start = min(r.start for r in ranges)
111-
group_end = max(r.end for r in ranges)
112-
return RangeGroup(start=group_start, end=group_end, requests=ranges)
113-
114-
def _find_reader_for_offset(self, offset: int) -> Optional[SequentialS3Reader]:
115-
"""Find reader that contains the given offset."""
116-
for i, group in enumerate(self._range_groups):
117-
if group.start <= offset < group.end:
118-
self._current_reader_index = i
119-
return self._group_readers[i]
120-
if group.start > offset: # TODO handle this case properly by raising errors
121-
break
122-
return None
123-
124-
def seek(self, offset: int, whence: int = SEEK_SET, /) -> int:
125-
self._position = offset
126-
reader = self._find_reader_for_offset(offset)
127-
if not reader:
128-
return self._position
129-
return reader.seek(offset, whence)
99+
def _get_stream_for_group(self, idx: int) -> Iterator[bytes]:
100+
"""
101+
Returns a cached iterator for the given range group,
102+
or creates a new one if not present.
103+
"""
104+
if idx not in self._streams:
105+
group = self._range_groups[idx]
106+
stream = self._get_stream(group.start, group.end)
107+
self._streams[idx] = stream
108+
self._stream_positions[idx] = group.start
109+
self._stream_buffers[idx] = b""
110+
return self._streams[idx]
130111

131112
def read(self, size: Optional[int] = None) -> bytes:
132-
reader = self._find_reader_for_offset(self._position)
133-
if not reader:
113+
"""Reads up to `size` bytes sequentially across grouped ranges."""
114+
if not size or size <= 0:
134115
return b""
135-
data = reader.read(size)
136-
self._position += len(data)
137-
return data
116+
117+
pos = self._position
118+
119+
# Find group (with cache)
120+
if (
121+
self._current_group_idx < len(self._range_groups)
122+
and self._range_groups[self._current_group_idx].start
123+
<= pos
124+
< self._range_groups[self._current_group_idx].end
125+
):
126+
group_idx = self._current_group_idx
127+
else:
128+
# Search for matching group
129+
for i, g in enumerate(self._range_groups):
130+
if g.start <= pos < g.end:
131+
group_idx = i
132+
self._current_group_idx = group_idx
133+
break
134+
else:
135+
return b""
136+
137+
stream = self._get_stream_for_group(group_idx)
138+
139+
current_pos = self._stream_positions[group_idx]
140+
buffer = self._stream_buffers[group_idx]
141+
remaining = size
142+
chunks: List[bytes] = []
143+
144+
# 1. Serve from buffered leftover bytes
145+
if buffer and current_pos <= pos < current_pos + len(buffer):
146+
offset = pos - current_pos
147+
end = offset + min(remaining, len(buffer) - offset)
148+
chunks.append(buffer[offset:end])
149+
remaining -= end - offset
150+
current_pos = pos + (end - offset)
151+
self._stream_buffers[group_idx] = buffer[end:] if end < len(buffer) else b""
152+
153+
# 2. Read more data from S3 stream
154+
while remaining > 0:
155+
try:
156+
chunk = next(stream)
157+
except StopIteration:
158+
break
159+
160+
# Skip ahead if behind target
161+
if current_pos < pos:
162+
skip = min(pos - current_pos, len(chunk))
163+
chunk = chunk[skip:]
164+
current_pos += skip
165+
166+
# Take needed part of chunk
167+
take = min(len(chunk), remaining)
168+
chunks.append(chunk[:take])
169+
remaining -= take
170+
current_pos += take
171+
172+
# Save leftover bytes
173+
if take < len(chunk):
174+
self._stream_buffers[group_idx] = chunk[take:]
175+
break
176+
177+
self._stream_positions[group_idx] = current_pos
178+
self._position = pos + (size - remaining)
179+
return b"".join(chunks)
180+
181+
def seek(self, offset: int, whence: int = SEEK_SET, /) -> int:
182+
if whence == SEEK_SET:
183+
self._position = offset
184+
elif whence == SEEK_CUR:
185+
self._position += offset
186+
return self._position
138187

139188
def readinto(self, buf) -> int:
140-
reader = self._find_reader_for_offset(self._position)
141-
if not reader:
142-
return 0
143-
bytes_read = reader.readinto(buf)
144-
self._position += bytes_read
145-
return bytes_read
189+
data = self.read(len(buf))
190+
n = len(data)
191+
buf[:n] = data
192+
return n
146193

147194
def tell(self) -> int:
148195
return self._position
149196

150197
def close(self) -> None:
151-
for reader in self._group_readers.values():
152-
reader.close()
198+
self._streams.clear()
199+
self._stream_positions.clear()
200+
self._stream_buffers.clear()

0 commit comments

Comments
 (0)