Skip to content

Make _get_perspective_coeffs device agnostic #9076

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
ptrblck opened this issue May 18, 2025 · 1 comment
Open

Make _get_perspective_coeffs device agnostic #9076

ptrblck opened this issue May 18, 2025 · 1 comment

Comments

@ptrblck
Copy link
Contributor

ptrblck commented May 18, 2025

🐛 Describe the bug

Currently, _get_perspective_coeffs creates internal tensors on the CPU as seen here and fails in the torch.linalg.lstsq call if start/endpoints is a tensor contain data on the GPU (or another device).

The docs explain a list of list of python:ints is expected, but still tensors are allowed and do not fail.

A fix would be to create a_matrix using the device attribute of the input. An alternative would be to move the points to the host, but this would sync the code and disallow graph capture or to error out if tensors are passed.

If we want to accept tensor inputs, the tensor clone should also be fixed b_matrix = torch.tensor(startpoints, dtype=torch.float64).view(8).

Original error reported in the discussion board and reproduced using:

import torch
import torchvision.transforms.functional as TF


device = "cuda"
reference_image = torch.randn(1, 3, 224, 224, device=device)
B, C, H, W = reference_image.shape

W = 200
H = 200
# Define source points (original corners of the target image)
src_points = torch.tensor([
    [0, 0],  # Top-left
    [W - 1, 0],  # Top-right
    [W - 1, H - 1],  # Bottom-right
    [0, H - 1]  # Bottom-left
], dtype=torch.float32, device=reference_image.device)
src_points = src_points.unsqueeze(0).repeat(B, 1, 1)  # (B, 4, 2)

predicted_points = torch.tensor([
    [0, 0],  # Top-left
    [W - 10, 0],  # Top-right
    [W - 10, H - 10],  # Bottom-right
    [0, H - 10]  # Bottom-left
], dtype=torch.float32, device=reference_image.device)
predicted_points = predicted_points .unsqueeze(0).repeat(B, 1, 1)  # (B, 4, 2)

warped_images = []
for i in range(B):
    warped = TF.perspective(
        reference_image[i],
        src_points[i],
        predicted_points[i],
        interpolation=TF.InterpolationMode.BILINEAR,
        fill=0,
    )

# RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cpu and cuda:0! (when checking argument for argument b in method wrapper_CUDA_out_linalg_lstsq_out)```

### Versions


`0.22.0.dev20250404+cu128`
@AntoineSimoulin
Copy link
Member

Hey @ptrblck, thanks for your time and for opening the issue! I made a proposition of changes in #9082 along with your suggestions:

  • Create the a_matrix using the device attribute of the input;
  • Modifiy how the b_matrix is created as well;
  • Modify how the a_matrix values are filled to make sure the function was still compatible with torch script and passing tests.

Let me know what you think!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants