From 99f9863da9e49187971b8334417a5da066bf0c20 Mon Sep 17 00:00:00 2001 From: zhengchenyu Date: Tue, 18 Nov 2025 15:18:04 +0800 Subject: [PATCH 1/2] Keep the training data continuous and the total batch size constant regardless of changes in the replica world size. --- torchft/data.py | 107 +++++++++++++++++- torchft/data_test.py | 61 ++++++++++- torchft/manager.py | 92 +++++++++++++++- torchft/manager_test.py | 128 +++++++++++++++++++++- torchft/optim.py | 4 +- train_ddp2.py | 237 ++++++++++++++++++++++++++++++++++++++++ 6 files changed, 619 insertions(+), 10 deletions(-) create mode 100644 train_ddp2.py diff --git a/torchft/data.py b/torchft/data.py index 02e5b3be..5f1b6a55 100644 --- a/torchft/data.py +++ b/torchft/data.py @@ -14,11 +14,114 @@ dataloader frequently to avoid duplicate batches. """ -from typing import Optional - +import torch import torch.distributed as dist +from torch.utils.data.dataset import Dataset +from torch.utils.data.sampler import Sampler from torch.utils import data +import math +from collections.abc import Iterator +from typing import Optional, TypeVar + +_T_co = TypeVar("_T_co", covariant=True) + +class SkipDistributedSampler(Sampler[_T_co]): + def __init__( + self, + dataset: Dataset, + num_replicas: Optional[int] = None, + rank: Optional[int] = None, + shuffle: bool = True, + seed: int = 0, + drop_last: bool = False, + skip_samples: int = 0, + ) -> None: + if num_replicas is None: + if not dist.is_available(): + raise RuntimeError("Requires distributed package to be available") + num_replicas = dist.get_world_size() + if rank is None: + if not dist.is_available(): + raise RuntimeError("Requires distributed package to be available") + rank = dist.get_rank() + if rank >= num_replicas or rank < 0: + raise ValueError( + f"Invalid rank {rank}, rank should be in the interval [0, {num_replicas - 1}]" + ) + self.dataset = dataset + self.num_replicas = num_replicas + self.rank = rank + self.epoch = 0 + self.drop_last = drop_last + self.skip_samples = skip_samples + # If the dataset length is evenly divisible by # of replicas, then there + # is no need to drop any data, since the dataset will be split equally. + if self.drop_last and len(self.dataset) % self.num_replicas != 0: # type: ignore[arg-type] + # Split to nearest available length that is evenly divisible. + # This is to ensure each rank receives the same amount of data when + # using this Sampler. + self.num_samples = math.ceil( + (len(self.dataset) - self.skip_samples - self.num_replicas) / self.num_replicas # type: ignore[arg-type] + ) + else: + self.num_samples = math.ceil((len(self.dataset) - self.skip_samples) / self.num_replicas) # type: ignore[arg-type] + self.total_size = self.num_samples * self.num_replicas + self.shuffle = shuffle + self.seed = seed + + def __iter__(self) -> Iterator[_T_co]: + if self.shuffle: + # deterministically shuffle based on epoch and seed + g = torch.Generator() + g.manual_seed(self.seed + self.epoch) + indices = torch.randperm(len(self.dataset), generator=g).tolist() # type: ignore[arg-type] + else: + indices = list(range(len(self.dataset))) # type: ignore[arg-type] + + if not self.drop_last: + indices = indices[self.skip_samples: len(indices)] + # add extra samples to make it evenly divisible + padding_size = self.total_size - len(indices) + if padding_size <= len(indices): + indices += indices[:padding_size] + else: + indices += (indices * math.ceil(padding_size / len(indices)))[ + :padding_size + ] + else: + # remove tail of data to make it evenly divisible. + indices = indices[self.skip_samples : self.skip_samples + self.total_size] + if len(indices) != self.total_size: + raise AssertionError( + f"Number of indices ({len(indices)}) does not match total_size ({self.total_size})" + ) + + # subsample + indices = indices[self.rank : self.total_size : self.num_replicas] + if len(indices) != self.num_samples: + raise AssertionError( + f"Number of subsampled indices ({len(indices)}) does not match num_samples ({self.num_samples})" + ) + + # pyrefly: ignore # bad-return + return iter(indices) + + def __len__(self) -> int: + return self.num_samples + + def set_epoch(self, epoch: int) -> None: + r""" + Set the epoch for this sampler. + + When :attr:`shuffle=True`, this ensures all replicas + use a different random ordering for each epoch. Otherwise, the next iteration of this + sampler will yield the same ordering. + + Args: + epoch (int): Epoch number. + """ + self.epoch = epoch # pyre-fixme[24]: expected generic parameter class DistributedSampler(data.distributed.DistributedSampler): diff --git a/torchft/data_test.py b/torchft/data_test.py index 8dae190e..37a6a331 100644 --- a/torchft/data_test.py +++ b/torchft/data_test.py @@ -8,7 +8,7 @@ from torch.utils.data import Dataset -from torchft.data import DistributedSampler +from torchft.data import DistributedSampler, SkipDistributedSampler class DummyDataset(Dataset): @@ -37,3 +37,62 @@ def test_distributed_sampler(self) -> None: sampler_iter = iter(sampler) self.assertEqual(next(sampler_iter), 500) + + def test_skip_distributed_sampler(self): + dataset_length = 100 + dataset = DummyDataset(dataset_length) + + # Case 1: sample is not skipped + for drop_last in [True, False]: + num_replicas = 7 + for rank in range(num_replicas): + sampler = SkipDistributedSampler(dataset=dataset, num_replicas=num_replicas, + rank=rank, shuffle=False, drop_last=drop_last) + cur = rank + for idx in sampler: + self.assertEqual(idx, (cur % dataset_length), f"idx={idx}, cur={cur}") + cur += num_replicas + # If drop_last is True, read ceil((100-7)/7)*7=98 samples totally. + # If drop_last is False, read ceil(100/7)*7=105 samples totally. + if drop_last: + self.assertEqual(cur, 98 + rank, f"rank={rank}, cur={cur}") + else: + self.assertEqual(cur, 105 + rank, f"rank={rank}, cur={cur}") + + # Case 2: sample is skipped + for drop_last in [True, False]: + num_replicas = 7 + skip_samples = 10 + for rank in range(num_replicas): + sampler = SkipDistributedSampler(dataset=dataset, num_replicas=num_replicas, + rank=rank, shuffle=False, drop_last=drop_last, + skip_samples=skip_samples) + cur = rank + for idx in sampler: + expected = ((cur + skip_samples) % dataset_length + skip_samples) \ + if (cur + skip_samples) >= dataset_length else (cur + skip_samples) + self.assertEqual(idx, expected, f"idx={idx}, expected={expected}") + cur += num_replicas + # If drop_last is True, read ceil((100-10-7)/7)*7=84 samples totally. + # If drop_last is False, read ceil((100-10)/7)*7=91 samples totally. + if drop_last: + self.assertEqual(cur, 84 + rank, f"rank={rank}, cur={cur}") + else: + self.assertEqual(cur, 91 + rank, f"rank={rank}, cur={cur}") + + # Case 3: drop last is False and padding size is larger than number of indices + # If skip_samples is 90, and num_replicas is 31, then the indices is [90, 92, ..., 99]. + # It means only 10 samples are left, so padding size is 21 which is larger than 10. + num_replicas = 31 + skip_samples = 90 + expected = list(range(90, 100)) + expected = (expected * 4)[:31] + for rank in range(num_replicas): + sampler = SkipDistributedSampler(dataset=dataset, num_replicas=num_replicas, + rank=rank, shuffle=False, drop_last=False, + skip_samples=skip_samples) + cnt = 0 + for idx in sampler: + self.assertEqual(idx, expected[rank], f"idx={idx}, rank={rank}, expected={expected}") + cnt += 1 + self.assertTrue(cnt, 1) diff --git a/torchft/manager.py b/torchft/manager.py index 7e785846..b8ae26be 100644 --- a/torchft/manager.py +++ b/torchft/manager.py @@ -26,6 +26,7 @@ """ import concurrent.futures +import gc import logging import os import socket @@ -185,6 +186,7 @@ def __init__( init_sync: bool = True, max_retries: Optional[int] = None, quorum_retries: int = 0, + dataloader_fn: Optional[Callable[[int, int, int], None]] = None, ) -> None: """ Args: @@ -365,6 +367,17 @@ def __init__( self._update_fr_path() + # The number of batches committed in the current epoch.Compare to _batches_committed, + # _current_batches_committed will reset to 0 when next epoch starts. + self._current_batches_committed = 0 + self._epoch = 0 + self._loaded_epoch = 0 + self._loaded_current_batches_committed = 0 + self._dataloader_fn = dataloader_fn + self._dataloader_dirty = False + self._dataloader_iter = None + self._accumulation_steps = 1 + def allow_state_dict_read(self) -> None: if self._is_state_dict_read_allowed: return @@ -438,6 +451,12 @@ def allreduce( return _DummyWork(tensor) self.wait_quorum() + + # If dirty, the result will not be committed, so return empty tensor. + if self._dataloader_dirty: + work = _DummyWork(torch.zeros_like(tensor)) + return _ManagedWork(self, work, tensor) + num_participants: int = self.num_participants() if not self.is_participating(): @@ -678,6 +697,8 @@ def _async_quorum( if self._use_async_quorum or not allow_heal else (replica_rank, replica_world_size) ) + self._replica_rank = replica_rank + self._replica_world_size = replica_world_size # For fixed with spares we need to ensure that we don't have more # participating replicas than the min replica size. @@ -691,6 +712,7 @@ def _async_quorum( ): self._participating_replica_rank = None + quorum_changed = False if quorum_id != self._quorum_id: self.quorum_logger.info( "", @@ -737,6 +759,7 @@ def _async_quorum( self._logger.exception(f"got exception in pg configure: {e}") self.report_error(e) return + quorum_changed = True if allow_heal: # run recovery on the recovery stream if available @@ -807,6 +830,38 @@ def _async_quorum( else None ) + # reconfigure dataloader after healing so that we can get offset from other replica group + if quorum_changed and self._dataloader_fn: + self.reconfigure_dataloader() + self._dataloader_dirty = True + + def get_batch_samples(self, epoch=0, num_batches=None, batch_size=None, total_batch_size=None): + # In general, `start_quorum` might not have been called during the first loop, + # and the dataloader might not have been initialized yet. In this case, we should + # return immediately and set the dirty flag to avoid computation and commit. + if not self._dataloader_iter: + self._dataloader_dirty = True + return [] + # If the recovery worker is behind the current epoch, we should skip computation and commit. + if epoch < self._loaded_epoch: + return None + + if total_batch_size != None and batch_size != None: + num_batches = total_batch_size // (batch_size * self._replica_world_size) + + assert num_batches is not None, ("num_batches must be specified or " + "total_batch_size and batch_size must be specified") + + batch_samples = [] + for _ in range(num_batches): + try: + batch_samples.append(next(self._dataloader_iter)) + except StopIteration: + break + self._dataloader_dirty = False + self._accumulation_steps = len(batch_samples) + return batch_samples if batch_samples else None + def _update_fr_path(self) -> None: """ Update the path that flight recorder will dump the traces to. @@ -921,9 +976,14 @@ def should_commit(self, timeout: Optional[timedelta] = None) -> bool: # decide whether we're in a healthy state to increase the step count if should_commit: - self._step += 1 - self._batches_committed += self.num_participants() self._commit_failures = 0 # Reset failure counter on success + if not self._dataloader_dirty: + self._step += 1 + self._batches_committed += self.num_participants() * self._accumulation_steps + self._current_batches_committed += self.num_participants() * self._accumulation_steps + return True + else: + return False else: self._commit_failures += 1 # Check if we've hit max retries @@ -934,8 +994,7 @@ def should_commit(self, timeout: Optional[timedelta] = None) -> bool: msg = f"should_commit failed {self._commit_failures} times consecutively, exceeding max_retries={self._max_retries}" self._logger.exception(msg) raise RuntimeError(msg) - - return should_commit + return False def load_state_dict(self, state_dict: Dict[str, int]) -> None: """ @@ -948,6 +1007,11 @@ def load_state_dict(self, state_dict: Dict[str, int]) -> None: """ self._step = state_dict["step"] self._batches_committed = state_dict["batches_committed"] + self._loaded_epoch = state_dict["epoch"] + self._loaded_current_batches_committed = state_dict["current_batches_committed"] + if self._loaded_epoch == 0: + self._epoch = 0 + self._current_batches_committed = self._loaded_current_batches_committed def _manager_state_dict(self) -> Dict[str, object]: with self._state_dict_lock.r_lock(): @@ -969,7 +1033,8 @@ def state_dict(self) -> Dict[str, int]: Returns: the state dict for this manager """ - return {"step": self._step, "batches_committed": self._batches_committed} + return {"step": self._step, "batches_committed": self._batches_committed, + "epoch": self._epoch, "current_batches_committed": self._current_batches_committed} def current_step(self) -> int: """ @@ -1047,6 +1112,23 @@ def is_participating(self) -> bool: return False return True + def reconfigure_dataloader(self): + dataloader = self._dataloader_fn(self._replica_world_size, + self._replica_rank, self._current_batches_committed) + dataloader.sampler.set_epoch(self._epoch) + self._dataloader_iter = iter(dataloader) + # cleanup for old dataloader + gc.collect() + + def next_epoch(self): + self._epoch += 1 + if self._loaded_epoch == self._epoch: + self._current_batches_committed = self._loaded_current_batches_committed + else: + self._current_batches_committed = 0 + if self._dataloader_fn: + self.reconfigure_dataloader() + self._dataloader_dirty = False class _ManagerLogger: def __init__(self, manager: Manager, replica_id: str, group_rank: int) -> None: diff --git a/torchft/manager_test.py b/torchft/manager_test.py index 4d2dc42c..95873c7b 100644 --- a/torchft/manager_test.py +++ b/torchft/manager_test.py @@ -8,16 +8,19 @@ import threading import time from datetime import timedelta -from typing import Optional +from typing import Optional, Callable from unittest import TestCase from unittest.mock import create_autospec, MagicMock, patch import torch from torch.distributed import ReduceOp, TCPStore +from torch.utils.data import DataLoader from torchft._torchft import QuorumResult from torchft.checkpointing._rwlock import RWLock from torchft.checkpointing.transport import CheckpointTransport +from torchft.data import SkipDistributedSampler +from torchft.data_test import DummyDataset from torchft.manager import Manager, MANAGER_ADDR_KEY, REPLICA_ID_KEY, WorldSizeMode from torchft.process_group import ProcessGroup from torchft.work import _DummyWork @@ -47,6 +50,7 @@ def _create_manager( timeout: timedelta = timedelta(seconds=10), init_sync: bool = True, max_retries: Optional[int] = None, + dataloader_fn: Optional[Callable[[int, int, int], None]] = None, ) -> Manager: pg = create_autospec(ProcessGroup) pg.errored.return_value = None @@ -76,6 +80,7 @@ def _create_manager( timeout=timeout, init_sync=init_sync, max_retries=max_retries, + dataloader_fn=dataloader_fn, ) self.manager = manager return manager @@ -909,3 +914,124 @@ def test_manager_state_dict_with_lock(self, client_mock: MagicMock) -> None: # Restore the original lock manager._state_dict_lock = original_lock + + @patch("torchft.manager.ManagerClient", autospec=True) + def test_dataloader_after_quorum(self, client_mock: MagicMock) -> None: + # 1 Initial + dataset_len = 1000 + batch_size = 4 + dataset = DummyDataset(dataset_len) + committed_batches = 0 + store = TCPStore( + host_name="localhost", port=0, is_master=True, wait_for_workers=False + ) + + def dataloader_fn(replica_world_size, replica_rank, batches_committed): + sampler = SkipDistributedSampler( + dataset=dataset, + num_replicas=replica_world_size, + rank=replica_rank, + shuffle=False, + seed=0, + drop_last=False, + skip_samples=batches_committed * batch_size, + ) + dataloader = DataLoader(dataset, batch_size=batch_size, + num_workers=replica_world_size, sampler=sampler) + return dataloader + + def exptected_samples(world_size, rank, committed_batches, expected_len=None): + expected = [] + expected_len = expected_len if expected_len is not None else batch_size + for i in range(expected_len): + expected.append(committed_batches * batch_size + rank + i * world_size) + expected = [x % dataset_len for x in expected] + return expected + + # Create manager + manager = self._create_manager(dataloader_fn=dataloader_fn) + manager.set_epoch(0) + + # mock for should_commit + client_mock().should_commit = mock_should_commit + + # mock for quorum + quorum = QuorumResult() + quorum.quorum_id = 123 + quorum.replica_rank = 1 + quorum.replica_world_size = 2 + quorum.recover_src_manager_address = "manager address" + quorum.store_address = f"localhost:{store.port}" + quorum.max_step = 1 + quorum.max_replica_rank = 1 + quorum.max_world_size = 2 + quorum.heal = False + + # 2 The initial state has 2 replicas + quorum.replica_world_size = 2 + client_mock()._quorum.return_value = quorum + + # 2.1 Get sampler first time without quorum, then will got empty batches + batches = manager.get_batch_samples(1) + self.assertNotEqual(batches, None) + self.assertEqual(len(batches), 0) + + # 2.2 Start quorum, then reinit dataloader + manager.start_quorum() + manager.wait_quorum() + batches = manager.get_batch_samples(1) + self.assertEqual(len(batches), 1) + for inputs in batches: + self.assertTrue(len(inputs), 4) + self.assertEqual(inputs.tolist(), + exptected_samples(quorum.replica_world_size, quorum.replica_rank, committed_batches)) + + # 2.3 Call should commit to increment committed batches, then get samples + manager.should_commit() + committed_batches += quorum.replica_world_size + batches = manager.get_batch_samples(1) + self.assertEqual(len(batches), 1) + for inputs in batches: + self.assertTrue(len(inputs), 4) + self.assertEqual(inputs.tolist(), + exptected_samples(quorum.replica_world_size, quorum.replica_rank, committed_batches)) + + # 3 Start quorum to increment step and replica world size to 3 + quorum.quorum_id = 124 + quorum.replica_world_size = 3 + client_mock()._quorum.return_value = quorum + manager.start_quorum() + manager.wait_quorum() + + # 3.1 Get sample after quorum with 3 replicas, and set dirty flag to mock dataloader is reinit. + batches = manager.get_batch_samples(1) + manager._dataloader_dirty = True + self.assertEqual(len(batches), 1) + for inputs in batches: + self.assertTrue(len(inputs), 4) + self.assertEqual(inputs.tolist(), + exptected_samples(quorum.replica_world_size, quorum.replica_rank, committed_batches)) + # When the dataloader is dirty, should not commit + self.assertFalse(manager.should_commit()) + # reset the dirty flag + manager._dataloader_dirty = False + + # 3.2 Call should commit to increment committed batches + manager.should_commit() + committed_batches += quorum.replica_world_size + batches = manager.get_batch_samples(1) + self.assertEqual(len(batches), 1) + for inputs in batches: + self.assertTrue(len(inputs), 4) + self.assertEqual(inputs.tolist(), + exptected_samples(quorum.replica_world_size, quorum.replica_rank, committed_batches)) + + # 3.3 Continue to get samples until the dataloader is exhausted + while (batches := manager.get_batch_samples()) != None: + committed_batches += quorum.replica_world_size + self.assertEqual(len(batches), 1) + for inputs in batches: + self.assertTrue(len(inputs), 4) + self.assertEqual(inputs.tolist(), + exptected_samples(quorum.replica_world_size, quorum.replica_rank, + committed_batches, expected_len=len(inputs.tolist()))) diff --git a/torchft/optim.py b/torchft/optim.py index a2884392..4b68e338 100644 --- a/torchft/optim.py +++ b/torchft/optim.py @@ -49,10 +49,12 @@ def zero_grad(self, set_to_none: bool = True) -> None: self.manager.start_quorum() self.optim.zero_grad(set_to_none) - def step(self, closure: Optional[object] = None) -> None: + def step(self, closure: Optional[object] = None) -> bool: assert closure is None, "optimizers that use closures are not supported" if self.manager.should_commit(): self.optim.step() + return True + return False @property def param_groups(self) -> List[Dict[str, Any]]: diff --git a/train_ddp2.py b/train_ddp2.py new file mode 100644 index 00000000..239e861c --- /dev/null +++ b/train_ddp2.py @@ -0,0 +1,237 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import logging +import os +from datetime import timedelta + +from torch.utils.data import DataLoader +from torch.utils.tensorboard import SummaryWriter + +from torchft.data import SkipDistributedSampler + +os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" +REPLICA_GROUP_ID = int(os.environ.get("REPLICA_GROUP_ID", 0)) +os.environ["NCCL_HOSTID"] = str(REPLICA_GROUP_ID) + +import torch +import torchvision +import torchvision.transforms as transforms +from torch import nn, optim +from torch.distributed.elastic.multiprocessing.errors import record + +from torchft import ( + DistributedDataParallel, + Manager, + Optimizer, + ProcessGroupGloo, + ProcessGroupNCCL, + ProcessGroupXCCL, +) +from torchft.checkpointing.pg_transport import PGTransport + +logging.basicConfig(level=logging.INFO) + +NUM_EPOCHS = 10 +BATCH_SIZE = 16 +TOTAL_BATCH_SIZE = BATCH_SIZE * 6 +CHECKPOINT_ENABLED = True +CHECKPOINT_PATH = "./tmp/train_ddp2_checkpoint/ckpt" + +def save_model(m, optimizer, manager): + state_dict_to_save = { + "model": m.state_dict(), + "optim": optimizer.state_dict(), + "torchft": manager.state_dict(), + } + # Save the checkpoint path by step and save the latest step to latest file + step_checkpoint_path = f"{CHECKPOINT_PATH}_step_{manager.current_step()}" + torch.save(state_dict_to_save, step_checkpoint_path) + latest_path = f"{CHECKPOINT_PATH}_latest" + with open(latest_path, "w") as f: + f.write(step_checkpoint_path) + # Delete the older checkpoints + for filename in os.listdir("./tmp/train_ddp2_checkpoint/"): + if filename.startswith("ckpt_step_"): + step_str = filename.split("_")[-1] + try: + step_num = int(step_str) + if step_num < manager.current_step() - 1000: + os.remove(os.path.join("./tmp/train_ddp2_checkpoint/", filename)) + except ValueError: + continue + +def load_model(m, optimizer, manager): + if os.path.exists(f"{CHECKPOINT_PATH}_latest"): + with open(f"{CHECKPOINT_PATH}_latest", "r") as f: + latest_checkpoint_path = f.read().strip() + print(f"Loading checkpoint from {latest_checkpoint_path}") + loaded_state_dict = torch.load(latest_checkpoint_path) + m.load_state_dict(loaded_state_dict["model"]) + optimizer.load_state_dict(loaded_state_dict["optim"]) + manager.load_state_dict(loaded_state_dict["torchft"]) + +@record +def main() -> None: + REPLICA_GROUP_ID = int(os.environ.get("REPLICA_GROUP_ID", 0)) + + transform = transforms.Compose( + [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))] + ) + trainset = torchvision.datasets.CIFAR10( + root="./cifar", train=True, download=True, transform=transform + ) + + def load_state_dict(state_dict): + m.load_state_dict(state_dict["model"]) + optimizer.load_state_dict(state_dict["optim"]) + + def state_dict(): + return { + "model": m.state_dict(), + "optim": optimizer.state_dict(), + } + + if torch.cuda.is_available(): + device = "cuda" + pg = ProcessGroupNCCL(timeout=timedelta(seconds=30)) + elif torch.xpu.is_available(): + device = "xpu" + pg = ProcessGroupXCCL(timeout=timedelta(seconds=30)) + else: + device = "cpu" + pg = ProcessGroupGloo(timeout=timedelta(seconds=5)) + + transport = PGTransport( + pg, + timeout=timedelta(seconds=10), + device=( + "cuda" + if torch.cuda.is_available() + else "xpu" + if torch.xpu.is_available() + else "cpu" + ), + ) + + def dataloader_fn(replica_world_size, replica_rank, current_batches_committed): + sampler = SkipDistributedSampler( + dataset=trainset, + num_replicas=replica_world_size, + rank=replica_rank, + shuffle=True, + seed=0, + drop_last=True, + skip_samples=current_batches_committed * BATCH_SIZE, + ) + + # drop_last to ensure all replicas have the same number of batches + dataloader = DataLoader(trainset, batch_size=BATCH_SIZE, + num_workers=0, sampler=sampler, + drop_last=True) + return dataloader + + manager = Manager( + pg=pg, + min_replica_size=1, + load_state_dict=load_state_dict, + state_dict=state_dict, + replica_id=f"train_ddp_{REPLICA_GROUP_ID}", + timeout=timedelta(seconds=30), + checkpoint_transport=transport, + dataloader_fn=dataloader_fn + ) + + class Net(nn.Module): + def __init__(self): + super().__init__() + self.cnn = nn.Sequential( + nn.Conv2d(3, 6, 5), + nn.ReLU(), + nn.MaxPool2d(2, 2), + nn.Conv2d(6, 16, 5), + nn.ReLU(), + nn.MaxPool2d(2, 2), + ) + + final_dim = 10 + # We add a useless 1GB intermediate layer so we spend more time in dist + # communication so injected failures are more likely to cause issues + # if they exist. + target_size = 1_000_000_000 + self.useless = nn.Embedding(target_size // final_dim // 4, final_dim) + + self.classifier = nn.Sequential( + nn.Linear(16 * 5 * 5, 120), + nn.ReLU(), + nn.Linear(120, 84), + nn.ReLU(), + nn.Linear(84, final_dim), + ) + + def forward(self, x): + x = self.cnn(x) + x = torch.flatten(x, 1) # flatten all dimensions except batch + x = self.classifier(x) + x += self.useless.weight[0] + return x + + m = Net().to(device) + m = DistributedDataParallel(manager, m) + optimizer = Optimizer(manager, optim.AdamW(m.parameters())) + criterion = nn.CrossEntropyLoss() + if CHECKPOINT_ENABLED: + load_model(m, optimizer, manager) + + print(m) + num_params = sum(p.numel() for p in m.parameters()) + print(f"Total number of parameters: {num_params}") + + loss_writer = SummaryWriter(log_dir='./tmp/loss_train_ddp2') + for epoch in range(NUM_EPOCHS): + while (batches := manager.get_batch_samples(epoch=epoch, + batch_size=BATCH_SIZE, total_batch_size=TOTAL_BATCH_SIZE)) is not None: + optimizer.zero_grad() + total_loss = 0.0 + for inputs, labels in batches: + inputs = inputs.to(device) + labels = labels.to(device) + out = m(inputs) + loss = criterion(out, labels) + loss.backward() + total_loss += loss.item() + # If errored, the optimizer step will be a no-op, and the parameter will not be updated. + # Although it is possible to use new pg to compute old batches, it is still safe. + if not optimizer.step(): + continue + + # all reduce the loss across all replicas + total_loss /= len(batches) + loss_tensor = torch.tensor(total_loss, device=device) + manager.allreduce(loss_tensor).wait() + avg_loss = loss_tensor.item() + if manager.participating_rank() == 0: + loss_writer.add_scalar('Training Loss', avg_loss, global_step=manager.batches_committed()) + if manager.current_step() % 100 == 0: + print(f"Epoch {epoch + 1}, step = {manager.current_step()}, batch_committed {manager.batches_committed()}, Loss: {avg_loss:.4f}") + if CHECKPOINT_ENABLED and manager.current_step() % 200 == 0 and manager.participating_rank() == 0: + save_model(m, optimizer, manager) + print(f"Epoch {epoch + 1} completed, batches_committed {manager.batches_committed()}.") + manager.next_epoch() + loss_writer.close() + +if __name__ == "__main__": + main() + + +# 1 启动 torchft lighthouse +# RUST_BACKTRACE=1 torchft_lighthouse --min_replicas 1 --quorum_tick_ms 100 --join_timeout_ms 10000 + +# 2 启动任务 +## cd /work/zhengchenyu/ml-examples/ +## NCCL_HOSTID=0 RUST_LOG=INFO NCCL_SOCKET_IFNAME=eth0 CUDA_VISIBLE_DEVICES=0 TORCHFT_LIGHTHOUSE=http://localhost:29510 REPLICA_GROUP_ID=0 torchrun --master_port 29501 --nnodes 1 --nproc_per_node 1 pytorch_tutorials/torchft/train_ddp2.py +## NCCL_HOSTID=1 RUST_LOG=INFO NCCL_SOCKET_IFNAME=eth0 CUDA_VISIBLE_DEVICES=1 TORCHFT_LIGHTHOUSE=http://localhost:29510 REPLICA_GROUP_ID=1 torchrun --master_port 29502 --nnodes 1 --nproc_per_node 1 pytorch_tutorials/torchft/train_ddp2.py +## NCCL_HOSTID=2 RUST_LOG=INFO NCCL_SOCKET_IFNAME=eth0 CUDA_VISIBLE_DEVICES=2 TORCHFT_LIGHTHOUSE=http://localhost:29510 REPLICA_GROUP_ID=2 torchrun --master_port 29503 --nnodes 1 --nproc_per_node 1 pytorch_tutorials/torchft/train_ddp2.py From 31cecf08c0b8371bed537d484e705e61b6aab05f Mon Sep 17 00:00:00 2001 From: zhengchenyu Date: Tue, 18 Nov 2025 15:52:16 +0800 Subject: [PATCH 2/2] fix style --- torchft/data.py | 21 ++++++++++------ torchft/data_test.py | 46 ++++++++++++++++++++++++++--------- torchft/manager.py | 34 +++++++++++++++++++------- torchft/manager_test.py | 54 ++++++++++++++++++++++++++++++----------- train_ddp2.py | 53 ++++++++++++++++++++++++---------------- 5 files changed, 144 insertions(+), 64 deletions(-) diff --git a/torchft/data.py b/torchft/data.py index 5f1b6a55..dcc7a0f1 100644 --- a/torchft/data.py +++ b/torchft/data.py @@ -14,18 +14,19 @@ dataloader frequently to avoid duplicate batches. """ +import math +from collections.abc import Iterator +from typing import Optional, TypeVar + import torch import torch.distributed as dist +from torch.utils import data from torch.utils.data.dataset import Dataset from torch.utils.data.sampler import Sampler -from torch.utils import data - -import math -from collections.abc import Iterator -from typing import Optional, TypeVar _T_co = TypeVar("_T_co", covariant=True) + class SkipDistributedSampler(Sampler[_T_co]): def __init__( self, @@ -62,10 +63,13 @@ def __init__( # This is to ensure each rank receives the same amount of data when # using this Sampler. self.num_samples = math.ceil( - (len(self.dataset) - self.skip_samples - self.num_replicas) / self.num_replicas # type: ignore[arg-type] + (len(self.dataset) - self.skip_samples - self.num_replicas) + / self.num_replicas # type: ignore[arg-type] ) else: - self.num_samples = math.ceil((len(self.dataset) - self.skip_samples) / self.num_replicas) # type: ignore[arg-type] + self.num_samples = math.ceil( + (len(self.dataset) - self.skip_samples) / self.num_replicas + ) # type: ignore[arg-type] self.total_size = self.num_samples * self.num_replicas self.shuffle = shuffle self.seed = seed @@ -80,7 +84,7 @@ def __iter__(self) -> Iterator[_T_co]: indices = list(range(len(self.dataset))) # type: ignore[arg-type] if not self.drop_last: - indices = indices[self.skip_samples: len(indices)] + indices = indices[self.skip_samples : len(indices)] # add extra samples to make it evenly divisible padding_size = self.total_size - len(indices) if padding_size <= len(indices): @@ -123,6 +127,7 @@ def set_epoch(self, epoch: int) -> None: """ self.epoch = epoch + # pyre-fixme[24]: expected generic parameter class DistributedSampler(data.distributed.DistributedSampler): """ diff --git a/torchft/data_test.py b/torchft/data_test.py index 37a6a331..dcd456fb 100644 --- a/torchft/data_test.py +++ b/torchft/data_test.py @@ -46,11 +46,18 @@ def test_skip_distributed_sampler(self): for drop_last in [True, False]: num_replicas = 7 for rank in range(num_replicas): - sampler = SkipDistributedSampler(dataset=dataset, num_replicas=num_replicas, - rank=rank, shuffle=False, drop_last=drop_last) + sampler = SkipDistributedSampler( + dataset=dataset, + num_replicas=num_replicas, + rank=rank, + shuffle=False, + drop_last=drop_last, + ) cur = rank for idx in sampler: - self.assertEqual(idx, (cur % dataset_length), f"idx={idx}, cur={cur}") + self.assertEqual( + idx, (cur % dataset_length), f"idx={idx}, cur={cur}" + ) cur += num_replicas # If drop_last is True, read ceil((100-7)/7)*7=98 samples totally. # If drop_last is False, read ceil(100/7)*7=105 samples totally. @@ -64,13 +71,21 @@ def test_skip_distributed_sampler(self): num_replicas = 7 skip_samples = 10 for rank in range(num_replicas): - sampler = SkipDistributedSampler(dataset=dataset, num_replicas=num_replicas, - rank=rank, shuffle=False, drop_last=drop_last, - skip_samples=skip_samples) + sampler = SkipDistributedSampler( + dataset=dataset, + num_replicas=num_replicas, + rank=rank, + shuffle=False, + drop_last=drop_last, + skip_samples=skip_samples, + ) cur = rank for idx in sampler: - expected = ((cur + skip_samples) % dataset_length + skip_samples) \ - if (cur + skip_samples) >= dataset_length else (cur + skip_samples) + expected = ( + ((cur + skip_samples) % dataset_length + skip_samples) + if (cur + skip_samples) >= dataset_length + else (cur + skip_samples) + ) self.assertEqual(idx, expected, f"idx={idx}, expected={expected}") cur += num_replicas # If drop_last is True, read ceil((100-10-7)/7)*7=84 samples totally. @@ -88,11 +103,18 @@ def test_skip_distributed_sampler(self): expected = list(range(90, 100)) expected = (expected * 4)[:31] for rank in range(num_replicas): - sampler = SkipDistributedSampler(dataset=dataset, num_replicas=num_replicas, - rank=rank, shuffle=False, drop_last=False, - skip_samples=skip_samples) + sampler = SkipDistributedSampler( + dataset=dataset, + num_replicas=num_replicas, + rank=rank, + shuffle=False, + drop_last=False, + skip_samples=skip_samples, + ) cnt = 0 for idx in sampler: - self.assertEqual(idx, expected[rank], f"idx={idx}, rank={rank}, expected={expected}") + self.assertEqual( + idx, expected[rank], f"idx={idx}, rank={rank}, expected={expected}" + ) cnt += 1 self.assertTrue(cnt, 1) diff --git a/torchft/manager.py b/torchft/manager.py index b8ae26be..8e6160f0 100644 --- a/torchft/manager.py +++ b/torchft/manager.py @@ -835,7 +835,9 @@ def _async_quorum( self.reconfigure_dataloader() self._dataloader_dirty = True - def get_batch_samples(self, epoch=0, num_batches=None, batch_size=None, total_batch_size=None): + def get_batch_samples( + self, epoch=0, num_batches=None, batch_size=None, total_batch_size=None + ): # In general, `start_quorum` might not have been called during the first loop, # and the dataloader might not have been initialized yet. In this case, we should # return immediately and set the dirty flag to avoid computation and commit. @@ -849,8 +851,10 @@ def get_batch_samples(self, epoch=0, num_batches=None, batch_size=None, total_ba if total_batch_size != None and batch_size != None: num_batches = total_batch_size // (batch_size * self._replica_world_size) - assert num_batches is not None, ("num_batches must be specified or " - "total_batch_size and batch_size must be specified") + assert num_batches is not None, ( + "num_batches must be specified or " + "total_batch_size and batch_size must be specified" + ) batch_samples = [] for _ in range(num_batches): @@ -979,8 +983,12 @@ def should_commit(self, timeout: Optional[timedelta] = None) -> bool: self._commit_failures = 0 # Reset failure counter on success if not self._dataloader_dirty: self._step += 1 - self._batches_committed += self.num_participants() * self._accumulation_steps - self._current_batches_committed += self.num_participants() * self._accumulation_steps + self._batches_committed += ( + self.num_participants() * self._accumulation_steps + ) + self._current_batches_committed += ( + self.num_participants() * self._accumulation_steps + ) return True else: return False @@ -1033,8 +1041,12 @@ def state_dict(self) -> Dict[str, int]: Returns: the state dict for this manager """ - return {"step": self._step, "batches_committed": self._batches_committed, - "epoch": self._epoch, "current_batches_committed": self._current_batches_committed} + return { + "step": self._step, + "batches_committed": self._batches_committed, + "epoch": self._epoch, + "current_batches_committed": self._current_batches_committed, + } def current_step(self) -> int: """ @@ -1113,8 +1125,11 @@ def is_participating(self) -> bool: return True def reconfigure_dataloader(self): - dataloader = self._dataloader_fn(self._replica_world_size, - self._replica_rank, self._current_batches_committed) + dataloader = self._dataloader_fn( + self._replica_world_size, + self._replica_rank, + self._current_batches_committed, + ) dataloader.sampler.set_epoch(self._epoch) self._dataloader_iter = iter(dataloader) # cleanup for old dataloader @@ -1130,6 +1145,7 @@ def next_epoch(self): self.reconfigure_dataloader() self._dataloader_dirty = False + class _ManagerLogger: def __init__(self, manager: Manager, replica_id: str, group_rank: int) -> None: self._logger: logging.Logger = logging.getLogger(__name__) diff --git a/torchft/manager_test.py b/torchft/manager_test.py index 95873c7b..5728a204 100644 --- a/torchft/manager_test.py +++ b/torchft/manager_test.py @@ -8,7 +8,7 @@ import threading import time from datetime import timedelta -from typing import Optional, Callable +from typing import Callable, Optional from unittest import TestCase from unittest.mock import create_autospec, MagicMock, patch @@ -936,8 +936,12 @@ def dataloader_fn(replica_world_size, replica_rank, batches_committed): drop_last=False, skip_samples=batches_committed * batch_size, ) - dataloader = DataLoader(dataset, batch_size=batch_size, - num_workers=replica_world_size, sampler=sampler) + dataloader = DataLoader( + dataset, + batch_size=batch_size, + num_workers=replica_world_size, + sampler=sampler, + ) return dataloader def exptected_samples(world_size, rank, committed_batches, expected_len=None): @@ -983,8 +987,12 @@ def exptected_samples(world_size, rank, committed_batches, expected_len=None): self.assertEqual(len(batches), 1) for inputs in batches: self.assertTrue(len(inputs), 4) - self.assertEqual(inputs.tolist(), - exptected_samples(quorum.replica_world_size, quorum.replica_rank, committed_batches)) + self.assertEqual( + inputs.tolist(), + exptected_samples( + quorum.replica_world_size, quorum.replica_rank, committed_batches + ), + ) # 2.3 Call should commit to increment committed batches, then get samples manager.should_commit() @@ -993,8 +1001,12 @@ def exptected_samples(world_size, rank, committed_batches, expected_len=None): self.assertEqual(len(batches), 1) for inputs in batches: self.assertTrue(len(inputs), 4) - self.assertEqual(inputs.tolist(), - exptected_samples(quorum.replica_world_size, quorum.replica_rank, committed_batches)) + self.assertEqual( + inputs.tolist(), + exptected_samples( + quorum.replica_world_size, quorum.replica_rank, committed_batches + ), + ) # 3 Start quorum to increment step and replica world size to 3 quorum.quorum_id = 124 @@ -1009,8 +1021,12 @@ def exptected_samples(world_size, rank, committed_batches, expected_len=None): self.assertEqual(len(batches), 1) for inputs in batches: self.assertTrue(len(inputs), 4) - self.assertEqual(inputs.tolist(), - exptected_samples(quorum.replica_world_size, quorum.replica_rank, committed_batches)) + self.assertEqual( + inputs.tolist(), + exptected_samples( + quorum.replica_world_size, quorum.replica_rank, committed_batches + ), + ) # When the dataloader is dirty, should not commit self.assertFalse(manager.should_commit()) # reset the dirty flag @@ -1023,8 +1039,12 @@ def exptected_samples(world_size, rank, committed_batches, expected_len=None): self.assertEqual(len(batches), 1) for inputs in batches: self.assertTrue(len(inputs), 4) - self.assertEqual(inputs.tolist(), - exptected_samples(quorum.replica_world_size, quorum.replica_rank, committed_batches)) + self.assertEqual( + inputs.tolist(), + exptected_samples( + quorum.replica_world_size, quorum.replica_rank, committed_batches + ), + ) # 3.3 Continue to get samples until the dataloader is exhausted while (batches := manager.get_batch_samples()) != None: @@ -1032,6 +1052,12 @@ def exptected_samples(world_size, rank, committed_batches, expected_len=None): self.assertEqual(len(batches), 1) for inputs in batches: self.assertTrue(len(inputs), 4) - self.assertEqual(inputs.tolist(), - exptected_samples(quorum.replica_world_size, quorum.replica_rank, - committed_batches, expected_len=len(inputs.tolist()))) + self.assertEqual( + inputs.tolist(), + exptected_samples( + quorum.replica_world_size, + quorum.replica_rank, + committed_batches, + expected_len=len(inputs.tolist()), + ), + ) diff --git a/train_ddp2.py b/train_ddp2.py index 239e861c..8f2d26a6 100644 --- a/train_ddp2.py +++ b/train_ddp2.py @@ -41,6 +41,7 @@ CHECKPOINT_ENABLED = True CHECKPOINT_PATH = "./tmp/train_ddp2_checkpoint/ckpt" + def save_model(m, optimizer, manager): state_dict_to_save = { "model": m.state_dict(), @@ -64,6 +65,7 @@ def save_model(m, optimizer, manager): except ValueError: continue + def load_model(m, optimizer, manager): if os.path.exists(f"{CHECKPOINT_PATH}_latest"): with open(f"{CHECKPOINT_PATH}_latest", "r") as f: @@ -74,6 +76,7 @@ def load_model(m, optimizer, manager): optimizer.load_state_dict(loaded_state_dict["optim"]) manager.load_state_dict(loaded_state_dict["torchft"]) + @record def main() -> None: REPLICA_GROUP_ID = int(os.environ.get("REPLICA_GROUP_ID", 0)) @@ -129,9 +132,13 @@ def dataloader_fn(replica_world_size, replica_rank, current_batches_committed): ) # drop_last to ensure all replicas have the same number of batches - dataloader = DataLoader(trainset, batch_size=BATCH_SIZE, - num_workers=0, sampler=sampler, - drop_last=True) + dataloader = DataLoader( + trainset, + batch_size=BATCH_SIZE, + num_workers=0, + sampler=sampler, + drop_last=True, + ) return dataloader manager = Manager( @@ -142,7 +149,7 @@ def dataloader_fn(replica_world_size, replica_rank, current_batches_committed): replica_id=f"train_ddp_{REPLICA_GROUP_ID}", timeout=timedelta(seconds=30), checkpoint_transport=transport, - dataloader_fn=dataloader_fn + dataloader_fn=dataloader_fn, ) class Net(nn.Module): @@ -190,10 +197,13 @@ def forward(self, x): num_params = sum(p.numel() for p in m.parameters()) print(f"Total number of parameters: {num_params}") - loss_writer = SummaryWriter(log_dir='./tmp/loss_train_ddp2') + loss_writer = SummaryWriter(log_dir="./tmp/loss_train_ddp2") for epoch in range(NUM_EPOCHS): - while (batches := manager.get_batch_samples(epoch=epoch, - batch_size=BATCH_SIZE, total_batch_size=TOTAL_BATCH_SIZE)) is not None: + while ( + batches := manager.get_batch_samples( + epoch=epoch, batch_size=BATCH_SIZE, total_batch_size=TOTAL_BATCH_SIZE + ) + ) is not None: optimizer.zero_grad() total_loss = 0.0 for inputs, labels in batches: @@ -214,24 +224,25 @@ def forward(self, x): manager.allreduce(loss_tensor).wait() avg_loss = loss_tensor.item() if manager.participating_rank() == 0: - loss_writer.add_scalar('Training Loss', avg_loss, global_step=manager.batches_committed()) + loss_writer.add_scalar( + "Training Loss", avg_loss, global_step=manager.batches_committed() + ) if manager.current_step() % 100 == 0: - print(f"Epoch {epoch + 1}, step = {manager.current_step()}, batch_committed {manager.batches_committed()}, Loss: {avg_loss:.4f}") - if CHECKPOINT_ENABLED and manager.current_step() % 200 == 0 and manager.participating_rank() == 0: + print( + f"Epoch {epoch + 1}, step = {manager.current_step()}, batch_committed {manager.batches_committed()}, Loss: {avg_loss:.4f}" + ) + if ( + CHECKPOINT_ENABLED + and manager.current_step() % 200 == 0 + and manager.participating_rank() == 0 + ): save_model(m, optimizer, manager) - print(f"Epoch {epoch + 1} completed, batches_committed {manager.batches_committed()}.") + print( + f"Epoch {epoch + 1} completed, batches_committed {manager.batches_committed()}." + ) manager.next_epoch() loss_writer.close() + if __name__ == "__main__": main() - - -# 1 启动 torchft lighthouse -# RUST_BACKTRACE=1 torchft_lighthouse --min_replicas 1 --quorum_tick_ms 100 --join_timeout_ms 10000 - -# 2 启动任务 -## cd /work/zhengchenyu/ml-examples/ -## NCCL_HOSTID=0 RUST_LOG=INFO NCCL_SOCKET_IFNAME=eth0 CUDA_VISIBLE_DEVICES=0 TORCHFT_LIGHTHOUSE=http://localhost:29510 REPLICA_GROUP_ID=0 torchrun --master_port 29501 --nnodes 1 --nproc_per_node 1 pytorch_tutorials/torchft/train_ddp2.py -## NCCL_HOSTID=1 RUST_LOG=INFO NCCL_SOCKET_IFNAME=eth0 CUDA_VISIBLE_DEVICES=1 TORCHFT_LIGHTHOUSE=http://localhost:29510 REPLICA_GROUP_ID=1 torchrun --master_port 29502 --nnodes 1 --nproc_per_node 1 pytorch_tutorials/torchft/train_ddp2.py -## NCCL_HOSTID=2 RUST_LOG=INFO NCCL_SOCKET_IFNAME=eth0 CUDA_VISIBLE_DEVICES=2 TORCHFT_LIGHTHOUSE=http://localhost:29510 REPLICA_GROUP_ID=2 torchrun --master_port 29503 --nnodes 1 --nproc_per_node 1 pytorch_tutorials/torchft/train_ddp2.py