Skip to content

Commit

Permalink
update wheels.yml toml and cmake
Browse files Browse the repository at this point in the history
  • Loading branch information
ASKabalan committed Oct 23, 2024
1 parent 6549b5a commit 6ab940f
Show file tree
Hide file tree
Showing 4 changed files with 93 additions and 10 deletions.
31 changes: 31 additions & 0 deletions .github/workflows/wheels.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
name: Build

on: [push]

jobs:
build_wheels:
name: Build wheels on ${{ matrix.os }}
runs-on: ${{ matrix.os }}
strategy:
matrix:
os: [ubuntu-latest]

steps:
- uses: actions/checkout@v4

# Used to host cibuildwheel
- uses: actions/setup-python@v5

- name: Install cibuildwheel
run: python -m pip install cibuildwheel==2.21.3

- name: Build wheels
run: python -m cibuildwheel --output-dir wheelhouse
# to supply options, put them in 'env', like:
# env:
# CIBW_SOME_OPTION: value

- uses: actions/upload-artifact@v4
with:
name: cibw-wheels-${{ matrix.os }}-${{ strategy.job-index }}
path: ./wheelhouse/*.whl
16 changes: 8 additions & 8 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -59,18 +59,18 @@ if(CMAKE_CUDA_COMPILER AND JD_CUDECOMP_BACKEND)

# Add _jaxdecomp modulei
pybind11_add_module(_jaxdecomp
src/halo.cu
src/jaxdecomp.cc
src/grid_descriptor_mgr.cc
src/fft.cu
src/transpose.cu
src/csrc/halo.cu
src/csrc/jaxdecomp.cc
src/csrc/grid_descriptor_mgr.cc
src/csrc/fft.cu
src/csrc/transpose.cu
)

set_target_properties(_jaxdecomp PROPERTIES CUDA_ARCHITECTURES "${CUDECOMP_CUDA_CC_LIST}")

target_include_directories(_jaxdecomp
PRIVATE
${CMAKE_CURRENT_LIST_DIR}/include
${CMAKE_CURRENT_LIST_DIR}/src/csrc/include
${CMAKE_CURRENT_SOURCE_DIR}/third_party/cuDecomp/include
${NVHPC_CUDA_INCLUDE_DIR}
${MPI_CXX_INCLUDE_DIRS}
Expand All @@ -88,8 +88,8 @@ if(CMAKE_CUDA_COMPILER AND JD_CUDECOMP_BACKEND)
set_target_properties(_jaxdecomp PROPERTIES LINKER_LANGUAGE CXX)
target_compile_definitions(_jaxdecomp PRIVATE JD_CUDECOMP_BACKEND)
else()
pybind11_add_module(_jaxdecomp src/jaxdecomp.cc)
target_include_directories(_jaxdecomp PRIVATE ${CMAKE_CURRENT_LIST_DIR}/include)
pybind11_add_module(_jaxdecomp src/csrc/jaxdecomp.cc)
target_include_directories(_jaxdecomp PRIVATE ${CMAKE_CURRENT_LIST_DIR}/src/csrc/include)
target_compile_definitions(_jaxdecomp PRIVATE JD_JAX_BACKEND)
endif()

Expand Down
6 changes: 4 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
[build-system]
requires = [ "scikit-build-core","pybind11"]
requires = ["scikit-build-core", "pybind11"]
build-backend = "scikit_build_core.build"

[project]
Expand All @@ -19,6 +19,8 @@ classifiers = [
"Operating System :: OS Independent"
]
dependencies = ["jaxtyping"]
packages = ["jaxdecomp"]
package-dir = {"" = "src"}

[project.optional-dependencies]
test = ["pytest"]
Expand All @@ -29,9 +31,9 @@ cmake.version = ">=3.25"
build-dir = "build/{wheel_tag}"
wheel.py-api = "py3"
cmake.build-type = "Release"
# Add any additional configurations for scikit-build if necessary
wheel.install-dir = "jaxdecomp/_src"

[tool.scikit-build.cmake.define]
CMAKE_LIBRARY_OUTPUT_DIRECTORY = ""
CMAKE_EXPORT_COMPILE_COMMANDS = "ON"

50 changes: 50 additions & 0 deletions src/csrc/include/jaxdecomp.h
Original file line number Diff line number Diff line change
@@ -1,11 +1,53 @@
#ifndef _JAX_DECOMP_H_
#define _JAX_DECOMP_H_
#ifdef JD_CUDECOMP_BACKEND
#include "checks.h"
#include <cudecomp.h>
#endif
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>

#ifdef JD_JAX_BACKEND
enum cudecompTransposeCommBackend_t {
CUDECOMP_TRANSPOSE_COMM_MPI_P2P = 1, ///< MPI backend using peer-to-peer algorithm (i.e.,MPI_Isend/MPI_Irecv)
CUDECOMP_TRANSPOSE_COMM_MPI_P2P_PL = 2, ///< MPI backend using peer-to-peer algorithm with pipelining
CUDECOMP_TRANSPOSE_COMM_MPI_A2A = 3, ///< MPI backend using MPI_Alltoallv
CUDECOMP_TRANSPOSE_COMM_NCCL = 4, ///< NCCL backend
CUDECOMP_TRANSPOSE_COMM_NCCL_PL = 5, ///< NCCL backend with pipelining
CUDECOMP_TRANSPOSE_COMM_NVSHMEM = 6, ///< NVSHMEM backend
CUDECOMP_TRANSPOSE_COMM_NVSHMEM_PL = 7 ///< NVSHMEM backend with pipelining
};

/**
* @brief This enum lists the different available halo backend options.
*/
enum cudecompHaloCommBackend_t {
CUDECOMP_HALO_COMM_MPI = 1, ///< MPI backend
CUDECOMP_HALO_COMM_MPI_BLOCKING = 2, ///< MPI backend with blocking between each peer transfer
CUDECOMP_HALO_COMM_NCCL = 3, ///< NCCL backend
CUDECOMP_HALO_COMM_NVSHMEM = 4, ///< NVSHMEM backend
CUDECOMP_HALO_COMM_NVSHMEM_BLOCKING = 5 ///< NVSHMEM backend with blocking between each peer transfer
};
#endif

