Skip to content

Commit 88af214

Browse files
janeyx99facebook-github-bot
authored andcommitted
Add build option to split torch_cuda library into torch_cuda_cu and torch_cuda_cpp (pytorch#49050)
Summary: Because of the size of our `libtorch_cuda.so`, linking with other hefty binaries presents a problem where 32bit relocation markers are too small and end up overflowing. This PR attempts to break up `torch_cuda` into `torch_cuda_cu` and `torch_cuda_cpp`. `torch_cuda_cu`: all the files previously in `Caffe2_GPU_SRCS` that are * pure `.cu` files in `aten`match * all the BLAS files * all the THC files, except for THCAllocator.cpp, THCCachingHostAllocator.cpp and THCGeneral.cpp * all files in`detail` * LegacyDefinitions.cpp and LegacyTHFunctionsCUDA.cpp * Register*CUDA.cpp * CUDAHooks.cpp * CUDASolver.cpp * TensorShapeCUDA.cpp `torch_cuda_cpp`: all other files in `Caffe2_GPU_SRCS` Accordingly, TORCH_CUDA_API and TORCH_CUDA_BUILD_MAIN_LIB usages are getting split as well to TORCH_CUDA_CU_API and TORCH_CUDA_CPP_API. To test this locally, you can run `export BUILD_SPLIT_CUDA=ON && python setup.py develop`. In your `build/lib` folder, you should find binaries for both `torch_cuda_cpp` and `torch_cuda_cu`. To see that the SPLIT_CUDA option was toggled, you can grep the Summary of running cmake and make sure `Split CUDA` is ON. This build option is tested on CI for CUDA 11.1 builds (linux for now, but windows soon). Pull Request resolved: pytorch#49050 Reviewed By: walterddr Differential Revision: D26114310 Pulled By: janeyx99 fbshipit-source-id: 0180f2519abb5a9cdde16a6fb7dd3171cff687a6
1 parent 87ad77e commit 88af214

19 files changed

+239
-74
lines changed

.jenkins/pytorch/build.sh

+5
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,11 @@ if [[ "$BUILD_ENVIRONMENT" == *coverage* ]]; then
5353
export USE_CPP_CODE_COVERAGE=ON
5454
fi
5555

56+
if [[ "$BUILD_ENVIRONMENT" == *cuda11* ]]; then
57+
# enable split torch_cuda build option in CMake
58+
export BUILD_SPLIT_CUDA=ON
59+
fi
60+
5661
# TODO: Don't run this...
5762
pip_install -r requirements.txt || true
5863

.jenkins/pytorch/win-test-helpers/build_pytorch.bat

+4-1
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,10 @@ if "%REBUILD%" == "" (
110110
aws s3 cp "s3://ossci-windows/Restore PyTorch Environment.lnk" "C:\Users\circleci\Desktop\Restore PyTorch Environment.lnk"
111111
)
112112
)
113+
:: tests if BUILD_ENVIRONMENT contains cuda11 as a substring
114+
if not x%BUILD_ENVIRONMENT:cuda11=%==x%BUILD_ENVIRONMENT% (
115+
set BUILD_SPLIT_CUDA=ON
116+
)
113117

