Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 22 additions & 6 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,15 +1,26 @@
cmake_minimum_required(VERSION 3.19...3.25)

find_program(NVHPC_CXX_BIN "nvc++" REQUIRED)
set(CMAKE_CXX_COMPILER ${NVHPC_CXX_BIN})

find_program(NVHPC_C_BIN "nvc" REQUIRED)
set(CMAKE_C_COMPILER ${NVHPC_C_BIN})

project(jaxdecomp LANGUAGES CXX CUDA)

# NVCC 12 does not support C++20
set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CUDA_STANDARD 17)

# Latest JAX v0.4.26 no longer supports cuda 11.8
set(NVHPC_CUDA_VERSION 12.2)
# Build debug
# set(CMAKE_BUILD_TYPE Debug)
add_subdirectory(third_party/cuDecomp)
find_package(CUDAToolkit REQUIRED VERSION 12)
set(NVHPC_CUDA_VERSION ${CUDAToolkit_VERSION_MAJOR}.${CUDAToolkit_VERSION_MINOR})

project(jaxdecomp LANGUAGES CXX CUDA)
message(STATUS "Using CUDA ${NVHPC_CUDA_VERSION}")
# Build Release by default
set(CMAKE_BUILD_TYPE Release CACHE STRING "Choose the type of build.")

add_subdirectory(third_party/cuDecomp)

option(CUDECOMP_BUILD_FORTRAN "Build Fortran bindings" OFF)
option(CUDECOMP_ENABLE_NVSHMEM "Enable NVSHMEM" OFF)
Expand All @@ -32,7 +43,7 @@ find_library(NCCL_LIBRARY
NAMES nccl
HINTS ${NVHPC_NCCL_LIBRARY_DIR}
)
string(REPLACE "/lib" "/include" NCCL_INCLUDE_DIR ${NVHPC_NCCL_LIBRARY_DIR})
string(REPLACE "/lib" "/include" NCCL_INCLUDE_DIR ${NVHPC_NCCL_LIBRARY_DIR})


message(STATUS "Using NCCL library: ${NCCL_LIBRARY}")
Expand Down Expand Up @@ -66,4 +77,9 @@ target_link_libraries(_jaxdecomp PRIVATE NVHPC::CUTENSOR)
target_link_libraries(_jaxdecomp PRIVATE NVHPC::CUDA)
target_link_libraries(_jaxdecomp PRIVATE ${NCCL_LIBRARY})
target_link_libraries(_jaxdecomp PRIVATE cudecomp)
target_link_libraries(_jaxdecomp PRIVATE stdc++fs)
set_target_properties(_jaxdecomp PROPERTIES LINKER_LANGUAGE CXX)

set_target_properties(_jaxdecomp PROPERTIES INSTALL_RPATH "$ORIGIN/lib")

install(TARGETS _jaxdecomp LIBRARY DESTINATION . PUBLIC_HEADER DESTINATION .)
36 changes: 36 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
[build-system]
requires = [ "scikit-build-core","pybind11"]
build-backend = "scikit_build_core.build"

[project]
name = "jaxdecomp"
version = "0.1.0"
description = "JAX bindings for the cuDecomp library"
authors = [
{ name = "Wassim Kabalan" },
{ name = "Francois Lanusse"}
]
urls = { "Homepage" = "https://github.com/DifferentiableUniverseInitiative/jaxDecomp" }
readme = "README.md"
license = { file = "LICENSE" }
classifiers = [
"Programming Language :: Python :: 3",
"License :: OSI Approved :: MIT License",
"Operating System :: OS Independent"
]
dependencies = []

[project.optional-dependencies]
test = ["pytest"]

[tool.scikit-build]
minimum-version = "0.8"
cmake.version = ">=3.19"
build-dir = "build/{wheel_tag}"
wheel.py-api = "py3"
cmake.build-type = "Release"
# Add any additional configurations for scikit-build if necessary
wheel.install-dir = "jaxdecomp/_src"

[tool.scikit-build.cmake.define]
CMAKE_LIBRARY_OUTPUT_DIRECTORY = ""
96 changes: 0 additions & 96 deletions setup.py

This file was deleted.