12
12
from torch ._inductor .codegen .triton import FixedTritonConfig , TritonKernel
13
13
from torch ._inductor .test_case import TestCase
14
14
from torch ._inductor .utils import run_and_get_code
15
+ from torch .testing import assert_close
15
16
from torch .testing ._internal .common_cuda import IS_SM89
16
17
from torch .testing ._internal .common_utils import (
17
18
instantiate_parametrized_tests ,
@@ -33,19 +34,99 @@ def setUp(self):
33
34
torch ._inductor .metrics .generated_kernel_count = 0
34
35
torch ._dynamo .reset ()
35
36
36
- def run_and_check (self , fn , args , * , expect_kernel_count = 1 ):
37
- args_cpu = [tensor .cpu ().to (torch .float32 ) for tensor in args ]
38
- expected = fn (* args_cpu ).to (torch .float16 )
39
- fn = torch .compile (fn , fullgraph = True )
40
- result , (source_code ,) = run_and_get_code (fn , * args )
41
- self .assertEqual (result , expected )
42
- self .assertIn ("@triton_heuristics.cooperative_reduction" , source_code )
37
+ def run_and_check (self , fn , args , dtype = None , * , expect_kernel_count = 1 ):
38
+ # Define fixed tolerances
39
+ RTOL = 1e-5
40
+ ATOL = 1e-6
41
+
42
+ # calculate reference value in higher precision when input dtype is float16
43
+ ref_dtype = dtype
44
+ if dtype == torch .float16 :
45
+ ref_dtype = torch .float64
46
+
47
+ # Cast to the determined reference dtype
48
+ args_ref = [tensor .to (ref_dtype ) for tensor in args ]
49
+
50
+ # Calculate expected output
51
+ raw_expected = fn (* args_ref )
52
+
53
+ if isinstance (raw_expected , (tuple , list )):
54
+ # If it's a tuple or list, apply .to(dtype) to each tensor within it
55
+ # Also, handle cases where dtype might not be provided (e.g., for bool reductions)
56
+ if dtype is not None :
57
+ expected = type (raw_expected )(
58
+ [
59
+ t .to (dtype ) if isinstance (t , torch .Tensor ) else t
60
+ for t in raw_expected
61
+ ]
62
+ )
63
+ else :
64
+ expected = type (raw_expected )(
65
+ [
66
+ t .to (torch .float64 ) if isinstance (t , torch .Tensor ) else t
67
+ for t in raw_expected
68
+ ]
69
+ )
70
+ else :
71
+ # If it's a single tensor
72
+ if dtype is not None :
73
+ expected = raw_expected .to (dtype )
74
+ else :
75
+ expected = raw_expected .to (torch .float64 )
76
+
77
+ fn_compiled = torch .compile (fn , fullgraph = True )
78
+ result , (source_code ,) = run_and_get_code (fn_compiled , * args )
79
+
80
+ # For comparison, ensure result is also a tuple/list if expected is
81
+ if isinstance (expected , (tuple , list )):
82
+ if isinstance (result , torch .Tensor ):
83
+ result = (result ,)
84
+ elif not isinstance (result , type (expected )):
85
+ result = type (expected )(result )
86
+
87
+ if dtype is not None :
88
+ result = type (result )(
89
+ [t .to (dtype ) if isinstance (t , torch .Tensor ) else t for t in result ]
90
+ )
91
+ else :
92
+ result = type (result )(
93
+ [
94
+ t .to (torch .float64 ) if isinstance (t , torch .Tensor ) else t
95
+ for t in result
96
+ ]
97
+ )
98
+ else :
99
+ if dtype is not None and isinstance (result , torch .Tensor ):
100
+ result = result .to (dtype )
101
+ elif isinstance (result , torch .Tensor ):
102
+ result = result .to (torch .float64 )
103
+
104
+ # Apply assert_close with fixed tolerances for tensor comparisons
105
+ if isinstance (result , torch .Tensor ) and isinstance (expected , torch .Tensor ):
106
+ assert_close (result , expected , rtol = RTOL , atol = ATOL )
107
+ elif isinstance (result , (tuple , list )) and isinstance (expected , (tuple , list )):
108
+ # Iterate through elements for comparison
109
+ for r_item , e_item in zip (result , expected ):
110
+ if isinstance (r_item , torch .Tensor ) and isinstance (
111
+ e_item , torch .Tensor
112
+ ):
113
+ assert_close (r_item , e_item , rtol = RTOL , atol = ATOL )
114
+ else :
115
+ # Fallback to assertEqual for non-tensor elements (e.g., bool, int)
116
+ self .assertEqual (r_item , e_item )
117
+ else :
118
+ # Fallback to assertEqual for other types not handled by assert_close
119
+ self .assertEqual (result , expected )
120
+
121
+ if "@triton_heuristics.fixed_config" in source_code :
122
+ self .assertIn ("cooperative_reduction_grid" , source_code )
123
+ else :
124
+ self .assertIn ("@triton_heuristics.cooperative_reduction" , source_code )
43
125
if "async_compile.multi_kernel" not in source_code :
44
126
self .assertEqual (
45
127
torch ._inductor .metrics .generated_kernel_count , expect_kernel_count
46
128
)
47
129
return source_code
48
-
49
130
@parametrize (
50
131
"name" ,
51
132
[
0 commit comments