diff --git a/tests/test_uvm_tensor.py b/tests/test_uvm_tensor.py index dbd7140..da4053a 100644 --- a/tests/test_uvm_tensor.py +++ b/tests/test_uvm_tensor.py @@ -20,10 +20,12 @@ @pytest.mark.cpu_and_gpu def test_uvm_tensor() -> None: if torch.cuda.is_available() and _UVM_TENSOR_AVAILABLE: + device = torch.device("cuda:0") + torch.cuda.set_device(device) uvm_tensor = torch.rand( (64, 64), out=new_managed_tensor( - torch.empty(0, dtype=torch.float32, device="cuda:0"), + torch.empty(0, dtype=torch.float32, device=device), [64, 64], ), )