|
8 | 8 | import json |
9 | 9 | import numbers |
10 | 10 | from pathlib import Path |
11 | | -from typing import Literal, Optional, Tuple, Union |
| 11 | +from typing import List, Literal, Optional, Sequence, Tuple, Union |
12 | 12 |
|
13 | 13 | import torch |
14 | | -from torch import device as torch_device, Tensor |
| 14 | +from torch import device as torch_device, nn, Tensor |
15 | 15 |
|
16 | 16 | from torchcodec import _core as core, Frame, FrameBatch |
17 | 17 | from torchcodec.decoders._decoder_utils import ( |
18 | 18 | _get_cuda_backend, |
19 | 19 | create_decoder, |
20 | 20 | ERROR_REPORTING_INSTRUCTIONS, |
21 | 21 | ) |
| 22 | +from torchcodec.transforms import DecoderTransform, Resize |
22 | 23 |
|
23 | 24 |
|
24 | 25 | class VideoDecoder: |
@@ -67,6 +68,11 @@ class VideoDecoder: |
67 | 68 | probably is. Default: "exact". |
68 | 69 | Read more about this parameter in: |
69 | 70 | :ref:`sphx_glr_generated_examples_decoding_approximate_mode.py` |
| 71 | + transforms (sequence of transform objects, optional): Sequence of transforms to be |
| 72 | + applied to the decoded frames by the decoder itself, in order. Accepts both |
| 73 | + :class:`~torchcodec.transforms.DecoderTransform` and |
| 74 | + :class:`~torchvision.transforms.v2.Transform` |
| 75 | + objects. Read more about this parameter in: TODO_DECODER_TRANSFORMS_TUTORIAL. |
70 | 76 | custom_frame_mappings (str, bytes, or file-like object, optional): |
71 | 77 | Mapping of frames to their metadata, typically generated via ffprobe. |
72 | 78 | This enables accurate frame seeking without requiring a full video scan. |
@@ -105,6 +111,7 @@ def __init__( |
105 | 111 | num_ffmpeg_threads: int = 1, |
106 | 112 | device: Optional[Union[str, torch_device]] = None, |
107 | 113 | seek_mode: Literal["exact", "approximate"] = "exact", |
| 114 | + transforms: Optional[Sequence[Union[DecoderTransform, nn.Module]]] = None, |
108 | 115 | custom_frame_mappings: Optional[ |
109 | 116 | Union[str, bytes, io.RawIOBase, io.BufferedReader] |
110 | 117 | ] = None, |
@@ -151,13 +158,16 @@ def __init__( |
151 | 158 |
|
152 | 159 | device_variant = _get_cuda_backend() |
153 | 160 |
|
| 161 | + transform_specs = _make_transform_specs(transforms) |
| 162 | + |
154 | 163 | core.add_video_stream( |
155 | 164 | self._decoder, |
156 | 165 | stream_index=stream_index, |
157 | 166 | dimension_order=dimension_order, |
158 | 167 | num_threads=num_ffmpeg_threads, |
159 | 168 | device=device, |
160 | 169 | device_variant=device_variant, |
| 170 | + transform_specs=transform_specs, |
161 | 171 | custom_frame_mappings=custom_frame_mappings_data, |
162 | 172 | ) |
163 | 173 |
|
@@ -435,6 +445,78 @@ def _get_and_validate_stream_metadata( |
435 | 445 | ) |
436 | 446 |
|
437 | 447 |
|
| 448 | +def _convert_to_decoder_transforms( |
| 449 | + transforms: Sequence[Union[DecoderTransform, nn.Module]], |
| 450 | +) -> List[DecoderTransform]: |
| 451 | + """Convert a sequence of transforms that may contain TorchVision transform |
| 452 | + objects into a list of only TorchCodec transform objects. |
| 453 | +
|
| 454 | + Args: |
| 455 | + transforms: Squence of transform objects. The objects can be one of two |
| 456 | + types: |
| 457 | + 1. torchcodec.transforms.DecoderTransform |
| 458 | + 2. torchvision.transforms.v2.Transform, but our type annotation |
| 459 | + only mentions its base, nn.Module. We don't want to take a |
| 460 | + hard dependency on TorchVision. |
| 461 | +
|
| 462 | + Returns: |
| 463 | + List of DecoderTransform objects. |
| 464 | + """ |
| 465 | + try: |
| 466 | + from torchvision.transforms import v2 |
| 467 | + |
| 468 | + tv_available = True |
| 469 | + except ImportError: |
| 470 | + tv_available = False |
| 471 | + |
| 472 | + converted_transforms: list[DecoderTransform] = [] |
| 473 | + for transform in transforms: |
| 474 | + if not isinstance(transform, DecoderTransform): |
| 475 | + if not tv_available: |
| 476 | + raise ValueError( |
| 477 | + f"The supplied transform, {transform}, is not a TorchCodec " |
| 478 | + " DecoderTransform. TorchCodec also accept TorchVision " |
| 479 | + "v2 transforms, but TorchVision is not installed." |
| 480 | + ) |
| 481 | + elif isinstance(transform, v2.Resize): |
| 482 | + converted_transforms.append(Resize._from_torchvision(transform)) |
| 483 | + else: |
| 484 | + raise ValueError( |
| 485 | + f"Unsupported transform: {transform}. Transforms must be " |
| 486 | + "either a TorchCodec DecoderTransform or a TorchVision " |
| 487 | + "v2 transform." |
| 488 | + ) |
| 489 | + else: |
| 490 | + converted_transforms.append(transform) |
| 491 | + |
| 492 | + return converted_transforms |
| 493 | + |
| 494 | + |
| 495 | +def _make_transform_specs( |
| 496 | + transforms: Optional[Sequence[Union[DecoderTransform, nn.Module]]], |
| 497 | +) -> str: |
| 498 | + """Given a sequence of transforms, turn those into the specification string |
| 499 | + the core API expects. |
| 500 | +
|
| 501 | + Args: |
| 502 | + transforms: Optional sequence of transform objects. The objects can be |
| 503 | + one of two types: |
| 504 | + 1. torchcodec.transforms.DecoderTransform |
| 505 | + 2. torchvision.transforms.v2.Transform, but our type annotation |
| 506 | + only mentions its base, nn.Module. We don't want to take a |
| 507 | + hard dependency on TorchVision. |
| 508 | +
|
| 509 | + Returns: |
| 510 | + String of transforms in the format the core API expects: transform |
| 511 | + specifications separate by semicolons. |
| 512 | + """ |
| 513 | + if transforms is None: |
| 514 | + return "" |
| 515 | + |
| 516 | + transforms = _convert_to_decoder_transforms(transforms) |
| 517 | + return ";".join([t._make_transform_spec() for t in transforms]) |
| 518 | + |
| 519 | + |
438 | 520 | def _read_custom_frame_mappings( |
439 | 521 | custom_frame_mappings: Union[str, bytes, io.RawIOBase, io.BufferedReader] |
440 | 522 | ) -> tuple[Tensor, Tensor, Tensor]: |
|
0 commit comments