Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Packaging] Migrating to scikit-build core #22

Merged
merged 11 commits into from
Jul 8, 2024
7 changes: 0 additions & 7 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,3 @@ repos:
hooks:
- id: isort
name: isort (python)
- repo: https://github.com/pre-commit/mirrors-clang-format
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This made my precommit hook fail for some reason...

rev: v18.1.4
hooks:
- id: clang-format
files: '\.(c|cc|cpp|h|hpp|cxx|hh|cu|cuh)$'
exclude: '^third_party/|/pybind11/'
name: clang-format
28 changes: 21 additions & 7 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,17 +1,26 @@
cmake_minimum_required(VERSION 3.19...3.25)

find_program(NVHPC_CXX_BIN "nvc++" REQUIRED)
set(CMAKE_CXX_COMPILER ${NVHPC_CXX_BIN})

find_program(NVHPC_C_BIN "nvc" REQUIRED)
set(CMAKE_C_COMPILER ${NVHPC_C_BIN})
EiffL marked this conversation as resolved.
Show resolved Hide resolved

project(jaxdecomp LANGUAGES CXX CUDA)

# NVCC 12 does not support C++20
set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CUDA_STANDARD 17)

# Latest JAX v0.4.26 no longer supports cuda 11.8
# By default, build for CUDA 12.2, users can override this with -DNVHPC_CUDA_VERSION=11.8
set(NVHPC_CUDA_VERSION 12.2 CACHE STRING "CUDA version to build for" )
find_package(CUDAToolkit REQUIRED VERSION 12)
set(NVHPC_CUDA_VERSION ${CUDAToolkit_VERSION_MAJOR}.${CUDAToolkit_VERSION_MINOR})

# Build debug
# set(CMAKE_BUILD_TYPE Debug)
add_subdirectory(third_party/cuDecomp)
message(STATUS "Using CUDA ${NVHPC_CUDA_VERSION}")
# Build Release by default
set(CMAKE_BUILD_TYPE Release CACHE STRING "Choose the type of build.")

project(jaxdecomp LANGUAGES CXX CUDA)
add_subdirectory(third_party/cuDecomp)

option(CUDECOMP_BUILD_FORTRAN "Build Fortran bindings" OFF)
option(CUDECOMP_ENABLE_NVSHMEM "Enable NVSHMEM" OFF)
Expand All @@ -34,7 +43,7 @@ find_library(NCCL_LIBRARY
NAMES nccl
HINTS ${NVHPC_NCCL_LIBRARY_DIR}
)
string(REPLACE "/lib" "/include" NCCL_INCLUDE_DIR ${NVHPC_NCCL_LIBRARY_DIR})
string(REPLACE "/lib" "/include" NCCL_INCLUDE_DIR ${NVHPC_NCCL_LIBRARY_DIR})


message(STATUS "Using NCCL library: ${NCCL_LIBRARY}")
Expand Down Expand Up @@ -68,4 +77,9 @@ target_link_libraries(_jaxdecomp PRIVATE NVHPC::CUTENSOR)
target_link_libraries(_jaxdecomp PRIVATE NVHPC::CUDA)
target_link_libraries(_jaxdecomp PRIVATE ${NCCL_LIBRARY})
target_link_libraries(_jaxdecomp PRIVATE cudecomp)
target_link_libraries(_jaxdecomp PRIVATE stdc++fs)
set_target_properties(_jaxdecomp PROPERTIES LINKER_LANGUAGE CXX)

set_target_properties(_jaxdecomp PROPERTIES INSTALL_RPATH "$ORIGIN/lib")

install(TARGETS _jaxdecomp LIBRARY DESTINATION . PUBLIC_HEADER DESTINATION .)
36 changes: 36 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
[build-system]
requires = [ "scikit-build-core","pybind11"]
build-backend = "scikit_build_core.build"

[project]
name = "jaxdecomp"
version = "0.1.0"
description = "JAX bindings for the cuDecomp library"
authors = [
{ name = "Wassim Kabalan" },
{ name = "Francois Lanusse"}
]
urls = { "Homepage" = "https://github.com/DifferentiableUniverseInitiative/jaxDecomp" }
readme = "README.md"
license = { file = "LICENSE" }
classifiers = [
"Programming Language :: Python :: 3",
"License :: OSI Approved :: MIT License",
"Operating System :: OS Independent"
]
dependencies = []

[project.optional-dependencies]
test = ["pytest"]

[tool.scikit-build]
minimum-version = "0.8"
cmake.version = ">=3.19"
build-dir = "build/{wheel_tag}"
wheel.py-api = "py3"
cmake.build-type = "Release"
# Add any additional configurations for scikit-build if necessary
wheel.install-dir = "jaxdecomp/_src"

[tool.scikit-build.cmake.define]
CMAKE_LIBRARY_OUTPUT_DIRECTORY = ""
96 changes: 0 additions & 96 deletions setup.py

This file was deleted.

Loading