|
27 | 27 | )
|
28 | 28 | from torch.distributed.tensor import DTensor, init_device_mesh, Shard
|
29 | 29 | from torch.distributed.tensor.debug import CommDebugMode
|
30 |
| -from torch.testing._internal.common_cuda import TEST_CUDA |
| 30 | +from torch.testing._internal.common_cuda import ( |
| 31 | + PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, |
| 32 | + TEST_CUDA, |
| 33 | +) |
31 | 34 | from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
|
32 | 35 | from torch.testing._internal.common_fsdp import (
|
33 | 36 | check_sharded_parity,
|
|
41 | 44 | )
|
42 | 45 | from torch.testing._internal.common_utils import (
|
43 | 46 | get_cycles_per_ms,
|
| 47 | + NAVI4_ARCH, |
44 | 48 | run_tests,
|
| 49 | + skipIfRocmArch, |
45 | 50 | wrapSwapTensorsTest,
|
46 | 51 | )
|
47 | 52 | from torch.testing._internal.distributed._tensor.common_dtensor import (
|
@@ -94,6 +99,7 @@ def world_size(self) -> int:
|
94 | 99 | return 4
|
95 | 100 |
|
96 | 101 | @unittest.skipIf(not TEST_CUDA, "no cuda")
|
| 102 | + @skipIfRocmArch(NAVI4_ARCH) # Supported in future releaes |
97 | 103 | def test_param_registration_after_forward(self):
|
98 | 104 | """Tests the parameter registration after forward."""
|
99 | 105 | device = torch.device("cuda", 0)
|
@@ -200,6 +206,7 @@ def world_size(self) -> int:
|
200 | 206 |
|
201 | 207 | @unittest.skipIf(not TEST_CUDA, "no cuda")
|
202 | 208 | @wrapSwapTensorsTest(True)
|
| 209 | + @skipIfRocmArch(NAVI4_ARCH) # Supported in future releaes |
203 | 210 | def test_to_float64_after_init(self):
|
204 | 211 | """Tests that the user can cast the module to float64 after init."""
|
205 | 212 | # NOTE: Test fp64 instead of a lower precision dtype like bf16 for
|
@@ -310,6 +317,9 @@ def _shard_placement_fn(param: nn.Parameter) -> Optional[Shard]:
|
310 | 317 |
|
311 | 318 | @skip_if_lt_x_gpu(2)
|
312 | 319 | @compiled_fsdp_test(compile_compute_on_module=Transformer)
|
| 320 | + @unittest.skipIf( |
| 321 | + not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Platform does not support fused SDPA" |
| 322 | + ) |
313 | 323 | def test_train_parity_multi_group(self):
|
314 | 324 | """
|
315 | 325 | Tests train parity against DDP when using multiple parameter groups for
|
|
0 commit comments