Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions vllm/v1/worker/gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -2362,7 +2362,7 @@ def propose_draft_token_ids(
sampling_metadata: SamplingMetadata,
hidden_states: torch.Tensor,
sample_hidden_states: torch.Tensor,
aux_hidden_states: Optional[torch.Tensor],
aux_hidden_states: Optional[list[torch.Tensor]],
spec_decode_metadata: Optional[SpecDecodeMetadata],
common_attn_metadata: CommonAttentionMetadata,
) -> Union[list[list[int]], torch.Tensor]:
Expand All @@ -2382,6 +2382,7 @@ def propose_draft_token_ids(
else:
indices = []
offset = 0
assert spec_decode_metadata is not None
for num_draft, tokens in zip(
spec_decode_metadata.num_draft_tokens,
sampled_token_ids):
Expand Down Expand Up @@ -2432,6 +2433,7 @@ def propose_draft_token_ids(
# TODO(woosuk): Support M-RoPE.
target_positions = self.positions.gpu[:num_scheduled_tokens]
if self.use_aux_hidden_state_outputs:
assert aux_hidden_states is not None
target_hidden_states = torch.cat(
[h[:num_scheduled_tokens] for h in aux_hidden_states],
dim=-1)
Expand All @@ -2457,6 +2459,7 @@ def propose_draft_token_ids(
# TODO(woosuk): Support M-RoPE.
target_positions = self.positions.gpu[token_indices]
if self.use_aux_hidden_state_outputs:
assert aux_hidden_states is not None
target_hidden_states = torch.cat(
[h[token_indices] for h in aux_hidden_states], dim=-1)
else:
Expand Down Expand Up @@ -2892,7 +2895,9 @@ def _dummy_run(
assert not create_mixed_batch
num_reqs = cdiv(num_tokens, max_query_len)
assert num_reqs <= max_num_reqs, \
"Do not capture num_reqs > max_num_reqs for uniform batch"
f"Do not capture num_reqs {num_reqs} > max_num_reqs " \
f"{max_num_reqs} for uniform batch. Num tokens: " \
f"{num_tokens}, max_query_len: {max_query_len}"
num_scheduled_tokens_list = [max_query_len] * num_reqs
if num_tokens % max_query_len != 0:
num_scheduled_tokens_list[-1] = num_tokens % max_query_len
Expand Down
Loading