Skip to content

Commit 1be36aa

Browse files
authored
Merge pull request #22 from DifferentiableUniverseInitiative/push-to-pypi
[Packaging] Migrating to scikit-build core
2 parents d551967 + d6640e8 commit 1be36aa

File tree

4 files changed

+60
-106
lines changed

4 files changed

+60
-106
lines changed

CMakeLists.txt

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,26 @@
11
cmake_minimum_required(VERSION 3.19...3.25)
22

3+
find_program(NVHPC_CXX_BIN "nvc++" REQUIRED)
4+
set(CMAKE_CXX_COMPILER ${NVHPC_CXX_BIN})
5+
6+
find_program(NVHPC_C_BIN "nvc" REQUIRED)
7+
set(CMAKE_C_COMPILER ${NVHPC_C_BIN})
8+
9+
project(jaxdecomp LANGUAGES CXX CUDA)
10+
311
# NVCC 12 does not support C++20
412
set(CMAKE_CXX_STANDARD 17)
513
set(CMAKE_CUDA_STANDARD 17)
14+
615
# Latest JAX v0.4.26 no longer supports cuda 11.8
7-
# By default, build for CUDA 12.2, users can override this with -DNVHPC_CUDA_VERSION=11.8
8-
set(NVHPC_CUDA_VERSION 12.2 CACHE STRING "CUDA version to build for" )
16+
find_package(CUDAToolkit REQUIRED VERSION 12)
17+
set(NVHPC_CUDA_VERSION ${CUDAToolkit_VERSION_MAJOR}.${CUDAToolkit_VERSION_MINOR})
918

10-
# Build debug
11-
# set(CMAKE_BUILD_TYPE Debug)
12-
add_subdirectory(third_party/cuDecomp)
19+
message(STATUS "Using CUDA ${NVHPC_CUDA_VERSION}")
20+
# Build Release by default
21+
set(CMAKE_BUILD_TYPE Release CACHE STRING "Choose the type of build.")
1322

14-
project(jaxdecomp LANGUAGES CXX CUDA)
23+
add_subdirectory(third_party/cuDecomp)
1524

1625
option(CUDECOMP_BUILD_FORTRAN "Build Fortran bindings" OFF)
1726
option(CUDECOMP_ENABLE_NVSHMEM "Enable NVSHMEM" OFF)
@@ -34,7 +43,7 @@ find_library(NCCL_LIBRARY
3443
NAMES nccl
3544
HINTS ${NVHPC_NCCL_LIBRARY_DIR}
3645
)
37-
string(REPLACE "/lib" "/include" NCCL_INCLUDE_DIR ${NVHPC_NCCL_LIBRARY_DIR})
46+
string(REPLACE "/lib" "/include" NCCL_INCLUDE_DIR ${NVHPC_NCCL_LIBRARY_DIR})
3847

3948

4049
message(STATUS "Using NCCL library: ${NCCL_LIBRARY}")
@@ -68,4 +77,9 @@ target_link_libraries(_jaxdecomp PRIVATE NVHPC::CUTENSOR)
6877
target_link_libraries(_jaxdecomp PRIVATE NVHPC::CUDA)
6978
target_link_libraries(_jaxdecomp PRIVATE ${NCCL_LIBRARY})
7079
target_link_libraries(_jaxdecomp PRIVATE cudecomp)
80+
target_link_libraries(_jaxdecomp PRIVATE stdc++fs)
7181
set_target_properties(_jaxdecomp PROPERTIES LINKER_LANGUAGE CXX)
82+
83+
set_target_properties(_jaxdecomp PROPERTIES INSTALL_RPATH "$ORIGIN/lib")
84+
85+
install(TARGETS _jaxdecomp LIBRARY DESTINATION . PUBLIC_HEADER DESTINATION .)

jaxdecomp/_src/halo.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ class HaloPrimitive(BasePrimitive):
2525

2626
name = "halo_exchange"
2727
multiple_results = False
28-
impl_static_args = (1, 2, 3)
28+
impl_static_args = (1, 2)
2929
inner_primitive = None
3030
outer_primitive = None
3131

@@ -310,7 +310,7 @@ def halo_p_lower(x: Array, halo_extents: Tuple[int, int, int],
310310

311311

312312
# Custom Partitioning
313-
@partial(jax.custom_vjp, nondiff_argnums=(1, 2, 3))
313+
@partial(jax.custom_vjp, nondiff_argnums=(1, 2))
314314
def halo_exchange(x: Array, halo_extents: Tuple[int, int, int],
315315
halo_periods: Tuple[bool, bool, bool]) -> Array:
316316
"""
@@ -385,4 +385,4 @@ def _halo_bwd_rule(halo_extents: Tuple[int, int, int],
385385
halo_exchange.defvjp(_halo_fwd_rule, _halo_bwd_rule)
386386

387387
# JIT compile the halo_exchange operation
388-
halo_exchange = jax.jit(halo_exchange, static_argnums=(1, 2, 3))
388+
halo_exchange = jax.jit(halo_exchange, static_argnums=(1, 2))

pyproject.toml

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
[build-system]
2+
requires = [ "scikit-build-core","pybind11"]
3+
build-backend = "scikit_build_core.build"
4+
5+
[project]
6+
name = "jaxdecomp"
7+
version = "0.1.0"
8+
description = "JAX bindings for the cuDecomp library"
9+
authors = [
10+
{ name = "Wassim Kabalan" },
11+
{ name = "Francois Lanusse"}
12+
]
13+
urls = { "Homepage" = "https://github.com/DifferentiableUniverseInitiative/jaxDecomp" }
14+
readme = "README.md"
15+
license = { file = "LICENSE" }
16+
classifiers = [
17+
"Programming Language :: Python :: 3",
18+
"License :: OSI Approved :: MIT License",
19+
"Operating System :: OS Independent"
20+
]
21+
dependencies = []
22+
23+
[project.optional-dependencies]
24+
test = ["pytest"]
25+
26+
[tool.scikit-build]
27+
minimum-version = "0.8"
28+
cmake.version = ">=3.25"
29+
build-dir = "build/{wheel_tag}"
30+
wheel.py-api = "py3"
31+
cmake.build-type = "Release"
32+
# Add any additional configurations for scikit-build if necessary
33+
wheel.install-dir = "jaxdecomp/_src"
34+
35+
[tool.scikit-build.cmake.define]
36+
CMAKE_LIBRARY_OUTPUT_DIRECTORY = ""

setup.py

Lines changed: 0 additions & 96 deletions
This file was deleted.

0 commit comments

Comments
 (0)