diff --git a/.github/workflows/docs.yaml b/.github/workflows/docs.yaml index f9a65d9..2bf58d9 100644 --- a/.github/workflows/docs.yaml +++ b/.github/workflows/docs.yaml @@ -25,7 +25,7 @@ jobs: sudo apt-get install -y protobuf-compiler - pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu128 + pip install --pre torch torchvision torchaudio torchcomms --index-url https://download.pytorch.org/whl/nightly/cu128 pip install .[dev] -v pip install -r docs/requirements.txt diff --git a/.github/workflows/lint.yaml b/.github/workflows/lint.yaml index 08ba6f2..370a780 100644 --- a/.github/workflows/lint.yaml +++ b/.github/workflows/lint.yaml @@ -23,7 +23,7 @@ jobs: sudo apt-get install -y protobuf-compiler - pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu128 + pip install --pre torch torchvision torchaudio torchcomms --index-url https://download.pytorch.org/whl/nightly/cu128 pip install .[dev] -v # install recent version of Rust via rustup diff --git a/.github/workflows/unittest-mac.yaml b/.github/workflows/unittest-mac.yaml index 6743378..324371e 100644 --- a/.github/workflows/unittest-mac.yaml +++ b/.github/workflows/unittest-mac.yaml @@ -12,7 +12,7 @@ jobs: steps: - name: Checkout uses: actions/checkout@v4 - + - name: Setup miniconda uses: pytorch/test-infra/.github/actions/setup-miniconda@main with: @@ -39,7 +39,7 @@ jobs: python -m pip install --upgrade pip - pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cpu + pip install --pre torch torchvision torchaudio torchcomms --index-url https://download.pytorch.org/whl/nightly/cpu pip install -e .[dev] -v diff --git a/.github/workflows/unittest.yaml b/.github/workflows/unittest.yaml index d1f5bd5..9404328 100644 --- a/.github/workflows/unittest.yaml +++ b/.github/workflows/unittest.yaml @@ -41,10 +41,10 @@ jobs: # Optionally install torch nightly, pulls latest CUDA from pip otherwise if [ "${{ matrix.torch-version }}" = "nightly" ]; then - pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu128 + pip install --pre torch torchvision torchaudio torchcomms --index-url https://download.pytorch.org/whl/nightly/cu128 fi if [ "${{ matrix.torch-version }}" = "test" ]; then - pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/test/cu128 + pip install --pre torch torchvision torchaudio torchcomms --index-url https://download.pytorch.org/whl/test/cu128 fi # Install dependencies diff --git a/.pyre_configuration b/.pyre_configuration index 7913bbe..ffe1e60 100644 --- a/.pyre_configuration +++ b/.pyre_configuration @@ -13,6 +13,9 @@ }, { "site-package": "parameterized" + }, + { + "site-package": "torchcomms" } ] } diff --git a/torchft/__init__.py b/torchft/__init__.py index b07e128..d93100e 100644 --- a/torchft/__init__.py +++ b/torchft/__init__.py @@ -4,6 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +from torchft.comms import TorchCommGloo, TorchCommNCCL from torchft.data import DistributedSampler from torchft.ddp import DistributedDataParallel from torchft.manager import Manager @@ -31,4 +32,6 @@ "ProcessGroupBabyNCCL", "ProcessGroupBabyXCCL", "ProcessGroupGloo", + "TorchCommNCCL", + "TorchCommGloo", ) diff --git a/torchft/comms.py b/torchft/comms.py new file mode 100644 index 0000000..c3d3104 --- /dev/null +++ b/torchft/comms.py @@ -0,0 +1,709 @@ +# pyre-strict +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +TorchComm Integration Library +=============================== + +This module provides a wrapper around torchcomms that is used for cross replica +communication within torchft. + +This uses torchcomms directly while providing a compatible interface for +reconfiguration and collective operations. + +Usage: + # For Gloo backend + comm_gloo = TorchCommGloo(timeout=timedelta(seconds=60)) + comm_gloo.configure(store_addr, replica_id, rank, world_size) + + # For NCCL backend + comm_nccl = TorchCommNCCL(timeout=timedelta(seconds=60)) + comm_nccl.configure(store_addr, replica_id, rank, world_size) +""" + +import logging +import os +import warnings +from contextlib import contextmanager +from datetime import timedelta +from typing import Dict, Generator, List, Optional, Union + +import torch +import torch.distributed as dist +import torchcomms +from torch.distributed.distributed_c10d import AllreduceOptions, ReduceOp +from torchft.futures import context_timeout, stream_timeout +from torchft.process_group import ( + create_store_client, + TORCHFT_TRIGGER_FR_ON_ABORT, + trigger_nccl_fr_trace_through_pipe, +) + +logger: logging.Logger = logging.getLogger(__name__) + + +class TorchWork(dist._Work): + """ + Timeout wrapper for TorchWork that wraps TorchWork objects to + add timeout handling for wait operations. + + Args: + comm: The TorchComm instance to abort on timeout + work: The TorchWork object to wrap + timeout: The timeout duration for operations + """ + + def __init__( + self, + comm: "TorchComm", + work: torchcomms.TorchWork, + value: object, + timeout: timedelta, + ) -> None: + super().__init__() + self._comm: "TorchComm" = comm + self._work: torchcomms.TorchWork = work + self._value: object = value + self._timeout: timedelta = timeout + + self._fut: torch.futures.Future[object] = torch.futures.Future() + self._fut.set_result(self._value) + + def wait(self, timeout: Optional[timedelta] = None) -> bool: + """ + Wait for the work to complete with timeout handling. + + Args: + timeout: Optional timeout override + """ + async_timeout = timeout or self._timeout + with self._stream_timeout(self._comm, async_timeout): + if self._work is not None: + self._work.wait() + + # Always use cuda stream for timeout to avoid ProcessGroupNCCL + # watchdog firing and crashing the process. + if timeout is not None: + torch.cuda.synchronize() + + return True + + def get_future( + self, + ) -> torch.futures.Future[object]: + return self._fut + + def is_completed(self) -> bool: + """Check if the work is completed.""" + return self._work.is_completed() if self._work is not None else True + + def block_current_stream(self, timeout: Optional[timedelta] = None) -> None: + raise NotImplementedError("The method is not supposed to be called") + + def synchronize(self) -> None: + raise NotImplementedError("The method is not supposed to be called") + + @classmethod + @contextmanager + def _stream_timeout( + cls, comm: "TorchComm", timeout: timedelta + ) -> Generator[None, None, None]: + """ + Set a timeout on the CUDA stream for the given comm. + + This does not hold a reference to self to avoid holding the work + object/tensors longer than necessary. + + Args: + comm: The TorchComm to call abort on. + timeout: The timeout to set on the CUDA stream. + """ + + def callback() -> None: + logger.error(f"aborting after {timeout}!") + comm.abort() + + # make sure .wait() can be cancelled if it blocks i.e. in barrier + with context_timeout(callback, timeout): + yield + + # Cancel work if the cuda stream doesn't complete + stream_timeout(callback, timeout) + + +class TorchComm: + """ + Base wrapper for torchcomms providing a process group-like interface. + + This provides the common implementation for both Gloo and NCCL backends + using torchcomms as the underlying communication library. + + Args: + backend: torchcomms backend name (e.g., "gloo", "nccl") + timeout: default timeout for operations + device: torch device to use (e.g., "cpu", "cuda") + """ + + def __init__( + self, + backend: str, + timeout: timedelta, + device: torch.device, + ) -> None: + self._backend = backend + self._timeout = timeout + self._device = device + self._comm: Optional[torchcomms.TorchComm] = None + self._replica_id: Optional[str] = None + self._rank: Optional[int] = None + self._world_size: Optional[int] = None + self._quorum_id: Optional[int] = None + self._group_rank: Optional[int] = None + self._group_world_size: Optional[int] = None + self._global_ranks: Optional[List[int]] = None + self._errored: Optional[Exception] = None + + self.errors_logger: logging.Logger = logging.getLogger("torchft_errors") + + def _wrap_work(self, work: torchcomms.TorchWork, value: object) -> TorchWork: + """ + Wrap work object to allow intercepting wait/synchronization. + + Subclasses can override this to provide custom work wrapping, + such as adding timeouts or error handling. + + Args: + work: The work object to wrap + value: The tensor or value associated with this work + + Returns: + The wrapped work object (or original if no wrapping needed) + """ + return TorchWork(self, work, value, self._timeout) + + @contextmanager + def _run_context(self) -> Generator[None, None, None]: + """ + Context manager for running collective operations. + + Subclasses can override this to provide custom behavior around + collective operations, such as timeout management or error handling. + + Yields: + None + """ + yield + + def configure( + self, + store_addr: str, + replica_id: str, + rank: int, + world_size: int, + quorum_id: Optional[int] = None, + group_rank: Optional[int] = None, + group_world_size: Optional[int] = None, + global_ranks: Optional[List[int]] = None, + ) -> None: + """ + Reconfigure the communication group with new parameters. + + Args: + store_addr: address of the store to use (host:port/prefix) + replica_id: the replica_id for this group + rank: rank of this process + world_size: world size of this communication group + quorum_id: current quorum's identifier + group_rank: local rank within the replica group + group_world_size: the number of ranks within a replica + global_ranks: the global ranks part of this group + """ + self._replica_id = replica_id + self._rank = rank + self._world_size = world_size + self._quorum_id = quorum_id + self._group_rank = group_rank + self._group_world_size = group_world_size + self._global_ranks = global_ranks + + # Shutdown existing comm if present + if self._comm is not None: + self.shutdown() + + store = create_store_client(store_addr, self._timeout) + + # Build communication name and hints + comm_name = f"torchft_{replica_id}_q{quorum_id}_r{rank}" + hints: Dict[str, str] = {} + + # TODO: unused currently but this can be used to set metadata for + # flight recorder + if self._global_ranks: + hints["global_ranks"] = ",".join(str(r) for r in self._global_ranks) + if self._group_rank is not None and self._group_world_size is not None: + hints["group_name"] = ( + f"torchft_quorum_{self._quorum_id}_" + f"rank_{self._group_rank % self._group_world_size}" + ) + + # Set the ranks properly for the cross replica process group + os.environ["TORCHCOMM_RANK"] = str(rank) + os.environ["TORCHCOMM_SIZE"] = str(world_size) + + # Create torchcomms communicator + self._comm = torchcomms.new_comm( + backend=self._backend, + device=self._device, + abort_process_on_timeout_or_error=False, + timeout=self._timeout, + store=store, + name=comm_name, + hints=hints, + ) + + self._errored = None + + def shutdown(self) -> None: + """Shutdown the communication group.""" + if self._comm is not None: + self._comm.finalize() + self._comm = None + + def abort(self) -> None: + """ + Abort the communication group with error logging. + + This logs the error before shutting down the communicator. + """ + self._errored = RuntimeError("aborted") + + self.errors_logger.info( + "", + extra={ + "job_id": os.environ.get("JOB_ID", "unknown"), + "replica_id": self._replica_id, + "rank": self._rank, + "quorum_id": self._quorum_id, + "error": "torchcomm_abort", + }, + ) + + # Trigger NCCL flight recorder trace if enabled + if ( + os.environ.get(TORCHFT_TRIGGER_FR_ON_ABORT, "0") == "1" + and self._rank is not None + ): + trigger_nccl_fr_trace_through_pipe(self._rank) + + self.shutdown() + + def errored(self) -> Optional[Exception]: + """Check if an error has occurred (torchcomms compatible method).""" + return self._errored + + # Collective operations - all operations use async_op=True and return TorchWork + # Users should call .wait() on the returned work object to synchronize + + def allgather( + self, + output_tensors: List[torch.Tensor], + input_tensor: torch.Tensor, + ) -> TorchWork: + """ + Gather tensors from all ranks into output_tensors. + + Args: + output_tensors: List of output tensors, one per rank + input_tensor: Input tensor to gather from this rank + + Returns: + TorchWork object that can be waited on + """ + with self._run_context(): + assert self._comm is not None + work = self._comm.all_gather( + tensor_list=output_tensors, + tensor=input_tensor, + async_op=True, + ) + return self._wrap_work(work, input_tensor) + + def allgather_single( + self, + output: torch.Tensor, + input: torch.Tensor, + ) -> TorchWork: + """ + Gather tensors from all ranks into a single output tensor. + + Args: + output: Output tensor (size: world_size * input.size()) + input: Input tensor to gather from this rank + + Returns: + TorchWork object that can be waited on + """ + with self._run_context(): + assert self._comm is not None + work = self._comm.all_gather_single( + output=output, + input=input, + async_op=True, + ) + return self._wrap_work(work, input) + + def allreduce( + self, + tensors: list[torch.Tensor], + opts: Union[AllreduceOptions, ReduceOp, torchcomms.ReduceOp], + ) -> TorchWork: + """ + Reduce tensor across all ranks. + + Args: + tensor: Tensor to reduce (in-place) + op: Reduction operation (default: SUM) + + Returns: + TorchWork object that can be waited on + """ + assert len(tensors) == 1 + + if isinstance(opts, ReduceOp): + if opts == ReduceOp.SUM: + tc_opts = torchcomms.ReduceOp.SUM + elif opts == ReduceOp.AVG: + tc_opts = torchcomms.ReduceOp.AVG + else: + raise AssertionError("unsupported reduce op") + elif isinstance(opts, AllreduceOptions): + if opts.reduceOp == ReduceOp.SUM: + tc_opts = torchcomms.ReduceOp.SUM + elif opts.reduceOp == ReduceOp.AVG: + tc_opts = torchcomms.ReduceOp.AVG + else: + raise AssertionError("unsupported reduce op") + elif isinstance(opts, torchcomms.ReduceOp): + tc_opts = opts + else: + raise AssertionError("unsupported reduce option type") + + with self._run_context(): + assert self._comm is not None + work = self._comm.all_reduce( + tensor=tensors[0], + op=tc_opts, + async_op=True, + ) + return self._wrap_work(work, tensors[0]) + + def alltoall_single( + self, + output: torch.Tensor, + input: torch.Tensor, + ) -> TorchWork: + """ + All-to-all scatter/gather operation with single tensors. + + Args: + output: Output tensor + input: Input tensor + + Returns: + TorchWork object that can be waited on + """ + with self._run_context(): + assert self._comm is not None + work = self._comm.all_to_all_single( + output=output, + input=input, + async_op=True, + ) + return self._wrap_work(work, input) + + def alltoall_v_single( + self, + output: torch.Tensor, + input: torch.Tensor, + output_split_sizes: List[int], + input_split_sizes: List[int], + ) -> TorchWork: + """ + All-to-all scatter/gather operation with variable sizes. + + Args: + output: Output tensor + input: Input tensor + output_split_sizes: Sizes for splitting output + input_split_sizes: Sizes for splitting input + + Returns: + TorchWork object that can be waited on + """ + with self._run_context(): + assert self._comm is not None + work = self._comm.all_to_all_v_single( + output=output, + input=input, + output_split_sizes=output_split_sizes, + input_split_sizes=input_split_sizes, + async_op=True, + ) + return self._wrap_work(work, input) + + def barrier(self) -> TorchWork: + """ + Synchronize all processes. + + Returns: + TorchWork object that can be waited on + """ + with self._run_context(): + assert self._comm is not None + work = self._comm.barrier(async_op=True) + return self._wrap_work(work, None) + + def broadcast( + self, + tensor: torch.Tensor, + root: int, + ) -> TorchWork: + """ + Broadcast tensor from root to all other ranks. + + Args: + tensor: Tensor to broadcast + root: Root rank + + Returns: + TorchWork object that can be waited on + """ + with self._run_context(): + assert self._comm is not None + work = self._comm.broadcast( + tensor=tensor, + root=root, + async_op=True, + ) + return self._wrap_work(work, tensor) + + def reduce_scatter( + self, + output: torch.Tensor, + input_list: List[torch.Tensor], + op: torchcomms.ReduceOp = torchcomms.ReduceOp.SUM, + ) -> TorchWork: + """ + Reduce and scatter tensors across all ranks. + + Args: + output: Output tensor + input_list: List of input tensors + op: Reduction operation (default: SUM) + + Returns: + TorchWork object that can be waited on + """ + with self._run_context(): + assert self._comm is not None + work = self._comm.reduce_scatter( + output=output, + input_list=input_list, + op=op, + async_op=True, + ) + return self._wrap_work(work, output) + + def reduce_scatter_single( + self, + output: torch.Tensor, + input: torch.Tensor, + op: torchcomms.ReduceOp = torchcomms.ReduceOp.SUM, + ) -> TorchWork: + """ + Reduce and scatter with single tensors. + + Args: + output: Output tensor + input: Input tensor + op: Reduction operation (default: SUM) + + Returns: + TorchWork object that can be waited on + """ + with self._run_context(): + assert self._comm is not None + work = self._comm.reduce_scatter_single( + output=output, + input=input, + op=op, + async_op=True, + ) + return self._wrap_work(work, output) + + def send( + self, + tensor: torch.Tensor, + dst: int, + ) -> TorchWork: + """ + Send tensor to destination rank. + + Args: + tensor: Tensor to send + dst: Destination rank + + Returns: + TorchWork object that can be waited on + """ + with self._run_context(): + assert self._comm is not None + work = self._comm.send( + tensor=tensor, + dst=dst, + async_op=True, + ) + return self._wrap_work(work, tensor) + + def recv( + self, + tensor: torch.Tensor, + src: int, + ) -> TorchWork: + """ + Receive tensor from source rank. + + Args: + tensor: Tensor to receive into + src: Source rank + + Returns: + TorchWork object that can be waited on + """ + with self._run_context(): + assert self._comm is not None + work = self._comm.recv( + tensor=tensor, + src=src, + async_op=True, + ) + return self._wrap_work(work, tensor) + + +class TorchCommGloo(TorchComm): + """ + Gloo backend wrapper for torchcomms. + + This provides a drop-in replacement for ProcessGroupGloo using torchcomms. + + Args: + timeout: Default timeout for operations (default: 60 seconds) + + Example: + comm = TorchCommGloo() + comm.configure(store_addr="localhost:1234/prefix", replica_id="r0", + rank=0, world_size=4) + tensor = torch.randn(10) + work = comm.allreduce(tensor) + work.wait() + """ + + def __init__(self, timeout: timedelta = timedelta(seconds=60)) -> None: + super().__init__( + backend="gloo", + timeout=timeout, + device=torch.device("cpu"), + ) + + def reduce_scatter( + self, + output: torch.Tensor, + input_list: List[torch.Tensor], + op: torchcomms.ReduceOp = torchcomms.ReduceOp.SUM, + ) -> TorchWork: + """ + Gloo backend does not support reduce_scatter. + + Raises: + NotImplementedError: Always raised + """ + raise NotImplementedError("Gloo backend does not support reduce_scatter") + + def reduce_scatter_single( + self, + output: torch.Tensor, + input: torch.Tensor, + op: torchcomms.ReduceOp = torchcomms.ReduceOp.SUM, + ) -> TorchWork: + """ + Gloo backend does not support reduce_scatter_single. + + Raises: + NotImplementedError: Always raised + """ + raise NotImplementedError("Gloo backend does not support reduce_scatter_single") + + +class TorchCommNCCL(TorchComm): + """ + NCCL backend wrapper for torchcomms. + + This provides a drop-in replacement for ProcessGroupNCCL using torchcomms. + + If you are using a supported version of NCCL (NCCL >= 2.26, torch >= 2.7) + this will attempt to use ncclCommAbort to recover from any timeouts. + + Args: + timeout: Default timeout for operations (default: 60 seconds) + + Example: + comm = TorchCommNCCL() + comm.configure(store_addr="localhost:1234/prefix", replica_id="r0", + rank=0, world_size=4) + tensor = torch.randn(10).cuda() + work = comm.allreduce(tensor) + work.wait() + """ + + def __init__(self, timeout: timedelta = timedelta(seconds=60.0)) -> None: + super().__init__( + backend="nccl", + timeout=timeout, + device=torch.device("cuda", torch.cuda.current_device()), + ) + self._use_abort: bool = torch.cuda.nccl.version() >= (2, 25) + + NONBLOCKING_TIMEOUT_ENV = "TORCH_NCCL_NONBLOCKING_TIMEOUT" + if NONBLOCKING_TIMEOUT_ENV not in os.environ: + warnings.warn( + f"{NONBLOCKING_TIMEOUT_ENV} is not set, defaulting to {timeout}. " + "If any nonblocking NCCL operations have already run this may " + "result in the default timeout of 30 minutes and hangs on error.", + stacklevel=2, + ) + os.environ[NONBLOCKING_TIMEOUT_ENV] = str(timeout.total_seconds()) + + @contextmanager + def _run_context(self) -> Generator[None, None, None]: + """ + Context manager for running collective operations with timeout. + + Yields: + None + """ + if not self._use_abort: + yield + + timeout: timedelta = self._timeout + + def callback() -> None: + logger.error(f"aborting after {timeout}!") + self.abort() + + # when running in blocking mode we need to make sure collectives can timeout + with context_timeout(callback, timeout): + yield diff --git a/torchft/manager.py b/torchft/manager.py index 7e78584..6311794 100644 --- a/torchft/manager.py +++ b/torchft/manager.py @@ -57,6 +57,7 @@ from torchft._torchft import ManagerClient, ManagerServer from torchft.checkpointing import CheckpointTransport, HTTPTransport from torchft.checkpointing._rwlock import RWLock +from torchft.comms import TorchComm from torchft.futures import future_timeout from torchft.utils import get_stream_context, synchronize from torchft.work import _DummyWork @@ -163,7 +164,7 @@ class Manager: def __init__( self, - pg: "ProcessGroup", + pg: Union["ProcessGroup", TorchComm], load_state_dict: Optional[Callable[[T], None]], state_dict: Optional[Callable[[], T]], min_replica_size: int, @@ -188,6 +189,7 @@ def __init__( ) -> None: """ Args: + pg: process group or torchcomms wrapper to use for communication. load_state_dict: function to load the state dict when recovering state_dict: function to save the state dict with recovering min_replica_size: minimum number of replicas on each step @@ -221,7 +223,7 @@ def __init__( replica_id: if rank==0, the replica_id for this group hostname: if rank==0, the hostname to advertise to the lighthouse server checkpoint_transport: the checkpoint transport to use for - transfering checkpoints to recovering replicas, defaults to HTTPTransport + transferring checkpoints to recovering replicas, defaults to HTTPTransport init_sync: whether to synchronize the model weights on step 0. If all of the model weights are initialized identically via ``torch.set_seed`` you should set this to False. @@ -456,7 +458,9 @@ def allreduce( try: # Run the allreduce async and save the work object so we can wait on # it later. + # TODO: Support quantization with torchcomms if should_quantize and IS_TRITON_AVAILABLE: + assert isinstance(self._pg, ProcessGroup) work = allreduce_quantized( [tensor], pg_reduce_op, @@ -473,16 +477,16 @@ def allreduce( # on the Future @torch.profiler.record_function("torchft::manager::allreduce::callback") def callback( - fut: torch.futures.Future[torch.Tensor], + fut: torch.futures.Future[list[torch.Tensor]], ) -> torch.Tensor: nonlocal tensor if reduce_op == ReduceOp.AVG: tensor /= num_participants return tensor - managed_work = _ManagedWork(self, work, tensor) + managed_work = _ManagedWork(self, work, [tensor]) fut = managed_work.get_future() - fut = cast(torch.futures.Future[torch.Tensor], fut) + fut = cast(torch.futures.Future[list[torch.Tensor]], fut) fut = fut.then(callback) return managed_work @@ -1223,13 +1227,13 @@ def __init__( ) -> None: super().__init__() # Underlying `Work` retruned from process group operations - self._work = work + self._work: dist._Work = work # Used to report errors to the manager through `wrap_future()` - self._manager = manager + self._manager: Manager = manager # The value returned by the final future in the callback chain - self._value = value + self._value: object = value # The head of the callback chain self._managed_fut_head = _ManagedFuture[object](weakref.ref(self)) diff --git a/torchft/manager_test.py b/torchft/manager_test.py index 4d2dc42..08db9e3 100644 --- a/torchft/manager_test.py +++ b/torchft/manager_test.py @@ -387,6 +387,7 @@ def test_allreduce_error(self, client_mock: MagicMock) -> None: self.assertTrue(manager._errored) # this should be skipped due to error manager.allreduce(torch.tensor([1.0])).wait() + # pyre-ignore[16]: _pg is mocked self.assertEqual(manager._pg.allreduce.call_count, 2) # pyre-ignore[16]: _pg is mocked self.assertEqual(manager._pg.allreduce.return_value.get_future.call_count, 1) @@ -408,12 +409,14 @@ def test_allreduce_error(self, client_mock: MagicMock) -> None: bad_fut.set_exception(RuntimeError("injected failure")) manager._pg.allreduce.return_value.get_future.return_value = bad_fut manager.allreduce(torch.tensor([1.0])).wait() + # pyre-ignore[16]: _pg is mocked self.assertEqual(manager._pg.allreduce.return_value.get_future.call_count, 2) self.assertTrue(manager._errored) self.assertFalse(manager.should_commit()) self.assertTrue(manager._errored) # cleanup + # pyre-ignore[16]: _pg is mocked manager._pg.allreduce.reset_mock(return_value=True) # recover on next step diff --git a/torchft/process_group.py b/torchft/process_group.py index 87fd599..269d240 100644 --- a/torchft/process_group.py +++ b/torchft/process_group.py @@ -42,6 +42,7 @@ import torch import torch.distributed as dist import torch.multiprocessing as mp +import torchcomms # pyre-fixme[21]: no attribute ProcessGroupGloo from torch.distributed import ( @@ -167,7 +168,7 @@ def allgather_into_tensor_coalesced( def allreduce( self, tensors: List[torch.Tensor], - opts: Union[AllreduceOptions, ReduceOp], + opts: Union[AllreduceOptions, ReduceOp, torchcomms.ReduceOp], ) -> Work: """ Reduces the tensor data across all machines in such a way that all get the final result. @@ -555,7 +556,9 @@ def allreduce(self, tensors: List[torch.Tensor], opts: object) -> Work: ) def allreduce_coalesced( - self, tensors: List[torch.Tensor], opts: Union[AllreduceOptions, ReduceOp] + self, + tensors: List[torch.Tensor], + opts: Union[AllreduceOptions, ReduceOp, torchcomms.ReduceOp], ) -> Work: with self._run_context(): return self._wrap_work( @@ -1068,7 +1071,9 @@ def allreduce(self, tensors: List[torch.Tensor], opts: object) -> Work: return res def allreduce_coalesced( - self, tensors: List[torch.Tensor], opts: Union[AllreduceOptions, ReduceOp] + self, + tensors: List[torch.Tensor], + opts: Union[AllreduceOptions, ReduceOp, torchcomms.ReduceOp], ) -> Work: res = _DummyWork(tensors) self._work.append(res) @@ -1331,6 +1336,7 @@ class ManagedProcessGroup(ProcessGroupWrapper): """ def __init__(self, manager: "Manager") -> None: + assert isinstance(manager._pg, ProcessGroup) super().__init__(pg=manager._pg) self._manager = manager @@ -1350,6 +1356,7 @@ def size(self) -> int: return self._manager.num_participants() def getBackendName(self) -> str: + assert isinstance(self._manager._pg, ProcessGroup) return self._manager._pg.getBackendName() @@ -1827,7 +1834,7 @@ def allgather_into_tensor_coalesced( def allreduce( self, tensors: List[torch.Tensor], - opts: Union[dist.AllreduceOptions, dist.ReduceOp], + opts: Union[AllreduceOptions, ReduceOp, torchcomms.ReduceOp], ) -> Work: _assert_list(tensors) _maybe_share_tensors(tensors) diff --git a/train_diloco.py b/train_diloco.py index bb61599..10f49be 100644 --- a/train_diloco.py +++ b/train_diloco.py @@ -15,6 +15,7 @@ USE_STREAMING = os.getenv("USE_STREAMING", "False") == "True" USE_NCCL = os.getenv("USE_NCCL", "False") == "True" +USE_PG = os.getenv("USE_PG", "False") == "True" import torch import torch.nn.functional as F @@ -31,6 +32,8 @@ ProcessGroupBabyNCCL, ProcessGroupGloo, ProcessGroupNCCL, + TorchCommGloo, + TorchCommNCCL, ) from torchft.checkpointing.http_transport import HTTPTransport from torchft.local_sgd import DiLoCo @@ -60,13 +63,21 @@ def state_dict(): } device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - pg = ( - ProcessGroupNCCL( - timeout=timedelta(seconds=10), + + if USE_PG: + pg = ( + ProcessGroupNCCL( + timeout=timedelta(seconds=10), + ) + if torch.cuda.is_available() and USE_NCCL + else ProcessGroupGloo(timeout=timedelta(seconds=10)) + ) + else: + pg = ( + TorchCommNCCL(timeout=timedelta(seconds=10)) + if USE_NCCL + else TorchCommGloo(timeout=timedelta(seconds=10)) ) - if torch.cuda.is_available() and USE_NCCL - else ProcessGroupGloo(timeout=timedelta(seconds=10)) - ) transport = HTTPTransport( timeout=timedelta(seconds=10), @@ -231,7 +242,7 @@ def trace_handler(p): if manager.current_step() % 100 == 0: print(f"[{manager.current_step()}] loss = {loss.item()}") - if manager.current_step() >= 15: + if manager.current_step() >= 150: # complete training prof.stop() writer.flush()