Skip to content
Merged
Show file tree
Hide file tree
Changes from 45 commits
Commits
Show all changes
47 commits
Select commit Hold shift + click to select a range
d5d7924
[Core] Async scheduling + structured outputs compatibility
njhill Oct 13, 2025
9810947
small fixes
njhill Oct 15, 2025
bc33394
misc code improvement
njhill Oct 16, 2025
e5f9634
simplify with context manager
njhill Oct 17, 2025
8cba549
readability/simplification updates
njhill Oct 17, 2025
66906ff
include sample_tokens() when logging error details
njhill Oct 17, 2025
ac87699
reorg logic a bit for readability
njhill Oct 17, 2025
885760b
Merge remote-tracking branch 'origin/main' into async-sched-struct-ou…
njhill Oct 17, 2025
eef1d44
update comment
njhill Oct 17, 2025
2d17506
Merge remote-tracking branch 'origin/main' into async-sched-struct-ou…
njhill Oct 17, 2025
ac60de7
TPU updates
njhill Oct 17, 2025
01eec54
add ray compatibility
njhill Oct 17, 2025
866a281
Merge remote-tracking branch 'origin/main' into async-sched-struct-ou…
njhill Oct 17, 2025
717fbad
Merge remote-tracking branch 'refs/remotes/origin/main' into async-sc…
njhill Oct 18, 2025
0c03cb2
fix import and test
njhill Oct 21, 2025
f6b3318
Merge remote-tracking branch 'origin/main' into async-sched-struct-ou…
njhill Oct 21, 2025
0127d64
test updates
njhill Oct 21, 2025
b8208bd
add to e2e async scheduling test
njhill Oct 21, 2025
09090a6
Merge remote-tracking branch 'origin/main' into async-sched-struct-ou…
njhill Oct 21, 2025
694616f
typing updates
njhill Oct 21, 2025
3f83daa
Merge remote-tracking branch 'refs/remotes/origin/main' into async-sc…
njhill Oct 21, 2025
59a522c
other minor cleanup
njhill Oct 21, 2025
19a8ef5
Merge remote-tracking branch 'origin/main' into async-sched-struct-ou…
njhill Oct 21, 2025
e741744
fix doc build
njhill Oct 21, 2025
68084a6
Merge remote-tracking branch 'refs/remotes/origin/main' into async-sc…
njhill Oct 22, 2025
489646b
Merge remote-tracking branch 'origin/main' into async-sched-struct-ou…
njhill Oct 22, 2025
bab7da7
fixes
njhill Oct 22, 2025
99566d2
fix fix
njhill Oct 22, 2025
757cbbf
Merge remote-tracking branch 'origin/main' into async-sched-struct-ou…
njhill Oct 23, 2025
31cbdf4
Merge remote-tracking branch 'refs/remotes/origin/main' into async-sc…
njhill Oct 23, 2025
34e9c26
Merge remote-tracking branch 'refs/remotes/origin/main' into async-sc…
njhill Oct 27, 2025
8964d5e
Merge remote-tracking branch 'refs/remotes/origin/main' into async-sc…
njhill Oct 28, 2025
7d09150
minor tpu_model_runner improvement
njhill Oct 28, 2025
f658423
Merge remote-tracking branch 'origin/main' into async-sched-struct-ou…
njhill Oct 28, 2025
a3f7330
Merge remote-tracking branch 'origin/main' into async-sched-struct-ou…
njhill Oct 28, 2025
b239fa8
Merge remote-tracking branch 'refs/remotes/origin/main' into async-sc…
njhill Oct 30, 2025
002d6cb
Merge remote-tracking branch 'origin/main' into async-sched-struct-ou…
njhill Oct 30, 2025
c123fb5
rename do_execute_model() per @WoosukKwon's suggestion
njhill Oct 30, 2025
ce58e38
only run sample_tokens() in final PP rank
njhill Oct 30, 2025
e845151
Merge remote-tracking branch 'origin/main' into async-sched-struct-ou…
njhill Oct 30, 2025
9361685
fix external_launcher mode
njhill Oct 31, 2025
a7ed3a0
Merge remote-tracking branch 'origin/main' into async-sched-struct-ou…
njhill Oct 31, 2025
ca5b20c
Merge remote-tracking branch 'origin/main' into async-sched-struct-ou…
njhill Oct 31, 2025
311e48d
Add some comments, update WorkerBase, some simpler formatting
njhill Oct 31, 2025
defcffd
Merge remote-tracking branch 'origin/main' into async-sched-struct-ou…
njhill Oct 31, 2025
d756533
fix import
njhill Oct 31, 2025
7af5dc4
Merge remote-tracking branch 'origin/main' into async-sched-struct-ou…
njhill Oct 31, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@

