Skip to content

Commit 0535b00

Browse files
authored
Decoder-native resize public implementation (#1003)
1 parent ab4cf29 commit 0535b00

File tree

10 files changed

+321
-67
lines changed

10 files changed

+321
-67
lines changed

.github/workflows/lint.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ jobs:
6262
run: python -m pip install --upgrade pip
6363
- name: Install dependencies and FFmpeg
6464
run: |
65-
python -m pip install --pre torch --index-url https://download.pytorch.org/whl/nightly/cpu
65+
python -m pip install --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/cpu
6666
conda install "ffmpeg=7.0.1" pkg-config pybind11 -c conda-forge
6767
ffmpeg -version
6868
- name: Build and install torchcodec

docs/source/api_ref_transforms.rst

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
.. _transforms:
2+
3+
=====================
4+
torchcodec.transforms
5+
=====================
6+
7+
.. currentmodule:: torchcodec.transforms
8+
9+
For a tutorial, see: TODO_DECODER_TRANSFORMS_TUTORIAL.
10+
11+
.. autosummary::
12+
:toctree: generated/
13+
:nosignatures:
14+
:template: dataclass.rst
15+
16+
DecoderTransform
17+
Resize

docs/source/conf.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,7 @@ def __call__(self, filename):
209209
intersphinx_mapping = {
210210
"python": ("https://docs.python.org/3/", None),
211211
"torch": ("https://pytorch.org/docs/stable/", None),
212+
"torchvision": ("https://docs.pytorch.org/vision/stable/", None),
212213
"numpy": ("https://numpy.org/doc/stable/", None),
213214
"PIL": ("https://pillow.readthedocs.io/en/stable/", None),
214215
"matplotlib": ("https://matplotlib.org/stable/", None),

docs/source/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,3 +125,4 @@ Encoding
125125
api_ref_decoders
126126
api_ref_encoders
127127
api_ref_samplers
128+
api_ref_transforms

mypy.ini

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,3 +4,4 @@ files = src/torchcodec
44
show_error_codes = True
55
pretty = True
66
allow_redefinition = True
7+
follow_untyped_imports = True

src/torchcodec/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
# Note: usort wants to put Frame and FrameBatch after decoders and samplers,
1010
# but that results in circular import.
1111
from ._frame import AudioSamples, Frame, FrameBatch # usort:skip # noqa
12-
from . import decoders, encoders, samplers # noqa
12+
from . import decoders, encoders, samplers, transforms # noqa
1313

1414
try:
1515
# Note that version.py is generated during install.

src/torchcodec/decoders/_video_decoder.py

Lines changed: 84 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,17 +8,18 @@
88
import json
99
import numbers
1010
from pathlib import Path
11-
from typing import Literal, Optional, Tuple, Union
11+
from typing import List, Literal, Optional, Sequence, Tuple, Union
1212

1313
import torch
14-
from torch import device as torch_device, Tensor
14+
from torch import device as torch_device, nn, Tensor
1515

1616
from torchcodec import _core as core, Frame, FrameBatch
1717
from torchcodec.decoders._decoder_utils import (
1818
_get_cuda_backend,
1919
create_decoder,
2020
ERROR_REPORTING_INSTRUCTIONS,
2121
)
22+
from torchcodec.transforms import DecoderTransform, Resize
2223

2324

2425
class VideoDecoder:
@@ -67,6 +68,11 @@ class VideoDecoder:
6768
probably is. Default: "exact".
6869
Read more about this parameter in:
6970
:ref:`sphx_glr_generated_examples_decoding_approximate_mode.py`
71+
transforms (sequence of transform objects, optional): Sequence of transforms to be
72+
applied to the decoded frames by the decoder itself, in order. Accepts both
73+
:class:`~torchcodec.transforms.DecoderTransform` and
74+
:class:`~torchvision.transforms.v2.Transform`
75+
objects. Read more about this parameter in: TODO_DECODER_TRANSFORMS_TUTORIAL.
7076
custom_frame_mappings (str, bytes, or file-like object, optional):
7177
Mapping of frames to their metadata, typically generated via ffprobe.
7278
This enables accurate frame seeking without requiring a full video scan.
@@ -105,6 +111,7 @@ def __init__(
105111
num_ffmpeg_threads: int = 1,
106112
device: Optional[Union[str, torch_device]] = None,
107113
seek_mode: Literal["exact", "approximate"] = "exact",
114+
transforms: Optional[Sequence[Union[DecoderTransform, nn.Module]]] = None,
108115
custom_frame_mappings: Optional[
109116
Union[str, bytes, io.RawIOBase, io.BufferedReader]
110117
] = None,
@@ -151,13 +158,16 @@ def __init__(
151158

152159
device_variant = _get_cuda_backend()
153160

161+
transform_specs = _make_transform_specs(transforms)
162+
154163
core.add_video_stream(
155164
self._decoder,
156165
stream_index=stream_index,
157166
dimension_order=dimension_order,
158167
num_threads=num_ffmpeg_threads,
159168
device=device,
160169
device_variant=device_variant,
170+
transform_specs=transform_specs,
161171
custom_frame_mappings=custom_frame_mappings_data,
162172
)
163173

@@ -435,6 +445,78 @@ def _get_and_validate_stream_metadata(
435445
)
436446

437447

448+
def _convert_to_decoder_transforms(
449+
transforms: Sequence[Union[DecoderTransform, nn.Module]],
450+
) -> List[DecoderTransform]:
451+
"""Convert a sequence of transforms that may contain TorchVision transform
452+
objects into a list of only TorchCodec transform objects.
453+
454+
Args:
455+
transforms: Squence of transform objects. The objects can be one of two
456+
types:
457+
1. torchcodec.transforms.DecoderTransform
458+
2. torchvision.transforms.v2.Transform, but our type annotation
459+
only mentions its base, nn.Module. We don't want to take a
460+
hard dependency on TorchVision.
461+
462+
Returns:
463+
List of DecoderTransform objects.
464+
"""
465+
try:
466+
from torchvision.transforms import v2
467+
468+
tv_available = True
469+
except ImportError:
470+
tv_available = False
471+
472+
converted_transforms: list[DecoderTransform] = []
473+
for transform in transforms:
474+
if not isinstance(transform, DecoderTransform):
475+
if not tv_available:
476+
raise ValueError(
477+
f"The supplied transform, {transform}, is not a TorchCodec "
478+
" DecoderTransform. TorchCodec also accept TorchVision "
479+
"v2 transforms, but TorchVision is not installed."
480+
)
481+
elif isinstance(transform, v2.Resize):
482+
converted_transforms.append(Resize._from_torchvision(transform))
483+
else:
484+
raise ValueError(
485+
f"Unsupported transform: {transform}. Transforms must be "
486+
"either a TorchCodec DecoderTransform or a TorchVision "
487+
"v2 transform."
488+
)
489+
else:
490+
converted_transforms.append(transform)
491+
492+
return converted_transforms
493+
494+
495+
def _make_transform_specs(
496+
transforms: Optional[Sequence[Union[DecoderTransform, nn.Module]]],
497+
) -> str:
498+
"""Given a sequence of transforms, turn those into the specification string
499+
the core API expects.
500+
501+
Args:
502+
transforms: Optional sequence of transform objects. The objects can be
503+
one of two types:
504+
1. torchcodec.transforms.DecoderTransform
505+
2. torchvision.transforms.v2.Transform, but our type annotation
506+
only mentions its base, nn.Module. We don't want to take a
507+
hard dependency on TorchVision.
508+
509+
Returns:
510+
String of transforms in the format the core API expects: transform
511+
specifications separate by semicolons.
512+
"""
513+
if transforms is None:
514+
return ""
515+
516+
transforms = _convert_to_decoder_transforms(transforms)
517+
return ";".join([t._make_transform_spec() for t in transforms])
518+
519+
438520
def _read_custom_frame_mappings(
439521
custom_frame_mappings: Union[str, bytes, io.RawIOBase, io.BufferedReader]
440522
) -> tuple[Tensor, Tensor, Tensor]:
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from ._decoder_transforms import DecoderTransform, Resize # noqa
Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from abc import ABC, abstractmethod
8+
from dataclasses import dataclass
9+
from types import ModuleType
10+
from typing import Sequence
11+
12+
from torch import nn
13+
14+
15+
@dataclass
16+
class DecoderTransform(ABC):
17+
"""Base class for all decoder transforms.
18+
19+
A *decoder transform* is a transform that is applied by the decoder before
20+
returning the decoded frame. Applying decoder transforms to frames
21+
should be both faster and more memory efficient than receiving normally
22+
decoded frames and applying the same kind of transform.
23+
24+
Most ``DecoderTransform`` objects have a complementary transform in TorchVision,
25+
specificially in `torchvision.transforms.v2 <https://docs.pytorch.org/vision/stable/transforms.html>`_. For such transforms, we
26+
ensure that:
27+
28+
1. The names are the same.
29+
2. Default behaviors are the same.
30+
3. The parameters for the ``DecoderTransform`` object are a subset of the
31+
TorchVision :class:`~torchvision.transforms.v2.Transform` object.
32+
4. Parameters with the same name control the same behavior and accept a
33+
subset of the same types.
34+
5. The difference between the frames returned by a decoder transform and
35+
the complementary TorchVision transform are such that a model should
36+
not be able to tell the difference.
37+
"""
38+
39+
@abstractmethod
40+
def _make_transform_spec(self) -> str:
41+
pass
42+
43+
44+
def import_torchvision_transforms_v2() -> ModuleType:
45+
try:
46+
from torchvision.transforms import v2
47+
except ImportError as e:
48+
raise RuntimeError(
49+
"Cannot import TorchVision; this should never happen, please report a bug."
50+
) from e
51+
return v2
52+
53+
54+
@dataclass
55+
class Resize(DecoderTransform):
56+
"""Resize the decoded frame to a given size.
57+
58+
Complementary TorchVision transform: :class:`~torchvision.transforms.v2.Resize`.
59+
Interpolation is always bilinear. Anti-aliasing is always on.
60+
61+
Args:
62+
size: (sequence of int): Desired output size. Must be a sequence of
63+
the form (height, width).
64+
"""
65+
66+
size: Sequence[int]
67+
68+
def _make_transform_spec(self) -> str:
69+
assert len(self.size) == 2
70+
return f"resize, {self.size[0]}, {self.size[1]}"
71+
72+
@classmethod
73+
def _from_torchvision(cls, resize_tv: nn.Module):
74+
v2 = import_torchvision_transforms_v2()
75+
76+
assert isinstance(resize_tv, v2.Resize)
77+
78+
if resize_tv.interpolation is not v2.InterpolationMode.BILINEAR:
79+
raise ValueError(
80+
"TorchVision Resize transform must use bilinear interpolation."
81+
)
82+
if resize_tv.antialias is False:
83+
raise ValueError(
84+
"TorchVision Resize transform must have antialias enabled."
85+
)
86+
if resize_tv.size is None:
87+
raise ValueError("TorchVision Resize transform must have a size specified.")
88+
if len(resize_tv.size) != 2:
89+
raise ValueError(
90+
"TorchVision Resize transform must have a (height, width) "
91+
f"pair for the size, got {resize_tv.size}."
92+
)
93+
return cls(size=resize_tv.size)

0 commit comments

Comments
 (0)