33
44import logging
55from dataclasses import dataclass
6- from typing import List , Optional , Callable , Union , Dict
7- from io import SEEK_SET
6+ from typing import List , Optional , Callable , Union , Dict , Iterator
7+ from io import SEEK_SET , SEEK_CUR
88
99from s3torchconnectorclient ._mountpoint_s3_client import (
1010 ObjectInfo ,
1111 GetObjectStream ,
1212 HeadObjectResult ,
1313)
1414from .s3reader import S3Reader
15- from .sequential import SequentialS3Reader
1615
1716log = logging .getLogger (__name__ )
1817
1918
2019@dataclass
2120class RangeRequest :
21+ """Singular range request; Inclusive start, exclusive end"""
22+
2223 start : int
2324 end : int
2425 request_id : Optional [str ] = None
@@ -31,6 +32,8 @@ class RangeGroup:
3132 requests : List [RangeRequest ]
3233
3334
35+ # TODO: Update name, since it now requires sequential reading and is optimised for DCP
36+ # TODO: Update docstring to emphasise this requires Load Ordering in prepare_local_plan
3437class ListOfRangesS3Reader (S3Reader ):
3538 """Optimized reader with pre-calculated request mapping and batch prefetch."""
3639
@@ -42,30 +45,20 @@ def __init__(
4245 get_object_info : Callable [[], Union [ObjectInfo , HeadObjectResult ]],
4346 get_stream : Callable [[Optional [int ], Optional [int ]], GetObjectStream ],
4447 max_gap_size : int = 200 * 1024 * 1024 ,
45- ** kwargs ,
4648 ):
4749 self ._bucket = bucket
4850 self ._key = key
4951 self ._get_object_info = get_object_info
5052 self ._get_stream = get_stream
5153
5254 # Calculate range groups using coalescing logic
53- self ._range_groups = self ._calculate_range_groups (ranges , max_gap_size )
54-
55- # Pre-create all readers
56- self ._group_readers : Dict [int , SequentialS3Reader ] = {}
57- for i , group in enumerate (self ._range_groups ):
58- reader = SequentialS3Reader (
59- bucket = bucket ,
60- key = key ,
61- get_object_info = get_object_info ,
62- get_stream = get_stream ,
63- start_offset = group .start ,
64- end_offset = group .end ,
65- )
66- # TODO - judge if this is beneficial or not.
67- reader .prefetch () # Batch prefetch all ranges
68- self ._group_readers [i ] = reader
55+ self ._range_groups = self ._coalesce_ranges (ranges , max_gap_size )
56+ self ._current_group_idx : int = 0
57+
58+ # Per-group stream cache
59+ self ._streams : Dict [int , Iterator [bytes ]] = {}
60+ self ._stream_positions : Dict [int , int ] = {}
61+ self ._stream_buffers : Dict [int , bytes ] = {}
6962
7063 self ._position : int = 0
7164
@@ -77,76 +70,131 @@ def bucket(self) -> str:
7770 def key (self ) -> str :
7871 return self ._key
7972
80- def _calculate_range_groups (
73+ def seekable (self ) -> bool :
74+ """Not seekable — torch/distributed/checkpoint/filesystem.py will use read() instead of readinto()."""
75+ return False
76+
77+ def _coalesce_ranges (
8178 self , ranges : List [RangeRequest ], max_gap_size : int
8279 ) -> List [RangeGroup ]:
83- """Coalescing logic - group ranges within max_gap_size."""
84- # TODO: optimise this logic
80+ """Coalescing nearby byte ranges within max_gap_size."""
8581 if not ranges :
8682 return []
8783
88- # TODO: could be pre-sorted in prepare_local_plan for dcp.load
89- sorted_ranges = sorted (ranges , key = lambda r : r .start )
90- groups = []
91- current_group = [sorted_ranges [0 ]]
92-
93- for i in range (1 , len (sorted_ranges )):
94- prev_end = current_group [- 1 ].end
95- curr_start = sorted_ranges [i ].start
84+ # TODO: could be pre-sorted in prepare_local_plan (small optimisation)
85+ ranges = sorted (ranges , key = lambda r : r .start )
86+ groups : List [RangeGroup ] = []
87+ current = [ranges [0 ]]
9688
97- if curr_start - prev_end <= max_gap_size :
98- current_group .append (sorted_ranges [i ])
89+ for r in ranges [1 :]:
90+ if r .start - current [- 1 ].end <= max_gap_size :
91+ current .append (r )
9992 else :
100- groups .append (self . _create_range_group ( current_group ))
101- current_group = [sorted_ranges [ i ] ]
93+ groups .append (RangeGroup ( current [ 0 ]. start , current [ - 1 ]. end , current ))
94+ current = [r ]
10295
103- groups .append (self . _create_range_group ( current_group ))
96+ groups .append (RangeGroup ( current [ 0 ]. start , current [ - 1 ]. end , current ))
10497 return groups
10598
106- def _create_range_group (self , ranges : List [RangeRequest ]) -> RangeGroup :
107- """Create range group - always succeeds since we only use gap size."""
108- # TODO remove min/max code by tracking incrementally in _calculate_range_groups
109- # * (was kept since it's easier to understand and test)
110- group_start = min (r .start for r in ranges )
111- group_end = max (r .end for r in ranges )
112- return RangeGroup (start = group_start , end = group_end , requests = ranges )
113-
114- def _find_reader_for_offset (self , offset : int ) -> Optional [SequentialS3Reader ]:
115- """Find reader that contains the given offset."""
116- for i , group in enumerate (self ._range_groups ):
117- if group .start <= offset < group .end :
118- self ._current_reader_index = i
119- return self ._group_readers [i ]
120- if group .start > offset : # TODO handle this case properly by raising errors
121- break
122- return None
123-
124- def seek (self , offset : int , whence : int = SEEK_SET , / ) -> int :
125- self ._position = offset
126- reader = self ._find_reader_for_offset (offset )
127- if not reader :
128- return self ._position
129- return reader .seek (offset , whence )
99+ def _get_stream_for_group (self , idx : int ) -> Iterator [bytes ]:
100+ """
101+ Returns a cached iterator for the given range group,
102+ or creates a new one if not present.
103+ """
104+ if idx not in self ._streams :
105+ group = self ._range_groups [idx ]
106+ stream = self ._get_stream (group .start , group .end )
107+ self ._streams [idx ] = stream
108+ self ._stream_positions [idx ] = group .start
109+ self ._stream_buffers [idx ] = b""
110+ return self ._streams [idx ]
130111
131112 def read (self , size : Optional [int ] = None ) -> bytes :
132- reader = self . _find_reader_for_offset ( self . _position )
133- if not reader :
113+ """Reads up to `size` bytes sequentially across grouped ranges."""
114+ if not size or size <= 0 :
134115 return b""
135- data = reader .read (size )
136- self ._position += len (data )
137- return data
116+
117+ pos = self ._position
118+
119+ # Find group (with cache)
120+ if (
121+ self ._current_group_idx < len (self ._range_groups )
122+ and self ._range_groups [self ._current_group_idx ].start
123+ <= pos
124+ < self ._range_groups [self ._current_group_idx ].end
125+ ):
126+ group_idx = self ._current_group_idx
127+ else :
128+ # Search for matching group
129+ for i , g in enumerate (self ._range_groups ):
130+ if g .start <= pos < g .end :
131+ group_idx = i
132+ self ._current_group_idx = group_idx
133+ break
134+ else :
135+ return b""
136+
137+ stream = self ._get_stream_for_group (group_idx )
138+
139+ current_pos = self ._stream_positions [group_idx ]
140+ buffer = self ._stream_buffers [group_idx ]
141+ remaining = size
142+ chunks : List [bytes ] = []
143+
144+ # 1. Serve from buffered leftover bytes
145+ if buffer and current_pos <= pos < current_pos + len (buffer ):
146+ offset = pos - current_pos
147+ end = offset + min (remaining , len (buffer ) - offset )
148+ chunks .append (buffer [offset :end ])
149+ remaining -= end - offset
150+ current_pos = pos + (end - offset )
151+ self ._stream_buffers [group_idx ] = buffer [end :] if end < len (buffer ) else b""
152+
153+ # 2. Read more data from S3 stream
154+ while remaining > 0 :
155+ try :
156+ chunk = next (stream )
157+ except StopIteration :
158+ break
159+
160+ # Skip ahead if behind target
161+ if current_pos < pos :
162+ skip = min (pos - current_pos , len (chunk ))
163+ chunk = chunk [skip :]
164+ current_pos += skip
165+
166+ # Take needed part of chunk
167+ take = min (len (chunk ), remaining )
168+ chunks .append (chunk [:take ])
169+ remaining -= take
170+ current_pos += take
171+
172+ # Save leftover bytes
173+ if take < len (chunk ):
174+ self ._stream_buffers [group_idx ] = chunk [take :]
175+ break
176+
177+ self ._stream_positions [group_idx ] = current_pos
178+ self ._position = pos + (size - remaining )
179+ return b"" .join (chunks )
180+
181+ def seek (self , offset : int , whence : int = SEEK_SET , / ) -> int :
182+ if whence == SEEK_SET :
183+ self ._position = offset
184+ elif whence == SEEK_CUR :
185+ self ._position += offset
186+ return self ._position
138187
139188 def readinto (self , buf ) -> int :
140- reader = self ._find_reader_for_offset (self ._position )
141- if not reader :
142- return 0
143- bytes_read = reader .readinto (buf )
144- self ._position += bytes_read
145- return bytes_read
189+ data = self .read (len (buf ))
190+ n = len (data )
191+ buf [:n ] = data
192+ return n
146193
147194 def tell (self ) -> int :
148195 return self ._position
149196
150197 def close (self ) -> None :
151- for reader in self ._group_readers .values ():
152- reader .close ()
198+ self ._streams .clear ()
199+ self ._stream_positions .clear ()
200+ self ._stream_buffers .clear ()
0 commit comments