1
1
cmake_minimum_required (VERSION 3.19...3.25 )
2
2
3
+ find_program (NVHPC_CXX_BIN "nvc++" REQUIRED )
4
+ set (CMAKE_CXX_COMPILER ${NVHPC_CXX_BIN} )
5
+
6
+ find_program (NVHPC_C_BIN "nvc" REQUIRED )
7
+ set (CMAKE_C_COMPILER ${NVHPC_C_BIN} )
8
+
9
+ project (jaxdecomp LANGUAGES CXX CUDA )
10
+
3
11
# NVCC 12 does not support C++20
4
12
set (CMAKE_CXX_STANDARD 17 )
5
13
set (CMAKE_CUDA_STANDARD 17 )
14
+
6
15
# Latest JAX v0.4.26 no longer supports cuda 11.8
7
- # By default, build for CUDA 12.2, users can override this with -DNVHPC_CUDA_VERSION=11.8
8
- set (NVHPC_CUDA_VERSION 12.2 CACHE STRING "CUDA version to build for" )
16
+ find_package ( CUDAToolkit REQUIRED VERSION 12 )
17
+ set (NVHPC_CUDA_VERSION ${CUDAToolkit_VERSION_MAJOR} .${CUDAToolkit_VERSION_MINOR} )
9
18
10
- # Build debug
11
- # set(CMAKE_BUILD_TYPE Debug)
12
- add_subdirectory ( third_party/cuDecomp )
19
+ message ( STATUS "Using CUDA ${NVHPC_CUDA_VERSION} " )
20
+ # Build Release by default
21
+ set ( CMAKE_BUILD_TYPE Release CACHE STRING "Choose the type of build." )
13
22
14
- project ( jaxdecomp LANGUAGES CXX CUDA )
23
+ add_subdirectory ( third_party/cuDecomp )
15
24
16
25
option (CUDECOMP_BUILD_FORTRAN "Build Fortran bindings" OFF )
17
26
option (CUDECOMP_ENABLE_NVSHMEM "Enable NVSHMEM" OFF )
@@ -34,7 +43,7 @@ find_library(NCCL_LIBRARY
34
43
NAMES nccl
35
44
HINTS ${NVHPC_NCCL_LIBRARY_DIR}
36
45
)
37
- string (REPLACE "/lib" "/include" NCCL_INCLUDE_DIR ${NVHPC_NCCL_LIBRARY_DIR} )
46
+ string (REPLACE "/lib" "/include" NCCL_INCLUDE_DIR ${NVHPC_NCCL_LIBRARY_DIR} )
38
47
39
48
40
49
message (STATUS "Using NCCL library: ${NCCL_LIBRARY} " )
@@ -68,4 +77,9 @@ target_link_libraries(_jaxdecomp PRIVATE NVHPC::CUTENSOR)
68
77
target_link_libraries (_jaxdecomp PRIVATE NVHPC::CUDA )
69
78
target_link_libraries (_jaxdecomp PRIVATE ${NCCL_LIBRARY} )
70
79
target_link_libraries (_jaxdecomp PRIVATE cudecomp )
80
+ target_link_libraries (_jaxdecomp PRIVATE stdc++fs )
71
81
set_target_properties (_jaxdecomp PROPERTIES LINKER_LANGUAGE CXX )
82
+
83
+ set_target_properties (_jaxdecomp PROPERTIES INSTALL_RPATH "$ORIGIN/lib" )
84
+
85
+ install (TARGETS _jaxdecomp LIBRARY DESTINATION . PUBLIC_HEADER DESTINATION . )
0 commit comments