Skip to content

Commit

Permalink
group: optimize grouputil
Browse files Browse the repository at this point in the history
* Add check_map_is_strided to detect strided pattern and convert a map into a
strided pmap.

* In MPIR_Group_check_subset, use MPIR_Group_lpid_to_rank rather than a
manual linear search.

* Move internal static routines to the bottom of grouputil.c.
  • Loading branch information
hzhou committed Feb 7, 2025
1 parent 2aba5cd commit 1535d80
Showing 1 changed file with 116 additions and 74 deletions.
190 changes: 116 additions & 74 deletions src/mpi/group/grouputil.c
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,10 @@ int MPIR_Group_create(int nproc, MPIR_Group ** new_group_ptr)
return mpi_errno;
}

static bool check_map_is_strided(int size, MPIR_Lpid * map,
MPIR_Lpid * offset_out, MPIR_Lpid * stride_out,
MPIR_Lpid * blocksize_out);

int MPIR_Group_create_map(int size, int rank, MPIR_Session * session_ptr, MPIR_Lpid * map,
MPIR_Group ** new_group_ptr)
{
Expand All @@ -104,10 +108,16 @@ int MPIR_Group_create_map(int size, int rank, MPIR_Session * session_ptr, MPIR_L
newgrp->rank = rank;
MPIR_Group_set_session_ptr(newgrp, session_ptr);

newgrp->pmap.use_map = true;
newgrp->pmap.u.map = map;
if (check_map_is_strided(size, map, &newgrp->pmap.u.stride.offset,
&newgrp->pmap.u.stride.stride, &newgrp->pmap.u.stride.blocksize)) {
newgrp->pmap.use_map = false;
MPL_free(map);
} else {
newgrp->pmap.use_map = true;
newgrp->pmap.u.map = map;
/* TODO: build hash to accelerate MPIR_Group_lpid_to_rank */
}

/* TODO: build hash to accelerate MPIR_Group_lpid_to_rank */
*new_group_ptr = newgrp;
}

Expand Down Expand Up @@ -152,50 +162,7 @@ int MPIR_Group_create_stride(int size, int rank, MPIR_Session * session_ptr,
goto fn_exit;
}

static MPIR_Lpid pmap_rank_to_lpid(struct MPIR_Pmap *pmap, int rank)
{
if (rank < 0 || rank >= pmap->size) {
return MPI_UNDEFINED;
}

if (pmap->use_map) {
return pmap->u.map[rank];
} else {
MPIR_Lpid i_blk = rank / pmap->u.stride.blocksize;
MPIR_Lpid r_blk = rank % pmap->u.stride.blocksize;
return pmap->u.stride.offset + i_blk * pmap->u.stride.stride + r_blk;
}
}

static int pmap_lpid_to_rank(struct MPIR_Pmap *pmap, MPIR_Lpid lpid)
{
if (pmap->use_map) {
/* Use linear search for now.
* Optimization: build hash map in MPIR_Group_create_map and do O(1) hash lookup
*/
for (int rank = 0; rank < pmap->size; rank++) {
if (pmap->u.map[rank] == lpid) {
return rank;
}
}
return MPI_UNDEFINED;
} else {
lpid -= pmap->u.stride.offset;
MPIR_Lpid i_blk = lpid / pmap->u.stride.stride;
MPIR_Lpid r_blk = lpid % pmap->u.stride.stride;

if (r_blk >= pmap->u.stride.blocksize) {
return MPI_UNDEFINED;
}

int rank = i_blk * pmap->u.stride.blocksize + r_blk;
if (rank >= 0 && rank < pmap->size) {
return rank;
} else {
return MPI_UNDEFINED;
}
}
}
static int pmap_lpid_to_rank(struct MPIR_Pmap *pmap, MPIR_Lpid lpid);

int MPIR_Group_lpid_to_rank(MPIR_Group * group, MPIR_Lpid lpid)
{
Expand Down Expand Up @@ -342,42 +309,19 @@ int MPIR_Group_check_valid_ranges(MPIR_Group * group_ptr, int ranges[][3], int n
int MPIR_Group_check_subset(MPIR_Group * group_ptr, MPIR_Comm * comm_ptr)
{
int mpi_errno = MPI_SUCCESS;
MPIR_CHKLMEM_DECL();

MPIR_Assert(group_ptr != NULL);

int vsize = comm_ptr->comm_kind == MPIR_COMM_KIND__INTERCOMM ? comm_ptr->local_size :
comm_ptr->remote_size;

/* Initialize the vmap */
MPIR_Lpid *vmap;
MPIR_CHKLMEM_MALLOC(vmap, vsize * sizeof(MPIR_Lpid));
for (int i = 0; i < vsize; i++) {
/* FIXME: MPID_Comm_get_lpid to be removed */
uint64_t dev_lpid;
MPID_Comm_get_lpid(comm_ptr, i, &dev_lpid, FALSE);
MPIR_Assert((dev_lpid >> 32) == 0);
vmap[i] = dev_lpid;
}

for (int rank = 0; rank < group_ptr->size; rank++) {
MPIR_Lpid lpid = MPIR_Group_rank_to_lpid(group_ptr, rank);
bool found = false;
for (int i = 0; i < vsize; i++) {
if (vmap[i] == lpid) {
found = true;
break;
}
}
if (!found) {
MPIR_ERR_SET1(mpi_errno, MPI_ERR_GROUP, "**groupnotincomm",
"**groupnotincomm %d", rank);
goto fn_fail;
int r = MPIR_Group_lpid_to_rank(comm_ptr->local_group, lpid);
if (r == MPI_UNDEFINED) {
MPIR_ERR_SETANDJUMP1(mpi_errno, MPI_ERR_GROUP, "**groupnotincomm",
"**groupnotincomm %d", rank);
}
}

fn_exit:
MPIR_CHKLMEM_FREEALL();
return mpi_errno;
fn_fail:
goto fn_exit;
Expand All @@ -394,3 +338,101 @@ void MPIR_Group_set_session_ptr(MPIR_Group * group_ptr, MPIR_Session * session_p
MPIR_Session_add_ref(session_ptr);
}
}

/* internal static routines */

static bool check_map_is_strided(int size, MPIR_Lpid * map,
MPIR_Lpid * offset_out, MPIR_Lpid * stride_out,
MPIR_Lpid * blocksize_out)
{
MPIR_Assert(size > 0);
if (size == 1) {
*offset_out = map[0];
*stride_out = 1;
*blocksize_out = 1;
return true;
} else {
MPIR_Lpid offset, stride, blocksize;
offset = map[0];

blocksize = 1;
for (int i = 1; i < size; i++) {
if (map[i] - map[i - 1] == 1) {
blocksize++;
} else {
break;
}
}
if (blocksize == size) {
/* consecutive */
*offset_out = offset;
*stride_out = 1;
*blocksize_out = 1;
return true;
} else {
stride = map[blocksize] - map[0];
int n_strides = (size + blocksize - 1) / blocksize;
int k = 0;
for (int i = 0; i < n_strides; i++) {
for (int j = 0; j < blocksize; j++) {
if (map[k] != offset + i * stride + j) {
return false;
}
k++;
if (k == size) {
break;
}
}
}
*offset_out = offset;
*stride_out = stride;
*blocksize_out = blocksize;
return true;
}
}
}

static MPIR_Lpid pmap_rank_to_lpid(struct MPIR_Pmap *pmap, int rank)
{
if (rank < 0 || rank >= pmap->size) {
return MPI_UNDEFINED;
}

if (pmap->use_map) {
return pmap->u.map[rank];
} else {
MPIR_Lpid i_blk = rank / pmap->u.stride.blocksize;
MPIR_Lpid r_blk = rank % pmap->u.stride.blocksize;
return pmap->u.stride.offset + i_blk * pmap->u.stride.stride + r_blk;
}
}

static int pmap_lpid_to_rank(struct MPIR_Pmap *pmap, MPIR_Lpid lpid)
{
if (pmap->use_map) {
/* Use linear search for now.
* Optimization: build hash map in MPIR_Group_create_map and do O(1) hash lookup
*/
for (int rank = 0; rank < pmap->size; rank++) {
if (pmap->u.map[rank] == lpid) {
return rank;
}
}
return MPI_UNDEFINED;
} else {
lpid -= pmap->u.stride.offset;
MPIR_Lpid i_blk = lpid / pmap->u.stride.stride;
MPIR_Lpid r_blk = lpid % pmap->u.stride.stride;

if (r_blk >= pmap->u.stride.blocksize) {
return MPI_UNDEFINED;
}

int rank = i_blk * pmap->u.stride.blocksize + r_blk;
if (rank >= 0 && rank < pmap->size) {
return rank;
} else {
return MPI_UNDEFINED;
}
}
}

0 comments on commit 1535d80

Please sign in to comment.