diff --git a/CMakeLists.txt b/CMakeLists.txt index fced2db..bf4c2e2 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,17 +1,26 @@ cmake_minimum_required(VERSION 3.19...3.25) +find_program(NVHPC_CXX_BIN "nvc++" REQUIRED) +set(CMAKE_CXX_COMPILER ${NVHPC_CXX_BIN}) + +find_program(NVHPC_C_BIN "nvc" REQUIRED) +set(CMAKE_C_COMPILER ${NVHPC_C_BIN}) + +project(jaxdecomp LANGUAGES CXX CUDA) + # NVCC 12 does not support C++20 set(CMAKE_CXX_STANDARD 17) set(CMAKE_CUDA_STANDARD 17) + # Latest JAX v0.4.26 no longer supports cuda 11.8 -# By default, build for CUDA 12.2, users can override this with -DNVHPC_CUDA_VERSION=11.8 -set(NVHPC_CUDA_VERSION 12.2 CACHE STRING "CUDA version to build for" ) +find_package(CUDAToolkit REQUIRED VERSION 12) +set(NVHPC_CUDA_VERSION ${CUDAToolkit_VERSION_MAJOR}.${CUDAToolkit_VERSION_MINOR}) -# Build debug -# set(CMAKE_BUILD_TYPE Debug) -add_subdirectory(third_party/cuDecomp) +message(STATUS "Using CUDA ${NVHPC_CUDA_VERSION}") +# Build Release by default +set(CMAKE_BUILD_TYPE Release CACHE STRING "Choose the type of build.") -project(jaxdecomp LANGUAGES CXX CUDA) +add_subdirectory(third_party/cuDecomp) option(CUDECOMP_BUILD_FORTRAN "Build Fortran bindings" OFF) option(CUDECOMP_ENABLE_NVSHMEM "Enable NVSHMEM" OFF) @@ -34,7 +43,7 @@ find_library(NCCL_LIBRARY NAMES nccl HINTS ${NVHPC_NCCL_LIBRARY_DIR} ) - string(REPLACE "/lib" "/include" NCCL_INCLUDE_DIR ${NVHPC_NCCL_LIBRARY_DIR}) +string(REPLACE "/lib" "/include" NCCL_INCLUDE_DIR ${NVHPC_NCCL_LIBRARY_DIR}) message(STATUS "Using NCCL library: ${NCCL_LIBRARY}") @@ -68,4 +77,9 @@ target_link_libraries(_jaxdecomp PRIVATE NVHPC::CUTENSOR) target_link_libraries(_jaxdecomp PRIVATE NVHPC::CUDA) target_link_libraries(_jaxdecomp PRIVATE ${NCCL_LIBRARY}) target_link_libraries(_jaxdecomp PRIVATE cudecomp) +target_link_libraries(_jaxdecomp PRIVATE stdc++fs) set_target_properties(_jaxdecomp PROPERTIES LINKER_LANGUAGE CXX) + +set_target_properties(_jaxdecomp PROPERTIES INSTALL_RPATH "$ORIGIN/lib") + +install(TARGETS _jaxdecomp LIBRARY DESTINATION . PUBLIC_HEADER DESTINATION .) diff --git a/jaxdecomp/_src/halo.py b/jaxdecomp/_src/halo.py index ddcf5cc..3ce2602 100644 --- a/jaxdecomp/_src/halo.py +++ b/jaxdecomp/_src/halo.py @@ -25,7 +25,7 @@ class HaloPrimitive(BasePrimitive): name = "halo_exchange" multiple_results = False - impl_static_args = (1, 2, 3) + impl_static_args = (1, 2) inner_primitive = None outer_primitive = None @@ -310,7 +310,7 @@ def halo_p_lower(x: Array, halo_extents: Tuple[int, int, int], # Custom Partitioning -@partial(jax.custom_vjp, nondiff_argnums=(1, 2, 3)) +@partial(jax.custom_vjp, nondiff_argnums=(1, 2)) def halo_exchange(x: Array, halo_extents: Tuple[int, int, int], halo_periods: Tuple[bool, bool, bool]) -> Array: """ @@ -385,4 +385,4 @@ def _halo_bwd_rule(halo_extents: Tuple[int, int, int], halo_exchange.defvjp(_halo_fwd_rule, _halo_bwd_rule) # JIT compile the halo_exchange operation -halo_exchange = jax.jit(halo_exchange, static_argnums=(1, 2, 3)) +halo_exchange = jax.jit(halo_exchange, static_argnums=(1, 2)) diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..349d6b8 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,36 @@ +[build-system] +requires = [ "scikit-build-core","pybind11"] +build-backend = "scikit_build_core.build" + +[project] +name = "jaxdecomp" +version = "0.1.0" +description = "JAX bindings for the cuDecomp library" +authors = [ + { name = "Wassim Kabalan" }, + { name = "Francois Lanusse"} +] +urls = { "Homepage" = "https://github.com/DifferentiableUniverseInitiative/jaxDecomp" } +readme = "README.md" +license = { file = "LICENSE" } +classifiers = [ + "Programming Language :: Python :: 3", + "License :: OSI Approved :: MIT License", + "Operating System :: OS Independent" +] +dependencies = [] + +[project.optional-dependencies] +test = ["pytest"] + +[tool.scikit-build] +minimum-version = "0.8" +cmake.version = ">=3.25" +build-dir = "build/{wheel_tag}" +wheel.py-api = "py3" +cmake.build-type = "Release" +# Add any additional configurations for scikit-build if necessary +wheel.install-dir = "jaxdecomp/_src" + +[tool.scikit-build.cmake.define] +CMAKE_LIBRARY_OUTPUT_DIRECTORY = "" diff --git a/setup.py b/setup.py deleted file mode 100644 index b07affb..0000000 --- a/setup.py +++ /dev/null @@ -1,96 +0,0 @@ -import os -import subprocess -import sys -from pathlib import Path - -from setuptools import Extension, find_packages, setup -from setuptools.command.build_ext import build_ext - - -# A CMakeExtension needs a sourcedir instead of a file list. -# The name must be the _single_ output extension from the CMake build. -# If you need multiple extensions, see scikit-build. -class CMakeExtension(Extension): - - def __init__(self, name: str, sourcedir: str = "") -> None: - super().__init__(name, sources=[]) - self.sourcedir = os.fspath(Path(sourcedir).resolve()) - - -class CMakeBuild(build_ext): - - def build_extension(self, ext: CMakeExtension) -> None: - # Must be in this form due to bug in .resolve() only fixed in Python 3.10+ - ext_fullpath = Path.cwd() / self.get_ext_fullpath( - ext.name) # type: ignore[no-untyped-call] - extdir = ext_fullpath.parent.resolve() - - # Using this requires trailing slash for auto-detection & inclusion of - # auxiliary "native" libs - - debug = int(os.environ.get("DEBUG", - 0)) if self.debug is None else self.debug - cfg = "Debug" if debug else "Release" - - # CMake lets you override the generator - we need to check this. - # Can be set with Conda-Build, for example. - cmake_generator = os.environ.get("CMAKE_GENERATOR", "") - - # Set Python_EXECUTABLE instead if you use PYBIND11_FINDPYTHON - # EXAMPLE_VERSION_INFO shows you how to pass a value into the C++ code - # from Python. - cmake_args = [ - f"-DCMAKE_LIBRARY_OUTPUT_DIRECTORY={extdir}{os.sep}", - f"-DPYTHON_EXECUTABLE={sys.executable}", - f"-DCMAKE_BUILD_TYPE={cfg}", # not used on MSVC, but no harm - ] - build_args = [] - # Adding CMake arguments set as environment variable - # (needed e.g. to build for ARM OSx on conda-forge) - if "CMAKE_ARGS" in os.environ: - cmake_args += [ - item for item in os.environ["CMAKE_ARGS"].split(" ") if item - ] - - # Single config generators are handled "normally" - single_config = any(x in cmake_generator for x in {"NMake", "Ninja"}) - - # CMake allows an arch-in-generator style for backward compatibility - contains_arch = any(x in cmake_generator for x in {"ARM", "Win64"}) - - # Multi-config generators have a different way to specify configs - if not single_config: - cmake_args += [f"-DCMAKE_LIBRARY_OUTPUT_DIRECTORY_{cfg.upper()}={extdir}"] - build_args += ["--config", cfg] - - # Set CMAKE_BUILD_PARALLEL_LEVEL to control the parallel build level - # across all generators. - if "CMAKE_BUILD_PARALLEL_LEVEL" not in os.environ: - # self.parallel is a Python 3 only way to set parallel jobs by hand - # using -j in the build_ext call, not supported by pip or PyPA-build. - if hasattr(self, "parallel") and self.parallel: - # CMake 3.12+ only. - build_args += [f"-j{self.parallel}"] - - build_temp = Path(self.build_temp) / ext.name - if not build_temp.exists(): - build_temp.mkdir(parents=True) - - subprocess.run( - ["cmake", ext.sourcedir] + cmake_args, cwd=build_temp, check=True) - - subprocess.run( - ["cmake", "--build", "."] + build_args, cwd=build_temp, check=True) - - -setup( - name='jaxDecomp', - url='https://github.com/DifferentiableUniverseInitiative/jaxDecomp', - author='Wassim Kabalan, Francois Lanusse', - description='JAX bindings for the cuDecomp library', - ext_modules=[CMakeExtension("jaxdecomp/_src/_jaxdecomp")], - cmdclass={"build_ext": CMakeBuild}, - packages=find_packages(), - include_package_data=True, - use_scm_version=True, - setup_requires=["setuptools_scm"])