Skip to content
5 changes: 3 additions & 2 deletions test/common_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,8 +400,9 @@ def make_image_pil(*args, **kwargs):
return to_pil_image(make_image(*args, **kwargs))


def make_image_cvcuda(*args, **kwargs):
return to_cvcuda_tensor(make_image(*args, **kwargs))
def make_image_cvcuda(*args, batch_dims=(1,), **kwargs):
# explicitly default batch_dims to (1,) since to_cvcuda_tensor requires a batch dimension (ndims == 4)
return to_cvcuda_tensor(make_image(*args, batch_dims=batch_dims, **kwargs))


def make_keypoints(canvas_size=DEFAULT_SIZE, *, num_points=4, dtype=None, device="cpu"):
Expand Down
59 changes: 54 additions & 5 deletions test/test_transforms_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -5517,7 +5517,17 @@ def test_kernel_image_inplace(self, device):
def test_kernel_video(self):
check_kernel(F.normalize_video, make_video(dtype=torch.float32), mean=self.MEAN, std=self.STD)

@pytest.mark.parametrize("make_input", [make_image_tensor, make_image, make_video])
@pytest.mark.parametrize(
"make_input",
[
make_image_tensor,
make_image,
make_video,
pytest.param(
make_image_cvcuda, marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="test requires CVCUDA")
),
],
)
def test_functional(self, make_input):
check_functional(F.normalize, make_input(dtype=torch.float32), mean=self.MEAN, std=self.STD)

Expand All @@ -5527,9 +5537,16 @@ def test_functional(self, make_input):
(F.normalize_image, torch.Tensor),
(F.normalize_image, tv_tensors.Image),
(F.normalize_video, tv_tensors.Video),
pytest.param(
F._misc._normalize_cvcuda,
"cvcuda.Tensor",
marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="test requires CVCUDA"),
),
],
)
def test_functional_signature(self, kernel, input_type):
if input_type == "cvcuda.Tensor":
input_type = _import_cvcuda().Tensor
check_functional_kernel_signature_match(F.normalize, kernel=kernel, input_type=input_type)

def test_functional_error(self):
Expand All @@ -5555,7 +5572,17 @@ def _sample_input_adapter(self, transform, input, device):
adapted_input[key] = value
return adapted_input

@pytest.mark.parametrize("make_input", [make_image_tensor, make_image, make_video])
@pytest.mark.parametrize(
"make_input",
[
make_image_tensor,
make_image,
make_video,
pytest.param(
make_image_cvcuda, marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="test requires CVCUDA")
),
],
)
def test_transform(self, make_input):
check_transform(
transforms.Normalize(mean=self.MEAN, std=self.STD),
Expand All @@ -5570,14 +5597,36 @@ def _reference_normalize_image(self, image, *, mean, std):

@pytest.mark.parametrize(("mean", "std"), MEANS_STDS)
@pytest.mark.parametrize("dtype", [torch.float16, torch.float32, torch.float64])
@pytest.mark.parametrize(
"make_input",
[
make_image,
pytest.param(
make_image_cvcuda, marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="test requires CVCUDA")
),
],
)
@pytest.mark.parametrize("fn", [F.normalize, transform_cls_to_functional(transforms.Normalize)])
def test_correctness_image(self, mean, std, dtype, fn):
image = make_image(dtype=dtype)
def test_correctness_image(self, mean, std, dtype, make_input, fn):
if make_input == make_image_cvcuda and dtype != torch.float32:
pytest.skip("CVCUDA only supports float32 for normalize")

image = make_input(dtype=dtype)

actual = fn(image, mean=mean, std=std)

if make_input == make_image_cvcuda:
image = F.cvcuda_to_tensor(image).to(device="cpu")
image = image.squeeze(0)
actual = F.cvcuda_to_tensor(actual).to(device="cpu")
actual = actual.squeeze(0)

expected = self._reference_normalize_image(image, mean=mean, std=std)

assert_equal(actual, expected)
if make_input == make_image_cvcuda:
torch.testing.assert_close(actual, expected, rtol=0, atol=1e-6)
else:
assert_equal(actual, expected)


class TestClampBoundingBoxes:
Expand Down
4 changes: 2 additions & 2 deletions torchvision/transforms/v2/_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from torch import nn
from torch.utils._pytree import tree_flatten, tree_unflatten
from torchvision import tv_tensors
from torchvision.transforms.v2._utils import check_type, has_any, is_pure_tensor
from torchvision.transforms.v2._utils import check_type, has_any, is_cvcuda_tensor, is_pure_tensor
from torchvision.utils import _log_api_usage_once

from .functional._utils import _get_kernel
Expand All @@ -23,7 +23,7 @@ class Transform(nn.Module):

# Class attribute defining transformed types. Other types are passed-through without any transformation
# We support both Types and callables that are able to do further checks on the type of the input.
_transformed_types: tuple[type | Callable[[Any], bool], ...] = (torch.Tensor, PIL.Image.Image)
_transformed_types: tuple[type | Callable[[Any], bool], ...] = (torch.Tensor, PIL.Image.Image, is_cvcuda_tensor)

