Skip to content

Commit 3a77942

Browse files
joerundesumitd2
authored andcommitted
[Bugfix] Use heartbeats instead of health checks (vllm-project#8583)
Signed-off-by: Sumit Dubey <[email protected]>
1 parent f8cb85e commit 3a77942

File tree

4 files changed

+87
-63
lines changed

4 files changed

+87
-63
lines changed

tests/mq_llm_engine/test_error_handling.py

Lines changed: 4 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -153,27 +153,20 @@ async def test_failed_abort(tmp_socket):
153153
await client.check_health()
154154

155155
# Trigger an abort on the client side.
156-
async def bad_abort_after_2s():
157-
await asyncio.sleep(2.0)
158-
await client.abort(request_id="foo")
156+
# This request ID does not exist, and will cause the engine to error
157+
await client.abort(request_id="foo")
159158

160-
# Trigger an abort in 2s from now.
161-
abort_task = asyncio.create_task(bad_abort_after_2s())
162-
163-
# Exception in abort() will happen during this generation.
164-
# This will kill the engine and should return ENGINE_DEAD_ERROR
159+
# Future generation requests will now fail
165160
# with reference to the original KeyError("foo")
166161
with pytest.raises(MQEngineDeadError) as execinfo:
167162
async for _ in client.generate(
168163
inputs="Hello my name is",
169-
sampling_params=SamplingParams(max_tokens=2000),
164+
sampling_params=SamplingParams(max_tokens=10),
170165
request_id=uuid.uuid4()):
171166
pass
172167
assert "KeyError" in repr(execinfo.value)
173168
assert client.errored
174169

175-
await abort_task
176-
177170
# This should raise the original error.
178171
with pytest.raises(RAISED_ERROR):
179172
await client.check_health()

vllm/engine/multiprocessing/__init__.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -43,10 +43,6 @@ class RPCAbortRequest:
4343
request_id: str
4444

4545

46-
class RPCHealthRequest:
47-
pass
48-
49-
5046
class RPCStartupRequest(Enum):
5147
IS_SERVER_READY = 1
5248

@@ -56,8 +52,7 @@ class RPCStartupResponse:
5652
tracing_enabled: bool
5753

5854

59-
RPC_REQUEST_T = Union[RPCProcessRequest, RPCAbortRequest, RPCHealthRequest,
60-
RPCStartupRequest]
55+
RPC_REQUEST_T = Union[RPCProcessRequest, RPCAbortRequest, RPCStartupRequest]
6156

6257
REQUEST_OUTPUTS_T = Union[List[RequestOutput], RPCError]
6358

vllm/engine/multiprocessing/client.py

Lines changed: 22 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,8 @@
2020
IPC_HEALTH_EXT, IPC_INPUT_EXT,
2121
IPC_OUTPUT_EXT, RPC_REQUEST_T,
2222
VLLM_RPC_SUCCESS_STR, RPCAbortRequest,
23-
RPCError, RPCHealthRequest,
24-
RPCProcessRequest, RPCStartupRequest,
25-
RPCStartupResponse)
23+
RPCError, RPCProcessRequest,
24+
RPCStartupRequest, RPCStartupResponse)
2625
# yapf: enable
2726
from vllm.envs import VLLM_RPC_TIMEOUT
2827
from vllm.inputs import PromptInputs
@@ -95,9 +94,9 @@ def __init__(self, ipc_path: str, engine_config: EngineConfig):
9594
self.output_socket: Socket = self.context.socket(zmq.constants.PULL)
9695
self.output_socket.connect(f"{ipc_path}{IPC_OUTPUT_EXT}")
9796

98-
# IPC path for ack of check_health requests.
99-
self.health_socket: Socket = self.context.socket(zmq.constants.PULL)
100-
self.health_socket.connect(f"{ipc_path}{IPC_HEALTH_EXT}")
97+
# IPC path for acking heartbeats.
98+
self.heartbeat_socket: Socket = self.context.socket(zmq.constants.PULL)
99+
self.heartbeat_socket.connect(f"{ipc_path}{IPC_HEALTH_EXT}")
101100

102101
# IPC path for the data socket.
103102
self.data_ipc_path = f"{ipc_path}{IPC_DATA_EXT}"
@@ -124,34 +123,28 @@ def get_data_socket(self) -> Iterator[Socket]:
124123
finally:
125124
socket.close(linger=0)
126125

