Skip to content

Commit 2cd9487

Browse files
committed
init_shm: shortcut the trivial case of local_size is 1
It's simpler to shortcut the trivial case. Set global MPIDU_Init_shm_local_size and MPIDU_Init_shm_local_rank. This prepares the flexibility that later we can extend Init_shm to dynamic processes.
1 parent bfaa659 commit 2cd9487

File tree

2 files changed

+85
-79
lines changed

2 files changed

+85
-79
lines changed

src/mpid/common/shm/mpidu_init_shm.c

+53-52
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,9 @@
1212

1313
static int init_shm_initialized;
1414

15+
int MPIDU_Init_shm_local_size;
16+
int MPIDU_Init_shm_local_rank;
17+
1518
#ifdef ENABLE_NO_LOCAL
1619
/* shared memory disabled, just stubs */
1720

@@ -55,8 +58,6 @@ typedef struct Init_shm_barrier {
5558
MPL_atomic_int_t wait;
5659
} Init_shm_barrier_t;
5760

58-
static int local_size;
59-
static int my_local_rank;
6061
static MPIDU_shm_seg_t memory;
6162
static Init_shm_barrier_t *barrier;
6263
static void *baseaddr;
@@ -88,12 +89,12 @@ static int Init_shm_barrier(void)
8889

8990
MPIR_FUNC_ENTER;
9091

91-
if (local_size == 1)
92+
if (MPIDU_Init_shm_local_size == 1)
9293
goto fn_exit;
9394

9495
MPIR_ERR_CHKINTERNAL(!barrier_init, mpi_errno, "barrier not initialized");
9596