114118
python setup.py install --cmake && sccache --show-stats && (
115119
if "%BUILD_ENVIRONMENT%"=="" (
@@ -118,4 +122,3 @@ python setup.py install --cmake && sccache --show-stats && (
118122
7z a %TMP_DIR_WIN%\%IMAGE_COMMIT_TAG%.7z %CONDA_PARENT_DIR%\Miniconda3\Lib\site-packages\torch %CONDA_PARENT_DIR%\Miniconda3\Lib\site-packages\caffe2 && copy /Y "%TMP_DIR_WIN%\%IMAGE_COMMIT_TAG%.7z" "%PYTORCH_FINAL_PACKAGE_DIR%\"
119123
)
120124
)
121-

CMakeLists.txt

+7
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,13 @@ option(COLORIZE_OUTPUT "Colorize output during compilation" ON)
161161
option(USE_ASAN "Use Address Sanitizer" OFF)
162162
option(USE_TSAN "Use Thread Sanitizer" OFF)
163163
option(USE_CUDA "Use CUDA" ON)
164+
# BUILD_SPLIT_CUDA must also be exported as an environment variable before building, with
165+
# `export BUILD_SPLIT_CUDA=1` because cpp_extension.py can only work properly if this variable
166+
# also exists in the environment.
167+
# This option is incompatible with CUDA_SEPARABLE_COMPILATION.
168+
cmake_dependent_option(
169+
BUILD_SPLIT_CUDA "Split torch_cuda library into torch_cuda_cu and torch_cuda_cpp" OFF
170+
"USE_CUDA AND NOT CUDA_SEPARABLE_COMPILATION" OFF)
164171
option(USE_FAST_NVCC "Use parallel NVCC build" OFF)
165172
option(USE_ROCM "Use ROCm" ON)
166173
option(CAFFE2_STATIC_LINK_CUDA "Statically link CUDA libraries" OFF)

aten/src/ATen/cuda/CUDAContext.cpp

+3-3
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,9 @@ void initDeviceProperty(DeviceIndex device_index) {
2929

3030
} // anonymous namespace
3131

32-
// We need this function to force the linking against torch_cuda on Windows.
33-
// If you need to modify this function, please specify a new function and apply the changes
34-
// according to https://github.com/pytorch/pytorch/pull/34288.
32+
// We need this function to force the linking against torch_cuda(_cpp) on Windows.
33+
// If you need to modify this function, please specify a new function and apply
34+
// the changes according to https://github.com/pytorch/pytorch/pull/34288.
3535
// Related issue: https://github.com/pytorch/pytorch/issues/31611.
3636
/* Device info */
3737
int warp_size() {

aten/src/ATen/cudnn/Descriptors.h

+8-8
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ struct DescriptorDeleter {
8080
// initialized the first time you call set() or any other initializing
8181
// function.
8282
template <typename T, cudnnStatus_t (*ctor)(T**), cudnnStatus_t (*dtor)(T*)>
83-
class TORCH_CUDA_CU_API Descriptor {
83+
class TORCH_CUDA_CPP_API Descriptor {
8484
public:
8585
// TODO: Figure out why const-correctness doesn't work here
8686

@@ -108,7 +108,7 @@ class TORCH_CUDA_CU_API Descriptor {
108108
std::unique_ptr<T, DescriptorDeleter<T, dtor>> desc_;
109109
};
110110

111-
class TORCH_CUDA_CU_API TensorDescriptor : public Descriptor<
111+
class TORCH_CUDA_CPP_API TensorDescriptor : public Descriptor<
112112
cudnnTensorStruct,
113113
&cudnnCreateTensorDescriptor,
114114
&cudnnDestroyTensorDescriptor> {
@@ -147,7 +147,7 @@ class TORCH_CUDA_CU_API TensorDescriptor : public Descriptor<
147147

148148
std::ostream& operator<<(std::ostream & out, const TensorDescriptor& d);
149149

150-
class TORCH_CUDA_CU_API FilterDescriptor : public Descriptor<
150+
class TORCH_CUDA_CPP_API FilterDescriptor : public Descriptor<
151151
cudnnFilterStruct,
152152
&cudnnCreateFilterDescriptor,
153153
&cudnnDestroyFilterDescriptor> {
@@ -163,7 +163,7 @@ class TORCH_CUDA_CU_API FilterDescriptor : public Descriptor<
163163

164164
std::ostream& operator<<(std::ostream & out, const FilterDescriptor& d);
165165

166-
struct TORCH_CUDA_CU_API ConvolutionDescriptor
166+
struct TORCH_CUDA_CPP_API ConvolutionDescriptor
167167
: public Descriptor<
168168
cudnnConvolutionStruct,
169169
&cudnnCreateConvolutionDescriptor,
@@ -186,7 +186,7 @@ struct TORCH_CUDA_CU_API ConvolutionDescriptor
186186
}
187187
};
188188

189-
struct TORCH_CUDA_CU_API SpatialTransformerDescriptor
189+
struct TORCH_CUDA_CPP_API SpatialTransformerDescriptor
190190
: public Descriptor<
191191
cudnnSpatialTransformerStruct,
192192
&cudnnCreateSpatialTransformerDescriptor,
@@ -196,7 +196,7 @@ struct TORCH_CUDA_CU_API SpatialTransformerDescriptor
196196
}
197197
};
198198

199-
struct TORCH_CUDA_CU_API DropoutDescriptor
199+
struct TORCH_CUDA_CPP_API DropoutDescriptor
200200
: public Descriptor<
201201
cudnnDropoutStruct,
202202
&cudnnCreateDropoutDescriptor,
@@ -235,7 +235,7 @@ struct TORCH_CUDA_CU_API DropoutDescriptor
235235
}
236236
};
237237

238-
struct TORCH_CUDA_CU_API RNNDescriptor : public Descriptor<
238+
struct TORCH_CUDA_CPP_API RNNDescriptor : public Descriptor<
239239
cudnnRNNStruct,
240240
&cudnnCreateRNNDescriptor,
241241
&cudnnDestroyRNNDescriptor> {
@@ -282,7 +282,7 @@ struct TORCH_CUDA_CU_API RNNDescriptor : public Descriptor<
282282
}
283283
};
284284