127-
async def run_check_health_loop(self, timeout: int):
128-
"""Background loop that continually probes the RPCServer for health.
129-
130-
The loop sends CHECK_HEALTH requests to the INPUT_SOCKET, which
131-
the MQLLMEngine server is blocking on.
132-
133-
The Server replies on the HEALTH_SOCKET (rather than on the
134-
OUTPUT_SOCKET such that the messages are not intermingled with
135-
output streaming).
126+
async def run_heartbeat_loop(self, timeout: int):
127+
"""Background loop that continually listens to the RPCServer for
128+
heartbeats.
136129
"""
137-
138130
try:
139131
while True:
140-
if await self.health_socket.poll(timeout=timeout) == 0:
141-
# Wakeup every N seconds and do a health probe.
142-
await self._send_one_way_rpc_request(
143-
RPCHealthRequest(), self.input_socket)
144-
145-
# Wait for ack from the health socket.
146-
await self._await_ack(error_message="Health check failed.",
147-
socket=self.health_socket)
132+
if await self.heartbeat_socket.poll(timeout=timeout) == 0:
133+
# No heartbeat was received. Set error and exit the loop
134+
self._set_errored(
135+
TimeoutError("No heartbeat received "
136+
"from MQLLMEngine"))
137+
logger.debug("Shutting down MQLLMEngineClient check "
138+
"health loop due to timeout")
139+
break
140+
148141
else:
149-
# Server sent a health status message unprompted.
142+
# Heartbeat received- check the message
150143
await self._check_success(
151-
error_message="Health check failed.",
152-
socket=self.health_socket)
144+
error_message="Heartbeat failed.",
145+
socket=self.heartbeat_socket)
153146

154-
logger.debug("Health probe successful.")
147+
logger.debug("Heartbeat successful.")
155148

156149
except asyncio.CancelledError:
157150
logger.debug("Shutting down MQLLMEngineClient check health loop.")
@@ -234,7 +227,7 @@ async def setup(self):
234227

235228
# Start health_loop.
236229
self.health_loop = asyncio.create_task(
237-
self.run_check_health_loop(timeout=VLLM_RPC_TIMEOUT))
230+
self.run_heartbeat_loop(timeout=VLLM_RPC_TIMEOUT))
238231

239232
def close(self):
240233
"""Destroy the ZeroMQ Context."""

vllm/engine/multiprocessing/engine.py

Lines changed: 60 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
import pickle
22
import signal
3+
import threading
4+
import time
35
from contextlib import contextmanager
46
from typing import Iterator, List, Optional, Union
57

@@ -15,10 +17,10 @@
1517
IPC_HEALTH_EXT, IPC_INPUT_EXT,
1618
IPC_OUTPUT_EXT, REQUEST_OUTPUTS_T,
1719
VLLM_RPC_SUCCESS_STR, RPCAbortRequest,
18-
RPCError, RPCHealthRequest,
19-
RPCProcessRequest, RPCStartupRequest,
20-
RPCStartupResponse)
20+
RPCError, RPCProcessRequest,
21+
RPCStartupRequest, RPCStartupResponse)
2122
# yapf: enable
23+
from vllm.envs import VLLM_RPC_TIMEOUT
2224
from vllm.logger import init_logger
2325
from vllm.outputs import RequestOutput
2426
from vllm.usage.usage_lib import UsageContext
@@ -91,16 +93,30 @@ def __init__(self,
9193
self.output_socket = self.ctx.socket(zmq.constants.PUSH)
9294
self.output_socket.bind(f"{ipc_path}{IPC_OUTPUT_EXT}")
9395

94-
# Send health status back to client.
95-
self.health_socket = self.ctx.socket(zmq.constants.PUSH)
96-
self.health_socket.bind(f"{ipc_path}{IPC_HEALTH_EXT}")
96+
# Send heartbeats back to client.
97+
self.heartbeat_socket = self.ctx.socket(zmq.constants.PUSH)
98+
self.heartbeat_socket.bind(f"{ipc_path}{IPC_HEALTH_EXT}")
9799

98100
# IPC path for the data socket.
99101
self.data_ipc_path = f"{ipc_path}{IPC_DATA_EXT}"
100102

101103
# Error state.
102104
self._errored_with: Optional[BaseException] = None
103105

106+
# Heartbeat thread
107+
self.heartbeat_thread = threading.Thread(target=self._heartbeat_loop,
108+
daemon=True)
109+
self._heartbeat_stop_event = threading.Event()
110+
# The heartbeat needs to be faster than what the client will wait for
111+
# The VLLM_RPC_TIMEOUT duration is in ms, and we need one in seconds
112+
self.heartbeat_interval_seconds = VLLM_RPC_TIMEOUT / 5000.0
113+
114+
self._last_alive_time = time.time()
115+
# The heartbeats can tolerate a long period of the engine chugging
116+
# away at a generation request.
117+
# The VLLM_RPC_TIMEOUT duration is in ms, and we need one in seconds
118+
self.last_alive_threshold = VLLM_RPC_TIMEOUT * 3.0 / 1000.0
119+
104120
@property
105121
def dead_error(self) -> BaseException:
106122
if self._errored_with is not None:
@@ -131,6 +147,8 @@ def start(self):
131147
try:
132148
logger.debug("Starting Startup Loop.")
133149
self.run_startup_loop()
150+
logger.debug("Starting heartbeat thread")
151+
self.heartbeat_thread.start()
134152
logger.debug("Starting Engine Loop.")
135153
self.run_engine_loop()
136154
except Exception as e:
@@ -144,6 +162,7 @@ def start(self):
144162
def cleanup(self):
145163
"""Cleanup zeromq state on shutdown."""
146164
# Closes all sockets and destroys context.
165+
self._heartbeat_stop_event.set()
147166
self.ctx.destroy(linger=0)
148167
del self.engine
149168

@@ -182,9 +201,11 @@ def run_engine_loop(self):
182201
"""Core busy loop of the LLMEngine."""
183202

184203
while True:
204+
self._alive()
185205
if not self.engine.has_unfinished_requests():
186206
# Poll until there is work to do.
187207
while self.input_socket.poll(timeout=POLLING_TIMEOUT_MS) == 0:
208+
self._alive()
188209
self.engine.do_log_stats()
189210
logger.debug("Waiting for new requests in engine loop.")
190211

@@ -200,7 +221,6 @@ def run_engine_loop(self):
200221

201222
def engine_step(self) -> List[RequestOutput]:
202223
"""Engine step wrapper with error handling."""
203-
204224
try:
205225
return self.engine.step()
206226
except SystemExit:
@@ -229,10 +249,9 @@ def handle_new_input(self):
229249
self._handle_process_request(request)
230250
elif isinstance(request, RPCAbortRequest):
231251
self._handle_abort_request(request)
232-
elif isinstance(request, RPCHealthRequest):
233-
self._handle_health_request()
234252
else:
235-
raise ValueError("Unknown RPCRequest Type: {request}")
253+
raise ValueError("Unknown RPCRequest Type: "
254+
f"{type(request)}")
236255

237256
except Exception as e:
238257
self._set_errored(e)
@@ -279,13 +298,32 @@ def _handle_abort_request(self, request: RPCAbortRequest):
279298
if self.log_requests:
280299
logger.info("Aborted request %s.", request.request_id)
281300

282-
def _handle_health_request(self):
301+
def _heartbeat_loop(self):
302+
while not self._heartbeat_stop_event.wait(
303+
timeout=self.heartbeat_interval_seconds):
304+
# Loops until the stop event is set
305+
self._heartbeat()
306+
307+
logger.debug("Exiting MQLLMEngine heartbeat thread")
308+
309+
def _heartbeat(self):
310+
# Send unhealthy if engine has already errored
283311
if self._errored_with is not None:
284312
self._send_unhealthy(self._errored_with)
285313

286-
# Raises error if unhealthy.
287-
self.engine.check_health()
288-
self._send_healthy()
314+
# Check for life of the main loop
315+
elif time.time() - self._last_alive_time > self.last_alive_threshold:
316+
self._send_unhealthy(RuntimeError("Engine loop has died"))
317+
318+
else:
319+
# Otherwise- check health of the engine
320+
# self.engine.check_health() raises on unhealthy
321+
try:
322+
self.engine.check_health()
323+
self._send_healthy()
324+
except Exception as e:
325+
self._set_errored(e)
326+
self._send_unhealthy(e)
289327

290328
def _send_outputs(self, outputs: REQUEST_OUTPUTS_T):
291329
"""Send List of RequestOutput to RPCClient."""
@@ -295,12 +333,14 @@ def _send_outputs(self, outputs: REQUEST_OUTPUTS_T):
295333

296334
def _send_healthy(self):
297335
"""Send HEALTHY message to RPCClient."""
298-
self.health_socket.send_multipart(HEALTHY_RESPONSE, copy=False)
336+
if not self.heartbeat_socket.closed:
337+
self.heartbeat_socket.send_multipart(HEALTHY_RESPONSE, copy=False)
299338

300339
def _send_unhealthy(self, error: BaseException):
301340
"""Send UNHEALTHY message to RPCClient."""
302-
error_bytes = pickle.dumps(error)
303-
self.health_socket.send_multipart((error_bytes, ), copy=False)
341+
if not self.heartbeat_socket.closed:
342+
error_bytes = pickle.dumps(error)
343+
self.heartbeat_socket.send_multipart((error_bytes, ), copy=False)
304344

305345
def _async_socket_engine_callback(self,
306346
request_outputs: REQUEST_OUTPUTS_T):
@@ -313,6 +353,9 @@ def _set_errored(self, e: BaseException):
313353
if self._errored_with is None:
314354
self._errored_with = e
315355

356+
def _alive(self):
357+
self._last_alive_time = time.time()
358+
316359

317360
def run_mp_engine(engine_args: AsyncEngineArgs, usage_context: UsageContext,
318361
ipc_path: str):

0 commit comments

Comments
 (0)