from tblib import pickling_support

# Import fixture
from tests.v1.entrypoints.conftest import sample_json_schema # noqa

# ruff: noqa

# Install support for pickling exceptions so that we can nicely propagate
Expand Down
9 changes: 0 additions & 9 deletions tests/v1/core/test_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,8 +337,6 @@ def test_stop_via_update_from_output():
num_common_prefix_blocks=[],
finished_req_ids=set(),
free_encoder_mm_hashes=[],
structured_output_request_ids=[],
grammar_bitmask=None,
)

model_output = ModelRunnerOutput(
Expand Down Expand Up @@ -385,8 +383,6 @@ def test_stop_via_update_from_output():
num_common_prefix_blocks=[],
finished_req_ids=set(),
free_encoder_mm_hashes=[],
structured_output_request_ids=[],
grammar_bitmask=None,
)

model_output = ModelRunnerOutput(
Expand Down Expand Up @@ -431,8 +427,6 @@ def test_stop_via_update_from_output():
num_common_prefix_blocks=[],
finished_req_ids=set(),
free_encoder_mm_hashes=[],
structured_output_request_ids=[],
grammar_bitmask=None,
)

model_output = ModelRunnerOutput(
Expand Down Expand Up @@ -472,8 +466,6 @@ def test_stop_via_update_from_output():
num_common_prefix_blocks=[],
finished_req_ids=set(),
free_encoder_mm_hashes=[],
structured_output_request_ids=[],
grammar_bitmask=None,
)

