diff --git a/tests/distributed/test_shm_broadcast.py b/tests/distributed/test_shm_broadcast.py index 711c2441f34b..f9eacc11d75f 100644 --- a/tests/distributed/test_shm_broadcast.py +++ b/tests/distributed/test_shm_broadcast.py @@ -9,7 +9,7 @@ from vllm.distributed.device_communicators.shm_broadcast import MessageQueue from vllm.distributed.utils import StatelessProcessGroup -from vllm.utils import get_ip, get_open_port, update_environment_variables +from vllm.utils import get_open_port, update_environment_variables def get_arrays(n: int, seed: int = 0) -> list[np.ndarray]: @@ -60,12 +60,12 @@ def worker_fn(): rank = dist.get_rank() if rank == 0: port = get_open_port() - ip = get_ip() + ip = '127.0.0.1' dist.broadcast_object_list([ip, port], src=0) else: recv = [None, None] dist.broadcast_object_list(recv, src=0) - ip, port = recv + ip, port = recv # type: ignore stateless_pg = StatelessProcessGroup.create(ip, port, rank, dist.get_world_size()) @@ -107,10 +107,10 @@ def worker_fn(): if pg == dist.group.WORLD: dist.barrier() - print("torch distributed passed the test!") + print(f"torch distributed passed the test! Rank {rank}") else: pg.barrier() - print("StatelessProcessGroup passed the test!") + print(f"StatelessProcessGroup passed the test! Rank {rank}") def test_shm_broadcast(): diff --git a/vllm/distributed/device_communicators/shm_broadcast.py b/vllm/distributed/device_communicators/shm_broadcast.py index fa944407a703..40e57e6624d1 100644 --- a/vllm/distributed/device_communicators/shm_broadcast.py +++ b/vllm/distributed/device_communicators/shm_broadcast.py @@ -1,8 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 -import os import pickle -import sys import time from contextlib import contextmanager from dataclasses import dataclass, field @@ -19,7 +17,7 @@ from zmq import SUB, SUBSCRIBE, XPUB, XPUB_VERBOSE, Context # type: ignore import vllm.envs as envs -from vllm.distributed.utils import StatelessProcessGroup +from vllm.distributed.utils import StatelessProcessGroup, sched_yield from vllm.logger import init_logger from vllm.utils import (get_ip, get_open_port, get_open_zmq_ipc_path, is_valid_ipv6_address) @@ -28,20 +26,6 @@ logger = init_logger(__name__) -# We prefer to use os.sched_yield as it results in tighter polling loops, -# measured to be around 3e-7 seconds. However on earlier versions of Python -# os.sched_yield() does not release the GIL, so we fall back to time.sleep(0) -USE_SCHED_YIELD = ((sys.version_info[:3] >= (3, 11, 1)) - or (sys.version_info[:2] == (3, 10) - and sys.version_info[2] >= 8)) - - -def sched_yield(): - if USE_SCHED_YIELD: - os.sched_yield() - else: - time.sleep(0) - class ShmRingBuffer: diff --git a/vllm/distributed/utils.py b/vllm/distributed/utils.py index 6bb323d79d64..93a069d36c4b 100644 --- a/vllm/distributed/utils.py +++ b/vllm/distributed/utils.py @@ -6,9 +6,12 @@ # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. import dataclasses import datetime +import os import pickle import socket +import sys import time +import uuid from collections import deque from collections.abc import Sequence from typing import Any, Optional @@ -27,6 +30,20 @@ logger = init_logger(__name__) +# We prefer to use os.sched_yield as it results in tighter polling loops, +# measured to be around 3e-7 seconds. However on earlier versions of Python +# os.sched_yield() does not release the GIL, so we fall back to time.sleep(0) +USE_SCHED_YIELD = ((sys.version_info[:3] >= (3, 11, 1)) + or (sys.version_info[:2] == (3, 10) + and sys.version_info[2] >= 8)) + + +def sched_yield(): + if USE_SCHED_YIELD: + os.sched_yield() + else: + time.sleep(0) + def ensure_divisibility(numerator, denominator): """Ensure that numerator is divisible by the denominator.""" @@ -212,10 +229,141 @@ def all_gather_obj(self, obj: Any) -> list[Any]: gathered_objs.append(recv_obj) return gathered_objs - def barrier(self): - """A barrier to synchronize all ranks.""" + def barrier(self, timeout: float = 30.0): + """A robust barrier to synchronize all ranks. + + + Uses a multi-phase approach to ensure all processes reach the barrier + before proceeding: + + 1. Each process signals it has reached the barrier + + 2. Each process signals that it has confirmed the arrival of all other + ranks. + + 3. Rank 0 waits for all other ranks to signal their departure to ensure + that all ranks have departed the barrier first. + + Args: + timeout: Maximum time in seconds to wait for each phase (in seconds) + + + Raises: + RuntimeError: If coordination fails or times out + """ + # Generate a barrier ID that is globally unique + try: + if self.rank == 0: + barrier_id = f"barrier_{uuid.uuid4()}" + self.broadcast_obj(barrier_id, src=0) + else: + barrier_id = self.broadcast_obj(None, src=0) + except Exception as e: + raise RuntimeError("Failed to broadcast barrier_id") from e + + # Phase 1: Signal arrival at barrier + # Wait for all processes to arrive + # We need all ranks to confirm the arrival of all other ranks. + # This is the key synchronization point. + arrival_key = f"arrival_{barrier_id}_{self.rank}" + try: + self.store.set(arrival_key, b"1") + except Exception as e: + raise RuntimeError("Failed to signal barrier arrival") from e + + start_time = time.time() + processes_arrived: set[int] = set() + + while len(processes_arrived) < self.world_size: + # Check for timeout + cur_time = time.time() + if cur_time - start_time > timeout: + raise RuntimeError("Barrier timed out after %f seconds", + timeout) + + # Check for each process + for i in range(self.world_size): + if i in processes_arrived: + continue + + key = f"arrival_{barrier_id}_{i}" + try: + # Try to get the key - if it exists, we'll get a value + # If it doesn't exist, it will throw an exception + self.store.get(key) + processes_arrived.add(i) + except KeyError: + # Key doesn't exist yet + pass + except Exception as check_e: + logger.debug("Error checking key existence: %s", check_e) + sched_yield() + + # Short sleep to avoid tight polling + if len(processes_arrived) < self.world_size: + sched_yield() + + # Phase 2: Signal departure from barrier + # We only care to block at this stage in rank 0, which runs the + # server side of the TCPStore. We want to make sure that all + # clients have departed the barrier before rank 0 in case the + # next thing after the barrier is a shutdown, including tearing + # down the TCPStore. Other ranks can exit the barrier immediately + # after signaling their departure. + departure_key = f"departure_{barrier_id}_{self.rank}" + try: + self.store.set(departure_key, b"1") + except Exception as e: + raise RuntimeError("Failed to signal barrier departure") from e + + if self.rank != 0: + return + + # Make rank 0 wait for all processes to signal departure + start_time = time.time() + processes_departed: set[int] = set() + + while len(processes_departed) < self.world_size: + # Check for timeout + if time.time() - start_time > timeout: + raise RuntimeError("Barrier departure timed out after %f s", + timeout) + + # Check for each process + for i in range(self.world_size): + if i in processes_departed: + continue + + key = f"departure_{barrier_id}_{i}" + try: + # Try to get the key - if it exists, we'll get a value + # If it doesn't exist, it will throw an exception + self.store.get(key) + processes_departed.add(i) + except KeyError: + # Key doesn't exist yet + pass + except Exception as check_e: + logger.debug("Error checking key existence: %s", check_e) + sched_yield() + + # Short sleep to avoid tight polling + if len(processes_departed) < self.world_size: + sched_yield() + + # Clean up keys to avoid leaking memory in the store for i in range(self.world_size): - self.broadcast_obj(None, src=i) + try: + self.store.delete_key(f"arrival_{barrier_id}_{i}") + except Exception: + logger.debug("Error deleting key: %s", + f'arrival_{barrier_id}_{i}') + + try: + self.store.delete_key(f"departure_{barrier_id}_{i}") + except Exception: + logger.debug("Error deleting key: %s", + f'departure_{barrier_id}_{i}') @staticmethod def create(