Skip to content

Commit

Permalink
group: optimize grouputil
Browse files Browse the repository at this point in the history
* Add check_map_is_strided to detect strided pattern and convert a map into a
strided pmap.

* Move internal static routines to the bottom of grouputil.c.
  • Loading branch information
hzhou committed Feb 11, 2025
1 parent ace3004 commit a78823b
Showing 1 changed file with 124 additions and 47 deletions.
171 changes: 124 additions & 47 deletions src/mpi/group/grouputil.c
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,10 @@ int MPIR_Group_create(int nproc, MPIR_Group ** new_group_ptr)
return mpi_errno;
}

static bool check_map_is_strided(int size, MPIR_Lpid * map,
MPIR_Lpid * offset_out, MPIR_Lpid * stride_out,
MPIR_Lpid * blocksize_out);

int MPIR_Group_create_map(int size, int rank, MPIR_Session * session_ptr, MPIR_Lpid * map,
MPIR_Group ** new_group_ptr)
{
Expand All @@ -104,10 +108,16 @@ int MPIR_Group_create_map(int size, int rank, MPIR_Session * session_ptr, MPIR_L
newgrp->rank = rank;
MPIR_Group_set_session_ptr(newgrp, session_ptr);

newgrp->pmap.use_map = true;
newgrp->pmap.u.map = map;
if (check_map_is_strided(size, map, &newgrp->pmap.u.stride.offset,
&newgrp->pmap.u.stride.stride, &newgrp->pmap.u.stride.blocksize)) {
newgrp->pmap.use_map = false;
MPL_free(map);
} else {
newgrp->pmap.use_map = true;
newgrp->pmap.u.map = map;
/* TODO: build hash to accelerate MPIR_Group_lpid_to_rank */
}

/* TODO: build hash to accelerate MPIR_Group_lpid_to_rank */
*new_group_ptr = newgrp;
}

Expand Down Expand Up @@ -155,50 +165,8 @@ int MPIR_Group_create_stride(int size, int rank, MPIR_Session * session_ptr,
goto fn_exit;
}

static MPIR_Lpid pmap_rank_to_lpid(struct MPIR_Pmap *pmap, int rank)
{
if (rank < 0 || rank >= pmap->size) {
return MPI_UNDEFINED;
}

if (pmap->use_map) {
return pmap->u.map[rank];
} else {
MPIR_Lpid i_blk = rank / pmap->u.stride.blocksize;
MPIR_Lpid r_blk = rank % pmap->u.stride.blocksize;
return pmap->u.stride.offset + i_blk * pmap->u.stride.stride + r_blk;
}
}

static int pmap_lpid_to_rank(struct MPIR_Pmap *pmap, MPIR_Lpid lpid)
{
if (pmap->use_map) {
/* Use linear search for now.
* Optimization: build hash map in MPIR_Group_create_map and do O(1) hash lookup
*/
for (int rank = 0; rank < pmap->size; rank++) {
if (pmap->u.map[rank] == lpid) {
return rank;
}
}
return MPI_UNDEFINED;
} else {
lpid -= pmap->u.stride.offset;
MPIR_Lpid i_blk = lpid / pmap->u.stride.stride;
MPIR_Lpid r_blk = lpid % pmap->u.stride.stride;

if (r_blk >= pmap->u.stride.blocksize) {
return MPI_UNDEFINED;
}

int rank = i_blk * pmap->u.stride.blocksize + r_blk;
if (rank >= 0 && rank < pmap->size) {
return rank;
} else {
return MPI_UNDEFINED;
}
}
}
static int pmap_lpid_to_rank(struct MPIR_Pmap *pmap, MPIR_Lpid lpid);
static MPIR_Lpid pmap_rank_to_lpid(struct MPIR_Pmap *pmap, int rank);

int MPIR_Group_lpid_to_rank(MPIR_Group * group, MPIR_Lpid lpid)
{
Expand Down Expand Up @@ -397,3 +365,112 @@ void MPIR_Group_set_session_ptr(MPIR_Group * group_ptr, MPIR_Session * session_p
MPIR_Session_add_ref(session_ptr);
}
}

/* internal static routines */

static bool check_map_is_strided(int size, MPIR_Lpid * map,
MPIR_Lpid * offset_out, MPIR_Lpid * stride_out,
MPIR_Lpid * blocksize_out)
{
MPIR_Assert(size > 0);
if (size == 1) {
*offset_out = map[0];
*stride_out = 1;
*blocksize_out = 1;
return true;
} else {
MPIR_Lpid offset, stride, blocksize;
offset = map[0];

blocksize = 1;
for (int i = 1; i < size; i++) {
if (map[i] - map[i - 1] == 1) {
blocksize++;
} else {
break;
}
}
if (blocksize == size) {
/* consecutive */
*offset_out = offset;
*stride_out = 1;
*blocksize_out = 1;
return true;
} else {
/* NOTE: stride may be negative */
stride = map[blocksize] - map[0];
int n_strides = (size + blocksize - 1) / blocksize;
int k = 0;
for (int i = 0; i < n_strides; i++) {
for (int j = 0; j < blocksize; j++) {
if (map[k] != offset + i * stride + j) {
return false;
}
k++;
if (k == size) {
break;
}
}
}
*offset_out = offset;
*stride_out = stride;
*blocksize_out = blocksize;
return true;
}
}
}

static MPIR_Lpid pmap_rank_to_lpid(struct MPIR_Pmap *pmap, int rank)
{
if (rank < 0 || rank >= pmap->size) {
return MPI_UNDEFINED;
}

if (pmap->use_map) {
return pmap->u.map[rank];
} else {
MPIR_Lpid i_blk = rank / pmap->u.stride.blocksize;
MPIR_Lpid r_blk = rank % pmap->u.stride.blocksize;
return pmap->u.stride.offset + i_blk * pmap->u.stride.stride + r_blk;
}
}

static int pmap_lpid_to_rank(struct MPIR_Pmap *pmap, MPIR_Lpid lpid)
{
if (pmap->use_map) {
/* Use linear search for now.
* Optimization: build hash map in MPIR_Group_create_map and do O(1) hash lookup
*/
for (int rank = 0; rank < pmap->size; rank++) {
if (pmap->u.map[rank] == lpid) {
return rank;
}
}
return MPI_UNDEFINED;
} else {
lpid -= pmap->u.stride.offset;
MPIR_Lpid i_blk = lpid / pmap->u.stride.stride;
MPIR_Lpid r_blk = lpid % pmap->u.stride.stride;
/* NOTE: stride could be negative, in which case, make sure r_blk >= 0 */
if (r_blk < 0) {
MPIR_Assert(pmap->u.stride.stride < 0);
r_blk -= pmap->u.stride.stride;
i_blk += 1;
}

if (i_blk < 0) {
return MPI_UNDEFINED;
}

if (r_blk >= pmap->u.stride.blocksize) {
return MPI_UNDEFINED;
}

int rank = i_blk * pmap->u.stride.blocksize + r_blk;
if (rank >= 0 && rank < pmap->size) {
return rank;
} else {
return MPI_UNDEFINED;
}
}
}

0 comments on commit a78823b

Please sign in to comment.