Skip to content

Commit aa8a9fa

Browse files
colesburyfacebook-github-bot
authored andcommitted
Extend DispatchStub to support CUDA dispatch (pytorch#9664)
Summary: This is a modification of the strategy from pytorch#8919 and pytorch#9579. ``` Previously, the CPU architecture-specific kernels self-registered with the DispatchStub. When linking as part of a static library, this requires the flag --whole-archive to be passed to the linker to ensure that the object files for the kernels are included. Caffe2 and TensorFlow use that strategy. We ran into some issues with --whole-archive blowing up the binary size of some downstream projects in Facebook. This PR avoids --whole-archive for CPU kernels. The downside is that the generic code needs to be aware of whether kernels are compiled with AVX and with AVX2 (via HAVE_AVX_CPU_DEFINITION and HAVE_AVX2_CPU_DEFINITION). The CUDA kernels still self-register with DispatchStub because the CPU library is not aware of whether the CUDA library will be available at runtime. There are a few major changes to DispatchStub - The environment variable ATEN_CPU_CAPABILITY overrides the CPU capability detection code (Previous ATEN_DISABLE_AVX/AVX2) - DispatchStub is defined in the generic native code instead of the CPU_CAPABILITY_DEFAULT kernel. ``` Pull Request resolved: pytorch#9664 Differential Revision: D8943350 Pulled By: colesbury fbshipit-source-id: 329229b0ee9ff94fc001b960287814bd734096ef
1 parent 3e9e3ef commit aa8a9fa

11 files changed

+252
-152
lines changed

.jenkins/pytorch/test.sh

+3-6
Original file line numberDiff line numberDiff line change
@@ -44,13 +44,10 @@ if [[ "$BUILD_ENVIRONMENT" == *asan* ]]; then
4444
(cd test && ! get_exit_code python -c "import torch; torch._C._crash_if_aten_asan(3)")
4545
fi
4646

47-
export ATEN_DISABLE_AVX=
48-
export ATEN_DISABLE_AVX2=
4947
if [[ "${JOB_BASE_NAME}" == *-NO_AVX-* ]]; then
50-
export ATEN_DISABLE_AVX=1
51-
fi
52-
if [[ "${JOB_BASE_NAME}" == *-NO_AVX2-* ]]; then
53-
export ATEN_DISABLE_AVX2=1
48+
export ATEN_CPU_CAPABILITY=default
49+
elif [[ "${JOB_BASE_NAME}" == *-NO_AVX2-* ]]; then
50+
export ATEN_CPU_CAPABILITY=avx
5451
fi
5552

5653
test_python_nn() {

aten/src/ATen/native/DispatchStub.cpp

+44
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
#include "DispatchStub.h"
2+
3+
#include <ATen/Error.h>
4+
5+
#include <cpuinfo.h>
6+
#include <cstdlib>
7+
#include <cstring>
8+
9+
namespace at { namespace native {
10+
11+
static CPUCapability compute_cpu_capability() {
12+
auto envar = std::getenv("ATEN_CPU_CAPABILITY");
13+
if (envar) {
14+
if (strcmp(envar, "avx2") == 0) {
15+
return CPUCapability::AVX2;
16+
}
17+
if (strcmp(envar, "avx") == 0) {
18+
return CPUCapability::AVX;
19+
}
20+
if (strcmp(envar, "default") == 0) {
21+
return CPUCapability::DEFAULT;
22+
}
23+
AT_WARN("ignoring invalid value for ATEN_CPU_CAPABILITY: ", envar);
24+
}
25+
26+
#ifndef __powerpc__
27+
if (cpuinfo_initialize()) {
28+
if (cpuinfo_has_x86_avx2() && cpuinfo_has_x86_fma3()) {
29+
return CPUCapability::AVX2;
30+
}
31+
if (cpuinfo_has_x86_avx()) {
32+
return CPUCapability::AVX;
33+
}
34+
}
35+
#endif
36+
return CPUCapability::DEFAULT;
37+
}
38+
39+
CPUCapability get_cpu_capability() {
40+
static CPUCapability capability = compute_cpu_capability();
41+
return capability;
42+
}
43+
44+
}} // namespace at::native

aten/src/ATen/native/DispatchStub.h

+125
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
#pragma once
2+
3+
#include <ATen/Error.h>
4+
#include <ATen/ScalarType.h>
5+
#include <type_traits>
6+
7+
// Implements instruction set specific function dispatch.
8+
//
9+
// Kernels that may make use of specialized instruction sets (e.g. AVX) are
10+
// compiled multiple times with different compiler flags (e.g. -mavx). A
11+
// DispatchStub contains a table of function pointers for a kernel. At runtime,
12+
// the fastest available kernel is chosen based on the features reported by
13+
// cpuinfo.
14+
//
15+
// Example:
16+
//
17+
// In native/MyKernel.h:
18+
// using fn_type = void(*)(const Tensor& x);
19+
// DECLARE_DISPATCH(fn_type, stub);
20+
//
21+
// In native/MyKernel.cpp
22+
// DEFINE_DISPATCH(stub);
23+
//
24+
// In native/cpu/MyKernel.cpp:
25+
// void kernel(const Tensor& x) { ... }
26+
// REGISTER_DISPATCH(stub, &kernel);
27+
//
28+
// To call:
29+
// stub(kCPU, tensor);
30+
31+
// ignore warnings about DispatchStub::DEFAULT, AVX, AVX2 defined elsewhere
32+
#if defined(__clang__)
33+
#pragma clang diagnostic push
34+
#pragma clang diagnostic ignored "-Wundefined-var-template"
35+
#endif
36+
37+
namespace at { namespace native {
38+
39+
enum class CPUCapability {
40+
DEFAULT = 0,
41+
AVX = 1,
42+
AVX2 = 2,
43+
NUM_OPTIONS
44+
};
45+
46+
CPUCapability get_cpu_capability();
47+
48+
template <typename FnPtr, typename T>
49+
struct DispatchStub {
50+
static_assert(std::is_pointer<FnPtr>::value, "FnPtr should be a pointer type");
51+
52+
template <typename... ArgTypes>
53+
void operator()(Backend backend, ArgTypes... args) {
54+
if (backend == Backend::CPU) {
55+
if (!cpu_dispatch_ptr) {
56+
cpu_dispatch_ptr = choose_cpu_impl();
57+
}
58+
(*cpu_dispatch_ptr)(args...);
59+
} else if (backend == Backend::CUDA) {
60+
AT_ASSERTM(cuda_dispatch_ptr, "DispatchStub: missing CUDA kernel");
61+
(*cuda_dispatch_ptr)(args...);
62+
} else {
63+
AT_ERROR("DispatchStub: unsupported backend", backend);
64+
}
65+
}
66+
67+
FnPtr choose_cpu_impl() {
68+
auto capability = static_cast<int>(get_cpu_capability());
69+
(void)capability;
70+
#ifdef HAVE_AVX2_CPU_DEFINITION
71+
if (capability >= static_cast<int>(CPUCapability::AVX2)) {
72+
AT_ASSERTM(AVX2, "DispatchStub: missing AVX2 kernel");
73+
return AVX2;
74+
}
75+
#endif
76+
#ifdef HAVE_AVX_CPU_DEFINITION
77+
if (capability >= static_cast<int>(CPUCapability::AVX)) {
78+
AT_ASSERTM(AVX, "DispatchStub: missing AVX kernel");
79+
return AVX;
80+
}
81+
#endif
82+
AT_ASSERTM(DEFAULT, "DispatchStub: missing default kernel");
83+
return DEFAULT;
84+
}
85+
86+
FnPtr cpu_dispatch_ptr = nullptr;
87+
FnPtr cuda_dispatch_ptr = nullptr;
88+
static FnPtr DEFAULT;
89+
#ifdef HAVE_AVX_CPU_DEFINITION
90+
static FnPtr AVX;
91+
#endif
92+
#ifdef HAVE_AVX2_CPU_DEFINITION
93+
static FnPtr AVX2;
94+
#endif
95+
};
96+
97+
namespace {
98+
template <typename FnPtr, typename T>
99+
struct RegisterDispatch {
100+
RegisterDispatch(DispatchStub<FnPtr, T>& stub, FnPtr value) {
101+
stub.cuda_dispatch_ptr = value;
102+
}
103+
};
104+
} // anonymous namespace
105+
106+
#define DECLARE_DISPATCH(fn, name) \
107+
extern struct name : DispatchStub<fn, name> {} name
108+
109+
#define DEFINE_DISPATCH(name) struct name name
110+
111+
#if defined(__CUDACC__)
112+
#define REGISTER_DISPATCH(name, fn) \
113+
static RegisterDispatch<decltype(fn), struct name> name ## __register(name, fn);
114+
#elif defined(CPU_CAPABILITY)
115+
#define REGISTER_DISPATCH(name, fn) \
116+
template <> decltype(fn) DispatchStub<decltype(fn), struct name>::CPU_CAPABILITY = fn;
117+
#endif
118+
119+
120+
}} // namespace at::native
121+
122+
123+
#if defined(__clang__)
124+
#pragma clang diagnostic pop
125+
#endif

aten/src/ATen/native/ReduceOps.cpp

+7-4
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,9 @@
1717
namespace at {
1818
namespace native {
1919

20+
DEFINE_DISPATCH(sum_kernel);
21+
DEFINE_DISPATCH(prod_kernel);
22+
2023
static inline Tensor integer_upcast(const Tensor& self, optional<ScalarType> dtype) {
2124
ScalarType scalarType = self.type().scalarType();
2225
ScalarType upcast_scalarType = dtype.value_or(at::isIntegralType(scalarType) ? ScalarType::Long : scalarType);
@@ -127,7 +130,7 @@ Tensor sum(const Tensor &self) {
127130
Tensor _sum_cpu(const Tensor& self) {
128131
if (self.is_contiguous()) {
129132
Tensor result = at::empty({}, self.type());
130-
sum_kernel(result, self, at::nullopt);
133+
sum_kernel(kCPU, result, self, at::nullopt);
131134
return result;
132135
}
133136
return self._sumall();
@@ -148,7 +151,7 @@ Tensor prod(const Tensor &self) {
148151
Tensor _prod_cpu(const Tensor &self) {
149152
if (self.is_contiguous()) {
150153
Tensor result = at::empty({}, self.type());
151-
prod_kernel(result, self, at::nullopt);
154+
prod_kernel(kCPU, result, self, at::nullopt);
152155
return result;
153156
}
154157
return self._prodall();
@@ -222,7 +225,7 @@ Tensor &_sum_out_cpu(Tensor &result, const Tensor &self, int64_t dim_,
222225
return result;
223226
if (self.is_contiguous() && result.is_contiguous()) {
224227
_dimreduce_setup(result, self, dim);
225-
sum_kernel(result, self, dim);
228+
sum_kernel(kCPU, result, self, dim);
226229
if (!keepdim) result.squeeze_(dim);
227230
return result;
228231
}
@@ -260,7 +263,7 @@ Tensor &_prod_out_cpu(Tensor &result, const Tensor &self, int64_t dim_,
260263
return result;
261264
if (self.is_contiguous() && result.is_contiguous()) {
262265
_dimreduce_setup(result, self, dim);
263-
prod_kernel(result, self, dim);
266+
prod_kernel(kCPU, result, self, dim);
264267
if (!keepdim) result.squeeze_(dim);
265268
return result;
266269
}

aten/src/ATen/native/SoftMax.cpp

+10-4
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,7 @@ Tensor softmax_cpu(const Tensor& input_, const int64_t dim_) {
128128
dim >= 0 && dim < input.dim(),
129129
"dim must be non-negative and less than input dimensions");
130130
if (input.ndimension() > 0 && dim == input.ndimension() - 1) {
131-
softmax_lastdim_kernel(output, input);
131+
softmax_lastdim_kernel(kCPU, output, input);
132132
} else {
133133
AT_DISPATCH_FLOATING_TYPES(input.type(), "softmax", [&] {
134134
host_softmax<scalar_t, false>(output, input, dim);
@@ -147,7 +147,7 @@ Tensor log_softmax_cpu(const Tensor& input_, const int64_t dim_) {
147147
dim >= 0 && dim < input.dim(),
148148
"dim must be non-negative and less than input dimensions");
149149
if (input.ndimension() > 0 && dim == input.ndimension() - 1) {
150-
log_softmax_lastdim_kernel(output, input);
150+
log_softmax_lastdim_kernel(kCPU, output, input);
151151
} else {
152152
AT_DISPATCH_FLOATING_TYPES(input.type(), "log_softmax", [&] {
153153
host_softmax<scalar_t, true>(output, input, dim);
@@ -176,7 +176,7 @@ Tensor softmax_backward_cpu(
176176
dim >= 0 && dim < grad.dim(),
177177
"dim must be non-negative and less than input dimensions");
178178
if (grad.ndimension() > 0 && dim == grad.ndimension() - 1) {
179-
softmax_backward_lastdim_kernel(grad_input, grad, output);
179+
softmax_backward_lastdim_kernel(kCPU, grad_input, grad, output);
180180
} else {
181181
AT_DISPATCH_FLOATING_TYPES(grad.type(), "softmax_backward", [&] {
182182
host_softmax_backward<scalar_t, false>(grad_input, grad, output, dim);
@@ -205,13 +205,19 @@ Tensor log_softmax_backward_cpu(
205205
dim >= 0 && dim < grad.dim(),
206206
"dim must be non-negative and less than input dimensions");
207207
if (grad.ndimension() > 0 && dim == grad.ndimension() - 1) {
208-
log_softmax_backward_lastdim_kernel(grad_input, grad, output);
208+
log_softmax_backward_lastdim_kernel(kCPU, grad_input, grad, output);
209209
} else {
210210
AT_DISPATCH_FLOATING_TYPES(grad.type(), "log_softmax_backward", [&] {
211211
host_softmax_backward<scalar_t, true>(grad_input, grad, output, dim);
212212
});
213213
}
214214
return grad_input;
215215
}
216+
217+
DEFINE_DISPATCH(softmax_lastdim_kernel);
218+
DEFINE_DISPATCH(log_softmax_lastdim_kernel);
219+
DEFINE_DISPATCH(softmax_backward_lastdim_kernel);
220+
DEFINE_DISPATCH(log_softmax_backward_lastdim_kernel);
221+
216222
}
217223
}

aten/src/ATen/native/UnaryOps.cpp

+26-2
Original file line numberDiff line numberDiff line change
@@ -92,14 +92,14 @@ Tensor& fill_(Tensor& self, const Tensor& value) {
9292
Tensor& _##op##__cpu(Tensor& self_) { \
9393
if (self_.numel() > 0) { \
9494
Tensor self = sort_strides(self_); \
95-
op##Impl(self, self); \
95+
op##Impl(kCPU, self, self); \
9696
} \
9797
return self_; \
9898
} \
9999
Tensor& _##op##_out_cpu(Tensor& result, const Tensor& self) { \
100100
result.resize_(self.sizes()); \
101101
if (result.numel() > 0) { \
102-
op##Impl(result, self); \
102+
op##Impl(kCPU, result, self); \
103103
} \
104104
return result; \
105105
}
@@ -145,5 +145,29 @@ IMPLEMENT_UNARY_OP_VEC(tan)
145145
IMPLEMENT_UNARY_OP_VEC(tanh)
146146
IMPLEMENT_UNARY_OP_VEC(trunc)
147147

148+
DEFINE_DISPATCH(absImpl);
149+
DEFINE_DISPATCH(acosImpl);
150+
DEFINE_DISPATCH(asinImpl);
151+
DEFINE_DISPATCH(atanImpl);
152+
DEFINE_DISPATCH(ceilImpl);
153+
DEFINE_DISPATCH(cosImpl);
154+
DEFINE_DISPATCH(erfImpl);
155+
DEFINE_DISPATCH(erfcImpl);
156+
DEFINE_DISPATCH(expImpl);
157+
DEFINE_DISPATCH(expm1Impl);
158+
DEFINE_DISPATCH(floorImpl);
159+
DEFINE_DISPATCH(logImpl);
160+
DEFINE_DISPATCH(log10Impl);
161+
DEFINE_DISPATCH(log1pImpl);
162+
DEFINE_DISPATCH(log2Impl);
163+
DEFINE_DISPATCH(roundImpl);
164+
DEFINE_DISPATCH(rsqrtImpl);
165+
DEFINE_DISPATCH(sigmoidImpl);
166+
DEFINE_DISPATCH(sinImpl);
167+
DEFINE_DISPATCH(sqrtImpl);
168+
DEFINE_DISPATCH(tanImpl);
169+
DEFINE_DISPATCH(tanhImpl);
170+
DEFINE_DISPATCH(truncImpl);
171+
148172
}
149173
} // namespace at

0 commit comments

Comments
 (0)