96-
if (MPL_atomic_fetch_add_int(&barrier->val, 1) == local_size - 1) {
97+
if (MPL_atomic_fetch_add_int(&barrier->val, 1) == MPIDU_Init_shm_local_size - 1) {
9798
MPL_atomic_store_int(&barrier->val, 0);
9899
MPL_atomic_store_int(&barrier->wait, 1 - sense);
99100
} else {
@@ -112,39 +113,35 @@ static int Init_shm_barrier(void)
112113
int MPIDU_Init_shm_init(void)
113114
{
114115
int mpi_errno = MPI_SUCCESS, mpl_err = 0;
115-
MPIR_CHKPMEM_DECL();
116116
MPIR_CHKLMEM_DECL();
117117

118118
MPIR_FUNC_ENTER;
119119

120-
local_size = MPIR_Process.local_size;
121-
my_local_rank = MPIR_Process.local_rank;
122-
123-
size_t segment_len = MPIDU_SHM_CACHE_LINE_LEN + sizeof(MPIDU_Init_shm_block_t) * local_size;
124-
125-
char *serialized_hnd = NULL;
126-
int serialized_hnd_size = 0;
120+
MPIDU_Init_shm_local_size = MPIR_Process.local_size;
121+
MPIDU_Init_shm_local_rank = MPIR_Process.local_rank;
127122

128-
mpl_err = MPL_shm_hnd_init(&(memory.hnd));
129-
MPIR_ERR_CHKANDJUMP(mpl_err, mpi_errno, MPI_ERR_OTHER, "**alloc_shar_mem");
123+
if (MPIDU_Init_shm_local_size == 1) {
124+
/* We'll special case this trivial case */
130125

131-
memory.segment_len = segment_len;
126+
/* All processes need call MPIR_pmi_bcast. This is because we may need call MPIR_pmi_barrier
127+
* inside depend on PMI versions, and all processes need participate.
128+
*/
129+
int dummy;
130+
mpi_errno = MPIR_pmi_bcast(&dummy, sizeof(int), MPIR_PMI_DOMAIN_LOCAL);
131+
MPIR_ERR_CHECK(mpi_errno);
132+
} else {
133+
size_t segment_len = MPIDU_SHM_CACHE_LINE_LEN +
134+
sizeof(MPIDU_Init_shm_block_t) * MPIDU_Init_shm_local_size;
132135

133-
if (local_size == 1) {
134-
char *addr;
136+
char *serialized_hnd = NULL;
137+
int serialized_hnd_size = 0;
135138

136-
MPIR_CHKPMEM_MALLOC(addr, segment_len + MPIDU_SHM_CACHE_LINE_LEN, MPL_MEM_SHM);
139+
mpl_err = MPL_shm_hnd_init(&(memory.hnd));
140+
MPIR_ERR_CHKANDJUMP(mpl_err, mpi_errno, MPI_ERR_OTHER, "**alloc_shar_mem");
137141

138-
memory.base_addr = addr;
139-
baseaddr =
140-
(char *) (((uintptr_t) addr + (uintptr_t) MPIDU_SHM_CACHE_LINE_LEN - 1) &
141-
(~((uintptr_t) MPIDU_SHM_CACHE_LINE_LEN - 1)));
142-
memory.symmetrical = 0;
142+
memory.segment_len = segment_len;
143143

144-
mpi_errno = Init_shm_barrier_init(TRUE);
145-
MPIR_ERR_CHECK(mpi_errno);
146-
} else {
147-
if (my_local_rank == 0) {
144+
if (MPIDU_Init_shm_local_rank == 0) {
148145
/* root prepare shm segment */
149146
mpl_err = MPL_shm_seg_create_and_attach(memory.hnd, memory.segment_len,
150147
(void **) &(memory.base_addr), 0);
@@ -164,15 +161,13 @@ int MPIDU_Init_shm_init(void)
164161
serialized_hnd_size = MPIR_pmi_max_val_size();
165162
MPIR_CHKLMEM_MALLOC(serialized_hnd, serialized_hnd_size);
166163
}
167-
}
168-
/* All processes need call MPIR_pmi_bcast. This is because we may need call MPIR_pmi_barrier
169-
* inside depend on PMI versions, and all processes need participate.
170-
*/
171-
mpi_errno = MPIR_pmi_bcast(serialized_hnd, serialized_hnd_size, MPIR_PMI_DOMAIN_LOCAL);
172-
MPIR_ERR_CHECK(mpi_errno);
173-
if (local_size != 1) {
174-
MPIR_Assert(local_size > 1);
175-
if (my_local_rank > 0) {
164+
/* All processes need call MPIR_pmi_bcast. This is because we may need call MPIR_pmi_barrier
165+
* inside depend on PMI versions, and all processes need participate.
166+
*/
167+
mpi_errno = MPIR_pmi_bcast(serialized_hnd, serialized_hnd_size, MPIR_PMI_DOMAIN_LOCAL);
168+
MPIR_ERR_CHECK(mpi_errno);
169+
170+
if (MPIDU_Init_shm_local_rank > 0) {
176171
/* non-root attach shm segment */
177172
mpl_err = MPL_shm_hnd_deserialize(memory.hnd, serialized_hnd, strlen(serialized_hnd));
178173
MPIR_ERR_CHKANDJUMP(mpl_err, mpi_errno, MPI_ERR_OTHER, "**alloc_shar_mem");
@@ -188,17 +183,17 @@ int MPIDU_Init_shm_init(void)
188183
mpi_errno = Init_shm_barrier();
189184
MPIR_ERR_CHECK(mpi_errno);
190185

191-
if (my_local_rank == 0) {
186+
if (MPIDU_Init_shm_local_rank == 0) {
192187
/* memory->hnd no longer needed */
193188
mpl_err = MPL_shm_seg_remove(memory.hnd);
194189
MPIR_ERR_CHKANDJUMP(mpl_err, mpi_errno, MPI_ERR_OTHER, "**remove_shar_mem");
195190
}
196191

197192
baseaddr = memory.base_addr + MPIDU_SHM_CACHE_LINE_LEN;
198193
memory.symmetrical = 0;
199-
}
200194

201-
mpi_errno = Init_shm_barrier();
195+
mpi_errno = Init_shm_barrier();
196+
}
202197

203198
init_shm_initialized = 1;
204199

@@ -207,7 +202,6 @@ int MPIDU_Init_shm_init(void)
207202
MPIR_FUNC_EXIT;
208203
return mpi_errno;
209204
fn_fail:
210-
MPIR_CHKPMEM_REAP();
211205
goto fn_exit;
212206
}
213207

@@ -217,16 +211,12 @@ int MPIDU_Init_shm_finalize(void)
217211

218212
MPIR_FUNC_ENTER;
219213

220-
if (!init_shm_initialized) {
214+
if (!init_shm_initialized || MPIDU_Init_shm_local_size == 1) {
221215
goto fn_exit;
222216
}
223217

224-
if (local_size == 1)
225-
MPL_free(memory.base_addr);
226-
else {
227-
mpl_err = MPL_shm_seg_detach(memory.hnd, (void **) &(memory.base_addr), memory.segment_len);
228-
MPIR_ERR_CHKANDJUMP(mpl_err, mpi_errno, MPI_ERR_OTHER, "**detach_shar_mem");
229-
}
218+
mpl_err = MPL_shm_seg_detach(memory.hnd, (void **) &(memory.base_addr), memory.segment_len);
219+
MPIR_ERR_CHKANDJUMP(mpl_err, mpi_errno, MPI_ERR_OTHER, "**detach_shar_mem");
230220

231221
MPL_shm_hnd_finalize(&(memory.hnd));
232222

@@ -245,7 +235,9 @@ int MPIDU_Init_shm_barrier(void)
245235

246236
MPIR_FUNC_ENTER;
247237

248-
mpi_errno = Init_shm_barrier();
238+
if (MPIDU_Init_shm_local_size > 1) {
239+
mpi_errno = Init_shm_barrier();
240+
}
249241

250242
MPIR_FUNC_EXIT;
251243

@@ -258,8 +250,11 @@ int MPIDU_Init_shm_put(void *orig, size_t len)
258250

259251
MPIR_FUNC_ENTER;
260252

261-
MPIR_Assert(len <= sizeof(MPIDU_Init_shm_block_t));
262-
MPIR_Memcpy((char *) baseaddr + my_local_rank * sizeof(MPIDU_Init_shm_block_t), orig, len);
253+
if (MPIDU_Init_shm_local_size > 1) {
254+
MPIR_Assert(len <= sizeof(MPIDU_Init_shm_block_t));
255+
MPIR_Memcpy((char *) baseaddr + MPIDU_Init_shm_local_rank * sizeof(MPIDU_Init_shm_block_t),
256+
orig, len);
257+
}
263258

264259
MPIR_FUNC_EXIT;
265260

@@ -272,7 +267,10 @@ int MPIDU_Init_shm_get(int local_rank, size_t len, void *target)
272267

273268
MPIR_FUNC_ENTER;
274269

275-
MPIR_Assert(local_rank < local_size && len <= sizeof(MPIDU_Init_shm_block_t));
270+
/* a single process should not get its own put */
271+
MPIR_Assert(MPIDU_Init_shm_local_size > 1);
272+
273+
MPIR_Assert(local_rank < MPIDU_Init_shm_local_size && len <= sizeof(MPIDU_Init_shm_block_t));
276274
MPIR_Memcpy(target, (char *) baseaddr + local_rank * sizeof(MPIDU_Init_shm_block_t), len);
277275

278276
MPIR_FUNC_EXIT;
@@ -286,7 +284,10 @@ int MPIDU_Init_shm_query(int local_rank, void **target_addr)
286284

287285
MPIR_FUNC_ENTER;
288286

289-
MPIR_Assert(local_rank < local_size);
287+
/* a single process should not get its own put */
288+
MPIR_Assert(MPIDU_Init_shm_local_size > 1);
289+
290+
MPIR_Assert(local_rank < MPIDU_Init_shm_local_size);
290291
*target_addr = (char *) baseaddr + local_rank * sizeof(MPIDU_Init_shm_block_t);
291292

292293
MPIR_FUNC_EXIT;

src/mpid/common/shm/mpidu_init_shm_alloc.c

+32-27
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,9 @@
1919
#include <sys/shm.h>
2020
#endif
2121

22+
extern int MPIDU_Init_shm_local_size;
23+
extern int MPIDU_Init_shm_local_rank;
24+
2225
typedef struct memory_list {
2326
void *ptr;
2427
MPIDU_shm_seg_t *memory;
@@ -39,8 +42,6 @@ int MPIDU_Init_shm_alloc(size_t len, void **ptr)
3942
int mpi_errno = MPI_SUCCESS, mpl_err = 0;
4043
void *current_addr;
4144
size_t segment_len = len;
42-
int local_rank = MPIR_Process.local_rank;
43-
int num_local = MPIR_Process.local_size;
4445
MPIDU_shm_seg_t *memory = NULL;
4546
memory_list_t *memory_node = NULL;
4647
MPIR_CHKPMEM_DECL();
@@ -49,6 +50,12 @@ int MPIDU_Init_shm_alloc(size_t len, void **ptr)
4950

5051
MPIR_Assert(segment_len > 0);
5152

53+
if (MPIDU_Init_shm_local_size == 1) {
54+
*ptr = MPL_aligned_alloc(MPL_CACHELINE_SIZE, len, MPL_MEM_SHM);
55+
MPIR_ERR_CHKANDJUMP(!*ptr, mpi_errno, MPI_ERR_OTHER, "**nomem");
56+
goto fn_exit;
57+
}
58+
5259
MPIR_CHKPMEM_MALLOC(memory, sizeof(*memory), MPL_MEM_SHM);
5360

5461
mpl_err = MPL_shm_hnd_init(&(memory->hnd));
@@ -58,19 +65,9 @@ int MPIDU_Init_shm_alloc(size_t len, void **ptr)
5865

5966
char *serialized_hnd = NULL;
6067
int serialized_hnd_size = 0;
61-
/* if there is only one process on this processor, don't use shared memory */
62-
if (num_local == 1) {
63-
char *addr;
64-
65-
MPIR_CHKPMEM_MALLOC(addr, segment_len + MPIDU_SHM_CACHE_LINE_LEN, MPL_MEM_SHM);
6668

67-
memory->base_addr = addr;
68-
current_addr =
69-
(char *) (((uintptr_t) addr + (uintptr_t) MPIDU_SHM_CACHE_LINE_LEN - 1) &
70-
(~((uintptr_t) MPIDU_SHM_CACHE_LINE_LEN - 1)));
71-
memory->symmetrical = 1;
72-
} else {
73-
if (local_rank == 0) {
69+
{
70+
if (MPIDU_Init_shm_local_rank == 0) {
7471
/* root prepare shm segment */
7572
mpl_err = MPL_shm_seg_create_and_attach(memory->hnd, memory->segment_len,
7673
(void **) &(memory->base_addr), 0);
@@ -98,7 +95,7 @@ int MPIDU_Init_shm_alloc(size_t len, void **ptr)
9895

9996
MPIDU_Init_shm_barrier();
10097

101-
if (local_rank == 0) {
98+
if (MPIDU_Init_shm_local_rank == 0) {
10299
/* memory->hnd no longer needed */
103100
mpl_err = MPL_shm_seg_remove(memory->hnd);
104101
MPIR_ERR_CHKANDJUMP(mpl_err, mpi_errno, MPI_ERR_OTHER, "**remove_shar_mem");
@@ -124,8 +121,10 @@ int MPIDU_Init_shm_alloc(size_t len, void **ptr)
124121
return mpi_errno;
125122
fn_fail:
126123
/* --BEGIN ERROR HANDLING-- */
127-
MPL_shm_seg_remove(memory->hnd);
128-
MPL_shm_hnd_finalize(&(memory->hnd));
124+
if (MPIDU_Init_shm_local_size > 1) {
125+
MPL_shm_seg_remove(memory->hnd);
126+
MPL_shm_hnd_finalize(&(memory->hnd));
127+
}
129128
MPIR_CHKPMEM_REAP();
130129
goto fn_exit;
131130
/* --END ERROR HANDLING-- */
@@ -140,6 +139,11 @@ int MPIDU_Init_shm_free(void *ptr)
140139

141140
MPIR_FUNC_ENTER;
142141

142+
if (MPIDU_Init_shm_local_size == 1) {
143+
MPL_free(ptr);
144+
goto fn_exit;
145+
}
146+
143147
/* retrieve memory handle for baseaddr */
144148
LL_FOREACH(memory_head, el) {
145149
if (el->ptr == ptr) {
@@ -152,17 +156,14 @@ int MPIDU_Init_shm_free(void *ptr)
152156

153157
MPIR_Assert(memory != NULL);
154158

155-
if (MPIR_Process.local_size == 1)
156-
MPL_free(memory->base_addr);
157-
else {
158-
mpl_err = MPL_shm_seg_detach(memory->hnd, (void **) &(memory->base_addr),
159-
memory->segment_len);
160-
MPIR_ERR_CHKANDJUMP(mpl_err, mpi_errno, MPI_ERR_OTHER, "**detach_shar_mem");
161-
}
159+
mpl_err = MPL_shm_seg_detach(memory->hnd, (void **) &(memory->base_addr), memory->segment_len);
160+
MPIR_ERR_CHKANDJUMP(mpl_err, mpi_errno, MPI_ERR_OTHER, "**detach_shar_mem");
162161

163162
fn_exit:
164-
MPL_shm_hnd_finalize(&(memory->hnd));
165-
MPL_free(memory);
163+
if (MPIDU_Init_shm_local_size > 1) {
164+
MPL_shm_hnd_finalize(&(memory->hnd));
165+
MPL_free(memory);
166+
}
166167
MPIR_FUNC_EXIT;
167168
return mpi_errno;
168169
fn_fail:
@@ -174,6 +175,10 @@ int MPIDU_Init_shm_is_symm(void *ptr)
174175
int ret = -1;
175176
memory_list_t *el;
176177

178+
if (MPIDU_Init_shm_local_size == 1) {
179+
return 1;
180+
}
181+
177182
/* retrieve memory handle for baseaddr */
178183
LL_FOREACH(memory_head, el) {
179184
if (el->ptr == ptr) {
@@ -196,7 +201,7 @@ static int check_alloc(MPIDU_shm_seg_t * memory)
196201

197202
MPIR_FUNC_ENTER;
198203

199-
if (MPIR_Process.local_rank == 0) {
204+
if (MPIDU_Init_shm_local_rank == 0) {
200205
MPIDU_Init_shm_put(memory->base_addr, sizeof(void *));
201206
}
202207

0 commit comments

Comments
 (0)