Skip to content

Commit 28db275

Browse files
committed
fix VideoFromFile stream source to _ReentrantBytesIO for parallel async use
1 parent 807538f commit 28db275

File tree

1 file changed

+147
-33
lines changed

1 file changed

+147
-33
lines changed

comfy_api/latest/_input_impl/video_types.py

Lines changed: 147 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,124 @@
1313
from .._util import VideoContainer, VideoCodec, VideoComponents
1414

1515

16+
class _ReentrantBytesIO(io.BytesIO):
17+
"""Read-only, seekable BytesIO-compatible view over shared immutable bytes."""
18+
19+
def __init__(self, data: bytes):
20+
super().__init__(b"") # Initialize base BytesIO with an empty buffer; we do not use its internal storage.
21+
if data is None:
22+
raise TypeError("data must be bytes, not None")
23+
self._data = data
24+
self._pos = 0
25+
self._len = len(data)
26+
27+
def getvalue(self) -> bytes:
28+
if self.closed:
29+
raise ValueError("I/O operation on closed file.")
30+
return self._data
31+
32+
def getbuffer(self) -> memoryview:
33+
if self.closed:
34+
raise ValueError("I/O operation on closed file.")
35+
return memoryview(self._data)
36+
37+
def readable(self) -> bool:
38+
return True
39+
40+
def writable(self) -> bool:
41+
return False
42+
43+
def seekable(self) -> bool:
44+
return True
45+
46+
def tell(self) -> int:
47+
return self._pos
48+
49+
def seek(self, offset: int, whence: int = io.SEEK_SET) -> int:
50+
if self.closed:
51+
raise ValueError("I/O operation on closed file.")
52+
if whence == io.SEEK_SET:
53+
new_pos = offset
54+
elif whence == io.SEEK_CUR:
55+
new_pos = self._pos + offset
56+
elif whence == io.SEEK_END:
57+
new_pos = self._len + offset
58+
else:
59+
raise ValueError(f"Invalid whence: {whence}")
60+
if new_pos < 0:
61+
raise ValueError("Negative seek position")
62+
self._pos = new_pos
63+
return self._pos
64+
65+
def readinto(self, b) -> int:
66+
if self.closed:
67+
raise ValueError("I/O operation on closed file.")
68+
mv = memoryview(b)
69+
if mv.readonly:
70+
raise TypeError("readinto() argument must be writable")
71+
mv = mv.cast("B")
72+
if self._pos >= self._len:
73+
return 0
74+
n = min(len(mv), self._len - self._pos)
75+
mv[:n] = self._data[self._pos:self._pos + n]
76+
self._pos += n
77+
return n
78+
79+
def readinto1(self, b) -> int:
80+
return self.readinto(b)
81+
82+
def read(self, size: int = -1) -> bytes:
83+
if self.closed:
84+
raise ValueError("I/O operation on closed file.")
85+
if size is None or size < 0:
86+
size = self._len - self._pos
87+
if self._pos >= self._len:
88+
return b""
89+
end = min(self._pos + size, self._len)
90+
out = self._data[self._pos:end]
91+
self._pos = end
92+
return out
93+
94+
def read1(self, size: int = -1) -> bytes:
95+
return self.read(size)
96+
97+
def readline(self, size: int = -1) -> bytes:
98+
if self.closed:
99+
raise ValueError("I/O operation on closed file.")
100+
if self._pos >= self._len:
101+
return b""
102+
end_limit = self._len if size is None or size < 0 else min(self._len, self._pos + size)
103+
nl = self._data.find(b"\n", self._pos, end_limit)
104+
end = (nl + 1) if nl != -1 else end_limit
105+
out = self._data[self._pos:end]
106+
self._pos = end
107+
return out
108+
109+
def readlines(self, hint: int = -1) -> list[bytes]:
110+
if self.closed:
111+
raise ValueError("I/O operation on closed file.")
112+
lines: list[bytes] = []
113+
total = 0
114+
while True:
115+
line = self.readline()
116+
if not line:
117+
break
118+
lines.append(line)
119+
total += len(line)
120+
if hint is not None and 0 <= hint <= total:
121+
break
122+
return lines
123+
124+
def write(self, b) -> int:
125+
raise io.UnsupportedOperation("not writable")
126+
127+
def writelines(self, lines) -> None:
128+
raise io.UnsupportedOperation("not writable")
129+
130+
def truncate(self, size: int | None = None) -> int:
131+
raise io.UnsupportedOperation("not writable")
132+
133+
16134
def container_to_output_format(container_format: str | None) -> str | None:
17135
"""
18136
A container's `format` may be a comma-separated list of formats.
@@ -57,21 +175,31 @@ class VideoFromFile(VideoInput):
57175
Class representing video input from a file.
58176
"""
59177

60-
def __init__(self, file: str | io.BytesIO):
178+
__data: str | bytes
179+
180+
def __init__(self, file: str | io.BytesIO | bytes | bytearray | memoryview):
61181
"""
62182
Initialize the VideoFromFile object based off of either a path on disk or a BytesIO object
63183
containing the file contents.
64184
"""
65-
self.__file = file
185+
if isinstance(file, str):
186+
self.__data = file
187+
elif isinstance(file, io.BytesIO):
188+
# Snapshot to immutable bytes once to ensure re-entrant, parallel-safe readers.
189+
self.__data = file.getbuffer().tobytes()
190+
elif isinstance(file, (bytes, bytearray, memoryview)):
191+
self.__data = bytes(file)
192+
else:
193+
raise TypeError(f"Unsupported video source type: {type(file)!r}")
66194

