Skip to content

Commit 24a6c29

Browse files
committed
[SWDEV-523736] Skip&Fix some testcases for Navi4x
1 parent 9663f2d commit 24a6c29

File tree

12 files changed

+169
-93
lines changed

12 files changed

+169
-93
lines changed

aten/src/ATen/native/cuda/Blas.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1048,9 +1048,11 @@ _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2,
10481048
TORCH_CHECK(!out_dtype || *out_dtype == out.scalar_type(), "out_dtype must match output matrix type");
10491049
TORCH_CHECK(isFloat8Type(mat1.scalar_type()), "Expected mat1 to be Float8 matrix got ", mat1.scalar_type());
10501050
TORCH_CHECK(isFloat8Type(mat2.scalar_type()), "Expected mat2 to be Float8 matrix got ", mat2.scalar_type());
1051+
#ifndef USE_ROCM
10511052
// Type restrictions imposed by CuBLASLt as of CUDA-12.1
10521053
TORCH_CHECK(mat1.scalar_type() != ScalarType::Float8_e5m2 || mat2.scalar_type() != ScalarType::Float8_e5m2,
10531054
"Multiplication of two Float8_e5m2 matrices is not supported");
1055+
#endif
10541056
if (bias) {
10551057
TORCH_CHECK(out.scalar_type() != kFloat, "Bias is not supported when out_dtype is set to Float32");
10561058
TORCH_CHECK(bias->scalar_type() == ScalarType::BFloat16 || bias->scalar_type() == ScalarType::Half,

test/distributed/_tools/test_sac_ilp.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
get_optimal_checkpointing_policy_per_module,
1919
sac_milp,
2020
)
21-
from torch.testing._internal.common_cuda import TEST_CUDA
21+
from torch.testing._internal.common_cuda import TEST_CUDA, PLATFORM_SUPPORTS_FLASH_ATTENTION
2222
from torch.testing._internal.common_utils import (
2323
run_tests,
2424
skipIfTorchDynamo,
@@ -181,7 +181,7 @@ def test_sac_ilp_case1(self):
181181

182182
@skipIfTorchDynamo("https://github.com/pytorch/pytorch/issues/115653")
183183
@unittest.skipIf(not TEST_CUDA, "CUDA not available")
184-
@skipIfRocmArch(NAVI_ARCH)
184+
@unittest.skipIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Some archs don't support SDPA")
185185
def test_sac_ilp_case2(self):
186186
"""
187187
This is a case where the memory budget is not binding, meaning that no

test/distributed/tensor/parallel/test_tp_examples.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
RowwiseParallel,
2828
)
2929
from torch.distributed.tensor.parallel.input_reshard import input_reshard
30+
from torch.testing._internal.common_cuda import PLATFORM_SUPPORTS_FLASH_ATTENTION
3031
from torch.testing._internal.common_utils import (
3132
instantiate_parametrized_tests,
3233
parametrize,
@@ -41,6 +42,7 @@
4142
Transformer,
4243
with_comms,
4344
)
45+
from unittest import skipIf
4446

4547

4648
c10d_functional = torch.ops.c10d_functional
@@ -412,6 +414,8 @@ def test_transformer_training(self, is_seq_parallel, dtype: torch.dtype):
412414
+ f"{str(dtype).split('.')[-1]}_"
413415
+ f"thaw_{'__'.join(sorted({n.rpartition('.')[0].replace('.', '_') for n in thaw})) if thaw else 'all'}",
414416
)
417+
418+
@skipIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Some archs don't support SDPA")
415419
def test_transformer_req_grad(self, thaw_params, is_seq_parallel, dtype, exp_cnts):
416420
# Sample a subset of `requires_grad` patterns
417421

test/dynamo/test_graph_deduplication.py

Lines changed: 48 additions & 67 deletions
Large diffs are not rendered by default.

test/dynamo/test_graph_region_tracker.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -70,8 +70,7 @@ def fn(x, y):
7070
torch.rand(10, 10),
7171
torch.ones(10, 20),
7272
),
73-
"""[[['y0', 'x0', 'sum_2', 'sum_1', 'z'], \
74-
['y0_1', 'x0_1', 'sum_4', 'sum_3', 'z_1'], ['y0_2', 'x0_2', 'sum_6', 'sum_5', 'z_2']]]""",
73+
"""[[['x0', 'y0', 'sum_1', 'sum_2', 'z'], ['x0_1', 'y0_1', 'sum_3', 'sum_4', 'z_1'], ['x0_2', 'y0_2', 'sum_5', 'sum_6', 'z_2']]]""",
7574
)
7675

7776
def test_get_regions_multiple_region_groups(self):
@@ -104,8 +103,7 @@ def fn(x, y):
104103
torch.rand(10, 10),
105104
torch.ones(10, 20),
106105
),
107-
"""[[['y1', 'x1', 'sum_3', 'sum_2', 'z'], ['y1_1', 'x1_1', 'sum_5', 'sum_4', 'z_1'], \
108-
['y1_2', 'x1_2', 'sum_8', 'sum_7', 'z_2']], [['b', 'cos_1', 'sum_1', 'a', 'c'], ['b_1', 'cos_2', 'sum_6', 'a_1', 'c_1']]]""",
106+
"""[[['x1', 'y1', 'sum_2', 'sum_3', 'z'], ['x1_1', 'y1_1', 'sum_4', 'sum_5', 'z_1'], ['x1_2', 'y1_2', 'sum_7', 'sum_8', 'z_2']], [['a', 'b', 'cos_1', 'sum_1', 'c'], ['a_1', 'b_1', 'cos_2', 'sum_6', 'c_1']]]""",
109107
)
110108

111109
def test_no_single_node_regions(self):
@@ -177,8 +175,7 @@ def fn(x, y):
177175
torch.rand(10, 10),
178176
torch.ones(10, 20),
179177
),
180-
"""[[['y1', 'sum_1', 'x1', 'o0'], ['y1_1', 'sum_2', 'x1_1', 'o2'], \
181-
['y1_2', 'sum_3', 'x1_2', 'o4'], ['y1_3', 'sum_4', 'x1_3', 'o5']]]""",
178+
"""[[['x1', 'y1', 'sum_1', 'o0'], ['x1_1', 'y1_1', 'sum_2', 'o2'], ['x1_2', 'y1_2', 'sum_3', 'o4'], ['x1_3', 'y1_3', 'sum_4', 'o5']]]""",
182179
)
183180

184181
def test_nested_args(self):

test/inductor/test_cooperative_reductions.py

Lines changed: 89 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from torch._inductor.codegen.triton import FixedTritonConfig, TritonKernel
1313
from torch._inductor.test_case import TestCase
1414
from torch._inductor.utils import run_and_get_code
15+
from torch.testing import assert_close
1516
from torch.testing._internal.common_cuda import IS_SM89
1617
from torch.testing._internal.common_utils import (
1718
instantiate_parametrized_tests,
@@ -33,19 +34,99 @@ def setUp(self):
3334
torch._inductor.metrics.generated_kernel_count = 0
3435
torch._dynamo.reset()
3536

36-
def run_and_check(self, fn, args, *, expect_kernel_count=1):
37-
args_cpu = [tensor.cpu().to(torch.float32) for tensor in args]
38-
expected = fn(*args_cpu).to(torch.float16)
39-
fn = torch.compile(fn, fullgraph=True)
40-
result, (source_code,) = run_and_get_code(fn, *args)
41-
self.assertEqual(result, expected)
42-
self.assertIn("@triton_heuristics.cooperative_reduction", source_code)
37+
def run_and_check(self, fn, args, dtype=None, *, expect_kernel_count=1):
38+
# Define fixed tolerances
39+
RTOL = 1e-5
40+
ATOL = 1e-6
41+
42+
# calculate reference value in higher precision when input dtype is float16
43+
ref_dtype = dtype
44+
if dtype == torch.float16:
45+
ref_dtype = torch.float64
46+
47+
# Cast to the determined reference dtype
48+
args_ref = [tensor.to(ref_dtype) for tensor in args]
49+
50+
# Calculate expected output
51+
raw_expected = fn(*args_ref)
52+
53+
if isinstance(raw_expected, (tuple, list)):
54+
# If it's a tuple or list, apply .to(dtype) to each tensor within it
55+
# Also, handle cases where dtype might not be provided (e.g., for bool reductions)
56+
if dtype is not None:
57+
expected = type(raw_expected)(
58+
[
59+
t.to(dtype) if isinstance(t, torch.Tensor) else t
60+
for t in raw_expected
61+
]
62+
)
63+
else:
64+
expected = type(raw_expected)(
65+
[
66+
t.to(torch.float64) if isinstance(t, torch.Tensor) else t
67+
for t in raw_expected
68+
]
69+
)
70+
else:
71+
# If it's a single tensor
72+
if dtype is not None:
73+
expected = raw_expected.to(dtype)
74+
else:
75+
expected = raw_expected.to(torch.float64)
76+
77+
fn_compiled = torch.compile(fn, fullgraph=True)
78+
result, (source_code,) = run_and_get_code(fn_compiled, *args)
79+
80+
# For comparison, ensure result is also a tuple/list if expected is
81+
if isinstance(expected, (tuple, list)):
82+
if isinstance(result, torch.Tensor):
83+
result = (result,)
84+
elif not isinstance(result, type(expected)):
85+
result = type(expected)(result)
86+
87+
if dtype is not None:
88+
result = type(result)(
89+
[t.to(dtype) if isinstance(t, torch.Tensor) else t for t in result]
90+
)
91+
else:
92+
result = type(result)(
93+
[
94+
t.to(torch.float64) if isinstance(t, torch.Tensor) else t
95+
for t in result
96+
]
97+
)
98+
else:
99+
if dtype is not None and isinstance(result, torch.Tensor):
100+
result = result.to(dtype)
101+
elif isinstance(result, torch.Tensor):
102+
result = result.to(torch.float64)
103+
104+
# Apply assert_close with fixed tolerances for tensor comparisons
105+
if isinstance(result, torch.Tensor) and isinstance(expected, torch.Tensor):
106+
assert_close(result, expected, rtol=RTOL, atol=ATOL)
107+
elif isinstance(result, (tuple, list)) and isinstance(expected, (tuple, list)):
108+
# Iterate through elements for comparison
109+
for r_item, e_item in zip(result, expected):
110+
if isinstance(r_item, torch.Tensor) and isinstance(
111+
e_item, torch.Tensor
112+
):
113+
assert_close(r_item, e_item, rtol=RTOL, atol=ATOL)
114+
else:
115+
# Fallback to assertEqual for non-tensor elements (e.g., bool, int)
116+
self.assertEqual(r_item, e_item)
117+
else:
118+
# Fallback to assertEqual for other types not handled by assert_close
119+
self.assertEqual(result, expected)
120+
121+
if "@triton_heuristics.fixed_config" in source_code:
122+
self.assertIn("cooperative_reduction_grid", source_code)
123+
else:
124+
self.assertIn("@triton_heuristics.cooperative_reduction", source_code)
43125
if "async_compile.multi_kernel" not in source_code:
44126
self.assertEqual(
45127
torch._inductor.metrics.generated_kernel_count, expect_kernel_count
46128
)
47129
return source_code
48-
49130
@parametrize(
50131
"name",
51132
[

test/inductor/test_cuda_repro.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
IS_FBCODE,
3535
skipIfRocm,
3636
TEST_WITH_ASAN,
37+
xfailIfPy312Plus,
3738
)
3839

3940

@@ -1568,6 +1569,7 @@ def get_input() -> torch.Tensor:
15681569
self.assertEqual(result, a + b)
15691570
self.assertIn("znumel", code)
15701571

1572+
@xfailIfPy312Plus # https://github.com/pytorch/pytorch/issues/142032
15711573
def test_repeated_masked_load(self):
15721574
target_size = (8, 2)
15731575
mem_eff_temporal_upsampling_interp_chunks = 2

test/inductor/test_flex_decoding.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1332,6 +1332,7 @@ def mask_mod(b, h, q, kv):
13321332
self.assertEqual(query.grad[:, :, M:, :].sum(), 0)
13331333

13341334
@supported_platform
1335+
@skipIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Some archs don't support SDPA")
13351336
def test_windowed_no_mask_vs_sdpa(self):
13361337
score_mod = _generate_windowed(1000)
13371338
attention = functools.partial(flex_attention, score_mod=score_mod)

test/test_license.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,11 @@ def test_distinfo_license(self):
4545
'Found too many "torch-*dist-info" directories '
4646
f'in "{site_packages}, expected only one'
4747
)
48-
with open(os.path.join(os.path.join(distinfo[0], "LICENSE"))) as fid:
48+
# setuptools renamed *dist-info/LICENSE to *dist-info/licenses/LICENSE sicne 77.0
49+
license_file = os.path.join(distinfo[0], "licenses", "LICENSE")
50+
if not os.path.exists(license_file):
51+
license_file = os.path.join(distinfo[0], "LICENSE")
52+
with open(license_file) as fid:
4953
txt = fid.read()
5054
self.assertTrue(starting_txt in txt)
5155

test/test_linalg.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
TEST_WITH_ROCM, IS_FBCODE, IS_REMOTE_GPU, iter_indices,
2020
make_fullrank_matrices_with_distinct_singular_values,
2121
freeze_rng_state, IS_ARM64, IS_SANDCASTLE, TEST_OPT_EINSUM, parametrize, skipIfTorchDynamo,
22+
skipIfRocmArch, NAVI4_ARCH,
2223
setBlasBackendsToDefaultFinally, setLinalgBackendsToDefaultFinally, serialTest)
2324
from torch.testing._internal.common_device_type import \
2425
(instantiate_device_type_tests, dtypes, has_cusolver, has_hipsolver,
@@ -6440,6 +6441,7 @@ def test_baddbmm_input_dtypes_compatibility(self, device, dtype):
64406441

64416442
@unittest.skipIf(IS_FBCODE and IS_REMOTE_GPU, "cublas runtime error")
64426443
@onlyCUDA
6444+
@skipIfRocmArch(NAVI4_ARCH)
64436445
def test_matmul_45724(self, device):
64446446
# https://github.com/pytorch/pytorch/issues/45724
64456447
a = torch.rand(65537, 22, 64, device=device, dtype=torch.half)

0 commit comments

Comments
 (0)