Skip to content

Make _get_perspective_coeffs device agnostic #9082

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions test/test_functional_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -434,6 +434,35 @@ def test_perspective_batch(device, dims_and_points, dt):
)


@pytest.mark.parametrize("device", cpu_and_cuda())
@pytest.mark.parametrize("dims_and_points", _get_data_dims_and_points_for_perspective())
@pytest.mark.parametrize("dt", [None, torch.float32, torch.float64, torch.float16])
def test_perspective_tensor_input(device, dims_and_points, dt):

if dt == torch.float16 and device == "cpu":
# skip float16 on CPU case
return

data_dims, (spoints, epoints) = dims_and_points
print(spoints, epoints)

batch_tensors = _create_data_batch(*data_dims, num_samples=4, device=device)
if dt is not None:
batch_tensors = batch_tensors.to(dtype=dt)

# Ignore the equivalence between scripted and regular function on float16 cuda. The pixels at
# the border may be entirely different due to small rounding errors.
scripted_fn_atol = -1 if (dt == torch.float16 and device == "cuda") else 1e-8
_test_fn_on_batch(
batch_tensors,
F.perspective,
scripted_fn_atol=scripted_fn_atol,
startpoints=torch.tensor(spoints, device=device, dtype=dt),
endpoints=torch.tensor(epoints, device=device, dtype=dt),
interpolation=NEAREST,
)


def test_perspective_interpolation_type():
spoints = [[0, 0], [33, 0], [33, 25], [0, 25]]
epoints = [[3, 2], [32, 3], [30, 24], [2, 25]]
Expand Down
31 changes: 19 additions & 12 deletions torchvision/transforms/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -671,32 +671,39 @@ def hflip(img: Tensor) -> Tensor:
return F_t.hflip(img)


def _get_perspective_coeffs(startpoints: List[List[int]], endpoints: List[List[int]]) -> List[float]:
def _get_perspective_coeffs(startpoints: List[List[int]] | Tensor, endpoints: List[List[int]] | Tensor) -> List[float]:
"""Helper function to get the coefficients (a, b, c, d, e, f, g, h) for the perspective transforms.

In Perspective Transform each pixel (x, y) in the original image gets transformed as,
(x, y) -> ( (ax + by + c) / (gx + hy + 1), (dx + ey + f) / (gx + hy + 1) )

Args:
startpoints (list of list of ints): List containing four lists of two integers corresponding to four corners
startpoints (list of list of ints or Tensor): List or Tensor containing four lists of two integers corresponding to four corners
``[top-left, top-right, bottom-right, bottom-left]`` of the original image.
endpoints (list of list of ints): List containing four lists of two integers corresponding to four corners
endpoints (list of list of ints or Tensor): List or Tensor containing four lists of two integers corresponding to four corners
``[top-left, top-right, bottom-right, bottom-left]`` of the transformed image.

Returns:
octuple (a, b, c, d, e, f, g, h) for transforming each pixel.
"""

startpoints = startpoints if isinstance(startpoints, Tensor) else torch.tensor(startpoints, dtype=torch.float64)
endpoints = endpoints if isinstance(endpoints, Tensor) else torch.tensor(endpoints, dtype=torch.float64)

if len(startpoints) != 4 or len(endpoints) != 4:
raise ValueError(
f"Please provide exactly four corners, got {len(startpoints)} startpoints and {len(endpoints)} endpoints."
)
a_matrix = torch.zeros(2 * len(startpoints), 8, dtype=torch.float64)

for i, (p1, p2) in enumerate(zip(endpoints, startpoints)):
a_matrix[2 * i, :] = torch.tensor([p1[0], p1[1], 1, 0, 0, 0, -p2[0] * p1[0], -p2[0] * p1[1]])
a_matrix[2 * i + 1, :] = torch.tensor([0, 0, 0, p1[0], p1[1], 1, -p2[1] * p1[0], -p2[1] * p1[1]])
a_matrix = torch.zeros(2 * len(startpoints), 8, dtype=torch.float64, device=startpoints.device)
a_matrix[::2, :2] = endpoints
a_matrix[1::2, 3:5] = endpoints
a_matrix[::2, 2] = 1
a_matrix[1::2, 5] = 1
a_matrix[::2, 6:] = -startpoints[:, 0:1] * endpoints
a_matrix[1::2, 6:] = -startpoints[:, 1:2] * endpoints

b_matrix = torch.tensor(startpoints, dtype=torch.float64).view(8)
b_matrix = startpoints.to(dtype=torch.float64).view(8)
# do least squares in double precision to prevent numerical issues
res = torch.linalg.lstsq(a_matrix, b_matrix, driver="gels").solution.to(torch.float32)

Expand All @@ -706,8 +713,8 @@ def _get_perspective_coeffs(startpoints: List[List[int]], endpoints: List[List[i

def perspective(
img: Tensor,
startpoints: List[List[int]],
endpoints: List[List[int]],
startpoints: List[List[int]] | Tensor,
endpoints: List[List[int]] | Tensor,
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
fill: Optional[List[float]] = None,
) -> Tensor:
Expand All @@ -717,9 +724,9 @@ def perspective(

Args:
img (PIL Image or Tensor): Image to be transformed.
startpoints (list of list of ints): List containing four lists of two integers corresponding to four corners
startpoints (list of list of ints or Tensor): List or Tensor containing four lists of two integers corresponding to four corners
``[top-left, top-right, bottom-right, bottom-left]`` of the original image.
endpoints (list of list of ints): List containing four lists of two integers corresponding to four corners
endpoints (list of list of ints or Tensor): List or Tensor containing four lists of two integers corresponding to four corners
``[top-left, top-right, bottom-right, bottom-left]`` of the transformed image.
interpolation (InterpolationMode): Desired interpolation enum defined by
:class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.BILINEAR``.
Expand Down