Skip to content

Commit 39853e4

Browse files
committed
perf(s3reader): optimize ListOfRangesS3Reader for sequential DCP workloads
- Remove dependency on SequentialS3Reader for self-managed streams - Implement direct stream management with per-group buffering - Optimize read() method with no BytesIO buffer assuming sequential reading - We now enforce non-seekable behaviour to force sequential reading patterns This implementation is now significantly faster for distributed checkpoint loading patterns while maintaining correctness for sequential access. This relies on load ordering optimisation which enforces sequential reading with read() operations, but will not work with readinto() operations since those still have backward seek patterns.
1 parent 59aae7b commit 39853e4

File tree

2 files changed

+130
-139
lines changed

2 files changed

+130
-139
lines changed

s3torchconnector/src/s3torchconnector/s3reader/list_of_ranges.py

Lines changed: 120 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -3,22 +3,23 @@
33

44
import logging
55
from dataclasses import dataclass
6-
from typing import List, Optional, Callable, Union, Dict
7-
from io import SEEK_SET
6+
from typing import List, Optional, Callable, Union, Dict, Iterator
7+
from io import SEEK_SET, SEEK_CUR
88

99
from s3torchconnectorclient._mountpoint_s3_client import (
1010
ObjectInfo,
1111
GetObjectStream,
1212
HeadObjectResult,
1313
)
1414
from .s3reader import S3Reader
15-
from .sequential import SequentialS3Reader
1615

1716
log = logging.getLogger(__name__)
1817

1918

2019
@dataclass
2120
class RangeRequest:
21+
"""Singular range request; Inclusive start, exclusive end"""
22+
2223
start: int
2324
end: int
2425
request_id: Optional[str] = None
@@ -31,6 +32,8 @@ class RangeGroup:
3132
requests: List[RangeRequest]
3233

3334

35+
# TODO: Update name, since it now requires sequential reading and is optimised for DCP
36+
# TODO: Update docstring to emphasise this requires Load Ordering in prepare_local_plan
3437
class ListOfRangesS3Reader(S3Reader):
3538
"""Optimized reader with pre-calculated request mapping and batch prefetch."""
3639

@@ -42,30 +45,20 @@ def __init__(
4245
get_object_info: Callable[[], Union[ObjectInfo, HeadObjectResult]],
4346
get_stream: Callable[[Optional[int], Optional[int]], GetObjectStream],
4447
max_gap_size: int = 200 * 1024 * 1024,
45-
**kwargs,
4648
):
4749
self._bucket = bucket
4850
self._key = key
4951
self._get_object_info = get_object_info
5052
self._get_stream = get_stream
5153

5254
# Calculate range groups using coalescing logic
53-
self._range_groups = self._calculate_range_groups(ranges, max_gap_size)
54-
55-
# Pre-create all readers
56-
self._group_readers: Dict[int, SequentialS3Reader] = {}
57-
for i, group in enumerate(self._range_groups):
58-
reader = SequentialS3Reader(
59-
bucket=bucket,
60-
key=key,
61-
get_object_info=get_object_info,
62-
get_stream=get_stream,
63-
start_offset=group.start,
64-
end_offset=group.end,
65-
)
66-
# TODO - judge if this is beneficial or not.
67-
reader.prefetch() # Batch prefetch all ranges
68-
self._group_readers[i] = reader
55+
self._range_groups = self._coalesce_ranges(ranges, max_gap_size)
56+
self._current_group_idx: int = 0
57+
58+
# Per-group stream cache
59+
self._streams: Dict[int, Iterator[bytes]] = {}
60+
self._stream_positions: Dict[int, int] = {}
61+
self._stream_buffers: Dict[int, bytes] = {}
6962

7063
self._position: int = 0
7164

@@ -77,76 +70,130 @@ def bucket(self) -> str:
7770
def key(self) -> str:
7871
return self._key
7972

