@@ -49,8 +49,7 @@ static inline int coll_allreduce_decision_fixed(int comm_size, size_t msg_size)
4949    return  alg ;
5050}
5151
52- #ifdef  HAVE_XPMEM_H 
53- static  inline  int  mca_coll_acoll_reduce_xpmem_h (const  void  * sbuf , void  * rbuf , size_t  count ,
52+ static  inline  int  mca_coll_acoll_reduce_smsc_h (const  void  * sbuf , void  * rbuf , size_t  count ,
5453                                                struct  ompi_datatype_t  * dtype , struct  ompi_op_t  * op ,
5554                                                struct  ompi_communicator_t  * comm ,
5655                                                mca_coll_base_module_t  * module ,
@@ -79,9 +78,9 @@ static inline int mca_coll_acoll_reduce_xpmem_h(const void *sbuf, void *rbuf, si
7978    int  l2_local_rank  =  data -> l2_local_rank ;
8079    char  * tmp_sbuf  =  NULL ;
8180    char  * tmp_rbuf  =  NULL ;
82-     if  (!subc -> xpmem_use_sr_buf ) {
81+     if  (!subc -> smsc_use_sr_buf ) {
8382        tmp_rbuf  =  (char  * ) data -> scratch ;
84-         tmp_sbuf  =  (char  * ) data -> scratch  +  (subc -> xpmem_buf_size ) / 2 ;
83+         tmp_sbuf  =  (char  * ) data -> scratch  +  (subc -> smsc_buf_size ) / 2 ;
8584        if  ((MPI_IN_PLACE  ==  sbuf )) {
8685            memcpy (tmp_sbuf , rbuf , total_dsize );
8786        } else  {
@@ -112,7 +111,10 @@ static inline int mca_coll_acoll_reduce_xpmem_h(const void *sbuf, void *rbuf, si
112111        return  err ;
113112    }
114113
115-     register_and_cache (size , total_dsize , rank , data );
114+     err  =  register_mem_with_smsc (rank , size , total_dsize , data , comm );
115+     if  (err  !=  MPI_SUCCESS ) {
116+         return  err ;
117+     }
116118
117119    /* reduce to the local group leader */ 
118120    size_t  chunk  =  count  / l1_gp_size ;
@@ -123,21 +125,21 @@ static inline int mca_coll_acoll_reduce_xpmem_h(const void *sbuf, void *rbuf, si
123125            memcpy (tmp_rbuf , sbuf , my_count_size  *  dsize );
124126
125127        for  (int  i  =  1 ; i  <  l1_gp_size ; i ++ ) {
126-             ompi_op_reduce (op , (char  * ) data -> xpmem_saddr [l1_gp [i ]] +  chunk  *  l1_local_rank  *  dsize ,
128+             ompi_op_reduce (op , (char  * ) data -> smsc_saddr [l1_gp [i ]] +  chunk  *  l1_local_rank  *  dsize ,
127129                           (char  * ) tmp_rbuf  +  chunk  *  l1_local_rank  *  dsize , my_count_size , dtype );
128130        }
129131    } else  {
130132        ompi_3buff_op_reduce (op ,
131-                              (char  * ) data -> xpmem_saddr [l1_gp [0 ]] +  chunk  *  l1_local_rank  *  dsize ,
133+                              (char  * ) data -> smsc_saddr [l1_gp [0 ]] +  chunk  *  l1_local_rank  *  dsize ,
132134                             (char  * ) tmp_sbuf  +  chunk  *  l1_local_rank  *  dsize ,
133-                              (char  * ) data -> xpmem_raddr [l1_gp [0 ]] +  chunk  *  l1_local_rank  *  dsize ,
135+                              (char  * ) data -> smsc_raddr [l1_gp [0 ]] +  chunk  *  l1_local_rank  *  dsize ,
134136                             my_count_size , dtype );
135137        for  (int  i  =  1 ; i  <  l1_gp_size ; i ++ ) {
136138            if  (i  ==  l1_local_rank ) {
137139                continue ;
138140            }
139-             ompi_op_reduce (op , (char  * ) data -> xpmem_saddr [l1_gp [i ]] +  chunk  *  l1_local_rank  *  dsize ,
140-                            (char  * ) data -> xpmem_raddr [l1_gp [0 ]] +  chunk  *  l1_local_rank  *  dsize ,
141+             ompi_op_reduce (op , (char  * ) data -> smsc_saddr [l1_gp [i ]] +  chunk  *  l1_local_rank  *  dsize ,
142+                            (char  * ) data -> smsc_raddr [l1_gp [0 ]] +  chunk  *  l1_local_rank  *  dsize ,
141143                           my_count_size , dtype );
142144        }
143145    }
@@ -155,7 +157,7 @@ static inline int mca_coll_acoll_reduce_xpmem_h(const void *sbuf, void *rbuf, si
155157
156158        if  (0  ==  l2_local_rank ) {
157159            for  (int  i  =  1 ; i  <  local_size ; i ++ ) {
158-                 ompi_op_reduce (op , (char  * ) data -> xpmem_raddr [l2_gp [i ]], (char  * ) tmp_rbuf ,
160+                 ompi_op_reduce (op , (char  * ) data -> smsc_raddr [l2_gp [i ]], (char  * ) tmp_rbuf ,
159161                               my_count_size , dtype );
160162            }
161163        } else  {
@@ -165,24 +167,26 @@ static inline int mca_coll_acoll_reduce_xpmem_h(const void *sbuf, void *rbuf, si
165167                }
166168
167169                ompi_op_reduce (op ,
168-                                (char  * ) data -> xpmem_raddr [l2_gp [i ]] +  chunk  *  l2_local_rank  *  dsize ,
169-                                (char  * ) data -> xpmem_raddr [0 ] +  chunk  *  l2_local_rank  *  dsize ,
170+                                (char  * ) data -> smsc_raddr [l2_gp [i ]] +  chunk  *  l2_local_rank  *  dsize ,
171+                                (char  * ) data -> smsc_raddr [0 ] +  chunk  *  l2_local_rank  *  dsize ,
170172                               my_count_size , dtype );
171173            }
172174            ompi_op_reduce (op , (char  * ) tmp_rbuf  +  chunk  *  l2_local_rank  *  dsize ,
173-                            (char  * ) data -> xpmem_raddr [0 ] +  chunk  *  l2_local_rank  *  dsize ,
175+                            (char  * ) data -> smsc_raddr [0 ] +  chunk  *  l2_local_rank  *  dsize ,
174176                           my_count_size , dtype );
175177        }
176178    }
177179
178180    err  =  ompi_coll_base_barrier_intra_tree (comm , module );
179-     if  (!subc -> xpmem_use_sr_buf ) {
181+     if  (!subc -> smsc_use_sr_buf ) {
180182        memcpy (rbuf , tmp_rbuf , total_dsize );
181183    }
184+     // Note: neither unmap nor deregister will have any effect here, just having it for consistency 
185+     unmap_mem_with_smsc (rank , size , data );
182186    return  err ;
183187}
184188
185- static  inline  int  mca_coll_acoll_allreduce_xpmem_f (const  void  * sbuf , void  * rbuf , size_t  count ,
189+ static  inline  int  mca_coll_acoll_allreduce_smsc_f (const  void  * sbuf , void  * rbuf , size_t  count ,
186190                                                   struct  ompi_datatype_t  * dtype ,
187191                                                   struct  ompi_op_t  * op ,
188192                                                   struct  ompi_communicator_t  * comm ,
@@ -204,9 +208,9 @@ static inline int mca_coll_acoll_allreduce_xpmem_f(const void *sbuf, void *rbuf,
204208
205209    char  * tmp_sbuf  =  NULL ;
206210    char  * tmp_rbuf  =  NULL ;
207-     if  (!subc -> xpmem_use_sr_buf ) {
211+     if  (!subc -> smsc_use_sr_buf ) {
208212        tmp_rbuf  =  (char  * ) data -> scratch ;
209-         tmp_sbuf  =  (char  * ) data -> scratch  +  (subc -> xpmem_buf_size ) / 2 ;
213+         tmp_sbuf  =  (char  * ) data -> scratch  +  (subc -> smsc_buf_size ) / 2 ;
210214        if  ((MPI_IN_PLACE  ==  sbuf )) {
211215            memcpy (tmp_sbuf , rbuf , total_dsize );
212216        } else  {
@@ -238,15 +242,18 @@ static inline int mca_coll_acoll_allreduce_xpmem_f(const void *sbuf, void *rbuf,
238242        return  err ;
239243    }
240244
241-     register_and_cache (size , total_dsize , rank , data );
245+     err  =  register_mem_with_smsc (rank , size , total_dsize , data , comm );
246+     if  (err  !=  MPI_SUCCESS ) {
247+         return  err ;
248+     }
242249
243250    size_t  chunk  =  count  / size ;
244251    size_t  my_count_size  =  (rank  ==  (size  -  1 )) ? (count  / size ) +  count  % size  : count  / size ;
245252    if  (0  ==  rank ) {
246253        if  (sbuf  !=  MPI_IN_PLACE )
247254            memcpy (tmp_rbuf , sbuf , my_count_size  *  dsize );
248255    } else  {
249-         ompi_3buff_op_reduce (op , (char  * ) data -> xpmem_saddr [0 ] +  chunk  *  rank  *  dsize ,
256+         ompi_3buff_op_reduce (op , (char  * ) data -> smsc_saddr [0 ] +  chunk  *  rank  *  dsize ,
250257                             (char  * ) tmp_sbuf  +  chunk  *  rank  *  dsize ,
251258                             (char  * ) tmp_rbuf  +  chunk  *  rank  *  dsize , my_count_size , dtype );
252259    }
@@ -260,7 +267,7 @@ static inline int mca_coll_acoll_allreduce_xpmem_f(const void *sbuf, void *rbuf,
260267        if  (rank  ==  i ) {
261268            continue ;
262269        }
263-         ompi_op_reduce (op , (char  * ) data -> xpmem_saddr [i ] +  chunk  *  rank  *  dsize ,
270+         ompi_op_reduce (op , (char  * ) data -> smsc_saddr [i ] +  chunk  *  rank  *  dsize ,
264271                       (char  * ) tmp_rbuf  +  chunk  *  rank  *  dsize , my_count_size , dtype );
265272    }
266273    err  =  ompi_coll_base_barrier_intra_tree (comm , module );
@@ -270,21 +277,23 @@ static inline int mca_coll_acoll_allreduce_xpmem_f(const void *sbuf, void *rbuf,
270277
271278    size_t  tmp  =  chunk  *  dsize ;
272279    for  (int  i  =  0 ; i  <  size ; i ++ ) {
273-         if  (subc -> xpmem_use_sr_buf  &&  (rank  ==  i )) {
280+         if  (subc -> smsc_use_sr_buf  &&  (rank  ==  i )) {
274281            continue ;
275282        }
276283        my_count_size  =  (i  ==  (size  -  1 )) ? (count  / size ) +  count  % size  : count  / size ;
277284        size_t  tmp1  =  i  *  tmp ;
278285        char  * dst  =  (char  * ) rbuf  +  tmp1 ;
279-         char  * src  =  (char  * ) data -> xpmem_raddr [i ] +  tmp1 ;
286+         char  * src  =  (char  * ) data -> smsc_raddr [i ] +  tmp1 ;
280287        memcpy (dst , src , my_count_size  *  dsize );
281288    }
282289
283290    err  =  ompi_coll_base_barrier_intra_tree (comm , module );
284291
292+     // Note: neither unmap nor deregister will have any effect here, just having it for consistency 
293+     unmap_mem_with_smsc (rank , size , data );
294+ 
285295    return  err ;
286296}
287- #endif 
288297
289298void  mca_coll_acoll_sync (coll_acoll_data_t  * data , int  offset , int  * group , int  gp_size , int  rank ,
290299                         int  up )
@@ -450,7 +459,7 @@ int mca_coll_acoll_allreduce_intra(const void *sbuf, void *rbuf, size_t count,
450459    ompi_datatype_type_size (dtype , & dsize );
451460    total_dsize  =  dsize  *  count ;
452461
453-     /* Disable shm/xpmem based optimizations if: */ 
462+     /* Disable smsc/ shm/xpmem based optimizations if: */ 
454463    /* - datatype is not a predefined type */ 
455464    /* - it's a gpu buffer */ 
456465    uint64_t  flags  =  0 ;
@@ -481,6 +490,10 @@ int mca_coll_acoll_allreduce_intra(const void *sbuf, void *rbuf, size_t count,
481490    coll_acoll_subcomms_t  * subc  =  NULL ;
482491    err  =  check_and_create_subc (comm , acoll_module , & subc );
483492
493+     if  (MPI_SUCCESS  !=  err ) {
494+         return  err ;
495+     }
496+ 
484497    /* Fallback to knomial if subc is not obtained */ 
485498    if  (NULL  ==  subc ) {
486499        return  ompi_coll_base_allreduce_intra_redscat_allgather (sbuf , rbuf , count , dtype , op , comm ,
@@ -518,42 +531,27 @@ int mca_coll_acoll_allreduce_intra(const void *sbuf, void *rbuf, size_t count,
518531                                                                     comm , module , 0 );
519532            }
520533        } else  if  (total_dsize  <  4194304 ) {
521- #ifdef  HAVE_XPMEM_H 
522-             if  (((subc -> xpmem_use_sr_buf  !=  0 ) ||  (subc -> xpmem_buf_size  >  2  *  total_dsize )) &&  (subc -> without_xpmem  !=  1 ) &&  is_opt ) {
523-                 return  mca_coll_acoll_allreduce_xpmem_f (sbuf , rbuf , count , dtype , op , comm , module , subc );
534+             if  (((subc -> smsc_use_sr_buf  !=  0 ) ||  (subc -> smsc_buf_size  >  2  *  total_dsize )) &&  (subc -> without_smsc  !=  1 ) &&  is_opt ) {
535+                 return  mca_coll_acoll_allreduce_smsc_f (sbuf , rbuf , count , dtype , op , comm , module , subc );
524536            } else  {
525537                return  ompi_coll_base_allreduce_intra_redscat_allgather (sbuf , rbuf , count , dtype ,
526538                                                                        op , comm , module );
527539            }
528- #else 
529-             return  ompi_coll_base_allreduce_intra_redscat_allgather (sbuf , rbuf , count , dtype , op ,
530-                                                                     comm , module );
531- #endif 
532540        } else  if  (total_dsize  <= 16777216 ) {
533- #ifdef  HAVE_XPMEM_H 
534-             if  (((subc -> xpmem_use_sr_buf  !=  0 ) ||  (subc -> xpmem_buf_size  >  2  *  total_dsize )) &&  (subc -> without_xpmem  !=  1 ) &&  is_opt ) {
535-                 mca_coll_acoll_reduce_xpmem_h (sbuf , rbuf , count , dtype , op , comm , module , subc );
541+             if  (((subc -> smsc_use_sr_buf  !=  0 ) ||  (subc -> smsc_buf_size  >  2  *  total_dsize )) &&  (subc -> without_smsc  !=  1 ) &&  is_opt ) {
542+                 mca_coll_acoll_reduce_smsc_h (sbuf , rbuf , count , dtype , op , comm , module , subc );
536543                return  mca_coll_acoll_bcast (rbuf , count , dtype , 0 , comm , module );
537544            } else  {
538545                return  ompi_coll_base_allreduce_intra_redscat_allgather (sbuf , rbuf , count , dtype ,
539546                                                                        op , comm , module );
540547            }
541- #else 
542-             return  ompi_coll_base_allreduce_intra_redscat_allgather (sbuf , rbuf , count , dtype , op ,
543-                                                                     comm , module );
544- #endif 
545548        } else  {
546- #ifdef  HAVE_XPMEM_H 
547-             if  (((subc -> xpmem_use_sr_buf  !=  0 ) ||  (subc -> xpmem_buf_size  >  2  *  total_dsize )) &&  (subc -> without_xpmem  !=  1 ) &&  is_opt ) {
548-                 return  mca_coll_acoll_allreduce_xpmem_f (sbuf , rbuf , count , dtype , op , comm , module , subc );
549+             if  (((subc -> smsc_use_sr_buf  !=  0 ) ||  (subc -> smsc_buf_size  >  2  *  total_dsize )) &&  (subc -> without_smsc  !=  1 ) &&  is_opt ) {
550+                 return  mca_coll_acoll_allreduce_smsc_f (sbuf , rbuf , count , dtype , op , comm , module , subc );
549551            } else  {
550552                return  ompi_coll_base_allreduce_intra_redscat_allgather (sbuf , rbuf , count , dtype ,
551553                                                                        op , comm , module );
552554            }
553- #else 
554-             return  ompi_coll_base_allreduce_intra_redscat_allgather (sbuf , rbuf , count , dtype , op ,
555-                                                                     comm , module );
556- #endif 
557555        }
558556
559557    } else  {
0 commit comments