diff --git a/test/run_tests.sh b/test/run_tests.sh index 243f8ee365c..aaf4425da69 100755 --- a/test/run_tests.sh +++ b/test/run_tests.sh @@ -150,6 +150,7 @@ function run_xla_op_tests1 { run_dynamic "$CDIR/ds/test_dynamic_shapes.py" run_dynamic "$CDIR/ds/test_dynamic_shape_models.py" "$@" --verbosity=$VERBOSITY run_eager_debug "$CDIR/test_operations.py" "$@" --verbosity=$VERBOSITY + PJRT_DEVICE=CPU python3 "$CDIR/test_operations.py" -v -k test_rand_on_xla_cpu run_test "$CDIR/test_operations.py" "$@" --verbosity=$VERBOSITY run_test_without_functionalization "$CDIR/test_operations.py" "$@" --verbosity=$VERBOSITY run_pt_xla_debug "$CDIR/debug_tool/test_pt_xla_debug.py" diff --git a/test/test_operations.py b/test/test_operations.py index 82f6c7d5dd4..92ad8277338 100644 --- a/test/test_operations.py +++ b/test/test_operations.py @@ -93,6 +93,11 @@ def onlyOnCUDA(fn): return unittest.skipIf(accelerator != "cuda", "PJRT_DEVICE=CUDA required")(fn) +def onlyOnCPU(fn): + accelerator = os.environ.get("PJRT_DEVICE").lower() + return unittest.skipIf(accelerator != "cpu", "PJRT_DEVICE=CPU required")(fn) + + def onlyIfXLAExperimentalContains(feat): experimental = os.environ.get("XLA_EXPERIMENTAL", "").split(":") return unittest.skipIf(feat not in experimental, @@ -2424,6 +2429,12 @@ def test_isneginf_no_fallback(self): t = t.to(torch.float16) self._test_no_fallback(torch.isneginf, (t,)) + @onlyOnCPU + def test_rand_on_xla_cpu(self): + t = torch.rand(10, device=xm.xla_device()) + t_cpu = t.cpu() + self.assertEqual(t_cpu, t.cpu()) + class MNISTComparator(nn.Module):