80-
def _calculate_range_groups(
73+
def seekable(self) -> bool:
74+
"""Not seekable — torch/distributed/checkpoint/filesystem.py will use read() instead of readinto()."""
75+
return False
76+
77+
def _coalesce_ranges(
8178
self, ranges: List[RangeRequest], max_gap_size: int
8279
) -> List[RangeGroup]:
83-
"""Coalescing logic - group ranges within max_gap_size."""
84-
# TODO: optimise this logic
80+
"""Coalescing nearby byte ranges within max_gap_size."""
8581
if not ranges:
8682
return []
8783

88-
# TODO: could be pre-sorted in prepare_local_plan for dcp.load
89-
sorted_ranges = sorted(ranges, key=lambda r: r.start)
90-
groups = []
91-
current_group = [sorted_ranges[0]]
92-
93-
for i in range(1, len(sorted_ranges)):
94-
prev_end = current_group[-1].end
95-
curr_start = sorted_ranges[i].start
84+
# TODO: could be pre-sorted in prepare_local_plan (small optimisation)
85+
ranges = sorted(ranges, key=lambda r: r.start)
86+
groups: List[RangeGroup] = []
87+
current = [ranges[0]]
9688

97-
if curr_start - prev_end <= max_gap_size:
98-
current_group.append(sorted_ranges[i])
89+
for r in ranges[1:]:
90+
if r.start - current[-1].end <= max_gap_size:
91+
current.append(r)
9992
else:
100-
groups.append(self._create_range_group(current_group))
101-
current_group = [sorted_ranges[i]]
93+
groups.append(RangeGroup(current[0].start, current[-1].end, current))
94+
current = [r]
10295

103-
groups.append(self._create_range_group(current_group))
96+
groups.append(RangeGroup(current[0].start, current[-1].end, current))
10497
return groups
10598

106-
def _create_range_group(self, ranges: List[RangeRequest]) -> RangeGroup:
107-
"""Create range group - always succeeds since we only use gap size."""
108-
# TODO remove min/max code by tracking incrementally in _calculate_range_groups
109-
# * (was kept since it's easier to understand and test)
110-
group_start = min(r.start for r in ranges)
111-
group_end = max(r.end for r in ranges)
112-
return RangeGroup(start=group_start, end=group_end, requests=ranges)
113-
114-
def _find_reader_for_offset(self, offset: int) -> Optional[SequentialS3Reader]:
115-
"""Find reader that contains the given offset."""
116-
for i, group in enumerate(self._range_groups):
117-
if group.start <= offset < group.end:
118-
self._current_reader_index = i
119-
return self._group_readers[i]
120-
if group.start > offset: # TODO handle this case properly by raising errors
121-
break
122-
return None
123-
124-
def seek(self, offset: int, whence: int = SEEK_SET, /) -> int:
125-
self._position = offset
126-
reader = self._find_reader_for_offset(offset)
127-
if not reader:
128-
return self._position
129-
return reader.seek(offset, whence)
99+
def _get_stream_for_group(self, idx: int) -> Iterator[bytes]:
100+
"""
101+
Returns a cached iterator for the given range group,
102+
or creates a new one if not present.
103+
"""
104+
if idx not in self._streams:
105+
group = self._range_groups[idx]
106+
stream = self._get_stream(group.start, group.end)
107+
self._streams[idx] = stream
108+
self._stream_positions[idx] = group.start
109+
self._stream_buffers[idx] = b""
110+
return self._streams[idx]
130111

131112
def read(self, size: Optional[int] = None) -> bytes:
132-
reader = self._find_reader_for_offset(self._position)
133-
if not reader:
113+
"""Reads up to `size` bytes sequentially across grouped ranges."""
114+
if not size or size <= 0:
134115
return b""
135-
data = reader.read(size)
136-
self._position += len(data)
137-
return data
116+
117+
pos = self._position
118+
119+
# Find group (with cache)
120+
if (
121+
self._current_group_idx < len(self._range_groups)
122+
and self._range_groups[self._current_group_idx].start
123+
<= pos
124+
< self._range_groups[self._current_group_idx].end
125+
):
126+
group_idx = self._current_group_idx
127+
else:
128+
group_idx = next(
129+
(i for i, g in enumerate(self._range_groups) if g.start <= pos < g.end),
130+
None,
131+
)
132+
if group_idx is None:
133+
return b""
134+
self._current_group_idx = group_idx
135+
136+
stream = self._get_stream_for_group(group_idx)
137+
138+
current_pos = self._stream_positions[group_idx]
139+
buffer = self._stream_buffers[group_idx]
140+
remaining = size
141+
chunks: List[bytes] = []
142+
143+
# 1. Serve from buffered leftover bytes
144+
if buffer and current_pos <= pos < current_pos + len(buffer):
145+
offset = pos - current_pos
146+
end = offset + min(remaining, len(buffer) - offset)
147+
chunks.append(buffer[offset:end])
148+
remaining -= end - offset
149+
current_pos = pos + (end - offset)
150+
self._stream_buffers[group_idx] = buffer[end:] if end < len(buffer) else b""
151+
152+
# 2. Read more data from S3 stream
153+
while remaining > 0:
154+
try:
155+
chunk = next(stream)
156+
except StopIteration:
157+
break
158+
159+
# Skip ahead if behind target
160+
if current_pos < pos:
161+
skip = min(pos - current_pos, len(chunk))
162+
chunk = chunk[skip:]
163+
current_pos += skip
164+
165+
# Take needed part of chunk
166+
take = min(len(chunk), remaining)
167+
chunks.append(chunk[:take])
168+
remaining -= take
169+
current_pos += take
170+
171+
# Save leftover bytes
172+
if take < len(chunk):
173+
self._stream_buffers[group_idx] = chunk[take:]
174+
break
175+
176+
self._stream_positions[group_idx] = current_pos
177+
self._position = pos + (size - remaining)
178+
return b"".join(chunks)
179+
180+
def seek(self, offset: int, whence: int = SEEK_SET, /) -> int:
181+
if whence == SEEK_SET:
182+
self._position = offset
183+
elif whence == SEEK_CUR:
184+
self._position += offset
185+
return self._position
138186

139187
def readinto(self, buf) -> int:
140-
reader = self._find_reader_for_offset(self._position)
141-
if not reader:
142-
return 0
143-
bytes_read = reader.readinto(buf)
144-
self._position += bytes_read
145-
return bytes_read
188+
data = self.read(len(buf))
189+
n = len(data)
190+
buf[:n] = data
191+
return n
146192

147193
def tell(self) -> int:
148194
return self._position
149195

150196
def close(self) -> None:
151-
for reader in self._group_readers.values():
152-
reader.close()
197+
self._streams.clear()
198+
self._stream_positions.clear()
199+
self._stream_buffers.clear()

s3torchconnector/src/s3torchconnector/s3reader/sequential.py

Lines changed: 10 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
11
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
22
# // SPDX-License-Identifier: BSD
33

4-
import os
5-
import logging
64
import io
75
from functools import cached_property
86
from io import SEEK_CUR, SEEK_END, SEEK_SET
@@ -15,35 +13,6 @@
1513
)
1614
from .s3reader import S3Reader
1715

