From 88025215dc692ac28e1abb79e18892760d53f109 Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Wed, 2 Apr 2025 14:07:18 -0700 Subject: [PATCH 01/12] [V1] DP scale-out (2/N): Decouple engine process management and comms Signed-off-by: Nick Hill --- vllm/config.py | 5 + vllm/engine/arg_utils.py | 7 + vllm/v1/engine/core.py | 98 +++++++++----- vllm/v1/engine/core_client.py | 247 ++++++++++++++++++++-------------- vllm/v1/utils.py | 22 ++- 5 files changed, 232 insertions(+), 147 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index 2669d1a13b37..2d9854f865b3 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1430,6 +1430,7 @@ class ParallelConfig: pipeline_parallel_size: int = 1 # Number of pipeline parallel groups. tensor_parallel_size: int = 1 # Number of tensor parallel groups. data_parallel_size: int = 1 # Number of data parallel groups. + data_parallel_size_local: int = 1 # Number of data parallel groups. data_parallel_rank: int = 0 # Rank of the data parallel group. # Local rank of the data parallel group, defaults to global rank. data_parallel_rank_local: Optional[int] = None @@ -1537,6 +1538,10 @@ def __post_init__(self) -> None: self.world_size = self.pipeline_parallel_size * \ self.tensor_parallel_size + if not (0 < self.data_parallel_size_local <= self.data_parallel_size): + raise ValueError( + "data_parallel_size_local must be <= data_parallel_size") + if self.data_parallel_size > 1: # Data parallel was specified in the engine args. self.data_parallel_master_port = get_open_port() diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 89c9b67470e6..d6965a08ef04 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -116,6 +116,7 @@ class EngineArgs: pipeline_parallel_size: int = 1 tensor_parallel_size: int = 1 data_parallel_size: int = 1 + data_parallel_size_local: Optional[int] = None enable_expert_parallel: bool = False max_parallel_loading_workers: Optional[int] = None block_size: Optional[int] = None @@ -1186,10 +1187,16 @@ def create_engine_config( # but we should not do this here. placement_group = ray.util.get_current_placement_group() + # Local DP size defaults to global DP size if not set. + data_parallel_size_local = self.data_parallel_size if ( + self.data_parallel_size_local + is None) else self.data_parallel_size_local + parallel_config = ParallelConfig( pipeline_parallel_size=self.pipeline_parallel_size, tensor_parallel_size=self.tensor_parallel_size, data_parallel_size=self.data_parallel_size, + data_parallel_size_local=data_parallel_size_local, enable_expert_parallel=self.enable_expert_parallel, max_parallel_loading_workers=self.max_parallel_loading_workers, disable_custom_all_reduce=self.disable_custom_all_reduce, diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index f58c77e4f165..b7fc3eb6ff7d 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -22,8 +22,8 @@ from vllm.lora.request import LoRARequest from vllm.transformers_utils.config import ( maybe_register_config_serialize_by_value) -from vllm.utils import (get_exception_traceback, resolve_obj_by_qualname, - zmq_socket_ctx) +from vllm.utils import (get_exception_traceback, make_zmq_socket, + resolve_obj_by_qualname, zmq_socket_ctx) from vllm.v1.core.kv_cache_utils import (get_kv_cache_config, unify_kv_cache_configs) from vllm.v1.core.sched.interface import SchedulerInterface @@ -309,9 +309,9 @@ class EngineCoreProc(EngineCore): def __init__( self, - input_path: str, - output_path: str, vllm_config: VllmConfig, + on_head_node: bool, + input_address: str, executor_class: type[Executor], log_stats: bool, engine_index: int = 0, @@ -323,6 +323,19 @@ def __init__( self.global_unfinished_reqs = False + # Create input socket. + input_ctx = zmq.Context() # type: ignore[attr-defined] + identity = engine_index.to_bytes(length=2, byteorder="little") + input_socket = make_zmq_socket(input_ctx, + input_address, + zmq.DEALER, + identity=identity, + bind=False) + + # Register engine with front-end. + output_address = self.startup_handshake(input_socket, on_head_node, + vllm_config.parallel_config) + # Background Threads and Queues for IO. These enable us to # overlap ZMQ socket IO with GPU since they release the GIL, # and to overlap some serialization/deserialization with the @@ -332,12 +345,39 @@ def __init__( Any]] = queue.Queue() self.output_queue: queue.Queue[EngineCoreOutputs] = queue.Queue() threading.Thread(target=self.process_input_socket, - args=(input_path, engine_index), + args=(input_socket, ), daemon=True).start() threading.Thread(target=self.process_output_socket, - args=(output_path, engine_index), + args=(output_address, engine_index), daemon=True).start() + @staticmethod + def startup_handshake(input_socket: zmq.Socket, on_head_node: bool, + parallel_config: ParallelConfig) -> str: + + # Send registration message. + input_socket.send( + msgspec.msgpack.encode({ + "local": on_head_node, + "status": "READY" + })) + + # Receive initialization message. + logger.info("Waiting for init message from front-end.") + input_socket.poll(timeout=5 * 60 * 1000) + init_bytes = input_socket.recv() + init_message = msgspec.msgpack.decode(init_bytes) + logger.debug("Received init message: %s", init_message) + + output_socket_address = init_message["output_socket_address"] + #TBD maybe replace IP with configured head node address + + received_parallel_config = init_message["parallel_config"] + for key, value in received_parallel_config.items(): + setattr(parallel_config, key, value) + + return output_socket_address + @staticmethod def run_engine_core(*args, dp_rank: int = 0, @@ -472,35 +512,25 @@ def _convert_msgspec_args(method, args): and not isinstance(v, p.annotation) else v for v, p in zip(args, arg_types)) - def process_input_socket(self, input_path: str, engine_index: int): + def process_input_socket(self, input_socket: zmq.Socket): """Input socket IO thread.""" # Msgpack serialization decoding. add_request_decoder = MsgpackDecoder(EngineCoreRequest) generic_decoder = MsgpackDecoder() - identity = engine_index.to_bytes(length=2, byteorder="little") - - with zmq_socket_ctx(input_path, - zmq.DEALER, - identity=identity, - bind=False) as socket: - # Send ready message to front-end once input socket is connected. - socket.send(b'READY') - - while True: - # (RequestType, RequestData) - type_frame, data_frame = socket.recv_multipart(copy=False) - request_type = EngineCoreRequestType(bytes(type_frame.buffer)) + while True: + # (RequestType, RequestData) + type_frame, data_frame = input_socket.recv_multipart(copy=False) + request_type = EngineCoreRequestType(bytes(type_frame.buffer)) - # Deserialize the request data. - decoder = add_request_decoder if ( - request_type - == EngineCoreRequestType.ADD) else generic_decoder - request = decoder.decode(data_frame.buffer) + # Deserialize the request data. + decoder = add_request_decoder if ( + request_type == EngineCoreRequestType.ADD) else generic_decoder + request = decoder.decode(data_frame.buffer) - # Push to input queue for core busy loop. - self.input_queue.put_nowait((request_type, request)) + # Push to input queue for core busy loop. + self.input_queue.put_nowait((request_type, request)) def process_output_socket(self, output_path: str, engine_index: int): """Output socket IO thread.""" @@ -527,9 +557,9 @@ class DPEngineCoreProc(EngineCoreProc): def __init__( self, - input_path: str, - output_path: str, vllm_config: VllmConfig, + on_head_node: bool, + input_address: str, executor_class: type[Executor], log_stats: bool, ): @@ -551,17 +581,17 @@ def __init__( from vllm.platforms import current_platform if current_platform.is_cuda_alike(): from vllm.platforms.cuda import device_id_to_physical_device_id - tp_size = vllm_config.parallel_config.tensor_parallel_size + world_size = vllm_config.parallel_config.world_size os.environ["CUDA_VISIBLE_DEVICES"] = ",".join( str(device_id_to_physical_device_id(i)) - for i in range(local_dp_rank * tp_size, (local_dp_rank + 1) * - tp_size)) + for i in range(local_dp_rank * + world_size, (local_dp_rank + 1) * world_size)) self.dp_group = vllm_config.parallel_config.stateless_init_dp_group() # Initialize the engine after setting up environment. - super().__init__(input_path, output_path, vllm_config, executor_class, - log_stats, dp_rank) + super().__init__(vllm_config, on_head_node, input_address, + executor_class, log_stats, dp_rank) # Counts forward-passes of the model so that we can synchronize # finished with DP peers every N steps. diff --git a/vllm/v1/engine/core_client.py b/vllm/v1/engine/core_client.py index b94b0aa75386..ace6470faaea 100644 --- a/vllm/v1/engine/core_client.py +++ b/vllm/v1/engine/core_client.py @@ -10,18 +10,20 @@ from abc import ABC, abstractmethod from collections.abc import Awaitable from concurrent.futures import Future -from dataclasses import dataclass, field +from dataclasses import dataclass from threading import Thread from typing import Any, Callable, Optional, TypeVar, Union +import msgspec import zmq import zmq.asyncio -from vllm.config import VllmConfig +from vllm.config import ParallelConfig, VllmConfig from vllm.logger import init_logger from vllm.lora.request import LoRARequest -from vllm.utils import (get_open_zmq_inproc_path, get_open_zmq_ipc_path, - kill_process_tree, make_zmq_socket) +from vllm.utils import (get_open_port, get_open_zmq_inproc_path, + get_open_zmq_ipc_path, kill_process_tree, + make_zmq_socket) from vllm.v1.engine import (EngineCoreOutputs, EngineCoreRequest, EngineCoreRequestType, UtilityOutput) from vllm.v1.engine.core import EngineCore, EngineCoreProc @@ -255,46 +257,59 @@ def collective_rpc(self, return self.engine_core.collective_rpc(method, timeout, args, kwargs) -class CoreEngine: +class CoreEngineProcManager: """One per data parallel rank.""" def __init__( self, + local_engine_count: int, + start_index: int, vllm_config: VllmConfig, + on_head_node: bool, + input_address: str, executor_class: type[Executor], log_stats: bool, - input_path: str, - output_path: str, - index: int = 0, - local_dp_rank: int = 0, ): - self.index = index - self.identity = index.to_bytes(length=2, byteorder="little") + self.proc_handles = [] try: - # Start EngineCore in background process. - self.proc_handle = BackgroundProcHandle( - input_path=input_path, - output_path=output_path, - process_name=f"EngineCore_{index}", - target_fn=EngineCoreProc.run_engine_core, - process_kwargs={ - "vllm_config": vllm_config, - "dp_rank": index, - "local_dp_rank": local_dp_rank, - "executor_class": executor_class, - "log_stats": log_stats, - }) - - self.num_reqs_in_flight = 0 + for local_index in range(local_engine_count): + index = local_index + start_index + # Start EngineCore in background process. + self.proc_handles.append( + BackgroundProcHandle( + input_address=input_address, + process_name=f"EngineCore_{index}", + target_fn=EngineCoreProc.run_engine_core, + process_kwargs={ + "vllm_config": vllm_config, + "on_head_node": on_head_node, + "dp_rank": index, + "local_dp_rank": local_index, + "executor_class": executor_class, + "log_stats": log_stats, + })) finally: - if not hasattr(self, "num_reqs_in_flight"): - # Ensure socket is closed if process fails to start. + if len(self.proc_handles) != local_engine_count: self.close() def close(self): - if proc_handle := getattr(self, "proc_handle", None): + for proc_handle in self.proc_handles: proc_handle.shutdown() + def finished_procs(self) -> dict[int, int]: + return { + handle.proc.name: handle.proc.exitcode + for handle in self.proc_handles if handle.proc.exitcode is not None + } + + +class CoreEngine: + """One per data parallel rank.""" + + def __init__(self, index: int = 0): + self.identity = index.to_bytes(length=2, byteorder="little") + self.num_reqs_in_flight = 0 + @dataclass class BackgroundResources: @@ -302,7 +317,7 @@ class BackgroundResources: circular reference back to the client object.""" ctx: Union[zmq.Context] - core_engines: list[CoreEngine] = field(default_factory=list) + local_engine_manager: Optional[CoreEngineProcManager] = None output_socket: Optional[Union[zmq.Socket, zmq.asyncio.Socket]] = None input_socket: Optional[Union[zmq.Socket, zmq.asyncio.Socket]] = None shutdown_path: Optional[str] = None @@ -310,8 +325,8 @@ class BackgroundResources: def __call__(self): """Clean up background resources.""" - for core_engine in self.core_engines: - core_engine.close() + if self.local_engine_manager is not None: + self.local_engine_manager.close() # ZMQ context termination can hang if the sockets # aren't explicitly closed first. @@ -383,67 +398,111 @@ def sigusr1_handler(signum, frame): self.resources = BackgroundResources(ctx=sync_ctx) self._finalizer = weakref.finalize(self, self.resources) - # Paths and sockets for IPC. - self.output_path = get_open_zmq_ipc_path() - input_path = get_open_zmq_ipc_path() - self.input_socket = make_zmq_socket(self.ctx, - input_path, - zmq.ROUTER, - bind=True) - self.resources.input_socket = self.input_socket - - new_core_engine = lambda index, local_dp_rank=None: CoreEngine( - vllm_config, executor_class, log_stats, input_path, self. - output_path, index, local_dp_rank) + # TODO + parallel_config = vllm_config.parallel_config + dp_size = parallel_config.data_parallel_size + local_engine_count = parallel_config.data_parallel_size_local - # Start engine core process(es). - self._init_core_engines(vllm_config, new_core_engine, - self.resources.core_engines) + # TODO somewhere validate local count <= dp_size + if local_engine_count == dp_size: + input_address = get_open_zmq_ipc_path() + output_address = get_open_zmq_ipc_path() + else: + host = parallel_config.data_parallel_master_ip + input_port = 13345 # todo from arg/config + output_port = get_open_port() + input_address = f"tcp://{host}:{input_port}" + output_address = f"tcp://{host}:{output_port}" + + # Create input and output sockets. + self.input_socket = self.resources.input_socket = make_zmq_socket( + self.ctx, input_address, zmq.ROUTER, bind=True) + + self.resources.output_socket = make_zmq_socket(self.ctx, + output_address, + zmq.constants.PULL) + + # Start local engines. + if local_engine_count: + self.resources.local_engine_manager = CoreEngineProcManager( + vllm_config=vllm_config, + executor_class=executor_class, + log_stats=log_stats, + input_address=input_address, + on_head_node=True, + local_engine_count=local_engine_count, + start_index=0) + + self.core_engines = [CoreEngine(i) for i in range(dp_size)] + self.core_engine = self.core_engines[0] # Wait for engine core process(es) to start. - self._wait_for_engine_startup() + self._wait_for_engine_startup(output_address, parallel_config) self.utility_results: dict[int, AnyFuture] = {} - def _wait_for_engine_startup(self): + def _wait_for_engine_startup(self, output_address: str, + parallel_config: ParallelConfig): # Get a sync handle to the socket which can be sync or async. sync_input_socket = zmq.Socket.shadow(self.input_socket) # Wait for engine core process(es) to send ready messages. - identities = set(eng.index for eng in self.resources.core_engines) - while identities: + local_engine_count = parallel_config.data_parallel_size_local + # TODO offline case compatibility + local_indices = set(range(local_engine_count)) + remote_indices = set( + range(len(self.core_engines) - local_engine_count)) + while local_indices or remote_indices: while not sync_input_socket.poll(timeout=STARTUP_POLL_PERIOD_MS): - logger.info("Waiting for %d core engine proc(s) to start: %s", - len(identities), identities) - eng_id_bytes, msg = sync_input_socket.recv_multipart() - eng_id = int.from_bytes(eng_id_bytes, byteorder="little") - if eng_id not in identities: - raise RuntimeError(f"Unexpected or duplicate engine: {eng_id}") - if msg != b'READY': - raise RuntimeError(f"Engine {eng_id} failed: {msg.decode()}") - logger.info("Core engine process %d ready.", eng_id) - identities.discard(eng_id) + local_count = len(local_indices) + if remote_indices: + remote_count = len(remote_indices) + logger.info( + "Waiting for %d local and %d remote core engine " + "proc(s) to start: %s, %s", local_count, remote_count, + local_indices, remote_indices) + else: + logger.info( + "Waiting for %d local core engine proc(s) " + "to start: %s", local_count, local_indices) + eng_identity, ready_msg_bytes = sync_input_socket.recv_multipart() + ready_msg = msgspec.msgpack.decode(ready_msg_bytes) + local, status = ready_msg["local"], ready_msg["status"] + eng_index = int.from_bytes(eng_identity, byteorder="little") + if status != "READY": + raise RuntimeError(f"{'Local' if local else 'Remote'} engine " + f"{eng_index} failed: {status}") + + index_set = local_indices if local else remote_indices + if eng_index not in index_set: + raise RuntimeError( + f"Unexpected or duplicate " + f"{'local' if local else 'remote'} engine: {eng_index}") + + # Send init message with DP config info. + init_message = self.encoder.encode({ + "output_socket_address": output_address, + "parallel_config": { + "data_parallel_master_ip": + parallel_config.data_parallel_master_ip, + "data_parallel_master_port": + parallel_config.data_parallel_master_port, + "data_parallel_size": parallel_config.data_parallel_size, + }, + }) + + sync_input_socket.send_multipart((eng_identity, init_message), + copy=False) + + logger.debug("%s core engine process %d ready.", + "Local" if local else "Remote", eng_index) + index_set.discard(eng_index) # Double check that the process are running. - for engine in self.resources.core_engines: - proc = engine.proc_handle.proc - if proc.exitcode is not None: - raise RuntimeError(f"Engine proc {proc.name} not running") - - def _init_core_engines( - self, - vllm_config: VllmConfig, - new_core_engine: Callable[[int, Optional[int]], CoreEngine], - core_engines: list[CoreEngine], - ) -> None: - - # Default case - single core engine. - dp_rank = vllm_config.parallel_config.data_parallel_rank - local_dp_rank = vllm_config.parallel_config.data_parallel_rank_local - core_engine = new_core_engine( - dp_rank, local_dp_rank if local_dp_rank is not None else dp_rank) - core_engines.append(core_engine) - self.core_engine = core_engine + engine_manager = self.resources.local_engine_manager + if engine_manager and (procs := engine_manager.finished_procs()): + raise RuntimeError( + f"Local engine proc(s) exited unexpectedly: {procs}") def shutdown(self): self._finalizer() @@ -476,7 +535,8 @@ def __init__(self, vllm_config: VllmConfig, executor_class: type[Executor], # Ensure that the outputs socket processing thread does not have # a ref to the client which prevents gc. ctx = self.ctx - output_path = self.output_path + out_socket = self.resources.output_socket + assert out_socket is not None decoder = self.decoder utility_results = self.utility_results outputs_queue = self.outputs_queue @@ -486,7 +546,6 @@ def __init__(self, vllm_config: VllmConfig, executor_class: type[Executor], def process_outputs_socket(): shutdown_socket = ctx.socket(zmq.PAIR) - out_socket = make_zmq_socket(ctx, output_path, zmq.constants.PULL) try: shutdown_socket.bind(shutdown_path) poller = zmq.Poller() @@ -518,6 +577,9 @@ def process_outputs_socket(): daemon=True) self.output_queue_thread.start() + # The thread takes on responsibility for closing the socket. + self.resources.output_socket = None + def get_output(self) -> EngineCoreOutputs: return self.outputs_queue.get() @@ -621,10 +683,8 @@ def _ensure_output_queue_task(self): outputs_queue = self.outputs_queue output_handler = self.outputs_handler _self_ref = weakref.ref(self) if output_handler else None - output_path = self.output_path - output_socket = make_zmq_socket(self.ctx, output_path, - zmq.constants.PULL) - self.resources.output_socket = output_socket + output_socket = self.resources.output_socket + assert output_socket is not None async def process_outputs_socket(): while True: @@ -762,21 +822,6 @@ def __init__(self, vllm_config: VllmConfig, executor_class: type[Executor], self.outputs_handler = DPAsyncMPClient.process_engine_outputs # type: ignore[assignment] - def _init_core_engines( - self, - vllm_config: VllmConfig, - new_core_engine: Callable[[int, Optional[int]], CoreEngine], - core_engines: list[CoreEngine], - ) -> None: - - # Launch a core engine for each data parallel rank. - dp_size = vllm_config.parallel_config.data_parallel_size - for i in range(dp_size): - # Multi-node not yet supported so local_dp_rank == dp_rank. - core_engines.append(new_core_engine(i, i)) - - self.core_engines = core_engines - async def call_utility_async(self, method: str, *args) -> Any: # Only the result from the first engine is returned. return (await asyncio.gather(*[ diff --git a/vllm/v1/utils.py b/vllm/v1/utils.py index fed5761b04b6..470e2a572ed8 100644 --- a/vllm/v1/utils.py +++ b/vllm/v1/utils.py @@ -98,25 +98,22 @@ class BackgroundProcHandle: def __init__( self, - input_path: str, - output_path: str, + input_address: str, process_name: str, target_fn: Callable, process_kwargs: dict[Any, Any], ): context = get_mp_context() - assert ("input_path" not in process_kwargs - and "output_path" not in process_kwargs) - process_kwargs["input_path"] = input_path - process_kwargs["output_path"] = output_path + assert "input_address" not in process_kwargs + process_kwargs["input_address"] = input_address # Run busy loop in background process. self.proc = context.Process(target=target_fn, kwargs=process_kwargs, name=process_name) self._finalizer = weakref.finalize(self, shutdown, self.proc, - input_path, output_path) + input_address) self.proc.start() def shutdown(self): @@ -125,7 +122,7 @@ def shutdown(self): # Note(rob): shutdown function cannot be a bound method, # else the gc cannot collect the object. -def shutdown(proc: multiprocessing.Process, input_path: str, output_path: str): +def shutdown(proc: multiprocessing.Process, input_address: str): # Shutdown the process. if proc.is_alive(): proc.terminate() @@ -135,11 +132,12 @@ def shutdown(proc: multiprocessing.Process, input_path: str, output_path: str): kill_process_tree(proc.pid) # Remove zmq ipc socket files. - ipc_sockets = [output_path, input_path] + ipc_sockets = (input_address, ) for ipc_socket in ipc_sockets: - socket_file = ipc_socket.replace("ipc://", "") - if os and os.path.exists(socket_file): - os.remove(socket_file) + if ipc_socket.startswith("ipc://"): + socket_file = ipc_socket.replace("ipc://", "") + if os and os.path.exists(socket_file): + os.remove(socket_file) def bind_kv_cache( From e86938050a9f22b4c50446cb7f26a0ad7ac3d781 Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Thu, 3 Apr 2025 11:19:27 -0700 Subject: [PATCH 02/12] Headless mode Signed-off-by: Nick Hill --- vllm/config.py | 3 +- vllm/engine/arg_utils.py | 33 ++++++++ vllm/entrypoints/cli/serve.py | 62 +++++++++++++- vllm/v1/engine/core.py | 52 ++++++++---- vllm/v1/engine/core_client.py | 155 +++++++++++++++------------------- vllm/v1/utils.py | 101 ++++++++++++++++------ 6 files changed, 270 insertions(+), 136 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index 2d9854f865b3..57ab0ef05964 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1436,6 +1436,7 @@ class ParallelConfig: data_parallel_rank_local: Optional[int] = None # IP of the data parallel master. data_parallel_master_ip: str = "127.0.0.1" + data_parallel_rpc_port: int = 29550 # Port for data parallel messaging. data_parallel_master_port: int = 29500 # Port of the data parallel master. enable_expert_parallel: bool = False # Use EP instead of TP for MoE layers. @@ -1538,7 +1539,7 @@ def __post_init__(self) -> None: self.world_size = self.pipeline_parallel_size * \ self.tensor_parallel_size - if not (0 < self.data_parallel_size_local <= self.data_parallel_size): + if self.data_parallel_size_local > self.data_parallel_size: raise ValueError( "data_parallel_size_local must be <= data_parallel_size") diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index d6965a08ef04..b7cacb177dc5 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -117,6 +117,9 @@ class EngineArgs: tensor_parallel_size: int = 1 data_parallel_size: int = 1 data_parallel_size_local: Optional[int] = None + data_parallel_start_rank: int = 0 + data_parallel_address: Optional[str] = None + data_parallel_rpc_port: Optional[int] = None enable_expert_parallel: bool = False max_parallel_loading_workers: Optional[int] = None block_size: Optional[int] = None @@ -435,6 +438,29 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: 'MoE layers will be sharded according to the ' 'product of the tensor-parallel-size and ' 'data-parallel-size.') + parser.add_argument('--data-parallel-size-local', + '-dpl', + type=int, + default=EngineArgs.data_parallel_size_local, + help='Number of data parallel replicas to run on ' + 'this node.') + parser.add_argument('--data-parallel-start-rank', + '-dpr', + type=int, + default=EngineArgs.data_parallel_start_rank, + help='Starting data parallel rank for secondary ' + 'nodes.') + parser.add_argument('--data-parallel-address', + '-dpa', + type=str, + default=EngineArgs.data_parallel_address, + help='Address of data parallel cluster head-node.') + parser.add_argument('--data-parallel-rpc-port', + '-dpp', + type=int, + default=EngineArgs.data_parallel_rpc_port, + help='Port for data parallel RPC communication.') + parser.add_argument( '--enable-expert-parallel', action='store_true', @@ -1192,11 +1218,18 @@ def create_engine_config( self.data_parallel_size_local is None) else self.data_parallel_size_local + # This port is only used when there are remote data parallel engines, + # otherwise the local IPC transport is used. + data_parallel_rpc_port = self.data_parallel_rpc_port if ( + self.data_parallel_rpc_port + is not None) else (ParallelConfig.data_parallel_rpc_port) + parallel_config = ParallelConfig( pipeline_parallel_size=self.pipeline_parallel_size, tensor_parallel_size=self.tensor_parallel_size, data_parallel_size=self.data_parallel_size, data_parallel_size_local=data_parallel_size_local, + data_parallel_rpc_port=data_parallel_rpc_port, enable_expert_parallel=self.enable_expert_parallel, max_parallel_loading_workers=self.max_parallel_loading_workers, disable_custom_all_reduce=self.disable_custom_all_reduce, diff --git a/vllm/entrypoints/cli/serve.py b/vllm/entrypoints/cli/serve.py index e89ac4e21999..801dd6db3d7b 100644 --- a/vllm/entrypoints/cli/serve.py +++ b/vllm/entrypoints/cli/serve.py @@ -4,11 +4,20 @@ import uvloop +import vllm.envs as envs +from vllm import AsyncEngineArgs from vllm.entrypoints.cli.types import CLISubcommand from vllm.entrypoints.openai.api_server import run_server from vllm.entrypoints.openai.cli_args import (make_arg_parser, validate_parsed_serve_args) +from vllm.logger import init_logger +from vllm.usage.usage_lib import UsageContext from vllm.utils import FlexibleArgumentParser +from vllm.v1.engine.core import EngineCoreProc +from vllm.v1.engine.core_client import CoreEngineProcManager +from vllm.v1.executor.abstract import Executor + +logger = init_logger(__name__) class ServeSubcommand(CLISubcommand): @@ -24,7 +33,10 @@ def cmd(args: argparse.Namespace) -> None: if hasattr(args, 'model_tag') and args.model_tag is not None: args.model = args.model_tag - uvloop.run(run_server(args)) + if args.headless: + run_headless(args) + else: + uvloop.run(run_server(args)) def validate(self, args: argparse.Namespace) -> None: validate_parsed_serve_args(args) @@ -41,6 +53,12 @@ def subparser_init( nargs='?', help="The model tag to serve " "(optional if specified in config)") + serve_parser.add_argument( + "--headless", + action='store_true', + default=False, + help="Run in headless mode. See multi-node data parallel " + "documentation for more details.") serve_parser.add_argument( "--config", type=str, @@ -56,3 +74,45 @@ def subparser_init( def cmd_init() -> list[CLISubcommand]: return [ServeSubcommand()] + + +def run_headless(args: argparse.Namespace): + + # Create the EngineConfig. + engine_args = AsyncEngineArgs.from_cli_args(args) + usage_context = UsageContext.OPENAI_API_SERVER + vllm_config = engine_args.create_engine_config(usage_context=usage_context) + + if not envs.VLLM_USE_V1: + raise RuntimeError("Headless mode is only supported for V1") + + parallel_config = vllm_config.parallel_config + local_engine_count = parallel_config.data_parallel_size_local + host = parallel_config.data_parallel_master_ip + port = engine_args.data_parallel_rpc_port # add to config too + input_address = f"tcp://{host}:{port}" + + if local_engine_count <= 0: + raise RuntimeError("data_parallel_size_local must be > 0 in " + "headless mode") + + logger.info( + "Launching %d data parallel engine(s) in headless mode, " + "with head node address %s.", local_engine_count, input_address) + + # Create the engines. + engine_manager = CoreEngineProcManager( + target_fn=EngineCoreProc.run_engine_core, + local_engine_count=local_engine_count, + start_index=engine_args.data_parallel_start_rank, + vllm_config=vllm_config, + on_head_node=False, + input_address=input_address, + executor_class=Executor.get_class(vllm_config), + log_stats=not engine_args.disable_log_stats, + ) + + try: + engine_manager.join_first() + finally: + engine_manager.close() diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index b7fc3eb6ff7d..1399604692e4 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -316,13 +316,6 @@ def __init__( log_stats: bool, engine_index: int = 0, ): - super().__init__(vllm_config, executor_class, log_stats) - - self.step_fn = (self.step if self.batch_queue is None else - self.step_with_batch_queue) - - self.global_unfinished_reqs = False - # Create input socket. input_ctx = zmq.Context() # type: ignore[attr-defined] identity = engine_index.to_bytes(length=2, byteorder="little") @@ -336,6 +329,24 @@ def __init__( output_address = self.startup_handshake(input_socket, on_head_node, vllm_config.parallel_config) + # Set up data parallel environment. + self._init_data_parallel(vllm_config) + + # Initialize engine core and model. + super().__init__(vllm_config, executor_class, log_stats) + + self.step_fn = (self.step if self.batch_queue is None else + self.step_with_batch_queue) + + self.global_unfinished_reqs = False + + # Send ready message. + input_socket.send( + msgspec.msgpack.encode({ + "status": "READY", + "local": on_head_node + })) + # Background Threads and Queues for IO. These enable us to # overlap ZMQ socket IO with GPU since they release the GIL, # and to overlap some serialization/deserialization with the @@ -358,8 +369,8 @@ def startup_handshake(input_socket: zmq.Socket, on_head_node: bool, # Send registration message. input_socket.send( msgspec.msgpack.encode({ + "status": "HELLO", "local": on_head_node, - "status": "READY" })) # Receive initialization message. @@ -430,6 +441,9 @@ def signal_handler(signum, frame): if engine_core is not None: engine_core.shutdown() + def _init_data_parallel(self, vllm_config: VllmConfig): + pass + def run_busy_loop(self): """Core busy loop of the EngineCore.""" @@ -571,8 +585,20 @@ def __init__( _add_prefix(sys.stdout, process_name, pid) _add_prefix(sys.stderr, process_name, pid) - dp_size = vllm_config.parallel_config.data_parallel_size + # Counts forward-passes of the model so that we can synchronize + # finished with DP peers every N steps. + self.counter = 0 + + # Initialize the engine. dp_rank = vllm_config.parallel_config.data_parallel_rank + super().__init__(vllm_config, on_head_node, input_address, + executor_class, log_stats, dp_rank) + + def _init_data_parallel(self, vllm_config: VllmConfig): + + # Configure GPUs and stateless process group for data parallel. + dp_rank = vllm_config.parallel_config.data_parallel_rank + dp_size = vllm_config.parallel_config.data_parallel_size local_dp_rank = vllm_config.parallel_config.data_parallel_rank_local assert dp_size > 1 @@ -589,14 +615,6 @@ def __init__( self.dp_group = vllm_config.parallel_config.stateless_init_dp_group() - # Initialize the engine after setting up environment. - super().__init__(vllm_config, on_head_node, input_address, - executor_class, log_stats, dp_rank) - - # Counts forward-passes of the model so that we can synchronize - # finished with DP peers every N steps. - self.counter = 0 - def shutdown(self): super().shutdown() if dp_group := getattr(self, "dp_group", None): diff --git a/vllm/v1/engine/core_client.py b/vllm/v1/engine/core_client.py index ace6470faaea..c41ef85a1752 100644 --- a/vllm/v1/engine/core_client.py +++ b/vllm/v1/engine/core_client.py @@ -29,7 +29,7 @@ from vllm.v1.engine.core import EngineCore, EngineCoreProc from vllm.v1.executor.abstract import Executor from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder -from vllm.v1.utils import BackgroundProcHandle +from vllm.v1.utils import CoreEngineProcManager logger = init_logger(__name__) @@ -257,52 +257,6 @@ def collective_rpc(self, return self.engine_core.collective_rpc(method, timeout, args, kwargs) -class CoreEngineProcManager: - """One per data parallel rank.""" - - def __init__( - self, - local_engine_count: int, - start_index: int, - vllm_config: VllmConfig, - on_head_node: bool, - input_address: str, - executor_class: type[Executor], - log_stats: bool, - ): - self.proc_handles = [] - try: - for local_index in range(local_engine_count): - index = local_index + start_index - # Start EngineCore in background process. - self.proc_handles.append( - BackgroundProcHandle( - input_address=input_address, - process_name=f"EngineCore_{index}", - target_fn=EngineCoreProc.run_engine_core, - process_kwargs={ - "vllm_config": vllm_config, - "on_head_node": on_head_node, - "dp_rank": index, - "local_dp_rank": local_index, - "executor_class": executor_class, - "log_stats": log_stats, - })) - finally: - if len(self.proc_handles) != local_engine_count: - self.close() - - def close(self): - for proc_handle in self.proc_handles: - proc_handle.shutdown() - - def finished_procs(self) -> dict[int, int]: - return { - handle.proc.name: handle.proc.exitcode - for handle in self.proc_handles if handle.proc.exitcode is not None - } - - class CoreEngine: """One per data parallel rank.""" @@ -398,18 +352,17 @@ def sigusr1_handler(signum, frame): self.resources = BackgroundResources(ctx=sync_ctx) self._finalizer = weakref.finalize(self, self.resources) - # TODO + # TODO move address setup to separate method parallel_config = vllm_config.parallel_config dp_size = parallel_config.data_parallel_size local_engine_count = parallel_config.data_parallel_size_local - # TODO somewhere validate local count <= dp_size if local_engine_count == dp_size: input_address = get_open_zmq_ipc_path() output_address = get_open_zmq_ipc_path() else: host = parallel_config.data_parallel_master_ip - input_port = 13345 # todo from arg/config + input_port = parallel_config.data_parallel_rpc_port output_port = get_open_port() input_address = f"tcp://{host}:{input_port}" output_address = f"tcp://{host}:{output_port}" @@ -421,10 +374,10 @@ def sigusr1_handler(signum, frame): self.resources.output_socket = make_zmq_socket(self.ctx, output_address, zmq.constants.PULL) - # Start local engines. if local_engine_count: self.resources.local_engine_manager = CoreEngineProcManager( + EngineCoreProc.run_engine_core, vllm_config=vllm_config, executor_class=executor_class, log_stats=log_stats, @@ -446,56 +399,80 @@ def _wait_for_engine_startup(self, output_address: str, # Get a sync handle to the socket which can be sync or async. sync_input_socket = zmq.Socket.shadow(self.input_socket) + # TODO offline case compatibility + # Wait for engine core process(es) to send ready messages. local_engine_count = parallel_config.data_parallel_size_local - # TODO offline case compatibility - local_indices = set(range(local_engine_count)) - remote_indices = set( - range(len(self.core_engines) - local_engine_count)) - while local_indices or remote_indices: + remote_engine_count = len(self.core_engines) - local_engine_count + + # TODO simplify the startup tracking logic below! + pending_hello_local = set(range(local_engine_count)) + pending_hello_remote = set( + range(local_engine_count, len(self.core_engines))) + pending_ready_local = set(pending_hello_local) + pending_ready_remote = set(pending_hello_remote) + while pending_ready_local or pending_ready_remote: while not sync_input_socket.poll(timeout=STARTUP_POLL_PERIOD_MS): - local_count = len(local_indices) - if remote_indices: - remote_count = len(remote_indices) + local_conn = local_engine_count - len(pending_hello_local) + local_ready = local_engine_count - len(pending_ready_local) + if local_ready != local_engine_count: logger.info( - "Waiting for %d local and %d remote core engine " - "proc(s) to start: %s, %s", local_count, remote_count, - local_indices, remote_indices) - else: - logger.info( - "Waiting for %d local core engine proc(s) " - "to start: %s", local_count, local_indices) + "Waiting for local core engine procs: " + "%d/%d connected, %d/%d ready.", local_conn, + local_engine_count, local_ready, local_engine_count) + if remote_engine_count: + remote_conn = remote_engine_count - len( + pending_hello_remote) + remote_ready = remote_engine_count - len( + pending_ready_remote) + if remote_ready != remote_engine_count: + logger.info( + "Waiting for remote core engine procs: " + "%d/%d connected, %d/%d ready.", remote_conn, + remote_engine_count, remote_ready, + remote_engine_count) + + # Receive HELLO and READY messages from the input socket. eng_identity, ready_msg_bytes = sync_input_socket.recv_multipart() - ready_msg = msgspec.msgpack.decode(ready_msg_bytes) - local, status = ready_msg["local"], ready_msg["status"] eng_index = int.from_bytes(eng_identity, byteorder="little") - if status != "READY": + msg = msgspec.msgpack.decode(ready_msg_bytes) + status, local = msg["status"], msg["local"] + hello_set = pending_hello_local if local else pending_hello_remote + ready_set = pending_ready_local if local else pending_ready_remote + if status == "HELLO": + index_set = hello_set + elif status == "READY": + index_set = ready_set + else: raise RuntimeError(f"{'Local' if local else 'Remote'} engine " f"{eng_index} failed: {status}") - - index_set = local_indices if local else remote_indices if eng_index not in index_set: raise RuntimeError( - f"Unexpected or duplicate " + f"Unexpected or duplicate {status} " + f"{'local' if local else 'remote'} engine: {eng_index}") + if status == "READY" and eng_index in hello_set: + raise RuntimeError( + f"Unexpected READY before HELLO for " f"{'local' if local else 'remote'} engine: {eng_index}") - # Send init message with DP config info. - init_message = self.encoder.encode({ - "output_socket_address": output_address, - "parallel_config": { - "data_parallel_master_ip": - parallel_config.data_parallel_master_ip, - "data_parallel_master_port": - parallel_config.data_parallel_master_port, - "data_parallel_size": parallel_config.data_parallel_size, - }, - }) - - sync_input_socket.send_multipart((eng_identity, init_message), - copy=False) - - logger.debug("%s core engine process %d ready.", - "Local" if local else "Remote", eng_index) + if status == "HELLO": + # Send init message with DP config info. + init_message = self.encoder.encode({ + "output_socket_address": output_address, + "parallel_config": { + "data_parallel_master_ip": + parallel_config.data_parallel_master_ip, + "data_parallel_master_port": + parallel_config.data_parallel_master_port, + "data_parallel_size": + parallel_config.data_parallel_size, + }, + }) + sync_input_socket.send_multipart((eng_identity, init_message), + copy=False) + + logger.debug("%s from %s core engine process %s.", status, + "local" if local else "remote", eng_index) index_set.discard(eng_index) # Double check that the process are running. diff --git a/vllm/v1/utils.py b/vllm/v1/utils.py index 470e2a572ed8..adfdb86a7056 100644 --- a/vllm/v1/utils.py +++ b/vllm/v1/utils.py @@ -2,17 +2,21 @@ import multiprocessing import os +import time import weakref from collections import defaultdict from collections.abc import Sequence -from typing import (TYPE_CHECKING, Any, Callable, Generic, Optional, TypeVar, - Union, overload) +from multiprocessing import connection +from typing import (TYPE_CHECKING, Callable, Generic, Optional, TypeVar, Union, + overload) import torch +from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.model_executor.models.utils import extract_layer_index from vllm.utils import get_mp_context, kill_process_tree +from vllm.v1.executor.abstract import Executor if TYPE_CHECKING: from vllm.attention.layer import Attention @@ -90,7 +94,7 @@ def __repr__(self): return f"ConstantList({self._x})" -class BackgroundProcHandle: +class CoreEngineProcManager: """ Utility class to handle creation, readiness, and shutdown of background processes used by the AsyncLLM and LLMEngine. @@ -98,46 +102,87 @@ class BackgroundProcHandle: def __init__( self, - input_address: str, - process_name: str, target_fn: Callable, - process_kwargs: dict[Any, Any], + local_engine_count: int, + start_index: int, + vllm_config: VllmConfig, + on_head_node: bool, + input_address: str, + executor_class: type[Executor], + log_stats: bool, ): context = get_mp_context() - - assert "input_address" not in process_kwargs - process_kwargs["input_address"] = input_address - - # Run busy loop in background process. - self.proc = context.Process(target=target_fn, - kwargs=process_kwargs, - name=process_name) - self._finalizer = weakref.finalize(self, shutdown, self.proc, + common_kwargs = { + "vllm_config": vllm_config, + "on_head_node": on_head_node, + "input_address": input_address, + "executor_class": executor_class, + "log_stats": log_stats, + } + + self.processes = [] + for local_index in range(local_engine_count): + index = local_index + start_index + # Start EngineCore in background process. + self.processes.append( + context.Process(target=target_fn, + name=f"EngineCore_{index}", + kwargs=common_kwargs | { + "dp_rank": index, + "local_dp_rank": local_index, + })) + + self._finalizer = weakref.finalize(self, shutdown, self.processes, input_address) - self.proc.start() - - def shutdown(self): + try: + for proc in self.processes: + proc.start() + finally: + # Kill other procs if not all are running. + if self.finished_procs(): + self.close() + + def close(self): + """Shutdown all procs.""" self._finalizer() + def join_first(self): + """Wait for any process to exit.""" + connection.wait(proc.sentinel for proc in self.processes) + + def finished_procs(self) -> dict[int, int]: + """Returns dict of proc name -> exit code for any finished procs.""" + return { + proc.name: proc.exitcode + for proc in self.processes if proc.exitcode is not None + } + # Note(rob): shutdown function cannot be a bound method, # else the gc cannot collect the object. -def shutdown(proc: multiprocessing.Process, input_address: str): +def shutdown(procs: list[multiprocessing.Process], input_address: str): # Shutdown the process. - if proc.is_alive(): - proc.terminate() - proc.join(5) + for proc in procs: + if proc.is_alive(): + proc.terminate() + + # Allow 5 seconds for remaining procs to terminate. + deadline = time.monotonic() + 5 + for proc in procs: + remaining = deadline - time.monotonic() + if remaining <= 0: + break + proc.join(remaining) + for proc in procs: if proc.is_alive(): kill_process_tree(proc.pid) # Remove zmq ipc socket files. - ipc_sockets = (input_address, ) - for ipc_socket in ipc_sockets: - if ipc_socket.startswith("ipc://"): - socket_file = ipc_socket.replace("ipc://", "") - if os and os.path.exists(socket_file): - os.remove(socket_file) + if input_address.startswith("ipc://"): + socket_file = input_address[len("ipc://"):] + if os and os.path.exists(socket_file): + os.remove(socket_file) def bind_kv_cache( From 1ca3d1598f17d17d40d25a5ba3e44072798628d4 Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Thu, 3 Apr 2025 19:23:39 -0700 Subject: [PATCH 03/12] Wire data_parallel_address arg Signed-off-by: Nick Hill --- vllm/engine/arg_utils.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index b7cacb177dc5..729c1be1321d 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -1218,17 +1218,24 @@ def create_engine_config( self.data_parallel_size_local is None) else self.data_parallel_size_local + # DP address, used in multi-node case for torch distributed group + # and ZMQ sockets. + data_parallel_address = self.data_parallel_address if ( + self.data_parallel_address + is not None) else ParallelConfig.data_parallel_master_ip + # This port is only used when there are remote data parallel engines, # otherwise the local IPC transport is used. data_parallel_rpc_port = self.data_parallel_rpc_port if ( self.data_parallel_rpc_port - is not None) else (ParallelConfig.data_parallel_rpc_port) + is not None) else ParallelConfig.data_parallel_rpc_port parallel_config = ParallelConfig( pipeline_parallel_size=self.pipeline_parallel_size, tensor_parallel_size=self.tensor_parallel_size, data_parallel_size=self.data_parallel_size, data_parallel_size_local=data_parallel_size_local, + data_parallel_master_ip=data_parallel_address, data_parallel_rpc_port=data_parallel_rpc_port, enable_expert_parallel=self.enable_expert_parallel, max_parallel_loading_workers=self.max_parallel_loading_workers, From a5511836d0cee8007d326be57cce7b49340ce2c0 Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Thu, 3 Apr 2025 21:33:05 -0700 Subject: [PATCH 04/12] Some code cleanup Signed-off-by: Nick Hill --- vllm/v1/engine/core_client.py | 128 +++++++++++++++++++--------------- 1 file changed, 70 insertions(+), 58 deletions(-) diff --git a/vllm/v1/engine/core_client.py b/vllm/v1/engine/core_client.py index c41ef85a1752..c5d9b16f8fee 100644 --- a/vllm/v1/engine/core_client.py +++ b/vllm/v1/engine/core_client.py @@ -11,6 +11,7 @@ from collections.abc import Awaitable from concurrent.futures import Future from dataclasses import dataclass +from enum import Enum, auto from threading import Thread from typing import Any, Callable, Optional, TypeVar, Union @@ -257,11 +258,20 @@ def collective_rpc(self, return self.engine_core.collective_rpc(method, timeout, args, kwargs) +class CoreEngineState(Enum): + NEW = auto() + CONNECTED = auto() + READY = auto() + + class CoreEngine: """One per data parallel rank.""" - def __init__(self, index: int = 0): + def __init__(self, index: int = 0, local: bool = True): + self.local = local self.identity = index.to_bytes(length=2, byteorder="little") + + self.state = CoreEngineState.NEW self.num_reqs_in_flight = 0 @@ -352,20 +362,12 @@ def sigusr1_handler(signum, frame): self.resources = BackgroundResources(ctx=sync_ctx) self._finalizer = weakref.finalize(self, self.resources) - # TODO move address setup to separate method parallel_config = vllm_config.parallel_config dp_size = parallel_config.data_parallel_size local_engine_count = parallel_config.data_parallel_size_local - if local_engine_count == dp_size: - input_address = get_open_zmq_ipc_path() - output_address = get_open_zmq_ipc_path() - else: - host = parallel_config.data_parallel_master_ip - input_port = parallel_config.data_parallel_rpc_port - output_port = get_open_port() - input_address = f"tcp://{host}:{input_port}" - output_address = f"tcp://{host}:{output_port}" + input_address, output_address = self._get_zmq_addresses( + parallel_config) # Create input and output sockets. self.input_socket = self.resources.input_socket = make_zmq_socket( @@ -386,7 +388,10 @@ def sigusr1_handler(signum, frame): local_engine_count=local_engine_count, start_index=0) - self.core_engines = [CoreEngine(i) for i in range(dp_size)] + self.core_engines = [ + CoreEngine(index=i, local=(i < local_engine_count)) + for i in range(dp_size) + ] self.core_engine = self.core_engines[0] # Wait for engine core process(es) to start. @@ -394,6 +399,24 @@ def sigusr1_handler(signum, frame): self.utility_results: dict[int, AnyFuture] = {} + @staticmethod + def _get_zmq_addresses(parallel_config: ParallelConfig) -> tuple[str, str]: + """Returns (input_address, output_address).""" + dp_size = parallel_config.data_parallel_size + local_engine_count = parallel_config.data_parallel_size_local + + if local_engine_count == dp_size: + input_address = get_open_zmq_ipc_path() + output_address = get_open_zmq_ipc_path() + else: + host = parallel_config.data_parallel_master_ip + input_port = parallel_config.data_parallel_rpc_port + output_port = get_open_port() + input_address = f"tcp://{host}:{input_port}" + output_address = f"tcp://{host}:{output_port}" + + return input_address, output_address + def _wait_for_engine_startup(self, output_address: str, parallel_config: ParallelConfig): # Get a sync handle to the socket which can be sync or async. @@ -402,60 +425,39 @@ def _wait_for_engine_startup(self, output_address: str, # TODO offline case compatibility # Wait for engine core process(es) to send ready messages. - local_engine_count = parallel_config.data_parallel_size_local - remote_engine_count = len(self.core_engines) - local_engine_count - - # TODO simplify the startup tracking logic below! - pending_hello_local = set(range(local_engine_count)) - pending_hello_remote = set( - range(local_engine_count, len(self.core_engines))) - pending_ready_local = set(pending_hello_local) - pending_ready_remote = set(pending_hello_remote) - while pending_ready_local or pending_ready_remote: + local_count = parallel_config.data_parallel_size_local + remote_count = len(self.core_engines) - local_count + # [local, remote] counts + conn_pending, start_pending = [local_count, remote_count], [0, 0] + + while any(conn_pending) or any(start_pending): while not sync_input_socket.poll(timeout=STARTUP_POLL_PERIOD_MS): - local_conn = local_engine_count - len(pending_hello_local) - local_ready = local_engine_count - len(pending_ready_local) - if local_ready != local_engine_count: + if any(conn_pending): logger.info( - "Waiting for local core engine procs: " - "%d/%d connected, %d/%d ready.", local_conn, - local_engine_count, local_ready, local_engine_count) - if remote_engine_count: - remote_conn = remote_engine_count - len( - pending_hello_remote) - remote_ready = remote_engine_count - len( - pending_ready_remote) - if remote_ready != remote_engine_count: - logger.info( - "Waiting for remote core engine procs: " - "%d/%d connected, %d/%d ready.", remote_conn, - remote_engine_count, remote_ready, - remote_engine_count) + "Waiting for %d local, %d remote core engine proc(s) " + "to connect.", *conn_pending) + if any(start_pending): + logger.info( + "Waiting for %d local, %d remote core engine proc(s) " + "to start.", *start_pending) # Receive HELLO and READY messages from the input socket. eng_identity, ready_msg_bytes = sync_input_socket.recv_multipart() eng_index = int.from_bytes(eng_identity, byteorder="little") + if eng_index > len(self.core_engines): + raise RuntimeError( + f"Message from engine rank larger than " + f"configured data parallel size: {eng_index}") + engine = self.core_engines[eng_index] msg = msgspec.msgpack.decode(ready_msg_bytes) status, local = msg["status"], msg["local"] - hello_set = pending_hello_local if local else pending_hello_remote - ready_set = pending_ready_local if local else pending_ready_remote - if status == "HELLO": - index_set = hello_set - elif status == "READY": - index_set = ready_set - else: - raise RuntimeError(f"{'Local' if local else 'Remote'} engine " - f"{eng_index} failed: {status}") - if eng_index not in index_set: - raise RuntimeError( - f"Unexpected or duplicate {status} " - f"{'local' if local else 'remote'} engine: {eng_index}") - if status == "READY" and eng_index in hello_set: - raise RuntimeError( - f"Unexpected READY before HELLO for " - f"{'local' if local else 'remote'} engine: {eng_index}") + if local != engine.local: + raise RuntimeError(f"{status} message from " + f"{'local' if local else 'remote'} " + f" engine {eng_index}, expected it to be " + f"{'local' if engine.local else 'remote'}") + if status == "HELLO" and engine.state == CoreEngineState.NEW: - if status == "HELLO": # Send init message with DP config info. init_message = self.encoder.encode({ "output_socket_address": output_address, @@ -470,10 +472,20 @@ def _wait_for_engine_startup(self, output_address: str, }) sync_input_socket.send_multipart((eng_identity, init_message), copy=False) + conn_pending[0 if local else 1] -= 1 + start_pending[0 if local else 1] += 1 + engine.state = CoreEngineState.CONNECTED + elif status == "READY" and (engine.state + == CoreEngineState.CONNECTED): + start_pending[0 if local else 1] -= 1 + engine.state = CoreEngineState.READY + else: + raise RuntimeError(f"Unexpected {status} message for " + f"{'local' if local else 'remote'} engine " + f"{eng_index} in {engine.state} state.") logger.debug("%s from %s core engine process %s.", status, "local" if local else "remote", eng_index) - index_set.discard(eng_index) # Double check that the process are running. engine_manager = self.resources.local_engine_manager From a6621696742fec492842fb5c8168f069070c48b1 Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Fri, 4 Apr 2025 11:08:49 -0700 Subject: [PATCH 05/12] Fix offline DP compatibility Signed-off-by: Nick Hill --- vllm/config.py | 1 - vllm/entrypoints/cli/serve.py | 1 + vllm/v1/engine/core_client.py | 50 +++++++++++++++++++++++------------ vllm/v1/utils.py | 13 +++++---- 4 files changed, 42 insertions(+), 23 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index 57ab0ef05964..9152d847a6f8 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1546,7 +1546,6 @@ def __post_init__(self) -> None: if self.data_parallel_size > 1: # Data parallel was specified in the engine args. self.data_parallel_master_port = get_open_port() - # TODO multi-node else: # Otherwise fall back to env vars (e.g. for offline SPMD case). self.data_parallel_size = envs.VLLM_DP_SIZE diff --git a/vllm/entrypoints/cli/serve.py b/vllm/entrypoints/cli/serve.py index 801dd6db3d7b..28362613dcee 100644 --- a/vllm/entrypoints/cli/serve.py +++ b/vllm/entrypoints/cli/serve.py @@ -105,6 +105,7 @@ def run_headless(args: argparse.Namespace): target_fn=EngineCoreProc.run_engine_core, local_engine_count=local_engine_count, start_index=engine_args.data_parallel_start_rank, + local_start_index=0, vllm_config=vllm_config, on_head_node=False, input_address=input_address, diff --git a/vllm/v1/engine/core_client.py b/vllm/v1/engine/core_client.py index c5d9b16f8fee..144ea5bc9d64 100644 --- a/vllm/v1/engine/core_client.py +++ b/vllm/v1/engine/core_client.py @@ -363,11 +363,28 @@ def sigusr1_handler(signum, frame): self._finalizer = weakref.finalize(self, self.resources) parallel_config = vllm_config.parallel_config - dp_size = parallel_config.data_parallel_size local_engine_count = parallel_config.data_parallel_size_local + start_index = parallel_config.data_parallel_rank + local_start_index = parallel_config.data_parallel_rank_local + + # SPMD mode is where there is an LLM instance per DP rank and one + # core engine per LLM, see examples/offline_inference/data_parallel.py. + spmd_mode = local_start_index is not None + if spmd_mode: + assert local_engine_count == 1 + self.core_engines = [ + CoreEngine(index=local_start_index, local=True) + ] + else: + assert start_index == 0 + local_start_index = 0 + self.core_engines = [ + CoreEngine(index=i, local=(i < local_engine_count)) + for i in range(parallel_config.data_parallel_size) + ] input_address, output_address = self._get_zmq_addresses( - parallel_config) + parallel_config, spmd_mode) # Create input and output sockets. self.input_socket = self.resources.input_socket = make_zmq_socket( @@ -378,6 +395,7 @@ def sigusr1_handler(signum, frame): zmq.constants.PULL) # Start local engines. if local_engine_count: + # In server mode, start_index and local_start_index will both be 0. self.resources.local_engine_manager = CoreEngineProcManager( EngineCoreProc.run_engine_core, vllm_config=vllm_config, @@ -386,12 +404,9 @@ def sigusr1_handler(signum, frame): input_address=input_address, on_head_node=True, local_engine_count=local_engine_count, - start_index=0) + start_index=start_index, + local_start_index=local_start_index) - self.core_engines = [ - CoreEngine(index=i, local=(i < local_engine_count)) - for i in range(dp_size) - ] self.core_engine = self.core_engines[0] # Wait for engine core process(es) to start. @@ -400,12 +415,13 @@ def sigusr1_handler(signum, frame): self.utility_results: dict[int, AnyFuture] = {} @staticmethod - def _get_zmq_addresses(parallel_config: ParallelConfig) -> tuple[str, str]: + def _get_zmq_addresses(parallel_config: ParallelConfig, + spmd_mode: bool) -> tuple[str, str]: """Returns (input_address, output_address).""" dp_size = parallel_config.data_parallel_size local_engine_count = parallel_config.data_parallel_size_local - if local_engine_count == dp_size: + if local_engine_count == dp_size or spmd_mode: input_address = get_open_zmq_ipc_path() output_address = get_open_zmq_ipc_path() else: @@ -422,8 +438,6 @@ def _wait_for_engine_startup(self, output_address: str, # Get a sync handle to the socket which can be sync or async. sync_input_socket = zmq.Socket.shadow(self.input_socket) - # TODO offline case compatibility - # Wait for engine core process(es) to send ready messages. local_count = parallel_config.data_parallel_size_local remote_count = len(self.core_engines) - local_count @@ -444,18 +458,20 @@ def _wait_for_engine_startup(self, output_address: str, # Receive HELLO and READY messages from the input socket. eng_identity, ready_msg_bytes = sync_input_socket.recv_multipart() eng_index = int.from_bytes(eng_identity, byteorder="little") - if eng_index > len(self.core_engines): - raise RuntimeError( - f"Message from engine rank larger than " - f"configured data parallel size: {eng_index}") - engine = self.core_engines[eng_index] + engine = next( + (e for e in self.core_engines if e.identity == eng_identity), + None) + if engine is None: + raise RuntimeError(f"Message from engine with unexpected data " + f"parallel rank: {eng_index}") msg = msgspec.msgpack.decode(ready_msg_bytes) status, local = msg["status"], msg["local"] if local != engine.local: raise RuntimeError(f"{status} message from " f"{'local' if local else 'remote'} " - f" engine {eng_index}, expected it to be " + f"engine {eng_index}, expected it to be " f"{'local' if engine.local else 'remote'}") + if status == "HELLO" and engine.state == CoreEngineState.NEW: # Send init message with DP config info. diff --git a/vllm/v1/utils.py b/vllm/v1/utils.py index adfdb86a7056..e6f947af4d86 100644 --- a/vllm/v1/utils.py +++ b/vllm/v1/utils.py @@ -105,6 +105,7 @@ def __init__( target_fn: Callable, local_engine_count: int, start_index: int, + local_start_index: int, vllm_config: VllmConfig, on_head_node: bool, input_address: str, @@ -121,14 +122,15 @@ def __init__( } self.processes = [] - for local_index in range(local_engine_count): - index = local_index + start_index + for index in range(local_engine_count): + local_index = local_start_index + index + global_index = start_index + index # Start EngineCore in background process. self.processes.append( context.Process(target=target_fn, - name=f"EngineCore_{index}", + name=f"EngineCore_{global_index}", kwargs=common_kwargs | { - "dp_rank": index, + "dp_rank": global_index, "local_dp_rank": local_index, })) @@ -172,7 +174,8 @@ def shutdown(procs: list[multiprocessing.Process], input_address: str): remaining = deadline - time.monotonic() if remaining <= 0: break - proc.join(remaining) + if proc.is_alive(): + proc.join(remaining) for proc in procs: if proc.is_alive(): From 8126f726c1d5f6d2308e23adf58cebbf0fc11399 Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Mon, 7 Apr 2025 15:42:14 -0700 Subject: [PATCH 06/12] Address some review comments Signed-off-by: Nick Hill --- vllm/engine/arg_utils.py | 7 --- vllm/entrypoints/cli/serve.py | 9 +++- vllm/v1/engine/core.py | 84 +++++++++++++++++++---------------- 3 files changed, 54 insertions(+), 46 deletions(-) diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 554638924ed5..7b6cbae8b152 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -118,7 +118,6 @@ class EngineArgs: tensor_parallel_size: int = 1 data_parallel_size: int = 1 data_parallel_size_local: Optional[int] = None - data_parallel_start_rank: int = 0 data_parallel_address: Optional[str] = None data_parallel_rpc_port: Optional[int] = None enable_expert_parallel: bool = False @@ -450,12 +449,6 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: default=EngineArgs.data_parallel_size_local, help='Number of data parallel replicas to run on ' 'this node.') - parser.add_argument('--data-parallel-start-rank', - '-dpr', - type=int, - default=EngineArgs.data_parallel_start_rank, - help='Starting data parallel rank for secondary ' - 'nodes.') parser.add_argument('--data-parallel-address', '-dpa', type=str, diff --git a/vllm/entrypoints/cli/serve.py b/vllm/entrypoints/cli/serve.py index 0c2af1514885..b9f64026d756 100644 --- a/vllm/entrypoints/cli/serve.py +++ b/vllm/entrypoints/cli/serve.py @@ -60,6 +60,13 @@ def subparser_init( default=False, help="Run in headless mode. See multi-node data parallel " "documentation for more details.") + serve_parser.add_argument( + '--data-parallel-start-rank', + '-dpr', + type=int, + default=0, + help='Starting data parallel rank for secondary ' + 'nodes.') serve_parser.add_argument( "--config", type=str, @@ -105,7 +112,7 @@ def run_headless(args: argparse.Namespace): engine_manager = CoreEngineProcManager( target_fn=EngineCoreProc.run_engine_core, local_engine_count=local_engine_count, - start_index=engine_args.data_parallel_start_rank, + start_index=args.data_parallel_start_rank, local_start_index=0, vllm_config=vllm_config, on_head_node=False, diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index 1399604692e4..d0668fc8df0d 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -43,6 +43,7 @@ logger = init_logger(__name__) POLLING_TIMEOUT_S = 2.5 +HANDSHAKE_TIMEOUT_MINS = 5 _R = TypeVar('_R') # Return type for collective_rpc @@ -324,43 +325,47 @@ def __init__( zmq.DEALER, identity=identity, bind=False) - - # Register engine with front-end. - output_address = self.startup_handshake(input_socket, on_head_node, - vllm_config.parallel_config) - - # Set up data parallel environment. - self._init_data_parallel(vllm_config) - - # Initialize engine core and model. - super().__init__(vllm_config, executor_class, log_stats) - - self.step_fn = (self.step if self.batch_queue is None else - self.step_with_batch_queue) - - self.global_unfinished_reqs = False - - # Send ready message. - input_socket.send( - msgspec.msgpack.encode({ - "status": "READY", - "local": on_head_node - })) - - # Background Threads and Queues for IO. These enable us to - # overlap ZMQ socket IO with GPU since they release the GIL, - # and to overlap some serialization/deserialization with the - # model forward pass. - # Threads handle Socket <-> Queues and core_busy_loop uses Queue. - self.input_queue: queue.Queue[tuple[EngineCoreRequestType, - Any]] = queue.Queue() - self.output_queue: queue.Queue[EngineCoreOutputs] = queue.Queue() - threading.Thread(target=self.process_input_socket, - args=(input_socket, ), - daemon=True).start() - threading.Thread(target=self.process_output_socket, - args=(output_address, engine_index), - daemon=True).start() + try: + # Register engine with front-end. + output_address = self.startup_handshake( + input_socket, on_head_node, vllm_config.parallel_config) + + # Set up data parallel environment. + self._init_data_parallel(vllm_config) + + # Initialize engine core and model. + super().__init__(vllm_config, executor_class, log_stats) + + self.step_fn = (self.step if self.batch_queue is None else + self.step_with_batch_queue) + + self.global_unfinished_reqs = False + + # Send ready message. + input_socket.send( + msgspec.msgpack.encode({ + "status": "READY", + "local": on_head_node + })) + + # Background Threads and Queues for IO. These enable us to + # overlap ZMQ socket IO with GPU since they release the GIL, + # and to overlap some serialization/deserialization with the + # model forward pass. + # Threads handle Socket <-> Queues and core_busy_loop uses Queue. + self.input_queue: queue.Queue[tuple[EngineCoreRequestType, + Any]] = queue.Queue() + self.output_queue: queue.Queue[EngineCoreOutputs] = queue.Queue() + threading.Thread(target=self.process_input_socket, + args=(input_socket, ), + daemon=True).start() + input_socket = None + threading.Thread(target=self.process_output_socket, + args=(output_address, engine_index), + daemon=True).start() + finally: + if input_socket is not None: + input_socket.close(linger=0) @staticmethod def startup_handshake(input_socket: zmq.Socket, on_head_node: bool, @@ -375,7 +380,10 @@ def startup_handshake(input_socket: zmq.Socket, on_head_node: bool, # Receive initialization message. logger.info("Waiting for init message from front-end.") - input_socket.poll(timeout=5 * 60 * 1000) + if not input_socket.poll(timeout=HANDSHAKE_TIMEOUT_MINS * 60 * 1000): + raise RuntimeError("Did not receive response from front-end " + f"process within {HANDSHAKE_TIMEOUT_MINS} " + f"minutes") init_bytes = input_socket.recv() init_message = msgspec.msgpack.decode(init_bytes) logger.debug("Received init message: %s", init_message) From 8fdc6f5c120051bd8894443e53f3970d9207aa78 Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Mon, 7 Apr 2025 16:09:42 -0700 Subject: [PATCH 07/12] Address other minor review comments Signed-off-by: Nick Hill --- vllm/v1/engine/core.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index d0668fc8df0d..f52d010eefd4 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -318,7 +318,7 @@ def __init__( engine_index: int = 0, ): # Create input socket. - input_ctx = zmq.Context() # type: ignore[attr-defined] + input_ctx = zmq.Context() identity = engine_index.to_bytes(length=2, byteorder="little") input_socket = make_zmq_socket(input_ctx, input_address, @@ -389,7 +389,7 @@ def startup_handshake(input_socket: zmq.Socket, on_head_node: bool, logger.debug("Received init message: %s", init_message) output_socket_address = init_message["output_socket_address"] - #TBD maybe replace IP with configured head node address + #TBD(nick) maybe replace IP with configured head node address received_parallel_config = init_message["parallel_config"] for key, value in received_parallel_config.items(): From efa8ad864370f2d2c0bb281354272bbc46469842 Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Thu, 17 Apr 2025 12:19:21 -0700 Subject: [PATCH 08/12] Fix merge error, address @russellb's ipv6 review comment Signed-off-by: Nick Hill --- vllm/utils.py | 4 ++++ vllm/v1/engine/core_client.py | 8 ++++---- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/vllm/utils.py b/vllm/utils.py index c6e2afff72d7..350dce8f02b8 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -604,6 +604,10 @@ def is_valid_ipv6_address(address: str) -> bool: def get_distributed_init_method(ip: str, port: int) -> str: + return get_tcp_uri(ip, port) + + +def get_tcp_uri(ip: str, port: int) -> str: # Brackets are not permitted in ipv4 addresses, # see https://github.com/python/cpython/issues/103848 return f"tcp://[{ip}]:{port}" if ":" in ip else f"tcp://{ip}:{port}" diff --git a/vllm/v1/engine/core_client.py b/vllm/v1/engine/core_client.py index bc2d5f07ebbf..ba3bfbe66062 100644 --- a/vllm/v1/engine/core_client.py +++ b/vllm/v1/engine/core_client.py @@ -19,7 +19,7 @@ from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.utils import (get_open_port, get_open_zmq_inproc_path, - get_open_zmq_ipc_path, make_zmq_socket) + get_open_zmq_ipc_path, get_tcp_uri, make_zmq_socket) from vllm.v1.engine import (EngineCoreOutputs, EngineCoreRequest, EngineCoreRequestType, UtilityOutput) from vllm.v1.engine.core import EngineCore, EngineCoreProc @@ -423,8 +423,8 @@ def _get_zmq_addresses(parallel_config: ParallelConfig, host = parallel_config.data_parallel_master_ip input_port = parallel_config.data_parallel_rpc_port output_port = get_open_port() - input_address = f"tcp://{host}:{input_port}" - output_address = f"tcp://{host}:{output_port}" + input_address = get_tcp_uri(host, input_port) + output_address = get_tcp_uri(host, output_port) return input_address, output_address @@ -496,7 +496,7 @@ def _wait_for_engine_startup(self, output_address: str, parallel_config.data_parallel_size, }, }) - sync_input_socket.send_multipart((eng_identity, init_message), + sync_input_socket.send_multipart((eng_identity, *init_message), copy=False) conn_pending[0 if local else 1] -= 1 start_pending[0 if local else 1] += 1 From 30ab14b38791f7eaa8d4eec2e07756a8459c8282 Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Fri, 18 Apr 2025 08:58:33 -0700 Subject: [PATCH 09/12] Hande ipv6 URIs in all places Signed-off-by: Nick Hill --- vllm/distributed/utils.py | 3 ++- vllm/entrypoints/cli/serve.py | 4 ++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/vllm/distributed/utils.py b/vllm/distributed/utils.py index 2cb57afd4566..442a79bc7162 100644 --- a/vllm/distributed/utils.py +++ b/vllm/distributed/utils.py @@ -21,6 +21,7 @@ import vllm.envs as envs from vllm.logger import init_logger +from vllm.utils import get_tcp_uri logger = init_logger(__name__) @@ -283,7 +284,7 @@ def stateless_init_torch_distributed_process_group( always formed with process 1, 2, ..., 8, and the additional communication channel is formed with process 9 and 10. """ - init_method = f"tcp://{host}:{port}" + init_method = get_tcp_uri(host, port) backend = Backend(backend) # it is basically string timeout = _get_default_timeout(backend) diff --git a/vllm/entrypoints/cli/serve.py b/vllm/entrypoints/cli/serve.py index b9f64026d756..4b3e134a485d 100644 --- a/vllm/entrypoints/cli/serve.py +++ b/vllm/entrypoints/cli/serve.py @@ -12,7 +12,7 @@ validate_parsed_serve_args) from vllm.logger import init_logger from vllm.usage.usage_lib import UsageContext -from vllm.utils import FlexibleArgumentParser +from vllm.utils import FlexibleArgumentParser, get_tcp_uri from vllm.v1.engine.core import EngineCoreProc from vllm.v1.engine.core_client import CoreEngineProcManager from vllm.v1.executor.abstract import Executor @@ -98,7 +98,7 @@ def run_headless(args: argparse.Namespace): local_engine_count = parallel_config.data_parallel_size_local host = parallel_config.data_parallel_master_ip port = engine_args.data_parallel_rpc_port # add to config too - input_address = f"tcp://{host}:{port}" + input_address = get_tcp_uri(host, port) if local_engine_count <= 0: raise RuntimeError("data_parallel_size_local must be > 0 in " From acc5af341fdbf91a7ae9b1d0b21526533fb52bf7 Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Fri, 18 Apr 2025 19:10:31 -0700 Subject: [PATCH 10/12] Fix head node with no engines, don't require dp size on other nodes Signed-off-by: Nick Hill --- vllm/config.py | 14 ++++++++------ vllm/entrypoints/cli/serve.py | 13 +++++++++++-- vllm/v1/engine/core.py | 5 ++++- 3 files changed, 23 insertions(+), 9 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index d50987a77a80..6d4c5f1a77cf 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1621,13 +1621,16 @@ class is dynamically inherited by the worker class. This is used to inject world_size: int = field(init=False) """world_size is TPxPP, it affects the number of workers we create.""" - world_size_across_dp: int = field(init=False) - """world_size_across_dp is TPxPPxDP, it is the size of the world - including data parallelism.""" rank: int = 0 """Global rank in distributed setup.""" + @property + def world_size_across_dp(self) -> int: + """world_size_across_dp is TPxPPxDP, it is the size of the world + including data parallelism.""" + return self.world_size * self.data_parallel_size + def get_next_dp_init_port(self) -> int: """ We might need to initialize process groups in multiple @@ -1680,6 +1683,7 @@ def compute_hash(self): factors: list[Any] = [] factors.append(self.pipeline_parallel_size) factors.append(self.tensor_parallel_size) + factors.append(self.data_parallel_size) return hashlib.sha256(str(factors).encode()).hexdigest() def __post_init__(self) -> None: @@ -1690,7 +1694,7 @@ def __post_init__(self) -> None: raise ValueError( "data_parallel_size_local must be <= data_parallel_size") - if self.data_parallel_size > 1: + if self.data_parallel_size > 1 or self.data_parallel_size_local == 0: # Data parallel was specified in the engine args. self.data_parallel_master_port = get_open_port() else: @@ -1701,8 +1705,6 @@ def __post_init__(self) -> None: self.data_parallel_master_ip = envs.VLLM_DP_MASTER_IP self.data_parallel_master_port = envs.VLLM_DP_MASTER_PORT - self.world_size_across_dp = self.world_size * self.data_parallel_size - if self.distributed_executor_backend == "external_launcher": import os os.environ["VLLM_ENABLE_V1_MULTIPROCESSING"] = "0" diff --git a/vllm/entrypoints/cli/serve.py b/vllm/entrypoints/cli/serve.py index 4b3e134a485d..04be7c033998 100644 --- a/vllm/entrypoints/cli/serve.py +++ b/vllm/entrypoints/cli/serve.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 import argparse +import signal import uvloop @@ -65,8 +66,7 @@ def subparser_init( '-dpr', type=int, default=0, - help='Starting data parallel rank for secondary ' - 'nodes.') + help='Starting data parallel rank for secondary nodes.') serve_parser.add_argument( "--config", type=str, @@ -104,6 +104,14 @@ def run_headless(args: argparse.Namespace): raise RuntimeError("data_parallel_size_local must be > 0 in " "headless mode") + # Catch SIGTERM and SIGINT to allow graceful shutdown. + def signal_handler(signum, frame): + logger.debug("Received %d signal.", signum) + raise SystemExit + + signal.signal(signal.SIGTERM, signal_handler) + signal.signal(signal.SIGINT, signal_handler) + logger.info( "Launching %d data parallel engine(s) in headless mode, " "with head node address %s.", local_engine_count, input_address) @@ -124,4 +132,5 @@ def run_headless(args: argparse.Namespace): try: engine_manager.join_first() finally: + logger.info("Shutting down.") engine_manager.close() diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index 47e1b70cb046..b218eda84182 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -338,6 +338,9 @@ def __init__( output_address = self.startup_handshake( input_socket, on_head_node, vllm_config.parallel_config) + # Update config which may have changed from the handshake. + vllm_config.__post_init__() + # Set up data parallel environment. self._init_data_parallel(vllm_config) @@ -436,7 +439,7 @@ def signal_handler(signum, frame): try: parallel_config: ParallelConfig = kwargs[ "vllm_config"].parallel_config - if parallel_config.data_parallel_size > 1: + if parallel_config.data_parallel_size > 1 or dp_rank > 0: # Set data parallel rank for this engine process. parallel_config.data_parallel_rank = dp_rank parallel_config.data_parallel_rank_local = local_dp_rank From 42c30bf4ba68bc511709d8529c0db4d7c419a1d5 Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Mon, 12 May 2025 09:33:08 -0700 Subject: [PATCH 11/12] Fix test_startup_failure Signed-off-by: Nick Hill --- tests/v1/engine/test_engine_core_client.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/tests/v1/engine/test_engine_core_client.py b/tests/v1/engine/test_engine_core_client.py index fd8d1fd7ff48..452fe1e37e2c 100644 --- a/tests/v1/engine/test_engine_core_client.py +++ b/tests/v1/engine/test_engine_core_client.py @@ -18,9 +18,10 @@ from vllm.usage.usage_lib import UsageContext from vllm.v1.engine import EngineCoreRequest from vllm.v1.engine.core import EngineCore -from vllm.v1.engine.core_client import (AsyncMPClient, CoreEngine, - EngineCoreClient, SyncMPClient) +from vllm.v1.engine.core_client import (AsyncMPClient, EngineCoreClient, + SyncMPClient) from vllm.v1.executor.abstract import Executor +from vllm.v1.utils import CoreEngineProcManager from ...distributed.conftest import MockSubscriber from ...utils import create_new_process_for_each_test @@ -348,13 +349,13 @@ def test_startup_failure(monkeypatch: pytest.MonkeyPatch): # Monkey-patch to extract core process pid while it's starting. core_proc_pid = [None] - ce_ctor = CoreEngine.__init__ + cepm_ctor = CoreEngineProcManager.__init__ - def patched_ce_ctor(self, *args, **kwargs): - ce_ctor(self, *args, **kwargs) - core_proc_pid[0] = self.proc_handle.proc.pid + def patched_cepm_ctor(self: CoreEngineProcManager, *args, **kwargs): + cepm_ctor(self, *args, **kwargs) + core_proc_pid[0] = self.processes[0].pid - m.setattr(CoreEngine, "__init__", patched_ce_ctor) + m.setattr(CoreEngineProcManager, "__init__", patched_cepm_ctor) t = time.time() engine_args = EngineArgs(model=MODEL_NAME) From 3904d10662f3dc1cdd301125743ac4061992dd15 Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Mon, 12 May 2025 10:00:25 -0700 Subject: [PATCH 12/12] Fix mock config related test failure Signed-off-by: Nick Hill --- tests/async_engine/test_async_llm_engine.py | 2 +- vllm/config.py | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/async_engine/test_async_llm_engine.py b/tests/async_engine/test_async_llm_engine.py index 48e2e31e5db8..b6f44871497c 100644 --- a/tests/async_engine/test_async_llm_engine.py +++ b/tests/async_engine/test_async_llm_engine.py @@ -41,7 +41,7 @@ def __init__(self): self.abort_request_calls = 0 self.request_id = None # Ugly, remove dependency when possible - self.parallel_config = ParallelConfig(1, 1, False) + self.parallel_config = ParallelConfig() self.model_config = MockModelConfig() async def step_async(self, virtual_engine): diff --git a/vllm/config.py b/vllm/config.py index d83232e2e1b2..ff87ae9092f1 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1759,7 +1759,8 @@ def __post_init__(self) -> None: if self.data_parallel_size_local > self.data_parallel_size: raise ValueError( - "data_parallel_size_local must be <= data_parallel_size") + f"data_parallel_size_local ({self.data_parallel_size_local}) " + f"must be <= data_parallel_size ({self.data_parallel_size})") if self.data_parallel_size > 1 or self.data_parallel_size_local == 0: # Data parallel was specified in the engine args.