Skip to content

Commit

Permalink
Merge pull request #22 from DifferentiableUniverseInitiative/push-to-…
Browse files Browse the repository at this point in the history
…pypi

[Packaging]  Migrating to scikit-build core
  • Loading branch information
EiffL authored Jul 8, 2024
2 parents d551967 + d6640e8 commit 1be36aa
Show file tree
Hide file tree
Showing 4 changed files with 60 additions and 106 deletions.
28 changes: 21 additions & 7 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,17 +1,26 @@
cmake_minimum_required(VERSION 3.19...3.25)

find_program(NVHPC_CXX_BIN "nvc++" REQUIRED)
set(CMAKE_CXX_COMPILER ${NVHPC_CXX_BIN})

find_program(NVHPC_C_BIN "nvc" REQUIRED)
set(CMAKE_C_COMPILER ${NVHPC_C_BIN})

project(jaxdecomp LANGUAGES CXX CUDA)

# NVCC 12 does not support C++20
set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CUDA_STANDARD 17)

# Latest JAX v0.4.26 no longer supports cuda 11.8
# By default, build for CUDA 12.2, users can override this with -DNVHPC_CUDA_VERSION=11.8
set(NVHPC_CUDA_VERSION 12.2 CACHE STRING "CUDA version to build for" )
find_package(CUDAToolkit REQUIRED VERSION 12)
set(NVHPC_CUDA_VERSION ${CUDAToolkit_VERSION_MAJOR}.${CUDAToolkit_VERSION_MINOR})

# Build debug
# set(CMAKE_BUILD_TYPE Debug)
add_subdirectory(third_party/cuDecomp)
message(STATUS "Using CUDA ${NVHPC_CUDA_VERSION}")
# Build Release by default
set(CMAKE_BUILD_TYPE Release CACHE STRING "Choose the type of build.")

project(jaxdecomp LANGUAGES CXX CUDA)
add_subdirectory(third_party/cuDecomp)

option(CUDECOMP_BUILD_FORTRAN "Build Fortran bindings" OFF)
option(CUDECOMP_ENABLE_NVSHMEM "Enable NVSHMEM" OFF)
Expand All @@ -34,7 +43,7 @@ find_library(NCCL_LIBRARY
NAMES nccl
HINTS ${NVHPC_NCCL_LIBRARY_DIR}
)
string(REPLACE "/lib" "/include" NCCL_INCLUDE_DIR ${NVHPC_NCCL_LIBRARY_DIR})
string(REPLACE "/lib" "/include" NCCL_INCLUDE_DIR ${NVHPC_NCCL_LIBRARY_DIR})


message(STATUS "Using NCCL library: ${NCCL_LIBRARY}")
Expand Down Expand Up @@ -68,4 +77,9 @@ target_link_libraries(_jaxdecomp PRIVATE NVHPC::CUTENSOR)
target_link_libraries(_jaxdecomp PRIVATE NVHPC::CUDA)
target_link_libraries(_jaxdecomp PRIVATE ${NCCL_LIBRARY})
target_link_libraries(_jaxdecomp PRIVATE cudecomp)
target_link_libraries(_jaxdecomp PRIVATE stdc++fs)
set_target_properties(_jaxdecomp PROPERTIES LINKER_LANGUAGE CXX)

set_target_properties(_jaxdecomp PROPERTIES INSTALL_RPATH "$ORIGIN/lib")

install(TARGETS _jaxdecomp LIBRARY DESTINATION . PUBLIC_HEADER DESTINATION .)
6 changes: 3 additions & 3 deletions jaxdecomp/_src/halo.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ class HaloPrimitive(BasePrimitive):

name = "halo_exchange"
multiple_results = False
impl_static_args = (1, 2, 3)
impl_static_args = (1, 2)
inner_primitive = None
outer_primitive = None

Expand Down Expand Up @@ -310,7 +310,7 @@ def halo_p_lower(x: Array, halo_extents: Tuple[int, int, int],


# Custom Partitioning
@partial(jax.custom_vjp, nondiff_argnums=(1, 2, 3))
@partial(jax.custom_vjp, nondiff_argnums=(1, 2))
def halo_exchange(x: Array, halo_extents: Tuple[int, int, int],
halo_periods: Tuple[bool, bool, bool]) -> Array:
"""
Expand Down Expand Up @@ -385,4 +385,4 @@ def _halo_bwd_rule(halo_extents: Tuple[int, int, int],
halo_exchange.defvjp(_halo_fwd_rule, _halo_bwd_rule)

# JIT compile the halo_exchange operation
halo_exchange = jax.jit(halo_exchange, static_argnums=(1, 2, 3))
halo_exchange = jax.jit(halo_exchange, static_argnums=(1, 2))
36 changes: 36 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
[build-system]
requires = [ "scikit-build-core","pybind11"]
build-backend = "scikit_build_core.build"

[project]
name = "jaxdecomp"
version = "0.1.0"
description = "JAX bindings for the cuDecomp library"
authors = [
{ name = "Wassim Kabalan" },
{ name = "Francois Lanusse"}
]
urls = { "Homepage" = "https://github.com/DifferentiableUniverseInitiative/jaxDecomp" }
readme = "README.md"
license = { file = "LICENSE" }
classifiers = [
"Programming Language :: Python :: 3",
"License :: OSI Approved :: MIT License",
"Operating System :: OS Independent"
]
dependencies = []

[project.optional-dependencies]
test = ["pytest"]

[tool.scikit-build]
minimum-version = "0.8"
cmake.version = ">=3.25"
build-dir = "build/{wheel_tag}"
wheel.py-api = "py3"
cmake.build-type = "Release"
# Add any additional configurations for scikit-build if necessary
wheel.install-dir = "jaxdecomp/_src"

[tool.scikit-build.cmake.define]
CMAKE_LIBRARY_OUTPUT_DIRECTORY = ""
96 changes: 0 additions & 96 deletions setup.py

This file was deleted.

0 comments on commit 1be36aa

Please sign in to comment.