model_output = ModelRunnerOutput(
Expand Down Expand Up @@ -1988,7 +1980,6 @@ def test_schedule_skip_tokenizer_init():
scheduler.add_request(request)
output = scheduler.schedule()
assert len(output.scheduled_new_reqs) == len(requests)
assert output.grammar_bitmask is None


def test_schedule_skip_tokenizer_init_structured_output_request():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from vllm import SamplingParams
from vllm.logprobs import Logprob
from vllm.sampling_params import StructuredOutputsParams

from ...conftest import VllmRunner
from ...models.utils import check_outputs_equal
Expand All @@ -15,9 +16,12 @@


@dynamo_config.patch(cache_size_limit=16)
def test_preempt_and_async_scheduling_e2e(monkeypatch: pytest.MonkeyPatch):
def test_preempt_and_async_scheduling_e2e(
sample_json_schema, monkeypatch: pytest.MonkeyPatch
):
"""Test consistency of combos of async scheduling, preemption,
uni/multiproc executor, and various sampling parameters."""
uni/multiproc executor, and various sampling parameters
including structured outputs."""

first_prompt = (
"The following numbers of the sequence "
Expand All @@ -35,6 +39,12 @@ def test_preempt_and_async_scheduling_e2e(monkeypatch: pytest.MonkeyPatch):
dict(bad_words=["the", " the"]),
dict(logprobs=2),
dict(logprobs=2, presence_penalty=-1.0),
dict(structured_outputs=StructuredOutputsParams(json=sample_json_schema)),
dict(
structured_outputs=StructuredOutputsParams(json=sample_json_schema),
logprobs=2,
presence_penalty=-1.0,
),
]

default_params = dict(
Expand Down
19 changes: 18 additions & 1 deletion tests/v1/engine/test_engine_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,7 @@ def execute_model(
self,
scheduler_output,
non_block=False,
) -> Future[ModelRunnerOutput]:
) -> Future[ModelRunnerOutput | None]:
"""Make execute_model non-blocking."""

# DummyExecutor used only for testing async case.
Expand All @@ -263,6 +263,23 @@ def _execute():
# Use the thread pool instead of creating a new thread
return self.thread_pool.submit(_execute)

def sample_tokens(
self, grammar_output, non_block=False
) -> Future[ModelRunnerOutput]:
"""Make sample_tokens non-blocking."""

# DummyExecutor used only for testing async case.
assert non_block

def _execute():
output = self.collective_rpc("sample_tokens", args=(grammar_output,))
# Make a copy because output[0] may be reused
# by the next batch.
return copy.deepcopy(output[0])

# Use the thread pool instead of creating a new thread
return self.thread_pool.submit(_execute)

@property
def max_concurrent_batches(self) -> int:
return 2
Expand Down
4 changes: 3 additions & 1 deletion tests/v1/executor/test_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,9 @@ def collective_rpc(
# Drop marker to show that this was run
with open(".marker", "w"):
...
return super().collective_rpc(method, timeout, args, kwargs)
return super().collective_rpc(
method, timeout, args, kwargs, non_block, unique_reply_rank
)


CustomMultiprocExecutorAsync = CustomMultiprocExecutor
Expand Down
2 changes: 0 additions & 2 deletions tests/v1/kv_connector/unit/test_kv_connector_lifecyle.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,6 @@ def _make_empty_scheduler_output():
num_common_prefix_blocks=[],
finished_req_ids=set(),
free_encoder_mm_hashes=[],
structured_output_request_ids=[],
grammar_bitmask=None,
kv_connector_metadata=SharedStorageConnectorMetadata(),
)

Expand Down
4 changes: 1 addition & 3 deletions tests/v1/kv_connector/unit/test_nixl_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -981,9 +981,7 @@ def test_scheduler_kv_connector_stats_aggregation():
scheduled_encoder_inputs={},
num_common_prefix_blocks=[0],
finished_req_ids=set(),
free_encoder_mm_hashes=set(),
structured_output_request_ids={},
grammar_bitmask=None,
free_encoder_mm_hashes=[],
)

engine_core_outputs = scheduler.update_from_output(scheduler_output, model_output)
Expand Down
12 changes: 0 additions & 12 deletions tests/v1/tpu/worker/test_tpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,8 +92,6 @@ def _schedule_new_request(*req_ids: str) -> SchedulerOutput:
num_common_prefix_blocks=[],
finished_req_ids=set(),
free_encoder_mm_hashes=[],
structured_output_request_ids=[],
grammar_bitmask=None,
)


Expand Down Expand Up @@ -171,8 +169,6 @@ def test_update_states_request_finished(model_runner):
num_common_prefix_blocks=[],
finished_req_ids={req_id},
free_encoder_mm_hashes=[],
structured_output_request_ids=[],
grammar_bitmask=None,
)

model_runner._update_states(scheduler_output)
Expand Down Expand Up @@ -201,8 +197,6 @@ def test_update_states_request_resumed(model_runner):
num_common_prefix_blocks=[],
finished_req_ids=set(),
free_encoder_mm_hashes=[],
structured_output_request_ids=[],
grammar_bitmask=None,
)

model_runner._update_states(scheduler_output)
Expand Down Expand Up @@ -230,8 +224,6 @@ def test_update_states_request_resumed(model_runner):
num_common_prefix_blocks=[],
finished_req_ids=set(),
free_encoder_mm_hashes=[],
structured_output_request_ids=[],
grammar_bitmask=None,
)

model_runner._update_states(scheduler_output)
Expand Down Expand Up @@ -261,8 +253,6 @@ def test_update_states_no_changes(model_runner):
num_common_prefix_blocks=[],
finished_req_ids=set(),
free_encoder_mm_hashes=[],
structured_output_request_ids=[],
grammar_bitmask=None,
)

model_runner._update_states(scheduler_output)
Expand Down Expand Up @@ -296,8 +286,6 @@ def test_update_states_request_unscheduled(model_runner):
num_common_prefix_blocks=[],
finished_req_ids=set(),
free_encoder_mm_hashes=[],
structured_output_request_ids=[],
grammar_bitmask=None,
)

model_runner._update_states(scheduler_output)
Expand Down
12 changes: 0 additions & 12 deletions tests/v1/worker/test_gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,8 +150,6 @@ def _schedule_new_request(*req_ids: str) -> SchedulerOutput:
num_common_prefix_blocks=[],
finished_req_ids=set(),
free_encoder_mm_hashes=[],
structured_output_request_ids=[],
grammar_bitmask=None,
)


Expand Down Expand Up @@ -216,8 +214,6 @@ def test_update_states_request_finished(model_runner, dist_init):
num_common_prefix_blocks=[],
finished_req_ids={req_id},
free_encoder_mm_hashes=[],
structured_output_request_ids=[],
grammar_bitmask=None,
)

metadata_before = model_runner.input_batch.sampling_metadata
Expand Down Expand Up @@ -248,8 +244,6 @@ def test_update_states_request_resumed(model_runner, dist_init):
num_common_prefix_blocks=[],
finished_req_ids=set(),
free_encoder_mm_hashes=[],
structured_output_request_ids=[],
grammar_bitmask=None,
)

model_runner._update_states(scheduler_output)
Expand Down Expand Up @@ -277,8 +271,6 @@ def test_update_states_request_resumed(model_runner, dist_init):
num_common_prefix_blocks=[],
finished_req_ids=set(),
free_encoder_mm_hashes=[],
structured_output_request_ids=[],
grammar_bitmask=None,
)

metadata_before = model_runner.input_batch.sampling_metadata
Expand Down Expand Up @@ -370,8 +362,6 @@ def test_update_states_no_changes(model_runner, dist_init):
num_common_prefix_blocks=[],
finished_req_ids=set(),
free_encoder_mm_hashes=[],
structured_output_request_ids=[],
grammar_bitmask=None,
)

metadata_before = model_runner.input_batch.sampling_metadata
Expand Down Expand Up @@ -407,8 +397,6 @@ def test_update_states_request_unscheduled(model_runner, dist_init):
num_common_prefix_blocks=[],
finished_req_ids=set(),
free_encoder_mm_hashes=[],
structured_output_request_ids=[],
grammar_bitmask=None,
)

metadata_before = model_runner._update_states(scheduler_output)
Expand Down
30 changes: 18 additions & 12 deletions vllm/distributed/kv_transfer/kv_connector/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from collections.abc import Sequence
from concurrent.futures import CancelledError, Future
from typing import TYPE_CHECKING, Literal, cast
from typing import TYPE_CHECKING, Literal

import torch

Expand Down Expand Up @@ -138,8 +138,11 @@ def from_connector(cls, connector: "KVConnectorBase", world_size: int):
return cls(connector.get_finished_count() or world_size)

def aggregate(
self, outputs: list[ModelRunnerOutput], output_rank: int = 0
) -> ModelRunnerOutput:
self, outputs: list[ModelRunnerOutput | None], output_rank: int = 0
) -> ModelRunnerOutput | None:
if not outputs[output_rank]:
return None

# Aggregate kv_connector_output from all workers

def update_finished_set(
Expand All @@ -161,6 +164,7 @@ def update_finished_set(
aggregated_kv_connector_stats = None
invalid_block_ids = set[int]()
for model_runner_output in outputs:
assert model_runner_output is not None
kv_output = model_runner_output.kv_connector_output
if not kv_output:
continue
Expand Down Expand Up @@ -204,6 +208,7 @@ def update_finished_set(
# select output of the worker specified by output_rank
output = outputs[output_rank]

assert output is not None
output.kv_connector_output = KVConnectorOutput(
finished_sending=finished_sending or None,
finished_recving=finished_recving or None,
Expand All @@ -215,13 +220,16 @@ def update_finished_set(
return output

def async_aggregate(
self, output_futures: Sequence[Future[ModelRunnerOutput]], output_rank: int = 0
) -> Future[ModelRunnerOutput]:
self,
output_futures: Sequence[Future[ModelRunnerOutput | None]],
output_rank: int = 0,
) -> Future[ModelRunnerOutput | None]:
"""Takes a list of futures and returns a single future which resolves
to the respective list of outputs."""
result_future: Future[ModelRunnerOutput] = Future()
result_future: Future[ModelRunnerOutput | None] = Future()

outputs: list[ModelRunnerOutput | None] = [None] * len(output_futures)
remaining = len(output_futures)

def make_callback(idx):
def callback(fut):
Expand All @@ -236,12 +244,10 @@ def callback(fut):
result_future.set_exception(e)

# this check assumes io_thread_pool uses a single thread
if all(outputs):
result_future.set_result(
self.aggregate(
cast(list[ModelRunnerOutput], outputs), output_rank
)
)
nonlocal remaining
remaining -= 1
if not remaining:
result_future.set_result(self.aggregate(outputs, output_rank))

return callback

Expand Down
8 changes: 8 additions & 0 deletions vllm/v1/core/sched/async_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,12 @@ def _update_after_schedule(
scheduler_output: SchedulerOutput,
) -> None:
super()._update_after_schedule(scheduler_output)
pending_structured_output_tokens = False
for req_id in scheduler_output.num_scheduled_tokens:
request = self.requests[req_id]
pending_structured_output_tokens |= (
request.use_structured_output and request.num_output_placeholders > 0
)
if (
request.num_computed_tokens
== request.num_tokens + request.num_output_placeholders
Expand All @@ -25,6 +29,10 @@ def _update_after_schedule(
# TODO(woosuk): Support speculative decoding.
request.num_output_placeholders += 1

scheduler_output.pending_structured_output_tokens = (
pending_structured_output_tokens
)

def _update_request_with_output(
self,
request: Request,
Expand Down
8 changes: 7 additions & 1 deletion vllm/v1/core/sched/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

if TYPE_CHECKING:
from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorBase_V1
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput
from vllm.v1.engine import EngineCoreOutputs
from vllm.v1.metrics.stats import SchedulerStats
from vllm.v1.outputs import DraftTokenIds, ModelRunnerOutput
Expand Down Expand Up @@ -40,6 +40,12 @@ def schedule(self) -> "SchedulerOutput":
"""
raise NotImplementedError

@abstractmethod
def get_grammar_bitmask(
self, scheduler_output: "SchedulerOutput"
) -> "GrammarOutput | None":
raise NotImplementedError

@abstractmethod
def update_from_output(
self,
Expand Down
Loading