Skip to content
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 13 additions & 3 deletions tests/v1/core/test_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,9 @@ def test_schedule_partial_requests():
model_runner_output = ModelRunnerOutput(
req_ids=[request.request_id for request in requests],
req_id_to_index=req_to_index,
sampled_token_ids=[[0] for _ in range(len(requests))],
# Only the first request has a sampled token id because
# the rest requests are still being prefilled.
sampled_token_ids=[[0], [], []],
spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={},
Expand All @@ -266,7 +268,7 @@ def test_schedule_partial_requests():


@pytest.mark.parametrize("enable_prefix_caching", [True, False])
def test_schedule_concurrent_partial_requestse(enable_prefix_caching: bool):
def test_schedule_concurrent_partial_requests(enable_prefix_caching: bool):
"""Test scheduling behavior with concurrent partial requests.

This test verifies that: there are multiple long prefill requests in the
Expand Down Expand Up @@ -304,7 +306,7 @@ def test_schedule_concurrent_partial_requestse(enable_prefix_caching: bool):
model_runner_output = ModelRunnerOutput(
req_ids=[request.request_id for request in requests],
req_id_to_index=req_to_index,
sampled_token_ids=[[0] for _ in range(len(requests))],
sampled_token_ids=[[] for _ in range(len(requests))],
spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={},
Expand All @@ -325,6 +327,14 @@ def test_schedule_concurrent_partial_requestse(enable_prefix_caching: bool):
# Schedule the third step. All three requests are running.
# First and second requests are in the decode stage.
# All the remaining tokens in the third request are processed.
model_runner_output = ModelRunnerOutput(
req_ids=[request.request_id for request in requests],
req_id_to_index=req_to_index,
sampled_token_ids=[[0], [0]] + [[] for _ in range(len(requests) - 2)],
spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={},
)
scheduler.update_from_output(output1, model_runner_output)
output2 = scheduler.schedule()
assert len(scheduler.running) == 3
Expand Down
18 changes: 11 additions & 7 deletions tests/v1/engine/test_engine_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,8 +231,10 @@ def test_engine_core_concurrent_batches(monkeypatch: pytest.MonkeyPatch):
Test that the engine can handle multiple concurrent batches.
"""

def make_request_with_max_tokens(max_tokens: int) -> EngineCoreRequest:
def make_request_with_max_tokens(req_id: int,
max_tokens: int) -> EngineCoreRequest:
request = make_request()
request.request_id = req_id
request.sampling_params.max_tokens = max_tokens
return request

Expand Down Expand Up @@ -279,20 +281,22 @@ def max_concurrent_batches(self) -> int:
# Avoid all requests being scheduled once.
enable_prefix_caching=False,
max_num_batched_tokens=10,
# Reduce startup time.
enforce_eager=True,
)
vllm_config = engine_args.create_engine_config()
engine_core = EngineCore(vllm_config=vllm_config,
log_stats=False,
executor_class=DummyExecutor)
assert engine_core.batch_queue is not None

# Add two requests in a row.
req = make_request_with_max_tokens(5)
engine_core.add_request(req)
req = make_request_with_max_tokens(5)
engine_core.add_request(req)
# Add two requests in a row. Each request have 12 prompt tokens.
req0 = make_request_with_max_tokens(0, 5)
engine_core.add_request(req0)
req1 = make_request_with_max_tokens(1, 5)
engine_core.add_request(req1)

# First saturate the batch queue.
# Schedule Batch 1: (10, req0)
assert engine_core.step_with_batch_queue() is None
assert engine_core.batch_queue.qsize() == 1
assert engine_core.step_with_batch_queue() is None
Expand Down
91 changes: 48 additions & 43 deletions vllm/v1/core/sched/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,9 +153,9 @@ def schedule(self) -> SchedulerOutput:

