Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

init_shm: shortcut the trivial case of local_size is 1 #7251

Merged
merged 2 commits into from
Feb 24, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 14 additions & 10 deletions src/mpid/ch4/shm/posix/eager/iqueue/iqueue_init.c
Original file line number Diff line number Diff line change
Expand Up @@ -116,18 +116,22 @@ int MPIDI_POSIX_iqueue_post_init(void)
int mpi_errno = MPI_SUCCESS;

/* gather max_vcis */
int max_vcis = 0;
max_vcis = 0;
MPIDU_Init_shm_put(&MPIDI_POSIX_global.num_vcis, sizeof(int));
MPIDU_Init_shm_barrier();
for (int i = 0; i < MPIR_Process.local_size; i++) {
int num;
MPIDU_Init_shm_get(i, sizeof(int), &num);
if (max_vcis < num) {
max_vcis = num;
int max_vcis;
if (MPIR_Process.local_size == 1) {
max_vcis = MPIDI_POSIX_global.num_vcis;
} else {
max_vcis = 0;
MPIDU_Init_shm_put(&MPIDI_POSIX_global.num_vcis, sizeof(int));
MPIDU_Init_shm_barrier();
for (int i = 0; i < MPIR_Process.local_size; i++) {
int num;
MPIDU_Init_shm_get(i, sizeof(int), &num);
if (max_vcis < num) {
max_vcis = num;
}
}
MPIDU_Init_shm_barrier();
}
MPIDU_Init_shm_barrier();

MPIDI_POSIX_eager_iqueue_global.max_vcis = max_vcis;

Expand Down
105 changes: 53 additions & 52 deletions src/mpid/common/shm/mpidu_init_shm.c
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@

static int init_shm_initialized;

int MPIDU_Init_shm_local_size;
int MPIDU_Init_shm_local_rank;

#ifdef ENABLE_NO_LOCAL
/* shared memory disabled, just stubs */

Expand Down Expand Up @@ -55,8 +58,6 @@ typedef struct Init_shm_barrier {
MPL_atomic_int_t wait;
} Init_shm_barrier_t;

static int local_size;
static int my_local_rank;
static MPIDU_shm_seg_t memory;
static Init_shm_barrier_t *barrier;
static void *baseaddr;
Expand Down Expand Up @@ -88,12 +89,12 @@ static int Init_shm_barrier(void)

MPIR_FUNC_ENTER;

if (local_size == 1)
if (MPIDU_Init_shm_local_size == 1)
goto fn_exit;

MPIR_ERR_CHKINTERNAL(!barrier_init, mpi_errno, "barrier not initialized");

if (MPL_atomic_fetch_add_int(&barrier->val, 1) == local_size - 1) {
if (MPL_atomic_fetch_add_int(&barrier->val, 1) == MPIDU_Init_shm_local_size - 1) {
MPL_atomic_store_int(&barrier->val, 0);
MPL_atomic_store_int(&barrier->wait, 1 - sense);
} else {
Expand All @@ -112,39 +113,35 @@ static int Init_shm_barrier(void)
int MPIDU_Init_shm_init(void)
{
int mpi_errno = MPI_SUCCESS, mpl_err = 0;
MPIR_CHKPMEM_DECL();
MPIR_CHKLMEM_DECL();

MPIR_FUNC_ENTER;

local_size = MPIR_Process.local_size;
my_local_rank = MPIR_Process.local_rank;

size_t segment_len = MPIDU_SHM_CACHE_LINE_LEN + sizeof(MPIDU_Init_shm_block_t) * local_size;

char *serialized_hnd = NULL;
int serialized_hnd_size = 0;
MPIDU_Init_shm_local_size = MPIR_Process.local_size;
MPIDU_Init_shm_local_rank = MPIR_Process.local_rank;

mpl_err = MPL_shm_hnd_init(&(memory.hnd));
MPIR_ERR_CHKANDJUMP(mpl_err, mpi_errno, MPI_ERR_OTHER, "**alloc_shar_mem");
if (MPIDU_Init_shm_local_size == 1) {
/* We'll special case this trivial case */

memory.segment_len = segment_len;
/* All processes need call MPIR_pmi_bcast. This is because we may need call MPIR_pmi_barrier
* inside depend on PMI versions, and all processes need participate.
*/
int dummy;
mpi_errno = MPIR_pmi_bcast(&dummy, sizeof(int), MPIR_PMI_DOMAIN_LOCAL);
MPIR_ERR_CHECK(mpi_errno);
} else {
size_t segment_len = MPIDU_SHM_CACHE_LINE_LEN +
sizeof(MPIDU_Init_shm_block_t) * MPIDU_Init_shm_local_size;

if (local_size == 1) {
char *addr;
char *serialized_hnd = NULL;
int serialized_hnd_size = 0;

MPIR_CHKPMEM_MALLOC(addr, segment_len + MPIDU_SHM_CACHE_LINE_LEN, MPL_MEM_SHM);
mpl_err = MPL_shm_hnd_init(&(memory.hnd));
MPIR_ERR_CHKANDJUMP(mpl_err, mpi_errno, MPI_ERR_OTHER, "**alloc_shar_mem");

memory.base_addr = addr;
baseaddr =
(char *) (((uintptr_t) addr + (uintptr_t) MPIDU_SHM_CACHE_LINE_LEN - 1) &
(~((uintptr_t) MPIDU_SHM_CACHE_LINE_LEN - 1)));
memory.symmetrical = 0;
memory.segment_len = segment_len;

mpi_errno = Init_shm_barrier_init(TRUE);
MPIR_ERR_CHECK(mpi_errno);
} else {
if (my_local_rank == 0) {
if (MPIDU_Init_shm_local_rank == 0) {
/* root prepare shm segment */
mpl_err = MPL_shm_seg_create_and_attach(memory.hnd, memory.segment_len,
(void **) &(memory.base_addr), 0);
Expand All @@ -164,15 +161,13 @@ int MPIDU_Init_shm_init(void)
serialized_hnd_size = MPIR_pmi_max_val_size();
MPIR_CHKLMEM_MALLOC(serialized_hnd, serialized_hnd_size);
}
}
/* All processes need call MPIR_pmi_bcast. This is because we may need call MPIR_pmi_barrier
* inside depend on PMI versions, and all processes need participate.
*/
mpi_errno = MPIR_pmi_bcast(serialized_hnd, serialized_hnd_size, MPIR_PMI_DOMAIN_LOCAL);
MPIR_ERR_CHECK(mpi_errno);
if (local_size != 1) {
MPIR_Assert(local_size > 1);
if (my_local_rank > 0) {
/* All processes need call MPIR_pmi_bcast. This is because we may need call MPIR_pmi_barrier
* inside depend on PMI versions, and all processes need participate.
*/
mpi_errno = MPIR_pmi_bcast(serialized_hnd, serialized_hnd_size, MPIR_PMI_DOMAIN_LOCAL);
MPIR_ERR_CHECK(mpi_errno);

if (MPIDU_Init_shm_local_rank > 0) {
/* non-root attach shm segment */
mpl_err = MPL_shm_hnd_deserialize(memory.hnd, serialized_hnd, strlen(serialized_hnd));
MPIR_ERR_CHKANDJUMP(mpl_err, mpi_errno, MPI_ERR_OTHER, "**alloc_shar_mem");
Expand All @@ -188,17 +183,17 @@ int MPIDU_Init_shm_init(void)
mpi_errno = Init_shm_barrier();
MPIR_ERR_CHECK(mpi_errno);

if (my_local_rank == 0) {
if (MPIDU_Init_shm_local_rank == 0) {
/* memory->hnd no longer needed */
mpl_err = MPL_shm_seg_remove(memory.hnd);
MPIR_ERR_CHKANDJUMP(mpl_err, mpi_errno, MPI_ERR_OTHER, "**remove_shar_mem");
}

baseaddr = memory.base_addr + MPIDU_SHM_CACHE_LINE_LEN;
memory.symmetrical = 0;
}

mpi_errno = Init_shm_barrier();
mpi_errno = Init_shm_barrier();
}

init_shm_initialized = 1;

Expand All @@ -207,7 +202,6 @@ int MPIDU_Init_shm_init(void)
MPIR_FUNC_EXIT;
return mpi_errno;
fn_fail:
MPIR_CHKPMEM_REAP();
goto fn_exit;
}

Expand All @@ -217,16 +211,12 @@ int MPIDU_Init_shm_finalize(void)

MPIR_FUNC_ENTER;

if (!init_shm_initialized) {
if (!init_shm_initialized || MPIDU_Init_shm_local_size == 1) {
goto fn_exit;
}

if (local_size == 1)
MPL_free(memory.base_addr);
else {
mpl_err = MPL_shm_seg_detach(memory.hnd, (void **) &(memory.base_addr), memory.segment_len);
MPIR_ERR_CHKANDJUMP(mpl_err, mpi_errno, MPI_ERR_OTHER, "**detach_shar_mem");
}
mpl_err = MPL_shm_seg_detach(memory.hnd, (void **) &(memory.base_addr), memory.segment_len);
MPIR_ERR_CHKANDJUMP(mpl_err, mpi_errno, MPI_ERR_OTHER, "**detach_shar_mem");

MPL_shm_hnd_finalize(&(memory.hnd));

Expand All @@ -245,7 +235,9 @@ int MPIDU_Init_shm_barrier(void)

MPIR_FUNC_ENTER;

mpi_errno = Init_shm_barrier();
if (MPIDU_Init_shm_local_size > 1) {
mpi_errno = Init_shm_barrier();
}

MPIR_FUNC_EXIT;

Expand All @@ -258,8 +250,11 @@ int MPIDU_Init_shm_put(void *orig, size_t len)

MPIR_FUNC_ENTER;

MPIR_Assert(len <= sizeof(MPIDU_Init_shm_block_t));
MPIR_Memcpy((char *) baseaddr + my_local_rank * sizeof(MPIDU_Init_shm_block_t), orig, len);
if (MPIDU_Init_shm_local_size > 1) {
MPIR_Assert(len <= sizeof(MPIDU_Init_shm_block_t));
MPIR_Memcpy((char *) baseaddr + MPIDU_Init_shm_local_rank * sizeof(MPIDU_Init_shm_block_t),
orig, len);
}

MPIR_FUNC_EXIT;

Expand All @@ -272,7 +267,10 @@ int MPIDU_Init_shm_get(int local_rank, size_t len, void *target)

MPIR_FUNC_ENTER;

MPIR_Assert(local_rank < local_size && len <= sizeof(MPIDU_Init_shm_block_t));
/* a single process should not get its own put */
MPIR_Assert(MPIDU_Init_shm_local_size > 1);

MPIR_Assert(local_rank < MPIDU_Init_shm_local_size && len <= sizeof(MPIDU_Init_shm_block_t));
MPIR_Memcpy(target, (char *) baseaddr + local_rank * sizeof(MPIDU_Init_shm_block_t), len);

MPIR_FUNC_EXIT;
Expand All @@ -286,7 +284,10 @@ int MPIDU_Init_shm_query(int local_rank, void **target_addr)

MPIR_FUNC_ENTER;

MPIR_Assert(local_rank < local_size);
/* a single process should not get its own put */
MPIR_Assert(MPIDU_Init_shm_local_size > 1);

MPIR_Assert(local_rank < MPIDU_Init_shm_local_size);
*target_addr = (char *) baseaddr + local_rank * sizeof(MPIDU_Init_shm_block_t);

MPIR_FUNC_EXIT;
Expand Down
59 changes: 32 additions & 27 deletions src/mpid/common/shm/mpidu_init_shm_alloc.c
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@
#include <sys/shm.h>
#endif

extern int MPIDU_Init_shm_local_size;
extern int MPIDU_Init_shm_local_rank;

typedef struct memory_list {
void *ptr;
MPIDU_shm_seg_t *memory;
Expand All @@ -39,8 +42,6 @@ int MPIDU_Init_shm_alloc(size_t len, void **ptr)
int mpi_errno = MPI_SUCCESS, mpl_err = 0;
void *current_addr;
size_t segment_len = len;
int local_rank = MPIR_Process.local_rank;
int num_local = MPIR_Process.local_size;
MPIDU_shm_seg_t *memory = NULL;
memory_list_t *memory_node = NULL;
MPIR_CHKPMEM_DECL();
Expand All @@ -49,6 +50,12 @@ int MPIDU_Init_shm_alloc(size_t len, void **ptr)

MPIR_Assert(segment_len > 0);

if (MPIDU_Init_shm_local_size == 1) {
*ptr = MPL_aligned_alloc(MPL_CACHELINE_SIZE, len, MPL_MEM_SHM);
MPIR_ERR_CHKANDJUMP(!*ptr, mpi_errno, MPI_ERR_OTHER, "**nomem");
goto fn_exit;
}

MPIR_CHKPMEM_MALLOC(memory, sizeof(*memory), MPL_MEM_SHM);

mpl_err = MPL_shm_hnd_init(&(memory->hnd));
Expand All @@ -58,19 +65,9 @@ int MPIDU_Init_shm_alloc(size_t len, void **ptr)

char *serialized_hnd = NULL;
int serialized_hnd_size = 0;
/* if there is only one process on this processor, don't use shared memory */
if (num_local == 1) {
char *addr;

MPIR_CHKPMEM_MALLOC(addr, segment_len + MPIDU_SHM_CACHE_LINE_LEN, MPL_MEM_SHM);

memory->base_addr = addr;
current_addr =
(char *) (((uintptr_t) addr + (uintptr_t) MPIDU_SHM_CACHE_LINE_LEN - 1) &
(~((uintptr_t) MPIDU_SHM_CACHE_LINE_LEN - 1)));
memory->symmetrical = 1;
} else {
if (local_rank == 0) {
{
if (MPIDU_Init_shm_local_rank == 0) {
/* root prepare shm segment */
mpl_err = MPL_shm_seg_create_and_attach(memory->hnd, memory->segment_len,
(void **) &(memory->base_addr), 0);
Expand Down Expand Up @@ -98,7 +95,7 @@ int MPIDU_Init_shm_alloc(size_t len, void **ptr)

MPIDU_Init_shm_barrier();

if (local_rank == 0) {
if (MPIDU_Init_shm_local_rank == 0) {
/* memory->hnd no longer needed */
mpl_err = MPL_shm_seg_remove(memory->hnd);
MPIR_ERR_CHKANDJUMP(mpl_err, mpi_errno, MPI_ERR_OTHER, "**remove_shar_mem");
Expand All @@ -124,8 +121,10 @@ int MPIDU_Init_shm_alloc(size_t len, void **ptr)
return mpi_errno;
fn_fail:
/* --BEGIN ERROR HANDLING-- */
MPL_shm_seg_remove(memory->hnd);
MPL_shm_hnd_finalize(&(memory->hnd));
if (MPIDU_Init_shm_local_size > 1) {
MPL_shm_seg_remove(memory->hnd);
MPL_shm_hnd_finalize(&(memory->hnd));
}
MPIR_CHKPMEM_REAP();
goto fn_exit;
/* --END ERROR HANDLING-- */
Expand All @@ -140,6 +139,11 @@ int MPIDU_Init_shm_free(void *ptr)

MPIR_FUNC_ENTER;

if (MPIDU_Init_shm_local_size == 1) {
MPL_free(ptr);
goto fn_exit;
}

/* retrieve memory handle for baseaddr */
LL_FOREACH(memory_head, el) {
if (el->ptr == ptr) {
Expand All @@ -152,17 +156,14 @@ int MPIDU_Init_shm_free(void *ptr)

MPIR_Assert(memory != NULL);

if (MPIR_Process.local_size == 1)
MPL_free(memory->base_addr);
else {
mpl_err = MPL_shm_seg_detach(memory->hnd, (void **) &(memory->base_addr),
memory->segment_len);
MPIR_ERR_CHKANDJUMP(mpl_err, mpi_errno, MPI_ERR_OTHER, "**detach_shar_mem");
}
mpl_err = MPL_shm_seg_detach(memory->hnd, (void **) &(memory->base_addr), memory->segment_len);
MPIR_ERR_CHKANDJUMP(mpl_err, mpi_errno, MPI_ERR_OTHER, "**detach_shar_mem");

fn_exit:
MPL_shm_hnd_finalize(&(memory->hnd));
MPL_free(memory);
if (MPIDU_Init_shm_local_size > 1) {
MPL_shm_hnd_finalize(&(memory->hnd));
MPL_free(memory);
}
MPIR_FUNC_EXIT;
return mpi_errno;
fn_fail:
Expand All @@ -174,6 +175,10 @@ int MPIDU_Init_shm_is_symm(void *ptr)
int ret = -1;
memory_list_t *el;

if (MPIDU_Init_shm_local_size == 1) {
return 1;
}

/* retrieve memory handle for baseaddr */
LL_FOREACH(memory_head, el) {
if (el->ptr == ptr) {
Expand All @@ -196,7 +201,7 @@ static int check_alloc(MPIDU_shm_seg_t * memory)

MPIR_FUNC_ENTER;

if (MPIR_Process.local_rank == 0) {
if (MPIDU_Init_shm_local_rank == 0) {
MPIDU_Init_shm_put(memory->base_addr, sizeof(void *));
}

Expand Down