Skip to content

Commit 9015dfd

Browse files
k-artempragupta
andauthored
[release/2.7] Skip&Fix some testcases for Navi4x (#2645)
Manual cherry-pick of #2401 Fixes #SWDEV-548314 --------- Co-authored-by: Prachi Gupta <[email protected]>
1 parent a033df6 commit 9015dfd

File tree

10 files changed

+39
-5
lines changed

10 files changed

+39
-5
lines changed

functorch/experimental/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
# PyTorch forward-mode is not mature yet
2-
from functorch import functionalize
2+
from torch._functorch.deprecated import functionalize
33
from torch._functorch.apis import chunk_vmap
44
from torch._functorch.batch_norm_replacement import replace_all_batch_norm_modules_
55
from torch._functorch.eager_transforms import hessian, jacfwd, jvp

test/distributed/_composable/fsdp/test_fully_shard_training.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,10 @@
2727
)
2828
from torch.distributed.tensor import DTensor, init_device_mesh, Shard
2929
from torch.distributed.tensor.debug import CommDebugMode
30-
from torch.testing._internal.common_cuda import TEST_CUDA
30+
from torch.testing._internal.common_cuda import (
31+
PLATFORM_SUPPORTS_MEM_EFF_ATTENTION,
32+
TEST_CUDA,
33+
)
3134
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
3235
from torch.testing._internal.common_fsdp import (
3336
check_sharded_parity,
@@ -41,7 +44,9 @@
4144
)
4245
from torch.testing._internal.common_utils import (
4346
get_cycles_per_ms,
47+
NAVI4_ARCH,
4448
run_tests,
49+
skipIfRocmArch,
4550
wrapSwapTensorsTest,
4651
)
4752
from torch.testing._internal.distributed._tensor.common_dtensor import (
@@ -94,6 +99,7 @@ def world_size(self) -> int:
9499
return 4
95100

96101
@unittest.skipIf(not TEST_CUDA, "no cuda")
102+
@skipIfRocmArch(NAVI4_ARCH) # Supported in future releaes
97103
def test_param_registration_after_forward(self):
98104
"""Tests the parameter registration after forward."""
99105
device = torch.device("cuda", 0)
@@ -200,6 +206,7 @@ def world_size(self) -> int:
200206

201207
@unittest.skipIf(not TEST_CUDA, "no cuda")
202208
@wrapSwapTensorsTest(True)
209+
@skipIfRocmArch(NAVI4_ARCH) # Supported in future releaes
203210
def test_to_float64_after_init(self):
204211
"""Tests that the user can cast the module to float64 after init."""
205212
# NOTE: Test fp64 instead of a lower precision dtype like bf16 for
@@ -310,6 +317,9 @@ def _shard_placement_fn(param: nn.Parameter) -> Optional[Shard]:
310317

311318
@skip_if_lt_x_gpu(2)
312319
@compiled_fsdp_test(compile_compute_on_module=Transformer)
320+
@unittest.skipIf(
321+
not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Platform does not support fused SDPA"
322+
)
313323
def test_train_parity_multi_group(self):
314324
"""
315325
Tests train parity against DDP when using multiple parameter groups for

test/distributed/_tools/test_sac_ilp.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
get_optimal_checkpointing_policy_per_module,
1818
sac_milp,
1919
)
20-
from torch.testing._internal.common_cuda import TEST_CUDA
20+
from torch.testing._internal.common_cuda import TEST_CUDA, PLATFORM_SUPPORTS_FLASH_ATTENTION
2121
from torch.testing._internal.common_utils import (
2222
run_tests,
2323
skipIfTorchDynamo,
@@ -180,7 +180,7 @@ def test_sac_ilp_case1(self):
180180

181181
@skipIfTorchDynamo("https://github.com/pytorch/pytorch/issues/115653")
182182
@unittest.skipIf(not TEST_CUDA, "CUDA not available")
183-
@skipIfRocmArch(NAVI_ARCH)
183+
@unittest.skipIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Some archs don't support SDPA")
184184
def test_sac_ilp_case2(self):
185185
"""
186186
This is a case where the memory budget is not binding, meaning that no

test/distributed/elastic/test_control_plane.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,12 @@
1515
TORCH_WORKER_SERVER_SOCKET,
1616
worker_main,
1717
)
18-
from torch.testing._internal.common_utils import requires_cuda, run_tests, TestCase
18+
from torch.testing._internal.common_utils import (
19+
requires_cuda,
20+
run_tests,
21+
skipIfRocm,
22+
TestCase,
23+
)
1924

2025

2126
class UnixHTTPConnection(HTTPConnection):
@@ -151,6 +156,7 @@ def test_dump_nccl_trace_pickle_with_json(self) -> None:
151156
)
152157
self.assertEqual(resp.status, 200)
153158