num_new_tokens = (request.num_tokens_with_spec -
request.num_computed_tokens)
if self.scheduler_config.long_prefill_token_threshold > 0:
num_new_tokens = min(
num_new_tokens,
if (0 < self.scheduler_config.long_prefill_token_threshold <
num_new_tokens):
num_new_tokens = (
self.scheduler_config.long_prefill_token_threshold)
num_new_tokens = min(num_new_tokens, token_budget)
assert num_new_tokens > 0
Expand Down Expand Up @@ -303,9 +303,9 @@ def schedule(self) -> SchedulerOutput:
num_computed_tokens -= self.block_size
num_new_tokens = self.block_size
computed_blocks.pop()
if self.scheduler_config.long_prefill_token_threshold > 0:
num_new_tokens = min(
num_new_tokens,
if (0 < self.scheduler_config.long_prefill_token_threshold <
num_new_tokens):
num_new_tokens = (
self.scheduler_config.long_prefill_token_threshold)
num_new_tokens = min(num_new_tokens, token_budget)
assert num_new_tokens > 0
Expand Down Expand Up @@ -433,6 +433,18 @@ def schedule(self) -> SchedulerOutput:
grammar_bitmask=grammar_bitmask,
)

# Advance the number of computed tokens for the request AFTER
# the request is scheduled.
# 1. The scheduler_output of the current step has to include the
# original number of scheduled tokens to determine input IDs.
# 2. Advance the number of computed tokens here allowing us to
# schedule the prefill request again immediately in the next
# scheduling step.
# 3. If some tokens (e.g. spec tokens) are rejected later, the number of
# computed tokens will be adjusted in update_from_output.
for req_id, num_scheduled_token in num_scheduled_tokens.items():
self.requests[req_id].num_computed_tokens += num_scheduled_token

self.finished_req_ids = set()
return scheduler_output

Expand Down Expand Up @@ -561,28 +573,19 @@ def update_from_output(

req_index = model_runner_output.req_id_to_index[req_id]
generated_token_ids = sampled_token_ids[req_index]
if req_id not in scheduler_output.scheduled_spec_decode_tokens:
# When the request's num_computed_tokens catches up
# its num_tokens, the request generates output tokens.
# Otherwise, we ignore the sampler output for the request.
request.num_computed_tokens += num_tokens_scheduled
assert request.num_computed_tokens <= request.num_tokens
else:
# num_computed_tokens_step represents the number of tokens

scheduled_spec_token_ids = (
scheduler_output.scheduled_spec_decode_tokens.get(req_id))
if scheduled_spec_token_ids:
# num_computed_tokens represents the number of tokens
# processed in the current step, considering scheduled
# tokens and rejections.
# It is calculated as:
# num_computed_tokens_step = num_scheduled_tokens -
# num_tokens_rejected,
# where num_tokens_rejected is given by:
# tokens and rejections. If some tokens are rejected,
# num_computed_tokens is decreased by the number of rejected
# tokens, where is given by:
# len(scheduled_spec_token_ids) + 1 - len(generated_token_ids).
scheduled_spec_token_ids = (
scheduler_output.scheduled_spec_decode_tokens[req_id])

num_computed_tokens_step = num_scheduled_tokens[req_id] - (
len(scheduled_spec_token_ids) + 1 -
len(generated_token_ids))
request.num_computed_tokens += num_computed_tokens_step
num_tokens_rejected = (len(scheduled_spec_token_ids) + 1 -
len(generated_token_ids))
request.num_computed_tokens -= num_tokens_rejected

cached_encoder_input_ids = (
self.encoder_cache_manager.get_cached_input_ids(request))
Expand All @@ -605,24 +608,26 @@ def update_from_output(
new_logprobs = None
new_token_ids: list[int] = []

if request.num_computed_tokens >= request.num_tokens:
for output_token_id in generated_token_ids:
request.append_output_token_ids(output_token_id)
new_token_ids.append(output_token_id)

# Check for stop and update request state.
# This must be called before we make the EngineCoreOutput.
stopped = check_stop(request, self.max_model_len)
if stopped:
self._free_request(request)
break
# Append generated tokens and check for stop. Note that if
# a request is still being prefilled, we expect the model runner
# to return empty token ids for the request.
for output_token_id in generated_token_ids:
request.append_output_token_ids(output_token_id)
new_token_ids.append(output_token_id)

# Check for stop and update request state.
# This must be called before we make the EngineCoreOutput.
stopped = check_stop(request, self.max_model_len)
if stopped:
self._free_request(request)
break

# Extract sample logprobs if needed.
if request.sampling_params.logprobs is not None:
assert logprobs is not None
# NOTE: once we support N tokens per step (spec decode),
# the outer lists can be of length > 1.
new_logprobs = logprobs.slice(req_index, req_index + 1)
# Extract sample logprobs if needed.
if (request.sampling_params.logprobs is not None
and logprobs is not None):
# NOTE: once we support N tokens per step (spec decode),
# the outer lists can be of length > 1.
new_logprobs = logprobs.slice(req_index, req_index + 1)

if new_token_ids and request.use_structured_output:
# NOTE: structured_output_request
Expand Down
19 changes: 19 additions & 0 deletions vllm/v1/sample/rejection_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,14 +107,33 @@ def forward(
@staticmethod
def parse_output(
output_token_ids: torch.Tensor,
ignored_req_idxs: list[int],
vocab_size: int,
) -> list[list[int]]:
"""Parse the output of the rejection sampler.

Args:
output_token_ids: The sampled token IDs in shape
[batch_size, max_spec_len + 1]. The rejected tokens are
replaced with `PLACEHOLDER_TOKEN_ID` by the rejection sampler
and will be filtered out in this function.
ignored_req_idxs: The indices of the requests that should not be
sampled. This is usually because the request is still in the
prefill phase.
vocab_size: The size of the vocabulary.

Returns:
A list of lists of token IDs.
"""
output_token_ids_np = output_token_ids.cpu().numpy()
# Create mask for valid tokens.
valid_mask = ((output_token_ids_np != PLACEHOLDER_TOKEN_ID) &
(output_token_ids_np < vocab_size))

ignored_req_idx_set = set(ignored_req_idxs)
outputs = [
row[valid_mask[i]].tolist()
if i not in ignored_req_idx_set else []
for i, row in enumerate(output_token_ids_np)
]
return outputs
Expand Down
19 changes: 15 additions & 4 deletions vllm/v1/worker/gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1085,16 +1085,21 @@ def execute_model(

# TODO(woosuk): The following loop can be slow since it iterates over
# the requests one by one. Optimize.
for i, generator in self.input_batch.generators.items():
req_id = self.input_batch.req_ids[i]
should_not_sampled_req_idxs = []
for i, req_id in enumerate(self.input_batch.req_ids):
req_state = self.requests[req_id]
seq_len = (req_state.num_computed_tokens +
scheduler_output.num_scheduled_tokens[req_id])
if seq_len < req_state.num_tokens:
# Ignore the sampled token for partial prefills.
# Rewind the generator state as if the token was not sampled.
# This relies on cuda-specific torch-internal impl details
generator.set_offset(generator.get_offset() - 4)
generator = self.input_batch.generators.get(i)
if generator is not None:
generator.set_offset(generator.get_offset() - 4)
# Record the index of the request that should not be sampled,
# so that we could clear the sampled tokens before returning.
should_not_sampled_req_idxs.append(i)

# NOTE: GPU -> CPU Sync happens here.
# Move as many CPU operations as possible before this sync point.
Expand All @@ -1114,10 +1119,16 @@ def execute_model(
if max_gen_len == 1:
# No spec decode tokens.
valid_sampled_token_ids = sampled_token_ids.tolist()
# Mask out the sampled tokens that should not be sampled.
for i in should_not_sampled_req_idxs:
valid_sampled_token_ids[i].clear()
else:
# Includes spec decode tokens.
valid_sampled_token_ids = self.rejection_sampler.parse_output(
sampled_token_ids, self.input_batch.vocab_size)
sampled_token_ids,
should_not_sampled_req_idxs,
self.input_batch.vocab_size,
)

if not self.use_spec_decode:
spec_token_ids = None
Expand Down