From 605fc8a07aa32b52cdb7fe9a28d6f4865d8b7278 Mon Sep 17 00:00:00 2001 From: Wassim KABALAN Date: Fri, 28 Jun 2024 12:23:56 +0200 Subject: [PATCH 1/8] Migrate to Scikit-build tools --- CMakeLists.txt | 4 +++ pyproject.toml | 36 +++++++++++++++++++ setup.py | 96 -------------------------------------------------- 3 files changed, 40 insertions(+), 96 deletions(-) create mode 100644 pyproject.toml delete mode 100644 setup.py diff --git a/CMakeLists.txt b/CMakeLists.txt index 4e0d58b..c70f143 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -67,3 +67,7 @@ target_link_libraries(_jaxdecomp PRIVATE NVHPC::CUDA) target_link_libraries(_jaxdecomp PRIVATE ${NCCL_LIBRARY}) target_link_libraries(_jaxdecomp PRIVATE cudecomp) set_target_properties(_jaxdecomp PROPERTIES LINKER_LANGUAGE CXX) + + +install(TARGETS cudecomp LIBRARY DESTINATION ${CMAKE_INSTALL_PREFIX} PUBLIC_HEADER DESTINATION ${CMAKE_INSTALL_PREFIX}) +install(TARGETS _jaxdecomp LIBRARY DESTINATION ${CMAKE_INSTALL_PREFIX} PUBLIC_HEADER DESTINATION ${CMAKE_INSTALL_PREFIX}) diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..ea6b09d --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,36 @@ +[build-system] +requires = [ "scikit-build-core","pybind11"] +build-backend = "scikit_build_core.build" + +[project] +name = "jaxdecomp" +version = "0.1.0" +description = "JAX bindings for the cuDecomp library" +authors = [ + { name = "Wassim Kabalan" }, + { name = "Francois Lanusse"} +] +urls = { "Homepage" = "https://github.com/DifferentiableUniverseInitiative/jaxDecomp" } +readme = "README.md" +license = { file = "LICENSE" } +classifiers = [ + "Programming Language :: Python :: 3", + "License :: OSI Approved :: MIT License", + "Operating System :: OS Independent" +] +dependencies = [] + +[project.optional-dependencies] +test = ["pytest"] + +[tool.scikit-build] +minimum-version = "0.8" +cmake.version = ">=3.19" +build-dir = "build/{wheel_tag}" +wheel.py-api = "py3" +cmake.build-type = "Release" +# Add any additional configurations for scikit-build if necessary +cmake.args = [ + "-DCMAKE_INSTALL_PREFIX:PATH=jaxdecomp/_src/_jaxdecomp/", + "-DCMAKE_LIBRARY_OUTPUT_DIRECTORY:PATH=jaxdecomp/_src/_jaxdecomp/", +] diff --git a/setup.py b/setup.py deleted file mode 100644 index b07affb..0000000 --- a/setup.py +++ /dev/null @@ -1,96 +0,0 @@ -import os -import subprocess -import sys -from pathlib import Path - -from setuptools import Extension, find_packages, setup -from setuptools.command.build_ext import build_ext - - -# A CMakeExtension needs a sourcedir instead of a file list. -# The name must be the _single_ output extension from the CMake build. -# If you need multiple extensions, see scikit-build. -class CMakeExtension(Extension): - - def __init__(self, name: str, sourcedir: str = "") -> None: - super().__init__(name, sources=[]) - self.sourcedir = os.fspath(Path(sourcedir).resolve()) - - -class CMakeBuild(build_ext): - - def build_extension(self, ext: CMakeExtension) -> None: - # Must be in this form due to bug in .resolve() only fixed in Python 3.10+ - ext_fullpath = Path.cwd() / self.get_ext_fullpath( - ext.name) # type: ignore[no-untyped-call] - extdir = ext_fullpath.parent.resolve() - - # Using this requires trailing slash for auto-detection & inclusion of - # auxiliary "native" libs - - debug = int(os.environ.get("DEBUG", - 0)) if self.debug is None else self.debug - cfg = "Debug" if debug else "Release" - - # CMake lets you override the generator - we need to check this. - # Can be set with Conda-Build, for example. - cmake_generator = os.environ.get("CMAKE_GENERATOR", "") - - # Set Python_EXECUTABLE instead if you use PYBIND11_FINDPYTHON - # EXAMPLE_VERSION_INFO shows you how to pass a value into the C++ code - # from Python. - cmake_args = [ - f"-DCMAKE_LIBRARY_OUTPUT_DIRECTORY={extdir}{os.sep}", - f"-DPYTHON_EXECUTABLE={sys.executable}", - f"-DCMAKE_BUILD_TYPE={cfg}", # not used on MSVC, but no harm - ] - build_args = [] - # Adding CMake arguments set as environment variable - # (needed e.g. to build for ARM OSx on conda-forge) - if "CMAKE_ARGS" in os.environ: - cmake_args += [ - item for item in os.environ["CMAKE_ARGS"].split(" ") if item - ] - - # Single config generators are handled "normally" - single_config = any(x in cmake_generator for x in {"NMake", "Ninja"}) - - # CMake allows an arch-in-generator style for backward compatibility - contains_arch = any(x in cmake_generator for x in {"ARM", "Win64"}) - - # Multi-config generators have a different way to specify configs - if not single_config: - cmake_args += [f"-DCMAKE_LIBRARY_OUTPUT_DIRECTORY_{cfg.upper()}={extdir}"] - build_args += ["--config", cfg] - - # Set CMAKE_BUILD_PARALLEL_LEVEL to control the parallel build level - # across all generators. - if "CMAKE_BUILD_PARALLEL_LEVEL" not in os.environ: - # self.parallel is a Python 3 only way to set parallel jobs by hand - # using -j in the build_ext call, not supported by pip or PyPA-build. - if hasattr(self, "parallel") and self.parallel: - # CMake 3.12+ only. - build_args += [f"-j{self.parallel}"] - - build_temp = Path(self.build_temp) / ext.name - if not build_temp.exists(): - build_temp.mkdir(parents=True) - - subprocess.run( - ["cmake", ext.sourcedir] + cmake_args, cwd=build_temp, check=True) - - subprocess.run( - ["cmake", "--build", "."] + build_args, cwd=build_temp, check=True) - - -setup( - name='jaxDecomp', - url='https://github.com/DifferentiableUniverseInitiative/jaxDecomp', - author='Wassim Kabalan, Francois Lanusse', - description='JAX bindings for the cuDecomp library', - ext_modules=[CMakeExtension("jaxdecomp/_src/_jaxdecomp")], - cmdclass={"build_ext": CMakeBuild}, - packages=find_packages(), - include_package_data=True, - use_scm_version=True, - setup_requires=["setuptools_scm"]) From 015d1cc209fa0d88d084ba04738cf3903ef2723d Mon Sep 17 00:00:00 2001 From: Wassim KABALAN Date: Fri, 28 Jun 2024 17:05:37 +0200 Subject: [PATCH 2/8] fix --- pyproject.toml | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index ea6b09d..80e9360 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -30,7 +30,8 @@ build-dir = "build/{wheel_tag}" wheel.py-api = "py3" cmake.build-type = "Release" # Add any additional configurations for scikit-build if necessary -cmake.args = [ - "-DCMAKE_INSTALL_PREFIX:PATH=jaxdecomp/_src/_jaxdecomp/", - "-DCMAKE_LIBRARY_OUTPUT_DIRECTORY:PATH=jaxdecomp/_src/_jaxdecomp/", -] +wheel.install-dir = "jaxdecomp/_src" + + +[tool.scikit-build.cmake.define] +CMAKE_LIBRARY_OUTPUT_DIRECTORY = "" From 815b15505f77e4120e3f4aaf8dbc8d4db4518067 Mon Sep 17 00:00:00 2001 From: Henry Schreiner Date: Fri, 28 Jun 2024 12:07:22 -0400 Subject: [PATCH 3/8] fix: some CMake fixes --- CMakeLists.txt | 17 +++++++++++------ pyproject.toml | 4 ---- 2 files changed, 11 insertions(+), 10 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index c70f143..c7ef9c9 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,15 +1,19 @@ cmake_minimum_required(VERSION 3.19...3.25) +project(jaxdecomp LANGUAGES CXX CUDA) + # NVCC 12 does not support C++20 set(CMAKE_CXX_STANDARD 17) set(CMAKE_CUDA_STANDARD 17) + # Latest JAX v0.4.26 no longer supports cuda 11.8 -set(NVHPC_CUDA_VERSION 12.2) +find_package(CUDAToolkit REQUIRED VERSION 12) +set(NVHPC_CUDA_VERSION ${CUDAToolkit_VERSION_MAJOR}.${CUDAToolkit_VERSION_MINOR}) + # Build debug # set(CMAKE_BUILD_TYPE Debug) -add_subdirectory(third_party/cuDecomp) -project(jaxdecomp LANGUAGES CXX CUDA) +add_subdirectory(third_party/cuDecomp) option(CUDECOMP_BUILD_FORTRAN "Build Fortran bindings" OFF) option(CUDECOMP_ENABLE_NVSHMEM "Enable NVSHMEM" OFF) @@ -32,7 +36,7 @@ find_library(NCCL_LIBRARY NAMES nccl HINTS ${NVHPC_NCCL_LIBRARY_DIR} ) - string(REPLACE "/lib" "/include" NCCL_INCLUDE_DIR ${NVHPC_NCCL_LIBRARY_DIR}) +string(REPLACE "/lib" "/include" NCCL_INCLUDE_DIR ${NVHPC_NCCL_LIBRARY_DIR}) message(STATUS "Using NCCL library: ${NCCL_LIBRARY}") @@ -66,8 +70,9 @@ target_link_libraries(_jaxdecomp PRIVATE NVHPC::CUTENSOR) target_link_libraries(_jaxdecomp PRIVATE NVHPC::CUDA) target_link_libraries(_jaxdecomp PRIVATE ${NCCL_LIBRARY}) target_link_libraries(_jaxdecomp PRIVATE cudecomp) +target_link_libraries(_jaxdecomp PRIVATE stdc++fs) set_target_properties(_jaxdecomp PROPERTIES LINKER_LANGUAGE CXX) +set_target_properties(_jaxdecomp PROPERTIES INSTALL_RPATH "$ORIGIN/lib") -install(TARGETS cudecomp LIBRARY DESTINATION ${CMAKE_INSTALL_PREFIX} PUBLIC_HEADER DESTINATION ${CMAKE_INSTALL_PREFIX}) -install(TARGETS _jaxdecomp LIBRARY DESTINATION ${CMAKE_INSTALL_PREFIX} PUBLIC_HEADER DESTINATION ${CMAKE_INSTALL_PREFIX}) +install(TARGETS _jaxdecomp LIBRARY DESTINATION . PUBLIC_HEADER DESTINATION .) diff --git a/pyproject.toml b/pyproject.toml index 80e9360..d016159 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -31,7 +31,3 @@ wheel.py-api = "py3" cmake.build-type = "Release" # Add any additional configurations for scikit-build if necessary wheel.install-dir = "jaxdecomp/_src" - - -[tool.scikit-build.cmake.define] -CMAKE_LIBRARY_OUTPUT_DIRECTORY = "" From 6e26027610c3009f84c19f69398704a9ca66807d Mon Sep 17 00:00:00 2001 From: Wassim KABALAN Date: Thu, 4 Jul 2024 10:17:44 +0200 Subject: [PATCH 4/8] Migrating to Scikit build core --- pyproject.toml | 37 +++++++++++++++++++ setup.py | 96 -------------------------------------------------- 2 files changed, 37 insertions(+), 96 deletions(-) create mode 100644 pyproject.toml delete mode 100644 setup.py diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..80e9360 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,37 @@ +[build-system] +requires = [ "scikit-build-core","pybind11"] +build-backend = "scikit_build_core.build" + +[project] +name = "jaxdecomp" +version = "0.1.0" +description = "JAX bindings for the cuDecomp library" +authors = [ + { name = "Wassim Kabalan" }, + { name = "Francois Lanusse"} +] +urls = { "Homepage" = "https://github.com/DifferentiableUniverseInitiative/jaxDecomp" } +readme = "README.md" +license = { file = "LICENSE" } +classifiers = [ + "Programming Language :: Python :: 3", + "License :: OSI Approved :: MIT License", + "Operating System :: OS Independent" +] +dependencies = [] + +[project.optional-dependencies] +test = ["pytest"] + +[tool.scikit-build] +minimum-version = "0.8" +cmake.version = ">=3.19" +build-dir = "build/{wheel_tag}" +wheel.py-api = "py3" +cmake.build-type = "Release" +# Add any additional configurations for scikit-build if necessary +wheel.install-dir = "jaxdecomp/_src" + + +[tool.scikit-build.cmake.define] +CMAKE_LIBRARY_OUTPUT_DIRECTORY = "" diff --git a/setup.py b/setup.py deleted file mode 100644 index b07affb..0000000 --- a/setup.py +++ /dev/null @@ -1,96 +0,0 @@ -import os -import subprocess -import sys -from pathlib import Path - -from setuptools import Extension, find_packages, setup -from setuptools.command.build_ext import build_ext - - -# A CMakeExtension needs a sourcedir instead of a file list. -# The name must be the _single_ output extension from the CMake build. -# If you need multiple extensions, see scikit-build. -class CMakeExtension(Extension): - - def __init__(self, name: str, sourcedir: str = "") -> None: - super().__init__(name, sources=[]) - self.sourcedir = os.fspath(Path(sourcedir).resolve()) - - -class CMakeBuild(build_ext): - - def build_extension(self, ext: CMakeExtension) -> None: - # Must be in this form due to bug in .resolve() only fixed in Python 3.10+ - ext_fullpath = Path.cwd() / self.get_ext_fullpath( - ext.name) # type: ignore[no-untyped-call] - extdir = ext_fullpath.parent.resolve() - - # Using this requires trailing slash for auto-detection & inclusion of - # auxiliary "native" libs - - debug = int(os.environ.get("DEBUG", - 0)) if self.debug is None else self.debug - cfg = "Debug" if debug else "Release" - - # CMake lets you override the generator - we need to check this. - # Can be set with Conda-Build, for example. - cmake_generator = os.environ.get("CMAKE_GENERATOR", "") - - # Set Python_EXECUTABLE instead if you use PYBIND11_FINDPYTHON - # EXAMPLE_VERSION_INFO shows you how to pass a value into the C++ code - # from Python. - cmake_args = [ - f"-DCMAKE_LIBRARY_OUTPUT_DIRECTORY={extdir}{os.sep}", - f"-DPYTHON_EXECUTABLE={sys.executable}", - f"-DCMAKE_BUILD_TYPE={cfg}", # not used on MSVC, but no harm - ] - build_args = [] - # Adding CMake arguments set as environment variable - # (needed e.g. to build for ARM OSx on conda-forge) - if "CMAKE_ARGS" in os.environ: - cmake_args += [ - item for item in os.environ["CMAKE_ARGS"].split(" ") if item - ] - - # Single config generators are handled "normally" - single_config = any(x in cmake_generator for x in {"NMake", "Ninja"}) - - # CMake allows an arch-in-generator style for backward compatibility - contains_arch = any(x in cmake_generator for x in {"ARM", "Win64"}) - - # Multi-config generators have a different way to specify configs - if not single_config: - cmake_args += [f"-DCMAKE_LIBRARY_OUTPUT_DIRECTORY_{cfg.upper()}={extdir}"] - build_args += ["--config", cfg] - - # Set CMAKE_BUILD_PARALLEL_LEVEL to control the parallel build level - # across all generators. - if "CMAKE_BUILD_PARALLEL_LEVEL" not in os.environ: - # self.parallel is a Python 3 only way to set parallel jobs by hand - # using -j in the build_ext call, not supported by pip or PyPA-build. - if hasattr(self, "parallel") and self.parallel: - # CMake 3.12+ only. - build_args += [f"-j{self.parallel}"] - - build_temp = Path(self.build_temp) / ext.name - if not build_temp.exists(): - build_temp.mkdir(parents=True) - - subprocess.run( - ["cmake", ext.sourcedir] + cmake_args, cwd=build_temp, check=True) - - subprocess.run( - ["cmake", "--build", "."] + build_args, cwd=build_temp, check=True) - - -setup( - name='jaxDecomp', - url='https://github.com/DifferentiableUniverseInitiative/jaxDecomp', - author='Wassim Kabalan, Francois Lanusse', - description='JAX bindings for the cuDecomp library', - ext_modules=[CMakeExtension("jaxdecomp/_src/_jaxdecomp")], - cmdclass={"build_ext": CMakeBuild}, - packages=find_packages(), - include_package_data=True, - use_scm_version=True, - setup_requires=["setuptools_scm"]) From 91bd00941007ca94d4924999cb81720bfabb0c2b Mon Sep 17 00:00:00 2001 From: Wassim KABALAN Date: Thu, 4 Jul 2024 18:29:43 +0200 Subject: [PATCH 5/8] Set compilers manually to avoid endless CMake loop --- CMakeLists.txt | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index c7ef9c9..96754d8 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,5 +1,11 @@ cmake_minimum_required(VERSION 3.19...3.25) +find_program(NVHPC_CXX_BIN "nvc++" REQUIRED) +set(CMAKE_CXX_COMPILER ${NVHPC_CXX_BIN}) + +find_program(NVHPC_C_BIN "nvc" REQUIRED) +set(CMAKE_C_COMPILER ${NVHPC_C_BIN}) + project(jaxdecomp LANGUAGES CXX CUDA) # NVCC 12 does not support C++20 @@ -10,8 +16,9 @@ set(CMAKE_CUDA_STANDARD 17) find_package(CUDAToolkit REQUIRED VERSION 12) set(NVHPC_CUDA_VERSION ${CUDAToolkit_VERSION_MAJOR}.${CUDAToolkit_VERSION_MINOR}) -# Build debug -# set(CMAKE_BUILD_TYPE Debug) +message(STATUS "Using CUDA ${NVHPC_CUDA_VERSION}") +# Build Release by default +set(CMAKE_BUILD_TYPE Release CACHE STRING "Choose the type of build.") add_subdirectory(third_party/cuDecomp) From 34db3cf0b8f1f9e9645506c244379776d7e64964 Mon Sep 17 00:00:00 2001 From: EiffL Date: Sun, 7 Jul 2024 21:21:03 -0400 Subject: [PATCH 6/8] Bump cmake version to avoid NVHPCConfig issues --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index f76b47f..349d6b8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,7 +25,7 @@ test = ["pytest"] [tool.scikit-build] minimum-version = "0.8" -cmake.version = ">=3.19" +cmake.version = ">=3.25" build-dir = "build/{wheel_tag}" wheel.py-api = "py3" cmake.build-type = "Release" From f63d8cfefdb6cebaa256d5ef4eb3583df15400f2 Mon Sep 17 00:00:00 2001 From: EiffL Date: Sun, 7 Jul 2024 22:05:56 -0400 Subject: [PATCH 7/8] fix argument numbers after removing halo_reduce --- jaxdecomp/_src/halo.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/jaxdecomp/_src/halo.py b/jaxdecomp/_src/halo.py index ddcf5cc..3ce2602 100644 --- a/jaxdecomp/_src/halo.py +++ b/jaxdecomp/_src/halo.py @@ -25,7 +25,7 @@ class HaloPrimitive(BasePrimitive): name = "halo_exchange" multiple_results = False - impl_static_args = (1, 2, 3) + impl_static_args = (1, 2) inner_primitive = None outer_primitive = None @@ -310,7 +310,7 @@ def halo_p_lower(x: Array, halo_extents: Tuple[int, int, int], # Custom Partitioning -@partial(jax.custom_vjp, nondiff_argnums=(1, 2, 3)) +@partial(jax.custom_vjp, nondiff_argnums=(1, 2)) def halo_exchange(x: Array, halo_extents: Tuple[int, int, int], halo_periods: Tuple[bool, bool, bool]) -> Array: """ @@ -385,4 +385,4 @@ def _halo_bwd_rule(halo_extents: Tuple[int, int, int], halo_exchange.defvjp(_halo_fwd_rule, _halo_bwd_rule) # JIT compile the halo_exchange operation -halo_exchange = jax.jit(halo_exchange, static_argnums=(1, 2, 3)) +halo_exchange = jax.jit(halo_exchange, static_argnums=(1, 2)) From d6640e83bf9591249b913a90ac7c1c33b8b936cd Mon Sep 17 00:00:00 2001 From: EiffL Date: Sun, 7 Jul 2024 22:33:02 -0400 Subject: [PATCH 8/8] reverting precommit change --- .pre-commit-config.yaml | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index f44eaca..c7141cd 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -15,3 +15,10 @@ repos: hooks: - id: isort name: isort (python) +- repo: https://github.com/pre-commit/mirrors-clang-format + rev: v18.1.4 + hooks: + - id: clang-format + files: '\.(c|cc|cpp|h|hpp|cxx|hh|cu|cuh)$' + exclude: '^third_party/|/pybind11/' + name: clang-format