@@ -611,6 +611,112 @@ def reduce_scatter_tensor_coalesced(
611611 )
612612
613613
614+ class _ParallelWork (Work ):
615+ def __init__ (self , works : List [Work ]) -> None :
616+ super ().__init__ ()
617+ self ._works = works
618+
619+ def wait (self , timeout : Optional [timedelta ] = None ) -> bool :
620+ for work in self ._works :
621+ if timeout is not None :
622+ work .wait (timeout = timeout )
623+ else :
624+ work .wait ()
625+ return True
626+
627+ def get_future (self ) -> torch .futures .Future [object ]:
628+ futures = [work .get_future () for work in self ._works ]
629+ return torch .futures .collect_all (futures )
630+
631+
632+ class ParallelProcessGroup (ProcessGroupWrapper ):
633+ def __init__ (
634+ self ,
635+ base : ProcessGroupWrapper ,
636+ timeout : timedelta = timedelta (seconds = 60 ),
637+ count : int = 10 ,
638+ ) -> None :
639+ super ().__init__ (timeout = timeout )
640+
641+ self ._base = base
642+ self ._count = count
643+ self ._pgs = []
644+
645+ self ._create_pg = base ._create_pg
646+
647+ def configure (self , store_addr : str , rank : int , world_size : int ) -> None :
648+ # abort if already initialized
649+ self .abort ()
650+
651+ for i in range (self ._count ):
652+ store = create_store_client (
653+ f"{ store_addr } /parallel{ i } " , timeout = self ._timeout
654+ )
655+
656+ self ._pgs .append (self ._create_pg (store , rank , world_size ))
657+
658+ self ._pg = self ._pgs [0 ]
659+
660+ def getBackendName (self ) -> str :
661+ return f"{ self ._base .getBackendName ()} -parallel"
662+
663+ def _split_tensors (self , tensors : List [torch .Tensor ]) -> List [List [torch .Tensor ]]:
664+ if not isinstance (tensors , (list , tuple )):
665+ tensors = [tensors ]
666+
667+ tensor_lists = [[] for _ in range (self ._count )]
668+ for t in tensors :
669+ chunks = torch .tensor_split (t .view (- 1 ), self ._count , dim = 0 )
670+ for i , chunk in enumerate (chunks ):
671+ tensor_lists [i ].append (chunk )
672+
673+ return tensor_lists
674+
675+ def allreduce (self , tensors : List [torch .Tensor ], opts : object ) -> Work :
676+ tensor_lists = self ._split_tensors (tensors )
677+
678+ with self ._run_context ():
679+ works = []
680+ for i in range (self ._count ):
681+ works .append (
682+ self ._pgs [i ].allreduce (tensor_lists [i ], self ._opts_hook (opts ))
683+ )
684+
685+ return self ._wrap_work (_ParallelWork (works ), opts )
686+
687+ def reduce (self , tensors : List [torch .Tensor ], dst : int , opts : object ) -> Work :
688+ tensor_lists = self ._split_tensors (tensors )
689+
690+ with self ._run_context ():
691+ works = []
692+ for i in range (self ._count ):
693+ works .append (
694+ self ._pgs [i ].reduce (tensor_lists [i ], dst , self ._opts_hook (opts ))
695+ )
696+
697+ return self ._wrap_work (_ParallelWork (works ), opts )
698+
699+ def send (self , tensors : List [torch .Tensor ], dst_rank : int , tag : int ) -> Work :
700+ tensor_lists = self ._split_tensors (tensors )
701+
702+ with self ._run_context ():
703+ works = []
704+ for i in range (self ._count ):
705+ works .append (self ._pgs [i ].send (tensor_lists [i ], dst_rank , tag ))
706+
707+ return self ._wrap_work (_ParallelWork (works ), None )
708+
709+ def recv (self , tensors : List [torch .Tensor ], src_rank : int , tag : int ) -> Work :
710+ tensor_lists = self ._split_tensors (tensors )
711+
712+ with self ._run_context ():
713+ works = []
714+ for i in range (self ._count ):
715+ works .append (self ._pgs [i ].recv (tensor_lists [i ], src_rank , tag ))
716+
717+ return self ._wrap_work (_ParallelWork (works ), None )
718+
719+
614720class _WorkCUDATimeout (Work ):
615721 def __init__ (self , pg : ProcessGroup , work : Work , timeout : timedelta ) -> None :
616722 super ().__init__ ()
0 commit comments