Skip to content

Commit

Permalink
Merge pull request #23 from DifferentiableUniverseInitiative/fix_halo…
Browse files Browse the repository at this point in the history
…_again

Fix halo again
  • Loading branch information
ASKabalan authored Jul 19, 2024
2 parents c428f92 + 81a4095 commit cc642fb
Show file tree
Hide file tree
Showing 5 changed files with 71 additions and 16 deletions.
2 changes: 1 addition & 1 deletion include/grid_descriptor_mgr.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ class GridDescriptorManager {

AsyncLogger m_Tracer;
bool isInitialized = false;

int isMPIalreadyInitialized = false;
cudecompHandle_t m_Handle;

std::unordered_map<fftDescriptor, std::shared_ptr<FourierExecutor<double>>, std::hash<fftDescriptor>, std::equal_to<>>
Expand Down
6 changes: 5 additions & 1 deletion include/halo.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,19 +36,23 @@ template <typename real_t> class HaloExchange {
friend class GridDescriptorManager;

public:
HaloExchange() = default;
HaloExchange() : m_Tracer("JAXDECOMP") {}
// Grid descriptors are handled by the GridDescriptorManager
// No memory should be cleaned up here everything is handled by the GridDescriptorManager
~HaloExchange() = default;

HRESULT get_halo_descriptor(cudecompHandle_t handle, size_t& work_size, haloDescriptor_t& halo_desc);
HRESULT halo_exchange(cudecompHandle_t handle, haloDescriptor_t desc, cudaStream_t stream, void** buffers);

private:
AsyncLogger m_Tracer;

cudecompGridDesc_t m_GridConfig;
cudecompGridDescConfig_t m_GridDescConfig;
cudecompPencilInfo_t m_PencilInfo;

int64_t m_WorkSize;
HRESULT cleanUp(cudecompHandle_t handle);
};

} // namespace jaxdecomp
Expand Down
2 changes: 1 addition & 1 deletion jaxdecomp/_src/halo.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
from functools import partial
from typing import Tuple

Expand Down Expand Up @@ -202,7 +203,6 @@ def per_shard_impl(x: Array, halo_extents: Tuple[int, int, int],
pdims=pdims,
global_shape=global_shape,
)

return output

