Skip to content

Commit 0f3f670

Browse files
committed
fix VideoFromFile stream source to be re-entrant for parallel async use
1 parent 0899012 commit 0f3f670

File tree

2 files changed

+129
-38
lines changed

2 files changed

+129
-38
lines changed

comfy_api/latest/_input/video_types.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,13 +34,13 @@ def save_to(
3434
"""
3535
pass
3636

37-
def get_stream_source(self) -> Union[str, io.BytesIO]:
37+
def get_stream_source(self) -> Union[str, IO[bytes]]:
3838
"""
3939
Get a streamable source for the video. This allows processing without
4040
loading the entire video into memory.
4141
4242
Returns:
43-
Either a file path (str) or a BytesIO object that can be opened with av.
43+
Either a file path (str) or a binary file-like object (IO[bytes]) that can be opened with av.
4444
4545
Default implementation creates a BytesIO buffer, but subclasses should
4646
override this for better performance when possible.

comfy_api/latest/_input_impl/video_types.py

Lines changed: 127 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,8 @@
22
from av.container import InputContainer
33
from av.subtitles.stream import SubtitleStream
44
from fractions import Fraction
5-
from typing import Optional
5+
from typing import Optional, IO, Iterator
6+
from contextlib import contextmanager, suppress
67
from .._input import AudioInput, VideoInput
78
import av
89
import io
@@ -13,6 +14,87 @@
1314
from .._util import VideoContainer, VideoCodec, VideoComponents
1415

1516

17+
class _ReentrantBytesReader(io.RawIOBase):
18+
"""A lightweight, read-only, seekable stream over shared immutable bytes with an independent cursor."""
19+
20+
def __init__(self, data: bytes):
21+
super().__init__()
22+
self._data = data
23+
self._view = memoryview(data)
24+
self._pos = 0
25+
26+
def readable(self) -> bool:
27+
return True
28+
29+
def writable(self) -> bool:
30+
return False
31+
32+
def seekable(self) -> bool:
33+
return True
34+
35+
def tell(self) -> int:
36+
return self._pos
37+
38+
def seek(self, offset: int, whence: int = io.SEEK_SET) -> int:
39+
if self.closed:
40+
raise ValueError("I/O operation on closed file.")
41+
if whence == io.SEEK_SET:
42+
new_pos = offset
43+
elif whence == io.SEEK_CUR:
44+
new_pos = self._pos + offset
45+
elif whence == io.SEEK_END:
46+
new_pos = len(self._view) + offset
47+
else:
48+
raise ValueError(f"Invalid whence: {whence}")
49+
if new_pos < 0:
50+
raise ValueError("Negative seek position")
51+
self._pos = new_pos
52+
return self._pos
53+
54+
def readinto(self, b) -> int:
55+
if self.closed:
56+
raise ValueError("I/O operation on closed file.")
57+
mv = memoryview(b)
58+
if mv.readonly:
59+
raise TypeError("readinto() argument must be writable")
60+
mv = mv.cast("B")
61+
if self._pos >= len(self._view):
62+
return 0
63+
n = min(len(mv), len(self._view) - self._pos)
64+
mv[:n] = self._view[self._pos:self._pos + n]
65+
self._pos += n
66+
return n
67+
68+
def read(self, size: int = -1) -> bytes:
69+
if self.closed:
70+
raise ValueError("I/O operation on closed file.")
71+
if size is None or size < 0:
72+
size = len(self._view) - self._pos
73+
if self._pos >= len(self._view):
74+
return b""
75+
end = min(self._pos + size, len(self._view))
76+
out = self._view[self._pos:end].tobytes()
77+
self._pos = end
78+
return out
79+
80+
def close(self) -> None:
81+
with suppress(Exception):
82+
self._view.release()
83+
super().close()
84+
85+
def getvalue(self) -> bytes:
86+
"""Return the entire underlying byte payload like io.BytesIO.getvalue()."""
87+
if self.closed:
88+
raise ValueError("I/O operation on closed file.")
89+
return self._data
90+
91+
def getbuffer(self) -> memoryview:
92+
"""Return a readonly memoryview over the payload like io.BytesIO.getbuffer()."""
93+
if self.closed:
94+
raise ValueError("I/O operation on closed file.")
95+
return memoryview(self._data) # return new memoryview so external .release() can't break our internal _view
96+
97+
1698
def container_to_output_format(container_format: str | None) -> str | None:
1799
"""
18100
A container's `format` may be a comma-separated list of formats.
@@ -57,21 +139,31 @@ class VideoFromFile(VideoInput):
57139
Class representing video input from a file.
58140
"""
59141

60-
def __init__(self, file: str | io.BytesIO):
142+
def __init__(self, file: str | io.BytesIO | bytes | bytearray | memoryview):
61143
"""
62144
Initialize the VideoFromFile object based off of either a path on disk or a BytesIO object
63145
containing the file contents.
64146
"""
65-
self.__file = file
66-
67-
def get_stream_source(self) -> str | io.BytesIO:
147+
self.__path: Optional[str] = None
148+
self.__data: Optional[bytes] = None
149+
if isinstance(file, str):
150+
self.__path = file
151+
elif isinstance(file, io.BytesIO):
152+
# Snapshot to immutable bytes once to ensure re-entrant, parallel-safe readers.
153+
self.__data = file.getbuffer().tobytes()
154+
elif isinstance(file, (bytes, bytearray, memoryview)):
155+
self.__data = bytes(file)
156+
else:
157+
raise TypeError(f"Unsupported video source type: {type(file)!r}")
158+
159+
def get_stream_source(self) -> str | IO[bytes]:
68160
"""
69161
Return the underlying file source for efficient streaming.
70162
This avoids unnecessary memory copies when the source is already a file path.
71163
"""
72-
if isinstance(self.__file, io.BytesIO):
73-
self.__file.seek(0)
74-
return self.__file
164+
if self.__path is not None:
165+
return self.__path
166+
return _ReentrantBytesReader(self.__data)
75167

76168
def get_dimensions(self) -> tuple[int, int]:
77169
"""
@@ -80,14 +172,12 @@ def get_dimensions(self) -> tuple[int, int]:
80172
Returns:
81173
Tuple of (width, height)
82174
"""
83-
if isinstance(self.__file, io.BytesIO):
84-
self.__file.seek(0) # Reset the BytesIO object to the beginning
85-
with av.open(self.__file, mode='r') as container:
175+
with self._open_source() as src, av.open(src, mode="r") as container:
86176
for stream in container.streams:
87177
if stream.type == 'video':
88178
assert isinstance(stream, av.VideoStream)
89179
return stream.width, stream.height
90-
raise ValueError(f"No video stream found in file '{self.__file}'")
180+
raise ValueError(f"No video stream found in {self._source_label()}")
91181

92182
def get_duration(self) -> float:
93183
"""
@@ -96,9 +186,7 @@ def get_duration(self) -> float:
96186
Returns:
97187
Duration in seconds
98188
"""
99-
if isinstance(self.__file, io.BytesIO):
100-
self.__file.seek(0)
101-
with av.open(self.__file, mode="r") as container:
189+
with self._open_source() as src, av.open(src, mode="r") as container:
102190
if container.duration is not None:
103191
return float(container.duration / av.time_base)
104192

@@ -119,17 +207,14 @@ def get_duration(self) -> float:
119207
if frame_count > 0:
120208
return float(frame_count / video_stream.average_rate)
121209

122-
raise ValueError(f"Could not determine duration for file '{self.__file}'")
210+
raise ValueError(f"Could not determine duration for file '{self._source_label()}'")
123211

124212
def get_frame_count(self) -> int:
125213
"""
126214
Returns the number of frames in the video without materializing them as
127215
torch tensors.
128216
"""
129-
if isinstance(self.__file, io.BytesIO):
130-
self.__file.seek(0)
131-
132-
with av.open(self.__file, mode="r") as container:
217+
with self._open_source() as src, av.open(src, mode="r") as container:
133218
video_stream = self._get_first_video_stream(container)
134219
# 1. Prefer the frames field if available
135220
if video_stream.frames and video_stream.frames > 0:
@@ -160,18 +245,15 @@ def get_frame_count(self) -> int:
160245
frame_count += 1
161246

162247
if frame_count == 0:
163-
raise ValueError(f"Could not determine frame count for file '{self.__file}'")
248+
raise ValueError(f"Could not determine frame count for file '{self._source_label()}'")
164249
return frame_count
165250

166251
def get_frame_rate(self) -> Fraction:
167252
"""
168253
Returns the average frame rate of the video using container metadata
169254
without decoding all frames.
170255
"""
171-
if isinstance(self.__file, io.BytesIO):
172-
self.__file.seek(0)
173-
174-
with av.open(self.__file, mode="r") as container:
256+
with self._open_source() as src, av.open(src, mode="r") as container:
175257
video_stream = self._get_first_video_stream(container)
176258
# Preferred: use PyAV's average_rate (usually already a Fraction-like)
177259
if video_stream.average_rate:
@@ -193,9 +275,7 @@ def get_container_format(self) -> str:
193275
Returns:
194276
Container format as string
195277
"""
196-
if isinstance(self.__file, io.BytesIO):
197-
self.__file.seek(0)
198-
with av.open(self.__file, mode='r') as container:
278+
with self._open_source() as src, av.open(src, mode='r') as container:
199279
return container.format.name
200280

201281
def get_components_internal(self, container: InputContainer) -> VideoComponents:
@@ -239,11 +319,8 @@ def get_components_internal(self, container: InputContainer) -> VideoComponents:
239319
return VideoComponents(images=images, audio=audio, frame_rate=frame_rate, metadata=metadata)
240320

241321
def get_components(self) -> VideoComponents:
242-
if isinstance(self.__file, io.BytesIO):
243-
self.__file.seek(0) # Reset the BytesIO object to the beginning
244-
with av.open(self.__file, mode='r') as container:
322+
with self._open_source() as src, av.open(src, mode='r') as container:
245323
return self.get_components_internal(container)
246-
raise ValueError(f"No video stream found in file '{self.__file}'")
247324

248325
def save_to(
249326
self,
@@ -252,9 +329,7 @@ def save_to(
252329
codec: VideoCodec = VideoCodec.AUTO,
253330
metadata: Optional[dict] = None
254331
):
255-
if isinstance(self.__file, io.BytesIO):
256-
self.__file.seek(0) # Reset the BytesIO object to the beginning
257-
with av.open(self.__file, mode='r') as container:
332+
with self._open_source() as src, av.open(src, mode='r') as container:
258333
container_format = container.format.name
259334
video_encoding = container.streams.video[0].codec.name if len(container.streams.video) > 0 else None
260335
reuse_streams = True
@@ -306,9 +381,25 @@ def save_to(
306381
def _get_first_video_stream(self, container: InputContainer):
307382
video_stream = next((s for s in container.streams if s.type == "video"), None)
308383
if video_stream is None:
309-
raise ValueError(f"No video stream found in file '{self.__file}'")
384+
raise ValueError(f"No video stream found in file '{self._source_label()}'")
310385
return video_stream
311386

387+
def _source_label(self) -> str:
388+
if self.__path is not None:
389+
return self.__path
390+
return f"<in-memory video: {len(self.__data)} bytes>"
391+
392+
@contextmanager
393+
def _open_source(self) -> Iterator[str | IO[bytes]]:
394+
"""Internal helper to ensure file-like sources are closed after use."""
395+
src = self.get_stream_source()
396+
try:
397+
yield src
398+
finally:
399+
if not isinstance(src, str):
400+
with suppress(Exception):
401+
src.close()
402+
312403

313404
class VideoFromComponents(VideoInput):
314405
"""

0 commit comments

Comments
 (0)