21
21
//
22
22
// MODIFICATION NOTE (2024-09-25): added SM75 support (https://github.com/pytorch/ao/pull/942):
23
23
// - Modified the TilingConfig parameters for SM75 to deal with smaller shared memory
24
- // - Added proper architecture check at both host and device level
25
24
//
26
25
27
26
@@ -99,24 +98,7 @@ void fpx_linear_kernel(cudaStream_t stream,
99
98
static_assert (std::is_same<InputDataType, half>::value || std::is_same<InputDataType, __nv_bfloat16>::value, " Type must be 'half' or '__nv_bfloat16'" );
100
99
assert (M_Global % 256 == 0 );
101
100
assert (K_Global % 64 == 0 );
102
- assert (N_Global > 0 );
103
-
104
- // Check GPU Compute Capability before proceeding
105
- int device, major, minor;
106
- CHECK_CUDA (cudaGetDevice (&device));
107
- CHECK_CUDA (cudaDeviceGetAttribute (&major, cudaDevAttrComputeCapabilityMajor, device));
108
- CHECK_CUDA (cudaDeviceGetAttribute (&minor, cudaDevAttrComputeCapabilityMinor, device));
109
-
110
- // Early exit with error for unsupported architectures
111
- if ((major < 7 ) || (major == 7 && minor < 5 )) {
112
- TORCH_CHECK (false , " Quant-LLM Error: This kernel requires GPU with SM75 (Turing) or higher architecture. "
113
- " Your current device has SM" , major, minor, " which is not supported." );
114
- }
115
-
116
- const bool is_sm75_gpu = (major == 7 ) && (minor == 5 );
117
- if (is_sm75_gpu && std::is_same<InputDataType, __nv_bfloat16>::value) {
118
- TORCH_CHECK (false , " Quant-LLM Error: BFloat16 inputs are not supported on SM75 (Turing) GPUs." );
119
- }
101
+ assert (N_Global>0 );
120
102
121
103
// Work around to support more N shapes:
122
104
size_t N_PowerOf2;
@@ -127,6 +109,17 @@ void fpx_linear_kernel(cudaStream_t stream,
127
109
if (N_Global>64 && N_Global<=128 ) N_PowerOf2 = 128 ;
128
110
if (N_Global>128 ) N_PowerOf2 = ((N_Global-1 )/128 +1 ) * 128 ;
129
111
112
+ // Check GPU Compute Capability
113
+ int device, major, minor;
114
+ CHECK_CUDA (cudaGetDevice (&device));
115
+ CHECK_CUDA (cudaDeviceGetAttribute (&major, cudaDevAttrComputeCapabilityMajor, device));
116
+ CHECK_CUDA (cudaDeviceGetAttribute (&minor, cudaDevAttrComputeCapabilityMinor, device));
117
+ const bool is_sm75_gpu = (major == 7 ) && (minor == 5 );
118
+ if (is_sm75_gpu && std::is_same<InputDataType, __nv_bfloat16>::value)
119
+ TORCH_CHECK (false , " Bfloat16 inputs are not supported for SM75" );
120
+ if ((major < 7 ) || (major == 7 && minor < 5 ))
121
+ TORCH_CHECK (false , " FP6LLM_API Error: FP6LLM requires GPU with SM75 or higher!\n " );
122
+
130
123
if (is_sm75_gpu && (N_PowerOf2 == 64 || N_PowerOf2 == 128 || N_PowerOf2 % 128 == 0 )) {
131
124
// For SM75 and N >= 64, we use a different TilingConfig to deal with smaller shared memory.
132
125
if (Split_K == 1 ) {
@@ -143,7 +136,7 @@ void fpx_linear_kernel(cudaStream_t stream,
143
136
case 64 : Kernel_Ex<TilingConfig<4 , 1 , 8 >, InputDataType, InputDataType, EXPONENT, MANTISSA>(stream, Weight, Scales, B, C, M_Global, N_Global, K_Global, Split_K); break ;
144
137
case 128 : Kernel_Ex<TilingConfig<4 , 1 , 8 >, InputDataType, InputDataType, EXPONENT, MANTISSA>(stream, Weight, Scales, B, C, M_Global, N_Global, K_Global, Split_K); break ;
145
138
default : if (N_PowerOf2 % 128 != 0 ) {
146
- TORCH_CHECK (false , " Quant-LLM Error: Unsupported N dimension " , N_PowerOf2);
139
+ TORCH_CHECK (false , " FP6LLM_API Error: Unsupported N dimension " , N_PowerOf2);
147
140
}
148
141
Kernel_Ex<TilingConfig<4 , 1 , 8 >, InputDataType, InputDataType, EXPONENT, MANTISSA>(stream, Weight, Scales, B, C, M_Global, N_Global, K_Global, Split_K); break ;
149
142
}
@@ -156,7 +149,7 @@ void fpx_linear_kernel(cudaStream_t stream,
156
149
case 64 : Kernel_Ex<TilingConfig<4 , 1 , 8 >, InputDataType, float , EXPONENT, MANTISSA>(stream, Weight, Scales, B, Reduction_Workspace, M_Global, N_Global, K_Global, Split_K); break ;
157
150
case 128 : Kernel_Ex<TilingConfig<4 , 1 , 8 >, InputDataType, float , EXPONENT, MANTISSA>(stream, Weight, Scales, B, Reduction_Workspace, M_Global, N_Global, K_Global, Split_K); break ;
158
151
default : if (N_PowerOf2 % 128 != 0 ) {
159
- TORCH_CHECK (false , " Quant-LLM Error: Unsupported N dimension " , N_PowerOf2);
152
+ TORCH_CHECK (false , " FP6LLM_API Error: Unsupported N dimension " , N_PowerOf2);
160
153
}
161
154
Kernel_Ex<TilingConfig<4 , 1 , 8 >, InputDataType, float , EXPONENT, MANTISSA>(stream, Weight, Scales, B, Reduction_Workspace, M_Global, N_Global, K_Global, Split_K); break ;
162
155
}
@@ -217,23 +210,6 @@ torch::Tensor fp_eXmY_linear_forward_cuda(
217
210
torch::Tensor _scales,
218
211
int64_t splitK=1 )
219
212
{
220
- // Check GPU Compute Capability before proceeding
221
- int device, major, minor;
222
- CHECK_CUDA (cudaGetDevice (&device));
223
- CHECK_CUDA (cudaDeviceGetAttribute (&major, cudaDevAttrComputeCapabilityMajor, device));
224
- CHECK_CUDA (cudaDeviceGetAttribute (&minor, cudaDevAttrComputeCapabilityMinor, device));
225
-
226
- // Early exit with error for unsupported architectures
227
- if ((major < 7 ) || (major == 7 && minor < 5 )) {
228
- TORCH_CHECK (false , " Quant-LLM Error: This kernel requires GPU with SM75 (Turing) or higher architecture. "
229
- " Your current device has SM" , major, minor, " which is not supported." );
230
- }
231
-
232
- const bool is_sm75_gpu = (major == 7 ) && (minor == 5 );
233
- if (is_sm75_gpu && _in_feats.scalar_type () == at::ScalarType::BFloat16) {
234
- TORCH_CHECK (false , " Quant-LLM Error: BFloat16 inputs are not supported on SM75 (Turing) GPUs." );
235
- }
236
-
237
213
const int64_t NBITS = 1 + EXPONENT + MANTISSA;
238
214
int num_in_feats = _in_feats.size (0 );
239
215
int num_in_channels = _in_feats.size (1 );
0 commit comments