namespace jaxdecomp {

#ifdef JD_JAX_BACKEND

enum class TransposeType {
TRANSPOSE_XY,
TRANSPOSE_YZ,
TRANSPOSE_ZY,
TRANSPOSE_YX,
TRANSPOSE_XZ,
TRANSPOSE_ZX,
UNKNOWN_TRANSPOSE
};

enum Decomposition { slab_XY = 0, slab_YZ = 1, pencil = 2, no_decomp = 3 };

#endif

/**
* @brief A data structure defining configuration options for grid descriptor creation.
* Slightly adapted version of cudecompGridDescConfig_t which can be automatically translated by pybind11
Expand All @@ -27,6 +69,9 @@ typedef struct {
halo_comm_backend; ///< communication backend to use for halo communication (default: CUDECOMP_HALO_COMM_MPI)

} decompGridDescConfig_t;

#ifdef JD_CUDECOMP_BACKEND

void cudecompGridDescConfigSet(cudecompGridDescConfig_t* config, const decompGridDescConfig_t* source) {
// Initialize the config with the defaults
CHECK_CUDECOMP_EXIT(cudecompGridDescConfigSetDefaults(config));
Expand All @@ -36,6 +81,7 @@ void cudecompGridDescConfigSet(cudecompGridDescConfig_t* config, const decompGri
config->halo_comm_backend = source->halo_comm_backend;
config->transpose_comm_backend = source->transpose_comm_backend;
};
#endif

/**
* @brief A data structure containing geometry information about a pencil data buffer.
Expand All @@ -49,6 +95,8 @@ typedef struct {
std::array<int32_t, 3> halo_extents; ///< halo extents by dimension (in global order)
int64_t size; ///< number of elements in pencil (including halo elements)
} decompPencilInfo_t;

#ifdef JD_CUDECOMP_BACKEND
void decompPencilInfoSet(decompPencilInfo_t* info, const cudecompPencilInfo_t* source) {
for (int i = 0; i < 3; i++) {
info->hi[i] = source->hi[i];
Expand All @@ -59,6 +107,8 @@ void decompPencilInfoSet(decompPencilInfo_t* info, const cudecompPencilInfo_t* s
}
info->size = source->size;
};
#endif

}; // namespace jaxdecomp

#endif

0 comments on commit 6ab940f

Please sign in to comment.