@staticmethod
Expand Down
61 changes: 55 additions & 6 deletions src/grid_descriptor_mgr.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,8 @@ GridDescriptorManager::GridDescriptorManager() : m_Tracer("JAXDECOMP") {
MPI_Comm mpi_comm = MPI_COMM_WORLD;

// Check if MPI has already been initialized by the user (maybe with mpi4py)
int is_initialized;
CHECK_MPI_EXIT(MPI_Initialized(&is_initialized));
if (!is_initialized) { CHECK_MPI_EXIT(MPI_Init(nullptr, nullptr)); }
CHECK_MPI_EXIT(MPI_Initialized(&isMPIalreadyInitialized));
if (!isMPIalreadyInitialized) { CHECK_MPI_EXIT(MPI_Init(nullptr, nullptr)); }
// Initialize cuDecomp
CHECK_CUDECOMP_EXIT(cudecompInit(&m_Handle, mpi_comm));
isInitialized = true;
Expand Down Expand Up @@ -144,11 +143,12 @@ HRESULT GridDescriptorManager::createTransposeExecutor(transposeDescriptor& desc
return hr;
}

// TODO(Wassim) : This can be cleanup using some polymorphism
void GridDescriptorManager::finalize() {
if (!isInitialized) return;

StartTraceInfo(m_Tracer) << "JaxDecomp shut down" << std::endl;
// Destroy grid descriptors
// Destroy grid descriptors for FFTs
for (auto& descriptor : m_Descriptors64) {
auto& executor = descriptor.second;
// TODO(wassim) : Cleanup cudecomp resources
Expand All @@ -175,13 +175,62 @@ void GridDescriptorManager::finalize() {
executor->clearPlans();
}

// Destroy Halo descriptors
for (auto& descriptor : m_HaloDescriptors64) {
auto& executor = descriptor.second;
// Cleanup cudecomp resources
// CHECK_CUDECOMP_EXIT(cudecompFree(handle, grid_desc_c, work)); This can
// be used instead of requesting XLA to allocate the memory
cudecompResult_t err = cudecompGridDescDestroy(m_Handle, executor->m_GridConfig);
if (CUDECOMP_RESULT_SUCCESS != err) {
StartTraceInfo(m_Tracer) << "CUDECOMP error.at exit " << err << ")" << std::endl;
}
executor->cleanUp(m_Handle);
}

for (auto& descriptor : m_HaloDescriptors32) {
auto& executor = descriptor.second;
// Cleanup cudecomp resources
// CHECK_CUDECOMP_EXIT(cudecompFree(handle, grid_desc_c, work)); This can
// be used instead of requesting XLA to allocate the memory
cudecompResult_t err = cudecompGridDescDestroy(m_Handle, executor->m_GridConfig);
if (CUDECOMP_RESULT_SUCCESS != err) {
StartTraceInfo(m_Tracer) << "CUDECOMP error.at exit " << err << ")" << std::endl;
}
executor->cleanUp(m_Handle);
}

// Destroy Transpose descriptors
for (auto& descriptor : m_TransposeDescriptors64) {
auto& executor = descriptor.second;
// Cleanup cudecomp resources
// CHECK_CUDECOMP_EXIT(cudecompFree(handle, grid_desc_c, work)); This can
// be used instead of requesting XLA to allocate the memory
cudecompResult_t err = cudecompGridDescDestroy(m_Handle, executor->m_GridConfig);
if (CUDECOMP_RESULT_SUCCESS != err) {
StartTraceInfo(m_Tracer) << "CUDECOMP error.at exit " << err << ")" << std::endl;
}
}

for (auto& descriptor : m_TransposeDescriptors32) {
auto& executor = descriptor.second;
// Cleanup cudecomp resources
// CHECK_CUDECOMP_EXIT(cudecompFree(handle, grid_desc_c, work)); This can
// be used instead of requesting XLA to allocate the memory
cudecompResult_t err = cudecompGridDescDestroy(m_Handle, executor->m_GridConfig);
if (CUDECOMP_RESULT_SUCCESS != err) {
StartTraceInfo(m_Tracer) << "CUDECOMP error.at exit " << err << ")" << std::endl;
}
}

// TODO(wassim) : Cleanup cudecomp resources
// there is an issue with mpi4py calling finalize at py_exit before this
cudecompFinalize(m_Handle);
// Clean finish
CHECK_CUDA_EXIT(cudaDeviceSynchronize());
// MPI is finalized by the mpi4py runtime (I wish it wasn't)
// CHECK_MPI_EXIT(MPI_Finalize());
// If jaxDecomp initialized MPI finalize it
// Otherwise mpi4py will finalize its own MPI WORLD Communicator
if (!isMPIalreadyInitialized) { CHECK_MPI_EXIT(MPI_Finalize()); }
isInitialized = false;
}

Expand Down
16 changes: 9 additions & 7 deletions src/halo.cu
Original file line number Diff line number Diff line change
Expand Up @@ -24,23 +24,20 @@ HRESULT HaloExchange<real_t>::get_halo_descriptor(cudecompHandle_t handle, size_
CHECK_CUDECOMP_EXIT(
cudecompGetPencilInfo(handle, m_GridConfig, &m_PencilInfo, halo_desc.axis, halo_desc.halo_extents.data()));

cudecompPencilInfo_t no_halo;

// Get pencil information for the specified axis
CHECK_CUDECOMP_EXIT(cudecompGetPencilInfo(handle, m_GridConfig, &no_halo, halo_desc.axis, nullptr));

// Get workspace size
int64_t workspace_num_elements;
CHECK_CUDECOMP_EXIT(cudecompGetHaloWorkspaceSize(handle, m_GridConfig, halo_desc.axis, m_PencilInfo.halo_extents,
&workspace_num_elements));

// TODO(Wassim) Handle complex numbers
int64_t dtype_size;
if (halo_desc.double_precision)
CHECK_CUDECOMP_EXIT(cudecompGetDataTypeSize(CUDECOMP_DOUBLE, &dtype_size));
else
CHECK_CUDECOMP_EXIT(cudecompGetDataTypeSize(CUDECOMP_FLOAT, &dtype_size));

work_size = dtype_size * workspace_num_elements;
m_WorkSize = dtype_size * workspace_num_elements;
work_size = m_WorkSize;

return S_OK;
}
Expand All @@ -51,7 +48,6 @@ HRESULT HaloExchange<real_t>::halo_exchange(cudecompHandle_t handle, haloDescrip
void* data_d = buffers[0];
void* work_d = buffers[1];

// desc.axis = 2;
// Perform halo exchange along the three dimensions
for (int i = 0; i < 3; ++i) {
switch (desc.axis) {
Expand All @@ -73,6 +69,12 @@ HRESULT HaloExchange<real_t>::halo_exchange(cudecompHandle_t handle, haloDescrip
return S_OK;
};

template <typename real_t> HRESULT HaloExchange<real_t>::cleanUp(cudecompHandle_t handle) {
// XLA is doing the allocation
// nothing to clean up
return S_OK;
}

template class HaloExchange<float>;
template class HaloExchange<double>;
} // namespace jaxdecomp

0 comments on commit cc642fb

Please sign in to comment.