159+
@skipIfRocm # skipped upstream too
154160
def test_tcp(self) -> None:
155161
import requests
156162

test/distributed/fsdp/test_fsdp_core.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,11 @@
3535
TransformerWithSharedParams,
3636
)
3737
from torch.testing._internal.common_utils import (
38+
instantiate_parametrized_tests,
39+
NAVI4_ARCH,
3840
parametrize,
3941
run_tests,
42+
skipIfRocmArch,
4043
TEST_HPU,
4144
TEST_WITH_DEV_DBG_ASAN,
4245
)
@@ -160,6 +163,7 @@ def test_nested_always_wrap_model(
160163

161164
@skip_if_lt_x_gpu(2)
162165
@parametrize(params, configs, subtest_name)
166+
@skipIfRocmArch(NAVI4_ARCH) # Supported in future releases
163167
def test_transformer(
164168
self,
165169
cpu_offload: CPUOffload,

test/distributed/fsdp/test_fsdp_hybrid_shard.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from enum import auto, Enum
77
from functools import partial
88
from typing import Optional
9+
import unittest
910

1011
import torch
1112
import torch.distributed as dist
@@ -31,6 +32,9 @@
3132
FSDPTest,
3233
TransformerWithSharedParams,
3334
)
35+
from torch.testing._internal.common_cuda import (
36+
PLATFORM_SUPPORTS_FLASH_ATTENTION,
37+
)
3438
from torch.testing._internal.common_utils import (
3539
instantiate_parametrized_tests,
3640
run_tests,
@@ -227,6 +231,7 @@ def test_invalid_pg_specification_raises(self):
227231
# resharded after forward.
228232

229233
@skip_if_lt_x_gpu(2)
234+
@unittest.skipIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Does not support flash attention")
230235
def test_fsdp_hybrid_shard_basic_setup(self):
231236
"""
232237
Tests basic functionality of HYBRID_SHARD and _HYBRID_SHARD_ZERO2:

test/distributed/fsdp/test_fsdp_sharded_grad_scaler.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from torch.distributed.fsdp.wrap import ModuleWrapPolicy
2020
from torch.nn import TransformerDecoderLayer, TransformerEncoderLayer
2121
from torch.nn.parallel.distributed import DistributedDataParallel as DDP
22+
from torch.testing._internal.common_cuda import PLATFORM_SUPPORTS_MEM_EFF_ATTENTION
2223
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
2324
from torch.testing._internal.common_fsdp import (
2425
DEVICEInitMode,
@@ -236,6 +237,9 @@ def _build_model_and_optim(
236237
return model, optim, ref_model, ref_optim
237238

238239
@skip_if_lt_x_gpu(2)
240+
@unittest.skipIf(
241+
not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Platform does not support fused SDPA"
242+
)
239243
def test_sharded_grad_scaler_found_inf(self):
240244
self.run_subtests(
241245
{

test/distributed/optim/test_zero_redundancy_optimizer.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -917,6 +917,8 @@ def closure_sharded(input_tensor=input_tensor):
917917
torch.testing.assert_close(
918918
loss_ddp,
919919
loss_sharded_optim,
920+
atol=1.6e-3,
921+
rtol=3e-6,
920922
msg="Losses differ between local optimizer and ZeRO",
921923
)
922924
self._check_same_model_params(

test/distributed/tensor/parallel/test_tp_examples.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
Transformer,
4444
with_comms,
4545
)
46+
from unittest import skipIf
4647

4748

4849
c10d_functional = torch.ops.c10d_functional

test/test_linalg.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
TEST_WITH_ROCM, IS_FBCODE, IS_REMOTE_GPU, iter_indices,
2121
make_fullrank_matrices_with_distinct_singular_values,
2222
freeze_rng_state, IS_ARM64, IS_SANDCASTLE, TEST_OPT_EINSUM, parametrize, skipIfTorchDynamo,
23+
skipIfRocmArch, NAVI4_ARCH,
2324
setBlasBackendsToDefaultFinally, setLinalgBackendsToDefaultFinally, serialTest,
2425
runOnRocmArch, MI300_ARCH)
2526
from torch.testing._internal.common_device_type import \
@@ -7149,6 +7150,7 @@ def test_baddbmm_input_dtypes_compatibility(self, device, dtype):
71497150

71507151
@unittest.skipIf(IS_FBCODE and IS_REMOTE_GPU, "cublas runtime error")
71517152
@onlyCUDA
7153+
@skipIfRocmArch(NAVI4_ARCH)
71527154
def test_matmul_45724(self, device):
71537155
# https://github.com/pytorch/pytorch/issues/45724
71547156
a = torch.rand(65537, 22, 64, device=device, dtype=torch.half)

0 commit comments

Comments
 (0)