Skip to content

Commit a06fd20

Browse files
committed
fix VideoFromFile stream source to be re-entrant for parallel async use
1 parent 0899012 commit a06fd20

File tree

2 files changed

+118
-38
lines changed

2 files changed

+118
-38
lines changed

comfy_api/latest/_input/video_types.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,13 +34,13 @@ def save_to(
3434
"""
3535
pass
3636

37-
def get_stream_source(self) -> Union[str, io.BytesIO]:
37+
def get_stream_source(self) -> Union[str, IO[bytes]]:
3838
"""
3939
Get a streamable source for the video. This allows processing without
4040
loading the entire video into memory.
4141
4242
Returns:
43-
Either a file path (str) or a BytesIO object that can be opened with av.
43+
Either a file path (str) or a binary file-like object (IO[bytes]) that can be opened with av.
4444
4545
Default implementation creates a BytesIO buffer, but subclasses should
4646
override this for better performance when possible.

comfy_api/latest/_input_impl/video_types.py

Lines changed: 116 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,8 @@
22
from av.container import InputContainer
33
from av.subtitles.stream import SubtitleStream
44
from fractions import Fraction
5-
from typing import Optional
5+
from typing import Optional, IO, Iterator
6+
from contextlib import contextmanager, suppress
67
from .._input import AudioInput, VideoInput
78
import av
89
import io
@@ -13,6 +14,74 @@
1314
from .._util import VideoContainer, VideoCodec, VideoComponents
1415

1516

17+
class _ReentrantBytesReader(io.RawIOBase):
18+
"""A lightweight, read-only, seekable stream over shared immutable bytes with an independent cursor."""
19+
20+
def __init__(self, data: bytes):
21+
super().__init__()
22+
self._view = memoryview(data)
23+
self._pos = 0
24+
25+
def readable(self) -> bool:
26+
return True
27+
28+
def writable(self) -> bool:
29+
return False
30+
31+
def seekable(self) -> bool:
32+
return True
33+
34+
def tell(self) -> int:
35+
return self._pos
36+
37+
def seek(self, offset: int, whence: int = io.SEEK_SET) -> int:
38+
if self.closed:
39+
raise ValueError("I/O operation on closed file.")
40+
if whence == io.SEEK_SET:
41+
new_pos = offset
42+
elif whence == io.SEEK_CUR:
43+
new_pos = self._pos + offset
44+
elif whence == io.SEEK_END:
45+
new_pos = len(self._view) + offset
46+
else:
47+
raise ValueError(f"Invalid whence: {whence}")
48+
if new_pos < 0:
49+
raise ValueError("Negative seek position")
50+
self._pos = new_pos
51+
return self._pos
52+
53+
def readinto(self, b) -> int:
54+
if self.closed:
55+
raise ValueError("I/O operation on closed file.")
56+
mv = memoryview(b)
57+
if mv.readonly:
58+
raise TypeError("readinto() argument must be writable")
59+
mv = mv.cast("B")
60+
if self._pos >= len(self._view):
61+
return 0
62+
n = min(len(mv), len(self._view) - self._pos)
63+
mv[:n] = self._view[self._pos:self._pos + n]
64+
self._pos += n
65+
return n
66+
67+
def read(self, size: int = -1) -> bytes:
68+
if self.closed:
69+
raise ValueError("I/O operation on closed file.")
70+
if size is None or size < 0:
71+
size = len(self._view) - self._pos
72+
if self._pos >= len(self._view):
73+
return b""
74+
end = min(self._pos + size, len(self._view))
75+
out = self._view[self._pos:end].tobytes()
76+
self._pos = end
77+
return out
78+
79+
def close(self) -> None:
80+
with suppress(Exception):
81+
self._view.release()
82+
super().close()
83+
84+
1685
def container_to_output_format(container_format: str | None) -> str | None:
1786
"""
1887
A container's `format` may be a comma-separated list of formats.
@@ -57,21 +126,31 @@ class VideoFromFile(VideoInput):
57126
Class representing video input from a file.
58127
"""
59128

60-
def __init__(self, file: str | io.BytesIO):
129+
def __init__(self, file: str | io.BytesIO | bytes | bytearray | memoryview):
61130
"""
62131
Initialize the VideoFromFile object based off of either a path on disk or a BytesIO object
63132
containing the file contents.
64133
"""
65-
self.__file = file
66-
67-
def get_stream_source(self) -> str | io.BytesIO:
134+
self.__path: Optional[str] = None
135+
self.__data: Optional[bytes] = None
136+
if isinstance(file, str):
137+
self.__path = file
138+
elif isinstance(file, io.BytesIO):
139+
# Snapshot to immutable bytes once to ensure re-entrant, parallel-safe readers.
140+
self.__data = file.getbuffer().tobytes()
141+
elif isinstance(file, (bytes, bytearray, memoryview)):
142+
self.__data = bytes(file)
143+
else:
144+
raise TypeError(f"Unsupported video source type: {type(file)!r}")
145+
146+
def get_stream_source(self) -> str | IO[bytes] | io.RawIOBase:
68147
"""
69148
Return the underlying file source for efficient streaming.
70149
This avoids unnecessary memory copies when the source is already a file path.
71150
"""
72-
if isinstance(self.__file, io.BytesIO):
73-
self.__file.seek(0)
74-
return self.__file
151+
if self.__path is not None:
152+
return self.__path
153+
return _ReentrantBytesReader(self.__data)
75154

76155
def get_dimensions(self) -> tuple[int, int]:
77156
"""
@@ -80,14 +159,12 @@ def get_dimensions(self) -> tuple[int, int]:
80159
Returns:
81160
Tuple of (width, height)
82161
"""
83-
if isinstance(self.__file, io.BytesIO):
84-
self.__file.seek(0) # Reset the BytesIO object to the beginning
85-
with av.open(self.__file, mode='r') as container:
162+
with self._open_source() as src, av.open(src, mode="r") as container:
86163
for stream in container.streams:
87164
if stream.type == 'video':
88165
assert isinstance(stream, av.VideoStream)
89166
return stream.width, stream.height
90-
raise ValueError(f"No video stream found in file '{self.__file}'")
167+
raise ValueError(f"No video stream found in {self._source_label()}")
91168

92169
def get_duration(self) -> float:
93170
"""
@@ -96,9 +173,7 @@ def get_duration(self) -> float:
96173
Returns:
97174
Duration in seconds
98175
"""
99-
if isinstance(self.__file, io.BytesIO):
100-
self.__file.seek(0)
101-
with av.open(self.__file, mode="r") as container:
176+
with self._open_source() as src, av.open(src, mode="r") as container:
102177
if container.duration is not None:
103178
return float(container.duration / av.time_base)
104179

@@ -119,17 +194,14 @@ def get_duration(self) -> float:
119194
if frame_count > 0:
120195
return float(frame_count / video_stream.average_rate)
121196

122-
raise ValueError(f"Could not determine duration for file '{self.__file}'")
197+
raise ValueError(f"Could not determine duration for file '{self._source_label()}'")
123198

124199
def get_frame_count(self) -> int:
125200
"""
126201
Returns the number of frames in the video without materializing them as
127202
torch tensors.
128203
"""
129-
if isinstance(self.__file, io.BytesIO):
130-
self.__file.seek(0)
131-
132-
with av.open(self.__file, mode="r") as container:
204+
with self._open_source() as src, av.open(src, mode="r") as container:
133205
video_stream = self._get_first_video_stream(container)
134206
# 1. Prefer the frames field if available
135207
if video_stream.frames and video_stream.frames > 0:
@@ -160,18 +232,15 @@ def get_frame_count(self) -> int:
160232
frame_count += 1
161233

162234
if frame_count == 0:
163-
raise ValueError(f"Could not determine frame count for file '{self.__file}'")
235+
raise ValueError(f"Could not determine frame count for file '{self._source_label()}'")
164236
return frame_count
165237

166238
def get_frame_rate(self) -> Fraction:
167239
"""
168240
Returns the average frame rate of the video using container metadata
169241
without decoding all frames.
170242
"""
171-
if isinstance(self.__file, io.BytesIO):
172-
self.__file.seek(0)
173-
174-
with av.open(self.__file, mode="r") as container:
243+
with self._open_source() as src, av.open(src, mode="r") as container:
175244
video_stream = self._get_first_video_stream(container)
176245
# Preferred: use PyAV's average_rate (usually already a Fraction-like)
177246
if video_stream.average_rate:
@@ -193,9 +262,7 @@ def get_container_format(self) -> str:
193262
Returns:
194263
Container format as string
195264
"""
196-
if isinstance(self.__file, io.BytesIO):
197-
self.__file.seek(0)
198-
with av.open(self.__file, mode='r') as container:
265+
with self._open_source() as src, av.open(src, mode='r') as container:
199266
return container.format.name
200267

201268
def get_components_internal(self, container: InputContainer) -> VideoComponents:
@@ -239,11 +306,8 @@ def get_components_internal(self, container: InputContainer) -> VideoComponents:
239306
return VideoComponents(images=images, audio=audio, frame_rate=frame_rate, metadata=metadata)
240307

241308
def get_components(self) -> VideoComponents:
242-
if isinstance(self.__file, io.BytesIO):
243-
self.__file.seek(0) # Reset the BytesIO object to the beginning
244-
with av.open(self.__file, mode='r') as container:
309+
with self._open_source() as src, av.open(src, mode='r') as container:
245310
return self.get_components_internal(container)
246-
raise ValueError(f"No video stream found in file '{self.__file}'")
247311

248312
def save_to(
249313
self,
@@ -252,9 +316,7 @@ def save_to(
252316
codec: VideoCodec = VideoCodec.AUTO,
253317
metadata: Optional[dict] = None
254318
):
255-
if isinstance(self.__file, io.BytesIO):
256-
self.__file.seek(0) # Reset the BytesIO object to the beginning
257-
with av.open(self.__file, mode='r') as container:
319+
with self._open_source() as src, av.open(src, mode='r') as container:
258320
container_format = container.format.name
259321
video_encoding = container.streams.video[0].codec.name if len(container.streams.video) > 0 else None
260322
reuse_streams = True
@@ -306,9 +368,25 @@ def save_to(
306368
def _get_first_video_stream(self, container: InputContainer):
307369
video_stream = next((s for s in container.streams if s.type == "video"), None)
308370
if video_stream is None:
309-
raise ValueError(f"No video stream found in file '{self.__file}'")
371+
raise ValueError(f"No video stream found in file '{self._source_label()}'")
310372
return video_stream
311373

374+
def _source_label(self) -> str:
375+
if self.__path is not None:
376+
return self.__path
377+
return f"<in-memory video: {len(self.__data)} bytes>"
378+
379+
@contextmanager
380+
def _open_source(self) -> Iterator[str | IO[bytes]]:
381+
"""Internal helper to ensure file-like sources are closed after use."""
382+
src = self.get_stream_source()
383+
try:
384+
yield src
385+
finally:
386+
if not isinstance(src, str):
387+
with suppress(Exception):
388+
src.close()
389+
312390

313391
class VideoFromComponents(VideoInput):
314392
"""
@@ -381,3 +459,5 @@ def save_to(
381459

382460
# Flush encoder
383461
output.mux(audio_stream.encode(None))
462+
463+

0 commit comments

Comments
 (0)