67195
def get_stream_source(self) -> str | io.BytesIO:
68196
"""
69197
Return the underlying file source for efficient streaming.
70198
This avoids unnecessary memory copies when the source is already a file path.
71199
"""
72-
if isinstance(self.__file, io.BytesIO):
73-
self.__file.seek(0)
74-
return self.__file
200+
if isinstance(self.__data, str):
201+
return self.__data
202+
return _ReentrantBytesIO(self.__data)
75203

76204
def get_dimensions(self) -> tuple[int, int]:
77205
"""
@@ -80,14 +208,12 @@ def get_dimensions(self) -> tuple[int, int]:
80208
Returns:
81209
Tuple of (width, height)
82210
"""
83-
if isinstance(self.__file, io.BytesIO):
84-
self.__file.seek(0) # Reset the BytesIO object to the beginning
85-
with av.open(self.__file, mode='r') as container:
211+
with av.open(self.get_stream_source(), mode="r") as container:
86212
for stream in container.streams:
87213
if stream.type == 'video':
88214
assert isinstance(stream, av.VideoStream)
89215
return stream.width, stream.height
90-
raise ValueError(f"No video stream found in file '{self.__file}'")
216+
raise ValueError(f"No video stream found in {self._source_label()}")
91217

92218
def get_duration(self) -> float:
93219
"""
@@ -96,9 +222,7 @@ def get_duration(self) -> float:
96222
Returns:
97223
Duration in seconds
98224
"""
99-
if isinstance(self.__file, io.BytesIO):
100-
self.__file.seek(0)
101-
with av.open(self.__file, mode="r") as container:
225+
with av.open(self.get_stream_source(), mode="r") as container:
102226
if container.duration is not None:
103227
return float(container.duration / av.time_base)
104228

@@ -119,17 +243,14 @@ def get_duration(self) -> float:
119243
if frame_count > 0:
120244
return float(frame_count / video_stream.average_rate)
121245

122-
raise ValueError(f"Could not determine duration for file '{self.__file}'")
246+
raise ValueError(f"Could not determine duration for file '{self._source_label()}'")
123247

124248
def get_frame_count(self) -> int:
125249
"""
126250
Returns the number of frames in the video without materializing them as
127251
torch tensors.
128252
"""
129-
if isinstance(self.__file, io.BytesIO):
130-
self.__file.seek(0)
131-
132-
with av.open(self.__file, mode="r") as container:
253+
with av.open(self.get_stream_source(), mode="r") as container:
133254
video_stream = self._get_first_video_stream(container)
134255
# 1. Prefer the frames field if available
135256
if video_stream.frames and video_stream.frames > 0:
@@ -160,18 +281,15 @@ def get_frame_count(self) -> int:
160281
frame_count += 1
161282

162283
if frame_count == 0:
163-
raise ValueError(f"Could not determine frame count for file '{self.__file}'")
284+
raise ValueError(f"Could not determine frame count for file '{self._source_label()}'")
164285
return frame_count
165286

166287
def get_frame_rate(self) -> Fraction:
167288
"""
168289
Returns the average frame rate of the video using container metadata
169290
without decoding all frames.
170291
"""
171-
if isinstance(self.__file, io.BytesIO):
172-
self.__file.seek(0)
173-
174-
with av.open(self.__file, mode="r") as container:
292+
with av.open(self.get_stream_source(), mode="r") as container:
175293
video_stream = self._get_first_video_stream(container)
176294
# Preferred: use PyAV's average_rate (usually already a Fraction-like)
177295
if video_stream.average_rate:
@@ -193,9 +311,7 @@ def get_container_format(self) -> str:
193311
Returns:
194312
Container format as string
195313
"""
196-
if isinstance(self.__file, io.BytesIO):
197-
self.__file.seek(0)
198-
with av.open(self.__file, mode='r') as container:
314+
with av.open(self.get_stream_source(), mode='r') as container:
199315
return container.format.name
200316

201317
def get_components_internal(self, container: InputContainer) -> VideoComponents:
@@ -239,11 +355,8 @@ def get_components_internal(self, container: InputContainer) -> VideoComponents:
239355
return VideoComponents(images=images, audio=audio, frame_rate=frame_rate, metadata=metadata)
240356

241357
def get_components(self) -> VideoComponents:
242-
if isinstance(self.__file, io.BytesIO):
243-
self.__file.seek(0) # Reset the BytesIO object to the beginning
244-
with av.open(self.__file, mode='r') as container:
358+
with av.open(self.get_stream_source(), mode='r') as container:
245359
return self.get_components_internal(container)
246-
raise ValueError(f"No video stream found in file '{self.__file}'")
247360

248361
def save_to(
249362
self,
@@ -252,9 +365,7 @@ def save_to(
252365
codec: VideoCodec = VideoCodec.AUTO,
253366
metadata: Optional[dict] = None
254367
):
255-
if isinstance(self.__file, io.BytesIO):
256-
self.__file.seek(0) # Reset the BytesIO object to the beginning
257-
with av.open(self.__file, mode='r') as container:
368+
with av.open(self.get_stream_source(), mode='r') as container:
258369
container_format = container.format.name
259370
video_encoding = container.streams.video[0].codec.name if len(container.streams.video) > 0 else None
260371
reuse_streams = True
@@ -306,9 +417,12 @@ def save_to(
306417
def _get_first_video_stream(self, container: InputContainer):
307418
video_stream = next((s for s in container.streams if s.type == "video"), None)
308419
if video_stream is None:
309-
raise ValueError(f"No video stream found in file '{self.__file}'")
420+
raise ValueError(f"No video stream found in file '{self._source_label()}'")
310421
return video_stream
311422

423+
def _source_label(self) -> str:
424+
return self.__data if isinstance(self.__data, str) else f"<in-memory video: {len(self.__data)} bytes>"
425+
312426

313427
class VideoFromComponents(VideoInput):
314428
"""

0 commit comments

Comments
 (0)