@@ -103,10 +103,6 @@ struct ompi_comm_allreduce_context_t {
103103 ompi_comm_cid_context_t * cid_context ;
104104 int * tmpbuf ;
105105
106- /* for intercomm allreduce */
107- int * rcounts ;
108- int * rdisps ;
109-
110106 /* for group allreduce */
111107 int peers_comm [3 ];
112108};
@@ -121,8 +117,6 @@ static void ompi_comm_allreduce_context_construct (ompi_comm_allreduce_context_t
121117static void ompi_comm_allreduce_context_destruct (ompi_comm_allreduce_context_t * context )
122118{
123119 free (context -> tmpbuf );
124- free (context -> rcounts );
125- free (context -> rdisps );
126120}
127121
128122OBJ_CLASS_INSTANCE (ompi_comm_allreduce_context_t , opal_object_t ,
@@ -602,7 +596,7 @@ static int ompi_comm_allreduce_intra_nb (int *inbuf, int *outbuf, int count, str
602596/* Non-blocking version of ompi_comm_allreduce_inter */
603597static int ompi_comm_allreduce_inter_leader_exchange (ompi_comm_request_t * request );
604598static int ompi_comm_allreduce_inter_leader_reduce (ompi_comm_request_t * request );
605- static int ompi_comm_allreduce_inter_allgather (ompi_comm_request_t * request );
599+ static int ompi_comm_allreduce_inter_bcast (ompi_comm_request_t * request );
606600
607601static int ompi_comm_allreduce_inter_nb (int * inbuf , int * outbuf ,
608602 int count , struct ompi_op_t * op ,
@@ -636,18 +630,19 @@ static int ompi_comm_allreduce_inter_nb (int *inbuf, int *outbuf,
636630 rsize = ompi_comm_remote_size (intercomm );
637631 local_rank = ompi_comm_rank (intercomm );
638632
639- context -> tmpbuf = ( int * ) calloc ( count , sizeof ( int ));
640- context -> rdisps = (int * ) calloc (rsize , sizeof (int ));
641- context -> rcounts = ( int * ) calloc ( rsize , sizeof ( int ));
642- if ( OPAL_UNLIKELY ( NULL == context -> tmpbuf || NULL == context -> rdisps || NULL == context -> rcounts )) {
643- ompi_comm_request_return ( request ) ;
644- return OMPI_ERR_OUT_OF_RESOURCE ;
633+ if ( 0 == local_rank ) {
634+ context -> tmpbuf = (int * ) calloc (count , sizeof (int ));
635+ if ( OPAL_UNLIKELY ( NULL == context -> tmpbuf )) {
636+ ompi_comm_request_return ( request );
637+ return OMPI_ERR_OUT_OF_RESOURCE ;
638+ }
645639 }
646640
647641 /* Execute the inter-allreduce: the result from the local will be in the buffer of the remote group
648642 * and vise-versa. */
649- rc = intercomm -> c_coll .coll_iallreduce (inbuf , context -> tmpbuf , count , MPI_INT , op , intercomm ,
650- & subreq , intercomm -> c_coll .coll_iallreduce_module );
643+ rc = intercomm -> c_local_comm -> c_coll .coll_ireduce (inbuf , context -> tmpbuf , count , MPI_INT , op , 0
644+ intercomm -> c_local_comm , & subreq ,
645+ intercomm -> c_local_comm -> c_coll .coll_ireduce_module );
651646 if (OPAL_UNLIKELY (OMPI_SUCCESS != rc )) {
652647 ompi_comm_request_return (request );
653648 return rc ;
@@ -656,7 +651,7 @@ static int ompi_comm_allreduce_inter_nb (int *inbuf, int *outbuf,
656651 if (0 == local_rank ) {
657652 ompi_comm_request_schedule_append (request , ompi_comm_allreduce_inter_leader_exchange , & subreq , 1 );
658653 } else {
659- ompi_comm_request_schedule_append (request , ompi_comm_allreduce_inter_allgather , & subreq , 1 );
654+ ompi_comm_request_schedule_append (request , ompi_comm_allreduce_inter_bcast , & subreq , 1 );
660655 }
661656
662657 ompi_comm_request_start (request );
@@ -696,33 +691,20 @@ static int ompi_comm_allreduce_inter_leader_reduce (ompi_comm_request_t *request
696691
697692 ompi_op_reduce (context -> op , context -> tmpbuf , context -> outbuf , context -> count , MPI_INT );
698693
699- return ompi_comm_allreduce_inter_allgather (request );
694+ return ompi_comm_allreduce_inter_bcast (request );
700695}
701696
702697
703- static int ompi_comm_allreduce_inter_allgather (ompi_comm_request_t * request )
698+ static int ompi_comm_allreduce_inter_bcast (ompi_comm_request_t * request )
704699{
705700 ompi_comm_allreduce_context_t * context = (ompi_comm_allreduce_context_t * ) request -> context ;
706- ompi_communicator_t * intercomm = context -> cid_context -> comm ;
701+ ompi_communicator_t * comm = context -> cid_context -> comm -> c_local_comm ;
707702 ompi_request_t * subreq ;
708703 int scount = 0 , rc ;
709704
710- /* distribute the overall result to all processes in the other group.
711- Instead of using bcast, we are using here allgatherv, to avoid the
712- possible deadlock. Else, we need an algorithm to determine,
713- which group sends first in the inter-bcast and which receives
714- the result first.
715- */
716-
717- if (0 != ompi_comm_rank (intercomm )) {
718- context -> rcounts [0 ] = context -> count ;
719- } else {
720- scount = context -> count ;
721- }
722-
723- rc = intercomm -> c_coll .coll_iallgatherv (context -> outbuf , scount , MPI_INT , context -> outbuf ,
724- context -> rcounts , context -> rdisps , MPI_INT , intercomm ,
725- & subreq , intercomm -> c_coll .coll_iallgatherv_module );
705+ /* both roots have the same result. broadcast to the local group */
706+ rc = comm -> c_coll .coll_ibcast (context -> outbuf , context -> count , MPI_INT , 0 , comm ,
707+ & subreq , comm -> c_coll .coll_ibcast_module );
726708 if (OMPI_SUCCESS != rc ) {
727709 return rc ;
728710 }
0 commit comments