285-
struct TORCH_CUDA_CU_API CTCLossDescriptor
285+
struct TORCH_CUDA_CPP_API CTCLossDescriptor
286286
: public Descriptor<
287287
cudnnCTCLossStruct,
288288
&cudnnCreateCTCLossDescriptor,

aten/src/ATen/cudnn/Handle.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -5,5 +5,5 @@
55

66
namespace at { namespace native {
77

8-
TORCH_CUDA_CU_API cudnnHandle_t getCudnnHandle();
8+
TORCH_CUDA_CPP_API cudnnHandle_t getCudnnHandle();
99
}} // namespace at::native

aten/src/ATen/cudnn/Types.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
namespace at { namespace native {
77

8-
TORCH_CUDA_CU_API cudnnDataType_t
8+
TORCH_CUDA_CPP_API cudnnDataType_t
99
getCudnnDataTypeFromScalarType(const at::ScalarType dtype);
1010
cudnnDataType_t getCudnnDataType(const at::Tensor& tensor);
1111

aten/src/ATen/native/cuda/Bucketization.cu

+1
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,7 @@ Tensor& searchsorted_out_cuda(Tensor& result, const Tensor& sorted_sequence, con
126126
return result;
127127
}
128128

129+
// We need this function to force the linking against torch_cuda_cu on Windows.
129130
Tensor searchsorted_cuda(const Tensor& sorted_sequence, const Tensor& self, bool out_int32, bool right) {
130131
ScalarType scalar_type = out_int32 ? ScalarType::Int : ScalarType::Long;
131132
c10::TensorOptions options = TensorOptions().device(self.options().device()).dtype(scalar_type);

aten/src/ATen/native/cudnn/RNN.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -774,7 +774,7 @@ namespace {
774774
// Utilities exposed in RNNUtils.h
775775
namespace cudnn_rnn {
776776

777-
TORCH_CUDA_CU_API std::tuple<Tensor, std::vector<Tensor>>
777+
TORCH_CUDA_CPP_API std::tuple<Tensor, std::vector<Tensor>>
778778
copy_weights_to_flat_buf_views(
779779
TensorList weight_arr,
780780
int64_t weight_stride0,

aten/src/ATen/native/cudnn/RNNUtils.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ namespace at {
88
namespace native {
99
namespace cudnn_rnn {
1010

11-
TORCH_CUDA_CU_API std::tuple<Tensor, std::vector<Tensor>>
11+
TORCH_CUDA_CPP_API std::tuple<Tensor, std::vector<Tensor>>
1212
copy_weights_to_flat_buf_views(
1313
TensorList weight_arr,
1414
int64_t weight_stride0,

aten/src/ATen/test/cuda_distributions_test.cu

+9-8
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ TEST(DistributionsTest, TestPhiloxIncrementSmallUniformTensor) {
7979

8080
// get 4 randoms from uniform_(), philox offset is now incremented to 4 by this call
8181
at::empty({4}, at::TensorOptions(at::kCUDA)).uniform_();
82-
82+
8383
// expected uniforms will start from counter offset of 4
8484
assert_with_expected_uniforms(4);
8585
}
@@ -97,12 +97,13 @@ TEST(DistributionsTest, TestPhiloxIncrementBigUniformTensor) {
9797
// greater the number of threads launched), it hits the unroll loop in
9898
// the uniform_ kernel.
9999
// - Hence, we set the size of the tensor in this test to be 8 times the
100-
// maximum number of threads we can launch. This means that, each thread will
101-
// be yielding 8 elements, and as a result, curand_uniform4 will be called twice
102-
// and all the 8 elements in a thread will consume all the float4 from the
103-
// two calls of curand_unfiorm4 as a result of the unroll loop. Therefore,
104-
// after this call to the unform_, counter_offset for the next call to uniform_
105-
// will start from 8. This is what we test next.
100+
// maximum number of threads we can launch. This means that, each thread
101+
// will be yielding 8 elements, and as a result, curand_uniform4 will be
102+
// called twice and all the 8 elements in a thread will consume all the
103+
// float4 from the two calls of curand_uniform4 as a result of the unroll
104+
// loop. Therefore, after this call to the uniform_, counter_offset for
105+
// the next call to uniform_ will start from 8. This is what we test
106+
// next.
106107
// - assert that call to uniform_ will start from counter_offset of 8
107108

108109
// if cuda not available, return
@@ -121,7 +122,7 @@ TEST(DistributionsTest, TestPhiloxIncrementBigUniformTensor) {
121122

122123
// get numel randoms from uniform_(), philox offset is now incremented to 8 by this call
123124
at::empty({numel}, at::TensorOptions(at::kCUDA)).uniform_();
124-
125+
125126
// expected uniforms will start from counter offset of 8
126127
assert_with_expected_uniforms(8);
127128
}

aten/src/THC/THCAllocator.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
// IPC doesn't support (re)allocation
77

8-
class TORCH_CUDA_CU_API THCIpcDeleter {
8+
class TORCH_CUDA_CPP_API THCIpcDeleter {
99
public:
1010
THCIpcDeleter(std::shared_ptr<void> basePtr);
1111
~THCIpcDeleter();

aten/src/THC/THCCachingHostAllocator.h

+3-3
Original file line numberDiff line numberDiff line change
@@ -21,14 +21,14 @@
2121
// Note that this allocator does not split larger allocations into smaller
2222
// blocks, unlike the caching device allocator.
2323
//
24-
TORCH_CUDA_CU_API c10::Allocator* getTHCCachingHostAllocator(void);
24+
TORCH_CUDA_CPP_API c10::Allocator* getTHCCachingHostAllocator(void);
2525

2626
// Records an event in the specified stream. The allocation 'ptr' will not be
2727
// re-used until the event has occurred.
28-
TORCH_CUDA_CU_API cudaError_t
28+
TORCH_CUDA_CPP_API cudaError_t
2929
THCCachingHostAllocator_recordEvent(void* ptr, at::cuda::CUDAStream stream);
3030

3131
// Releases cached pinned memory allocations via cudaHostFree
32-
TORCH_CUDA_CU_API void THCCachingHostAllocator_emptyCache(void);
32+
TORCH_CUDA_CPP_API void THCCachingHostAllocator_emptyCache(void);
3333

3434
#endif

aten/src/THC/THCGeneral.h.in

+15-15
Original file line numberDiff line numberDiff line change
@@ -31,39 +31,39 @@ typedef struct _THCCudaResourcesPerDevice {
3131
size_t scratchSpacePerStream;
3232
} THCCudaResourcesPerDevice;
3333

34-
TORCH_CUDA_CU_API THCState* THCState_alloc(void);
35-
TORCH_CUDA_CU_API void THCState_free(THCState* state);
34+
TORCH_CUDA_CPP_API THCState* THCState_alloc(void);
35+
TORCH_CUDA_CPP_API void THCState_free(THCState* state);
3636

37-
TORCH_CUDA_CU_API void THCudaInit(THCState* state);
38-
TORCH_CUDA_CU_API void THCudaShutdown(THCState* state);
37+
TORCH_CUDA_CPP_API void THCudaInit(THCState* state);
38+
TORCH_CUDA_CPP_API void THCudaShutdown(THCState* state);
3939

4040
/* If device `dev` can access allocations on device `devToAccess`, this will return */
4141
/* 1; otherwise, 0. */
42-
TORCH_CUDA_CU_API int THCState_getPeerToPeerAccess(THCState* state, int dev, int devToAccess);
42+
TORCH_CUDA_CPP_API int THCState_getPeerToPeerAccess(THCState* state, int dev, int devToAccess);
4343

44-
TORCH_CUDA_CU_API c10::Allocator* THCState_getCudaHostAllocator(THCState* state);
44+
TORCH_CUDA_CPP_API c10::Allocator* THCState_getCudaHostAllocator(THCState* state);
4545

46-
TORCH_CUDA_CU_API void THCMagma_init(THCState *state);
46+
TORCH_CUDA_CPP_API void THCMagma_init(THCState *state);
4747

4848
/* For the current device and stream, returns the allocated scratch space */
49-
TORCH_CUDA_CU_API size_t THCState_getCurrentDeviceScratchSpaceSize(THCState* state);
49+
TORCH_CUDA_CPP_API size_t THCState_getCurrentDeviceScratchSpaceSize(THCState* state);
5050

5151
#define THCAssertSameGPU(expr) if (!expr) THError("arguments are located on different GPUs")
5252
#define THCudaCheck(err) __THCudaCheck(err, __FILE__, __LINE__)
5353
#define THCudaCheckWarn(err) __THCudaCheckWarn(err, __FILE__, __LINE__)
5454
#define THCublasCheck(err) __THCublasCheck(err, __FILE__, __LINE__)
5555
#define THCusparseCheck(err) __THCusparseCheck(err, __FILE__, __LINE__)
5656

57-
TORCH_CUDA_CU_API void __THCudaCheck(cudaError_t err, const char *file, const int line);
58-
TORCH_CUDA_CU_API void __THCudaCheckWarn(cudaError_t err, const char *file, const int line);
59-
TORCH_CUDA_CU_API void __THCublasCheck(cublasStatus_t status, const char *file, const int line);
60-
TORCH_CUDA_CU_API void __THCusparseCheck(cusparseStatus_t status, const char *file, const int line);
57+
TORCH_CUDA_CPP_API void __THCudaCheck(cudaError_t err, const char *file, const int line);
58+
TORCH_CUDA_CPP_API void __THCudaCheckWarn(cudaError_t err, const char *file, const int line);
59+
TORCH_CUDA_CPP_API void __THCublasCheck(cublasStatus_t status, const char *file, const int line);
60+
TORCH_CUDA_CPP_API void __THCusparseCheck(cusparseStatus_t status, const char *file, const int line);
6161

62-
TORCH_CUDA_CU_API void* THCudaMalloc(THCState *state, size_t size);
63-
TORCH_CUDA_CU_API void THCudaFree(THCState *state, void* ptr);
62+
TORCH_CUDA_CPP_API void* THCudaMalloc(THCState *state, size_t size);
63+
TORCH_CUDA_CPP_API void THCudaFree(THCState *state, void* ptr);
6464

6565
at::DataPtr THCudaHostAlloc(THCState *state, size_t size);
6666

67-
TORCH_CUDA_CU_API void THCudaHostRecord(THCState *state, void *ptr);
67+
TORCH_CUDA_CPP_API void THCudaHostRecord(THCState *state, void *ptr);
6868

6969
#endif

c10/macros/Export.h

+22-7
Original file line numberDiff line numberDiff line change
@@ -100,15 +100,30 @@
100100

101101
// NB: For now, HIP is overloaded to use the same macro, but ideally
102102
// HIPify should translate TORCH_CUDA_API to TORCH_HIP_API
103-
#if defined(TORCH_CUDA_BUILD_MAIN_LIB) || defined(TORCH_HIP_BUILD_MAIN_LIB)
104-
#define TORCH_CUDA_API C10_EXPORT
105-
#else
106-
#define TORCH_CUDA_API C10_IMPORT
103+
// JX: I removed the || defined(TORCH_HIP_BUILD_MAIN_LIB) check for TORCH_CUDA_*_API
104+
// since TORCH_HIP_API seems properly initialized below
105+
// libtorch_cuda_cu.so
106+
#ifdef TORCH_CUDA_CU_BUILD_MAIN_LIB
107+
#define TORCH_CUDA_CU_API C10_EXPORT
108+
#elif defined(BUILD_SPLIT_CUDA)
109+
#define TORCH_CUDA_CU_API C10_IMPORT
110+
#endif
111+
112+
// libtorch_cuda_cpp.so
113+
#ifdef TORCH_CUDA_CPP_BUILD_MAIN_LIB
114+
#define TORCH_CUDA_CPP_API C10_EXPORT
115+
#elif defined(BUILD_SPLIT_CUDA)
116+
#define TORCH_CUDA_CPP_API C10_IMPORT
107117
#endif
108118

109-
// This is in preparation for the imminent torch_cuda split
110-
#define TORCH_CUDA_CU_API TORCH_CUDA_API
111-
#define TORCH_CUDA_CPP_API TORCH_CUDA_API
119+
// libtorch_cuda.so (where torch_cuda_cu and torch_cuda_cpp are a part of the same api)
120+
#ifdef TORCH_CUDA_BUILD_MAIN_LIB
121+
#define TORCH_CUDA_CPP_API C10_EXPORT
122+
#define TORCH_CUDA_CU_API C10_EXPORT
123+
#elif !defined(BUILD_SPLIT_CUDA)
124+
#define TORCH_CUDA_CPP_API C10_IMPORT
125+
#define TORCH_CUDA_CU_API C10_IMPORT
126+
#endif
112127

113128
#if defined(TORCH_HIP_BUILD_MAIN_LIB)
114129
#define TORCH_HIP_API C10_EXPORT

0 commit comments

Comments
 (0)