18-
log = logging.getLogger(__name__)
19-
20-
21-
class _OffsetReaderView:
22-
"""Wrapper that translates absolute positions to buffer-relative positions"""
23-
24-
def __init__(self, buffer: io.BytesIO, start_offset: int = 0):
25-
self._buffer = buffer
26-
self._start_offset = start_offset
27-
28-
def seek(self, offset: int, whence: int = SEEK_SET) -> int:
29-
if whence == SEEK_SET:
30-
buffer_pos = offset - self._start_offset
31-
return self._buffer.seek(buffer_pos) + self._start_offset
32-
else:
33-
return self._buffer.seek(offset, whence) + self._start_offset
34-
35-
def tell(self) -> int:
36-
return self._buffer.tell() + self._start_offset
37-
38-
def read(self, size=-1):
39-
return self._buffer.read(size)
40-
41-
def readinto(self, buf):
42-
return self._buffer.readinto(buf)
43-
44-
def write(self, data):
45-
return self._buffer.write(data)
46-
4716

4817
class SequentialS3Reader(S3Reader):
4918
"""Sequential S3 reader implementation
@@ -57,32 +26,19 @@ def __init__(
5726
bucket: str,
5827
key: str,
5928
get_object_info: Callable[[], Union[ObjectInfo, HeadObjectResult]],
60-
get_stream: Callable[[Optional[int], Optional[int]], GetObjectStream],
61-
start_offset: Optional[int] = None,
62-
end_offset: Optional[int] = None,
29+
get_stream: Callable[[], GetObjectStream],
6330
):
6431
if not bucket:
6532
raise ValueError("Bucket should be specified")
6633
self._bucket = bucket
6734
self._key = key
68-
self._filename = os.path.basename(self._key)
6935
self._get_object_info = get_object_info
7036
self._get_stream = get_stream
7137
self._stream: Optional[Iterator[bytes]] = None
72-
self._raw_buffer = io.BytesIO()
73-
self._buffer = _OffsetReaderView(self._raw_buffer, start_offset or 0)
38+
self._buffer = io.BytesIO()
7439
self._size: Optional[int] = None
7540
# Invariant: _position == _buffer._tell() unless _position_at_end()
76-
self._position = start_offset or 0
77-
self._pid: int = os.getpid()
78-
79-
self._start_offset = start_offset
80-
self._end_offset = end_offset
81-
82-
# Log start of reading for tracking
83-
log.debug(
84-
f"file={self._filename}, pid={self._pid}, type=read_initialized, start_offset={start_offset}, end_offset={end_offset}"
85-
)
41+
self._position = 0
8642

8743
@property
8844
def bucket(self) -> str:
@@ -104,10 +60,7 @@ def prefetch(self) -> None:
10460
"""
10561

