Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions vllm/config/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,12 @@ class SchedulerConfig:
structured outputs, speculative decoding, and pipeline parallelism.
"""

async_execute_model: bool = False
"""EXPERIMENTAL: If set to True, perform async model execution.
This may help reduce the CPU overheads, leading to better latency
and throughput. Moreover, this rely on async scheduling.
"""

def compute_hash(self) -> str:
"""
WARNING: Whenever a new field is added to this config,
Expand Down Expand Up @@ -247,6 +253,10 @@ def __post_init__(self) -> None:
self.scheduler_cls = (
"vllm.v1.core.sched.async_scheduler.AsyncScheduler")

if self.async_execute_model:
assert self.async_scheduling, (
"async_execute_model requires async_scheduling to be True.")

@model_validator(mode='after')
def _verify_args(self) -> Self:
if (self.max_num_batched_tokens < self.max_model_len
Expand Down
45 changes: 30 additions & 15 deletions vllm/distributed/device_communicators/shm_broadcast.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@ class SpinTimer:
def record_activity(self):
pass

def spin(self):
sched_yield()
def spin(self, sleep_time: Optional[float] = None):
sched_yield(sleep_time)


class SpinSleepTimer(SpinTimer):
Expand Down Expand Up @@ -370,7 +370,11 @@ def wait_until_ready(self):
assert recv == b"READY"

@contextmanager
def acquire_write(self, timeout: Optional[float] = None):
def acquire_write(
self,
timeout: Optional[float] = None,
sleep_time: Optional[float] = None,
):
assert self._is_writer, "Only writers can acquire write"
start_time = time.monotonic()
n_warning = 1
Expand All @@ -385,7 +389,7 @@ def acquire_write(self, timeout: Optional[float] = None):
# we need to wait until it is read by all readers

# Release the processor to other threads
sched_yield()
sched_yield(sleep_time)

# if we wait for a long time, log a message
if (time.monotonic() - start_time
Expand Down Expand Up @@ -428,9 +432,12 @@ def acquire_write(self, timeout: Optional[float] = None):
break

@contextmanager
def acquire_read(self,
timeout: Optional[float] = None,
cancel: Optional[Event] = None):
def acquire_read(
self,
timeout: Optional[float] = None,
cancel: Optional[Event] = None,
sleep_time: Optional[float] = None,
):
assert self._is_local_reader, "Only readers can acquire read"
start_time = time.monotonic()
n_warning = 1
Expand All @@ -448,7 +455,7 @@ def acquire_read(self,
# we need to wait until it is written

# Release the processor to other threads
self._read_spin_timer.spin()
self._read_spin_timer.spin(sleep_time)

# if we wait for a long time, log a message
if (time.monotonic() - start_time
Expand Down Expand Up @@ -483,28 +490,36 @@ def acquire_read(self,
self._read_spin_timer.record_activity()
break

def enqueue(self, obj, timeout: Optional[float] = None):
def enqueue(
self,
obj,
timeout: Optional[float] = None,
sleep_time: Optional[float] = None,
):
""" Write to message queue with optional timeout (in seconds) """
assert self._is_writer, "Only writers can enqueue"
serialized_obj = pickle.dumps(obj, protocol=pickle.HIGHEST_PROTOCOL)
if self.n_local_reader > 0:
if len(serialized_obj) >= self.buffer.max_chunk_bytes:
with self.acquire_write(timeout) as buf:
with self.acquire_write(timeout, sleep_time) as buf:
buf[0] = 1 # overflow
self.local_socket.send(serialized_obj)
else:
with self.acquire_write(timeout) as buf:
with self.acquire_write(timeout, sleep_time) as buf:
buf[0] = 0 # not overflow
buf[1:len(serialized_obj) + 1] = serialized_obj
if self.n_remote_reader > 0:
self.remote_socket.send(serialized_obj)

def dequeue(self,
timeout: Optional[float] = None,
cancel: Optional[Event] = None):
def dequeue(
self,
timeout: Optional[float] = None,
cancel: Optional[Event] = None,
sleep_time: Optional[float] = None,
) -> Any:
""" Read from message queue with optional timeout (in seconds) """
if self._is_local_reader:
with self.acquire_read(timeout, cancel) as buf:
with self.acquire_read(timeout, cancel, sleep_time) as buf:
overflow = buf[0] == 1
if not overflow:
# no need to know the size of serialized object
Expand Down
11 changes: 9 additions & 2 deletions vllm/distributed/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,15 @@
and sys.version_info[2] >= 8))


def sched_yield():
if USE_SCHED_YIELD:
def sched_yield(sleep_time: Optional[float] = None):
# when we set more than one threads in Worker Process,
# os.sched_yield() and time.sleep(0) both set the thread to ready state,
# but the cpu may reschedule it immediately,
# so we add a small sleep time to make sure the thread is set to blocked state,
# and the cpu can schedule other threads.
if sleep_time is not None:
time.sleep(sleep_time)
elif USE_SCHED_YIELD:
os.sched_yield()
else:
time.sleep(0)
Expand Down
11 changes: 11 additions & 0 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,6 +445,8 @@ class EngineArgs:

async_scheduling: bool = SchedulerConfig.async_scheduling

async_execute_model: bool = SchedulerConfig.async_execute_model

kv_sharing_fast_prefill: bool = \
CacheConfig.kv_sharing_fast_prefill

Expand Down Expand Up @@ -864,6 +866,9 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
scheduler_group.add_argument("--async-scheduling",
**scheduler_kwargs["async_scheduling"])

scheduler_group.add_argument("--async-execute-model",
**scheduler_kwargs["async_execute_model"])

# vLLM arguments
vllm_kwargs = get_kwargs(VllmConfig)
vllm_group = parser.add_argument_group(
Expand Down Expand Up @@ -1254,6 +1259,12 @@ def create_engine_config(
raise ValueError("Async scheduling is not supported with "
"pipeline-parallel-size > 1.")

if self.async_execute_model:
# TODO(Ronald1995): Support async execute model with ray.
if self.distributed_executor_backend != "mp":
raise ValueError("Async execute model is only supported with "
"mp-based distributed executor backend.")

# Currently, async scheduling does not support speculative decoding.
# TODO(woosuk): Support it.
if self.speculative_config is not None:
Expand Down
8 changes: 8 additions & 0 deletions vllm/v1/engine/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,7 @@ def __init__(self,

self.request_block_hasher = get_request_block_hasher(
block_size, caching_hash_fn)
self.async_execute_model = self.vllm_config.scheduler_config.async_execute_model

def _initialize_kv_caches(
self, vllm_config: VllmConfig) -> tuple[int, int, KVCacheConfig]:
Expand Down Expand Up @@ -341,6 +342,13 @@ def step_with_batch_queue(
# but peeking the first element in a queue is not thread-safe,
# so we need more work.
if not scheduled_batch and not self.batch_queue.empty():
# when enable async_execute_model, we should not block to get
# future restult when total_num_scheduled_tokens equals to 0.
# cause in this case, it wont's send execute_model task to workers.
if (self.async_execute_model
and scheduler_output.total_num_scheduled_tokens == 0):
return engine_core_outputs, scheduled_batch
Comment on lines +348 to +350
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

This new conditional block introduces two critical issues:

  1. AttributeError Bug: scheduler_output can be None when this block is reached. This occurs if self.batch_queue is full, because scheduler.schedule() is not called, and scheduler_output remains None. Accessing scheduler_output.total_num_scheduled_tokens will then raise an AttributeError.

  2. Potential Livelock: Even if the AttributeError is fixed (e.g., by checking scheduler_output is not None), a logical flaw remains. If this condition is met, the function returns without processing items from self.batch_queue. Since the state that led to this condition might not change, subsequent calls to step_with_batch_queue could repeatedly hit the same condition, causing items in the queue to be starved and leading to a livelock.

The logic for when to process items from the queue versus returning early needs to be reconsidered to avoid these problems.


future, scheduler_output = self.batch_queue.get_nowait()

# Blocking until the first result is available.
Expand Down
38 changes: 37 additions & 1 deletion vllm/v1/executor/multiproc_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from multiprocessing.process import BaseProcess
from threading import Thread
from typing import Any, Callable, Optional, Union, cast
import queue

import cloudpickle

Expand Down Expand Up @@ -403,6 +404,12 @@ def __init__(
# Initializes a message queue for sending the model output
self.worker_response_mq = MessageQueue(1, 1)

# queue size and threadpool size are set to 2 to coincide with
# the max_concurrent_batches of the executor when enable async scheduling.
self.exe_queue = queue.Queue(2)
self.exe_thread_pool = ThreadPoolExecutor(
max_workers=2, thread_name_prefix="execute_model")

# Initialize device and loads weights
self.worker.init_device()
self.worker.load_model()
Expand Down Expand Up @@ -586,6 +593,12 @@ class ResponseStatus(Enum):

def worker_busy_loop(self):
"""Main busy loop for Multiprocessing Workers"""
async_execute_model = self.worker.vllm_config.scheduler_config.async_execute_model
events = {
"d2h_copy_event": threading.Event(),
"update_sampled_tokens_event": threading.Event()
}
exe_count = 0
while True:
method, args, kwargs, output_rank = self.rpc_broadcast_mq.dequeue()

Expand All @@ -594,7 +607,19 @@ def worker_busy_loop(self):
func = getattr(self.worker, method)
elif isinstance(method, bytes):
func = partial(cloudpickle.loads(method), self.worker)
output = func(*args, **kwargs)

if async_execute_model and func.__name__ == "execute_model":
args = (*args, exe_count, events)
output = self.execute_model_with_queue(
func,
*args,
**kwargs,
)
exe_count += 1
if not output:
continue
else:
output = func(*args, **kwargs)
except Exception as e:
# Notes have been introduced in python 3.11
if hasattr(e, "add_note"):
Expand All @@ -610,3 +635,14 @@ def worker_busy_loop(self):
if output_rank is None or self.rank == output_rank:
self.worker_response_mq.enqueue(
(WorkerProc.ResponseStatus.SUCCESS, output))

def execute_model_with_queue(self, func, *args, **kwargs):
"""Execute model with a queue for async execution."""
output = None
if not self.exe_queue.full():
output_future = self.exe_thread_pool.submit(func, *args, **kwargs)
self.exe_queue.put_nowait(output_future)
if self.exe_queue.full():
output = self.exe_queue.get().result()
self.exe_queue.task_done()
return output
Comment on lines +639 to +648
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The current implementation of execute_model_with_queue will lead to a deadlock. Here's why:

  1. On the first call to execute_model_with_queue, self.exe_queue is empty, so it's not full. A future is submitted and added to the queue. The function then returns None.
  2. In worker_busy_loop, because the output from execute_model_with_queue is None, the loop continues to the next iteration without sending a response back to the main process via self.worker_response_mq.enqueue() (due to the if not output: continue check).
  3. The MultiprocExecutor in the main process, which made the collective_rpc call, will hang indefinitely waiting for a response that will never arrive.

To prevent this deadlock, execute_model_with_queue must ensure that a response is sent for every execute_model RPC call. The pipelining logic needs to be revised to guarantee a reply, even for the first call that primes the pipeline.

2 changes: 1 addition & 1 deletion vllm/v1/sample/rejection_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ def parse_output(
Returns:
A list of lists of token IDs.
"""
output_token_ids_np = output_token_ids.cpu().numpy()
output_token_ids_np = output_token_ids.numpy()
# Create mask for valid tokens.
valid_mask = ((output_token_ids_np != PLACEHOLDER_TOKEN_ID) &
(output_token_ids_np < vocab_size))
Expand Down
Loading