def __init__(self) -> None:
super().__init__()
Expand Down
5 changes: 3 additions & 2 deletions torchvision/transforms/v2/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from torchvision._utils import sequence_to_str

from torchvision.transforms.transforms import _check_sequence_input, _setup_angle, _setup_size # noqa: F401
from torchvision.transforms.v2.functional import get_dimensions, get_size, is_pure_tensor
from torchvision.transforms.v2.functional import get_dimensions, get_size, is_cvcuda_tensor, is_pure_tensor
from torchvision.transforms.v2.functional._utils import _FillType, _FillTypeJIT


Expand Down Expand Up @@ -182,7 +182,7 @@ def query_chw(flat_inputs: list[Any]) -> tuple[int, int, int]:
chws = {
tuple(get_dimensions(inpt))
for inpt in flat_inputs
if check_type(inpt, (is_pure_tensor, tv_tensors.Image, PIL.Image.Image, tv_tensors.Video))
if check_type(inpt, (is_pure_tensor, tv_tensors.Image, PIL.Image.Image, tv_tensors.Video, is_cvcuda_tensor))
}
if not chws:
raise TypeError("No image or video was found in the sample")
Expand All @@ -207,6 +207,7 @@ def query_size(flat_inputs: list[Any]) -> tuple[int, int]:
tv_tensors.Mask,
tv_tensors.BoundingBoxes,
tv_tensors.KeyPoints,
is_cvcuda_tensor,
),
)
}
Expand Down
2 changes: 1 addition & 1 deletion torchvision/transforms/v2/functional/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from torchvision.transforms import InterpolationMode # usort: skip

from ._utils import is_pure_tensor, register_kernel # usort: skip
from ._utils import is_pure_tensor, register_kernel, is_cvcuda_tensor # usort: skip

from ._meta import (
clamp_bounding_boxes,
Expand Down
49 changes: 47 additions & 2 deletions torchvision/transforms/v2/functional/_misc.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import math
from typing import Optional
from typing import Optional, TYPE_CHECKING

import PIL.Image
import torch
Expand All @@ -13,7 +13,14 @@

from ._meta import _convert_bounding_box_format

from ._utils import _get_kernel, _register_kernel_internal, is_pure_tensor
from ._utils import _get_kernel, _import_cvcuda, _is_cvcuda_available, _register_kernel_internal, is_pure_tensor

CVCUDA_AVAILABLE = _is_cvcuda_available()

if TYPE_CHECKING:
import cvcuda # type: ignore[import-not-found]
if CVCUDA_AVAILABLE:
cvcuda = _import_cvcuda() # noqa: F811


def normalize(
Expand Down Expand Up @@ -72,6 +79,44 @@ def normalize_video(video: torch.Tensor, mean: list[float], std: list[float], in
return normalize_image(video, mean, std, inplace=inplace)


def _normalize_cvcuda(
image: "cvcuda.Tensor",
mean: list[float],
std: list[float],
inplace: bool = False,
) -> "cvcuda.Tensor":
cvcuda = _import_cvcuda()
if inplace:
raise ValueError("Inplace normalization is not supported for CVCUDA.")

# CV-CUDA supports signed int and float tensors
# torchvision only supports uint and float, right now CV-CUDA doesnt expose float16, so only check 32
# in the future add float16 once exposed in CV-CUDA
if not (image.dtype == cvcuda.Type.F32):
raise ValueError(f"Input tensor should be a float tensor. Got {image.dtype}.")

channels = image.shape[3]
if isinstance(mean, float | int):
mean = [mean] * channels
elif len(mean) != channels:
raise ValueError(f"Mean should have {channels} elements. Got {len(mean)}.")
if isinstance(std, float | int):
std = [std] * channels
elif len(std) != channels:
raise ValueError(f"Std should have {channels} elements. Got {len(std)}.")

mt = torch.as_tensor(mean, dtype=torch.float32).reshape(1, 1, 1, channels).cuda()
st = torch.as_tensor(std, dtype=torch.float32).reshape(1, 1, 1, channels).cuda()
mean_cv = cvcuda.as_tensor(mt, cvcuda.TensorLayout.NHWC)
std_cv = cvcuda.as_tensor(st, cvcuda.TensorLayout.NHWC)

return cvcuda.normalize(image, base=mean_cv, scale=std_cv, flags=cvcuda.NormalizeFlags.SCALE_IS_STDDEV)


if CVCUDA_AVAILABLE:
_register_kernel_internal(normalize, _import_cvcuda().Tensor)(_normalize_cvcuda)


def gaussian_blur(inpt: torch.Tensor, kernel_size: list[int], sigma: Optional[list[float]] = None) -> torch.Tensor:
"""See :class:`~torchvision.transforms.v2.GaussianBlur` for details."""
if torch.jit.is_scripting():
Expand Down
7 changes: 7 additions & 0 deletions torchvision/transforms/v2/functional/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,3 +169,10 @@ def _is_cvcuda_available():
return True
except ImportError:
return False


def is_cvcuda_tensor(inpt: Any) -> bool:
if _is_cvcuda_available():
cvcuda = _import_cvcuda()
return isinstance(inpt, cvcuda.Tensor)
return False