10662
if self._stream is None:
107-
if self._start_offset is not None or self._end_offset is not None:
108-
self._stream = self._get_stream(self._start_offset, self._end_offset)
109-
else:
110-
self._stream = self._get_stream(None, None)
63+
self._stream = self._get_stream()
11164

11265
def readinto(self, buf) -> int:
11366
"""Read up to len(buf) bytes into a pre-allocated, writable bytes-like object buf.
@@ -120,9 +73,6 @@ def readinto(self, buf) -> int:
12073
int : numer of bytes read or zero, if no bytes available
12174
"""
12275
buf_size = len(buf)
123-
log.debug(
124-
f"file={self._filename}, pid={self._pid}, type=readinto, position={self._position}, size={buf_size}"
125-
)
12676
if self._position_at_end() or buf_size == 0:
12777
# If no bytes are available or no place to write data, zero should be returned
12878
return 0
@@ -158,11 +108,6 @@ def read(self, size: Optional[int] = None) -> bytes:
158108

159109
if size is not None and not isinstance(size, int):
160110
raise TypeError(f"argument should be integer or None, not {type(size)!r}")
161-
162-
log.debug(
163-
f"file={self._filename}, pid={self._pid}, type=read, position={self._position}, size={size}"
164-
)
165-
166111
if self._position_at_end():
167112
# Invariant: if we're at EOF, it doesn't matter what `size` is, we'll always return no data and have no
168113
# side effect.
@@ -255,15 +200,14 @@ def _position_at_end(self) -> bool:
255200
# We can never be special cased to EOF if we never saw how long it is.
256201
# If we _are_ at EOF, we'll just not take the early exits.
257202
return False
258-
end_pos = min(self._end_offset, self._size) if self._end_offset else self._size
259-
return self._position >= end_pos
203+
return self._position == self._size
260204

261205
def _buffer_size(self) -> int:
262-
cur_pos = self._raw_buffer.tell()
263-
self._raw_buffer.seek(0, SEEK_END)
264-
buffer_size = self._raw_buffer.tell()
265-
self._raw_buffer.seek(cur_pos)
266-
return buffer_size + (self._start_offset or 0)
206+
cur_pos = self._buffer.tell()
207+
self._buffer.seek(0, SEEK_END)
208+
buffer_size = self._buffer.tell()
209+
self._buffer.seek(cur_pos)
210+
return buffer_size
267211

268212
def tell(self) -> int:
269213
"""

0 commit comments

Comments
 (0)