From f27261a09613a2ef047dc3fa4eed81da12e746d0 Mon Sep 17 00:00:00 2001 From: Tomas Ruiz Date: Mon, 18 Aug 2025 16:17:39 +0200 Subject: [PATCH 01/73] Speculative Decoding with Draft Model Signed-off-by: Tomas Ruiz --- .gitignore | 3 + examples/offline_inference/spec_decode.py | 19 +- pyproject.toml | 5 + tests/v1/e2e/test_spec_decode.py | 137 +++++++++++++++ vllm/benchmarks/throughput.py | 39 +++- vllm/config/__init__.py | 9 +- vllm/engine/arg_utils.py | 5 +- vllm/model_executor/model_loader/__init__.py | 6 +- .../model_loader/base_loader.py | 9 +- .../model_loader/gguf_loader.py | 9 +- .../model_loader/tensorizer_loader.py | 13 +- vllm/v1/core/sched/scheduler.py | 8 +- vllm/v1/spec_decode/eagle.py | 166 ++++++++++++++---- vllm/v1/spec_decode/metrics.py | 10 ++ vllm/v1/worker/gpu_model_runner.py | 37 +++- vllm/v1/worker/utils.py | 29 +-- 16 files changed, 406 insertions(+), 98 deletions(-) diff --git a/.gitignore b/.gitignore index 465935d488f8..d025841c5ae8 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,6 @@ +# Scripts for development +scripts/ + # version file generated by setuptools-scm /vllm/_version.py diff --git a/examples/offline_inference/spec_decode.py b/examples/offline_inference/spec_decode.py index 5af232cb6af6..2a517abaab31 100644 --- a/examples/offline_inference/spec_decode.py +++ b/examples/offline_inference/spec_decode.py @@ -53,7 +53,7 @@ def parse_args(): "--method", type=str, default="eagle", - choices=["ngram", "eagle", "eagle3", "mtp"], + choices=["ngram", "eagle", "eagle3", "mtp", "draft_model"], ) parser.add_argument("--num-spec-tokens", type=int, default=2) parser.add_argument("--prompt-lookup-max", type=int, default=5) @@ -68,7 +68,11 @@ def parse_args(): parser.add_argument("--output-len", type=int, default=256) parser.add_argument("--model-dir", type=str, default=None) parser.add_argument("--eagle-dir", type=str, default=None) + parser.add_argument("--draft-model", type=str, default=None) parser.add_argument("--custom-mm-prompts", action="store_true") + parser.add_argument("--gpu-memory-utilization", type=float, default=0.8) + parser.add_argument("--request-id-prefix", type=str, default="") + parser.add_argument("--max-model-len", type=int, default=16384) return parser.parse_args() @@ -118,6 +122,15 @@ def main(): "prompt_lookup_max": args.prompt_lookup_max, "prompt_lookup_min": args.prompt_lookup_min, } + elif args.method == "draft_model": + assert args.draft_model is not None and args.draft_model != "" + speculative_config = { + "method": args.method, + "model": args.draft_model, + "num_speculative_tokens": args.num_spec_tokens, + "enforce_eager": args.enforce_eager, + "max_model_len": args.max_model_len, + } else: raise ValueError(f"unknown method: {args.method}") @@ -127,10 +140,10 @@ def main(): tensor_parallel_size=args.tp, enable_chunked_prefill=args.enable_chunked_prefill, enforce_eager=args.enforce_eager, - gpu_memory_utilization=0.8, + gpu_memory_utilization=args.gpu_memory_utilization, speculative_config=speculative_config, disable_log_stats=False, - max_model_len=16384, + max_model_len=args.max_model_len, limit_mm_per_prompt={"image": 5}, disable_chunked_mm_input=True, ) diff --git a/pyproject.toml b/pyproject.toml index e63f8aeae278..e41d8a26aa55 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -154,6 +154,11 @@ markers = [ "skip_v1: do not run this test with v1", "optional: optional tests that are automatically skipped, include --optional to run them", ] +# Show print statements and logs during test execution +addopts = "-s --tb=short --log-cli-level=INFO" +log_cli = true +log_cli_format = "%(asctime)s [%(levelname)8s] %(name)s: %(message)s" +log_cli_date_format = "%Y-%m-%d %H:%M:%S" [tool.ty.src] root = "./vllm" diff --git a/tests/v1/e2e/test_spec_decode.py b/tests/v1/e2e/test_spec_decode.py index cd1d34fc6c3e..dd712050ed8a 100644 --- a/tests/v1/e2e/test_spec_decode.py +++ b/tests/v1/e2e/test_spec_decode.py @@ -3,6 +3,7 @@ from __future__ import annotations import random +from dataclasses import dataclass from typing import Any, Union import pytest @@ -13,7 +14,9 @@ from vllm.assets.base import VLLM_S3_BUCKET_URL from vllm.assets.image import VLM_IMAGES_DIR from vllm.distributed import cleanup_dist_env_and_memory +from vllm.outputs import RequestOutput from vllm.platforms import current_platform +from vllm.v1.spec_decode.metrics import compute_acceptance_rate def get_test_prompts(mm_enabled: bool): @@ -69,9 +72,17 @@ def get_test_prompts(mm_enabled: bool): @pytest.fixture def sampling_config(): + return greedy_sampling() + + +def greedy_sampling() -> SamplingParams: return SamplingParams(temperature=0, max_tokens=10, ignore_eos=False) +def stochastic_sampling() -> SamplingParams: + return SamplingParams(temperature=1.0, max_tokens=10, ignore_eos=False) + + @pytest.fixture def model_name(): return "meta-llama/Llama-3.1-8B-Instruct" @@ -230,3 +241,129 @@ def test_eagle_correctness( del spec_llm torch.cuda.empty_cache() cleanup_dist_env_and_memory() + + +@dataclass +class ArgsTest: + model: str + draft_model: str + sampling_config: SamplingParams + expected_acceptance_rate: float + expected_same_output_fraction: float + # Defaults + max_model_len: int = 1024 + gpu_memory_utilization: float = 0.5 + + +cases = [ + ArgsTest( + model="baidu/ERNIE-4.5-0.3B-PT", + draft_model="baidu/ERNIE-4.5-0.3B-PT", + sampling_config=greedy_sampling(), + expected_acceptance_rate=1.0, + expected_same_output_fraction=1.0, + ), + ArgsTest( + model="baidu/ERNIE-4.5-0.3B-PT", + draft_model="baidu/ERNIE-4.5-0.3B-PT", + sampling_config=stochastic_sampling(), + expected_acceptance_rate=0.2, + expected_same_output_fraction=0.0, + ), + ArgsTest( + model="meta-llama/Llama-3.2-1B-Instruct", + draft_model="meta-llama/Llama-3.2-1B-Instruct", + sampling_config=greedy_sampling(), + expected_acceptance_rate=0.8, + expected_same_output_fraction=0.5, + ), + ArgsTest( + model="meta-llama/Llama-3.2-1B-Instruct", + draft_model="meta-llama/Llama-3.2-1B-Instruct", + sampling_config=stochastic_sampling(), + expected_acceptance_rate=0.4, + expected_same_output_fraction=0.15, + ), + ArgsTest( + model="Qwen/Qwen3-1.7B", + draft_model="Qwen/Qwen3-0.6B", + sampling_config=greedy_sampling(), + expected_acceptance_rate=1.0, + expected_same_output_fraction=1.0, + ), + ArgsTest( + model="Qwen/Qwen3-1.7B", + draft_model="Qwen/Qwen3-0.6B", + sampling_config=stochastic_sampling(), + expected_acceptance_rate=0.9, + expected_same_output_fraction=0.9, + ), +] + + +@pytest.mark.parametrize("args", cases) +@pytest.mark.parametrize("enforce_eager", [True, False]) +def test_draft_model_correctness(args: ArgsTest, enforce_eager: bool, + monkeypatch: pytest.MonkeyPatch): + """Compare the outputs using and not using speculative decoding. + In the greedy decoding case, the outputs must match EXACTLY.""" + monkeypatch.setenv("VLLM_USE_V1", "1") + test_prompts = get_test_prompts(mm_enabled=False) + + spec_llm = LLM( + model=args.model, + speculative_config={ + "model": args.draft_model, + "method": "draft_model", + "num_speculative_tokens": 3, + "max_model_len": args.max_model_len, + "enforce_eager": enforce_eager, + }, + max_model_len=args.max_model_len, + gpu_memory_utilization=args.gpu_memory_utilization, + enforce_eager=enforce_eager, + disable_log_stats=False, # enables get_metrics() + ) + spec_outputs = spec_llm.chat(test_prompts, args.sampling_config) + acceptance_rate = compute_acceptance_rate(spec_llm.get_metrics()) + del spec_llm # CLEANUP + torch.cuda.empty_cache() + cleanup_dist_env_and_memory() + + assert acceptance_rate >= args.expected_acceptance_rate + + ref_llm = LLM( + model=args.model, + max_model_len=args.max_model_len, + gpu_memory_utilization=args.gpu_memory_utilization, + enforce_eager=enforce_eager, + ) + ref_outputs = ref_llm.chat(test_prompts, args.sampling_config) + del ref_llm # CLEANUP + torch.cuda.empty_cache() + cleanup_dist_env_and_memory() + + assert len(ref_outputs) > 0 + assert len(ref_outputs) == len(spec_outputs) + + match_fraction = compute_exact_matches(ref_outputs, spec_outputs) + assert match_fraction >= args.expected_same_output_fraction + + print(f"spec-decode: target={args.model}, draft={args.draft_model}, " + f"temperature={args.sampling_config.temperature:.2f}, " + f"acceptance_rate={acceptance_rate:.2f}, " + f"match_fraction={match_fraction:.2f}") + + +def compute_exact_matches(ref_outputs: list[RequestOutput], + spec_outputs: list[RequestOutput]) -> float: + """Compute the fraction of the prompts that match exactly""" + assert len(ref_outputs) == len(spec_outputs) + matches = 0 + for ref_output, spec_output in zip(ref_outputs, spec_outputs): + if ref_output.outputs[0].text == spec_output.outputs[0].text: + matches += 1 + else: + print(f"ref_output: {ref_output.outputs[0].text}") + print(f"spec_output: {spec_output.outputs[0].text}") + return matches / len(ref_outputs) diff --git a/vllm/benchmarks/throughput.py b/vllm/benchmarks/throughput.py index f022a55e625f..26cc59b22732 100644 --- a/vllm/benchmarks/throughput.py +++ b/vllm/benchmarks/throughput.py @@ -31,14 +31,17 @@ from vllm.outputs import RequestOutput from vllm.sampling_params import BeamSearchParams from vllm.utils import merge_async_iterators +from vllm.v1.metrics.reader import Metric +from vllm.v1.spec_decode.metrics import compute_acceptance_rate def run_vllm( requests: list[SampleRequest], n: int, engine_args: EngineArgs, + do_profile: bool, disable_detokenize: bool = False, -) -> tuple[float, Optional[list[RequestOutput]]]: +) -> "Results": from vllm import LLM, SamplingParams llm = LLM(**dataclasses.asdict(engine_args)) assert all( @@ -74,12 +77,16 @@ def run_vllm( outputs = None if not use_beam_search: + if do_profile: + llm.start_profile() start = time.perf_counter() outputs = llm.generate(prompts, sampling_params, lora_request=lora_requests, use_tqdm=True) end = time.perf_counter() + if do_profile: + llm.stop_profile() else: assert lora_requests is None, "BeamSearch API does not support LoRA" prompts = [request.prompt for request in requests] @@ -96,7 +103,8 @@ def run_vllm( ignore_eos=True, )) end = time.perf_counter() - return end - start, outputs + runtime = end - start + return Results(runtime=runtime, metrics=llm.get_metrics(), outputs=outputs) def run_vllm_chat( @@ -138,6 +146,13 @@ def run_vllm_chat( return end - start, outputs +@dataclasses.dataclass +class Results: + runtime: float + metrics: list[Metric] + outputs: list + + async def run_vllm_async( requests: list[SampleRequest], n: int, @@ -496,6 +511,12 @@ def add_cli_args(parser: argparse.ArgumentParser): type=str, default=None, help='Path to save the throughput results in JSON format.') + parser.add_argument( + "--print-acceptance-rate", + action="store_true", + default=False, + help="Print the acceptance rate of the speculative decoding model.", + ) parser.add_argument("--async-engine", action='store_true', default=False, @@ -543,6 +564,10 @@ def add_cli_args(parser: argparse.ArgumentParser): type=str, default=None, help="Split of the HF dataset.") + parser.add_argument("--profile", + action="store_true", + default=False, + help="Profile the model.") # prefix repetition dataset prefix_repetition_group = parser.add_argument_group( @@ -604,9 +629,12 @@ def main(args: argparse.Namespace): args.disable_detokenize, )) else: - elapsed_time, request_outputs = run_vllm( + bresults = run_vllm( requests, args.n, EngineArgs.from_cli_args(args), - args.disable_detokenize) + do_profile=args.profile, + disable_detokenize=args.disable_detokenize) + elapsed_time = bresults.runtime + request_outputs = bresults.outputs elif args.backend == "hf": assert args.tensor_parallel_size == 1 elapsed_time = run_hf(requests, args.model, tokenizer, args.n, @@ -651,6 +679,9 @@ def main(args: argparse.Namespace): f"{total_output_tokens / elapsed_time:.2f} output tokens/s") print(f"Total num prompt tokens: {total_prompt_tokens}") print(f"Total num output tokens: {total_output_tokens}") + if args.print_acceptance_rate: + rate = compute_acceptance_rate(bresults.metrics) + print(f"Acceptance rate: {rate:.2f}") # Output JSON results if specified if args.output_json: diff --git a/vllm/config/__init__.py b/vllm/config/__init__.py index f6f1838aedfc..2a963bd1b8db 100644 --- a/vllm/config/__init__.py +++ b/vllm/config/__init__.py @@ -2151,6 +2151,7 @@ def __post_init__(self): code_revision=self.code_revision, tokenizer_revision=self.target_model_config. tokenizer_revision, + max_model_len=self.max_model_len, spec_target_max_model_len=self.target_model_config. max_model_len, quantization=self.quantization, @@ -2192,11 +2193,6 @@ def __post_init__(self): ) else: self.method = "draft_model" - raise NotImplementedError( - "Speculative decoding with draft model is not " - "supported yet. Please consider using other " - "speculative decoding methods such as ngram, medusa, " - "eagle, or deepseek_mtp.") # Replace hf_config for EAGLE draft_model if self.method in ("eagle", "eagle3"): @@ -2407,6 +2403,9 @@ def num_lookahead_slots(self) -> int: def use_eagle(self) -> bool: return self.method in ("eagle", "eagle3", "deepseek_mtp", "ernie_mtp") + def uses_draft_model(self) -> bool: + return self.method == "draft_model" + def __repr__(self) -> str: method = self.method model = None if method == "ngram" else self.draft_model_config.model diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index bee97f4cd04d..5c511b88da3d 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -1486,10 +1486,7 @@ def _is_v1_supported_oracle(self, model_config: ModelConfig) -> bool: # V1 supports N-gram, Medusa, and Eagle speculative decoding. if (self.speculative_config is not None and self.speculative_config.get("method") == "draft_model"): - raise NotImplementedError( - "Speculative decoding with draft model is not supported yet. " - "Please consider using other speculative decoding methods " - "such as ngram, medusa, eagle, or deepseek_mtp.") + return True V1_BACKENDS = [ "FLASH_ATTN_VLLM_V1", diff --git a/vllm/model_executor/model_loader/__init__.py b/vllm/model_executor/model_loader/__init__.py index 2dada794a8f3..0ce2267c9842 100644 --- a/vllm/model_executor/model_loader/__init__.py +++ b/vllm/model_executor/model_loader/__init__.py @@ -111,12 +111,14 @@ def get_model_loader(load_config: LoadConfig) -> BaseModelLoader: def get_model(*, vllm_config: VllmConfig, - model_config: Optional[ModelConfig] = None) -> nn.Module: + model_config: Optional[ModelConfig] = None, + prefix: str = "") -> nn.Module: loader = get_model_loader(vllm_config.load_config) if model_config is None: model_config = vllm_config.model_config return loader.load_model(vllm_config=vllm_config, - model_config=model_config) + model_config=model_config, + prefix=prefix) __all__ = [ diff --git a/vllm/model_executor/model_loader/base_loader.py b/vllm/model_executor/model_loader/base_loader.py index 4cf6c7988960..7d4a50a36250 100644 --- a/vllm/model_executor/model_loader/base_loader.py +++ b/vllm/model_executor/model_loader/base_loader.py @@ -31,8 +31,10 @@ def load_weights(self, model: nn.Module, inplace weights loading for an already-initialized model""" raise NotImplementedError - def load_model(self, vllm_config: VllmConfig, - model_config: ModelConfig) -> nn.Module: + def load_model(self, + vllm_config: VllmConfig, + model_config: ModelConfig, + prefix: str = "") -> nn.Module: """Load a model with the given configurations.""" device_config = vllm_config.device_config load_config = vllm_config.load_config @@ -42,7 +44,8 @@ def load_model(self, vllm_config: VllmConfig, with set_default_torch_dtype(model_config.dtype): with target_device: model = initialize_model(vllm_config=vllm_config, - model_config=model_config) + model_config=model_config, + prefix=prefix) logger.debug("Loading weights on %s ...", load_device) # Quantization does not happen in `load_weights` but after it diff --git a/vllm/model_executor/model_loader/gguf_loader.py b/vllm/model_executor/model_loader/gguf_loader.py index 9877cb3b7c06..054206598061 100644 --- a/vllm/model_executor/model_loader/gguf_loader.py +++ b/vllm/model_executor/model_loader/gguf_loader.py @@ -123,8 +123,10 @@ def load_weights(self, model: nn.Module, model.load_weights( self._get_weights_iterator(local_model_path, gguf_weights_map)) - def load_model(self, vllm_config: VllmConfig, - model_config: ModelConfig) -> nn.Module: + def load_model(self, + vllm_config: VllmConfig, + model_config: ModelConfig, + prefix: str = "") -> nn.Module: device_config = vllm_config.device_config local_model_path = self._prepare_weights(model_config.model) gguf_weights_map = self._get_gguf_weights_map(model_config) @@ -147,7 +149,8 @@ def load_model(self, vllm_config: VllmConfig, target_device = torch.device(device_config.device) with set_default_torch_dtype(model_config.dtype): with target_device: - model = initialize_model(vllm_config=vllm_config) + model = initialize_model(vllm_config=vllm_config, + prefix=prefix) self.load_weights(model, model_config) process_weights_after_loading(model, model_config, target_device) diff --git a/vllm/model_executor/model_loader/tensorizer_loader.py b/vllm/model_executor/model_loader/tensorizer_loader.py index fa01758ab4ce..b0737dd96209 100644 --- a/vllm/model_executor/model_loader/tensorizer_loader.py +++ b/vllm/model_executor/model_loader/tensorizer_loader.py @@ -58,6 +58,7 @@ def _get_weights_iterator( def _load_model_serialized_cpu( self, vllm_config: VllmConfig, + prefix: str = "", ) -> nn.Module: """Load a serialized model with tensorizer to the CPU. @@ -70,7 +71,8 @@ def _load_model_serialized_cpu( model_config = vllm_config.model_config with set_default_torch_dtype(model_config.dtype): with torch.device(device_config.device): - model = initialize_model(vllm_config=vllm_config) + model = initialize_model(vllm_config=vllm_config, + prefix=prefix) model.load_weights(self._get_weights_iterator()) return model.eval() @@ -103,8 +105,10 @@ def load_weights(self, model: nn.Module, else: model.load_weights(self._get_weights_iterator()) - def load_model(self, vllm_config: VllmConfig, - model_config: ModelConfig) -> nn.Module: + def load_model(self, + vllm_config: VllmConfig, + model_config: ModelConfig, + prefix: str = "") -> nn.Module: parallel_config = vllm_config.parallel_config self._verify_config(model_config, parallel_config) @@ -125,7 +129,8 @@ def load_model(self, vllm_config: VllmConfig, vllm_config=vllm_config) self.load_weights(model, model_config) return model - return self._load_model_serialized_cpu(vllm_config=vllm_config) + return self._load_model_serialized_cpu(vllm_config=vllm_config, + prefix=prefix) @staticmethod def save_model( diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 2d40e96632c9..3a3755bc778f 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -154,12 +154,14 @@ def __init__( cache_size=encoder_cache_size) speculative_config = vllm_config.speculative_config - self.use_eagle = False + use_eagle = False self.num_spec_tokens = self.num_lookahead_tokens = 0 if speculative_config: self.num_spec_tokens = speculative_config.num_speculative_tokens if speculative_config.use_eagle(): - self.use_eagle = True + use_eagle = True + self.num_lookahead_tokens = self.num_spec_tokens + if speculative_config.uses_draft_model(): self.num_lookahead_tokens = self.num_spec_tokens # Create the KV cache manager. @@ -167,7 +169,7 @@ def __init__( kv_cache_config=kv_cache_config, max_model_len=self.max_model_len, enable_caching=self.cache_config.enable_prefix_caching, - use_eagle=self.use_eagle, + use_eagle=use_eagle, log_stats=self.log_stats, enable_kv_cache_events=self.enable_kv_cache_events, dcp_world_size=self.dcp_world_size, diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index bf25c91d8390..4ebe584b9f05 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -12,8 +12,9 @@ from vllm.attention.layer import Attention from vllm.config import (CompilationLevel, VllmConfig, get_layers_from_vllm_config) +from vllm.config.compilation import CUDAGraphMode from vllm.distributed.parallel_state import get_pp_group -from vllm.forward_context import set_forward_context +from vllm.forward_context import BatchDescriptor, set_forward_context from vllm.logger import init_logger from vllm.model_executor.model_loader import get_model from vllm.model_executor.models import supports_multimodal @@ -44,18 +45,20 @@ class EagleAttentionMetadata(Protocol): slot_mapping: torch.Tensor -class EagleProposer: +class SpecDecodeProposer: def __init__( self, vllm_config: VllmConfig, device: torch.device, + pass_hidden_states_to_model: bool, runner=None, ): self.vllm_config = vllm_config self.speculative_config = vllm_config.speculative_config self.draft_model_config = self.speculative_config.draft_model_config self.method = self.speculative_config.method + self.pass_hidden_states_to_model = pass_hidden_states_to_model self.runner = runner self.dtype = vllm_config.model_config.dtype @@ -157,6 +160,8 @@ def propose( next_token_ids: torch.Tensor, common_attn_metadata: CommonAttentionMetadata, sampling_metadata: SamplingMetadata, + cudagraph_runtime_mode: CUDAGraphMode, + batch_descriptor: BatchDescriptor, mm_embeds: Optional[list[torch.Tensor]] = None, ) -> torch.Tensor: num_tokens = target_token_ids.shape[0] @@ -169,16 +174,22 @@ def propose( target_hidden_states) assert target_hidden_states.shape[-1] == self.hidden_size - # Shift the input ids by one token. - # E.g., [a1, b1, b2, c1, c2, c3] -> [b1, b2, c1, c2, c3, c3] - self.input_ids[:num_tokens - 1] = target_token_ids[1:] - # Replace the last token with the next token. - # E.g., [b1, b2, c1, c2, c3, c3] -> [a2, b2, b3, c2, c3, c4] - self.input_ids[last_token_indices] = next_token_ids + if self.method == "draft_model": + # Use full input ids, no shifting needed + self.input_ids[:num_tokens] = target_token_ids + else: + # Shift the input ids by one token. + # E.g., [a1, b1, b2, c1, c2, c3] -> [b1, b2, c1, c2, c3, c3] + self.input_ids[:num_tokens - 1] = target_token_ids[1:] + # Replace the last token with the next token. + # E.g., [b1, b2, c1, c2, c3, c3] -> [a2, b2, b3, c2, c3, c4] + self.input_ids[last_token_indices] = next_token_ids assert self.runner is not None # FIXME: need to consider multiple kv_cache_groups + assert len(self.runner.attn_groups) == 1 + assert len(self.runner.attn_groups[0]) == 1 attn_metadata = self.runner.attn_groups[0][0].metadata_builder\ .build_for_drafting(common_attn_metadata=common_attn_metadata, draft_index=0) @@ -195,7 +206,9 @@ def propose( num_input_tokens = num_tokens # copy inputs to buffer for cudagraph self.positions[:num_tokens] = target_positions - self.hidden_states[:num_tokens] = target_hidden_states + if self.pass_hidden_states_to_model: + self.hidden_states[:num_tokens] = target_hidden_states + if self.is_multimodal_model: input_ids = self.input_ids[:num_tokens] inputs_embeds = self.model.get_input_embeddings( @@ -209,16 +222,22 @@ def propose( inputs_embeds = None input_ids = self.input_ids[:num_input_tokens] + model_kwargs = { + "input_ids": input_ids, + "positions": self.positions[:num_input_tokens], + } + if self.pass_hidden_states_to_model: + model_kwargs[ + "hidden_states"] = self.hidden_states[:num_input_tokens] + model_kwargs["inputs_embeds"] = inputs_embeds + with set_forward_context(per_layer_attn_metadata, self.vllm_config, - num_tokens=num_input_tokens): - ret_hidden_states = self.model( - input_ids=input_ids, - positions=self.positions[:num_input_tokens], - hidden_states=self.hidden_states[:num_input_tokens], - inputs_embeds=inputs_embeds, - ) - if self.method in ("deepseek_mtp", "ernie_mtp"): + num_tokens=num_input_tokens, + cudagraph_runtime_mode=cudagraph_runtime_mode, + batch_descriptor=batch_descriptor): + ret_hidden_states = self.model(**model_kwargs) + if self.method in ("draft_model", "deepseek_mtp", "ernie_mtp"): last_hidden_states = ret_hidden_states hidden_states = last_hidden_states else: @@ -240,10 +259,22 @@ def propose( # [batch_size, num_tree_tokens] return torch.cat(draft_token_ids_list, dim=1) - draft_token_ids = logits.argmax(dim=-1) + if self.method == "draft_model": + # Reuse the next_token_ids to avoid a potential rejection + draft_token_ids = next_token_ids + else: + draft_token_ids = logits.argmax(dim=-1) + + if self.method == "draft_model": + # The draft model runs one forward pass to prefill + # the target_token_ids, and another forward pass for decoding + # based on the next_token_ids. I.e. it needs 1 more forward pass. + n_forward_passes = self.num_speculative_tokens + 1 + else: + n_forward_passes = self.num_speculative_tokens # Early exit if there is only one draft token to be generated. - if self.num_speculative_tokens == 1: + if n_forward_passes == 1: # [batch_size, 1] return draft_token_ids.view(-1, 1) @@ -263,7 +294,7 @@ def propose( attn_metadata.num_actual_tokens = batch_size attn_metadata.max_query_len = 1 attn_metadata.query_start_loc = self.arange[:batch_size + 1] - for _ in range(self.num_speculative_tokens - 1): + for _ in range(n_forward_passes - 1): # Update the inputs. # cast to int32 is crucial when eagle model is compiled. # tensor.argmax() returns int64 by default. @@ -309,6 +340,7 @@ def propose( self.input_ids[:batch_size] = input_ids self.positions[:batch_size] = clamped_positions self.hidden_states[:batch_size] = hidden_states + if self.is_multimodal_model: inputs_embeds = self.model.get_input_embeddings(input_ids) self.inputs_embeds[:batch_size] = inputs_embeds @@ -318,22 +350,44 @@ def propose( inputs_embeds = None input_ids = self.input_ids[:input_batch_size] + model_kwargs = { + "input_ids": input_ids, + "positions": self.positions[:input_batch_size], + } + if self.pass_hidden_states_to_model: + model_kwargs[ + "hidden_states"] = self.hidden_states[:input_batch_size] + model_kwargs["inputs_embeds"] = inputs_embeds + + batch_descriptor = BatchDescriptor(num_tokens=input_batch_size, + uniform_decode=True) + cudagraph_runtime_mode, batch_descriptor = \ + self.runner.cudagraph_dispatcher.dispatch(batch_descriptor) + # Run the model. - with set_forward_context(per_layer_attn_metadata, - self.vllm_config, - num_tokens=input_batch_size): - last_hidden_states, hidden_states = self.model( - input_ids=input_ids, - positions=self.positions[:input_batch_size], - hidden_states=self.hidden_states[:input_batch_size], - inputs_embeds=inputs_embeds, - ) + with set_forward_context( + per_layer_attn_metadata, + self.vllm_config, + num_tokens=input_batch_size, + cudagraph_runtime_mode=cudagraph_runtime_mode, + batch_descriptor=batch_descriptor): + ret_hidden_states = self.model(**model_kwargs) + if self.method in ("draft_model", "deepseek_mtp", "ernie_mtp"): + hidden_states = last_hidden_states = ret_hidden_states + else: + last_hidden_states, hidden_states = ret_hidden_states hidden_states = hidden_states[:batch_size] + logits = self.model.compute_logits(last_hidden_states[:batch_size], None) draft_token_ids = logits.argmax(dim=-1) draft_token_ids_list.append(draft_token_ids) + if self.method == "draft_model": + # the first draft_token_ids are identical to next_token_ids, so + # they don't need to be returned as proposed tokens + draft_token_ids_list = draft_token_ids_list[1:] + # [batch_size, num_speculative_tokens] draft_token_ids = torch.stack(draft_token_ids_list, dim=1) return draft_token_ids @@ -611,14 +665,19 @@ def load_model(self, target_model: nn.Module) -> None: from vllm.compilation.backends import set_model_tag with set_model_tag("eagle_head"): - self.model = get_model(vllm_config=self.vllm_config, - model_config=draft_model_config) + vllm_config_draft = replace(self.vllm_config, + model_config=draft_model_config) + self.model = get_model(vllm_config=vllm_config_draft, + model_config=draft_model_config, + prefix="draft_model") draft_attn_layer_names = ( get_layers_from_vllm_config(self.vllm_config, Attention).keys() - target_attn_layer_names) self.attn_layer_names = list(draft_attn_layer_names) + if self.vllm_config.speculative_config.uses_draft_model(): + return if supports_multimodal(target_model): # handle multimodality @@ -654,9 +713,11 @@ def load_model(self, target_model: nn.Module) -> None: def dummy_run( self, num_tokens: int, + forward_ctx_kwargs: dict, ) -> None: - with set_forward_context(None, self.vllm_config, - num_tokens=num_tokens): + with set_forward_context(vllm_config=self.vllm_config, + num_tokens=num_tokens, + **forward_ctx_kwargs): if self.is_multimodal_model: input_ids = None inputs_embeds = self.inputs_embeds[:num_tokens] @@ -664,12 +725,15 @@ def dummy_run( input_ids = self.input_ids[:num_tokens] inputs_embeds = None - self.model( - input_ids=input_ids, - positions=self.positions[:num_tokens], - hidden_states=self.hidden_states[:num_tokens], - inputs_embeds=inputs_embeds, - ) + model_kwargs = { + "input_ids": input_ids, + "positions": self.positions[:num_tokens], + } + if self.pass_hidden_states_to_model: + model_kwargs["hidden_states"] = self.hidden_states[:num_tokens] + model_kwargs["inputs_embeds"] = inputs_embeds + + self.model(**model_kwargs) def validate_same_kv_cache_group(self, kv_cache_config: KVCacheConfig) -> None: @@ -691,6 +755,30 @@ def validate_same_kv_cache_group(self, ) == 1, "All eagle layers should belong to the same kv cache group" +class EagleProposer(SpecDecodeProposer): + + def __init__(self, + vllm_config: VllmConfig, + device: torch.device, + runner=None): + super().__init__(vllm_config=vllm_config, + device=device, + runner=runner, + pass_hidden_states_to_model=True) + + +class DraftModelProposer(SpecDecodeProposer): + + def __init__(self, + vllm_config: VllmConfig, + device: torch.device, + runner=None): + super().__init__(vllm_config=vllm_config, + device=device, + runner=runner, + pass_hidden_states_to_model=False) + + # NOTE(woosuk): Currently, the below code is not used and we always use argmax # to sample the draft tokens. We will use this after we find a way to manage # the draft prob tensor. diff --git a/vllm/v1/spec_decode/metrics.py b/vllm/v1/spec_decode/metrics.py index b4bc3058c570..e9d8cee3f1a7 100644 --- a/vllm/v1/spec_decode/metrics.py +++ b/vllm/v1/spec_decode/metrics.py @@ -9,6 +9,7 @@ from vllm.config import SpeculativeConfig from vllm.logger import init_logger +from vllm.v1.metrics.reader import Metric logger = init_logger(__name__) @@ -176,3 +177,12 @@ def observe(self, spec_decoding_stats: SpecDecodingStats): for pos, counter in enumerate( self.counter_spec_decode_num_accepted_tokens_per_pos): counter.inc(spec_decoding_stats.num_accepted_tokens_per_pos[pos]) + + +def compute_acceptance_rate(metrics: list[Metric]) -> float: + name2metric = {metric.name: metric for metric in metrics} + n_draft_toks = name2metric[ + "vllm:spec_decode_num_draft_tokens"].value # type: ignore + n_accepted_toks = name2metric[ + "vllm:spec_decode_num_accepted_tokens"].value # type: ignore + return n_accepted_toks / n_draft_toks diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 549c5dd2bbb2..f933cfe335e2 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -75,7 +75,8 @@ from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.sample.rejection_sampler import RejectionSampler from vllm.v1.sample.sampler import Sampler -from vllm.v1.spec_decode.eagle import EagleProposer +from vllm.v1.spec_decode.eagle import (DraftModelProposer, EagleProposer, + SpecDecodeProposer) from vllm.v1.spec_decode.medusa import MedusaProposer from vllm.v1.spec_decode.metadata import SpecDecodeMetadata from vllm.v1.spec_decode.ngram_proposer import NgramProposer @@ -237,6 +238,10 @@ def __init__( if self.speculative_config and get_pp_group().is_last_rank: if self.speculative_config.method == "ngram": self.drafter = NgramProposer(self.vllm_config) + elif self.speculative_config.uses_draft_model(): + self.drafter = DraftModelProposer(self.vllm_config, + self.device, + self) # type: ignore elif self.speculative_config.use_eagle(): self.drafter = EagleProposer(self.vllm_config, self.device, self) # type: ignore @@ -1982,6 +1987,8 @@ def execute_model( aux_hidden_states, spec_decode_metadata, spec_decode_common_attn_metadata, + cudagraph_runtime_mode=cudagraph_runtime_mode, + batch_descriptor=batch_descriptor, ) with record_function_or_nullcontext("EPLB"): @@ -2029,6 +2036,8 @@ def propose_draft_token_ids( aux_hidden_states: Optional[torch.Tensor], spec_decode_metadata: Optional[SpecDecodeMetadata], common_attn_metadata: CommonAttentionMetadata, + cudagraph_runtime_mode: CUDAGraphMode, + batch_descriptor: BatchDescriptor, ) -> Union[list[list[int]], torch.Tensor]: num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens if self.speculative_config.method == "ngram": @@ -2055,8 +2064,10 @@ def propose_draft_token_ids( target_hidden_states=hidden_states, sampling_metadata=sampling_metadata, ) - elif self.speculative_config.use_eagle(): - assert isinstance(self.drafter, EagleProposer) + elif self.speculative_config.use_eagle( + ) or self.speculative_config.method == "draft_model": + assert isinstance(self.drafter, + (EagleProposer, DraftModelProposer)) # TODO(woosuk): Refactor the loop. req_ids = self.input_batch.req_ids next_token_ids: list[int] = [] @@ -2122,6 +2133,8 @@ def propose_draft_token_ids( sampling_metadata=sampling_metadata, common_attn_metadata=common_attn_metadata, mm_embeds=mm_embeds, + cudagraph_runtime_mode=cudagraph_runtime_mode, + batch_descriptor=batch_descriptor, ) return draft_token_ids @@ -2634,9 +2647,21 @@ def _dummy_run( else: hidden_states = outputs - if self.speculative_config and self.speculative_config.use_eagle(): - assert isinstance(self.drafter, EagleProposer) - self.drafter.dummy_run(num_tokens) + # Execute dummy run for drafter + is_eagle = (self.speculative_config + and self.speculative_config.use_eagle()) + is_draft_model = (self.speculative_config + and self.speculative_config.uses_draft_model()) + do_draft_dummy_run = is_eagle or is_draft_model + if do_draft_dummy_run: + assert isinstance(self.drafter, SpecDecodeProposer) + forward_ctx_kwargs = { + "attn_metadata": attn_metadata, + "cudagraph_runtime_mode": cudagraph_runtime_mode, + "batch_descriptor": batch_descriptor, + } + self.drafter.dummy_run(num_tokens, + forward_ctx_kwargs=forward_ctx_kwargs) # This is necessary to avoid blocking DP. # For dummy runs, we typically skip EPLB since we don't have any real diff --git a/vllm/v1/worker/utils.py b/vllm/v1/worker/utils.py index 6767804c71b9..7d3c0be8c5a6 100644 --- a/vllm/v1/worker/utils.py +++ b/vllm/v1/worker/utils.py @@ -1,6 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from collections import defaultdict from dataclasses import dataclass from typing import TYPE_CHECKING, Optional @@ -255,25 +254,11 @@ def bind_kv_cache( layers with layer names as keys. runner_kv_caches: The kv_cache declared by ModelRunner. """ - # Bind kv_caches to ModelRunner - assert len(runner_kv_caches) == 0 - - # Convert kv_caches dict to a list of tensors in the order of layer_index. - index2name = defaultdict(list) - for layer_name in kv_caches: - index2name[extract_layer_index(layer_name)].append(layer_name) - - for layer_index in sorted(index2name.keys()): - layer_names = index2name[layer_index] - if len(layer_names) > 1: - # One typical case is encoder-decoder model, e.g., bart. - # The cross attention and self attention in the same decoder layer - # has different layer_name but the same layer_index. - raise NotImplementedError - layer_name = layer_names[0] - runner_kv_caches.append(kv_caches[layer_name]) - - # Bind kv_caches to forward context - for layer_name, kv_cache in kv_caches.items(): + layer_names1 = set(kv_caches.keys()) + layer_names2 = set(forward_context.keys()) + assert layer_names1 == layer_names2 + sorted_layers: list[str] = sorted(layer_names1, key=extract_layer_index) + for layer in sorted_layers: # NOTE: Use list because of v0 PP virtual engine. - forward_context[layer_name].kv_cache = [kv_cache] + forward_context[layer].kv_cache = [kv_caches[layer]] + runner_kv_caches.append(kv_caches[layer]) \ No newline at end of file From 3b06a7c98ea751cfb13a4bb12cd036714cf1f78e Mon Sep 17 00:00:00 2001 From: Tomas Ruiz Date: Mon, 8 Sep 2025 13:10:38 +0200 Subject: [PATCH 02/73] Unod change to 'vllm bench throughput' Signed-off-by: Tomas Ruiz --- vllm/benchmarks/throughput.py | 39 ++++------------------------------- 1 file changed, 4 insertions(+), 35 deletions(-) diff --git a/vllm/benchmarks/throughput.py b/vllm/benchmarks/throughput.py index 26cc59b22732..f022a55e625f 100644 --- a/vllm/benchmarks/throughput.py +++ b/vllm/benchmarks/throughput.py @@ -31,17 +31,14 @@ from vllm.outputs import RequestOutput from vllm.sampling_params import BeamSearchParams from vllm.utils import merge_async_iterators -from vllm.v1.metrics.reader import Metric -from vllm.v1.spec_decode.metrics import compute_acceptance_rate def run_vllm( requests: list[SampleRequest], n: int, engine_args: EngineArgs, - do_profile: bool, disable_detokenize: bool = False, -) -> "Results": +) -> tuple[float, Optional[list[RequestOutput]]]: from vllm import LLM, SamplingParams llm = LLM(**dataclasses.asdict(engine_args)) assert all( @@ -77,16 +74,12 @@ def run_vllm( outputs = None if not use_beam_search: - if do_profile: - llm.start_profile() start = time.perf_counter() outputs = llm.generate(prompts, sampling_params, lora_request=lora_requests, use_tqdm=True) end = time.perf_counter() - if do_profile: - llm.stop_profile() else: assert lora_requests is None, "BeamSearch API does not support LoRA" prompts = [request.prompt for request in requests] @@ -103,8 +96,7 @@ def run_vllm( ignore_eos=True, )) end = time.perf_counter() - runtime = end - start - return Results(runtime=runtime, metrics=llm.get_metrics(), outputs=outputs) + return end - start, outputs def run_vllm_chat( @@ -146,13 +138,6 @@ def run_vllm_chat( return end - start, outputs -@dataclasses.dataclass -class Results: - runtime: float - metrics: list[Metric] - outputs: list - - async def run_vllm_async( requests: list[SampleRequest], n: int, @@ -511,12 +496,6 @@ def add_cli_args(parser: argparse.ArgumentParser): type=str, default=None, help='Path to save the throughput results in JSON format.') - parser.add_argument( - "--print-acceptance-rate", - action="store_true", - default=False, - help="Print the acceptance rate of the speculative decoding model.", - ) parser.add_argument("--async-engine", action='store_true', default=False, @@ -564,10 +543,6 @@ def add_cli_args(parser: argparse.ArgumentParser): type=str, default=None, help="Split of the HF dataset.") - parser.add_argument("--profile", - action="store_true", - default=False, - help="Profile the model.") # prefix repetition dataset prefix_repetition_group = parser.add_argument_group( @@ -629,12 +604,9 @@ def main(args: argparse.Namespace): args.disable_detokenize, )) else: - bresults = run_vllm( + elapsed_time, request_outputs = run_vllm( requests, args.n, EngineArgs.from_cli_args(args), - do_profile=args.profile, - disable_detokenize=args.disable_detokenize) - elapsed_time = bresults.runtime - request_outputs = bresults.outputs + args.disable_detokenize) elif args.backend == "hf": assert args.tensor_parallel_size == 1 elapsed_time = run_hf(requests, args.model, tokenizer, args.n, @@ -679,9 +651,6 @@ def main(args: argparse.Namespace): f"{total_output_tokens / elapsed_time:.2f} output tokens/s") print(f"Total num prompt tokens: {total_prompt_tokens}") print(f"Total num output tokens: {total_output_tokens}") - if args.print_acceptance_rate: - rate = compute_acceptance_rate(bresults.metrics) - print(f"Acceptance rate: {rate:.2f}") # Output JSON results if specified if args.output_json: From e41b0a398c5f4d19846479e60f624b22fe1ab8c2 Mon Sep 17 00:00:00 2001 From: Tomas Ruiz Date: Mon, 8 Sep 2025 13:23:40 +0200 Subject: [PATCH 03/73] Don't return too early Signed-off-by: Tomas Ruiz --- vllm/engine/arg_utils.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 5c511b88da3d..74186edddd89 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -1483,11 +1483,6 @@ def _is_v1_supported_oracle(self, model_config: ModelConfig) -> bool: recommend_to_remove=False) return False - # V1 supports N-gram, Medusa, and Eagle speculative decoding. - if (self.speculative_config is not None - and self.speculative_config.get("method") == "draft_model"): - return True - V1_BACKENDS = [ "FLASH_ATTN_VLLM_V1", "FLASH_ATTN", From 10366b9b6f3c839bc4425f12600947833d735c7a Mon Sep 17 00:00:00 2001 From: Tomas Ruiz Date: Mon, 8 Sep 2025 14:10:03 +0200 Subject: [PATCH 04/73] Undo change to bind_kv_cache() Signed-off-by: Tomas Ruiz --- tests/v1/test_utils.py | 32 ++++++++++++++++++++++++++++++++ vllm/v1/worker/utils.py | 33 ++++++++++++++++++++++++++------- 2 files changed, 58 insertions(+), 7 deletions(-) diff --git a/tests/v1/test_utils.py b/tests/v1/test_utils.py index 00d98a873a31..87123dac1daf 100644 --- a/tests/v1/test_utils.py +++ b/tests/v1/test_utils.py @@ -42,6 +42,38 @@ def test_bind_kv_cache(): assert runner_kv_caches[3] is kv_cache['layers.3.self_attn'] +def test_bind_kv_cache_draft_model(): + from vllm.attention import Attention + ctx = { + 'model.layers.0.attn': Attention(32, 128, 0.1), + 'model.layers.1.attn': Attention(32, 128, 0.1), + 'draft_model.layers.0.attn': Attention(32, 128, 0.1), + 'draft_model.layers.1.attn': Attention(32, 128, 0.1), + } + kv_cache = { + 'model.layers.0.attn': torch.zeros((1, )), + 'model.layers.1.attn': torch.zeros((1, )), + 'draft_model.layers.0.attn': torch.zeros((1, )), + 'draft_model.layers.1.attn': torch.zeros((1, )), + } + runner_kv_caches: list[torch.Tensor] = [] + bind_kv_cache(kv_cache, ctx, runner_kv_caches) + assert ctx['model.layers.0.attn'].kv_cache[0] is kv_cache[ + 'model.layers.0.attn'] + assert ctx['model.layers.1.attn'].kv_cache[0] is kv_cache[ + 'model.layers.1.attn'] + assert ctx['draft_model.layers.0.attn'].kv_cache[0] is kv_cache[ + 'draft_model.layers.0.attn'] + assert ctx['draft_model.layers.1.attn'].kv_cache[0] is kv_cache[ + 'draft_model.layers.1.attn'] + + # caches are ordered by layer_index, interleaving target and draft model + assert runner_kv_caches[0] is kv_cache['model.layers.0.attn'] + assert runner_kv_caches[1] is kv_cache['draft_model.layers.0.attn'] + assert runner_kv_caches[2] is kv_cache['model.layers.1.attn'] + assert runner_kv_caches[3] is kv_cache['draft_model.layers.1.attn'] + + def test_bind_kv_cache_non_attention(): from vllm.attention import Attention diff --git a/vllm/v1/worker/utils.py b/vllm/v1/worker/utils.py index 7d3c0be8c5a6..ace521c1d002 100644 --- a/vllm/v1/worker/utils.py +++ b/vllm/v1/worker/utils.py @@ -1,5 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from collections import defaultdict from dataclasses import dataclass from typing import TYPE_CHECKING, Optional @@ -254,11 +255,29 @@ def bind_kv_cache( layers with layer names as keys. runner_kv_caches: The kv_cache declared by ModelRunner. """ - layer_names1 = set(kv_caches.keys()) - layer_names2 = set(forward_context.keys()) - assert layer_names1 == layer_names2 - sorted_layers: list[str] = sorted(layer_names1, key=extract_layer_index) - for layer in sorted_layers: + # Bind kv_caches to ModelRunner + assert len(runner_kv_caches) == 0 + + # Convert kv_caches dict to a list of tensors in the order of layer_index. + index2name = defaultdict(list) + for layer_name in kv_caches: + index2name[extract_layer_index(layer_name)].append(layer_name) + + for layer_index in sorted(index2name.keys()): + layer_names = index2name[layer_index] + non_draft_layers = [ + name for name in layer_names if not name.startswith('draft_model.') + ] + if len(non_draft_layers) > 1: + # One typical case is encoder-decoder model, e.g., bart. + # The cross attention and self attention in the same decoder layer + # has different layer_name but the same layer_index. + raise NotImplementedError + + for layer_name in layer_names: + runner_kv_caches.append(kv_caches[layer_name]) + + # Bind kv_caches to forward context + for layer_name, kv_cache in kv_caches.items(): # NOTE: Use list because of v0 PP virtual engine. - forward_context[layer].kv_cache = [kv_caches[layer]] - runner_kv_caches.append(kv_caches[layer]) \ No newline at end of file + forward_context[layer_name].kv_cache = [kv_cache] From 92af339729b66781d2761b8da904ce10607dfff1 Mon Sep 17 00:00:00 2001 From: Tomas Ruiz Date: Mon, 8 Sep 2025 14:18:23 +0200 Subject: [PATCH 05/73] Undo changes to pyproject.toml Signed-off-by: Tomas Ruiz --- pyproject.toml | 5 ----- 1 file changed, 5 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index e41d8a26aa55..e63f8aeae278 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -154,11 +154,6 @@ markers = [ "skip_v1: do not run this test with v1", "optional: optional tests that are automatically skipped, include --optional to run them", ] -# Show print statements and logs during test execution -addopts = "-s --tb=short --log-cli-level=INFO" -log_cli = true -log_cli_format = "%(asctime)s [%(levelname)8s] %(name)s: %(message)s" -log_cli_date_format = "%Y-%m-%d %H:%M:%S" [tool.ty.src] root = "./vllm" From f2f9876bca93edfe38b6d19526c72ec77c5a031a Mon Sep 17 00:00:00 2001 From: Tomas Ruiz Date: Mon, 8 Sep 2025 17:10:43 +0200 Subject: [PATCH 06/73] Simplify test array Signed-off-by: Tomas Ruiz --- tests/v1/e2e/test_spec_decode.py | 33 +++++--------------------------- 1 file changed, 5 insertions(+), 28 deletions(-) diff --git a/tests/v1/e2e/test_spec_decode.py b/tests/v1/e2e/test_spec_decode.py index dd712050ed8a..1155284b5b41 100644 --- a/tests/v1/e2e/test_spec_decode.py +++ b/tests/v1/e2e/test_spec_decode.py @@ -251,39 +251,13 @@ class ArgsTest: expected_acceptance_rate: float expected_same_output_fraction: float # Defaults + target_tensor_parallel_size: int = 1 + draft_tensor_parallel_size: int = 1 max_model_len: int = 1024 gpu_memory_utilization: float = 0.5 cases = [ - ArgsTest( - model="baidu/ERNIE-4.5-0.3B-PT", - draft_model="baidu/ERNIE-4.5-0.3B-PT", - sampling_config=greedy_sampling(), - expected_acceptance_rate=1.0, - expected_same_output_fraction=1.0, - ), - ArgsTest( - model="baidu/ERNIE-4.5-0.3B-PT", - draft_model="baidu/ERNIE-4.5-0.3B-PT", - sampling_config=stochastic_sampling(), - expected_acceptance_rate=0.2, - expected_same_output_fraction=0.0, - ), - ArgsTest( - model="meta-llama/Llama-3.2-1B-Instruct", - draft_model="meta-llama/Llama-3.2-1B-Instruct", - sampling_config=greedy_sampling(), - expected_acceptance_rate=0.8, - expected_same_output_fraction=0.5, - ), - ArgsTest( - model="meta-llama/Llama-3.2-1B-Instruct", - draft_model="meta-llama/Llama-3.2-1B-Instruct", - sampling_config=stochastic_sampling(), - expected_acceptance_rate=0.4, - expected_same_output_fraction=0.15, - ), ArgsTest( model="Qwen/Qwen3-1.7B", draft_model="Qwen/Qwen3-0.6B", @@ -318,9 +292,11 @@ def test_draft_model_correctness(args: ArgsTest, enforce_eager: bool, "num_speculative_tokens": 3, "max_model_len": args.max_model_len, "enforce_eager": enforce_eager, + "tensor_parallel_size": args.draft_tensor_parallel_size, }, max_model_len=args.max_model_len, gpu_memory_utilization=args.gpu_memory_utilization, + tensor_parallel_size=args.target_tensor_parallel_size, enforce_eager=enforce_eager, disable_log_stats=False, # enables get_metrics() ) @@ -336,6 +312,7 @@ def test_draft_model_correctness(args: ArgsTest, enforce_eager: bool, model=args.model, max_model_len=args.max_model_len, gpu_memory_utilization=args.gpu_memory_utilization, + tensor_parallel_size=args.target_tensor_parallel_size, enforce_eager=enforce_eager, ) ref_outputs = ref_llm.chat(test_prompts, args.sampling_config) From 824ba10269ec2e7d82fc03149c9862d62a67c362 Mon Sep 17 00:00:00 2001 From: Tomas Ruiz Date: Tue, 9 Sep 2025 05:07:46 +0000 Subject: [PATCH 07/73] Ensure EAGLE loads correctly Signed-off-by: Tomas Ruiz --- vllm/v1/spec_decode/eagle.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index 4ebe584b9f05..5a5a5abe692e 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -664,12 +664,18 @@ def load_model(self, target_model: nn.Module) -> None: get_layers_from_vllm_config(self.vllm_config, Attention).keys()) from vllm.compilation.backends import set_model_tag - with set_model_tag("eagle_head"): - vllm_config_draft = replace(self.vllm_config, - model_config=draft_model_config) - self.model = get_model(vllm_config=vllm_config_draft, - model_config=draft_model_config, - prefix="draft_model") + + if self.vllm_config.speculative_config.uses_draft_model(): + with set_model_tag("draft_model"): + vllm_config_draft = replace(self.vllm_config, + model_config=draft_model_config) + self.model = get_model(vllm_config=vllm_config_draft, + model_config=draft_model_config, + prefix="draft_model") + else: + with set_model_tag("eagle_head"): + self.model = get_model(vllm_config=self.vllm_config, + model_config=draft_model_config) draft_attn_layer_names = ( get_layers_from_vllm_config(self.vllm_config, Attention).keys() - From 5e248c1b33264429924acb71c1592b216eba1572 Mon Sep 17 00:00:00 2001 From: Tomas Ruiz Date: Tue, 9 Sep 2025 09:35:17 +0200 Subject: [PATCH 08/73] Pass input_embeds when model is multimodal Signed-off-by: Tomas Ruiz --- vllm/v1/spec_decode/eagle.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index 5a5a5abe692e..678ac472c694 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -229,6 +229,7 @@ def propose( if self.pass_hidden_states_to_model: model_kwargs[ "hidden_states"] = self.hidden_states[:num_input_tokens] + if self.is_multimodal_model: model_kwargs["inputs_embeds"] = inputs_embeds with set_forward_context(per_layer_attn_metadata, @@ -737,6 +738,7 @@ def dummy_run( } if self.pass_hidden_states_to_model: model_kwargs["hidden_states"] = self.hidden_states[:num_tokens] + if self.is_multimodal_model: model_kwargs["inputs_embeds"] = inputs_embeds self.model(**model_kwargs) From 1669ea7965a0097d3e62d16b2b3b10aeeccfd603 Mon Sep 17 00:00:00 2001 From: Tomas Ruiz Date: Tue, 9 Sep 2025 11:09:38 +0200 Subject: [PATCH 09/73] Raise NotImplementedError on Mrope or Multimodal models Signed-off-by: Tomas Ruiz --- vllm/v1/spec_decode/eagle.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index 678ac472c694..43132bc461c6 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -77,6 +77,13 @@ def __init__( self.is_multimodal_model = vllm_config.model_config \ .is_multimodal_model + if self.is_multimodal_model and self.method == "draft_model": + raise NotImplementedError("Speculative Decoding with draft models " + "does not support multimodal models yet") + if self.draft_model_config.uses_mrope and self.method == "draft_model": + raise NotImplementedError("Speculative Decoding with draft models " + "does not support M-RoPE yet") + self.use_cuda_graph = (self.vllm_config.compilation_config.level == CompilationLevel.PIECEWISE and not self.vllm_config.model_config.enforce_eager) @@ -358,6 +365,7 @@ def propose( if self.pass_hidden_states_to_model: model_kwargs[ "hidden_states"] = self.hidden_states[:input_batch_size] + if self.is_multimodal_model: model_kwargs["inputs_embeds"] = inputs_embeds batch_descriptor = BatchDescriptor(num_tokens=input_batch_size, From 54e107d299b63cb07003dd8fcaa6fa862ceafd23 Mon Sep 17 00:00:00 2001 From: Tomas Ruiz Date: Wed, 17 Sep 2025 13:36:44 +0200 Subject: [PATCH 10/73] Speculative decoding with draft model separate from EAGLE Signed-off-by: Tomas Ruiz --- vllm/engine/arg_utils.py | 6 - vllm/v1/core/sched/scheduler.py | 6 +- vllm/v1/spec_decode/draft_model.py | 256 ++++++++++++++++++ vllm/v1/spec_decode/eagle.py | 412 ++++++++++++----------------- vllm/v1/worker/gpu_model_runner.py | 64 +++-- 5 files changed, 465 insertions(+), 279 deletions(-) create mode 100644 vllm/v1/spec_decode/draft_model.py diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index f18b6c1eb33d..71ceb475be85 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -1520,12 +1520,6 @@ def _is_v1_supported_oracle(self, model_config: ModelConfig) -> bool: recommend_to_remove=False) return False - # No OTLP observability so far. - if (self.otlp_traces_endpoint or self.collect_detailed_traces): - _raise_or_fallback(feature_name="--otlp-traces-endpoint", - recommend_to_remove=False) - return False - V1_BACKENDS = [ "FLASH_ATTN_VLLM_V1", "FLASH_ATTN", diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 86f3b1712c7c..3685d5d25032 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -154,12 +154,12 @@ def __init__( cache_size=encoder_cache_size) speculative_config = vllm_config.speculative_config - use_eagle = False + self.use_eagle = False self.num_spec_tokens = self.num_lookahead_tokens = 0 if speculative_config: self.num_spec_tokens = speculative_config.num_speculative_tokens if speculative_config.use_eagle(): - use_eagle = True + self.use_eagle = True self.num_lookahead_tokens = self.num_spec_tokens if speculative_config.uses_draft_model(): self.num_lookahead_tokens = self.num_spec_tokens @@ -169,7 +169,7 @@ def __init__( kv_cache_config=kv_cache_config, max_model_len=self.max_model_len, enable_caching=self.cache_config.enable_prefix_caching, - use_eagle=use_eagle, + use_eagle=self.use_eagle, log_stats=self.log_stats, enable_kv_cache_events=self.enable_kv_cache_events, dcp_world_size=self.dcp_world_size, diff --git a/vllm/v1/spec_decode/draft_model.py b/vllm/v1/spec_decode/draft_model.py new file mode 100644 index 000000000000..58a2c23e8b42 --- /dev/null +++ b/vllm/v1/spec_decode/draft_model.py @@ -0,0 +1,256 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from dataclasses import replace +from typing import Any + +import torch + +from vllm.attention.layer import Attention +from vllm.config import ModelConfig, VllmConfig, get_layers_from_vllm_config +from vllm.config.compilation import CUDAGraphMode +from vllm.forward_context import BatchDescriptor, set_forward_context +from vllm.model_executor.model_loader import get_model +from vllm.v1.attention.backends.tree_attn import TreeAttentionMetadata +from vllm.v1.attention.backends.utils import CommonAttentionMetadata +from vllm.v1.spec_decode.eagle import (PADDING_SLOT_ID, SpecDecodeBaseProposer, + drafter_prepare_inputs) +from vllm.v1.worker.ubatching import dbo_current_ubatch_id + + +class DraftModelProposer(SpecDecodeBaseProposer): + + def __init__( + self, + vllm_config: VllmConfig, + device: torch.device, + runner=None, + ): + super().__init__(vllm_config, device, runner) + self._raise_if_multimodal() + self._raise_if_mrope() + + def _raise_if_multimodal(self): + if self.is_multimodal_model: + raise NotImplementedError("Speculative Decoding with draft models " + "does not support multimodal models yet") + + def _raise_if_mrope(self): + if self.draft_model_config.uses_mrope: + raise NotImplementedError("Speculative Decoding with draft models " + "does not support M-RoPE yet") + + def _model_kwargs(self, num_tokens: int) -> dict[str, Any]: + self._raise_if_multimodal() + self._raise_if_mrope() + return { + "input_ids": self.input_ids[:num_tokens], + "positions": self.positions[:num_tokens], + } + + def dummy_run(self, num_tokens: int, forward_ctx_kwargs: dict): + model_kwargs = self._model_kwargs(num_tokens) + with set_forward_context( + vllm_config=self.vllm_config, + num_tokens=num_tokens, + **forward_ctx_kwargs, + ): + self.model(**model_kwargs) + + # Copied and adapted from eagle.py + def propose( + self, + # [num_tokens] + target_token_ids: torch.Tensor, + # [num_tokens] + target_positions: torch.Tensor, + # [batch_size] + next_token_ids: torch.Tensor, + common_attn_metadata: CommonAttentionMetadata, + cudagraph_runtime_mode: CUDAGraphMode, + batch_descriptor: BatchDescriptor, + ) -> torch.Tensor: + num_tokens = target_token_ids.shape[0] + batch_size = next_token_ids.shape[0] + last_token_indices = common_attn_metadata.query_start_loc[1:] - 1 + + self.input_ids[:num_tokens] = target_token_ids + + assert self.runner is not None + + # FIXME: need to consider multiple kv_cache_groups + assert len(self.runner.attn_groups) == 1 + assert len(self.runner.attn_groups[0]) == 1 + ubatch_id = dbo_current_ubatch_id() + attn_metadata_builder = self.runner.attn_groups[0][ + 0].metadata_builders[ubatch_id] + attn_metadata = attn_metadata_builder.build_for_drafting( + common_attn_metadata=common_attn_metadata, draft_index=0) + + # At this moment, we assume all draft model layers belong to the same KV + # cache group, thus using the same attention metadata. + per_layer_attn_metadata = {} + for layer_name in self.attn_layer_names: + per_layer_attn_metadata[layer_name] = attn_metadata + + if self.use_cuda_graph and num_tokens <= self.cudagraph_batch_sizes[-1]: + num_input_tokens = self.vllm_config.pad_for_cudagraph(num_tokens) + else: + num_input_tokens = num_tokens + # copy inputs to buffer for cudagraph + self.positions[:num_tokens] = target_positions + + model_kwargs = self._model_kwargs(num_input_tokens) + with set_forward_context( + per_layer_attn_metadata, + self.vllm_config, + num_tokens=num_input_tokens, + cudagraph_runtime_mode=cudagraph_runtime_mode, + batch_descriptor=batch_descriptor, + ): + last_hidden_states = self.model(**model_kwargs) + + sample_hidden_states = last_hidden_states[last_token_indices] + logits = self.model.compute_logits(sample_hidden_states, None) + positions = target_positions[last_token_indices] + + if isinstance(attn_metadata, TreeAttentionMetadata): + raise NotImplementedError("Speculative Decoding with draft models " + "does not support tree attention yet") + + # Reuse the next_token_ids to avoid a potential rejection + draft_token_ids = next_token_ids + + # The draft model runs one forward pass to prefill + # the target_token_ids, and another forward pass for decoding + # based on the next_token_ids. I.e. it needs 1 more forward pass. + n_forward_passes = self.num_speculative_tokens + 1 + # Early exit if there is only one draft token to be generated. + if n_forward_passes == 1: + # [batch_size, 1] + return draft_token_ids.view(-1, 1) + + # Generate the remaining draft tokens. + draft_token_ids_list = [draft_token_ids] + + if self.use_cuda_graph and batch_size <= self.cudagraph_batch_sizes[-1]: + input_batch_size = self.vllm_config.pad_for_cudagraph(batch_size) + else: + input_batch_size = batch_size + + attn_metadata.num_actual_tokens = batch_size + attn_metadata.max_query_len = 1 + attn_metadata.query_start_loc = self.arange[:batch_size + 1] + for _ in range(n_forward_passes - 1): + # Update the inputs. + # cast to int32 is crucial when draft model is compiled. + # tensor.argmax() returns int64 by default. + input_ids = draft_token_ids_list[-1].int() + positions += 1 + + # NOTE(woosuk): We should handle the case where the draft model + # generates tokens beyond the max model length. Since it is complex + # to remove such requests from the batch, we keep them in the batch + # but adjust the position ids and slot mappings to avoid the + # out-of-range access during the model execution. The draft tokens + # generated with this adjustment should be ignored. + exceeds_max_model_len = positions >= self.max_model_len + # Mask out the position ids that exceed the max model length. + # Otherwise, we may get out-of-range error in RoPE. + clamped_positions = torch.where(exceeds_max_model_len, 0, + positions) + + # Increment the sequence lengths. + attn_metadata.max_seq_len += 1 + attn_metadata.seq_lens += 1 + # Consider max model length. + attn_metadata.max_seq_len = min(attn_metadata.max_seq_len, + self.max_model_len) + # For the requests that exceed the max model length, we set the + # sequence length to 1 to minimize their overheads in attention. + attn_metadata.seq_lens.masked_fill_(exceeds_max_model_len, 1) + + # Compute the slot mapping. + block_numbers = clamped_positions // self.block_size + block_ids = attn_metadata.block_table.gather( + dim=1, index=block_numbers.view(-1, 1)) + block_ids = block_ids.view(-1) + attn_metadata.slot_mapping = (block_ids * self.block_size + + clamped_positions % self.block_size) + # Mask out the slot mappings that exceed the max model length. + # Otherwise, the KV cache will be inadvertently updated with the + # padding tokens. + attn_metadata.slot_mapping.masked_fill_(exceeds_max_model_len, + PADDING_SLOT_ID) + + # copy inputs to buffer for cudagraph + self.input_ids[:batch_size] = input_ids + self.positions[:batch_size] = clamped_positions + + model_kwargs = self._model_kwargs(input_batch_size) + batch_descriptor = BatchDescriptor(num_tokens=input_batch_size, + uniform_decode=True) + cudagraph_runtime_mode, batch_descriptor = ( + self.runner.cudagraph_dispatcher.dispatch(batch_descriptor)) + + # Run the model. + with set_forward_context( + per_layer_attn_metadata, + self.vllm_config, + num_tokens=input_batch_size, + cudagraph_runtime_mode=cudagraph_runtime_mode, + batch_descriptor=batch_descriptor, + ): + last_hidden_states = self.model(**model_kwargs) + + logits = self.model.compute_logits(last_hidden_states[:batch_size], + None) + draft_token_ids = logits.argmax(dim=-1) + draft_token_ids_list.append(draft_token_ids) + + # the first draft_token_ids are identical to next_token_ids, so + # they don't need to be returned as proposed tokens + draft_token_ids_list = draft_token_ids_list[1:] + + # [batch_size, num_speculative_tokens] + draft_token_ids = torch.stack(draft_token_ids_list, dim=1) + return draft_token_ids + + def load_model(self) -> None: + draft_model_config: ModelConfig = ( + self.vllm_config.speculative_config.draft_model_config) + vllm_config_draft: VllmConfig = replace( + self.vllm_config, model_config=draft_model_config) + + # This must be computed before loading the draft model + # because that mutates the forward_context of the vllm_config + target_attn_layer_names = set( + get_layers_from_vllm_config(self.vllm_config, Attention).keys()) + + from vllm.compilation.backends import set_model_tag + + with set_model_tag("draft_model"): + self.model = get_model( + vllm_config=vllm_config_draft, + model_config=draft_model_config, + prefix="draft_model", + ) + + # This must be computed after loading the draft model + # because that mutates the forward_context of the vllm_config + draft_attn_layer_names = ( + get_layers_from_vllm_config(self.vllm_config, Attention).keys() - + target_attn_layer_names) + self.attn_layer_names = list(draft_attn_layer_names) + + # Copied from eagle.py + def prepare_inputs( + self, + common_attn_metadata: CommonAttentionMetadata, + # [batch_size] + num_rejected_tokens: torch.Tensor, + ) -> tuple[CommonAttentionMetadata, torch.Tensor]: + return drafter_prepare_inputs( + self.token_arange_np, + common_attn_metadata, + num_rejected_tokens, + ) diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index b58609a59939..66aa8f6a6baa 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -12,9 +12,8 @@ from vllm.attention.layer import Attention from vllm.config import (CompilationLevel, VllmConfig, get_layers_from_vllm_config) -from vllm.config.compilation import CUDAGraphMode from vllm.distributed.parallel_state import get_pp_group -from vllm.forward_context import BatchDescriptor, set_forward_context +from vllm.forward_context import set_forward_context from vllm.logger import init_logger from vllm.model_executor.model_loader import get_model from vllm.model_executor.models import supports_multimodal @@ -46,20 +45,18 @@ class EagleAttentionMetadata(Protocol): slot_mapping: torch.Tensor -class SpecDecodeProposer: +class SpecDecodeBaseProposer: def __init__( self, vllm_config: VllmConfig, device: torch.device, - pass_hidden_states_to_model: bool, runner=None, ): self.vllm_config = vllm_config self.speculative_config = vllm_config.speculative_config self.draft_model_config = self.speculative_config.draft_model_config self.method = self.speculative_config.method - self.pass_hidden_states_to_model = pass_hidden_states_to_model self.runner = runner self.dtype = vllm_config.model_config.dtype @@ -70,21 +67,10 @@ def __init__( self.max_num_tokens = ( vllm_config.scheduler_config.max_num_batched_tokens) self.token_arange_np = np.arange(self.max_num_tokens) - # We need to get the hidden size from the draft model config because - # the draft model's hidden size can be different from the target model's - # hidden size (e.g., Llama 3.3 70B). - self.hidden_size = self.draft_model_config.get_hidden_size() self.is_multimodal_model = vllm_config.model_config \ .is_multimodal_model - if self.is_multimodal_model and self.method == "draft_model": - raise NotImplementedError("Speculative Decoding with draft models " - "does not support multimodal models yet") - if self.draft_model_config.uses_mrope and self.method == "draft_model": - raise NotImplementedError("Speculative Decoding with draft models " - "does not support M-RoPE yet") - self.use_cuda_graph = (self.vllm_config.compilation_config.level == CompilationLevel.PIECEWISE and not self.vllm_config.model_config.enforce_eager) @@ -99,20 +85,35 @@ def __init__( self.positions = torch.zeros(self.max_num_tokens, dtype=torch.int64, device=device) - self.hidden_states = torch.zeros( - (self.max_num_tokens, self.hidden_size), - dtype=self.dtype, - device=device) - max_batch_size = vllm_config.scheduler_config.max_num_seqs + self.max_batch_size = vllm_config.scheduler_config.max_num_seqs self.arange = torch.arange( # We need +1 here because the arange is used to set query_start_loc, # which has one more element than batch_size. - max_batch_size + 1, + self.max_batch_size + 1, device=device, dtype=torch.int32, ) + +class EagleProposer(SpecDecodeBaseProposer): + + def __init__( + self, + vllm_config: VllmConfig, + device: torch.device, + runner=None, + ): + super().__init__(vllm_config, device, runner) + # We need to get the hidden size from the draft model config because + # the draft model's hidden size can be different from the target model's + # hidden size (e.g., Llama 3.3 70B). + self.hidden_size = self.draft_model_config.get_hidden_size() + self.hidden_states = torch.zeros( + (self.max_num_tokens, self.hidden_size), + dtype=self.dtype, + device=device) + self.inputs_embeds = torch.zeros( (self.max_num_tokens, self.hidden_size), dtype=self.dtype, @@ -154,7 +155,7 @@ def __init__( len(self.tree_choices) + 1, device=device, dtype=torch.int32, - ).repeat(max_batch_size, 1) + ).repeat(self.max_batch_size, 1) def propose( self, @@ -168,8 +169,6 @@ def propose( next_token_ids: torch.Tensor, common_attn_metadata: CommonAttentionMetadata, sampling_metadata: SamplingMetadata, - cudagraph_runtime_mode: CUDAGraphMode, - batch_descriptor: BatchDescriptor, mm_embeds: Optional[list[torch.Tensor]] = None, ) -> torch.Tensor: num_tokens = target_token_ids.shape[0] @@ -182,22 +181,16 @@ def propose( target_hidden_states) assert target_hidden_states.shape[-1] == self.hidden_size - if self.method == "draft_model": - # Use full input ids, no shifting needed - self.input_ids[:num_tokens] = target_token_ids - else: - # Shift the input ids by one token. - # E.g., [a1, b1, b2, c1, c2, c3] -> [b1, b2, c1, c2, c3, c3] - self.input_ids[:num_tokens - 1] = target_token_ids[1:] - # Replace the last token with the next token. - # E.g., [b1, b2, c1, c2, c3, c3] -> [a2, b2, b3, c2, c3, c4] - self.input_ids[last_token_indices] = next_token_ids + # Shift the input ids by one token. + # E.g., [a1, b1, b2, c1, c2, c3] -> [b1, b2, c1, c2, c3, c3] + self.input_ids[:num_tokens - 1] = target_token_ids[1:] + # Replace the last token with the next token. + # E.g., [b1, b2, c1, c2, c3, c3] -> [a2, b2, b3, c2, c3, c4] + self.input_ids[last_token_indices] = next_token_ids assert self.runner is not None # FIXME: need to consider multiple kv_cache_groups - assert len(self.runner.attn_groups) == 1 - assert len(self.runner.attn_groups[0]) == 1 ubatch_id = dbo_current_ubatch_id() attn_metadata_builder = \ self.runner.attn_groups[0][0].metadata_builders[ubatch_id] @@ -216,9 +209,7 @@ def propose( num_input_tokens = num_tokens # copy inputs to buffer for cudagraph self.positions[:num_tokens] = target_positions - if self.pass_hidden_states_to_model: - self.hidden_states[:num_tokens] = target_hidden_states - + self.hidden_states[:num_tokens] = target_hidden_states if self.is_multimodal_model: input_ids = self.input_ids[:num_tokens] inputs_embeds = self.model.get_input_embeddings( @@ -232,24 +223,16 @@ def propose( inputs_embeds = None input_ids = self.input_ids[:num_input_tokens] - model_kwargs = { - "input_ids": input_ids, - "positions": self.positions[:num_input_tokens], - } - if self.pass_hidden_states_to_model: - model_kwargs[ - "hidden_states"] = self.hidden_states[:num_input_tokens] - if self.is_multimodal_model: - model_kwargs["inputs_embeds"] = inputs_embeds - with set_forward_context(per_layer_attn_metadata, self.vllm_config, - num_tokens=num_input_tokens, - cudagraph_runtime_mode=cudagraph_runtime_mode, - batch_descriptor=batch_descriptor): - ret_hidden_states = self.model(**model_kwargs) - if self.method in ("draft_model", "deepseek_mtp", "ernie_mtp", - "qwen3_next_mtp"): + num_tokens=num_input_tokens): + ret_hidden_states = self.model( + input_ids=input_ids, + positions=self.positions[:num_input_tokens], + hidden_states=self.hidden_states[:num_input_tokens], + inputs_embeds=inputs_embeds, + ) + if self.method in ("deepseek_mtp", "ernie_mtp", "qwen3_next_mtp"): last_hidden_states = ret_hidden_states hidden_states = last_hidden_states else: @@ -271,22 +254,10 @@ def propose( # [batch_size, num_tree_tokens] return torch.cat(draft_token_ids_list, dim=1) - if self.method == "draft_model": - # Reuse the next_token_ids to avoid a potential rejection - draft_token_ids = next_token_ids - else: - draft_token_ids = logits.argmax(dim=-1) - - if self.method == "draft_model": - # The draft model runs one forward pass to prefill - # the target_token_ids, and another forward pass for decoding - # based on the next_token_ids. I.e. it needs 1 more forward pass. - n_forward_passes = self.num_speculative_tokens + 1 - else: - n_forward_passes = self.num_speculative_tokens + draft_token_ids = logits.argmax(dim=-1) # Early exit if there is only one draft token to be generated. - if n_forward_passes == 1: + if self.num_speculative_tokens == 1: # [batch_size, 1] return draft_token_ids.view(-1, 1) @@ -306,7 +277,7 @@ def propose( attn_metadata.num_actual_tokens = batch_size attn_metadata.max_query_len = 1 attn_metadata.query_start_loc = self.arange[:batch_size + 1] - for _ in range(n_forward_passes - 1): + for _ in range(self.num_speculative_tokens - 1): # Update the inputs. # cast to int32 is crucial when eagle model is compiled. # tensor.argmax() returns int64 by default. @@ -352,7 +323,6 @@ def propose( self.input_ids[:batch_size] = input_ids self.positions[:batch_size] = clamped_positions self.hidden_states[:batch_size] = hidden_states - if self.is_multimodal_model: inputs_embeds = self.model.get_input_embeddings(input_ids) self.inputs_embeds[:batch_size] = inputs_embeds @@ -362,47 +332,28 @@ def propose( inputs_embeds = None input_ids = self.input_ids[:input_batch_size] - model_kwargs = { - "input_ids": input_ids, - "positions": self.positions[:input_batch_size], - } - if self.pass_hidden_states_to_model: - model_kwargs[ - "hidden_states"] = self.hidden_states[:input_batch_size] - if self.is_multimodal_model: - model_kwargs["inputs_embeds"] = inputs_embeds - - batch_descriptor = BatchDescriptor(num_tokens=input_batch_size, - uniform_decode=True) - cudagraph_runtime_mode, batch_descriptor = \ - self.runner.cudagraph_dispatcher.dispatch(batch_descriptor) - # Run the model. - with set_forward_context( - per_layer_attn_metadata, - self.vllm_config, - num_tokens=input_batch_size, - cudagraph_runtime_mode=cudagraph_runtime_mode, - batch_descriptor=batch_descriptor): - ret_hidden_states = self.model(**model_kwargs) - if self.method in ("draft_model", "deepseek_mtp", "ernie_mtp", - "qwen3_next_mtp"): - last_hidden_states = ret_hidden_states - hidden_states = ret_hidden_states - else: - last_hidden_states, hidden_states = ret_hidden_states + with set_forward_context(per_layer_attn_metadata, + self.vllm_config, + num_tokens=input_batch_size): + ret_hidden_states = self.model( + input_ids=input_ids, + positions=self.positions[:input_batch_size], + hidden_states=self.hidden_states[:input_batch_size], + inputs_embeds=inputs_embeds, + ) + if self.method in ("deepseek_mtp", "ernie_mtp", + "qwen3_next_mtp"): + last_hidden_states = ret_hidden_states + hidden_states = ret_hidden_states + else: + last_hidden_states, hidden_states = ret_hidden_states hidden_states = hidden_states[:batch_size] - logits = self.model.compute_logits(last_hidden_states[:batch_size], None) draft_token_ids = logits.argmax(dim=-1) draft_token_ids_list.append(draft_token_ids) - if self.method == "draft_model": - # the first draft_token_ids are identical to next_token_ids, so - # they don't need to be returned as proposed tokens - draft_token_ids_list = draft_token_ids_list[1:] - # [batch_size, num_speculative_tokens] draft_token_ids = torch.stack(draft_token_ids_list, dim=1) return draft_token_ids @@ -583,96 +534,12 @@ def prepare_inputs( # [batch_size] num_rejected_tokens: torch.Tensor ) -> tuple[CommonAttentionMetadata, torch.Tensor]: - """ - This function is used to prepare the inputs for the spec decode. - It updates to the common_attn_metadata to account for the rejected - tokens (and newly sampled tokens). It also returns the token indices - of the tokens that should be fed to the speculator. - """ - # E.g. - # common_attn_metadata.query_start_loc{_cpu}: - # [0, q1, q1 + q2, q1 + q2 + q3] - # common_attn_metadata.seq_lens{_cpu}: [s1, s2, s3] - # num_rejected_tokens: [n1, n2, n3] - # This function computes the intermediate values: - # num_tokens_per_req: [q1 - n1, q2 - n2, q3 - n3] - # And returns: - # common_attn_metadata.query_start_loc{_cpu}: - # [0, q1 - n1, q1 + q2 - n1 - n2, q1 + q2 + q3 - n1 - n2 - n3] - # common_attn_metadata.seq_lens{_cpu}: - # [s1 - n1 + 1, s2 - n2 + 1, s3 - n3 + 1] - # token_indices: [0, 1, ..., q1 - n1 - 1, - # q1, q1 + 1, ..., q1 + q2 - n2 - 1, - # q1 + q2, q1 + q2 + 1, ..., q1 + q2 + q3 - n3 - 1] - - device = common_attn_metadata.query_start_loc.device - query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu - new_seq_lens_cpu = common_attn_metadata.seq_lens_cpu \ - - num_rejected_tokens - - # [0, q1, q1 + q2, q1 + q2 + q3] -> [q1, q2, q3] - new_query_len_per_req = (query_start_loc_cpu[1:] - - query_start_loc_cpu[:-1]) - # [q1, q2, q3] -> [q1 - n1, q2 - n2, q3 - n3] - new_num_tokens_per_req = new_query_len_per_req - num_rejected_tokens - new_num_tokens_per_req_np = new_num_tokens_per_req.numpy() - - # [q1 - n1, q2 - n2, q3 - n3] -> - # [0, q1 - n1, q1 + q2 - n1 - n2, q1 + q2 + q3 - n1 - n2 - n3] - new_query_start_loc_cpu = torch.zeros( - query_start_loc_cpu.shape, - dtype=torch.int32, - pin_memory=is_pin_memory_available()) - new_query_start_loc_np = new_query_start_loc_cpu.numpy() - np.cumsum(new_num_tokens_per_req_np, out=new_query_start_loc_np[1:]) - - total_num_tokens = new_query_start_loc_np[-1] - # Example assuming num_tokens_per_req_np = [2, 4, 3] - # this implies that `new_query_start_locs` is: - # [0, 2, 6, 9] -> - # [0, 0, 2, 2, 2, 2, 6, 6, 6] - # _r1_ ____r2____ ___r3__ - new_query_start_locs_expanded = np.repeat(new_query_start_loc_np[:-1], - new_num_tokens_per_req_np) - # [0, 1, 2, 3, 4, 5, 6, 7, 8] -> - # [0, 1, 0, 1, 2, 3, 0, 1, 2] - # _r1_ ____r2____ ___r3__ - token_offests = self.token_arange_np[:total_num_tokens] \ - - new_query_start_locs_expanded - - # Expand starting positions to match token pattern - # [0, q1, q1 + q2] -> - # [0, 0, q1, q1, q1, q1, q1 + q2, q1 + q2, q1 + q2] - # _r1_ _____r2_______ ___________r3____________ - old_query_start_locs_expanded = np.repeat( - query_start_loc_cpu[:-1].numpy(), new_num_tokens_per_req_np) - # Final token indices are: - # [0, 1, // req 1 - # q1 + 0, q1 + 1, q1 + 2, q1 + 3, // req 2 - # q1 + q2 + 0, q1 + q2 + 1, q1 + q2 + 2] // req 3 - token_indices_np = token_offests + old_query_start_locs_expanded - token_indices = torch.from_numpy(token_indices_np).to( - device, non_blocking=True) - - spec_common_attn_metadata = CommonAttentionMetadata( - query_start_loc=new_query_start_loc_cpu.to(device, - non_blocking=True), - seq_lens=new_seq_lens_cpu.to(device, non_blocking=True), - query_start_loc_cpu=new_query_start_loc_cpu, - seq_lens_cpu=new_seq_lens_cpu, - num_computed_tokens_cpu=common_attn_metadata. - num_computed_tokens_cpu, - num_reqs=common_attn_metadata.num_reqs, - num_actual_tokens=total_num_tokens, - max_query_len=new_query_len_per_req.max().item(), - max_seq_len=new_seq_lens_cpu.max().item(), - block_table_tensor=common_attn_metadata.block_table_tensor, - slot_mapping=common_attn_metadata.slot_mapping[token_indices], - causal=True, + return drafter_prepare_inputs( + self.token_arange_np, + common_attn_metadata, + num_rejected_tokens, ) - return spec_common_attn_metadata, token_indices - def load_model(self, target_model: nn.Module) -> None: draft_model_config = \ self.vllm_config.speculative_config.draft_model_config @@ -680,26 +547,15 @@ def load_model(self, target_model: nn.Module) -> None: get_layers_from_vllm_config(self.vllm_config, Attention).keys()) from vllm.compilation.backends import set_model_tag - - if self.vllm_config.speculative_config.uses_draft_model(): - with set_model_tag("draft_model"): - vllm_config_draft = replace(self.vllm_config, - model_config=draft_model_config) - self.model = get_model(vllm_config=vllm_config_draft, - model_config=draft_model_config, - prefix="draft_model") - else: - with set_model_tag("eagle_head"): - self.model = get_model(vllm_config=self.vllm_config, - model_config=draft_model_config) + with set_model_tag("eagle_head"): + self.model = get_model(vllm_config=self.vllm_config, + model_config=draft_model_config) draft_attn_layer_names = ( get_layers_from_vllm_config(self.vllm_config, Attention).keys() - target_attn_layer_names) self.attn_layer_names = list(draft_attn_layer_names) - if self.vllm_config.speculative_config.uses_draft_model(): - return if supports_multimodal(target_model): # handle multimodality @@ -735,11 +591,9 @@ def load_model(self, target_model: nn.Module) -> None: def dummy_run( self, num_tokens: int, - forward_ctx_kwargs: dict, ) -> None: - with set_forward_context(vllm_config=self.vllm_config, - num_tokens=num_tokens, - **forward_ctx_kwargs): + with set_forward_context(None, self.vllm_config, + num_tokens=num_tokens): if self.is_multimodal_model: input_ids = None inputs_embeds = self.inputs_embeds[:num_tokens] @@ -747,16 +601,12 @@ def dummy_run( input_ids = self.input_ids[:num_tokens] inputs_embeds = None - model_kwargs = { - "input_ids": input_ids, - "positions": self.positions[:num_tokens], - } - if self.pass_hidden_states_to_model: - model_kwargs["hidden_states"] = self.hidden_states[:num_tokens] - if self.is_multimodal_model: - model_kwargs["inputs_embeds"] = inputs_embeds - - self.model(**model_kwargs) + self.model( + input_ids=input_ids, + positions=self.positions[:num_tokens], + hidden_states=self.hidden_states[:num_tokens], + inputs_embeds=inputs_embeds, + ) def validate_same_kv_cache_group(self, kv_cache_config: KVCacheConfig) -> None: @@ -778,30 +628,6 @@ def validate_same_kv_cache_group(self, ) == 1, "All eagle layers should belong to the same kv cache group" -class EagleProposer(SpecDecodeProposer): - - def __init__(self, - vllm_config: VllmConfig, - device: torch.device, - runner=None): - super().__init__(vllm_config=vllm_config, - device=device, - runner=runner, - pass_hidden_states_to_model=True) - - -class DraftModelProposer(SpecDecodeProposer): - - def __init__(self, - vllm_config: VllmConfig, - device: torch.device, - runner=None): - super().__init__(vllm_config=vllm_config, - device=device, - runner=runner, - pass_hidden_states_to_model=False) - - # NOTE(woosuk): Currently, the below code is not used and we always use argmax # to sample the draft tokens. We will use this after we find a way to manage # the draft prob tensor. @@ -843,3 +669,97 @@ def compute_probs_and_sample_next_token( next_token_ids, ) return next_token_ids, probs + + +def drafter_prepare_inputs( + token_arange_np: np.ndarray, + common_attn_metadata: CommonAttentionMetadata, + # [batch_size] + num_rejected_tokens: torch.Tensor +) -> tuple[CommonAttentionMetadata, torch.Tensor]: + """ + This function is used to prepare the inputs for the spec decode. + It updates to the common_attn_metadata to account for the rejected + tokens (and newly sampled tokens). It also returns the token indices + of the tokens that should be fed to the speculator. + """ + # E.g. + # common_attn_metadata.query_start_loc{_cpu}: + # [0, q1, q1 + q2, q1 + q2 + q3] + # common_attn_metadata.seq_lens{_cpu}: [s1, s2, s3] + # num_rejected_tokens: [n1, n2, n3] + # This function computes the intermediate values: + # num_tokens_per_req: [q1 - n1, q2 - n2, q3 - n3] + # And returns: + # common_attn_metadata.query_start_loc{_cpu}: + # [0, q1 - n1, q1 + q2 - n1 - n2, q1 + q2 + q3 - n1 - n2 - n3] + # common_attn_metadata.seq_lens{_cpu}: + # [s1 - n1 + 1, s2 - n2 + 1, s3 - n3 + 1] + # token_indices: [0, 1, ..., q1 - n1 - 1, + # q1, q1 + 1, ..., q1 + q2 - n2 - 1, + # q1 + q2, q1 + q2 + 1, ..., q1 + q2 + q3 - n3 - 1] + + device = common_attn_metadata.query_start_loc.device + query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu + new_seq_lens_cpu = common_attn_metadata.seq_lens_cpu \ + - num_rejected_tokens + + # [0, q1, q1 + q2, q1 + q2 + q3] -> [q1, q2, q3] + new_query_len_per_req = (query_start_loc_cpu[1:] - + query_start_loc_cpu[:-1]) + # [q1, q2, q3] -> [q1 - n1, q2 - n2, q3 - n3] + new_num_tokens_per_req = new_query_len_per_req - num_rejected_tokens + new_num_tokens_per_req_np = new_num_tokens_per_req.numpy() + + # [q1 - n1, q2 - n2, q3 - n3] -> + # [0, q1 - n1, q1 + q2 - n1 - n2, q1 + q2 + q3 - n1 - n2 - n3] + new_query_start_loc_cpu = torch.zeros(query_start_loc_cpu.shape, + dtype=torch.int32, + pin_memory=is_pin_memory_available()) + new_query_start_loc_np = new_query_start_loc_cpu.numpy() + np.cumsum(new_num_tokens_per_req_np, out=new_query_start_loc_np[1:]) + + total_num_tokens = new_query_start_loc_np[-1] + # Example assuming num_tokens_per_req_np = [2, 4, 3] + # this implies that `new_query_start_locs` is: + # [0, 2, 6, 9] -> + # [0, 0, 2, 2, 2, 2, 6, 6, 6] + # _r1_ ____r2____ ___r3__ + new_query_start_locs_expanded = np.repeat(new_query_start_loc_np[:-1], + new_num_tokens_per_req_np) + # [0, 1, 2, 3, 4, 5, 6, 7, 8] -> + # [0, 1, 0, 1, 2, 3, 0, 1, 2] + # _r1_ ____r2____ ___r3__ + token_offests = token_arange_np[:total_num_tokens] \ + - new_query_start_locs_expanded + + # Expand starting positions to match token pattern + # [0, q1, q1 + q2] -> + # [0, 0, q1, q1, q1, q1, q1 + q2, q1 + q2, q1 + q2] + # _r1_ _____r2_______ ___________r3____________ + old_query_start_locs_expanded = np.repeat(query_start_loc_cpu[:-1].numpy(), + new_num_tokens_per_req_np) + # Final token indices are: + # [0, 1, // req 1 + # q1 + 0, q1 + 1, q1 + 2, q1 + 3, // req 2 + # q1 + q2 + 0, q1 + q2 + 1, q1 + q2 + 2] // req 3 + token_indices_np = token_offests + old_query_start_locs_expanded + token_indices = torch.from_numpy(token_indices_np).to(device, + non_blocking=True) + + spec_common_attn_metadata = CommonAttentionMetadata( + query_start_loc=new_query_start_loc_cpu.to(device, non_blocking=True), + seq_lens=new_seq_lens_cpu.to(device, non_blocking=True), + query_start_loc_cpu=new_query_start_loc_cpu, + seq_lens_cpu=new_seq_lens_cpu, + num_computed_tokens_cpu=common_attn_metadata.num_computed_tokens_cpu, + num_reqs=common_attn_metadata.num_reqs, + num_actual_tokens=total_num_tokens, + max_query_len=new_query_len_per_req.max().item(), + max_seq_len=new_seq_lens_cpu.max().item(), + block_table_tensor=common_attn_metadata.block_table_tensor, + slot_mapping=common_attn_metadata.slot_mapping[token_indices], + causal=True, + ) + + return spec_common_attn_metadata, token_indices diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index dd2dc7f0dd25..2da4a5247ece 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -81,8 +81,8 @@ from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.sample.rejection_sampler import RejectionSampler from vllm.v1.sample.sampler import Sampler -from vllm.v1.spec_decode.eagle import (DraftModelProposer, EagleProposer, - SpecDecodeProposer) +from vllm.v1.spec_decode.draft_model import DraftModelProposer +from vllm.v1.spec_decode.eagle import EagleProposer from vllm.v1.spec_decode.medusa import MedusaProposer from vllm.v1.spec_decode.metadata import SpecDecodeMetadata from vllm.v1.spec_decode.ngram_proposer import NgramProposer @@ -270,9 +270,9 @@ def __init__( if self.speculative_config.method == "ngram": self.drafter = NgramProposer(self.vllm_config) elif self.speculative_config.uses_draft_model(): - self.drafter = DraftModelProposer(self.vllm_config, - self.device, - self) # type: ignore + self.drafter = DraftModelProposer(vllm_config=self.vllm_config, + device=self.device, + runner=self) # type: ignore elif self.speculative_config.use_eagle(): self.drafter = EagleProposer(self.vllm_config, self.device, self) # type: ignore @@ -2337,7 +2337,7 @@ def propose_draft_token_ids( sampling_metadata=sampling_metadata, ) elif self.speculative_config.use_eagle( - ) or self.speculative_config.method == "draft_model": + ) or self.speculative_config.uses_draft_model(): assert isinstance(self.drafter, (EagleProposer, DraftModelProposer)) # TODO(woosuk): Refactor the loop. @@ -2397,17 +2397,29 @@ def propose_draft_token_ids( mm_embeds = self._gather_mm_embeddings(scheduler_output, shift_computed_tokens=1) - draft_token_ids = self.drafter.propose( - target_token_ids=target_token_ids, - target_positions=target_positions, - target_hidden_states=target_hidden_states, - next_token_ids=next_token_ids, - sampling_metadata=sampling_metadata, - common_attn_metadata=common_attn_metadata, - mm_embeds=mm_embeds, - cudagraph_runtime_mode=cudagraph_runtime_mode, - batch_descriptor=batch_descriptor, - ) + if self.speculative_config.use_eagle(): + assert isinstance(self.drafter, EagleProposer) + draft_token_ids = self.drafter.propose( + target_token_ids=target_token_ids, + target_positions=target_positions, + target_hidden_states=target_hidden_states, + next_token_ids=next_token_ids, + sampling_metadata=sampling_metadata, + common_attn_metadata=common_attn_metadata, + mm_embeds=mm_embeds, + cudagraph_runtime_mode=cudagraph_runtime_mode, + batch_descriptor=batch_descriptor, + ) + elif self.speculative_config.uses_draft_model(): + assert isinstance(self.drafter, DraftModelProposer) + draft_token_ids = self.drafter.propose( + target_token_ids=target_token_ids, + target_positions=target_positions, + next_token_ids=next_token_ids, + common_attn_metadata=common_attn_metadata, + cudagraph_runtime_mode=cudagraph_runtime_mode, + batch_descriptor=batch_descriptor, + ) return draft_token_ids def propose_ngram_draft_token_ids( @@ -2503,7 +2515,11 @@ def load_model(self, eep_scale_up: bool = False) -> None: self.device) if hasattr(self, "drafter"): logger.info("Loading drafter model...") - self.drafter.load_model(self.model) + if self.speculative_config.uses_draft_model(): + assert isinstance(self.drafter, DraftModelProposer) + self.drafter.load_model() + else: + self.drafter.load_model(self.model) if self.use_aux_hidden_state_outputs: if supports_eagle3(self.model): self.model.set_aux_hidden_state_layers( @@ -3009,13 +3025,14 @@ def _dummy_run( hidden_states = outputs # Execute dummy run for drafter - is_eagle = (self.speculative_config - and self.speculative_config.use_eagle()) + if self.speculative_config and self.speculative_config.use_eagle(): + assert isinstance(self.drafter, EagleProposer) + self.drafter.dummy_run(num_tokens) + is_draft_model = (self.speculative_config and self.speculative_config.uses_draft_model()) - do_draft_dummy_run = is_eagle or is_draft_model - if do_draft_dummy_run: - assert isinstance(self.drafter, SpecDecodeProposer) + if is_draft_model: + assert isinstance(self.drafter, DraftModelProposer) forward_ctx_kwargs = { "attn_metadata": attn_metadata, "cudagraph_runtime_mode": cudagraph_runtime_mode, @@ -3023,7 +3040,6 @@ def _dummy_run( } self.drafter.dummy_run(num_tokens, forward_ctx_kwargs=forward_ctx_kwargs) - # This is necessary to avoid blocking DP. # For dummy runs, we typically skip EPLB since we don't have any real # requests to process. From 36fb9406a4b9f8ac447af4e47bc7c8e15bc68d5a Mon Sep 17 00:00:00 2001 From: Tomas Ruiz Date: Thu, 18 Sep 2025 17:24:25 +0200 Subject: [PATCH 11/73] Pass last_token_indices Signed-off-by: Tomas Ruiz --- vllm/v1/spec_decode/draft_model.py | 23 +--- vllm/v1/spec_decode/eagle.py | 196 ++++++++++++++--------------- vllm/v1/worker/gpu_model_runner.py | 1 + 3 files changed, 100 insertions(+), 120 deletions(-) diff --git a/vllm/v1/spec_decode/draft_model.py b/vllm/v1/spec_decode/draft_model.py index e5b951cbebb5..b5cbe0d5e5cf 100644 --- a/vllm/v1/spec_decode/draft_model.py +++ b/vllm/v1/spec_decode/draft_model.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from dataclasses import replace -from typing import Any +from typing import Any, Optional import torch @@ -12,8 +12,7 @@ from vllm.model_executor.model_loader import get_model from vllm.v1.attention.backends.tree_attn import TreeAttentionMetadata from vllm.v1.attention.backends.utils import CommonAttentionMetadata -from vllm.v1.spec_decode.eagle import (PADDING_SLOT_ID, SpecDecodeBaseProposer, - drafter_prepare_inputs) +from vllm.v1.spec_decode.eagle import PADDING_SLOT_ID, SpecDecodeBaseProposer from vllm.v1.worker.ubatching import dbo_current_ubatch_id @@ -65,13 +64,15 @@ def propose( target_positions: torch.Tensor, # [batch_size] next_token_ids: torch.Tensor, + last_token_indices: Optional[torch.Tensor], common_attn_metadata: CommonAttentionMetadata, cudagraph_runtime_mode: CUDAGraphMode, batch_descriptor: BatchDescriptor, ) -> torch.Tensor: num_tokens = target_token_ids.shape[0] batch_size = next_token_ids.shape[0] - last_token_indices = common_attn_metadata.query_start_loc[1:] - 1 + if last_token_indices is None: + last_token_indices = common_attn_metadata.query_start_loc[1:] - 1 self.input_ids[:num_tokens] = target_token_ids @@ -241,17 +242,3 @@ def load_model(self) -> None: get_layers_from_vllm_config(self.vllm_config, Attention).keys() - target_attn_layer_names) self.attn_layer_names = list(draft_attn_layer_names) - - # Copied from eagle.py - def prepare_inputs( - self, - common_attn_metadata: CommonAttentionMetadata, - sampled_token_ids: list[list[int]], - num_draft_tokens: list[int], - ) -> tuple[CommonAttentionMetadata, torch.Tensor]: - return drafter_prepare_inputs( - self.token_arange_np, - common_attn_metadata, - sampled_token_ids, - num_draft_tokens, - ) diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index eef23c4f7c0d..78bfb3834e7d 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -718,12 +718,101 @@ def prepare_inputs( sampled_token_ids: list[list[int]], num_draft_tokens: list[int], ) -> tuple[CommonAttentionMetadata, torch.Tensor]: - return drafter_prepare_inputs( - self.token_arange_np, - common_attn_metadata, - sampled_token_ids, - num_draft_tokens, + """ + This function is used to prepare the inputs for speculative decoding. + It updates to the common_attn_metadata to account for the rejected + tokens (and newly sampled tokens). It also returns the token indices + of the tokens that should be fed to the speculator. + """ + # E.g. + # common_attn_metadata.query_start_loc{_cpu}: + # [0, q1, q1 + q2, q1 + q2 + q3] + # common_attn_metadata.seq_lens{_cpu}: [s1, s2, s3] + # num_rejected_tokens: [n1, n2, n3] + # This function computes the intermediate values: + # num_tokens_per_req: [q1 - n1, q2 - n2, q3 - n3] + # And returns: + # common_attn_metadata.query_start_loc{_cpu}: + # [0, q1 - n1, q1 + q2 - n1 - n2, q1 + q2 + q3 - n1 - n2 - n3] + # common_attn_metadata.seq_lens{_cpu}: + # [s1 - n1 + 1, s2 - n2 + 1, s3 - n3 + 1] + # token_indices: [0, 1, ..., q1 - n1 - 1, + # q1, q1 + 1, ..., q1 + q2 - n2 - 1, + # q1 + q2, q1 + q2 + 1, ..., q1 + q2 + q3 - n3 - 1] + + num_rejected_tokens = [ + n + 1 - len(sampled_token_ids[i]) if n > 0 else 0 + for i, n in enumerate(num_draft_tokens) + ] + num_rejected_tokens = torch.tensor(num_rejected_tokens, + dtype=torch.int32) + + device = common_attn_metadata.query_start_loc.device + query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu + new_seq_lens_cpu = common_attn_metadata.seq_lens_cpu \ + - num_rejected_tokens + + # [0, q1, q1 + q2, q1 + q2 + q3] -> [q1, q2, q3] + new_query_len_per_req = (query_start_loc_cpu[1:] - + query_start_loc_cpu[:-1]) + # [q1, q2, q3] -> [q1 - n1, q2 - n2, q3 - n3] + new_num_tokens_per_req = new_query_len_per_req - num_rejected_tokens + new_num_tokens_per_req_np = new_num_tokens_per_req.numpy() + + # [q1 - n1, q2 - n2, q3 - n3] -> + # [0, q1 - n1, q1 + q2 - n1 - n2, q1 + q2 + q3 - n1 - n2 - n3] + new_query_start_loc_cpu = torch.zeros( + query_start_loc_cpu.shape, + dtype=torch.int32, + pin_memory=is_pin_memory_available()) + new_query_start_loc_np = new_query_start_loc_cpu.numpy() + np.cumsum(new_num_tokens_per_req_np, out=new_query_start_loc_np[1:]) + + total_num_tokens = new_query_start_loc_np[-1] + # Example assuming num_tokens_per_req_np = [2, 4, 3] + # this implies that `new_query_start_locs` is: + # [0, 2, 6, 9] -> + # [0, 0, 2, 2, 2, 2, 6, 6, 6] + # _r1_ ____r2____ ___r3__ + new_query_start_locs_expanded = np.repeat(new_query_start_loc_np[:-1], + new_num_tokens_per_req_np) + # [0, 1, 2, 3, 4, 5, 6, 7, 8] -> + # [0, 1, 0, 1, 2, 3, 0, 1, 2] + # _r1_ ____r2____ ___r3__ + token_offests = self.token_arange_np[:total_num_tokens] \ + - new_query_start_locs_expanded + + # Expand starting positions to match token pattern + # [0, q1, q1 + q2] -> + # [0, 0, q1, q1, q1, q1, q1 + q2, q1 + q2, q1 + q2] + # _r1_ _____r2_______ ___________r3____________ + old_query_start_locs_expanded = np.repeat( + query_start_loc_cpu[:-1].numpy(), new_num_tokens_per_req_np) + # Final token indices are: + # [0, 1, // req 1 + # q1 + 0, q1 + 1, q1 + 2, q1 + 3, // req 2 + # q1 + q2 + 0, q1 + q2 + 1, q1 + q2 + 2] // req 3 + token_indices_np = token_offests + old_query_start_locs_expanded + token_indices = torch.from_numpy(token_indices_np).to( + device, non_blocking=True) + + spec_common_attn_metadata = CommonAttentionMetadata( + query_start_loc=new_query_start_loc_cpu.to(device, + non_blocking=True), + seq_lens=new_seq_lens_cpu.to(device, non_blocking=True), + query_start_loc_cpu=new_query_start_loc_cpu, + seq_lens_cpu=new_seq_lens_cpu, + num_computed_tokens_cpu=common_attn_metadata. + num_computed_tokens_cpu, + num_reqs=common_attn_metadata.num_reqs, + num_actual_tokens=total_num_tokens, + max_query_len=new_query_len_per_req.max().item(), + max_seq_len=new_seq_lens_cpu.max().item(), + block_table_tensor=common_attn_metadata.block_table_tensor, + slot_mapping=common_attn_metadata.slot_mapping[token_indices], + causal=True, ) + return spec_common_attn_metadata, token_indices def load_model(self, target_model: nn.Module) -> None: draft_model_config = \ @@ -854,100 +943,3 @@ def compute_probs_and_sample_next_token( next_token_ids, ) return next_token_ids, probs - - -def drafter_prepare_inputs( - token_arange_np: np.ndarray, common_attn_metadata: CommonAttentionMetadata, - sampled_token_ids: list[list[int]], num_draft_tokens: list[int] -) -> tuple[CommonAttentionMetadata, torch.Tensor]: - """ - This function is used to prepare the inputs for speculative decoding. - It updates to the common_attn_metadata to account for the rejected - tokens (and newly sampled tokens). It also returns the token indices - of the tokens that should be fed to the speculator. - """ - # E.g. - # common_attn_metadata.query_start_loc{_cpu}: - # [0, q1, q1 + q2, q1 + q2 + q3] - # common_attn_metadata.seq_lens{_cpu}: [s1, s2, s3] - # num_rejected_tokens: [n1, n2, n3] - # This function computes the intermediate values: - # num_tokens_per_req: [q1 - n1, q2 - n2, q3 - n3] - # And returns: - # common_attn_metadata.query_start_loc{_cpu}: - # [0, q1 - n1, q1 + q2 - n1 - n2, q1 + q2 + q3 - n1 - n2 - n3] - # common_attn_metadata.seq_lens{_cpu}: - # [s1 - n1 + 1, s2 - n2 + 1, s3 - n3 + 1] - # token_indices: [0, 1, ..., q1 - n1 - 1, - # q1, q1 + 1, ..., q1 + q2 - n2 - 1, - # q1 + q2, q1 + q2 + 1, ..., q1 + q2 + q3 - n3 - 1] - - num_rejected_tokens = [ - n + 1 - len(sampled_token_ids[i]) if n > 0 else 0 - for i, n in enumerate(num_draft_tokens) - ] - num_rejected_tokens = torch.tensor(num_rejected_tokens, dtype=torch.int32) - - device = common_attn_metadata.query_start_loc.device - query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu - new_seq_lens_cpu = common_attn_metadata.seq_lens_cpu \ - - num_rejected_tokens - - # [0, q1, q1 + q2, q1 + q2 + q3] -> [q1, q2, q3] - new_query_len_per_req = (query_start_loc_cpu[1:] - - query_start_loc_cpu[:-1]) - # [q1, q2, q3] -> [q1 - n1, q2 - n2, q3 - n3] - new_num_tokens_per_req = new_query_len_per_req - num_rejected_tokens - new_num_tokens_per_req_np = new_num_tokens_per_req.numpy() - - # [q1 - n1, q2 - n2, q3 - n3] -> - # [0, q1 - n1, q1 + q2 - n1 - n2, q1 + q2 + q3 - n1 - n2 - n3] - new_query_start_loc_cpu = torch.zeros(query_start_loc_cpu.shape, - dtype=torch.int32, - pin_memory=is_pin_memory_available()) - new_query_start_loc_np = new_query_start_loc_cpu.numpy() - np.cumsum(new_num_tokens_per_req_np, out=new_query_start_loc_np[1:]) - - total_num_tokens = new_query_start_loc_np[-1] - # Example assuming num_tokens_per_req_np = [2, 4, 3] - # this implies that `new_query_start_locs` is: - # [0, 2, 6, 9] -> - # [0, 0, 2, 2, 2, 2, 6, 6, 6] - # _r1_ ____r2____ ___r3__ - new_query_start_locs_expanded = np.repeat(new_query_start_loc_np[:-1], - new_num_tokens_per_req_np) - # [0, 1, 2, 3, 4, 5, 6, 7, 8] -> - # [0, 1, 0, 1, 2, 3, 0, 1, 2] - # _r1_ ____r2____ ___r3__ - token_offests = token_arange_np[:total_num_tokens] \ - - new_query_start_locs_expanded - - # Expand starting positions to match token pattern - # [0, q1, q1 + q2] -> - # [0, 0, q1, q1, q1, q1, q1 + q2, q1 + q2, q1 + q2] - # _r1_ _____r2_______ ___________r3____________ - old_query_start_locs_expanded = np.repeat(query_start_loc_cpu[:-1].numpy(), - new_num_tokens_per_req_np) - # Final token indices are: - # [0, 1, // req 1 - # q1 + 0, q1 + 1, q1 + 2, q1 + 3, // req 2 - # q1 + q2 + 0, q1 + q2 + 1, q1 + q2 + 2] // req 3 - token_indices_np = token_offests + old_query_start_locs_expanded - token_indices = torch.from_numpy(token_indices_np).to(device, - non_blocking=True) - - spec_common_attn_metadata = CommonAttentionMetadata( - query_start_loc=new_query_start_loc_cpu.to(device, non_blocking=True), - seq_lens=new_seq_lens_cpu.to(device, non_blocking=True), - query_start_loc_cpu=new_query_start_loc_cpu, - seq_lens_cpu=new_seq_lens_cpu, - num_computed_tokens_cpu=common_attn_metadata.num_computed_tokens_cpu, - num_reqs=common_attn_metadata.num_reqs, - num_actual_tokens=total_num_tokens, - max_query_len=new_query_len_per_req.max().item(), - max_seq_len=new_seq_lens_cpu.max().item(), - block_table_tensor=common_attn_metadata.block_table_tensor, - slot_mapping=common_attn_metadata.slot_mapping[token_indices], - causal=True, - ) - return spec_common_attn_metadata, token_indices diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 8ebee699f23a..90c88fbc5e93 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -2472,6 +2472,7 @@ def propose_draft_token_ids( target_token_ids=target_token_ids, target_positions=target_positions, next_token_ids=next_token_ids, + last_token_indices=token_indices_to_sample, common_attn_metadata=common_attn_metadata, cudagraph_runtime_mode=cudagraph_runtime_mode, batch_descriptor=batch_descriptor, From b0185607746597603d89e67b23928b55253f9e0b Mon Sep 17 00:00:00 2001 From: Tomas Ruiz Date: Thu, 18 Sep 2025 23:23:06 +0200 Subject: [PATCH 12/73] Undo unnecessary changes Signed-off-by: Tomas Ruiz --- vllm/v1/worker/gpu_model_runner.py | 918 ++++++++++++----------------- 1 file changed, 384 insertions(+), 534 deletions(-) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 90c88fbc5e93..e2c9085db624 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -139,12 +139,12 @@ def __init__( with torch.cuda.stream(async_output_copy_stream): async_output_copy_stream.wait_stream(default_stream) self._sampled_token_ids_cpu = self._sampled_token_ids.to( - "cpu", non_blocking=True) + 'cpu', non_blocking=True) self._async_copy_ready_event.record() def get_output(self) -> ModelRunnerOutput: """Copy the device tensors to the host and return a ModelRunnerOutput. - + This function blocks until the copy is finished. """ self._async_copy_ready_event.synchronize() @@ -180,7 +180,6 @@ def __init__( self.observability_config = vllm_config.observability_config from vllm.model_executor.models.utils import set_cpu_offload_max_bytes - set_cpu_offload_max_bytes( int(self.cache_config.cpu_offload_gb * 1024**3)) @@ -197,7 +196,7 @@ def __init__( self.kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[ cache_config.cache_dtype] - self.is_pooling_model = model_config.runner_type == "pooling" + self.is_pooling_model = (model_config.runner_type == 'pooling') self.is_multimodal_raw_input_only_model = ( model_config.is_multimodal_raw_input_only_model) @@ -233,7 +232,8 @@ def __init__( if self.model_config.is_encoder_decoder: # Maximum length of the encoder input, only for encoder-decoder # models. - self.max_encoder_len = scheduler_config.max_num_encoder_input_tokens + self.max_encoder_len = scheduler_config.\ + max_num_encoder_input_tokens else: self.max_encoder_len = 0 @@ -311,25 +311,22 @@ def __init__( block_sizes=[self.cache_config.block_size], is_spec_decode=bool(self.vllm_config.speculative_config), logitsprocs=build_logitsprocs( - self.vllm_config, - self.device, - self.pin_memory, + self.vllm_config, self.device, self.pin_memory, self.is_pooling_model, - self.vllm_config.model_config.logits_processors, - ), + self.vllm_config.model_config.logits_processors), is_pooling_model=self.is_pooling_model, ) self.use_async_scheduling = self.scheduler_config.async_scheduling - self.async_output_copy_stream = (torch.cuda.Stream() if - self.use_async_scheduling else None) + self.async_output_copy_stream = torch.cuda.Stream() if \ + self.use_async_scheduling else None # TODO(woosuk): Provide an option to tune the max cudagraph batch size. # The convention is different. # self.cudagraph_batch_sizes sorts in ascending order. # The batch sizes in the config are in descending order. - if (self.compilation_config.cudagraph_capture_sizes and - self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE): + if self.compilation_config.cudagraph_capture_sizes and \ + self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE: self.cudagraph_batch_sizes = list( reversed(self.compilation_config.cudagraph_capture_sizes)) @@ -388,11 +385,10 @@ def __init__( # OPTIMIZATION: Cache the tensors rather than creating them every step. # Keep in int64 to avoid overflow with long context - self.arange_np = np.arange( - max(self.max_num_reqs + 1, self.max_model_len, - self.max_num_tokens), - dtype=np.int64, - ) + self.arange_np = np.arange(max(self.max_num_reqs + 1, + self.max_model_len, + self.max_num_tokens), + dtype=np.int64) # Layer pairings for cross-layer KV sharing. # If an Attention layer `layer_name` is in the keys of this dict, it @@ -406,18 +402,17 @@ def __init__( self.kv_sharing_fast_prefill_logits_indices = torch.zeros( self.max_num_tokens, dtype=torch.int32, device=self.device) - self.uniform_decode_query_len = ( - 1 if not self.speculative_config else 1 + - self.speculative_config.num_speculative_tokens) + self.uniform_decode_query_len = 1 if not self.speculative_config else \ + 1 + self.speculative_config.num_speculative_tokens # Cudagraph dispatcher for runtime cudagraph dispatching. self.cudagraph_dispatcher = CudagraphDispatcher(self.vllm_config) - self.mm_budget = (MultiModalBudget( + self.mm_budget = MultiModalBudget( self.model_config, self.scheduler_config, self.mm_registry, - ) if self.supports_mm_inputs else None) + ) if self.supports_mm_inputs else None self.reorder_batch_threshold: Optional[int] = None @@ -434,22 +429,17 @@ def __init__( (self.max_model_len, 1), dtype=torch.int64, device="cpu", - pin_memory=self.pin_memory, - ) - - def _make_buffer( - self, - *size: Union[int, torch.SymInt], - dtype: torch.dtype, - numpy: bool = True, - ) -> CpuGpuBuffer: - return CpuGpuBuffer( - *size, - dtype=dtype, - device=self.device, - pin_memory=self.pin_memory, - with_numpy=numpy, - ) + pin_memory=self.pin_memory) + + def _make_buffer(self, + *size: Union[int, torch.SymInt], + dtype: torch.dtype, + numpy: bool = True) -> CpuGpuBuffer: + return CpuGpuBuffer(*size, + dtype=dtype, + device=self.device, + pin_memory=self.pin_memory, + with_numpy=numpy) def _init_model_kwargs(self, num_tokens: int): model_kwargs = dict[str, Any]() @@ -462,10 +452,9 @@ def _init_model_kwargs(self, num_tokens: int): token_type_id_requests = dict[int, Any]() for i, param in enumerate(pooling_params): - if (param.extra_kwargs is not None - and (token_types := - param.extra_kwargs.get("compressed_token_type_ids")) - is not None): + if param.extra_kwargs is not None and \ + (token_types := param.extra_kwargs.get( + "compressed_token_type_ids")) is not None: token_type_id_requests[i] = token_types if len(token_type_id_requests) == 0: @@ -506,17 +495,17 @@ def _may_reorder_batch(self, scheduler_output: "SchedulerOutput") -> None: # required for DCP with q_len > 1, so we assert here. Remove this # assert once the custom mask is support is added to FA3. if self.dcp_world_size > 1: - assert self.reorder_batch_threshold == 1, ( - "DCP not support reorder_batch_threshold > 1 now.") + assert self.reorder_batch_threshold == 1, \ + "DCP not support reorder_batch_threshold > 1 now." reorder_batch_to_split_decodes_and_prefills( self.input_batch, scheduler_output, - decode_threshold=self.reorder_batch_threshold, - ) + decode_threshold=self.reorder_batch_threshold) # Note: used for model runner override. def _init_device_properties(self) -> None: - """Initialize attributes from torch.cuda.get_device_properties""" + """Initialize attributes from torch.cuda.get_device_properties + """ self.device_properties = torch.cuda.get_device_properties(self.device) self.num_sms = self.device_properties.multi_processor_count @@ -572,8 +561,8 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: sampling_params = new_req_data.sampling_params pooling_params = new_req_data.pooling_params - if (sampling_params and sampling_params.sampling_type - == SamplingType.RANDOM_SEED): + if sampling_params and \ + sampling_params.sampling_type == SamplingType.RANDOM_SEED: generator = torch.Generator(device=self.device) generator.manual_seed(sampling_params.seed) else: @@ -678,8 +667,8 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: self.input_batch.num_tokens[req_index] = end_token_index # Add spec_token_ids to token_ids_cpu. - spec_token_ids = scheduler_output.scheduled_spec_decode_tokens.get( - req_id, ()) + spec_token_ids = ( + scheduler_output.scheduled_spec_decode_tokens.get(req_id, ())) if spec_token_ids: num_spec_tokens = len(spec_token_ids) start_index = self.input_batch.num_tokens_no_spec[req_index] @@ -715,17 +704,14 @@ def _update_states_after_model_execute( return # Find the number of accepted tokens for each sequence. - num_accepted_tokens = ((torch.cat( + num_accepted_tokens = (torch.cat( [ output_token_ids, - torch.full( - (output_token_ids.size(0), 1), - -1, - device=output_token_ids.device, - ), + torch.full((output_token_ids.size(0), 1), + -1, + device=output_token_ids.device), ], - dim=1, - ) == -1).int().argmax(-1).cpu().numpy()) + dim=1) == -1).int().argmax(-1).cpu().numpy() for i, num_tokens in enumerate(num_accepted_tokens): self.input_batch.num_accepted_tokens_cpu[i] = num_tokens @@ -751,7 +737,7 @@ def _init_mrope_positions(self, req_state: CachedRequestState): if mm_input.get("use_audio_in_video") is True: use_audio_in_video = True - req_state.mrope_positions, req_state.mrope_position_delta = ( + req_state.mrope_positions, req_state.mrope_position_delta = \ MRotaryEmbedding.get_input_positions_tensor( req_state.prompt_token_ids, hf_config=self.model_config.hf_config, @@ -760,7 +746,7 @@ def _init_mrope_positions(self, req_state: CachedRequestState): second_per_grid_ts=second_per_grid_ts, audio_feature_lengths=audio_feature_lengths, use_audio_in_video=use_audio_in_video, - )) + ) def _extract_mm_kwargs( self, @@ -819,7 +805,7 @@ def _get_cumsum_and_arange( def _prepare_input_ids(self, total_num_scheduled_tokens: int, cu_num_tokens: np.ndarray) -> None: """Prepare the input IDs for the current batch. - + Carefully handles the `prev_sampled_token_ids` which can be cached from the previous engine iteration, in which case those tokens on the GPU need to be copied into the corresponding slots into input_ids.""" @@ -845,7 +831,7 @@ def _prepare_input_ids(self, total_num_scheduled_tokens: int, # last token in each common request. flattened_index = cu_num_tokens[cur_index].item() - 1 flattened_indices.append(flattened_index) - indices_match &= prev_index == flattened_index + indices_match &= (prev_index == flattened_index) max_flattened_index = max(max_flattened_index, flattened_index) num_commmon_tokens = len(flattened_indices) if num_commmon_tokens < total_num_scheduled_tokens: @@ -864,8 +850,7 @@ def _prepare_input_ids(self, total_num_scheduled_tokens: int, self.input_ids.gpu[:num_commmon_tokens].copy_( self.input_batch.prev_sampled_token_ids[:num_commmon_tokens, 0], - non_blocking=True, - ) + non_blocking=True) return # Upload the index tensors asynchronously # so the scatter can be non-blocking. @@ -877,14 +862,12 @@ def _prepare_input_ids(self, total_num_scheduled_tokens: int, prev_common_req_indices_tensor = torch.tensor( prev_common_req_indices, dtype=torch.int64, - pin_memory=self.pin_memory, - ).to(self.device, non_blocking=True) + pin_memory=self.pin_memory).to(self.device, non_blocking=True) self.input_ids.gpu.scatter_( dim=0, index=input_ids_index_tensor, src=self.input_batch.prev_sampled_token_ids[ - prev_common_req_indices_tensor, 0], - ) + prev_common_req_indices_tensor, 0]) def _get_encoder_seq_lens( self, @@ -906,16 +889,10 @@ def _get_encoder_seq_lens( def _prepare_inputs( self, scheduler_output: "SchedulerOutput" - ) -> tuple[ - PerLayerAttnMetadata, - torch.Tensor, - Optional[SpecDecodeMetadata], - np.ndarray, - Optional[CommonAttentionMetadata], - int, - Optional[UBatchSlices], - Optional[torch.Tensor], - ]: + ) -> tuple[PerLayerAttnMetadata, torch.Tensor, + Optional[SpecDecodeMetadata], np.ndarray, + Optional[CommonAttentionMetadata], int, Optional[UBatchSlices], + Optional[torch.Tensor]]: """ :return: tuple[ attn_metadata: layer-to-attention_metadata mapping, @@ -949,11 +926,9 @@ def _prepare_inputs( # Get positions. positions_np = self.positions.np[:total_num_scheduled_tokens] - np.add( - self.input_batch.num_computed_tokens_cpu[req_indices], - arange, - out=positions_np, - ) + np.add(self.input_batch.num_computed_tokens_cpu[req_indices], + arange, + out=positions_np) # Calculate M-RoPE positions. # Only relevant for models using M-RoPE (e.g, Qwen2-VL) @@ -970,12 +945,10 @@ def _prepare_inputs( # NOTE(woosuk): We use torch.index_select instead of np.take here # because torch.index_select is much faster than np.take for large # tensors. - torch.index_select( - self.input_batch.token_ids_cpu_tensor.flatten(), - 0, - torch.from_numpy(token_indices), - out=self.input_ids.cpu[:total_num_scheduled_tokens], - ) + torch.index_select(self.input_batch.token_ids_cpu_tensor.flatten(), + 0, + torch.from_numpy(token_indices), + out=self.input_ids.cpu[:total_num_scheduled_tokens]) self.input_batch.block_table.compute_slot_mapping( req_indices, positions_np) @@ -994,12 +967,11 @@ def _prepare_inputs( num_tokens_unpadded = scheduler_output.total_num_scheduled_tokens num_tokens_padded = num_tokens_unpadded + self.get_local_padding( num_tokens_unpadded) - ubatch_slices, num_tokens_after_padding = ubatch_split( - max_num_scheduled_tokens, - num_tokens_unpadded, - num_tokens_padded, - self.vllm_config, - ) + ubatch_slices, num_tokens_after_padding = \ + ubatch_split(max_num_scheduled_tokens, + num_tokens_unpadded, + num_tokens_padded, + self.vllm_config) self.seq_lens.np[:num_reqs] = ( self.input_batch.num_computed_tokens_cpu[:num_reqs] + @@ -1032,8 +1004,7 @@ def _prepare_inputs( # Only relevant for models using M-RoPE (e.g, Qwen2-VL) self.mrope_positions.gpu[:, :total_num_scheduled_tokens].copy_( self.mrope_positions.cpu[:, :total_num_scheduled_tokens], - non_blocking=True, - ) + non_blocking=True) else: # Common case (1D positions) self.positions.copy_to_gpu(total_num_scheduled_tokens) @@ -1054,10 +1025,8 @@ def _prepare_inputs( # Iterate over the dictionary rather than all requests since not all # requests have draft tokens. num_draft_tokens = np.zeros(num_reqs, dtype=np.int32) - for ( - req_id, - draft_token_ids, - ) in scheduler_output.scheduled_spec_decode_tokens.items(): + for req_id, draft_token_ids in ( + scheduler_output.scheduled_spec_decode_tokens.items()): req_idx = self.input_batch.req_id_to_index[req_id] num_draft_tokens[req_idx] = len(draft_token_ids) @@ -1143,8 +1112,8 @@ def _prepare_inputs( encoder_seq_lens=encoder_seq_lens, ) - if (self.speculative_config - and spec_decode_common_attn_metadata is None): + if self.speculative_config and \ + spec_decode_common_attn_metadata is None: spec_decode_common_attn_metadata = common_attn_metadata for attn_group in self.attn_groups[kv_cache_group_id]: @@ -1174,11 +1143,10 @@ def _prepare_inputs( for ubid, common_attn_metadata in enumerate( common_attn_metadata_list): assert common_attn_metadata.max_query_len == 1 - attn_metadata_i = attn_group.get_metadata_builder( + attn_metadata_i = (attn_group.get_metadata_builder( ubatch_id=ubid).build( common_prefix_len=common_prefix_len, - common_attn_metadata=common_attn_metadata, - ) + common_attn_metadata=common_attn_metadata)) for layer_name in kv_cache_group_spec.layer_names: assert type(attn_metadata) is list attn_metadata[ubid][layer_name] = attn_metadata_i @@ -1187,8 +1155,7 @@ def _prepare_inputs( attn_metadata_i = builder.build( common_prefix_len=common_prefix_len, common_attn_metadata=common_attn_metadata, - **extra_attn_metadata_args, - ) + **extra_attn_metadata_args) for layer_name in attn_group.layer_names: attn_metadata[layer_name] = attn_metadata_i @@ -1196,16 +1163,10 @@ def _prepare_inputs( if self.lora_config: self.set_active_loras(self.input_batch, num_scheduled_tokens) - return ( - attn_metadata, - logits_indices, - spec_decode_metadata, - num_scheduled_tokens, - spec_decode_common_attn_metadata, - max_num_scheduled_tokens, - ubatch_slices, - num_tokens_after_padding, - ) + return (attn_metadata, logits_indices, spec_decode_metadata, + num_scheduled_tokens, spec_decode_common_attn_metadata, + max_num_scheduled_tokens, ubatch_slices, + num_tokens_after_padding) def _compute_cascade_attn_prefix_len( self, @@ -1278,18 +1239,17 @@ def _compute_cascade_attn_prefix_len( num_reqs = len(num_scheduled_tokens) common_prefix_len = min( common_prefix_len, - self.input_batch.num_computed_tokens_cpu[:num_reqs].min(), - ) + self.input_batch.num_computed_tokens_cpu[:num_reqs].min()) # common_prefix_len should be a multiple of the block size. common_prefix_len = (common_prefix_len // kv_cache_spec.block_size * kv_cache_spec.block_size) - use_sliding_window = isinstance(kv_cache_spec, SlidingWindowSpec) or ( - isinstance(kv_cache_spec, FullAttentionSpec) - and kv_cache_spec.sliding_window is not None) - use_local_attention = isinstance( - kv_cache_spec, ChunkedLocalAttentionSpec) or ( - isinstance(kv_cache_spec, FullAttentionSpec) - and kv_cache_spec.attention_chunk_size is not None) + use_sliding_window = (isinstance(kv_cache_spec, SlidingWindowSpec) or + (isinstance(kv_cache_spec, FullAttentionSpec) + and kv_cache_spec.sliding_window is not None)) + use_local_attention = ( + isinstance(kv_cache_spec, ChunkedLocalAttentionSpec) + or (isinstance(kv_cache_spec, FullAttentionSpec) + and kv_cache_spec.attention_chunk_size is not None)) assert isinstance(kv_cache_spec, AttentionSpec) use_cascade = attn_metadata_builder.use_cascade_attention( common_prefix_len=common_prefix_len, @@ -1309,10 +1269,10 @@ def _calc_mrope_positions(self, scheduler_output: "SchedulerOutput"): req = self.requests[req_id] assert req.mrope_positions is not None - num_computed_tokens = self.input_batch.num_computed_tokens_cpu[ - index] - num_scheduled_tokens = scheduler_output.num_scheduled_tokens[ - req_id] + num_computed_tokens = \ + self.input_batch.num_computed_tokens_cpu[index] + num_scheduled_tokens = \ + scheduler_output.num_scheduled_tokens[req_id] num_prompt_tokens = len(req.prompt_token_ids) if num_computed_tokens + num_scheduled_tokens > num_prompt_tokens: @@ -1442,8 +1402,8 @@ def _prepare_kv_sharing_fast_prefill( num_logits_padded = self.vllm_config.pad_for_cudagraph(num_logits) else: num_logits_padded = num_logits - logits_indices_padded = self.kv_sharing_fast_prefill_logits_indices[: - num_logits_padded] + logits_indices_padded = ( + self.kv_sharing_fast_prefill_logits_indices[:num_logits_padded]) return logits_indices_padded def _batch_mm_kwargs_from_scheduler( @@ -1535,8 +1495,8 @@ def _gather_mm_embeddings( num_scheduled_tokens = scheduler_output.num_scheduled_tokens[ req_id] req_state = self.requests[req_id] - num_computed_tokens = (req_state.num_computed_tokens + - shift_computed_tokens) + num_computed_tokens = \ + req_state.num_computed_tokens + shift_computed_tokens for mm_feature in req_state.mm_features: pos_info = mm_feature.mm_position start_pos = pos_info.offset @@ -1563,8 +1523,8 @@ def _gather_mm_embeddings( mm_hash = mm_feature.identifier encoder_output = self.encoder_cache.get(mm_hash, None) - assert encoder_output is not None, ( - f"Encoder cache miss for {mm_hash}.") + assert encoder_output is not None,\ + f"Encoder cache miss for {mm_hash}." if (is_embed := pos_info.is_embed) is not None: is_embed = is_embed[start_idx:end_idx] @@ -1662,11 +1622,9 @@ def get_supported_tasks(self) -> tuple[SupportedTask, ...]: return tuple(tasks) def sync_and_slice_intermediate_tensors( - self, - num_tokens: int, - intermediate_tensors: IntermediateTensors, - sync_self: bool, - ) -> IntermediateTensors: + self, num_tokens: int, intermediate_tensors: IntermediateTensors, + sync_self: bool) -> IntermediateTensors: + assert self.intermediate_tensors is not None tp = self.vllm_config.parallel_config.tensor_parallel_size @@ -1678,7 +1636,8 @@ def sync_and_slice_intermediate_tensors( assert intermediate_tensors is not None for k, v in intermediate_tensors.items(): is_scattered = k == "residual" and is_rs - copy_len = num_tokens // tp if is_scattered else num_tokens + copy_len = num_tokens // tp if is_scattered else \ + num_tokens self.intermediate_tensors[k][:copy_len].copy_( v[:copy_len], non_blocking=True) @@ -1738,14 +1697,14 @@ def get_dp_padding(self, num_tokens_across_dp = DPMetadata.num_tokens_across_dp( num_tokens, dp_size, dp_rank) max_tokens_across_dp_cpu = torch.max(num_tokens_across_dp).item() - num_tokens_after_padding = torch.tensor( - [max_tokens_across_dp_cpu] * dp_size, - device="cpu", - dtype=torch.int32, - ) + num_tokens_after_padding = torch.tensor([max_tokens_across_dp_cpu] * + dp_size, + device="cpu", + dtype=torch.int32) return max_tokens_across_dp_cpu - num_tokens, num_tokens_after_padding def get_local_padding(self, num_tokens_unpadded: int) -> int: + num_tokens_padded = num_tokens_unpadded if (self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE @@ -1759,8 +1718,8 @@ def get_local_padding(self, num_tokens_unpadded: int) -> int: # Pad tokens to multiple of tensor_parallel_size when # enabled collective fusion for SP tp_size = self.vllm_config.parallel_config.tensor_parallel_size - if (self.vllm_config.compilation_config.pass_config. - enable_sequence_parallelism and tp_size > 1): + if self.vllm_config.compilation_config.pass_config. \ + enable_sequence_parallelism and tp_size > 1: num_tokens_padded = round_up(num_tokens_unpadded, tp_size) num_pad_tokens = num_tokens_padded - num_tokens_unpadded @@ -1783,10 +1742,10 @@ def _pool( num_scheduled_tokens: int, num_scheduled_tokens_np: np.ndarray, ) -> ModelRunnerOutput: - assert self.input_batch.num_reqs == len( - self.input_batch.pooling_params), ( - "Either all or none of the requests in" - " a batch must be pooling request") + assert self.input_batch.num_reqs ==\ + len(self.input_batch.pooling_params), \ + "Either all or none of the requests in" \ + " a batch must be pooling request" hidden_states = hidden_states[:num_scheduled_tokens] pooling_metadata = self.input_batch.get_pooling_metadata() @@ -1801,6 +1760,7 @@ def _pool( pooler_output: list[Optional[torch.Tensor]] = [] for raw_output, seq_len, prompt_len in zip( raw_pooler_output, seq_lens_cpu, pooling_metadata.prompt_lens): + output = raw_output.data if seq_len == prompt_len else None pooler_output.append(output) @@ -1838,16 +1798,10 @@ def _preprocess( intermediate_tensors: Optional[IntermediateTensors] = None, ubatch_slices: Optional[UBatchSlices] = None, num_tokens_after_padding: Optional[torch.Tensor] = None, - ) -> tuple[ - int, - int, - Optional[torch.Tensor], - Optional[torch.Tensor], - Optional[torch.Tensor], - torch.Tensor, - Optional[IntermediateTensors], - dict[str, Any], - ]: + ) -> tuple[int, int, Optional[torch.Tensor], Optional[torch.Tensor], + Optional[torch.Tensor], torch.Tensor, + Optional[IntermediateTensors], dict[str, Any]]: + num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens if ubatch_slices: assert num_tokens_after_padding is not None @@ -1921,9 +1875,8 @@ def _preprocess( ) def _sample( - self, - logits: Optional[torch.Tensor], - spec_decode_metadata: Optional[SpecDecodeMetadata], + self, logits: Optional[torch.Tensor], + spec_decode_metadata: Optional[SpecDecodeMetadata] ) -> SamplerOutput: # Sample the next token and get logprobs if needed. sampling_metadata = self.input_batch.sampling_metadata @@ -1962,12 +1915,9 @@ def _sample( return sampler_output def _bookkeeping_sync( - self, - scheduler_output: "SchedulerOutput", - sampler_output: SamplerOutput, - logits: Optional[torch.Tensor], - hidden_states: torch.Tensor, - num_scheduled_tokens: int, + self, scheduler_output: "SchedulerOutput", + sampler_output: SamplerOutput, logits: Optional[torch.Tensor], + hidden_states: torch.Tensor, num_scheduled_tokens: int ) -> tuple[ dict[str, int], Optional[LogprobsLists], @@ -1981,10 +1931,8 @@ def _bookkeeping_sync( if envs.VLLM_COMPUTE_NANS_IN_LOGITS: num_nans_in_logits = self._get_nans_in_logits(logits) - discard_sampled_tokens_req_indices = self.discard_request_indices.np[: - self - . - num_discarded_requests] + discard_sampled_tokens_req_indices = \ + self.discard_request_indices.np[:self.num_discarded_requests] for i in discard_sampled_tokens_req_indices: gen = self.input_batch.generators.get(int(i)) if gen is not None: @@ -1993,13 +1941,14 @@ def _bookkeeping_sync( # Copy some objects so they don't get modified after returning. # This is important when using async scheduling. req_ids_output_copy = self.input_batch.req_ids.copy() - req_id_to_index_output_copy = self.input_batch.req_id_to_index.copy() + req_id_to_index_output_copy = \ + self.input_batch.req_id_to_index.copy() # NOTE: GPU -> CPU Sync happens here. # Move as many CPU operations as possible before this sync point. logprobs_tensors = sampler_output.logprobs_tensors - logprobs_lists = (logprobs_tensors.tolists() - if logprobs_tensors is not None else None) + logprobs_lists = logprobs_tensors.tolists() \ + if logprobs_tensors is not None else None # Compute prompt logprobs if needed. prompt_logprobs_dict = self._get_prompt_logprobs_dict( @@ -2034,9 +1983,10 @@ def _bookkeeping_sync( # Cache the sampled tokens on the GPU and avoid CPU sync. # These will be copied into input_ids in the next step # when preparing inputs. - self.input_batch.prev_sampled_token_ids = sampled_token_ids - self.input_batch.prev_sampled_token_ids_invalid_indices = ( - invalid_req_indices_set) + self.input_batch.prev_sampled_token_ids = \ + sampled_token_ids + self.input_batch.prev_sampled_token_ids_invalid_indices = \ + invalid_req_indices_set self.input_batch.prev_req_id_to_index = { req_id: i for i, req_id in enumerate(self.input_batch.req_ids) @@ -2051,8 +2001,8 @@ def _bookkeeping_sync( req_ids = self.input_batch.req_ids for req_idx in range(num_sampled_tokens): if self.use_async_scheduling: - sampled_ids = ([-1] if req_idx not in invalid_req_indices_set - else None) + sampled_ids = [-1] if \ + req_idx not in invalid_req_indices_set else None else: sampled_ids = valid_sampled_token_ids[req_idx] if not sampled_ids: @@ -2066,7 +2016,7 @@ def _bookkeeping_sync( f"{self.max_model_len}") self.input_batch.token_ids_cpu[req_idx, - start_idx:end_idx] = (sampled_ids) + start_idx:end_idx] = sampled_ids self.input_batch.num_tokens_no_spec[req_idx] = end_idx self.input_batch.num_tokens[req_idx] = end_idx @@ -2109,16 +2059,10 @@ def execute_model( self.prepare_inputs_event.synchronize() try: # Prepare the decoder inputs. - ( - attn_metadata, - logits_indices, - spec_decode_metadata, - num_scheduled_tokens_np, - spec_decode_common_attn_metadata, - max_query_len, - ubatch_slices, - num_tokens_after_padding, - ) = self._prepare_inputs(scheduler_output) + (attn_metadata, logits_indices, spec_decode_metadata, + num_scheduled_tokens_np, spec_decode_common_attn_metadata, + max_query_len, ubatch_slices, num_tokens_after_padding + ) = self._prepare_inputs(scheduler_output) finally: if self.prepare_inputs_event is not None: @@ -2133,12 +2077,8 @@ def execute_model( positions, intermediate_tensors, model_kwargs, - ) = self._preprocess( - scheduler_output, - intermediate_tensors, - ubatch_slices, - num_tokens_after_padding, - ) + ) = self._preprocess(scheduler_output, intermediate_tensors, + ubatch_slices, num_tokens_after_padding) if ubatch_slices is not None: num_input_tokens = num_input_tokens // 2 @@ -2149,25 +2089,22 @@ def execute_model( == self.input_batch.num_reqs * max_query_len) batch_descriptor = BatchDescriptor(num_tokens=num_input_tokens, uniform_decode=uniform_decode) - cudagraph_runtime_mode, batch_descriptor = ( - self.cudagraph_dispatcher.dispatch(batch_descriptor)) + cudagraph_runtime_mode, batch_descriptor = \ + self.cudagraph_dispatcher.dispatch(batch_descriptor) # Run the model. # Use persistent buffers for CUDA graphs. - with ( - set_forward_context( - attn_metadata, - self.vllm_config, - num_tokens=num_input_tokens, - num_tokens_across_dp=num_tokens_across_dp, - cudagraph_runtime_mode=cudagraph_runtime_mode, - batch_descriptor=batch_descriptor, - ubatch_slices=ubatch_slices, - ), - record_function_or_nullcontext("Forward"), - self.maybe_get_kv_connector_output(scheduler_output) as - kv_connector_output, - ): + with (set_forward_context( + attn_metadata, + self.vllm_config, + num_tokens=num_input_tokens, + num_tokens_across_dp=num_tokens_across_dp, + cudagraph_runtime_mode=cudagraph_runtime_mode, + batch_descriptor=batch_descriptor, + ubatch_slices=ubatch_slices, + ), record_function_or_nullcontext("Forward"), + self.maybe_get_kv_connector_output(scheduler_output) as + kv_connector_output): model_output = self.model( input_ids=input_ids, positions=positions, @@ -2195,11 +2132,8 @@ def execute_model( if self.is_pooling_model: # Return the pooling output. - output = self._pool( - hidden_states, - num_scheduled_tokens, - num_scheduled_tokens_np, - ) + output = self._pool(hidden_states, num_scheduled_tokens, + num_scheduled_tokens_np) output.kv_connector_output = kv_connector_output return output @@ -2218,8 +2152,7 @@ def execute_model( get_pp_group().send_tensor_dict( hidden_states.tensors, all_gather_group=get_tp_group(), - all_gather_tensors=all_gather_tensors, - ) + all_gather_tensors=all_gather_tensors) logits = None else: sample_hidden_states = hidden_states[logits_indices] @@ -2230,11 +2163,9 @@ def execute_model( if logits is not None: model_output_broadcast_data["logits"] = logits.contiguous() - model_output_broadcast_data = ( - get_pp_group().broadcast_tensor_dict( - model_output_broadcast_data, - src=len(get_pp_group().ranks) - 1, - )) + model_output_broadcast_data = get_pp_group( + ).broadcast_tensor_dict(model_output_broadcast_data, + src=len(get_pp_group().ranks) - 1) assert model_output_broadcast_data is not None logits = model_output_broadcast_data["logits"] @@ -2262,15 +2193,14 @@ def propose_draft_token_ids(sampled_token_ids): batch_descriptor=batch_descriptor, ) - use_padded_batch = ( - self.speculative_config - and (self.speculative_config.use_eagle() - or self.speculative_config.uses_draft_model()) - and not self.speculative_config.disable_padded_drafter_batch) - run_draft_before_bookkeeping = use_padded_batch - if run_draft_before_bookkeeping: - # EAGLE speculative decoding can use the GPU sampled tokens - # as inputs, and does not need to wait for bookkeeping to finish. + use_padded_batch = self.speculative_config and \ + (self.speculative_config.use_eagle() + or self.speculative_config.uses_draft_model()) and \ + not self.speculative_config.disable_padded_drafter_batch + if use_padded_batch: + # EAGLE and draft model speculative decoding can use the + # GPU sampled tokens as inputs, and does not need + # to wait for bookkeeping to finish. propose_draft_token_ids(sampler_output.sampled_token_ids) with record_function_or_nullcontext("Bookkeep"): @@ -2282,15 +2212,11 @@ def propose_draft_token_ids(sampled_token_ids): req_ids_output_copy, req_id_to_index_output_copy, invalid_req_indices, - ) = self._bookkeeping_sync( - scheduler_output, - sampler_output, - logits, - hidden_states, - num_scheduled_tokens, - ) + ) = self._bookkeeping_sync(scheduler_output, sampler_output, + logits, hidden_states, + num_scheduled_tokens) - if self.speculative_config and not run_draft_before_bookkeeping: + if self.speculative_config and not use_padded_batch: # ngram and other speculative decoding methods use the sampled # tokens on the CPU, so they are run after bookkeeping. propose_draft_token_ids(valid_sampled_token_ids) @@ -2380,33 +2306,29 @@ def propose_draft_token_ids( # When padded-batch is disabled, the sampled_token_ids should be # the cpu-side list[list[int]] of valid sampled tokens for each # request, with invalid requests having empty lists. - assert isinstance( - sampled_token_ids, - list), ("sampled_token_ids should be a python list when" - "padded-batch is disabled.") + assert isinstance(sampled_token_ids, list), \ + "sampled_token_ids should be a python list when" \ + "padded-batch is disabled." next_token_ids = self.drafter.prepare_next_token_ids_cpu( - sampled_token_ids, - self.requests, - self.input_batch, - scheduler_output.num_scheduled_tokens, - ) + sampled_token_ids, self.requests, self.input_batch, + scheduler_output.num_scheduled_tokens) else: # When using padded-batch, the sampled_token_ids should be # the gpu tensor of sampled tokens for each request, of shape # (num_reqs, num_spec_tokens + 1) with rejected tokens having # value -1. - assert isinstance(sampled_token_ids, torch.Tensor), ( - "sampled_token_ids should be a torch.Tensor when" - "padded-batch is enabled.") - next_token_ids, valid_sampled_tokens_count = ( + assert isinstance(sampled_token_ids, torch.Tensor), \ + "sampled_token_ids should be a torch.Tensor when" \ + "padded-batch is enabled." + next_token_ids, valid_sampled_tokens_count = \ self.drafter.prepare_next_token_ids_padded( common_attn_metadata, sampled_token_ids, self.requests, self.input_batch, self.discard_request_indices.gpu, - self.num_discarded_requests, - )) + self.num_discarded_requests + ) if spec_decode_metadata is None: token_indices_to_sample = None @@ -2417,29 +2339,24 @@ def propose_draft_token_ids( if self.use_aux_hidden_state_outputs: target_hidden_states = torch.cat( [h[:num_scheduled_tokens] for h in aux_hidden_states], - dim=-1, - ) + dim=-1) else: target_hidden_states = hidden_states[:num_scheduled_tokens] else: if self.speculative_config.disable_padded_drafter_batch: token_indices_to_sample = None - common_attn_metadata, token_indices = ( + common_attn_metadata, token_indices =\ self.drafter.prepare_inputs( common_attn_metadata, sampled_token_ids, - spec_decode_metadata.num_draft_tokens, - )) + spec_decode_metadata.num_draft_tokens) else: - ( - common_attn_metadata, - token_indices, - token_indices_to_sample, - ) = self.drafter.prepare_inputs_padded( - common_attn_metadata, - spec_decode_metadata, - valid_sampled_tokens_count, - ) + common_attn_metadata, token_indices, \ + token_indices_to_sample =\ + self.drafter.prepare_inputs_padded( + common_attn_metadata, + spec_decode_metadata, + valid_sampled_tokens_count) target_token_ids = self.input_ids.gpu[token_indices] # TODO(woosuk): Support M-RoPE. @@ -2453,7 +2370,6 @@ def propose_draft_token_ids( if self.supports_mm_inputs: mm_embeds = self._gather_mm_embeddings(scheduler_output, shift_computed_tokens=1) - if self.speculative_config.use_eagle(): assert isinstance(self.drafter, EagleProposer) draft_token_ids = self.drafter.propose( @@ -2517,9 +2433,9 @@ def propose_ngram_draft_token_ids( def update_config(self, overrides: dict[str, Any]) -> None: allowed_config_names = {"load_config", "model_config"} for config_name, config_overrides in overrides.items(): - assert config_name in allowed_config_names, ( - f"Config `{config_name}` not supported. " - f"Allowed configs: {allowed_config_names}") + assert config_name in allowed_config_names, \ + f"Config `{config_name}` not supported. " \ + f"Allowed configs: {allowed_config_names}" config = getattr(self, config_name) new_config = update_config(config, config_overrides) setattr(self, config_name, new_config) @@ -2532,15 +2448,12 @@ def load_model(self, eep_scale_up: bool = False) -> None: logger.info("Starting to load model %s...", self.model_config.model) if eep_scale_up: from vllm.distributed.parallel_state import get_ep_group - num_local_physical_experts = torch.empty(1, dtype=torch.int32, device="cpu") - torch.distributed.broadcast( - num_local_physical_experts, - group=get_ep_group().cpu_group, - group_src=0, - ) + torch.distributed.broadcast(num_local_physical_experts, + group=get_ep_group().cpu_group, + group_src=0) num_local_physical_experts = int(num_local_physical_experts.item()) new_ep_size = get_ep_group().world_size global_expert_load, old_global_expert_indices = ( @@ -2548,10 +2461,10 @@ def load_model(self, eep_scale_up: bool = False) -> None: num_logical_experts = global_expert_load.shape[1] self.parallel_config.eplb_config.num_redundant_experts = ( num_local_physical_experts * new_ep_size - num_logical_experts) - assert (old_global_expert_indices.shape[1] % - num_local_physical_experts == 0) - old_ep_size = (old_global_expert_indices.shape[1] // - num_local_physical_experts) + assert old_global_expert_indices.shape[ + 1] % num_local_physical_experts == 0 + old_ep_size = old_global_expert_indices.shape[ + 1] // num_local_physical_experts rank_mapping = { old_ep_rank: old_ep_rank for old_ep_rank in range(old_ep_size) @@ -2568,20 +2481,19 @@ def load_model(self, eep_scale_up: bool = False) -> None: self.model = model_loader.load_model( vllm_config=self.vllm_config, model_config=self.model_config) if self.lora_config: - self.model = self.load_lora_model( - self.model, - self.model_config, - self.scheduler_config, - self.lora_config, - self.device, - ) + self.model = self.load_lora_model(self.model, + self.model_config, + self.scheduler_config, + self.lora_config, + self.device) if hasattr(self, "drafter"): logger.info("Loading drafter model...") - if self.speculative_config.uses_draft_model(): + if self.speculative_config.use_eagle(): + assert isinstance(self.drafter, EagleProposer) + self.drafter.load_model(self.model) + elif self.speculative_config.uses_draft_model(): assert isinstance(self.drafter, DraftModelProposer) self.drafter.load_model() - else: - self.drafter.load_model(self.model) if self.use_aux_hidden_state_outputs: if supports_eagle3(self.model): self.model.set_aux_hidden_state_layers( @@ -2592,15 +2504,13 @@ def load_model(self, eep_scale_up: bool = False) -> None: "aux_hidden_state_outputs was requested") time_after_load = time.perf_counter() self.model_memory_usage = m.consumed_memory - logger.info( - "Model loading took %.4f GiB and %.6f seconds", - self.model_memory_usage / GiB_bytes, - time_after_load - time_before_load, - ) + logger.info("Model loading took %.4f GiB and %.6f seconds", + self.model_memory_usage / GiB_bytes, + time_after_load - time_before_load) prepare_communication_buffer_for_model(self.model) - if (is_mixture_of_experts(self.model) - and self.parallel_config.enable_eplb): + if is_mixture_of_experts( + self.model) and self.parallel_config.enable_eplb: logger.info("EPLB is enabled for model %s.", self.model_config.model) self.eplb_state = EplbState.build( @@ -2612,44 +2522,37 @@ def load_model(self, eep_scale_up: bool = False) -> None: rank_mapping, ) - if (self.vllm_config.compilation_config.level - == CompilationLevel.DYNAMO_AS_IS and supports_dynamo()): + if ( + self.vllm_config.compilation_config.level == \ + CompilationLevel.DYNAMO_AS_IS and supports_dynamo() + ): backend = self.vllm_config.compilation_config.init_backend( self.vllm_config) compilation_counter.dynamo_as_is_count += 1 self.model.compile( fullgraph=envs.VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE, - backend=backend, - ) + backend=backend) return # for other compilation levels, cudagraph behavior is controlled by # CudagraphWraper and CudagraphDispatcher of vllm. # wrap the model with full cudagraph wrapper if needed. - if (self.compilation_config.cudagraph_mode.has_full_cudagraphs() - and not self.parallel_config.enable_dbo): + if self.compilation_config.cudagraph_mode.has_full_cudagraphs() \ + and not self.parallel_config.enable_dbo: self.model = CUDAGraphWrapper(self.model, self.vllm_config, runtime_mode=CUDAGraphMode.FULL) elif self.parallel_config.enable_dbo: if self.compilation_config.cudagraph_mode.has_full_cudagraphs(): - self.model = UBatchWrapper( - self.model, - self.vllm_config, - CUDAGraphMode.FULL, - self.device, - ) + self.model = UBatchWrapper(self.model, self.vllm_config, + CUDAGraphMode.FULL, self.device) else: - self.model = UBatchWrapper( - self.model, - self.vllm_config, - CUDAGraphMode.NONE, - self.device, - ) + self.model = UBatchWrapper(self.model, self.vllm_config, + CUDAGraphMode.NONE, self.device) def reload_weights(self) -> None: - assert getattr(self, "model", None) is not None, ( - "Cannot reload weights before model is loaded.") + assert getattr(self, "model", None) is not None, \ + "Cannot reload weights before model is loaded." model_loader = get_model_loader(self.load_config) logger.info("Reloading weights inplace...") model = self.get_model() @@ -2801,8 +2704,7 @@ def rand_input_ids() -> torch.Tensor: self.input_ids.gpu, low=0, high=self.model_config.get_vocab_size(), - dtype=input_ids.dtype, - ) + dtype=input_ids.dtype) logger.debug_once("Randomizing dummy data for DP Rank") input_ids.copy_(rand_input_ids()[:input_ids.size(0)], @@ -2875,9 +2777,9 @@ def _dummy_run( num_pad = 0 should_ubatch = False if ubatch_enabled: - should_ubatch = (num_tokens - >= self.parallel_config.dbo_decode_token_threshold - and allow_microbatching) + should_ubatch = num_tokens >= \ + self.parallel_config.dbo_decode_token_threshold and \ + allow_microbatching (should_ubatch, num_tokens_across_dp) = get_dp_padding_ubatch( num_tokens, num_tokens, should_ubatch, self.vllm_config) @@ -2889,9 +2791,7 @@ def _dummy_run( assert int(num_tokens_across_dp[0]) == num_tokens // 2 assert cudagraph_runtime_mode in { - CUDAGraphMode.NONE, - CUDAGraphMode.PIECEWISE, - CUDAGraphMode.FULL, + CUDAGraphMode.NONE, CUDAGraphMode.PIECEWISE, CUDAGraphMode.FULL } if not should_ubatch: @@ -2911,8 +2811,8 @@ def _dummy_run( # When setting max_query_len = 1, we switch to and capture the optimized # routine of FA2 for pure decode, i.e., Flashdecode + an optimization # for GQA/MQA. - max_query_len = (self.uniform_decode_query_len - if uniform_decode else num_tokens) + max_query_len = self.uniform_decode_query_len if uniform_decode else \ + num_tokens if allow_microbatching: assert self.uniform_decode_query_len == 1 assert uniform_decode is True @@ -2939,8 +2839,8 @@ def _dummy_run( max_query_len = num_prefill_tokens elif uniform_decode: num_reqs = num_tokens // max_query_len - assert num_reqs <= max_num_reqs, ( - "Do not capture num_reqs > max_num_reqs for uniform batch") + assert num_reqs <= max_num_reqs, \ + "Do not capture num_reqs > max_num_reqs for uniform batch" num_scheduled_tokens_list = [max_query_len] * num_reqs if num_tokens % max_query_len != 0: num_scheduled_tokens_list[-1] += num_tokens % max_query_len @@ -2965,10 +2865,8 @@ def _dummy_run( ubatch_slices = [ UBatchSlice(slice(0, num_reqs // 2), slice(0, num_tokens // 2)), - UBatchSlice( - slice(num_reqs // 2, num_reqs), - slice(num_tokens // 2, num_tokens), - ), + UBatchSlice(slice(num_reqs // 2, num_reqs), + slice(num_tokens // 2, num_tokens)) ] attn_metadata: Optional[PerLayerAttnMetadata] = None @@ -3010,8 +2908,7 @@ def _dummy_run( block_table[kv_cache_group_id].get_device_tensor(num_reqs), slot_mapping=self.input_batch.block_table[ kv_cache_group_id].slot_mapping.gpu[:num_tokens], - causal=True, - ) + causal=True) for attn_group in self.attn_groups[kv_cache_group_id]: if ubatch_slices is not None: common_attn_metadata_list = split_attn_metadata( @@ -3019,17 +2916,17 @@ def _dummy_run( for ubid, common_attn_metadata in enumerate( common_attn_metadata_list): assert common_attn_metadata.max_query_len == 1 - attn_metadata_i = attn_group.get_metadata_builder( - ubatch_id=ubid).build_for_cudagraph_capture( - common_attn_metadata) + attn_metadata_i = (attn_group\ + .get_metadata_builder(ubatch_id=ubid)\ + .build_for_cudagraph_capture(common_attn_metadata)) for layer_name in kv_cache_group_spec.layer_names: assert type(attn_metadata) is list - attn_metadata[ubid][layer_name] = ( - attn_metadata_i) + attn_metadata[ubid][ + layer_name] = attn_metadata_i else: assert type(attn_metadata) is dict - attn_metadata_i = attn_group.get_metadata_builder( - ).build_for_cudagraph_capture(common_attn_metadata) + attn_metadata_i = attn_group.get_metadata_builder()\ + .build_for_cudagraph_capture(common_attn_metadata) for layer_name in kv_cache_group_spec.layer_names: attn_metadata[layer_name] = attn_metadata_i @@ -3061,8 +2958,7 @@ def _dummy_run( self.model.make_empty_intermediate_tensors( batch_size=self.max_num_tokens, dtype=self.model_config.dtype, - device=self.device, - )) + device=self.device)) intermediate_tensors = self.sync_and_slice_intermediate_tensors( num_tokens, None, False) @@ -3070,9 +2966,10 @@ def _dummy_run( batch_descriptor = None else: # filter out the valid batch descriptor - _cg_mode, batch_descriptor = self.cudagraph_dispatcher.dispatch( - BatchDescriptor(num_tokens=num_tokens, - uniform_decode=uniform_decode)) + _cg_mode, batch_descriptor = \ + self.cudagraph_dispatcher.dispatch( + BatchDescriptor(num_tokens=num_tokens, + uniform_decode=uniform_decode)) # sanity check assert cudagraph_runtime_mode == _cg_mode, ( f"Cudagraph runtime mode mismatch at dummy_run. " @@ -3080,18 +2977,14 @@ def _dummy_run( if ubatch_slices is not None: num_tokens = num_tokens // 2 - with ( - self.maybe_randomize_inputs(input_ids), - set_forward_context( - attn_metadata, - self.vllm_config, - num_tokens=num_tokens, - num_tokens_across_dp=num_tokens_across_dp, - cudagraph_runtime_mode=cudagraph_runtime_mode, - batch_descriptor=batch_descriptor, - ubatch_slices=ubatch_slices, - ), - ): + with self.maybe_randomize_inputs(input_ids), set_forward_context( + attn_metadata, + self.vllm_config, + num_tokens=num_tokens, + num_tokens_across_dp=num_tokens_across_dp, + cudagraph_runtime_mode=cudagraph_runtime_mode, + batch_descriptor=batch_descriptor, + ubatch_slices=ubatch_slices): outputs = self.model( input_ids=input_ids, positions=positions, @@ -3105,22 +2998,20 @@ def _dummy_run( else: hidden_states = outputs - # Execute dummy run for drafter if self.speculative_config and self.speculative_config.use_eagle(): assert isinstance(self.drafter, EagleProposer) self.drafter.dummy_run(num_tokens) - is_draft_model = (self.speculative_config - and self.speculative_config.uses_draft_model()) - if is_draft_model: + if (self.speculative_config + and self.speculative_config.uses_draft_model()): assert isinstance(self.drafter, DraftModelProposer) forward_ctx_kwargs = { "attn_metadata": attn_metadata, "cudagraph_runtime_mode": cudagraph_runtime_mode, "batch_descriptor": batch_descriptor, } - self.drafter.dummy_run(num_tokens, - forward_ctx_kwargs=forward_ctx_kwargs) + self.drafter.dummy_run(num_tokens, forward_ctx_kwargs) + # This is necessary to avoid blocking DP. # For dummy runs, we typically skip EPLB since we don't have any real # requests to process. @@ -3172,7 +3063,7 @@ def _dummy_sampler_run( sampler_output = self.sampler(logits=logits, sampling_metadata=dummy_metadata) except RuntimeError as e: - if "out of memory" in str(e): + if 'out of memory' in str(e): raise RuntimeError( "CUDA out of memory occurred when warming up sampler with " f"{num_reqs} dummy requests. Please try lowering " @@ -3190,12 +3081,10 @@ def _dummy_sampler_run( # num_tokens, logits.shape[-1], device=self.device, # dtype=logits.dtype) draft_probs = None - target_logits = torch.randn( - num_tokens, - logits.shape[-1], - device=self.device, - dtype=logits.dtype, - ) + target_logits = torch.randn(num_tokens, + logits.shape[-1], + device=self.device, + dtype=logits.dtype) # NOTE(woosuk): Here, we should use int32 because the sampler uses # int32 for bonus_token_ids. If the dtype mismatches, re-compilation # will occur at runtime. @@ -3254,7 +3143,7 @@ def _dummy_pooler_run_task( return model.pooler(hidden_states=hidden_states, pooling_metadata=dummy_metadata) except RuntimeError as e: - if "out of memory" in str(e): + if 'out of memory' in str(e): raise RuntimeError( "CUDA out of memory occurred when warming up pooler " f"({task=}) with {num_reqs} dummy requests. Please try " @@ -3295,9 +3184,8 @@ def profile_run(self) -> None: # modality with the max possible input tokens even when # it supports multiple. dummy_modality = mm_budget.get_modality_with_max_tokens() - max_mm_items_per_batch = ( - mm_budget. - max_items_per_batch_by_modality[dummy_modality]) + max_mm_items_per_batch = mm_budget \ + .max_items_per_batch_by_modality[dummy_modality] logger.info( "Encoder cache will be initialized with a budget of " @@ -3315,9 +3203,9 @@ def profile_run(self) -> None: ) # Run multimodal encoder. - dummy_encoder_outputs = ( + dummy_encoder_outputs = \ self.model.get_multimodal_embeddings( - **batched_dummy_mm_inputs)) + **batched_dummy_mm_inputs) sanity_check_mm_encoder_outputs( dummy_encoder_outputs, @@ -3329,8 +3217,8 @@ def profile_run(self) -> None: enumerate(dummy_encoder_outputs)) # Add `is_profile` here to pre-allocate communication buffers - hidden_states, last_hidden_states = self._dummy_run( - self.max_num_tokens, is_profile=True) + hidden_states, last_hidden_states \ + = self._dummy_run(self.max_num_tokens, is_profile=True) if get_pp_group().is_last_rank: if self.is_pooling_model: output = self._dummy_pooler_run(hidden_states) @@ -3387,15 +3275,14 @@ def freeze_gc(): self._capture_cudagraphs( compilation_cases, cudagraph_runtime_mode=cudagraph_runtime_mode, - uniform_decode=False, - ) + uniform_decode=False) # Capture full cudagraph for uniform decode batches if we # don't already have full mixed prefill-decode cudagraphs. - if (cudagraph_mode.decode_mode() == CUDAGraphMode.FULL - and cudagraph_mode.separate_routine()): - max_num_tokens = (self.scheduler_config.max_num_seqs * - self.uniform_decode_query_len) + if cudagraph_mode.decode_mode() == CUDAGraphMode.FULL and \ + cudagraph_mode.separate_routine(): + max_num_tokens = self.scheduler_config.max_num_seqs * \ + self.uniform_decode_query_len decode_cudagraph_batch_sizes = [ x for x in self.cudagraph_batch_sizes if x <= max_num_tokens and x >= self.uniform_decode_query_len @@ -3405,8 +3292,7 @@ def freeze_gc(): self._capture_cudagraphs( compilation_cases=compilation_cases_decode, cudagraph_runtime_mode=CUDAGraphMode.FULL, - uniform_decode=True, - ) + uniform_decode=True) # Disable cudagraph capturing globally, so any unexpected cudagraph # capturing will be detected and raise an error after here. @@ -3420,22 +3306,16 @@ def freeze_gc(): elapsed_time = end_time - start_time cuda_graph_size = start_free_gpu_memory - end_free_gpu_memory # This usually takes 5~20 seconds. - logger.info( - "Graph capturing finished in %.0f secs, took %.2f GiB", - elapsed_time, - cuda_graph_size / (1 << 30), - ) + logger.info("Graph capturing finished in %.0f secs, took %.2f GiB", + elapsed_time, cuda_graph_size / (1 << 30)) return cuda_graph_size - def _capture_cudagraphs( - self, - compilation_cases: list[int], - cudagraph_runtime_mode: CUDAGraphMode, - uniform_decode: bool, - ): - assert (cudagraph_runtime_mode != CUDAGraphMode.NONE - and cudagraph_runtime_mode - in [CUDAGraphMode.FULL, CUDAGraphMode.PIECEWISE]) + def _capture_cudagraphs(self, compilation_cases: list[int], + cudagraph_runtime_mode: CUDAGraphMode, + uniform_decode: bool): + assert cudagraph_runtime_mode != CUDAGraphMode.NONE and \ + cudagraph_runtime_mode in [CUDAGraphMode.FULL, + CUDAGraphMode.PIECEWISE] # Only rank 0 should print progress bar during capture if is_global_first_rank(): @@ -3444,9 +3324,7 @@ def _capture_cudagraphs( disable=not self.load_config.use_tqdm_on_load, desc="Capturing CUDA graphs ({}, {})".format( "decode" if uniform_decode else "mixed prefill-decode", - cudagraph_runtime_mode.name, - ), - ) + cudagraph_runtime_mode.name)) enable_dbo = self.parallel_config.enable_dbo # DBO Only supports running Full cudagraphs with uniform # decode lengths @@ -3454,7 +3332,8 @@ def _capture_cudagraphs( for num_tokens in compilation_cases: # If the number of tokens is greater than the microbatching # threshold, don't generate a microbatched cudagraph - if num_tokens < self.parallel_config.dbo_decode_token_threshold: + if (num_tokens + < self.parallel_config.dbo_decode_token_threshold): continue # Warmup @@ -3462,23 +3341,19 @@ def _capture_cudagraphs( self.compilation_config.cudagraph_num_of_warmups): force_attention = ( cudagraph_runtime_mode == CUDAGraphMode.FULL) - self._dummy_run( - num_tokens, - cudagraph_runtime_mode=CUDAGraphMode.NONE, - force_attention=force_attention, - uniform_decode=True, - allow_microbatching=True, - skip_eplb=True, - ) + self._dummy_run(num_tokens, + cudagraph_runtime_mode=CUDAGraphMode.NONE, + force_attention=force_attention, + uniform_decode=True, + allow_microbatching=True, + skip_eplb=True) # Graph Capture - self._dummy_run( - num_tokens, - cudagraph_runtime_mode=CUDAGraphMode.FULL, - uniform_decode=True, - allow_microbatching=True, - skip_eplb=True, - ) + self._dummy_run(num_tokens, + cudagraph_runtime_mode=CUDAGraphMode.FULL, + uniform_decode=True, + allow_microbatching=True, + skip_eplb=True) # We skip EPLB here since we don't want to record dummy metrics for num_tokens in compilation_cases: for _ in range(self.compilation_config.cudagraph_num_of_warmups): @@ -3487,33 +3362,30 @@ def _capture_cudagraphs( # if we want to warm up attention or not. This is # different from the case where `FULL` implies capture # attention while `PIECEWISE` implies no attention. - force_attention = cudagraph_runtime_mode == CUDAGraphMode.FULL - self._dummy_run( - num_tokens, - cudagraph_runtime_mode=CUDAGraphMode.NONE, - force_attention=force_attention, - uniform_decode=uniform_decode, - skip_eplb=True, - remove_lora=False, - ) - self._dummy_run( - num_tokens, - cudagraph_runtime_mode=cudagraph_runtime_mode, - uniform_decode=uniform_decode, - skip_eplb=True, - remove_lora=False, - ) + force_attention = ( + cudagraph_runtime_mode == CUDAGraphMode.FULL) + self._dummy_run(num_tokens, + cudagraph_runtime_mode=CUDAGraphMode.NONE, + force_attention=force_attention, + uniform_decode=uniform_decode, + skip_eplb=True, + remove_lora=False) + self._dummy_run(num_tokens, + cudagraph_runtime_mode=cudagraph_runtime_mode, + uniform_decode=uniform_decode, + skip_eplb=True, + remove_lora=False) self.maybe_remove_all_loras(self.lora_config) def initialize_attn_backend(self, kv_cache_config: KVCacheConfig) -> None: """ Initialize the attention backends and attention metadata builders. """ - assert len(self.attn_groups) == 0, ( - "Attention backends are already initialized") + assert len(self.attn_groups) == 0, \ + "Attention backends are already initialized" def get_attn_backends_for_layers( - layer_names: list[str], + layer_names: list[str] ) -> dict[type[AttentionBackend], list[str]]: layers = get_layers_from_vllm_config(self.vllm_config, AttentionLayerBase, @@ -3591,26 +3463,26 @@ def initialize_cudagraph_capture(self) -> None: # Flexible resolve the cudagraph mode cudagraph_mode = self.compilation_config.cudagraph_mode # check cudagraph for mixed batch is supported - if (cudagraph_mode.mixed_mode() == CUDAGraphMode.FULL - and min_cg_support != AttentionCGSupport.ALWAYS): + if cudagraph_mode.mixed_mode() == CUDAGraphMode.FULL \ + and min_cg_support != AttentionCGSupport.ALWAYS: msg = (f"CUDAGraphMode.{cudagraph_mode.name} is not supported " f"with {min_cg_builder_name} backend (support: " f"{min_cg_support})") if min_cg_support == AttentionCGSupport.NEVER: # if not supported any full cudagraphs, just raise it. - msg += ("; please try cudagraph_mode=PIECEWISE, and " - "make sure compilation level is piecewise") + msg += "; please try cudagraph_mode=PIECEWISE, and "\ + "make sure compilation level is piecewise" raise ValueError(msg) # attempt to resolve the full cudagraph related mode if self.compilation_config.splitting_ops_contain_attention(): msg += "; setting cudagraph_mode=FULL_AND_PIECEWISE" - cudagraph_mode = self.compilation_config.cudagraph_mode = ( - CUDAGraphMode.FULL_AND_PIECEWISE) + cudagraph_mode = self.compilation_config.cudagraph_mode = \ + CUDAGraphMode.FULL_AND_PIECEWISE else: msg += "; setting cudagraph_mode=FULL_DECODE_ONLY" - cudagraph_mode = self.compilation_config.cudagraph_mode = ( - CUDAGraphMode.FULL_DECODE_ONLY) + cudagraph_mode = self.compilation_config.cudagraph_mode = \ + CUDAGraphMode.FULL_DECODE_ONLY logger.warning(msg) # check that if we are doing spec-decode + decode full-cudagraphs it is @@ -3623,18 +3495,18 @@ def initialize_cudagraph_capture(self) -> None: f"{min_cg_builder_name} (support: {min_cg_support})") if self.compilation_config.splitting_ops_contain_attention(): msg += "; setting cudagraph_mode=PIECEWISE" - cudagraph_mode = self.compilation_config.cudagraph_mode = ( - CUDAGraphMode.PIECEWISE) + cudagraph_mode = self.compilation_config.cudagraph_mode = \ + CUDAGraphMode.PIECEWISE else: msg += "; setting cudagraph_mode=NONE" - cudagraph_mode = self.compilation_config.cudagraph_mode = ( - CUDAGraphMode.NONE) + cudagraph_mode = self.compilation_config.cudagraph_mode = \ + CUDAGraphMode.NONE logger.warning(msg) # double check that we can support full cudagraph if they are requested # even after automatic downgrades - if (cudagraph_mode.has_full_cudagraphs() - and min_cg_support == AttentionCGSupport.NEVER): + if cudagraph_mode.has_full_cudagraphs() \ + and min_cg_support == AttentionCGSupport.NEVER: raise ValueError(f"CUDAGraphMode.{cudagraph_mode.name} is not " f"supported with {min_cg_builder_name} backend (" f"support:{min_cg_support}) " @@ -3645,8 +3517,7 @@ def initialize_cudagraph_capture(self) -> None: # initializing attn backends). self.cudagraph_dispatcher.initialize_cudagraph_keys( self.compilation_config.cudagraph_mode, - self.uniform_decode_query_len, - ) + self.uniform_decode_query_len) def calculate_reorder_batch_threshold(self) -> None: """ @@ -3662,8 +3533,8 @@ def calculate_reorder_batch_threshold(self) -> None: attn_metadata_builder_i.reorder_batch_threshold) if reorder_batch_threshold_i is not None: if self.reorder_batch_threshold is not None: - if (reorder_batch_threshold_i - != self.reorder_batch_threshold): + if reorder_batch_threshold_i != \ + self.reorder_batch_threshold: raise ValueError( f"Attention backend reorders decodes with " f"threshold {reorder_batch_threshold_i} but other " @@ -3718,7 +3589,7 @@ def _allocate_kv_cache_tensors( Returns: dict[str, torch.Tensor]: A map between layer names to their corresponding memory buffer for KV cache. - """ + """ kv_cache_raw_tensors: dict[str, torch.Tensor] = {} for kv_cache_tensor in kv_cache_config.kv_cache_tensors: tensor = torch.zeros(kv_cache_tensor.size, @@ -3733,24 +3604,21 @@ def _allocate_kv_cache_tensors( if layer_name in self.runner_only_attn_layers: continue layer_names.add(layer_name) - assert layer_names == set(kv_cache_raw_tensors.keys()), ( - "Some layers are not correctly initialized") + assert layer_names == set(kv_cache_raw_tensors.keys( + )), "Some layers are not correctly initialized" return kv_cache_raw_tensors def _attn_group_iterator(self) -> Iterator[AttentionGroup]: return itertools.chain.from_iterable(self.attn_groups) def _kv_cache_spec_attn_group_iterator( - self, ) -> Iterator[tuple[KVCacheSpec, AttentionGroup]]: + self) -> Iterator[tuple[KVCacheSpec, AttentionGroup]]: if not self.kv_cache_config.kv_cache_groups: return for kv_cache_spec_id, attn_groups in enumerate(self.attn_groups): for attn_group in attn_groups: - yield ( - self.kv_cache_config.kv_cache_groups[kv_cache_spec_id]. - kv_cache_spec, - attn_group, - ) + yield self.kv_cache_config.kv_cache_groups[ + kv_cache_spec_id].kv_cache_spec, attn_group def _reshape_kv_cache_tensors( self, @@ -3777,20 +3645,17 @@ def _reshape_kv_cache_tensors( continue raw_tensor = kv_cache_raw_tensors[layer_name] assert raw_tensor.numel() % kv_cache_spec.page_size_bytes == 0 - num_blocks = raw_tensor.numel( - ) // kv_cache_spec.page_size_bytes + num_blocks = (raw_tensor.numel() // + kv_cache_spec.page_size_bytes) if isinstance(kv_cache_spec, AttentionSpec): has_attn = True kv_cache_shape = attn_backend.get_kv_cache_shape( - num_blocks, - kv_cache_spec.block_size, - kv_cache_spec.num_kv_heads, - kv_cache_spec.head_size, - ) + num_blocks, kv_cache_spec.block_size, + kv_cache_spec.num_kv_heads, kv_cache_spec.head_size) dtype = kv_cache_spec.dtype try: - kv_cache_stride_order = ( - attn_backend.get_kv_cache_stride_order()) + kv_cache_stride_order = \ + attn_backend.get_kv_cache_stride_order() assert len(kv_cache_stride_order) == len( kv_cache_shape) except (AttributeError, NotImplementedError): @@ -3808,16 +3673,16 @@ def _reshape_kv_cache_tensors( kv_cache_stride_order.index(i) for i in range(len(kv_cache_stride_order)) ] - kv_caches[layer_name] = ( - kv_cache_raw_tensors[layer_name].view(dtype).view( - kv_cache_shape).permute(*inv_order)) + kv_caches[layer_name] = kv_cache_raw_tensors[ + layer_name].view(dtype).view(kv_cache_shape).permute( + *inv_order) elif isinstance(kv_cache_spec, MambaSpec): has_mamba = True raw_tensor = kv_cache_raw_tensors[layer_name] state_tensors = [] storage_offset_bytes = 0 - for shape, dtype in zip(kv_cache_spec.shapes, - kv_cache_spec.dtypes): + for (shape, dtype) in zip(kv_cache_spec.shapes, + kv_cache_spec.dtypes): dtype_size = get_dtype_size(dtype) num_element_per_page = ( kv_cache_spec.page_size_bytes // dtype_size) @@ -3858,19 +3723,14 @@ def _update_hybrid_attention_mamba_layout( kv_cache = kv_caches[layer_name] if (isinstance(kv_cache_spec, AttentionSpec) and kv_cache.shape[0] == 2): - assert kv_cache.shape[1] != 2, ( - "Fail to determine whether the layout is " - "(2, num_blocks, ...) or (num_blocks, 2, ...) for " - f"a tensor of shape {kv_cache.shape}") + assert kv_cache.shape[1] != 2, \ + "Fail to determine whether the layout is " \ + "(2, num_blocks, ...) or (num_blocks, 2, ...) for " \ + f"a tensor of shape {kv_cache.shape}" hidden_size = kv_cache.shape[2:].numel() - kv_cache.as_strided_( - size=kv_cache.shape, - stride=( - hidden_size, - 2 * hidden_size, - *kv_cache.stride()[2:], - ), - ) + kv_cache.as_strided_(size=kv_cache.shape, + stride=(hidden_size, 2 * hidden_size, + *kv_cache.stride()[2:])) def initialize_kv_cache_tensors( self, kv_cache_config: KVCacheConfig) -> dict[str, torch.Tensor]: @@ -3890,19 +3750,15 @@ def initialize_kv_cache_tensors( kv_cache_raw_tensors) # Set up cross-layer KV cache sharing - for ( - layer_name, - target_layer_name, - ) in self.shared_kv_cache_layers.items(): + for layer_name, target_layer_name in self.shared_kv_cache_layers.items( + ): logger.debug("%s reuses KV cache of %s", layer_name, target_layer_name) kv_caches[layer_name] = kv_caches[target_layer_name] - bind_kv_cache( - kv_caches, - self.compilation_config.static_forward_context, - self.kv_caches, - ) + bind_kv_cache(kv_caches, + self.compilation_config.static_forward_context, + self.kv_caches) return kv_caches def maybe_add_kv_sharing_layers_to_kv_cache_groups( @@ -3957,7 +3813,7 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: if has_kv_transfer_group(): get_kv_transfer_group().register_kv_caches(kv_caches) - if self.device.type == "xpu": + if self.device.type == 'xpu': get_kv_transfer_group().set_host_xfer_buffer_ops( copy_kv_blocks) @@ -3989,13 +3845,13 @@ def may_add_encoder_only_layers_to_kv_cache_config(self) -> None: num_kv_heads=attn_module.num_kv_heads, head_size=attn_module.head_size, dtype=self.kv_cache_dtype, - use_mla=use_mla, - ) + use_mla=use_mla) encoder_only_attn_specs[attn_spec].append(layer_name) self.runner_only_attn_layers.add(layer_name) if len(encoder_only_attn_specs) > 0: - assert len(encoder_only_attn_specs) == 1, ( - "Only support one encoder-only attention spec now") + assert len( + encoder_only_attn_specs + ) == 1, "Only support one encoder-only attention spec now" spec, layer_names = encoder_only_attn_specs.popitem() self.kv_cache_config.kv_cache_groups.append( KVCacheGroupSpec(layer_names=layer_names, kv_cache_spec=spec)) @@ -4036,38 +3892,32 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: head_size=attn_module.head_size, dtype=self.kv_cache_dtype, sliding_window=attn_module.sliding_window, - use_mla=use_mla, - ) - elif self.attention_chunk_size is not None and isinstance( - attn_module, ChunkedLocalAttention): + use_mla=use_mla) + elif self.attention_chunk_size is not None \ + and isinstance(attn_module, ChunkedLocalAttention): kv_cache_spec[layer_name] = ChunkedLocalAttentionSpec( block_size=block_size, num_kv_heads=attn_module.num_kv_heads, head_size=attn_module.head_size, dtype=self.kv_cache_dtype, attention_chunk_size=self.attention_chunk_size, - use_mla=use_mla, - ) + use_mla=use_mla) else: kv_cache_spec[layer_name] = FullAttentionSpec( block_size=block_size, num_kv_heads=attn_module.num_kv_heads, head_size=attn_module.head_size, dtype=self.kv_cache_dtype, - use_mla=use_mla, - ) + use_mla=use_mla) elif attn_module.attn_type == AttentionType.ENCODER_DECODER: kv_cache_spec[layer_name] = CrossAttentionSpec( block_size=block_size, num_kv_heads=attn_module.num_kv_heads, head_size=attn_module.head_size, dtype=self.kv_cache_dtype, - use_mla=use_mla, - ) - elif attn_module.attn_type in ( - AttentionType.ENCODER, - AttentionType.ENCODER_ONLY, - ): + use_mla=use_mla) + elif attn_module.attn_type in (AttentionType.ENCODER, + AttentionType.ENCODER_ONLY): # encoder-only attention does not need KV cache. continue else: From daee8ec1eb4661bb7f1cfa5fe7a3be09c33d1e88 Mon Sep 17 00:00:00 2001 From: Tomas Ruiz Date: Fri, 19 Sep 2025 09:12:01 +0200 Subject: [PATCH 13/73] Move more methods to base class Signed-off-by: Tomas Ruiz --- tests/v1/e2e/test_spec_decode.py | 18 ++- vllm/v1/spec_decode/eagle.py | 268 +++++++++++++++---------------- 2 files changed, 147 insertions(+), 139 deletions(-) diff --git a/tests/v1/e2e/test_spec_decode.py b/tests/v1/e2e/test_spec_decode.py index 195f534dc500..b78eecec85d7 100644 --- a/tests/v1/e2e/test_spec_decode.py +++ b/tests/v1/e2e/test_spec_decode.py @@ -19,7 +19,7 @@ from vllm.v1.spec_decode.metrics import compute_acceptance_rate -def get_test_prompts(mm_enabled: bool): +def get_test_prompts(mm_enabled: bool, quiet: bool = False): prompt_types = ["repeat", "sentence"] if mm_enabled: prompt_types.append("mm") @@ -28,7 +28,9 @@ def get_test_prompts(mm_enabled: bool): random.seed(0) random_prompt_type_choices = random.choices(prompt_types, k=num_prompts) - print(f"Prompt types: {random_prompt_type_choices}") + + if not quiet: + print(f"Prompt types: {random_prompt_type_choices}") # Generate a mixed batch of prompts, some of which can be easily # predicted by n-gram matching and some which likely cannot. @@ -270,12 +272,17 @@ class ArgsTest: @pytest.mark.parametrize("args", cases) @pytest.mark.parametrize("enforce_eager", [True, False]) -def test_draft_model_correctness(args: ArgsTest, enforce_eager: bool, - monkeypatch: pytest.MonkeyPatch): +@pytest.mark.parametrize("disable_padded_drafter_batch", [True, False]) +def test_draft_model_correctness( + args: ArgsTest, + enforce_eager: bool, + disable_padded_drafter_batch: bool, + monkeypatch: pytest.MonkeyPatch, +): """Compare the outputs using and not using speculative decoding. In the greedy decoding case, the outputs must match EXACTLY.""" monkeypatch.setenv("VLLM_USE_V1", "1") - test_prompts = get_test_prompts(mm_enabled=False) + test_prompts = get_test_prompts(mm_enabled=False, quiet=True) spec_llm = LLM( model=args.model, @@ -286,6 +293,7 @@ def test_draft_model_correctness(args: ArgsTest, enforce_eager: bool, "max_model_len": args.max_model_len, "enforce_eager": enforce_eager, "tensor_parallel_size": args.draft_tensor_parallel_size, + "disable_padded_drafter_batch": disable_padded_drafter_batch, }, max_model_len=args.max_model_len, gpu_memory_utilization=args.gpu_memory_utilization, diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index 78bfb3834e7d..a08f77a8b9b1 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -107,6 +107,38 @@ def __init__( device=device, with_numpy=True) + def prepare_next_token_ids_cpu( + self, sampled_token_ids: list[list[int]], + requests: dict[str, + CachedRequestState], gpu_input_batch: InputBatch, + num_scheduled_tokens: dict[str, int]) -> torch.Tensor: + """ + This function is used to prepare the inputs for speculative decoding. + It calculates the next token ids for each request based on the sampled + token ids from the CPU. If a request has no sampled token ids (e.g., + during the initial decoding steps), it falls back to using the request + state to get the next token id. + """ + req_ids = gpu_input_batch.req_ids + next_token_ids: list[int] = [] + for i, token_ids in enumerate(sampled_token_ids): + if token_ids: + # Common case. + next_token_id = token_ids[-1] + else: + # Partial prefill (rare case). + # Get the next token id from the request state. + req_id = req_ids[i] + req_state = requests[req_id] + seq_len = (req_state.num_computed_tokens + + num_scheduled_tokens[req_id]) + next_token_id = req_state.get_token_id(seq_len) + next_token_ids.append(next_token_id) + next_token_ids = torch.tensor(next_token_ids, + dtype=torch.int32, + device=self.input_ids.device) + return next_token_ids + def prepare_next_token_ids_padded(self, common_attn_metadata: CommonAttentionMetadata, sampled_token_ids: torch.Tensor, @@ -174,6 +206,108 @@ def prepare_next_token_ids_padded(self, return next_token_ids, valid_sampled_tokens_count + def prepare_inputs( + self, + common_attn_metadata: CommonAttentionMetadata, + sampled_token_ids: list[list[int]], + num_draft_tokens: list[int], + ) -> tuple[CommonAttentionMetadata, torch.Tensor]: + """ + This function is used to prepare the inputs for speculative decoding. + It updates to the common_attn_metadata to account for the rejected + tokens (and newly sampled tokens). It also returns the token indices + of the tokens that should be fed to the speculator. + """ + # E.g. + # common_attn_metadata.query_start_loc{_cpu}: + # [0, q1, q1 + q2, q1 + q2 + q3] + # common_attn_metadata.seq_lens{_cpu}: [s1, s2, s3] + # num_rejected_tokens: [n1, n2, n3] + # This function computes the intermediate values: + # num_tokens_per_req: [q1 - n1, q2 - n2, q3 - n3] + # And returns: + # common_attn_metadata.query_start_loc{_cpu}: + # [0, q1 - n1, q1 + q2 - n1 - n2, q1 + q2 + q3 - n1 - n2 - n3] + # common_attn_metadata.seq_lens{_cpu}: + # [s1 - n1 + 1, s2 - n2 + 1, s3 - n3 + 1] + # token_indices: [0, 1, ..., q1 - n1 - 1, + # q1, q1 + 1, ..., q1 + q2 - n2 - 1, + # q1 + q2, q1 + q2 + 1, ..., q1 + q2 + q3 - n3 - 1] + + num_rejected_tokens = [ + n + 1 - len(sampled_token_ids[i]) if n > 0 else 0 + for i, n in enumerate(num_draft_tokens) + ] + num_rejected_tokens = torch.tensor(num_rejected_tokens, + dtype=torch.int32) + + device = common_attn_metadata.query_start_loc.device + query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu + new_seq_lens_cpu = common_attn_metadata.seq_lens_cpu \ + - num_rejected_tokens + + # [0, q1, q1 + q2, q1 + q2 + q3] -> [q1, q2, q3] + new_query_len_per_req = (query_start_loc_cpu[1:] - + query_start_loc_cpu[:-1]) + # [q1, q2, q3] -> [q1 - n1, q2 - n2, q3 - n3] + new_num_tokens_per_req = new_query_len_per_req - num_rejected_tokens + new_num_tokens_per_req_np = new_num_tokens_per_req.numpy() + + # [q1 - n1, q2 - n2, q3 - n3] -> + # [0, q1 - n1, q1 + q2 - n1 - n2, q1 + q2 + q3 - n1 - n2 - n3] + new_query_start_loc_cpu = torch.zeros( + query_start_loc_cpu.shape, + dtype=torch.int32, + pin_memory=is_pin_memory_available()) + new_query_start_loc_np = new_query_start_loc_cpu.numpy() + np.cumsum(new_num_tokens_per_req_np, out=new_query_start_loc_np[1:]) + + total_num_tokens = new_query_start_loc_np[-1] + # Example assuming num_tokens_per_req_np = [2, 4, 3] + # this implies that `new_query_start_locs` is: + # [0, 2, 6, 9] -> + # [0, 0, 2, 2, 2, 2, 6, 6, 6] + # _r1_ ____r2____ ___r3__ + new_query_start_locs_expanded = np.repeat(new_query_start_loc_np[:-1], + new_num_tokens_per_req_np) + # [0, 1, 2, 3, 4, 5, 6, 7, 8] -> + # [0, 1, 0, 1, 2, 3, 0, 1, 2] + # _r1_ ____r2____ ___r3__ + token_offests = self.token_arange_np[:total_num_tokens] \ + - new_query_start_locs_expanded + + # Expand starting positions to match token pattern + # [0, q1, q1 + q2] -> + # [0, 0, q1, q1, q1, q1, q1 + q2, q1 + q2, q1 + q2] + # _r1_ _____r2_______ ___________r3____________ + old_query_start_locs_expanded = np.repeat( + query_start_loc_cpu[:-1].numpy(), new_num_tokens_per_req_np) + # Final token indices are: + # [0, 1, // req 1 + # q1 + 0, q1 + 1, q1 + 2, q1 + 3, // req 2 + # q1 + q2 + 0, q1 + q2 + 1, q1 + q2 + 2] // req 3 + token_indices_np = token_offests + old_query_start_locs_expanded + token_indices = torch.from_numpy(token_indices_np).to( + device, non_blocking=True) + + spec_common_attn_metadata = CommonAttentionMetadata( + query_start_loc=new_query_start_loc_cpu.to(device, + non_blocking=True), + seq_lens=new_seq_lens_cpu.to(device, non_blocking=True), + query_start_loc_cpu=new_query_start_loc_cpu, + seq_lens_cpu=new_seq_lens_cpu, + num_computed_tokens_cpu=common_attn_metadata. + num_computed_tokens_cpu, + num_reqs=common_attn_metadata.num_reqs, + num_actual_tokens=total_num_tokens, + max_query_len=new_query_len_per_req.max().item(), + max_seq_len=new_seq_lens_cpu.max().item(), + block_table_tensor=common_attn_metadata.block_table_tensor, + slot_mapping=common_attn_metadata.slot_mapping[token_indices], + causal=True, + ) + return spec_common_attn_metadata, token_indices + def prepare_inputs_padded(self, common_attn_metadata: CommonAttentionMetadata, spec_decode_metadata: SpecDecodeMetadata, @@ -510,38 +644,6 @@ def propose( draft_token_ids = torch.stack(draft_token_ids_list, dim=1) return draft_token_ids - def prepare_next_token_ids_cpu( - self, sampled_token_ids: list[list[int]], - requests: dict[str, - CachedRequestState], gpu_input_batch: InputBatch, - num_scheduled_tokens: dict[str, int]) -> torch.Tensor: - """ - This function is used to prepare the inputs for speculative decoding. - It calculates the next token ids for each request based on the sampled - token ids from the CPU. If a request has no sampled token ids (e.g., - during the initial decoding steps), it falls back to using the request - state to get the next token id. - """ - req_ids = gpu_input_batch.req_ids - next_token_ids: list[int] = [] - for i, token_ids in enumerate(sampled_token_ids): - if token_ids: - # Common case. - next_token_id = token_ids[-1] - else: - # Partial prefill (rare case). - # Get the next token id from the request state. - req_id = req_ids[i] - req_state = requests[req_id] - seq_len = (req_state.num_computed_tokens + - num_scheduled_tokens[req_id]) - next_token_id = req_state.get_token_id(seq_len) - next_token_ids.append(next_token_id) - next_token_ids = torch.tensor(next_token_ids, - dtype=torch.int32, - device=self.input_ids.device) - return next_token_ids - def propose_tree( self, batch_size: int, @@ -712,108 +814,6 @@ def propose_tree( total_num_drafts = self.cu_drafts_per_level[level + 1] return draft_token_ids_list - def prepare_inputs( - self, - common_attn_metadata: CommonAttentionMetadata, - sampled_token_ids: list[list[int]], - num_draft_tokens: list[int], - ) -> tuple[CommonAttentionMetadata, torch.Tensor]: - """ - This function is used to prepare the inputs for speculative decoding. - It updates to the common_attn_metadata to account for the rejected - tokens (and newly sampled tokens). It also returns the token indices - of the tokens that should be fed to the speculator. - """ - # E.g. - # common_attn_metadata.query_start_loc{_cpu}: - # [0, q1, q1 + q2, q1 + q2 + q3] - # common_attn_metadata.seq_lens{_cpu}: [s1, s2, s3] - # num_rejected_tokens: [n1, n2, n3] - # This function computes the intermediate values: - # num_tokens_per_req: [q1 - n1, q2 - n2, q3 - n3] - # And returns: - # common_attn_metadata.query_start_loc{_cpu}: - # [0, q1 - n1, q1 + q2 - n1 - n2, q1 + q2 + q3 - n1 - n2 - n3] - # common_attn_metadata.seq_lens{_cpu}: - # [s1 - n1 + 1, s2 - n2 + 1, s3 - n3 + 1] - # token_indices: [0, 1, ..., q1 - n1 - 1, - # q1, q1 + 1, ..., q1 + q2 - n2 - 1, - # q1 + q2, q1 + q2 + 1, ..., q1 + q2 + q3 - n3 - 1] - - num_rejected_tokens = [ - n + 1 - len(sampled_token_ids[i]) if n > 0 else 0 - for i, n in enumerate(num_draft_tokens) - ] - num_rejected_tokens = torch.tensor(num_rejected_tokens, - dtype=torch.int32) - - device = common_attn_metadata.query_start_loc.device - query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu - new_seq_lens_cpu = common_attn_metadata.seq_lens_cpu \ - - num_rejected_tokens - - # [0, q1, q1 + q2, q1 + q2 + q3] -> [q1, q2, q3] - new_query_len_per_req = (query_start_loc_cpu[1:] - - query_start_loc_cpu[:-1]) - # [q1, q2, q3] -> [q1 - n1, q2 - n2, q3 - n3] - new_num_tokens_per_req = new_query_len_per_req - num_rejected_tokens - new_num_tokens_per_req_np = new_num_tokens_per_req.numpy() - - # [q1 - n1, q2 - n2, q3 - n3] -> - # [0, q1 - n1, q1 + q2 - n1 - n2, q1 + q2 + q3 - n1 - n2 - n3] - new_query_start_loc_cpu = torch.zeros( - query_start_loc_cpu.shape, - dtype=torch.int32, - pin_memory=is_pin_memory_available()) - new_query_start_loc_np = new_query_start_loc_cpu.numpy() - np.cumsum(new_num_tokens_per_req_np, out=new_query_start_loc_np[1:]) - - total_num_tokens = new_query_start_loc_np[-1] - # Example assuming num_tokens_per_req_np = [2, 4, 3] - # this implies that `new_query_start_locs` is: - # [0, 2, 6, 9] -> - # [0, 0, 2, 2, 2, 2, 6, 6, 6] - # _r1_ ____r2____ ___r3__ - new_query_start_locs_expanded = np.repeat(new_query_start_loc_np[:-1], - new_num_tokens_per_req_np) - # [0, 1, 2, 3, 4, 5, 6, 7, 8] -> - # [0, 1, 0, 1, 2, 3, 0, 1, 2] - # _r1_ ____r2____ ___r3__ - token_offests = self.token_arange_np[:total_num_tokens] \ - - new_query_start_locs_expanded - - # Expand starting positions to match token pattern - # [0, q1, q1 + q2] -> - # [0, 0, q1, q1, q1, q1, q1 + q2, q1 + q2, q1 + q2] - # _r1_ _____r2_______ ___________r3____________ - old_query_start_locs_expanded = np.repeat( - query_start_loc_cpu[:-1].numpy(), new_num_tokens_per_req_np) - # Final token indices are: - # [0, 1, // req 1 - # q1 + 0, q1 + 1, q1 + 2, q1 + 3, // req 2 - # q1 + q2 + 0, q1 + q2 + 1, q1 + q2 + 2] // req 3 - token_indices_np = token_offests + old_query_start_locs_expanded - token_indices = torch.from_numpy(token_indices_np).to( - device, non_blocking=True) - - spec_common_attn_metadata = CommonAttentionMetadata( - query_start_loc=new_query_start_loc_cpu.to(device, - non_blocking=True), - seq_lens=new_seq_lens_cpu.to(device, non_blocking=True), - query_start_loc_cpu=new_query_start_loc_cpu, - seq_lens_cpu=new_seq_lens_cpu, - num_computed_tokens_cpu=common_attn_metadata. - num_computed_tokens_cpu, - num_reqs=common_attn_metadata.num_reqs, - num_actual_tokens=total_num_tokens, - max_query_len=new_query_len_per_req.max().item(), - max_seq_len=new_seq_lens_cpu.max().item(), - block_table_tensor=common_attn_metadata.block_table_tensor, - slot_mapping=common_attn_metadata.slot_mapping[token_indices], - causal=True, - ) - return spec_common_attn_metadata, token_indices - def load_model(self, target_model: nn.Module) -> None: draft_model_config = \ self.vllm_config.speculative_config.draft_model_config From 07d1b97fddcd0212204b59c593384739a0cb4b11 Mon Sep 17 00:00:00 2001 From: Tomas Ruiz Date: Mon, 22 Sep 2025 16:00:42 +0200 Subject: [PATCH 14/73] Fix call to model.compute_logits() Signed-off-by: Tomas Ruiz --- vllm/v1/spec_decode/draft_model.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/vllm/v1/spec_decode/draft_model.py b/vllm/v1/spec_decode/draft_model.py index b5cbe0d5e5cf..f3db41ef7d8b 100644 --- a/vllm/v1/spec_decode/draft_model.py +++ b/vllm/v1/spec_decode/draft_model.py @@ -111,7 +111,7 @@ def propose( last_hidden_states = self.model(**model_kwargs) sample_hidden_states = last_hidden_states[last_token_indices] - logits = self.model.compute_logits(sample_hidden_states, None) + logits = self.model.compute_logits(sample_hidden_states) positions = target_positions[last_token_indices] if isinstance(attn_metadata, TreeAttentionMetadata): @@ -203,8 +203,7 @@ def propose( ): last_hidden_states = self.model(**model_kwargs) - logits = self.model.compute_logits(last_hidden_states[:batch_size], - None) + logits = self.model.compute_logits(last_hidden_states[:batch_size]) draft_token_ids = logits.argmax(dim=-1) draft_token_ids_list.append(draft_token_ids) From 86d80401a22212de53b87b61c80c5e41726183e3 Mon Sep 17 00:00:00 2001 From: Tomas Ruiz Date: Mon, 22 Sep 2025 16:49:36 +0200 Subject: [PATCH 15/73] Move .propose() to superclass Signed-off-by: Tomas Ruiz --- vllm/v1/spec_decode/draft_model.py | 186 ++--------------- vllm/v1/spec_decode/eagle.py | 307 ++++++++++++++++++----------- vllm/v1/worker/gpu_model_runner.py | 22 +-- 3 files changed, 223 insertions(+), 292 deletions(-) diff --git a/vllm/v1/spec_decode/draft_model.py b/vllm/v1/spec_decode/draft_model.py index f3db41ef7d8b..1647e24c0ca2 100644 --- a/vllm/v1/spec_decode/draft_model.py +++ b/vllm/v1/spec_decode/draft_model.py @@ -1,19 +1,15 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from dataclasses import replace -from typing import Any, Optional +from typing import Any import torch from vllm.attention.layer import Attention from vllm.config import ModelConfig, VllmConfig, get_layers_from_vllm_config -from vllm.config.compilation import CUDAGraphMode -from vllm.forward_context import BatchDescriptor, set_forward_context +from vllm.forward_context import set_forward_context from vllm.model_executor.model_loader import get_model -from vllm.v1.attention.backends.tree_attn import TreeAttentionMetadata -from vllm.v1.attention.backends.utils import CommonAttentionMetadata -from vllm.v1.spec_decode.eagle import PADDING_SLOT_ID, SpecDecodeBaseProposer -from vllm.v1.worker.ubatching import dbo_current_ubatch_id +from vllm.v1.spec_decode.eagle import SpecDecodeBaseProposer class DraftModelProposer(SpecDecodeBaseProposer): @@ -24,7 +20,19 @@ def __init__( device: torch.device, runner=None, ): - super().__init__(vllm_config, device, runner) + super().__init__( + vllm_config=vllm_config, + device=device, + pass_hidden_states_to_model=False, + pass_cudagraph_args_to_forward_ctx=True, + # the first draft_token_ids are identical to next_token_ids, so + # they don't need to be returned as proposed tokens + drop_first_drafted_tokens=True, + runner=runner) + # The draft model runs one forward pass to prefill + # the target_token_ids, and another forward pass for decoding + # based on the next_token_ids. I.e. it needs 1 more forward pass. + self.num_forward_passes = self.num_speculative_tokens + 1 self._raise_if_multimodal() self._raise_if_mrope() @@ -48,6 +56,7 @@ def _model_kwargs(self, num_tokens: int) -> dict[str, Any]: def dummy_run(self, num_tokens: int, forward_ctx_kwargs: dict): model_kwargs = self._model_kwargs(num_tokens) + assert isinstance(self.model, torch.nn.Module) with set_forward_context( vllm_config=self.vllm_config, num_tokens=num_tokens, @@ -55,166 +64,11 @@ def dummy_run(self, num_tokens: int, forward_ctx_kwargs: dict): ): self.model(**model_kwargs) - # Copied and adapted from eagle.py - def propose( - self, - # [num_tokens] - target_token_ids: torch.Tensor, - # [num_tokens] - target_positions: torch.Tensor, - # [batch_size] - next_token_ids: torch.Tensor, - last_token_indices: Optional[torch.Tensor], - common_attn_metadata: CommonAttentionMetadata, - cudagraph_runtime_mode: CUDAGraphMode, - batch_descriptor: BatchDescriptor, - ) -> torch.Tensor: - num_tokens = target_token_ids.shape[0] - batch_size = next_token_ids.shape[0] - if last_token_indices is None: - last_token_indices = common_attn_metadata.query_start_loc[1:] - 1 - + def set_input_ids_first_pass(self, target_token_ids: torch.Tensor, + next_token_ids: torch.Tensor, num_tokens: int, + last_token_indices: torch.Tensor) -> None: self.input_ids[:num_tokens] = target_token_ids - assert self.runner is not None - - # FIXME: need to consider multiple kv_cache_groups - assert len(self.runner.attn_groups) == 1 - assert len(self.runner.attn_groups[0]) == 1 - ubatch_id = dbo_current_ubatch_id() - attn_metadata_builder = self.runner.attn_groups[0][ - 0].metadata_builders[ubatch_id] - attn_metadata = attn_metadata_builder.build_for_drafting( - common_attn_metadata=common_attn_metadata, draft_index=0) - - # At this moment, we assume all draft model layers belong to the same KV - # cache group, thus using the same attention metadata. - per_layer_attn_metadata = {} - for layer_name in self.attn_layer_names: - per_layer_attn_metadata[layer_name] = attn_metadata - - if self.use_cuda_graph and num_tokens <= self.cudagraph_batch_sizes[-1]: - num_input_tokens = self.vllm_config.pad_for_cudagraph(num_tokens) - else: - num_input_tokens = num_tokens - # copy inputs to buffer for cudagraph - self.positions[:num_tokens] = target_positions - - model_kwargs = self._model_kwargs(num_input_tokens) - with set_forward_context( - per_layer_attn_metadata, - self.vllm_config, - num_tokens=num_input_tokens, - cudagraph_runtime_mode=cudagraph_runtime_mode, - batch_descriptor=batch_descriptor, - ): - last_hidden_states = self.model(**model_kwargs) - - sample_hidden_states = last_hidden_states[last_token_indices] - logits = self.model.compute_logits(sample_hidden_states) - positions = target_positions[last_token_indices] - - if isinstance(attn_metadata, TreeAttentionMetadata): - raise NotImplementedError("Speculative Decoding with draft models " - "does not support tree attention yet") - - # Reuse the next_token_ids to avoid a potential rejection - draft_token_ids = next_token_ids - - # The draft model runs one forward pass to prefill - # the target_token_ids, and another forward pass for decoding - # based on the next_token_ids. I.e. it needs 1 more forward pass. - n_forward_passes = self.num_speculative_tokens + 1 - # Early exit if there is only one draft token to be generated. - if n_forward_passes == 1: - # [batch_size, 1] - return draft_token_ids.view(-1, 1) - - # Generate the remaining draft tokens. - draft_token_ids_list = [draft_token_ids] - - if self.use_cuda_graph and batch_size <= self.cudagraph_batch_sizes[-1]: - input_batch_size = self.vllm_config.pad_for_cudagraph(batch_size) - else: - input_batch_size = batch_size - - attn_metadata.num_actual_tokens = batch_size - attn_metadata.max_query_len = 1 - attn_metadata.query_start_loc = self.arange[:batch_size + 1] - for _ in range(n_forward_passes - 1): - # Update the inputs. - # cast to int32 is crucial when draft model is compiled. - # tensor.argmax() returns int64 by default. - input_ids = draft_token_ids_list[-1].int() - positions += 1 - - # NOTE(woosuk): We should handle the case where the draft model - # generates tokens beyond the max model length. Since it is complex - # to remove such requests from the batch, we keep them in the batch - # but adjust the position ids and slot mappings to avoid the - # out-of-range access during the model execution. The draft tokens - # generated with this adjustment should be ignored. - exceeds_max_model_len = positions >= self.max_model_len - # Mask out the position ids that exceed the max model length. - # Otherwise, we may get out-of-range error in RoPE. - clamped_positions = torch.where(exceeds_max_model_len, 0, - positions) - - # Increment the sequence lengths. - attn_metadata.max_seq_len += 1 - attn_metadata.seq_lens += 1 - # Consider max model length. - attn_metadata.max_seq_len = min(attn_metadata.max_seq_len, - self.max_model_len) - # For the requests that exceed the max model length, we set the - # sequence length to 1 to minimize their overheads in attention. - attn_metadata.seq_lens.masked_fill_(exceeds_max_model_len, 1) - - # Compute the slot mapping. - block_numbers = clamped_positions // self.block_size - block_ids = attn_metadata.block_table.gather( - dim=1, index=block_numbers.view(-1, 1)) - block_ids = block_ids.view(-1) - attn_metadata.slot_mapping = (block_ids * self.block_size + - clamped_positions % self.block_size) - # Mask out the slot mappings that exceed the max model length. - # Otherwise, the KV cache will be inadvertently updated with the - # padding tokens. - attn_metadata.slot_mapping.masked_fill_(exceeds_max_model_len, - PADDING_SLOT_ID) - - # copy inputs to buffer for cudagraph - self.input_ids[:batch_size] = input_ids - self.positions[:batch_size] = clamped_positions - - model_kwargs = self._model_kwargs(input_batch_size) - batch_descriptor = BatchDescriptor(num_tokens=input_batch_size, - uniform_decode=True) - cudagraph_runtime_mode, batch_descriptor = ( - self.runner.cudagraph_dispatcher.dispatch(batch_descriptor)) - - # Run the model. - with set_forward_context( - per_layer_attn_metadata, - self.vllm_config, - num_tokens=input_batch_size, - cudagraph_runtime_mode=cudagraph_runtime_mode, - batch_descriptor=batch_descriptor, - ): - last_hidden_states = self.model(**model_kwargs) - - logits = self.model.compute_logits(last_hidden_states[:batch_size]) - draft_token_ids = logits.argmax(dim=-1) - draft_token_ids_list.append(draft_token_ids) - - # the first draft_token_ids are identical to next_token_ids, so - # they don't need to be returned as proposed tokens - draft_token_ids_list = draft_token_ids_list[1:] - - # [batch_size, num_speculative_tokens] - draft_token_ids = torch.stack(draft_token_ids_list, dim=1) - return draft_token_ids - def load_model(self) -> None: draft_model_config: ModelConfig = ( self.vllm_config.speculative_config.draft_model_config) diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index 05663469cf4c..995b65f60eb8 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -1,9 +1,10 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import ast +from abc import ABC, abstractmethod from dataclasses import replace from importlib.util import find_spec -from typing import Optional, Protocol +from typing import Optional, Protocol, TypedDict import numpy as np import torch @@ -12,8 +13,9 @@ from vllm.attention.layer import Attention from vllm.config import (CompilationLevel, VllmConfig, get_layers_from_vllm_config) +from vllm.config.compilation import CUDAGraphMode from vllm.distributed.parallel_state import get_pp_group -from vllm.forward_context import set_forward_context +from vllm.forward_context import BatchDescriptor, set_forward_context from vllm.logger import init_logger from vllm.model_executor.model_loader import get_model from vllm.model_executor.models import supports_multimodal @@ -48,18 +50,25 @@ class EagleAttentionMetadata(Protocol): slot_mapping: torch.Tensor -class SpecDecodeBaseProposer: +class SpecDecodeBaseProposer(ABC): def __init__( self, vllm_config: VllmConfig, device: torch.device, + pass_hidden_states_to_model: bool, + pass_cudagraph_args_to_forward_ctx: bool, + drop_first_drafted_tokens: bool, runner=None, ): self.vllm_config = vllm_config self.speculative_config = vllm_config.speculative_config self.draft_model_config = self.speculative_config.draft_model_config self.method = self.speculative_config.method + self.pass_hidden_states_to_model = pass_hidden_states_to_model + self.pass_cudagraph_args_to_forward_ctx \ + = pass_cudagraph_args_to_forward_ctx + self.drop_first_drafted_tokens = drop_first_drafted_tokens self.runner = runner self.dtype = vllm_config.model_config.dtype @@ -67,9 +76,14 @@ def __init__( self.block_size = vllm_config.cache_config.block_size self.num_speculative_tokens = ( self.speculative_config.num_speculative_tokens) + self.num_forward_passes = self.num_speculative_tokens self.max_num_tokens = ( vllm_config.scheduler_config.max_num_batched_tokens) self.token_arange_np = np.arange(self.max_num_tokens) + # We need to get the hidden size from the draft model config because + # the draft model's hidden size can be different from the target model's + # hidden size (e.g., Llama 3.3 70B). + self.hidden_size = self.draft_model_config.get_hidden_size() self.is_multimodal_model = vllm_config.model_config \ .is_multimodal_model @@ -88,25 +102,73 @@ def __init__( self.positions = torch.zeros(self.max_num_tokens, dtype=torch.int64, device=device) + self.hidden_states = torch.zeros( + (self.max_num_tokens, self.hidden_size), + dtype=self.dtype, + device=device) - self.max_batch_size = vllm_config.scheduler_config.max_num_seqs - max_num_slots_for_arange = max(self.max_batch_size + 1, - self.max_num_tokens) - self.arange = torch.arange( - # We need +1 here because the arange is used to set query_start_loc, - # which has one more element than batch_size. - max_num_slots_for_arange, - device=device, - dtype=torch.int32, - ) + # We need +1 here because the arange is used to set query_start_loc, + # which has one more element than batch_size. + max_batch_size = vllm_config.scheduler_config.max_num_seqs + max_num_slots_for_arange = max(max_batch_size + 1, self.max_num_tokens) + self.arange = torch.arange(max_num_slots_for_arange, + device=device, + dtype=torch.int32) + + self.inputs_embeds = torch.zeros( + (self.max_num_tokens, self.hidden_size), + dtype=self.dtype, + device=device) self.backup_next_token_ids = CpuGpuBuffer( - self.max_batch_size, + max_batch_size, dtype=torch.int32, pin_memory=is_pin_memory_available(), device=device, with_numpy=True) + # Determine allowed attention backends once during initialization. + self.allowed_attn_types: tuple[type[EagleAttentionMetadata], ...] + if current_platform.is_rocm(): + rocm_types = [TritonAttentionMetadata, FlashAttentionMetadata] + # vllm.v1.attention.backends.rocm_aiter_fa is an optional backend + if find_spec("vllm.v1.attention.backends.rocm_aiter_fa"): + from vllm.v1.attention.backends.rocm_aiter_fa import ( + AiterFlashAttentionMetadata) + rocm_types.append(AiterFlashAttentionMetadata) + self.allowed_attn_types = tuple(rocm_types) + else: + self.allowed_attn_types = (FlashAttentionMetadata, + TreeAttentionMetadata) + + # Parse the speculative token tree. + spec_token_tree = self.speculative_config.speculative_token_tree + self.tree_choices: list[tuple[int, + ...]] = ast.literal_eval(spec_token_tree) + tree_depth = len(self.tree_choices[-1]) + # Precompute per-level properties of the tree. + num_drafts_per_level = [0] * tree_depth + for node in self.tree_choices: + num_drafts_per_level[len(node) - 1] += 1 + self.cu_drafts_per_level = [num_drafts_per_level[0]] + self.child_drafts_per_level = [num_drafts_per_level[0]] + for level in range(1, tree_depth): + self.cu_drafts_per_level.append(self.cu_drafts_per_level[-1] + + num_drafts_per_level[level]) + self.child_drafts_per_level.append(num_drafts_per_level[level] // + num_drafts_per_level[level - 1]) + # Precompute draft position offsets in flattened tree. + self.tree_draft_pos_offsets = torch.arange( + 1, + len(self.tree_choices) + 1, + device=device, + dtype=torch.int32, + ).repeat(max_batch_size, 1) + + # Lazily loaded attributes. + self.model: Optional[nn.Module] = None + self.attn_layer_names: list[str] = [] + def prepare_next_token_ids_cpu( self, sampled_token_ids: list[list[int]], requests: dict[str, @@ -361,68 +423,6 @@ def prepare_inputs_padded(self, return spec_common_attn_metadata, token_indices, token_indices_to_sample - -class EagleProposer(SpecDecodeBaseProposer): - - def __init__( - self, - vllm_config: VllmConfig, - device: torch.device, - runner=None, - ): - super().__init__(vllm_config, device, runner) - # We need to get the hidden size from the draft model config because - # the draft model's hidden size can be different from the target model's - # hidden size (e.g., Llama 3.3 70B). - self.hidden_size = self.draft_model_config.get_hidden_size() - self.hidden_states = torch.zeros( - (self.max_num_tokens, self.hidden_size), - dtype=self.dtype, - device=device) - - self.inputs_embeds = torch.zeros( - (self.max_num_tokens, self.hidden_size), - dtype=self.dtype, - device=device) - - # Determine allowed attention backends once during initialization. - self.allowed_attn_types: tuple[type[EagleAttentionMetadata], ...] - if current_platform.is_rocm(): - rocm_types = [TritonAttentionMetadata, FlashAttentionMetadata] - # vllm.v1.attention.backends.rocm_aiter_fa is an optional backend - if find_spec("vllm.v1.attention.backends.rocm_aiter_fa"): - from vllm.v1.attention.backends.rocm_aiter_fa import ( - AiterFlashAttentionMetadata) - rocm_types.append(AiterFlashAttentionMetadata) - self.allowed_attn_types = tuple(rocm_types) - else: - self.allowed_attn_types = (FlashAttentionMetadata, - TreeAttentionMetadata) - - # Parse the speculative token tree. - spec_token_tree = self.speculative_config.speculative_token_tree - self.tree_choices: list[tuple[int, - ...]] = ast.literal_eval(spec_token_tree) - tree_depth = len(self.tree_choices[-1]) - # Precompute per-level properties of the tree. - num_drafts_per_level = [0] * tree_depth - for node in self.tree_choices: - num_drafts_per_level[len(node) - 1] += 1 - self.cu_drafts_per_level = [num_drafts_per_level[0]] - self.child_drafts_per_level = [num_drafts_per_level[0]] - for level in range(1, tree_depth): - self.cu_drafts_per_level.append(self.cu_drafts_per_level[-1] + - num_drafts_per_level[level]) - self.child_drafts_per_level.append(num_drafts_per_level[level] // - num_drafts_per_level[level - 1]) - # Precompute draft position offsets in flattened tree. - self.tree_draft_pos_offsets = torch.arange( - 1, - len(self.tree_choices) + 1, - device=device, - dtype=torch.int32, - ).repeat(self.max_batch_size, 1) - def propose( self, # [num_tokens] @@ -436,6 +436,7 @@ def propose( last_token_indices: Optional[torch.Tensor], common_attn_metadata: CommonAttentionMetadata, sampling_metadata: SamplingMetadata, + cudagraph_args: "CudaGraphArgs", mm_embeds: Optional[list[torch.Tensor]] = None, ) -> torch.Tensor: num_tokens = target_token_ids.shape[0] @@ -450,14 +451,11 @@ def propose( target_hidden_states) assert target_hidden_states.shape[-1] == self.hidden_size - # Shift the input ids by one token. - # E.g., [a1, b1, b2, c1, c2, c3] -> [b1, b2, c1, c2, c3, c3] - self.input_ids[:num_tokens - 1] = target_token_ids[1:] - # Replace the last token with the next token. - # E.g., [b1, b2, c1, c2, c3, c3] -> [a2, b2, b3, c2, c3, c4] - self.input_ids[last_token_indices] = next_token_ids + self.set_input_ids_first_pass(target_token_ids, next_token_ids, + num_tokens, last_token_indices) assert self.runner is not None + assert isinstance(self.model, nn.Module) # FIXME: need to consider multiple kv_cache_groups ubatch_id = dbo_current_ubatch_id() @@ -478,10 +476,14 @@ def propose( num_input_tokens = num_tokens # copy inputs to buffer for cudagraph self.positions[:num_tokens] = target_positions - self.hidden_states[:num_tokens] = target_hidden_states + if self.pass_hidden_states_to_model: + # target_hidden_states and self.hidden_states can have different + # hidden dims. E.g. large target model and small draft model. + self.hidden_states[:num_tokens] = target_hidden_states + if self.is_multimodal_model: input_ids = self.input_ids[:num_tokens] - inputs_embeds = self.model.get_input_embeddings( + inputs_embeds = self.model.get_input_embeddings( # type: ignore input_ids, multimodal_embeddings=mm_embeds or None, ) @@ -492,25 +494,36 @@ def propose( inputs_embeds = None input_ids = self.input_ids[:num_input_tokens] - with set_forward_context(per_layer_attn_metadata, - self.vllm_config, - num_tokens=num_input_tokens): - ret_hidden_states = self.model( - input_ids=input_ids, - positions=self.positions[:num_input_tokens], - hidden_states=self.hidden_states[:num_input_tokens], - inputs_embeds=inputs_embeds, - ) - if self.method in ("deepseek_mtp", "ernie_mtp", "qwen3_next_mtp"): + model_kwargs = { + "input_ids": input_ids, + "positions": self.positions[:num_input_tokens], + "inputs_embeds": inputs_embeds, + } + if self.pass_hidden_states_to_model: + model_kwargs[ + "hidden_states"] = self.hidden_states[:num_input_tokens] + + forward_ctx_kwargs = dict( + attn_metadata=per_layer_attn_metadata, + vllm_config=self.vllm_config, + num_tokens=num_input_tokens, + ) + if self.pass_cudagraph_args_to_forward_ctx: + forward_ctx_kwargs.update(cudagraph_args) + + with set_forward_context(**forward_ctx_kwargs): + ret_hidden_states = self.model(**model_kwargs) + if not self.model_returns_tuple(): last_hidden_states = ret_hidden_states hidden_states = last_hidden_states else: last_hidden_states, hidden_states = ret_hidden_states sample_hidden_states = last_hidden_states[last_token_indices] - logits = self.model.compute_logits(sample_hidden_states) + logits = self.model.compute_logits( + sample_hidden_states) # type: ignore # Early exit if there is only one draft token to be generated. - if self.num_speculative_tokens == 1: + if self.num_forward_passes == 1: draft_token_ids = logits.argmax(dim=-1) return draft_token_ids.view(-1, 1) @@ -552,7 +565,7 @@ def propose( common_attn_metadata.query_start_loc = self.arange[:batch_size + 1] common_attn_metadata.query_start_loc_cpu = torch.from_numpy( self.token_arange_np[:batch_size + 1]).clone() - for token_index in range(self.num_speculative_tokens - 1): + for token_index in range(self.num_forward_passes - 1): # Update the inputs. # cast to int32 is crucial when eagle model is compiled. # tensor.argmax() returns int64 by default. @@ -610,7 +623,8 @@ def propose( self.positions[:batch_size] = clamped_positions self.hidden_states[:batch_size] = hidden_states if self.is_multimodal_model: - inputs_embeds = self.model.get_input_embeddings(input_ids) + inputs_embeds = self.model.get_input_embeddings( + input_ids) # type: ignore self.inputs_embeds[:batch_size] = inputs_embeds inputs_embeds = self.inputs_embeds[:input_batch_size] input_ids = None @@ -619,17 +633,28 @@ def propose( input_ids = self.input_ids[:input_batch_size] # Run the model. - with set_forward_context(per_layer_attn_metadata, - self.vllm_config, - num_tokens=input_batch_size): - ret_hidden_states = self.model( - input_ids=input_ids, - positions=self.positions[:input_batch_size], - hidden_states=self.hidden_states[:input_batch_size], - inputs_embeds=inputs_embeds, - ) - if self.method in ("deepseek_mtp", "ernie_mtp", - "qwen3_next_mtp"): + model_kwargs = { + "input_ids": input_ids, + "positions": self.positions[:input_batch_size], + "inputs_embeds": inputs_embeds, + } + if self.pass_hidden_states_to_model: + model_kwargs[ + "hidden_states"] = self.hidden_states[:input_batch_size] + + forward_ctx_kwargs = dict( + attn_metadata=per_layer_attn_metadata, + vllm_config=self.vllm_config, + num_tokens=input_batch_size, + ) + cudagraph_args = self.new_cudagraph_args( + num_tokens=input_batch_size) + if self.pass_cudagraph_args_to_forward_ctx: + forward_ctx_kwargs.update(cudagraph_args) + + with set_forward_context(**forward_ctx_kwargs): + ret_hidden_states = self.model(**model_kwargs) # type: ignore + if not self.model_returns_tuple(): last_hidden_states = ret_hidden_states hidden_states = ret_hidden_states else: @@ -639,6 +664,9 @@ def propose( draft_token_ids = logits.argmax(dim=-1) draft_token_ids_list.append(draft_token_ids) + if self.drop_first_drafted_tokens: + draft_token_ids_list = draft_token_ids_list[1:] + # [batch_size, num_speculative_tokens] draft_token_ids = torch.stack(draft_token_ids_list, dim=1) return draft_token_ids @@ -774,6 +802,7 @@ def propose_tree( else: num_input_tokens = num_tokens # Run the model. + assert isinstance(self.model, nn.Module) with set_forward_context(per_layer_attn_metadata, self.vllm_config, num_tokens=num_input_tokens): @@ -811,6 +840,47 @@ def propose_tree( total_num_drafts = self.cu_drafts_per_level[level + 1] return draft_token_ids_list + @abstractmethod + def set_input_ids_first_pass(self, target_token_ids: torch.Tensor, + next_token_ids: torch.Tensor, num_tokens: int, + last_token_indices: torch.Tensor) -> None: + raise NotImplementedError() + + def model_returns_tuple(self) -> bool: + return self.method not in ("deepseek_mtp", "ernie_mtp", + "qwen3_next_mtp", "draft_model") + + def new_cudagraph_args(self, num_tokens: int) -> "CudaGraphArgs": + batch_descriptor = BatchDescriptor(num_tokens=num_tokens, + uniform_decode=True) + cudagraph_runtime_mode, batch_descriptor = ( + self.runner.cudagraph_dispatcher.dispatch(batch_descriptor)) + return CudaGraphArgs( + cudagraph_runtime_mode=cudagraph_runtime_mode, + batch_descriptor=batch_descriptor, + ) + + +class CudaGraphArgs(TypedDict): + cudagraph_runtime_mode: CUDAGraphMode + batch_descriptor: BatchDescriptor + + +class EagleProposer(SpecDecodeBaseProposer): + + def __init__( + self, + vllm_config: VllmConfig, + device: torch.device, + runner=None, + ): + super().__init__(vllm_config=vllm_config, + device=device, + pass_hidden_states_to_model=True, + pass_cudagraph_args_to_forward_ctx=False, + drop_first_drafted_tokens=False, + runner=runner) + def load_model(self, target_model: nn.Module) -> None: draft_model_config = \ self.vllm_config.speculative_config.draft_model_config @@ -830,7 +900,7 @@ def load_model(self, target_model: nn.Module) -> None: if supports_multimodal(target_model): # handle multimodality - self.model.config.image_token_index = ( + self.model.config.image_token_index = ( # type: ignore target_model.config.image_token_index) target_language_model = target_model.get_language_model() else: @@ -838,11 +908,11 @@ def load_model(self, target_model: nn.Module) -> None: # share embed_tokens with the target model if needed if get_pp_group().world_size == 1 \ and self.model.model.embed_tokens.weight.shape \ - == target_language_model.model.embed_tokens.weight.shape: + == target_language_model.model.embed_tokens.weight.shape: # type: ignore logger.info( "Assuming the EAGLE head shares the same vocab embedding" " with the target model.") - del self.model.model.embed_tokens + del self.model.model.embed_tokens # type: ignore self.model.model.embed_tokens = ( target_language_model.model.embed_tokens) else: @@ -856,7 +926,7 @@ def load_model(self, target_model: nn.Module) -> None: if self.vllm_config.speculative_config.method != "eagle3" and \ hasattr(target_language_model, "lm_head"): logger.info("Loading EAGLE LM head weights from the target model.") - self.model.lm_head = target_language_model.lm_head + self.model.lm_head = target_language_model.lm_head # type: ignore @torch.inference_mode() def dummy_run( @@ -872,6 +942,7 @@ def dummy_run( input_ids = self.input_ids[:num_tokens] inputs_embeds = None + assert isinstance(self.model, nn.Module) self.model( input_ids=input_ids, positions=self.positions[:num_tokens], @@ -898,6 +969,16 @@ def validate_same_kv_cache_group(self, ]) ) == 1, "All eagle layers should belong to the same kv cache group" + def set_input_ids_first_pass(self, target_token_ids: torch.Tensor, + next_token_ids: torch.Tensor, num_tokens: int, + last_token_indices: torch.Tensor) -> None: + # Shift the input ids by one token. + # E.g., [a1, b1, b2, c1, c2, c3] -> [b1, b2, c1, c2, c3, c3] + self.input_ids[:num_tokens - 1] = target_token_ids[1:] + # Replace the last token with the next token. + # E.g., [b1, b2, c1, c2, c3, c3] -> [a2, b2, b3, c2, c3, c4] + self.input_ids[last_token_indices] = next_token_ids + # NOTE(woosuk): Currently, the below code is not used and we always use argmax # to sample the draft tokens. We will use this after we find a way to manage diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index a95940ba3c1a..927e7b945cef 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -2478,8 +2478,14 @@ def propose_draft_token_ids( if self.supports_mm_inputs: mm_embeds = self._gather_mm_embeddings(scheduler_output, shift_computed_tokens=1) - if self.speculative_config.use_eagle(): - assert isinstance(self.drafter, EagleProposer) + if (self.speculative_config.use_eagle() + or self.speculative_config.uses_draft_model()): + assert isinstance(self.drafter, + (EagleProposer, DraftModelProposer)) + cudagraph_args = dict( + cudagraph_runtime_mode=cudagraph_runtime_mode, + batch_descriptor=batch_descriptor, + ) draft_token_ids = self.drafter.propose( target_token_ids=target_token_ids, target_positions=target_positions, @@ -2489,17 +2495,7 @@ def propose_draft_token_ids( sampling_metadata=sampling_metadata, common_attn_metadata=common_attn_metadata, mm_embeds=mm_embeds, - ) - elif self.speculative_config.uses_draft_model(): - assert isinstance(self.drafter, DraftModelProposer) - draft_token_ids = self.drafter.propose( - target_token_ids=target_token_ids, - target_positions=target_positions, - next_token_ids=next_token_ids, - last_token_indices=token_indices_to_sample, - common_attn_metadata=common_attn_metadata, - cudagraph_runtime_mode=cudagraph_runtime_mode, - batch_descriptor=batch_descriptor, + cudagraph_args=cudagraph_args, ) return draft_token_ids From d37d780a3626ed1582043bcd221109dfb4768dea Mon Sep 17 00:00:00 2001 From: Tomas Ruiz Date: Wed, 24 Sep 2025 09:12:39 +0200 Subject: [PATCH 16/73] Minimize git diffs in EAGLE Signed-off-by: Tomas Ruiz --- vllm/v1/spec_decode/draft_model.py | 3 +- vllm/v1/spec_decode/eagle.py | 727 ++++++++++++++--------------- 2 files changed, 360 insertions(+), 370 deletions(-) diff --git a/vllm/v1/spec_decode/draft_model.py b/vllm/v1/spec_decode/draft_model.py index 1647e24c0ca2..484fdd7db0e5 100644 --- a/vllm/v1/spec_decode/draft_model.py +++ b/vllm/v1/spec_decode/draft_model.py @@ -25,6 +25,7 @@ def __init__( device=device, pass_hidden_states_to_model=False, pass_cudagraph_args_to_forward_ctx=True, + one_extra_forward_pass=True, # the first draft_token_ids are identical to next_token_ids, so # they don't need to be returned as proposed tokens drop_first_drafted_tokens=True, @@ -69,7 +70,7 @@ def set_input_ids_first_pass(self, target_token_ids: torch.Tensor, last_token_indices: torch.Tensor) -> None: self.input_ids[:num_tokens] = target_token_ids - def load_model(self) -> None: + def load_model(self, target_model: Any) -> None: draft_model_config: ModelConfig = ( self.vllm_config.speculative_config.draft_model_config) vllm_config_draft: VllmConfig = replace( diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index ed76213b6b40..dac219d7a738 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import ast -from abc import ABC, abstractmethod from dataclasses import replace from importlib.util import find_spec from typing import Optional, TypedDict @@ -39,7 +38,7 @@ PADDING_SLOT_ID = -1 -class SpecDecodeBaseProposer(ABC): +class SpecDecodeBaseProposer: def __init__( self, @@ -47,6 +46,7 @@ def __init__( device: torch.device, pass_hidden_states_to_model: bool, pass_cudagraph_args_to_forward_ctx: bool, + one_extra_forward_pass: bool, drop_first_drafted_tokens: bool, runner=None, ): @@ -66,6 +66,8 @@ def __init__( self.num_speculative_tokens = ( self.speculative_config.num_speculative_tokens) self.num_forward_passes = self.num_speculative_tokens + if one_extra_forward_pass: + self.num_forward_passes += 1 self.max_num_tokens = ( vllm_config.scheduler_config.max_num_batched_tokens) self.token_arange_np = np.arange(self.max_num_tokens) @@ -153,264 +155,6 @@ def __init__( dtype=torch.int32, ).repeat(max_batch_size, 1) - # Lazily loaded attributes. - self.model: Optional[nn.Module] = None - self.attn_layer_names: list[str] = [] - - def prepare_next_token_ids_cpu( - self, sampled_token_ids: list[list[int]], - requests: dict[str, - CachedRequestState], gpu_input_batch: InputBatch, - num_scheduled_tokens: dict[str, int]) -> torch.Tensor: - """ - This function is used to prepare the inputs for speculative decoding. - It calculates the next token ids for each request based on the sampled - token ids from the CPU. If a request has no sampled token ids (e.g., - during the initial decoding steps), it falls back to using the request - state to get the next token id. - """ - req_ids = gpu_input_batch.req_ids - next_token_ids: list[int] = [] - for i, token_ids in enumerate(sampled_token_ids): - if token_ids: - # Common case. - next_token_id = token_ids[-1] - else: - # Partial prefill (rare case). - # Get the next token id from the request state. - req_id = req_ids[i] - req_state = requests[req_id] - seq_len = (req_state.num_computed_tokens + - num_scheduled_tokens[req_id]) - next_token_id = req_state.get_token_id(seq_len) - next_token_ids.append(next_token_id) - next_token_ids = torch.tensor(next_token_ids, - dtype=torch.int32, - device=self.input_ids.device) - return next_token_ids - - def prepare_next_token_ids_padded(self, - common_attn_metadata: CommonAttentionMetadata, - sampled_token_ids: torch.Tensor, - requests: dict[str, CachedRequestState], - gpu_input_batch: InputBatch, - discard_request_indices: torch.Tensor, - num_discarded_requests: int) -> \ - tuple[torch.Tensor, torch.Tensor]: - """ - This function is used to prepare the inputs for speculative decoding. - It calculates the next token ids and the number of valid sampled tokens - for each request, considering the "discarded" requests whose next token - is not sampled and comes from `request.get_token_id()` instead. - It also accounts for the rejected tokens in `sampled_token_ids`. - This function must use device functions to operate on the inputs, and - should not introduce any blocking CPU-GPU synchronization. - """ - # TODO(Ben): Combine this into a custom fused kernel - - # Precompute get_token_id for when there is no valid next token - num_reqs = gpu_input_batch.num_reqs - self.backup_next_token_ids.np[:num_reqs] = np.array([ - requests[gpu_input_batch.req_ids[i]].get_token_id( - common_attn_metadata.seq_lens_cpu[i].item()) - for i in range(num_reqs) - ]) - self.backup_next_token_ids.copy_to_gpu(num_reqs) - - # Mask out the sampled tokens indices that should not be sampled. - discard_sampled_tokens_req_indices = \ - discard_request_indices[:num_discarded_requests] - - valid_sampled_token_ids_gpu = sampled_token_ids.clone() - valid_sampled_token_ids_gpu.index_fill_( - 0, discard_sampled_tokens_req_indices, -1) - - # Generate a mask for all valid tokens within those requests - max_gen_len = sampled_token_ids.shape[-1] - if max_gen_len == 1: - valid_mask = torch.ones_like(valid_sampled_token_ids_gpu, - dtype=torch.bool) - else: - valid_mask = ( - (valid_sampled_token_ids_gpu != -1) & - (valid_sampled_token_ids_gpu < gpu_input_batch.vocab_size)) - - # Count the number of valid tokens in each request - valid_sampled_tokens_count = valid_mask.sum(dim=1) - - # Get the rightmost valid index per row - last_valid_indices = valid_sampled_tokens_count - 1 - last_valid_indices_safe = torch.clamp(last_valid_indices, min=0) - - # Get last valid token from each row - # (assume undefined state where there is no valid token) - selected_tokens = torch.gather( - valid_sampled_token_ids_gpu, 1, - last_valid_indices_safe.unsqueeze(1)).squeeze(1) - - # Use last token if valid, pre-computed backup if not - batch_size = valid_sampled_token_ids_gpu.shape[0] - next_token_ids = torch.where( - last_valid_indices != -1, selected_tokens, - self.backup_next_token_ids.gpu[:batch_size]) - - return next_token_ids, valid_sampled_tokens_count - - def prepare_inputs( - self, - common_attn_metadata: CommonAttentionMetadata, - sampled_token_ids: list[list[int]], - num_draft_tokens: list[int], - ) -> tuple[CommonAttentionMetadata, torch.Tensor]: - """ - This function is used to prepare the inputs for speculative decoding. - It updates to the common_attn_metadata to account for the rejected - tokens (and newly sampled tokens). It also returns the token indices - of the tokens that should be fed to the speculator. - """ - # E.g. - # common_attn_metadata.query_start_loc{_cpu}: - # [0, q1, q1 + q2, q1 + q2 + q3] - # common_attn_metadata.seq_lens{_cpu}: [s1, s2, s3] - # num_rejected_tokens: [n1, n2, n3] - # This function computes the intermediate values: - # num_tokens_per_req: [q1 - n1, q2 - n2, q3 - n3] - # And returns: - # common_attn_metadata.query_start_loc{_cpu}: - # [0, q1 - n1, q1 + q2 - n1 - n2, q1 + q2 + q3 - n1 - n2 - n3] - # common_attn_metadata.seq_lens{_cpu}: - # [s1 - n1 + 1, s2 - n2 + 1, s3 - n3 + 1] - # token_indices: [0, 1, ..., q1 - n1 - 1, - # q1, q1 + 1, ..., q1 + q2 - n2 - 1, - # q1 + q2, q1 + q2 + 1, ..., q1 + q2 + q3 - n3 - 1] - - num_rejected_tokens = [ - n + 1 - len(sampled_token_ids[i]) if n > 0 else 0 - for i, n in enumerate(num_draft_tokens) - ] - num_rejected_tokens = torch.tensor(num_rejected_tokens, - dtype=torch.int32) - - device = common_attn_metadata.query_start_loc.device - query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu - new_seq_lens_cpu = common_attn_metadata.seq_lens_cpu \ - - num_rejected_tokens - - # [0, q1, q1 + q2, q1 + q2 + q3] -> [q1, q2, q3] - new_query_len_per_req = (query_start_loc_cpu[1:] - - query_start_loc_cpu[:-1]) - # [q1, q2, q3] -> [q1 - n1, q2 - n2, q3 - n3] - new_num_tokens_per_req = new_query_len_per_req - num_rejected_tokens - new_num_tokens_per_req_np = new_num_tokens_per_req.numpy() - - # [q1 - n1, q2 - n2, q3 - n3] -> - # [0, q1 - n1, q1 + q2 - n1 - n2, q1 + q2 + q3 - n1 - n2 - n3] - new_query_start_loc_cpu = torch.zeros( - query_start_loc_cpu.shape, - dtype=torch.int32, - pin_memory=is_pin_memory_available()) - new_query_start_loc_np = new_query_start_loc_cpu.numpy() - np.cumsum(new_num_tokens_per_req_np, out=new_query_start_loc_np[1:]) - - total_num_tokens = new_query_start_loc_np[-1] - # Example assuming num_tokens_per_req_np = [2, 4, 3] - # this implies that `new_query_start_locs` is: - # [0, 2, 6, 9] -> - # [0, 0, 2, 2, 2, 2, 6, 6, 6] - # _r1_ ____r2____ ___r3__ - new_query_start_locs_expanded = np.repeat(new_query_start_loc_np[:-1], - new_num_tokens_per_req_np) - # [0, 1, 2, 3, 4, 5, 6, 7, 8] -> - # [0, 1, 0, 1, 2, 3, 0, 1, 2] - # _r1_ ____r2____ ___r3__ - token_offests = self.token_arange_np[:total_num_tokens] \ - - new_query_start_locs_expanded - - # Expand starting positions to match token pattern - # [0, q1, q1 + q2] -> - # [0, 0, q1, q1, q1, q1, q1 + q2, q1 + q2, q1 + q2] - # _r1_ _____r2_______ ___________r3____________ - old_query_start_locs_expanded = np.repeat( - query_start_loc_cpu[:-1].numpy(), new_num_tokens_per_req_np) - # Final token indices are: - # [0, 1, // req 1 - # q1 + 0, q1 + 1, q1 + 2, q1 + 3, // req 2 - # q1 + q2 + 0, q1 + q2 + 1, q1 + q2 + 2] // req 3 - token_indices_np = token_offests + old_query_start_locs_expanded - token_indices = torch.from_numpy(token_indices_np).to( - device, non_blocking=True) - - spec_common_attn_metadata = CommonAttentionMetadata( - query_start_loc=new_query_start_loc_cpu.to(device, - non_blocking=True), - seq_lens=new_seq_lens_cpu.to(device, non_blocking=True), - query_start_loc_cpu=new_query_start_loc_cpu, - seq_lens_cpu=new_seq_lens_cpu, - num_computed_tokens_cpu=common_attn_metadata. - num_computed_tokens_cpu, - num_reqs=common_attn_metadata.num_reqs, - num_actual_tokens=total_num_tokens, - max_query_len=new_query_len_per_req.max().item(), - max_seq_len=new_seq_lens_cpu.max().item(), - block_table_tensor=common_attn_metadata.block_table_tensor, - slot_mapping=common_attn_metadata.slot_mapping[token_indices], - causal=True, - ) - return spec_common_attn_metadata, token_indices - - def prepare_inputs_padded(self, - common_attn_metadata: CommonAttentionMetadata, - spec_decode_metadata: SpecDecodeMetadata, - valid_sampled_tokens_count: torch.Tensor) -> \ - tuple[CommonAttentionMetadata, torch.Tensor, torch.Tensor]: - """ - This function is used to prepare the inputs for speculative decoding - It updates the common_attn_metadata for speculative decoding, - but does not consider the rejected tokens. Instead, all tokens - are included as inputs to the speculator, with the rejected tokens - used as padding and filtered out later by `token_indices_to_sample`. - No blocking CPU operations should be introduced in this function. - """ - num_draft_tokens_gpu = torch.cat([ - spec_decode_metadata.cu_num_draft_tokens[0:1], - spec_decode_metadata.cu_num_draft_tokens[1:] - - spec_decode_metadata.cu_num_draft_tokens[:-1] - ]) - - num_rejected_tokens_gpu = torch.where( - num_draft_tokens_gpu > 0, - num_draft_tokens_gpu + 1 - valid_sampled_tokens_count, - torch.zeros_like(num_draft_tokens_gpu)) - - query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu - - new_query_len_per_req = (query_start_loc_cpu[1:] - - query_start_loc_cpu[:-1]) - - total_num_tokens = query_start_loc_cpu[-1].item() - token_indices = self.arange[:total_num_tokens] - - spec_common_attn_metadata = CommonAttentionMetadata( - query_start_loc=common_attn_metadata.query_start_loc, - seq_lens=common_attn_metadata.seq_lens, - query_start_loc_cpu=query_start_loc_cpu, - seq_lens_cpu=common_attn_metadata.seq_lens_cpu, - num_computed_tokens_cpu=common_attn_metadata. - num_computed_tokens_cpu, - num_reqs=common_attn_metadata.num_reqs, - num_actual_tokens=total_num_tokens, - max_query_len=new_query_len_per_req.max().item(), - max_seq_len=common_attn_metadata.seq_lens_cpu.max().item(), - block_table_tensor=common_attn_metadata.block_table_tensor, - slot_mapping=common_attn_metadata.slot_mapping[token_indices], - causal=True, - ) - - token_indices_to_sample = common_attn_metadata.query_start_loc[1:] - 1 \ - - num_rejected_tokens_gpu - - return spec_common_attn_metadata, token_indices, token_indices_to_sample - def propose( self, # [num_tokens] @@ -443,7 +187,6 @@ def propose( num_tokens, last_token_indices) assert self.runner is not None - assert isinstance(self.model, nn.Module) # Select the correct attention metadata builders for EAGLE layers. # Get the attention metadata builders once and reuse for later. @@ -472,7 +215,7 @@ def propose( if self.is_multimodal_model: input_ids = self.input_ids[:num_tokens] - inputs_embeds = self.model.get_input_embeddings( # type: ignore + inputs_embeds = self.model.get_input_embeddings( input_ids, multimodal_embeddings=mm_embeds or None, ) @@ -508,8 +251,7 @@ def propose( else: last_hidden_states, hidden_states = ret_hidden_states sample_hidden_states = last_hidden_states[last_token_indices] - logits = self.model.compute_logits( - sample_hidden_states) # type: ignore + logits = self.model.compute_logits(sample_hidden_states) # Early exit if there is only one draft token to be generated. if self.num_forward_passes == 1: @@ -611,8 +353,7 @@ def propose( self.positions[:batch_size] = clamped_positions self.hidden_states[:batch_size] = hidden_states if self.is_multimodal_model: - inputs_embeds = self.model.get_input_embeddings( - input_ids) # type: ignore + inputs_embeds = self.model.get_input_embeddings(input_ids) self.inputs_embeds[:batch_size] = inputs_embeds inputs_embeds = self.inputs_embeds[:input_batch_size] input_ids = None @@ -620,44 +361,220 @@ def propose( inputs_embeds = None input_ids = self.input_ids[:input_batch_size] - # Run the model. - model_kwargs = { - "input_ids": input_ids, - "positions": self.positions[:input_batch_size], - "inputs_embeds": inputs_embeds, - } - if self.pass_hidden_states_to_model: - model_kwargs[ - "hidden_states"] = self.hidden_states[:input_batch_size] + # Run the model. + model_kwargs = { + "input_ids": input_ids, + "positions": self.positions[:input_batch_size], + "inputs_embeds": inputs_embeds, + } + if self.pass_hidden_states_to_model: + model_kwargs[ + "hidden_states"] = self.hidden_states[:input_batch_size] + + forward_ctx_kwargs = dict( + attn_metadata=per_layer_attn_metadata, + vllm_config=self.vllm_config, + num_tokens=input_batch_size, + ) + if self.pass_cudagraph_args_to_forward_ctx: + cudagraph_args = self.decoding_cudagraph_args( + num_tokens=input_batch_size) + forward_ctx_kwargs.update(cudagraph_args) + + with set_forward_context(**forward_ctx_kwargs): + ret_hidden_states = self.model(**model_kwargs) + if not self.model_returns_tuple(): + last_hidden_states = ret_hidden_states + hidden_states = ret_hidden_states + else: + last_hidden_states, hidden_states = ret_hidden_states + hidden_states = hidden_states[:batch_size] + logits = self.model.compute_logits(last_hidden_states[:batch_size]) + draft_token_ids = logits.argmax(dim=-1) + draft_token_ids_list.append(draft_token_ids) + + if self.drop_first_drafted_tokens: + draft_token_ids_list = draft_token_ids_list[1:] + + # [batch_size, num_speculative_tokens] + draft_token_ids = torch.stack(draft_token_ids_list, dim=1) + return draft_token_ids + + def set_input_ids_first_pass(self, target_token_ids: torch.Tensor, + next_token_ids: torch.Tensor, num_tokens: int, + last_token_indices: torch.Tensor) -> None: + # Shift the input ids by one token. + # E.g., [a1, b1, b2, c1, c2, c3] -> [b1, b2, c1, c2, c3, c3] + self.input_ids[:num_tokens - 1] = target_token_ids[1:] + # Replace the last token with the next token. + # E.g., [b1, b2, c1, c2, c3, c3] -> [a2, b2, b3, c2, c3, c4] + self.input_ids[last_token_indices] = next_token_ids + + def model_returns_tuple(self) -> bool: + return self.method not in ("deepseek_mtp", "ernie_mtp", + "qwen3_next_mtp", "draft_model") + + def decoding_cudagraph_args(self, num_tokens: int) -> "CudaGraphArgs": + batch_descriptor = BatchDescriptor(num_tokens=num_tokens, + uniform_decode=True) + cudagraph_runtime_mode, batch_descriptor = ( + self.runner.cudagraph_dispatcher.dispatch(batch_descriptor)) + return CudaGraphArgs( + cudagraph_runtime_mode=cudagraph_runtime_mode, + batch_descriptor=batch_descriptor, + ) + + def prepare_next_token_ids_cpu( + self, sampled_token_ids: list[list[int]], + requests: dict[str, + CachedRequestState], gpu_input_batch: InputBatch, + num_scheduled_tokens: dict[str, int]) -> torch.Tensor: + """ + This function is used to prepare the inputs for speculative decoding. + It calculates the next token ids for each request based on the sampled + token ids from the CPU. If a request has no sampled token ids (e.g., + during the initial decoding steps), it falls back to using the request + state to get the next token id. + """ + req_ids = gpu_input_batch.req_ids + next_token_ids: list[int] = [] + for i, token_ids in enumerate(sampled_token_ids): + if token_ids: + # Common case. + next_token_id = token_ids[-1] + else: + # Partial prefill (rare case). + # Get the next token id from the request state. + req_id = req_ids[i] + req_state = requests[req_id] + seq_len = (req_state.num_computed_tokens + + num_scheduled_tokens[req_id]) + next_token_id = req_state.get_token_id(seq_len) + next_token_ids.append(next_token_id) + next_token_ids = torch.tensor(next_token_ids, + dtype=torch.int32, + device=self.input_ids.device) + return next_token_ids + + def prepare_next_token_ids_padded(self, + common_attn_metadata: CommonAttentionMetadata, + sampled_token_ids: torch.Tensor, + requests: dict[str, CachedRequestState], + gpu_input_batch: InputBatch, + discard_request_indices: torch.Tensor, + num_discarded_requests: int) -> \ + tuple[torch.Tensor, torch.Tensor]: + """ + This function is used to prepare the inputs for speculative decoding. + It calculates the next token ids and the number of valid sampled tokens + for each request, considering the "discarded" requests whose next token + is not sampled and comes from `request.get_token_id()` instead. + It also accounts for the rejected tokens in `sampled_token_ids`. + This function must use device functions to operate on the inputs, and + should not introduce any blocking CPU-GPU synchronization. + """ + # TODO(Ben): Combine this into a custom fused kernel + + # Precompute get_token_id for when there is no valid next token + num_reqs = gpu_input_batch.num_reqs + self.backup_next_token_ids.np[:num_reqs] = np.array([ + requests[gpu_input_batch.req_ids[i]].get_token_id( + common_attn_metadata.seq_lens_cpu[i].item()) + for i in range(num_reqs) + ]) + self.backup_next_token_ids.copy_to_gpu(num_reqs) + + # Mask out the sampled tokens indices that should not be sampled. + discard_sampled_tokens_req_indices = \ + discard_request_indices[:num_discarded_requests] + + valid_sampled_token_ids_gpu = sampled_token_ids.clone() + valid_sampled_token_ids_gpu.index_fill_( + 0, discard_sampled_tokens_req_indices, -1) + + # Generate a mask for all valid tokens within those requests + max_gen_len = sampled_token_ids.shape[-1] + if max_gen_len == 1: + valid_mask = torch.ones_like(valid_sampled_token_ids_gpu, + dtype=torch.bool) + else: + valid_mask = ( + (valid_sampled_token_ids_gpu != -1) & + (valid_sampled_token_ids_gpu < gpu_input_batch.vocab_size)) + + # Count the number of valid tokens in each request + valid_sampled_tokens_count = valid_mask.sum(dim=1) + + # Get the rightmost valid index per row + last_valid_indices = valid_sampled_tokens_count - 1 + last_valid_indices_safe = torch.clamp(last_valid_indices, min=0) + + # Get last valid token from each row + # (assume undefined state where there is no valid token) + selected_tokens = torch.gather( + valid_sampled_token_ids_gpu, 1, + last_valid_indices_safe.unsqueeze(1)).squeeze(1) + + # Use last token if valid, pre-computed backup if not + batch_size = valid_sampled_token_ids_gpu.shape[0] + next_token_ids = torch.where( + last_valid_indices != -1, selected_tokens, + self.backup_next_token_ids.gpu[:batch_size]) + + return next_token_ids, valid_sampled_tokens_count + + def prepare_inputs_padded(self, + common_attn_metadata: CommonAttentionMetadata, + spec_decode_metadata: SpecDecodeMetadata, + valid_sampled_tokens_count: torch.Tensor) -> \ + tuple[CommonAttentionMetadata, torch.Tensor, torch.Tensor]: + """ + This function is used to prepare the inputs for speculative decoding + It updates the common_attn_metadata for speculative decoding, + but does not consider the rejected tokens. Instead, all tokens + are included as inputs to the speculator, with the rejected tokens + used as padding and filtered out later by `token_indices_to_sample`. + No blocking CPU operations should be introduced in this function. + """ + num_draft_tokens_gpu = torch.cat([ + spec_decode_metadata.cu_num_draft_tokens[0:1], + spec_decode_metadata.cu_num_draft_tokens[1:] - + spec_decode_metadata.cu_num_draft_tokens[:-1] + ]) + + num_rejected_tokens_gpu = torch.where( + num_draft_tokens_gpu > 0, + num_draft_tokens_gpu + 1 - valid_sampled_tokens_count, + torch.zeros_like(num_draft_tokens_gpu)) + + query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu - forward_ctx_kwargs = dict( - attn_metadata=per_layer_attn_metadata, - vllm_config=self.vllm_config, - num_tokens=input_batch_size, - ) - cudagraph_args = self.new_cudagraph_args( - num_tokens=input_batch_size) - if self.pass_cudagraph_args_to_forward_ctx: - forward_ctx_kwargs.update(cudagraph_args) + new_query_len_per_req = (query_start_loc_cpu[1:] - + query_start_loc_cpu[:-1]) - with set_forward_context(**forward_ctx_kwargs): - ret_hidden_states = self.model(**model_kwargs) # type: ignore - if not self.model_returns_tuple(): - last_hidden_states = ret_hidden_states - hidden_states = ret_hidden_states - else: - last_hidden_states, hidden_states = ret_hidden_states - hidden_states = hidden_states[:batch_size] - logits = self.model.compute_logits(last_hidden_states[:batch_size]) - draft_token_ids = logits.argmax(dim=-1) - draft_token_ids_list.append(draft_token_ids) + total_num_tokens = query_start_loc_cpu[-1].item() + token_indices = self.arange[:total_num_tokens] - if self.drop_first_drafted_tokens: - draft_token_ids_list = draft_token_ids_list[1:] + spec_common_attn_metadata = CommonAttentionMetadata( + query_start_loc=common_attn_metadata.query_start_loc, + seq_lens=common_attn_metadata.seq_lens, + query_start_loc_cpu=query_start_loc_cpu, + seq_lens_cpu=common_attn_metadata.seq_lens_cpu, + num_computed_tokens_cpu=common_attn_metadata. + num_computed_tokens_cpu, + num_reqs=common_attn_metadata.num_reqs, + num_actual_tokens=total_num_tokens, + max_query_len=new_query_len_per_req.max().item(), + max_seq_len=common_attn_metadata.seq_lens_cpu.max().item(), + block_table_tensor=common_attn_metadata.block_table_tensor, + slot_mapping=common_attn_metadata.slot_mapping[token_indices], + causal=True, + ) - # [batch_size, num_speculative_tokens] - draft_token_ids = torch.stack(draft_token_ids_list, dim=1) - return draft_token_ids + token_indices_to_sample = common_attn_metadata.query_start_loc[1:] - 1 \ + - num_rejected_tokens_gpu + + return spec_common_attn_metadata, token_indices, token_indices_to_sample def propose_tree( self, @@ -789,7 +706,6 @@ def propose_tree( else: num_input_tokens = num_tokens # Run the model. - assert isinstance(self.model, nn.Module) with set_forward_context(per_layer_attn_metadata, self.vllm_config, num_tokens=num_input_tokens): @@ -827,71 +743,108 @@ def propose_tree( total_num_drafts = self.cu_drafts_per_level[level + 1] return draft_token_ids_list - @abstractmethod - def set_input_ids_first_pass(self, target_token_ids: torch.Tensor, - next_token_ids: torch.Tensor, num_tokens: int, - last_token_indices: torch.Tensor) -> None: - raise NotImplementedError() - - def model_returns_tuple(self) -> bool: - return self.method not in ("deepseek_mtp", "ernie_mtp", - "qwen3_next_mtp", "draft_model") - - def new_cudagraph_args(self, num_tokens: int) -> "CudaGraphArgs": - batch_descriptor = BatchDescriptor(num_tokens=num_tokens, - uniform_decode=True) - cudagraph_runtime_mode, batch_descriptor = ( - self.runner.cudagraph_dispatcher.dispatch(batch_descriptor)) - return CudaGraphArgs( - cudagraph_runtime_mode=cudagraph_runtime_mode, - batch_descriptor=batch_descriptor, - ) - - def _get_attention_metadata_builder( - self) -> list[AttentionMetadataBuilder]: - """Find and return the attention metadata builders for EAGLE layers. - - Returns: - The metadata builders for EAGLE layers. - - Raises: - AssertionError: If no metadata builders are found for EAGLE layers. + def prepare_inputs( + self, + common_attn_metadata: CommonAttentionMetadata, + sampled_token_ids: list[list[int]], + num_draft_tokens: list[int], + ) -> tuple[CommonAttentionMetadata, torch.Tensor]: """ - builder = None - chosen_layer = self.attn_layer_names[0] + This function is used to prepare the inputs for speculative decoding. + It updates to the common_attn_metadata to account for the rejected + tokens (and newly sampled tokens). It also returns the token indices + of the tokens that should be fed to the speculator. + """ + # E.g. + # common_attn_metadata.query_start_loc{_cpu}: + # [0, q1, q1 + q2, q1 + q2 + q3] + # common_attn_metadata.seq_lens{_cpu}: [s1, s2, s3] + # num_rejected_tokens: [n1, n2, n3] + # This function computes the intermediate values: + # num_tokens_per_req: [q1 - n1, q2 - n2, q3 - n3] + # And returns: + # common_attn_metadata.query_start_loc{_cpu}: + # [0, q1 - n1, q1 + q2 - n1 - n2, q1 + q2 + q3 - n1 - n2 - n3] + # common_attn_metadata.seq_lens{_cpu}: + # [s1 - n1 + 1, s2 - n2 + 1, s3 - n3 + 1] + # token_indices: [0, 1, ..., q1 - n1 - 1, + # q1, q1 + 1, ..., q1 + q2 - n2 - 1, + # q1 + q2, q1 + q2 + 1, ..., q1 + q2 + q3 - n3 - 1] - for kv_cache_group in self.runner.attn_groups: - for attn_group in kv_cache_group: - if chosen_layer in attn_group.layer_names: - builder = attn_group.get_metadata_builder() - break - if builder is not None: - break + num_rejected_tokens = [ + n + 1 - len(sampled_token_ids[i]) if n > 0 else 0 + for i, n in enumerate(num_draft_tokens) + ] + num_rejected_tokens = torch.tensor(num_rejected_tokens, + dtype=torch.int32) - assert builder is not None, ( - "Failed to find attention metadata builder for EAGLE layers.") - return builder + device = common_attn_metadata.query_start_loc.device + query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu + new_seq_lens_cpu = common_attn_metadata.seq_lens_cpu \ + - num_rejected_tokens + # [0, q1, q1 + q2, q1 + q2 + q3] -> [q1, q2, q3] + new_query_len_per_req = (query_start_loc_cpu[1:] - + query_start_loc_cpu[:-1]) + # [q1, q2, q3] -> [q1 - n1, q2 - n2, q3 - n3] + new_num_tokens_per_req = new_query_len_per_req - num_rejected_tokens + new_num_tokens_per_req_np = new_num_tokens_per_req.numpy() -class CudaGraphArgs(TypedDict): - cudagraph_runtime_mode: CUDAGraphMode - batch_descriptor: BatchDescriptor + # [q1 - n1, q2 - n2, q3 - n3] -> + # [0, q1 - n1, q1 + q2 - n1 - n2, q1 + q2 + q3 - n1 - n2 - n3] + new_query_start_loc_cpu = torch.zeros( + query_start_loc_cpu.shape, + dtype=torch.int32, + pin_memory=is_pin_memory_available()) + new_query_start_loc_np = new_query_start_loc_cpu.numpy() + np.cumsum(new_num_tokens_per_req_np, out=new_query_start_loc_np[1:]) + total_num_tokens = new_query_start_loc_np[-1] + # Example assuming num_tokens_per_req_np = [2, 4, 3] + # this implies that `new_query_start_locs` is: + # [0, 2, 6, 9] -> + # [0, 0, 2, 2, 2, 2, 6, 6, 6] + # _r1_ ____r2____ ___r3__ + new_query_start_locs_expanded = np.repeat(new_query_start_loc_np[:-1], + new_num_tokens_per_req_np) + # [0, 1, 2, 3, 4, 5, 6, 7, 8] -> + # [0, 1, 0, 1, 2, 3, 0, 1, 2] + # _r1_ ____r2____ ___r3__ + token_offests = self.token_arange_np[:total_num_tokens] \ + - new_query_start_locs_expanded -class EagleProposer(SpecDecodeBaseProposer): + # Expand starting positions to match token pattern + # [0, q1, q1 + q2] -> + # [0, 0, q1, q1, q1, q1, q1 + q2, q1 + q2, q1 + q2] + # _r1_ _____r2_______ ___________r3____________ + old_query_start_locs_expanded = np.repeat( + query_start_loc_cpu[:-1].numpy(), new_num_tokens_per_req_np) + # Final token indices are: + # [0, 1, // req 1 + # q1 + 0, q1 + 1, q1 + 2, q1 + 3, // req 2 + # q1 + q2 + 0, q1 + q2 + 1, q1 + q2 + 2] // req 3 + token_indices_np = token_offests + old_query_start_locs_expanded + token_indices = torch.from_numpy(token_indices_np).to( + device, non_blocking=True) - def __init__( - self, - vllm_config: VllmConfig, - device: torch.device, - runner=None, - ): - super().__init__(vllm_config=vllm_config, - device=device, - pass_hidden_states_to_model=True, - pass_cudagraph_args_to_forward_ctx=False, - drop_first_drafted_tokens=False, - runner=runner) + spec_common_attn_metadata = CommonAttentionMetadata( + query_start_loc=new_query_start_loc_cpu.to(device, + non_blocking=True), + seq_lens=new_seq_lens_cpu.to(device, non_blocking=True), + query_start_loc_cpu=new_query_start_loc_cpu, + seq_lens_cpu=new_seq_lens_cpu, + num_computed_tokens_cpu=common_attn_metadata. + num_computed_tokens_cpu, + num_reqs=common_attn_metadata.num_reqs, + num_actual_tokens=total_num_tokens, + max_query_len=new_query_len_per_req.max().item(), + max_seq_len=new_seq_lens_cpu.max().item(), + block_table_tensor=common_attn_metadata.block_table_tensor, + slot_mapping=common_attn_metadata.slot_mapping[token_indices], + causal=True, + ) + + return spec_common_attn_metadata, token_indices def load_model(self, target_model: nn.Module) -> None: draft_model_config = \ @@ -912,7 +865,7 @@ def load_model(self, target_model: nn.Module) -> None: if supports_multimodal(target_model): # handle multimodality - self.model.config.image_token_index = ( # type: ignore + self.model.config.image_token_index = ( target_model.config.image_token_index) target_language_model = target_model.get_language_model() else: @@ -982,7 +935,6 @@ def dummy_run( input_ids = self.input_ids[:num_tokens] inputs_embeds = None - assert isinstance(self.model, nn.Module) self.model( input_ids=input_ids, positions=self.positions[:num_tokens], @@ -990,6 +942,31 @@ def dummy_run( inputs_embeds=inputs_embeds, ) + def _get_attention_metadata_builder( + self) -> list[AttentionMetadataBuilder]: + """Find and return the attention metadata builders for EAGLE layers. + + Returns: + The metadata builders for EAGLE layers. + + Raises: + AssertionError: If no metadata builders are found for EAGLE layers. + """ + builder = None + chosen_layer = self.attn_layer_names[0] + + for kv_cache_group in self.runner.attn_groups: + for attn_group in kv_cache_group: + if chosen_layer in attn_group.layer_names: + builder = attn_group.get_metadata_builder() + break + if builder is not None: + break + + assert builder is not None, ( + "Failed to find attention metadata builder for EAGLE layers.") + return builder + def validate_same_kv_cache_group(self, kv_cache_config: KVCacheConfig) -> None: """ @@ -1009,15 +986,27 @@ def validate_same_kv_cache_group(self, ]) ) == 1, "All eagle layers should belong to the same kv cache group" - def set_input_ids_first_pass(self, target_token_ids: torch.Tensor, - next_token_ids: torch.Tensor, num_tokens: int, - last_token_indices: torch.Tensor) -> None: - # Shift the input ids by one token. - # E.g., [a1, b1, b2, c1, c2, c3] -> [b1, b2, c1, c2, c3, c3] - self.input_ids[:num_tokens - 1] = target_token_ids[1:] - # Replace the last token with the next token. - # E.g., [b1, b2, c1, c2, c3, c3] -> [a2, b2, b3, c2, c3, c4] - self.input_ids[last_token_indices] = next_token_ids + +class CudaGraphArgs(TypedDict): + cudagraph_runtime_mode: CUDAGraphMode + batch_descriptor: BatchDescriptor + + +class EagleProposer(SpecDecodeBaseProposer): + + def __init__( + self, + vllm_config: VllmConfig, + device: torch.device, + runner=None, + ): + super().__init__(vllm_config, + device, + pass_hidden_states_to_model=True, + pass_cudagraph_args_to_forward_ctx=False, + one_extra_forward_pass=False, + drop_first_drafted_tokens=False, + runner=runner) # NOTE(woosuk): Currently, the below code is not used and we always use argmax From 5967e0999e4ad334146c1c3b9d1d8a1d74645797 Mon Sep 17 00:00:00 2001 From: Tomas Ruiz Date: Wed, 24 Sep 2025 09:18:47 +0200 Subject: [PATCH 17/73] Fix missing input Signed-off-by: Tomas Ruiz --- vllm/v1/spec_decode/draft_model.py | 8 ++++---- vllm/v1/worker/gpu_model_runner.py | 3 ++- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/vllm/v1/spec_decode/draft_model.py b/vllm/v1/spec_decode/draft_model.py index 484fdd7db0e5..99d525fbe9d4 100644 --- a/vllm/v1/spec_decode/draft_model.py +++ b/vllm/v1/spec_decode/draft_model.py @@ -25,15 +25,14 @@ def __init__( device=device, pass_hidden_states_to_model=False, pass_cudagraph_args_to_forward_ctx=True, + # The draft model runs one forward pass to prefill + # the target_token_ids, and another forward pass for decoding + # based on the next_token_ids. I.e. it needs 1 more forward pass. one_extra_forward_pass=True, # the first draft_token_ids are identical to next_token_ids, so # they don't need to be returned as proposed tokens drop_first_drafted_tokens=True, runner=runner) - # The draft model runs one forward pass to prefill - # the target_token_ids, and another forward pass for decoding - # based on the next_token_ids. I.e. it needs 1 more forward pass. - self.num_forward_passes = self.num_speculative_tokens + 1 self._raise_if_multimodal() self._raise_if_mrope() @@ -71,6 +70,7 @@ def set_input_ids_first_pass(self, target_token_ids: torch.Tensor, self.input_ids[:num_tokens] = target_token_ids def load_model(self, target_model: Any) -> None: + """Takes target_model to satisfy the type checker.""" draft_model_config: ModelConfig = ( self.vllm_config.speculative_config.draft_model_config) vllm_config_draft: VllmConfig = replace( diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index a92e9b8dce0e..b966286aa4d8 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -2604,7 +2604,8 @@ def load_model(self, eep_scale_up: bool = False) -> None: self.drafter.load_model(self.model) elif self.speculative_config.uses_draft_model(): assert isinstance(self.drafter, DraftModelProposer) - self.drafter.load_model() + # Passed something to satisfy the type checker + self.drafter.load_model(None) if self.use_aux_hidden_state_outputs: if supports_eagle3(self.model): self.model.set_aux_hidden_state_layers( From 7b03a455e08c0963fff2b589d7e2b1845463bb8a Mon Sep 17 00:00:00 2001 From: Benjamin Chislett Date: Thu, 25 Sep 2025 01:43:51 +0000 Subject: [PATCH 18/73] fix next_token_ids issue Signed-off-by: Benjamin Chislett --- vllm/v1/spec_decode/draft_model.py | 2 +- vllm/v1/spec_decode/eagle.py | 5 ++++- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/vllm/v1/spec_decode/draft_model.py b/vllm/v1/spec_decode/draft_model.py index 99d525fbe9d4..539686c4cecb 100644 --- a/vllm/v1/spec_decode/draft_model.py +++ b/vllm/v1/spec_decode/draft_model.py @@ -29,7 +29,7 @@ def __init__( # the target_token_ids, and another forward pass for decoding # based on the next_token_ids. I.e. it needs 1 more forward pass. one_extra_forward_pass=True, - # the first draft_token_ids are identical to next_token_ids, so + # the first draft_token_ids are replaced by next_token_ids, so # they don't need to be returned as proposed tokens drop_first_drafted_tokens=True, runner=runner) diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index dac219d7a738..9ccf097932fe 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -284,7 +284,10 @@ def propose( f"{self.allowed_attn_types}") # Generate the remaining draft tokens. - draft_token_ids_list = [draft_token_ids] + if self.drop_first_drafted_tokens: + draft_token_ids_list = [next_token_ids] + else: + draft_token_ids_list = [draft_token_ids] if self.use_cuda_graph and \ batch_size <= self.cudagraph_batch_sizes[-1]: From c7d2fd528b6d901f85fdc2d268dea2e9cf7178f2 Mon Sep 17 00:00:00 2001 From: Tomas Ruiz Date: Thu, 25 Sep 2025 10:13:02 +0200 Subject: [PATCH 19/73] Test also acceptance-len Signed-off-by: Tomas Ruiz --- tests/v1/e2e/test_spec_decode.py | 19 ++++++++++++++++--- vllm/v1/spec_decode/metrics.py | 10 ++++++++++ 2 files changed, 26 insertions(+), 3 deletions(-) diff --git a/tests/v1/e2e/test_spec_decode.py b/tests/v1/e2e/test_spec_decode.py index b78eecec85d7..9f8ecf3158fd 100644 --- a/tests/v1/e2e/test_spec_decode.py +++ b/tests/v1/e2e/test_spec_decode.py @@ -16,7 +16,8 @@ from vllm.distributed import cleanup_dist_env_and_memory from vllm.outputs import RequestOutput from vllm.platforms import current_platform -from vllm.v1.spec_decode.metrics import compute_acceptance_rate +from vllm.v1.spec_decode.metrics import (compute_acceptance_len, + compute_acceptance_rate) def get_test_prompts(mm_enabled: bool, quiet: bool = False): @@ -243,7 +244,9 @@ class ArgsTest: model: str draft_model: str sampling_config: SamplingParams + num_speculative_tokens: int expected_acceptance_rate: float + expected_acceptance_len: float expected_same_output_fraction: float # Defaults target_tensor_parallel_size: int = 1 @@ -253,17 +256,23 @@ class ArgsTest: cases = [ + # Same model for draft and target, greedy sampling. ArgsTest( model="Qwen/Qwen3-0.6B", draft_model="Qwen/Qwen3-0.6B", sampling_config=greedy_sampling(), + num_speculative_tokens=3, # K + expected_acceptance_len=3 + 1, # K + 1 expected_acceptance_rate=1.0, expected_same_output_fraction=1.0, ), + # Smaller draft model, stochastic sampling. ArgsTest( model="Qwen/Qwen3-1.7B", draft_model="Qwen/Qwen3-0.6B", sampling_config=stochastic_sampling(), + num_speculative_tokens=3, + expected_acceptance_len=2.85 + 1, expected_acceptance_rate=0.9, expected_same_output_fraction=0.9, ), @@ -289,7 +298,7 @@ def test_draft_model_correctness( speculative_config={ "model": args.draft_model, "method": "draft_model", - "num_speculative_tokens": 3, + "num_speculative_tokens": args.num_speculative_tokens, "max_model_len": args.max_model_len, "enforce_eager": enforce_eager, "tensor_parallel_size": args.draft_tensor_parallel_size, @@ -302,12 +311,15 @@ def test_draft_model_correctness( disable_log_stats=False, # enables get_metrics() ) spec_outputs = spec_llm.chat(test_prompts, args.sampling_config) - acceptance_rate = compute_acceptance_rate(spec_llm.get_metrics()) + metrics = spec_llm.get_metrics() + acceptance_rate: float = compute_acceptance_rate(metrics) + acceptance_len: float = compute_acceptance_len(metrics) del spec_llm # CLEANUP torch.cuda.empty_cache() cleanup_dist_env_and_memory() assert acceptance_rate >= args.expected_acceptance_rate + assert acceptance_len >= args.expected_acceptance_len ref_llm = LLM( model=args.model, @@ -330,6 +342,7 @@ def test_draft_model_correctness( print(f"spec-decode: target={args.model}, draft={args.draft_model}, " f"temperature={args.sampling_config.temperature:.2f}, " f"acceptance_rate={acceptance_rate:.2f}, " + f"acceptance_len={acceptance_len:.2f}, " f"match_fraction={match_fraction:.2f}") diff --git a/vllm/v1/spec_decode/metrics.py b/vllm/v1/spec_decode/metrics.py index 30a3654c5e8e..437a9cf9f6e6 100644 --- a/vllm/v1/spec_decode/metrics.py +++ b/vllm/v1/spec_decode/metrics.py @@ -212,6 +212,16 @@ def compute_acceptance_rate(metrics: list[Metric]) -> float: return n_accepted_toks / n_draft_toks +def compute_acceptance_len(metrics: list[Metric]) -> float: + name2metric = {metric.name: metric for metric in metrics} + n_drafts = name2metric["vllm:spec_decode_num_drafts"].value # type: ignore + n_accepted_toks = name2metric[ + "vllm:spec_decode_num_accepted_tokens"].value # type: ignore + if n_drafts == 0: + return 1 + return 1 + (n_accepted_toks / n_drafts) + + def make_per_engine(counter: prometheus_client.Counter, per_engine_labelvalues: dict[int, list[str]]): """Create a counter for each label value.""" From ac90311b6099908c895d340dad1b81005415731b Mon Sep 17 00:00:00 2001 From: Tomas Ruiz Date: Fri, 26 Sep 2025 11:18:43 +0200 Subject: [PATCH 20/73] Pass missing argument in test_eagle.py Signed-off-by: Tomas Ruiz --- tests/v1/spec_decode/test_eagle.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/v1/spec_decode/test_eagle.py b/tests/v1/spec_decode/test_eagle.py index 5096f9fd647b..004a9e290b7c 100644 --- a/tests/v1/spec_decode/test_eagle.py +++ b/tests/v1/spec_decode/test_eagle.py @@ -543,7 +543,8 @@ def create_deterministic_logits(token_ids): next_token_ids=next_token_ids, last_token_indices=None, common_attn_metadata=common_attn_metadata, - sampling_metadata=sampling_metadata) + sampling_metadata=sampling_metadata, + cudagraph_args=dict()) assert result.shape == (batch_size, num_speculative_tokens) @@ -698,7 +699,8 @@ def create_deterministic_logits(token_ids, k: int): next_token_ids=next_token_ids, last_token_indices=None, common_attn_metadata=common_attn_metadata, - sampling_metadata=sampling_metadata) + sampling_metadata=sampling_metadata, + cudagraph_args=dict()) assert result.shape == (batch_size, num_speculative_tokens) # The tokens are expected to be consecutive integers starting From b477e101b17dadf0549b468ad37e10e43d1e1d4c Mon Sep 17 00:00:00 2001 From: Tomas Ruiz Date: Fri, 26 Sep 2025 14:17:43 +0200 Subject: [PATCH 21/73] CKPT: Remove extra forward Signed-off-by: Tomas Ruiz --- tests/v1/e2e/test_spec_decode.py | 16 +++++ vllm/v1/attention/backends/utils.py | 3 + vllm/v1/spec_decode/draft_model.py | 100 +++++++++++++++++++++++++++- 3 files changed, 116 insertions(+), 3 deletions(-) diff --git a/tests/v1/e2e/test_spec_decode.py b/tests/v1/e2e/test_spec_decode.py index 823bd7e746a6..1813327dc062 100644 --- a/tests/v1/e2e/test_spec_decode.py +++ b/tests/v1/e2e/test_spec_decode.py @@ -16,6 +16,7 @@ from vllm.distributed import cleanup_dist_env_and_memory from vllm.outputs import RequestOutput from vllm.platforms import current_platform +from vllm.v1.spec_decode.draft_model import append_new_toks from vllm.v1.spec_decode.metrics import (compute_acceptance_len, compute_acceptance_rate) @@ -357,3 +358,18 @@ def compute_exact_matches(ref_outputs: list[RequestOutput], print(f"ref_output: {ref_output.outputs[0].text}") print(f"spec_output: {spec_output.outputs[0].text}") return matches / len(ref_outputs) + + +@pytest.mark.parametrize("device", ["cpu", "cuda"]) +def test_append_new_toks(device: str): + toks = torch.tensor([11, 12, 13, 21, 22, 31], device=device) + start_locs = torch.tensor([0, 3, 5, 6], device=device) + new_toks = torch.tensor([13, 23, 32], device=device) + + expected_toks = torch.tensor([11, 12, 13, 13, 21, 22, 23, 31, 32], + device=device) + expected_start_locs = torch.tensor([0, 4, 7, 9], device=device) + actual_toks, actual_start_locs = append_new_toks(toks, start_locs, + new_toks) + assert torch.all(actual_toks == expected_toks) + assert torch.all(actual_start_locs == expected_start_locs) diff --git a/vllm/v1/attention/backends/utils.py b/vllm/v1/attention/backends/utils.py index f37a829f401c..9e6f48b9cf54 100644 --- a/vllm/v1/attention/backends/utils.py +++ b/vllm/v1/attention/backends/utils.py @@ -83,6 +83,9 @@ class CommonAttentionMetadata: # Needed by CrossAttentionBuilder encoder_seq_lens: Optional[np.ndarray] = None + def batch_size(self) -> int: + return self.seq_lens_cpu.shape[0] + def slice_query_start_locs( query_start_loc: torch.Tensor, diff --git a/vllm/v1/spec_decode/draft_model.py b/vllm/v1/spec_decode/draft_model.py index 539686c4cecb..fbaba79048e6 100644 --- a/vllm/v1/spec_decode/draft_model.py +++ b/vllm/v1/spec_decode/draft_model.py @@ -9,7 +9,9 @@ from vllm.config import ModelConfig, VllmConfig, get_layers_from_vllm_config from vllm.forward_context import set_forward_context from vllm.model_executor.model_loader import get_model +from vllm.v1.attention.backends.utils import CommonAttentionMetadata from vllm.v1.spec_decode.eagle import SpecDecodeBaseProposer +from vllm.v1.spec_decode.metadata import SpecDecodeMetadata class DraftModelProposer(SpecDecodeBaseProposer): @@ -28,14 +30,72 @@ def __init__( # The draft model runs one forward pass to prefill # the target_token_ids, and another forward pass for decoding # based on the next_token_ids. I.e. it needs 1 more forward pass. - one_extra_forward_pass=True, + one_extra_forward_pass=False, # the first draft_token_ids are replaced by next_token_ids, so # they don't need to be returned as proposed tokens - drop_first_drafted_tokens=True, + drop_first_drafted_tokens=False, runner=runner) self._raise_if_multimodal() self._raise_if_mrope() + def prepare_inputs_padded(self, + common_attn_metadata: CommonAttentionMetadata, + spec_decode_metadata: SpecDecodeMetadata, + valid_sampled_tokens_count: torch.Tensor) -> \ + tuple[CommonAttentionMetadata, torch.Tensor, torch.Tensor]: + tup = super().prepare_inputs_padded(common_attn_metadata, + spec_decode_metadata, + valid_sampled_tokens_count) + common_attn_metadata, token_indices, token_indices_to_sample = tup + cad = common_attn_metadata + batch_size = common_attn_metadata.batch_size() + + # token_indices is [0, ..., N], extend by batch_size + new_token_indices = self.arange[:len(token_indices) + batch_size] + # token indices to sample must be increased + # by [+1, +2, ..., +batch_size] + new_token_indices_to_sample = token_indices_to_sample + self.arange[ + 1:batch_size + 1] + + # query start loc mus be increased by [+0, +1, +2, ..., +batch_size] + new_query_start_loc = cad.query_start_loc + self.arange[:len( + cad.query_start_loc)] + # seq lens must be increased by [+1, +1, ..., +1] size batch_size + new_seq_lens = cad.seq_lens + torch.ones_like(cad.seq_lens) + # num requests stays unchanged + new_num_reqs = cad.num_reqs + # num computed tokens are increased by [+1, +1, ..., +1] size batch_size + new_num_computed_tokens_cpu = cad.num_computed_tokens_cpu \ + + torch.ones_like(cad.num_computed_tokens_cpu) + # num actual tokens increases by batch_size + new_num_actual_tokens = cad.num_actual_tokens + batch_size + # max query len and max seq len increases by 1 + new_max_query_len = cad.max_query_len + 1 + new_max_seq_len = cad.max_seq_len + 1 + # block table tensor depends on num_requests, which doesn't change + new_block_table_tensor = cad.block_table_tensor + # slot mapping depends on num_scheduled_tokens, + # which increased by batch_size + assert len(self.runner.input_batch.block_table.block_tables) == 1 + kv_cache_group_id = 0 + new_slot_mapping = self.runner.input_batch.block_table[ + kv_cache_group_id].slot_mapping.gpu[:new_num_actual_tokens] + + new_cad = CommonAttentionMetadata( + query_start_loc=new_query_start_loc, + query_start_loc_cpu=new_query_start_loc.to("cpu"), + seq_lens=new_seq_lens, + seq_lens_cpu=new_seq_lens.to("cpu"), + num_reqs=new_num_reqs, + num_computed_tokens_cpu=new_num_computed_tokens_cpu, + num_actual_tokens=new_num_actual_tokens, + max_query_len=new_max_query_len, + max_seq_len=new_max_seq_len, + block_table_tensor=new_block_table_tensor, + slot_mapping=new_slot_mapping, + ) + return new_cad, new_token_indices, new_token_indices_to_sample + def _raise_if_multimodal(self): if self.is_multimodal_model: raise NotImplementedError("Speculative Decoding with draft models " @@ -67,7 +127,15 @@ def dummy_run(self, num_tokens: int, forward_ctx_kwargs: dict): def set_input_ids_first_pass(self, target_token_ids: torch.Tensor, next_token_ids: torch.Tensor, num_tokens: int, last_token_indices: torch.Tensor) -> None: - self.input_ids[:num_tokens] = target_token_ids + start_locs = torch.zeros(last_token_indices.shape[0] + 1, + device=last_token_indices.device, + dtype=torch.int32) + start_locs[1:] = last_token_indices + 1 + input_ids, _ = append_new_toks(toks=target_token_ids, + start_locs=start_locs, + new_toks=next_token_ids) + num_tokens = input_ids.shape[0] + self.input_ids[:num_tokens] = input_ids def load_model(self, target_model: Any) -> None: """Takes target_model to satisfy the type checker.""" @@ -96,3 +164,29 @@ def load_model(self, target_model: Any) -> None: get_layers_from_vllm_config(self.vllm_config, Attention).keys() - target_attn_layer_names) self.attn_layer_names = list(draft_attn_layer_names) + + +def append_new_toks( + toks: torch.Tensor, start_locs: torch.Tensor, + new_toks: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + long_len = toks.shape[0] + new_toks.shape[0] + long_toks = torch.zeros(long_len, device=toks.device, dtype=toks.dtype) + + # compute indices for previous toks + toks_idxs = torch.ones_like(toks) + toks_idxs[start_locs[1:-1]] += 1 + toks_idxs = toks_idxs.cumsum(0) - 1 + + # compute indices for new toks + new_toks_idxs = start_locs[1:] + torch.arange(new_toks.shape[0], + device=toks.device) + + # assign toks and new toks + long_toks[toks_idxs] = toks + long_toks[new_toks_idxs] = new_toks + + # compute new start locs + new_start_locs = torch.zeros_like(start_locs) + new_start_locs[1:] = new_toks_idxs + 1 + + return long_toks, new_start_locs From 309d827e267b4201188e3b618130cd75c842ca74 Mon Sep 17 00:00:00 2001 From: Tomas Ruiz Date: Fri, 26 Sep 2025 17:01:43 +0200 Subject: [PATCH 22/73] Prevent illegal access to hidden_states Signed-off-by: Tomas Ruiz --- vllm/v1/spec_decode/draft_model.py | 8 ++++++++ vllm/v1/worker/gpu_model_runner.py | 4 +++- 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/vllm/v1/spec_decode/draft_model.py b/vllm/v1/spec_decode/draft_model.py index fbaba79048e6..79144828f8a8 100644 --- a/vllm/v1/spec_decode/draft_model.py +++ b/vllm/v1/spec_decode/draft_model.py @@ -81,6 +81,14 @@ def prepare_inputs_padded(self, new_slot_mapping = self.runner.input_batch.block_table[ kv_cache_group_id].slot_mapping.gpu[:new_num_actual_tokens] + # new_positions = self.runner.positions.gpu[:new_num_actual_tokens] + # block_numbers = new_positions // self.block_size + # block_ids = new_block_table_tensor.gather( + # dim=1, index=block_numbers.view(1, -1)) + # block_ids = block_ids.view(-1) + # new_slot_mapping = (block_ids * self.block_size + # + new_positions % self.block_size) + new_cad = CommonAttentionMetadata( query_start_loc=new_query_start_loc, query_start_loc_cpu=new_query_start_loc.to("cpu"), diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 6b87119abe3e..91a904510567 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -2570,7 +2570,9 @@ def propose_draft_token_ids( target_token_ids = self.input_ids.gpu[token_indices] # TODO(woosuk): Support M-RoPE. target_positions = self.positions.gpu[token_indices] - if self.use_aux_hidden_state_outputs: + if self.speculative_config.uses_draft_model(): + target_hidden_states = None + elif self.use_aux_hidden_state_outputs: assert aux_hidden_states is not None target_hidden_states = torch.cat( [h[token_indices] for h in aux_hidden_states], dim=-1) From 2e97fabdbd237c316e818ccad6f516fa1d505692 Mon Sep 17 00:00:00 2001 From: Tomas Ruiz Date: Sun, 28 Sep 2025 15:02:50 +0200 Subject: [PATCH 23/73] Remove forward. single prompt works. Batch fails --- tests/v1/e2e/test_spec_decode.py | 2 +- vllm/v1/spec_decode/draft_model.py | 95 +++++++++++++++++++----------- vllm/v1/spec_decode/eagle.py | 4 +- vllm/v1/worker/gpu_model_runner.py | 22 ++++++- 4 files changed, 86 insertions(+), 37 deletions(-) diff --git a/tests/v1/e2e/test_spec_decode.py b/tests/v1/e2e/test_spec_decode.py index 1813327dc062..8d6709dbd4dd 100644 --- a/tests/v1/e2e/test_spec_decode.py +++ b/tests/v1/e2e/test_spec_decode.py @@ -291,7 +291,7 @@ def test_draft_model_correctness( """Compare the outputs using and not using speculative decoding. In the greedy decoding case, the outputs must match EXACTLY.""" monkeypatch.setenv("VLLM_USE_V1", "1") - test_prompts = get_test_prompts(mm_enabled=False, quiet=True) + test_prompts = get_test_prompts(mm_enabled=False, quiet=True)[:2] # success for single prompt spec_llm = LLM( model=args.model, diff --git a/vllm/v1/spec_decode/draft_model.py b/vllm/v1/spec_decode/draft_model.py index 79144828f8a8..eeab82d95560 100644 --- a/vllm/v1/spec_decode/draft_model.py +++ b/vllm/v1/spec_decode/draft_model.py @@ -26,7 +26,7 @@ def __init__( vllm_config=vllm_config, device=device, pass_hidden_states_to_model=False, - pass_cudagraph_args_to_forward_ctx=True, + pass_cudagraph_args_to_forward_ctx=False, # The draft model runs one forward pass to prefill # the target_token_ids, and another forward pass for decoding # based on the next_token_ids. I.e. it needs 1 more forward pass. @@ -37,25 +37,52 @@ def __init__( runner=runner) self._raise_if_multimodal() self._raise_if_mrope() + + def update_propose_kwargs(self, propose_kwargs: dict): + common_attn_metadata = propose_kwargs["common_attn_metadata"] + target_token_ids = propose_kwargs["target_token_ids"] + next_token_ids = propose_kwargs["next_token_ids"] + target_positions = propose_kwargs["target_positions"] + token_indices_to_sample = common_attn_metadata.query_start_loc[1:] - 1 + + # update target_token_ids + start_locs = torch.zeros(token_indices_to_sample.shape[0] + 1, + device=token_indices_to_sample.device, + dtype=torch.int32) + start_locs[1:] = token_indices_to_sample + 1 + new_target_token_ids, _ = append_new_toks(toks=target_token_ids, + start_locs=start_locs, + new_toks=next_token_ids) + # update positions + positions_to_append = target_positions[token_indices_to_sample] + 1 + new_target_positions, _ = append_new_toks(toks=target_positions, + start_locs=start_locs, + new_toks=positions_to_append) + # update common_attn_metadata + new_common_attn_metadata = self.update_common_attn_metadata( + new_target_positions, common_attn_metadata) + # update token_indices_to_sample + new_token_indices_to_sample = new_common_attn_metadata.query_start_loc[1:] - 1 + + new_propose_kwargs = dict( + target_token_ids=new_target_token_ids, + target_positions=new_target_positions, + next_token_ids=None, + last_token_indices=new_token_indices_to_sample, + common_attn_metadata=new_common_attn_metadata, + ) + return propose_kwargs | new_propose_kwargs - def prepare_inputs_padded(self, - common_attn_metadata: CommonAttentionMetadata, - spec_decode_metadata: SpecDecodeMetadata, - valid_sampled_tokens_count: torch.Tensor) -> \ - tuple[CommonAttentionMetadata, torch.Tensor, torch.Tensor]: - tup = super().prepare_inputs_padded(common_attn_metadata, - spec_decode_metadata, - valid_sampled_tokens_count) - common_attn_metadata, token_indices, token_indices_to_sample = tup + def update_common_attn_metadata(self, new_positions: torch.Tensor, common_attn_metadata: CommonAttentionMetadata): cad = common_attn_metadata batch_size = common_attn_metadata.batch_size() # token_indices is [0, ..., N], extend by batch_size - new_token_indices = self.arange[:len(token_indices) + batch_size] + # new_token_indices = self.arange[:len(target_token_ids) + len(next_token_ids)] # token indices to sample must be increased # by [+1, +2, ..., +batch_size] - new_token_indices_to_sample = token_indices_to_sample + self.arange[ - 1:batch_size + 1] + # new_token_indices_to_sample = last_token_indices + self.arange[ + # 1:batch_size + 1] # query start loc mus be increased by [+0, +1, +2, ..., +batch_size] new_query_start_loc = cad.query_start_loc + self.arange[:len( @@ -77,17 +104,16 @@ def prepare_inputs_padded(self, # slot mapping depends on num_scheduled_tokens, # which increased by batch_size assert len(self.runner.input_batch.block_table.block_tables) == 1 - kv_cache_group_id = 0 - new_slot_mapping = self.runner.input_batch.block_table[ - kv_cache_group_id].slot_mapping.gpu[:new_num_actual_tokens] - - # new_positions = self.runner.positions.gpu[:new_num_actual_tokens] - # block_numbers = new_positions // self.block_size - # block_ids = new_block_table_tensor.gather( - # dim=1, index=block_numbers.view(1, -1)) - # block_ids = block_ids.view(-1) - # new_slot_mapping = (block_ids * self.block_size - # + new_positions % self.block_size) + # kv_cache_group_id = 0 + # new_slot_mapping = self.runner.input_batch.block_table[ + # kv_cache_group_id].slot_mapping.gpu[:new_num_actual_tokens] + + block_numbers = new_positions // self.block_size + block_ids = new_block_table_tensor.gather( + dim=1, index=block_numbers.view(1, -1)) + block_ids = block_ids.view(-1) + new_slot_mapping = (block_ids * self.block_size + + new_positions % self.block_size) new_cad = CommonAttentionMetadata( query_start_loc=new_query_start_loc, @@ -102,7 +128,7 @@ def prepare_inputs_padded(self, block_table_tensor=new_block_table_tensor, slot_mapping=new_slot_mapping, ) - return new_cad, new_token_indices, new_token_indices_to_sample + return new_cad def _raise_if_multimodal(self): if self.is_multimodal_model: @@ -135,15 +161,16 @@ def dummy_run(self, num_tokens: int, forward_ctx_kwargs: dict): def set_input_ids_first_pass(self, target_token_ids: torch.Tensor, next_token_ids: torch.Tensor, num_tokens: int, last_token_indices: torch.Tensor) -> None: - start_locs = torch.zeros(last_token_indices.shape[0] + 1, - device=last_token_indices.device, - dtype=torch.int32) - start_locs[1:] = last_token_indices + 1 - input_ids, _ = append_new_toks(toks=target_token_ids, - start_locs=start_locs, - new_toks=next_token_ids) - num_tokens = input_ids.shape[0] - self.input_ids[:num_tokens] = input_ids + # start_locs = torch.zeros(last_token_indices.shape[0] + 1, + # device=last_token_indices.device, + # dtype=torch.int32) + # start_locs[1:] = last_token_indices + 1 + # input_ids, _ = append_new_toks(toks=target_token_ids, + # start_locs=start_locs, + # new_toks=next_token_ids) + # num_tokens = input_ids.shape[0] + # self.input_ids[:num_tokens] = input_ids + self.input_ids[:num_tokens] = target_token_ids def load_model(self, target_model: Any) -> None: """Takes target_model to satisfy the type checker.""" diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index d839a98ccbae..5db73d0076bd 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -172,7 +172,7 @@ def propose( mm_embeds: Optional[list[torch.Tensor]] = None, ) -> torch.Tensor: num_tokens = target_token_ids.shape[0] - batch_size = next_token_ids.shape[0] + batch_size = common_attn_metadata.batch_size() if last_token_indices is None: last_token_indices = common_attn_metadata.query_start_loc[1:] - 1 @@ -246,6 +246,7 @@ def propose( with set_forward_context(**forward_ctx_kwargs): ret_hidden_states = self.model(**model_kwargs) + self.runner.log_toks("Draft forward", model_kwargs["input_ids"]) if not self.model_returns_tuple(): last_hidden_states = ret_hidden_states hidden_states = last_hidden_states @@ -392,6 +393,7 @@ def propose( with set_forward_context(**forward_ctx_kwargs): ret_hidden_states = self.model(**model_kwargs) + self.runner.log_toks("Draft forward", model_kwargs["input_ids"]) if not self.model_returns_tuple(): last_hidden_states = ret_hidden_states hidden_states = ret_hidden_states diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 91a904510567..d060565fe6c4 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -173,12 +173,22 @@ def get_output(self) -> ModelRunnerOutput: class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): + + def log_toks(self, msg: str, toks): + if not self.do_log: + return + logger.info("%s: %s", msg, [self.tokenizer.decode(t) for t in toks]) def __init__( self, vllm_config: VllmConfig, device: torch.device, ): + self.do_log = True + if self.do_log: + from transformers import AutoTokenizer + self.tokenizer = AutoTokenizer.from_pretrained( + vllm_config.model_config.model) self.vllm_config = vllm_config self.model_config = vllm_config.model_config self.cache_config = vllm_config.cache_config @@ -2087,6 +2097,8 @@ def _sample( bonus_token_ids, sampling_metadata, ) + t0 = output_token_ids[0] + self.log_toks("sampled token ids", t0[t0 != -1]) sampler_output.sampled_token_ids = output_token_ids self._update_states_after_model_execute(output_token_ids) @@ -2234,6 +2246,8 @@ def execute_model( scheduler_output: "SchedulerOutput", intermediate_tensors: Optional[IntermediateTensors] = None, ) -> Union[ModelRunnerOutput, AsyncModelRunnerOutput, IntermediateTensors]: + if self.do_log: + logger.info("=======BEGIN STEP=======") with record_function_or_nullcontext("Preprocess"): with self.synchronize_input_prep(): # Update persistent batch states. @@ -2303,6 +2317,7 @@ def execute_model( inputs_embeds=inputs_embeds, **model_kwargs, ) + self.log_toks("Target forward", input_ids) with record_function_or_nullcontext("Postprocess"): if self.use_aux_hidden_state_outputs: @@ -2382,6 +2397,7 @@ def propose_draft_token_ids(sampled_token_ids): cudagraph_runtime_mode=cudagraph_runtime_mode, batch_descriptor=batch_descriptor, ) + self.log_toks("draft token ids [0]", self._draft_token_ids[0]) use_padded_batch = self.speculative_config and \ (self.speculative_config.use_eagle() @@ -2553,6 +2569,7 @@ def propose_draft_token_ids( target_hidden_states = hidden_states[:num_scheduled_tokens] else: if self.speculative_config.disable_padded_drafter_batch: + raise ValueError() token_indices_to_sample = None common_attn_metadata, token_indices =\ self.drafter.prepare_inputs( @@ -2590,7 +2607,7 @@ def propose_draft_token_ids( cudagraph_runtime_mode=cudagraph_runtime_mode, batch_descriptor=batch_descriptor, ) - draft_token_ids = self.drafter.propose( + propose_kwargs = dict( target_token_ids=target_token_ids, target_positions=target_positions, target_hidden_states=target_hidden_states, @@ -2601,6 +2618,9 @@ def propose_draft_token_ids( mm_embeds=mm_embeds, cudagraph_args=cudagraph_args, ) + if self.speculative_config.uses_draft_model(): + propose_kwargs = self.drafter.update_propose_kwargs(propose_kwargs) + draft_token_ids = self.drafter.propose(**propose_kwargs) return draft_token_ids def update_config(self, overrides: dict[str, Any]) -> None: From 89b9c1d88e8f2d2ca19735ba4f059cbe4c02fde6 Mon Sep 17 00:00:00 2001 From: Tomas Ruiz Date: Sun, 28 Sep 2025 15:40:07 +0200 Subject: [PATCH 24/73] Remove unnecessary if-else statement Signed-off-by: Tomas Ruiz --- vllm/v1/worker/gpu_model_runner.py | 34 +++++++++++++----------------- 1 file changed, 15 insertions(+), 19 deletions(-) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 917ab1df3a38..992de28ee3f0 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -2663,25 +2663,21 @@ def propose_draft_token_ids( else: mm_embed_inputs = None - if (self.speculative_config.use_eagle() - or self.speculative_config.uses_draft_model()): - assert isinstance(self.drafter, - (EagleProposer, DraftModelProposer)) - cudagraph_args: CudaGraphArgs = dict( - cudagraph_runtime_mode=cudagraph_runtime_mode, - batch_descriptor=batch_descriptor, - ) - draft_token_ids = self.drafter.propose( - target_token_ids=target_token_ids, - target_positions=target_positions, - target_hidden_states=target_hidden_states, - next_token_ids=next_token_ids, - last_token_indices=token_indices_to_sample, - sampling_metadata=sampling_metadata, - common_attn_metadata=common_attn_metadata, - mm_embed_inputs=mm_embed_inputs, - cudagraph_args=cudagraph_args, - ) + cudagraph_args: CudaGraphArgs = dict( + cudagraph_runtime_mode=cudagraph_runtime_mode, + batch_descriptor=batch_descriptor, + ) + draft_token_ids = self.drafter.propose( + target_token_ids=target_token_ids, + target_positions=target_positions, + target_hidden_states=target_hidden_states, + next_token_ids=next_token_ids, + last_token_indices=token_indices_to_sample, + sampling_metadata=sampling_metadata, + common_attn_metadata=common_attn_metadata, + mm_embed_inputs=mm_embed_inputs, + cudagraph_args=cudagraph_args, + ) return draft_token_ids def update_config(self, overrides: dict[str, Any]) -> None: From e74c71e1123931b623486987a587cf18ad742842 Mon Sep 17 00:00:00 2001 From: Tomas Ruiz Date: Tue, 30 Sep 2025 17:59:21 +0200 Subject: [PATCH 25/73] Minimize changes Signed-off-by: Tomas Ruiz --- vllm/v1/worker/gpu_model_runner.py | 42 ++++++++++++++---------------- 1 file changed, 19 insertions(+), 23 deletions(-) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 449d48917e56..d1588e4ed8cd 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -2681,29 +2681,25 @@ def propose_draft_token_ids( ) else: mm_embed_inputs = None - if (self.speculative_config.use_eagle() - or self.speculative_config.uses_draft_model()): - assert isinstance(self.drafter, - (EagleProposer, DraftModelProposer)) - cudagraph_args = dict( - cudagraph_runtime_mode=cudagraph_runtime_mode, - batch_descriptor=batch_descriptor, - ) - propose_kwargs = dict( - target_token_ids=target_token_ids, - target_positions=target_positions, - target_hidden_states=target_hidden_states, - next_token_ids=next_token_ids, - last_token_indices=token_indices_to_sample, - sampling_metadata=sampling_metadata, - common_attn_metadata=common_attn_metadata, - mm_embed_inputs=mm_embed_inputs, - cudagraph_args=cudagraph_args, - ) - if isinstance(self.drafter, DraftModelProposer): - propose_kwargs = self.drafter.update_propose_kwargs( - propose_kwargs) - draft_token_ids = self.drafter.propose(**propose_kwargs) + cudagraph_args = dict( + cudagraph_runtime_mode=cudagraph_runtime_mode, + batch_descriptor=batch_descriptor, + ) + propose_kwargs = dict( + target_token_ids=target_token_ids, + target_positions=target_positions, + target_hidden_states=target_hidden_states, + next_token_ids=next_token_ids, + last_token_indices=token_indices_to_sample, + sampling_metadata=sampling_metadata, + common_attn_metadata=common_attn_metadata, + mm_embed_inputs=mm_embed_inputs, + cudagraph_args=cudagraph_args, + ) + if isinstance(self.drafter, DraftModelProposer): + propose_kwargs = self.drafter.update_propose_kwargs( + propose_kwargs) + draft_token_ids = self.drafter.propose(**propose_kwargs) return draft_token_ids def update_config(self, overrides: dict[str, Any]) -> None: From 994e9cc251dbca9ce5a5a1eae981fe779ade8b9d Mon Sep 17 00:00:00 2001 From: Tomas Ruiz Date: Tue, 30 Sep 2025 19:59:20 +0200 Subject: [PATCH 26/73] Commit unit test success --- tests/v1/e2e/test_spec_decode.py | 3 +-- vllm/v1/attention/backends/utils.py | 3 +++ vllm/v1/spec_decode/draft_model.py | 29 ++++++----------------------- vllm/v1/spec_decode/eagle.py | 8 ++++++++ vllm/v1/worker/gpu_model_runner.py | 16 ++++++++++++---- 5 files changed, 30 insertions(+), 29 deletions(-) diff --git a/tests/v1/e2e/test_spec_decode.py b/tests/v1/e2e/test_spec_decode.py index 9b3788826215..b7247bc5bb69 100644 --- a/tests/v1/e2e/test_spec_decode.py +++ b/tests/v1/e2e/test_spec_decode.py @@ -363,8 +363,7 @@ def test_draft_model_correctness( """Compare the outputs using and not using speculative decoding. In the greedy decoding case, the outputs must match EXACTLY.""" monkeypatch.setenv("VLLM_USE_V1", "1") - test_prompts = get_test_prompts( - mm_enabled=False, quiet=True)[:2] # success for single prompt + test_prompts = get_test_prompts(mm_enabled=False, quiet=True) spec_llm = LLM( model=args.model, diff --git a/vllm/v1/attention/backends/utils.py b/vllm/v1/attention/backends/utils.py index 9e6f48b9cf54..10eeb39f488e 100644 --- a/vllm/v1/attention/backends/utils.py +++ b/vllm/v1/attention/backends/utils.py @@ -86,6 +86,9 @@ class CommonAttentionMetadata: def batch_size(self) -> int: return self.seq_lens_cpu.shape[0] + def last_token_indices(self) -> torch.Tensor: + return self.query_start_loc[1:] - 1 + def slice_query_start_locs( query_start_loc: torch.Tensor, diff --git a/vllm/v1/spec_decode/draft_model.py b/vllm/v1/spec_decode/draft_model.py index b4055603e004..eb1df89c8083 100644 --- a/vllm/v1/spec_decode/draft_model.py +++ b/vllm/v1/spec_decode/draft_model.py @@ -78,15 +78,6 @@ def update_common_attn_metadata( common_attn_metadata: CommonAttentionMetadata): cad = common_attn_metadata batch_size = common_attn_metadata.batch_size() - - # token_indices is [0, ..., N], extend by batch_size - # new_token_indices = \ - # self.arange[:len(target_token_ids) + len(next_token_ids)] - # token indices to sample must be increased - # by [+1, +2, ..., +batch_size] - # new_token_indices_to_sample = last_token_indices + self.arange[ - # 1:batch_size + 1] - # query start loc mus be increased by [+0, +1, +2, ..., +batch_size] new_query_start_loc = cad.query_start_loc + self.arange[:len( cad.query_start_loc)] @@ -104,20 +95,12 @@ def update_common_attn_metadata( new_max_seq_len = cad.max_seq_len + 1 # block table tensor depends on num_requests, which doesn't change new_block_table_tensor = cad.block_table_tensor - # slot mapping depends on num_scheduled_tokens, - # which increased by batch_size - assert len(self.runner.input_batch.block_table.block_tables) == 1 - # kv_cache_group_id = 0 - # new_slot_mapping = self.runner.input_batch.block_table[ - # kv_cache_group_id].slot_mapping.gpu[:new_num_actual_tokens] - - block_numbers = new_positions // self.block_size - block_ids = new_block_table_tensor.gather(dim=1, - index=block_numbers.view( - 1, -1)) - block_ids = block_ids.view(-1) - new_slot_mapping = (block_ids * self.block_size + - new_positions % self.block_size) + # slot mappings are extended (interleaved) by the next serial id + last_slot_mapping_ids = cad.slot_mapping[cad.last_token_indices()] + new_slot_mapping, _ = append_new_toks(toks=cad.slot_mapping, + start_locs=cad.query_start_loc, + new_toks=last_slot_mapping_ids + + 1) new_cad = CommonAttentionMetadata( query_start_loc=new_query_start_loc, diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index 00adef6be5c2..a8e80381d111 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -272,6 +272,10 @@ def propose( with set_forward_context(**forward_ctx_kwargs): ret_hidden_states = self.model(**model_kwargs) self.runner.log_toks("Draft forward", model_kwargs["input_ids"]) + if self.runner.do_log: + logger.info("Draft forward positions: %s", + model_kwargs["positions"].tolist()) + logger.info("Draft attn_metadata: %s", attn_metadata) if not self.model_returns_tuple(): last_hidden_states = ret_hidden_states hidden_states = last_hidden_states @@ -442,6 +446,10 @@ def propose( ret_hidden_states = self.model(**model_kwargs) self.runner.log_toks("Draft forward", model_kwargs["input_ids"]) + if self.runner.do_log: + logger.info("Draft forward positions: %s", + model_kwargs["positions"].tolist()) + logger.info("Draft attn_metadata: %s", attn_metadata) if not self.model_returns_tuple(): last_hidden_states = ret_hidden_states hidden_states = ret_hidden_states diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index d1588e4ed8cd..c6b70f0652a1 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -184,7 +184,7 @@ def __init__( vllm_config: VllmConfig, device: torch.device, ): - self.do_log = True + self.do_log = False if self.do_log: from transformers import AutoTokenizer self.tokenizer = AutoTokenizer.from_pretrained( @@ -2144,8 +2144,10 @@ def _sample( bonus_token_ids, sampling_metadata, ) - t0 = output_token_ids[0] - self.log_toks("sampled token ids", t0[t0 != -1]) + if self.do_log: + for idx, token_ids in enumerate(output_token_ids): + self.log_toks(f"sampled token ids [{idx}]", + token_ids[token_ids != -1]) sampler_output.sampled_token_ids = output_token_ids self._update_states_after_model_execute(output_token_ids) @@ -2398,6 +2400,10 @@ def execute_model( **model_kwargs, ) self.log_toks("Target forward", input_ids) + if self.do_log: + logger.info("Target forward positions: %s", positions.tolist()) + logger.info("Target attn_metadata: %s", + list(attn_metadata.values())[0]) with record_function_or_nullcontext("Postprocess"): if self.use_aux_hidden_state_outputs: @@ -2477,7 +2483,9 @@ def propose_draft_token_ids(sampled_token_ids): cudagraph_runtime_mode=cudagraph_runtime_mode, batch_descriptor=batch_descriptor, ) - self.log_toks("draft token ids [0]", self._draft_token_ids[0]) + if self.do_log: + for idx, draft_token_ids in enumerate(self._draft_token_ids): + self.log_toks(f"draft token ids [{idx}]", draft_token_ids) use_padded_batch = self.speculative_config and \ (self.speculative_config.use_eagle() From 26ab913411f25e0c9f1eee1e7536dac23289031f Mon Sep 17 00:00:00 2001 From: Tomas Ruiz Date: Tue, 30 Sep 2025 20:06:43 +0200 Subject: [PATCH 27/73] Remove unnecessary variables --- vllm/v1/spec_decode/draft_model.py | 18 +++++------------- vllm/v1/spec_decode/eagle.py | 20 +++----------------- vllm/v1/worker/gpu_model_runner.py | 1 - 3 files changed, 8 insertions(+), 31 deletions(-) diff --git a/vllm/v1/spec_decode/draft_model.py b/vllm/v1/spec_decode/draft_model.py index eb1df89c8083..a19817ff3bd0 100644 --- a/vllm/v1/spec_decode/draft_model.py +++ b/vllm/v1/spec_decode/draft_model.py @@ -21,19 +21,11 @@ def __init__( device: torch.device, runner=None, ): - super().__init__( - vllm_config=vllm_config, - device=device, - pass_hidden_states_to_model=False, - pass_cudagraph_args_to_forward_ctx=False, - # The draft model runs one forward pass to prefill - # the target_token_ids, and another forward pass for decoding - # based on the next_token_ids. I.e. it needs 1 more forward pass. - one_extra_forward_pass=False, - # the first draft_token_ids are replaced by next_token_ids, so - # they don't need to be returned as proposed tokens - drop_first_drafted_tokens=False, - runner=runner) + super().__init__(vllm_config=vllm_config, + device=device, + pass_hidden_states_to_model=False, + pass_cudagraph_args_to_forward_ctx=False, + runner=runner) self._raise_if_multimodal() self._raise_if_mrope() diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index a8e80381d111..736dd29afd7a 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -47,8 +47,6 @@ def __init__( device: torch.device, pass_hidden_states_to_model: bool, pass_cudagraph_args_to_forward_ctx: bool, - one_extra_forward_pass: bool, - drop_first_drafted_tokens: bool, runner=None, ): self.vllm_config = vllm_config @@ -58,7 +56,6 @@ def __init__( self.pass_hidden_states_to_model = pass_hidden_states_to_model self.pass_cudagraph_args_to_forward_ctx \ = pass_cudagraph_args_to_forward_ctx - self.drop_first_drafted_tokens = drop_first_drafted_tokens self.runner = runner self.dtype = vllm_config.model_config.dtype @@ -66,9 +63,6 @@ def __init__( self.block_size = vllm_config.cache_config.block_size self.num_speculative_tokens = ( self.speculative_config.num_speculative_tokens) - self.num_forward_passes = self.num_speculative_tokens - if one_extra_forward_pass: - self.num_forward_passes += 1 self.max_num_tokens = ( vllm_config.scheduler_config.max_num_batched_tokens) self.token_arange_np = np.arange(self.max_num_tokens) @@ -285,7 +279,7 @@ def propose( logits = self.model.compute_logits(sample_hidden_states) # Early exit if there is only one draft token to be generated. - if self.num_forward_passes == 1: + if self.num_speculative_tokens == 1: draft_token_ids = logits.argmax(dim=-1) return draft_token_ids.view(-1, 1) @@ -323,10 +317,7 @@ def propose( f"{self.allowed_attn_types}") # Generate the remaining draft tokens. - if self.drop_first_drafted_tokens: - draft_token_ids_list = [next_token_ids] - else: - draft_token_ids_list = [draft_token_ids] + draft_token_ids_list = [draft_token_ids] if self.use_cuda_graph and \ batch_size <= self.cudagraph_batch_sizes[-1]: @@ -339,7 +330,7 @@ def propose( common_attn_metadata.query_start_loc = self.arange[:batch_size + 1] common_attn_metadata.query_start_loc_cpu = torch.from_numpy( self.token_arange_np[:batch_size + 1]).clone() - for token_index in range(self.num_forward_passes - 1): + for token_index in range(self.num_speculative_tokens - 1): # Update the inputs. # cast to int32 is crucial when eagle model is compiled. # tensor.argmax() returns int64 by default. @@ -460,9 +451,6 @@ def propose( draft_token_ids = logits.argmax(dim=-1) draft_token_ids_list.append(draft_token_ids) - if self.drop_first_drafted_tokens: - draft_token_ids_list = draft_token_ids_list[1:] - # [batch_size, num_speculative_tokens] draft_token_ids = torch.stack(draft_token_ids_list, dim=1) return draft_token_ids @@ -1094,8 +1082,6 @@ def __init__( device, pass_hidden_states_to_model=True, pass_cudagraph_args_to_forward_ctx=False, - one_extra_forward_pass=False, - drop_first_drafted_tokens=False, runner=runner) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index c6b70f0652a1..5c2e4eea6c54 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -2656,7 +2656,6 @@ def propose_draft_token_ids( target_hidden_states = hidden_states[:num_scheduled_tokens] else: if self.speculative_config.disable_padded_drafter_batch: - raise ValueError() token_indices_to_sample = None common_attn_metadata, token_indices =\ self.drafter.prepare_inputs( From 01dd981d4c6cb658af3742ea57fe3bfdb03b54b1 Mon Sep 17 00:00:00 2001 From: Tomas Ruiz Date: Tue, 30 Sep 2025 20:12:30 +0200 Subject: [PATCH 28/73] Minimize changes --- vllm/v1/spec_decode/draft_model.py | 11 +---------- vllm/v1/worker/gpu_model_runner.py | 4 ++-- 2 files changed, 3 insertions(+), 12 deletions(-) diff --git a/vllm/v1/spec_decode/draft_model.py b/vllm/v1/spec_decode/draft_model.py index a19817ff3bd0..35c5cb21fee9 100644 --- a/vllm/v1/spec_decode/draft_model.py +++ b/vllm/v1/spec_decode/draft_model.py @@ -24,7 +24,7 @@ def __init__( super().__init__(vllm_config=vllm_config, device=device, pass_hidden_states_to_model=False, - pass_cudagraph_args_to_forward_ctx=False, + pass_cudagraph_args_to_forward_ctx=True, runner=runner) self._raise_if_multimodal() self._raise_if_mrope() @@ -140,15 +140,6 @@ def dummy_run(self, num_tokens: int, forward_ctx_kwargs: dict): def set_input_ids_first_pass(self, target_token_ids: torch.Tensor, next_token_ids: torch.Tensor, num_tokens: int, last_token_indices: torch.Tensor) -> None: - # start_locs = torch.zeros(last_token_indices.shape[0] + 1, - # device=last_token_indices.device, - # dtype=torch.int32) - # start_locs[1:] = last_token_indices + 1 - # input_ids, _ = append_new_toks(toks=target_token_ids, - # start_locs=start_locs, - # new_toks=next_token_ids) - # num_tokens = input_ids.shape[0] - # self.input_ids[:num_tokens] = input_ids self.input_ids[:num_tokens] = target_token_ids def load_model(self, target_model: Any) -> None: diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 5c2e4eea6c54..2800981c5a5d 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -92,7 +92,7 @@ from vllm.v1.sample.rejection_sampler import RejectionSampler from vllm.v1.sample.sampler import Sampler from vllm.v1.spec_decode.draft_model import DraftModelProposer -from vllm.v1.spec_decode.eagle import EagleProposer +from vllm.v1.spec_decode.eagle import CudaGraphArgs, EagleProposer from vllm.v1.spec_decode.medusa import MedusaProposer from vllm.v1.spec_decode.metadata import SpecDecodeMetadata from vllm.v1.spec_decode.ngram_proposer import NgramProposer @@ -2688,7 +2688,7 @@ def propose_draft_token_ids( ) else: mm_embed_inputs = None - cudagraph_args = dict( + cudagraph_args: CudaGraphArgs = dict( cudagraph_runtime_mode=cudagraph_runtime_mode, batch_descriptor=batch_descriptor, ) From 09a0bb3a43d4c67eadc7f152b0f3b579d24f9fdf Mon Sep 17 00:00:00 2001 From: Tomas Ruiz Date: Tue, 30 Sep 2025 20:13:50 +0200 Subject: [PATCH 29/73] Remove token logging Signed-off-by: Tomas Ruiz --- vllm/v1/spec_decode/eagle.py | 11 ----------- vllm/v1/worker/gpu_model_runner.py | 24 ------------------------ 2 files changed, 35 deletions(-) diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index 736dd29afd7a..068271a80837 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -265,11 +265,6 @@ def propose( with set_forward_context(**forward_ctx_kwargs): ret_hidden_states = self.model(**model_kwargs) - self.runner.log_toks("Draft forward", model_kwargs["input_ids"]) - if self.runner.do_log: - logger.info("Draft forward positions: %s", - model_kwargs["positions"].tolist()) - logger.info("Draft attn_metadata: %s", attn_metadata) if not self.model_returns_tuple(): last_hidden_states = ret_hidden_states hidden_states = last_hidden_states @@ -435,12 +430,6 @@ def propose( with set_forward_context(**forward_ctx_kwargs): ret_hidden_states = self.model(**model_kwargs) - self.runner.log_toks("Draft forward", - model_kwargs["input_ids"]) - if self.runner.do_log: - logger.info("Draft forward positions: %s", - model_kwargs["positions"].tolist()) - logger.info("Draft attn_metadata: %s", attn_metadata) if not self.model_returns_tuple(): last_hidden_states = ret_hidden_states hidden_states = ret_hidden_states diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 2800981c5a5d..6305d0024aa4 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -174,21 +174,11 @@ def get_output(self) -> ModelRunnerOutput: class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): - def log_toks(self, msg: str, toks): - if not self.do_log: - return - logger.info("%s: %s", msg, [self.tokenizer.decode(t) for t in toks]) - def __init__( self, vllm_config: VllmConfig, device: torch.device, ): - self.do_log = False - if self.do_log: - from transformers import AutoTokenizer - self.tokenizer = AutoTokenizer.from_pretrained( - vllm_config.model_config.model) self.vllm_config = vllm_config self.model_config = vllm_config.model_config self.cache_config = vllm_config.cache_config @@ -2144,10 +2134,6 @@ def _sample( bonus_token_ids, sampling_metadata, ) - if self.do_log: - for idx, token_ids in enumerate(output_token_ids): - self.log_toks(f"sampled token ids [{idx}]", - token_ids[token_ids != -1]) sampler_output.sampled_token_ids = output_token_ids self._update_states_after_model_execute(output_token_ids) @@ -2327,8 +2313,6 @@ def execute_model( scheduler_output: "SchedulerOutput", intermediate_tensors: Optional[IntermediateTensors] = None, ) -> Union[ModelRunnerOutput, AsyncModelRunnerOutput, IntermediateTensors]: - if self.do_log: - logger.info("=======BEGIN STEP=======") with record_function_or_nullcontext("Preprocess"): with self.synchronize_input_prep(): # Update persistent batch states. @@ -2399,11 +2383,6 @@ def execute_model( inputs_embeds=inputs_embeds, **model_kwargs, ) - self.log_toks("Target forward", input_ids) - if self.do_log: - logger.info("Target forward positions: %s", positions.tolist()) - logger.info("Target attn_metadata: %s", - list(attn_metadata.values())[0]) with record_function_or_nullcontext("Postprocess"): if self.use_aux_hidden_state_outputs: @@ -2483,9 +2462,6 @@ def propose_draft_token_ids(sampled_token_ids): cudagraph_runtime_mode=cudagraph_runtime_mode, batch_descriptor=batch_descriptor, ) - if self.do_log: - for idx, draft_token_ids in enumerate(self._draft_token_ids): - self.log_toks(f"draft token ids [{idx}]", draft_token_ids) use_padded_batch = self.speculative_config and \ (self.speculative_config.use_eagle() From 42faf1ce8c8380c1c4fbe681a87942b4c2841a19 Mon Sep 17 00:00:00 2001 From: Tomas Ruiz Date: Tue, 30 Sep 2025 20:39:57 +0200 Subject: [PATCH 30/73] Relocate utility method Signed-off-by: Tomas Ruiz --- tests/v1/attention/test_attention_backends.py | 16 ++++ tests/v1/e2e/test_spec_decode.py | 16 ---- vllm/v1/attention/backends/utils.py | 63 +++++++++++++++ vllm/v1/spec_decode/draft_model.py | 77 +------------------ 4 files changed, 83 insertions(+), 89 deletions(-) diff --git a/tests/v1/attention/test_attention_backends.py b/tests/v1/attention/test_attention_backends.py index 6c17be759ab6..ec90d34e8903 100644 --- a/tests/v1/attention/test_attention_backends.py +++ b/tests/v1/attention/test_attention_backends.py @@ -17,6 +17,7 @@ from vllm.platforms import current_platform from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, cdiv, is_torch_equal_or_newer from vllm.v1.attention.backends.utils import (CommonAttentionMetadata, + append_new_toks, set_kv_cache_layout) from vllm.v1.kv_cache_interface import FullAttentionSpec @@ -566,3 +567,18 @@ def sliding_window_mask_mod( LARGE_BLOCK_BACKENDS, sliding_window_mask_mod_fn, block_size=128) + + +@pytest.mark.parametrize("device", ["cpu", "cuda"]) +def test_append_new_toks(device: str): + toks = torch.tensor([11, 12, 13, 21, 22, 31], device=device) + start_locs = torch.tensor([0, 3, 5, 6], device=device) + new_toks = torch.tensor([13, 23, 32], device=device) + + expected_toks = torch.tensor([11, 12, 13, 13, 21, 22, 23, 31, 32], + device=device) + expected_start_locs = torch.tensor([0, 4, 7, 9], device=device) + actual_toks, actual_start_locs = append_new_toks(toks, start_locs, + new_toks) + assert torch.all(actual_toks == expected_toks) + assert torch.all(actual_start_locs == expected_start_locs) diff --git a/tests/v1/e2e/test_spec_decode.py b/tests/v1/e2e/test_spec_decode.py index b7247bc5bb69..4200d8d01b64 100644 --- a/tests/v1/e2e/test_spec_decode.py +++ b/tests/v1/e2e/test_spec_decode.py @@ -16,7 +16,6 @@ from vllm.distributed import cleanup_dist_env_and_memory from vllm.outputs import RequestOutput from vllm.platforms import current_platform -from vllm.v1.spec_decode.draft_model import append_new_toks from vllm.v1.spec_decode.metrics import (compute_acceptance_len, compute_acceptance_rate) @@ -430,18 +429,3 @@ def compute_exact_matches(ref_outputs: list[RequestOutput], print(f"ref_output: {ref_output.outputs[0].text}") print(f"spec_output: {spec_output.outputs[0].text}") return matches / len(ref_outputs) - - -@pytest.mark.parametrize("device", ["cpu", "cuda"]) -def test_append_new_toks(device: str): - toks = torch.tensor([11, 12, 13, 21, 22, 31], device=device) - start_locs = torch.tensor([0, 3, 5, 6], device=device) - new_toks = torch.tensor([13, 23, 32], device=device) - - expected_toks = torch.tensor([11, 12, 13, 13, 21, 22, 23, 31, 32], - device=device) - expected_start_locs = torch.tensor([0, 4, 7, 9], device=device) - actual_toks, actual_start_locs = append_new_toks(toks, start_locs, - new_toks) - assert torch.all(actual_toks == expected_toks) - assert torch.all(actual_start_locs == expected_start_locs) diff --git a/vllm/v1/attention/backends/utils.py b/vllm/v1/attention/backends/utils.py index 10eeb39f488e..95e2ca533a5f 100644 --- a/vllm/v1/attention/backends/utils.py +++ b/vllm/v1/attention/backends/utils.py @@ -105,6 +105,69 @@ def slice_query_start_locs( query_start_loc[request_slice.start] +def extend_all_queries_by_1(common_attn_metadata: CommonAttentionMetadata, + arange: torch.Tensor) -> CommonAttentionMetadata: + """ + Creates a new CommonAttentionMetadata with all query lengths increased by 1. + Also all seq lens are increased by 1. + This is useful e.g. in speculative decoding with draft models, where we + extend each sequence by 1 token. + """ + cad = common_attn_metadata + # query start loc must be increased by [+0, +1, +2, ..., +batch_size] + new_query_start_loc = cad.query_start_loc \ + + arange[:len(cad.query_start_loc)] + new_seq_lens = cad.seq_lens + 1 + # slot mappings are extended (interleaved) by the next serial id + last_slot_mapping_ids = cad.slot_mapping[cad.last_token_indices()] + new_slot_mapping, _ = append_new_toks(toks=cad.slot_mapping, + start_locs=cad.query_start_loc, + new_toks=last_slot_mapping_ids + 1) + new_cad = CommonAttentionMetadata( + query_start_loc=new_query_start_loc, + query_start_loc_cpu=new_query_start_loc.to("cpu"), + seq_lens=new_seq_lens, + seq_lens_cpu=new_seq_lens.to("cpu"), + num_reqs=cad.num_reqs, # num requests stays unchanged + num_computed_tokens_cpu=cad.num_computed_tokens_cpu + 1, + # each request is extended by 1 token -> batch_size tokens are added + num_actual_tokens=cad.num_actual_tokens + cad.batch_size(), + # All query lens increase by 1, so max query len increases by 1 + max_query_len=cad.max_query_len + 1, + max_seq_len=cad.max_seq_len + 1, + # block table tensor depends on num requests, which stays constant + block_table_tensor=cad.block_table_tensor, + slot_mapping=new_slot_mapping, + ) + return new_cad + + +def append_new_toks( + toks: torch.Tensor, start_locs: torch.Tensor, + new_toks: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + long_len = toks.shape[0] + new_toks.shape[0] + long_toks = torch.zeros(long_len, device=toks.device, dtype=toks.dtype) + + # compute indices for previous toks + toks_idxs = torch.ones_like(toks) + toks_idxs[start_locs[1:-1]] += 1 + toks_idxs = toks_idxs.cumsum(0) - 1 + + # compute indices for new toks + new_toks_idxs = start_locs[1:] + torch.arange(new_toks.shape[0], + device=toks.device) + + # assign toks and new toks + long_toks[toks_idxs] = toks + long_toks[new_toks_idxs] = new_toks + + # compute new start locs + new_start_locs = torch.zeros_like(start_locs) + new_start_locs[1:] = new_toks_idxs + 1 + + return long_toks, new_start_locs + + def _make_metadata_with_slice( ubatch_slice: UBatchSlice, attn_metadata: CommonAttentionMetadata) -> CommonAttentionMetadata: diff --git a/vllm/v1/spec_decode/draft_model.py b/vllm/v1/spec_decode/draft_model.py index 35c5cb21fee9..ea76fd4dfc65 100644 --- a/vllm/v1/spec_decode/draft_model.py +++ b/vllm/v1/spec_decode/draft_model.py @@ -9,7 +9,8 @@ from vllm.config import ModelConfig, VllmConfig, get_layers_from_vllm_config from vllm.forward_context import set_forward_context from vllm.model_executor.model_loader import get_model -from vllm.v1.attention.backends.utils import CommonAttentionMetadata +from vllm.v1.attention.backends.utils import (append_new_toks, + extend_all_queries_by_1) from vllm.v1.spec_decode.eagle import SpecDecodeBaseProposer @@ -50,8 +51,8 @@ def update_propose_kwargs(self, propose_kwargs: dict): start_locs=start_locs, new_toks=positions_to_append) # update common_attn_metadata - new_common_attn_metadata = self.update_common_attn_metadata( - new_target_positions, common_attn_metadata) + new_common_attn_metadata = extend_all_queries_by_1( + common_attn_metadata, arange=self.arange) # update token_indices_to_sample new_token_indices_to_sample = new_common_attn_metadata.query_start_loc[ 1:] - 1 @@ -65,50 +66,6 @@ def update_propose_kwargs(self, propose_kwargs: dict): ) return propose_kwargs | new_propose_kwargs - def update_common_attn_metadata( - self, new_positions: torch.Tensor, - common_attn_metadata: CommonAttentionMetadata): - cad = common_attn_metadata - batch_size = common_attn_metadata.batch_size() - # query start loc mus be increased by [+0, +1, +2, ..., +batch_size] - new_query_start_loc = cad.query_start_loc + self.arange[:len( - cad.query_start_loc)] - # seq lens must be increased by [+1, +1, ..., +1] size batch_size - new_seq_lens = cad.seq_lens + torch.ones_like(cad.seq_lens) - # num requests stays unchanged - new_num_reqs = cad.num_reqs - # num computed tokens are increased by [+1, +1, ..., +1] size batch_size - new_num_computed_tokens_cpu = cad.num_computed_tokens_cpu \ - + torch.ones_like(cad.num_computed_tokens_cpu) - # num actual tokens increases by batch_size - new_num_actual_tokens = cad.num_actual_tokens + batch_size - # max query len and max seq len increases by 1 - new_max_query_len = cad.max_query_len + 1 - new_max_seq_len = cad.max_seq_len + 1 - # block table tensor depends on num_requests, which doesn't change - new_block_table_tensor = cad.block_table_tensor - # slot mappings are extended (interleaved) by the next serial id - last_slot_mapping_ids = cad.slot_mapping[cad.last_token_indices()] - new_slot_mapping, _ = append_new_toks(toks=cad.slot_mapping, - start_locs=cad.query_start_loc, - new_toks=last_slot_mapping_ids + - 1) - - new_cad = CommonAttentionMetadata( - query_start_loc=new_query_start_loc, - query_start_loc_cpu=new_query_start_loc.to("cpu"), - seq_lens=new_seq_lens, - seq_lens_cpu=new_seq_lens.to("cpu"), - num_reqs=new_num_reqs, - num_computed_tokens_cpu=new_num_computed_tokens_cpu, - num_actual_tokens=new_num_actual_tokens, - max_query_len=new_max_query_len, - max_seq_len=new_max_seq_len, - block_table_tensor=new_block_table_tensor, - slot_mapping=new_slot_mapping, - ) - return new_cad - def _raise_if_multimodal(self): if self.supports_mm_inputs: raise NotImplementedError("Speculative Decoding with draft models " @@ -169,29 +126,3 @@ def load_model(self, target_model: Any) -> None: get_layers_from_vllm_config(self.vllm_config, Attention).keys() - target_attn_layer_names) self.attn_layer_names = list(draft_attn_layer_names) - - -def append_new_toks( - toks: torch.Tensor, start_locs: torch.Tensor, - new_toks: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: - long_len = toks.shape[0] + new_toks.shape[0] - long_toks = torch.zeros(long_len, device=toks.device, dtype=toks.dtype) - - # compute indices for previous toks - toks_idxs = torch.ones_like(toks) - toks_idxs[start_locs[1:-1]] += 1 - toks_idxs = toks_idxs.cumsum(0) - 1 - - # compute indices for new toks - new_toks_idxs = start_locs[1:] + torch.arange(new_toks.shape[0], - device=toks.device) - - # assign toks and new toks - long_toks[toks_idxs] = toks - long_toks[new_toks_idxs] = new_toks - - # compute new start locs - new_start_locs = torch.zeros_like(start_locs) - new_start_locs[1:] = new_toks_idxs + 1 - - return long_toks, new_start_locs From 044e45ca1b38d3e78c0a271d14135e789443e5d5 Mon Sep 17 00:00:00 2001 From: Tomas Ruiz Date: Tue, 30 Sep 2025 21:08:39 +0200 Subject: [PATCH 31/73] Simplify extend_flat_seqs() Signed-off-by: Tomas Ruiz --- tests/v1/attention/test_attention_backends.py | 28 +++++----- vllm/v1/attention/backends/utils.py | 54 +++++++++---------- vllm/v1/spec_decode/draft_model.py | 39 ++++++-------- vllm/v1/spec_decode/eagle.py | 2 +- 4 files changed, 61 insertions(+), 62 deletions(-) diff --git a/tests/v1/attention/test_attention_backends.py b/tests/v1/attention/test_attention_backends.py index ec90d34e8903..3b89c5117fed 100644 --- a/tests/v1/attention/test_attention_backends.py +++ b/tests/v1/attention/test_attention_backends.py @@ -17,7 +17,7 @@ from vllm.platforms import current_platform from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, cdiv, is_torch_equal_or_newer from vllm.v1.attention.backends.utils import (CommonAttentionMetadata, - append_new_toks, + extend_flat_seqs, set_kv_cache_layout) from vllm.v1.kv_cache_interface import FullAttentionSpec @@ -570,15 +570,19 @@ def sliding_window_mask_mod( @pytest.mark.parametrize("device", ["cpu", "cuda"]) -def test_append_new_toks(device: str): - toks = torch.tensor([11, 12, 13, 21, 22, 31], device=device) - start_locs = torch.tensor([0, 3, 5, 6], device=device) - new_toks = torch.tensor([13, 23, 32], device=device) - - expected_toks = torch.tensor([11, 12, 13, 13, 21, 22, 23, 31, 32], +def test_extend_flat_seqs(device: str): + # fmt: off + seqs = torch.tensor([11, 12, 13, + 21, 22, + 31], device=device) + end_locs = torch.tensor([2, 4, 5], device=device) + new_vals = torch.tensor([14, + 23, + 32], device=device) + expected_seqs = torch.tensor([11, 12, 13, 14, + 21, 22, 23, + 31, 32], device=device) - expected_start_locs = torch.tensor([0, 4, 7, 9], device=device) - actual_toks, actual_start_locs = append_new_toks(toks, start_locs, - new_toks) - assert torch.all(actual_toks == expected_toks) - assert torch.all(actual_start_locs == expected_start_locs) + # fmt: on + actual_seqs = extend_flat_seqs(seqs, end_locs, new_vals) + assert torch.all(actual_seqs == expected_seqs) diff --git a/vllm/v1/attention/backends/utils.py b/vllm/v1/attention/backends/utils.py index 95e2ca533a5f..98b99239c22e 100644 --- a/vllm/v1/attention/backends/utils.py +++ b/vllm/v1/attention/backends/utils.py @@ -120,9 +120,9 @@ def extend_all_queries_by_1(common_attn_metadata: CommonAttentionMetadata, new_seq_lens = cad.seq_lens + 1 # slot mappings are extended (interleaved) by the next serial id last_slot_mapping_ids = cad.slot_mapping[cad.last_token_indices()] - new_slot_mapping, _ = append_new_toks(toks=cad.slot_mapping, - start_locs=cad.query_start_loc, - new_toks=last_slot_mapping_ids + 1) + new_slot_mapping = extend_flat_seqs(seqs=cad.slot_mapping, + end_locs=cad.last_token_indices(), + new_vals=last_slot_mapping_ids + 1) new_cad = CommonAttentionMetadata( query_start_loc=new_query_start_loc, query_start_loc_cpu=new_query_start_loc.to("cpu"), @@ -142,30 +142,30 @@ def extend_all_queries_by_1(common_attn_metadata: CommonAttentionMetadata, return new_cad -def append_new_toks( - toks: torch.Tensor, start_locs: torch.Tensor, - new_toks: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: - long_len = toks.shape[0] + new_toks.shape[0] - long_toks = torch.zeros(long_len, device=toks.device, dtype=toks.dtype) - - # compute indices for previous toks - toks_idxs = torch.ones_like(toks) - toks_idxs[start_locs[1:-1]] += 1 - toks_idxs = toks_idxs.cumsum(0) - 1 - - # compute indices for new toks - new_toks_idxs = start_locs[1:] + torch.arange(new_toks.shape[0], - device=toks.device) - - # assign toks and new toks - long_toks[toks_idxs] = toks - long_toks[new_toks_idxs] = new_toks - - # compute new start locs - new_start_locs = torch.zeros_like(start_locs) - new_start_locs[1:] = new_toks_idxs + 1 - - return long_toks, new_start_locs +def extend_flat_seqs(seqs: torch.Tensor, end_locs: torch.Tensor, + new_vals: torch.Tensor) -> torch.Tensor: + """ + This function appends a single new value into multiple sequences + that are stored in a flat format. E.g. + [x1, x2, y1] and [x3, y2] become [x1, x2, x3, y1, y2] + """ + new_len = seqs.shape[0] + new_vals.shape[0] + new_seqs = torch.zeros(new_len, device=seqs.device, dtype=seqs.dtype) + + # indices for previous seqs + start_locs = end_locs[:-1] + 1 + seqs_new_idxs = torch.ones_like(seqs) + seqs_new_idxs[start_locs] += 1 + seqs_new_idxs = seqs_new_idxs.cumsum(0) - 1 + + # indices for new values + new_val_idxs = end_locs + 1 + torch.arange(new_vals.shape[0], + device=seqs.device) + # assign seqs and new vals + new_seqs[seqs_new_idxs] = seqs + new_seqs[new_val_idxs] = new_vals + + return new_seqs def _make_metadata_with_slice( diff --git a/vllm/v1/spec_decode/draft_model.py b/vllm/v1/spec_decode/draft_model.py index ea76fd4dfc65..826ec3237e10 100644 --- a/vllm/v1/spec_decode/draft_model.py +++ b/vllm/v1/spec_decode/draft_model.py @@ -9,8 +9,9 @@ from vllm.config import ModelConfig, VllmConfig, get_layers_from_vllm_config from vllm.forward_context import set_forward_context from vllm.model_executor.model_loader import get_model -from vllm.v1.attention.backends.utils import (append_new_toks, - extend_all_queries_by_1) +from vllm.v1.attention.backends.utils import (CommonAttentionMetadata, + extend_all_queries_by_1, + extend_flat_seqs) from vllm.v1.spec_decode.eagle import SpecDecodeBaseProposer @@ -31,38 +32,32 @@ def __init__( self._raise_if_mrope() def update_propose_kwargs(self, propose_kwargs: dict): - common_attn_metadata = propose_kwargs["common_attn_metadata"] + cad: CommonAttentionMetadata = propose_kwargs["common_attn_metadata"] target_token_ids = propose_kwargs["target_token_ids"] next_token_ids = propose_kwargs["next_token_ids"] target_positions = propose_kwargs["target_positions"] - token_indices_to_sample = common_attn_metadata.query_start_loc[1:] - 1 + token_indices_to_sample = cad.last_token_indices() # update target_token_ids - start_locs = torch.zeros(token_indices_to_sample.shape[0] + 1, - device=token_indices_to_sample.device, - dtype=torch.int32) - start_locs[1:] = token_indices_to_sample + 1 - new_target_token_ids, _ = append_new_toks(toks=target_token_ids, - start_locs=start_locs, - new_toks=next_token_ids) + end_locs = cad.last_token_indices() + new_target_token_ids = extend_flat_seqs(seqs=target_token_ids, + end_locs=end_locs, + new_vals=next_token_ids) # update positions positions_to_append = target_positions[token_indices_to_sample] + 1 - new_target_positions, _ = append_new_toks(toks=target_positions, - start_locs=start_locs, - new_toks=positions_to_append) - # update common_attn_metadata - new_common_attn_metadata = extend_all_queries_by_1( - common_attn_metadata, arange=self.arange) - # update token_indices_to_sample - new_token_indices_to_sample = new_common_attn_metadata.query_start_loc[ - 1:] - 1 + new_target_positions = extend_flat_seqs(seqs=target_positions, + end_locs=end_locs, + new_vals=positions_to_append) + + new_cad: CommonAttentionMetadata = extend_all_queries_by_1( + cad, arange=self.arange) new_propose_kwargs = dict( target_token_ids=new_target_token_ids, target_positions=new_target_positions, next_token_ids=None, - last_token_indices=new_token_indices_to_sample, - common_attn_metadata=new_common_attn_metadata, + last_token_indices=None, + common_attn_metadata=new_cad, ) return propose_kwargs | new_propose_kwargs diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index 068271a80837..c26a54fd51cf 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -192,7 +192,7 @@ def propose( batch_size = common_attn_metadata.batch_size() if last_token_indices is None: - last_token_indices = common_attn_metadata.query_start_loc[1:] - 1 + last_token_indices = common_attn_metadata.last_token_indices() if self.method == "eagle3": assert isinstance(self.model, Eagle3LlamaForCausalLM) From 7a1949d43c328bf2ae08d3e3637c07f5e9883c0c Mon Sep 17 00:00:00 2001 From: Tomas Ruiz Date: Tue, 30 Sep 2025 21:09:44 +0200 Subject: [PATCH 32/73] Document test Signed-off-by: Tomas Ruiz --- tests/v1/attention/test_attention_backends.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/tests/v1/attention/test_attention_backends.py b/tests/v1/attention/test_attention_backends.py index 3b89c5117fed..c665baacba9d 100644 --- a/tests/v1/attention/test_attention_backends.py +++ b/tests/v1/attention/test_attention_backends.py @@ -571,6 +571,11 @@ def sliding_window_mask_mod( @pytest.mark.parametrize("device", ["cpu", "cuda"]) def test_extend_flat_seqs(device: str): + """The extend_flat_seqs() function appends a single new value into multiple + sequences that are stored in a flat format. E.g. + [x1, x2, y1] and [x3, y2] become [x1, x2, x3, y1, y2] + """ + # fmt: off seqs = torch.tensor([11, 12, 13, 21, 22, From 316a6b822f81fa6c8d2e6da3670b5d986901027d Mon Sep 17 00:00:00 2001 From: Tomas Ruiz Date: Tue, 30 Sep 2025 21:19:05 +0200 Subject: [PATCH 33/73] Document funcs Signed-off-by: Tomas Ruiz --- vllm/v1/spec_decode/draft_model.py | 69 ++++++++++++++++++------------ 1 file changed, 41 insertions(+), 28 deletions(-) diff --git a/vllm/v1/spec_decode/draft_model.py b/vllm/v1/spec_decode/draft_model.py index 826ec3237e10..877dd22b5aaa 100644 --- a/vllm/v1/spec_decode/draft_model.py +++ b/vllm/v1/spec_decode/draft_model.py @@ -32,34 +32,8 @@ def __init__( self._raise_if_mrope() def update_propose_kwargs(self, propose_kwargs: dict): - cad: CommonAttentionMetadata = propose_kwargs["common_attn_metadata"] - target_token_ids = propose_kwargs["target_token_ids"] - next_token_ids = propose_kwargs["next_token_ids"] - target_positions = propose_kwargs["target_positions"] - token_indices_to_sample = cad.last_token_indices() - - # update target_token_ids - end_locs = cad.last_token_indices() - new_target_token_ids = extend_flat_seqs(seqs=target_token_ids, - end_locs=end_locs, - new_vals=next_token_ids) - # update positions - positions_to_append = target_positions[token_indices_to_sample] + 1 - new_target_positions = extend_flat_seqs(seqs=target_positions, - end_locs=end_locs, - new_vals=positions_to_append) - - new_cad: CommonAttentionMetadata = extend_all_queries_by_1( - cad, arange=self.arange) - - new_propose_kwargs = dict( - target_token_ids=new_target_token_ids, - target_positions=new_target_positions, - next_token_ids=None, - last_token_indices=None, - common_attn_metadata=new_cad, - ) - return propose_kwargs | new_propose_kwargs + return update_propose_kwargs(arange=self.arange, + propose_kwargs=propose_kwargs) def _raise_if_multimodal(self): if self.supports_mm_inputs: @@ -121,3 +95,42 @@ def load_model(self, target_model: Any) -> None: get_layers_from_vllm_config(self.vllm_config, Attention).keys() - target_attn_layer_names) self.attn_layer_names = list(draft_attn_layer_names) + + +def update_propose_kwargs(arange: torch.Tensor, propose_kwargs: dict): + """ + This function: + - Merges the target_token_ids and the next_token_ids into a + single flat tensor. + - Appends new positions for these next_token_ids. + - Updates the common_attn_metadata to reflect that all query lengths are +1. + """ + cad: CommonAttentionMetadata = propose_kwargs["common_attn_metadata"] + target_token_ids = propose_kwargs["target_token_ids"] + next_token_ids = propose_kwargs["next_token_ids"] + target_positions = propose_kwargs["target_positions"] + token_indices_to_sample = cad.last_token_indices() + + # merge target_token_ids and next_token_ids + end_locs = cad.last_token_indices() + new_target_token_ids = extend_flat_seqs(seqs=target_token_ids, + end_locs=end_locs, + new_vals=next_token_ids) + # append new positions + positions_to_append = target_positions[token_indices_to_sample] + 1 + new_target_positions = extend_flat_seqs(seqs=target_positions, + end_locs=end_locs, + new_vals=positions_to_append) + + # update common_attn_metadata + new_cad: CommonAttentionMetadata = extend_all_queries_by_1(cad, + arange=arange) + + new_propose_kwargs = dict( + target_token_ids=new_target_token_ids, + target_positions=new_target_positions, + next_token_ids=None, + last_token_indices=None, + common_attn_metadata=new_cad, + ) + return propose_kwargs | new_propose_kwargs From af060300f576c59207190f00929b4a548123efb1 Mon Sep 17 00:00:00 2001 From: Tomas Ruiz Date: Wed, 1 Oct 2025 12:59:59 +0200 Subject: [PATCH 34/73] Update BatchDescriptor with correct num_tokens Signed-off-by: Tomas Ruiz --- vllm/v1/spec_decode/eagle.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index c26a54fd51cf..ddaaf27ba7dd 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -261,6 +261,12 @@ def propose( num_tokens=num_input_tokens, ) if self.pass_cudagraph_args_to_forward_ctx: + # Update num_tokens in batch descriptor, eg after cudagraph padding + old_bd: BatchDescriptor = cudagraph_args["batch_descriptor"] + if old_bd is not None: + new_bd = BatchDescriptor(num_tokens=num_input_tokens, + uniform_decode=old_bd.uniform_decode) + cudagraph_args["batch_descriptor"] = new_bd forward_ctx_kwargs.update(cudagraph_args) with set_forward_context(**forward_ctx_kwargs): From a791d2ebbb69f0691891f418ae8d201119b4349e Mon Sep 17 00:00:00 2001 From: Tomas Ruiz Date: Wed, 1 Oct 2025 13:31:06 +0200 Subject: [PATCH 35/73] Make sure AL benchmark can run Signed-off-by: Tomas Ruiz --- examples/offline_inference/spec_decode.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/examples/offline_inference/spec_decode.py b/examples/offline_inference/spec_decode.py index f17f5d054505..5a788c8b9b55 100644 --- a/examples/offline_inference/spec_decode.py +++ b/examples/offline_inference/spec_decode.py @@ -54,7 +54,7 @@ def parse_args(): "--method", type=str, default="eagle", - choices=["ngram", "eagle", "eagle3", "mtp"], + choices=["ngram", "eagle", "eagle3", "mtp", "draft_model"], ) parser.add_argument("--num-spec-tokens", type=int, default=2) parser.add_argument("--prompt-lookup-max", type=int, default=5) @@ -74,7 +74,6 @@ def parse_args(): parser.add_argument("--custom-mm-prompts", action="store_true") parser.add_argument("--gpu-memory-utilization", type=float, default=0.8) parser.add_argument("--request-id-prefix", type=str, default="") - parser.add_argument("--max-model-len", type=int, default=16384) return parser.parse_args() From 1de5ef47875b7f70e7ddda8fd386ba8f70cf68ce Mon Sep 17 00:00:00 2001 From: Tomas Ruiz Date: Wed, 1 Oct 2025 14:03:45 +0200 Subject: [PATCH 36/73] Extend drafter max_num_tokens Signed-off-by: Tomas Ruiz --- vllm/v1/spec_decode/eagle.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index ddaaf27ba7dd..543b00e318ff 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -64,7 +64,9 @@ def __init__( self.num_speculative_tokens = ( self.speculative_config.num_speculative_tokens) self.max_num_tokens = ( - vllm_config.scheduler_config.max_num_batched_tokens) + vllm_config.scheduler_config.max_num_batched_tokens + + vllm_config.scheduler_config.max_num_seqs * + self.num_speculative_tokens) self.token_arange_np = np.arange(self.max_num_tokens) # We need to get the hidden size from the draft model config because # the draft model's hidden size can be different from the target model's From 4371d47e634a6eff06a6f713975c70a6121d0e8e Mon Sep 17 00:00:00 2001 From: Tomas Ruiz Date: Wed, 1 Oct 2025 18:39:00 +0200 Subject: [PATCH 37/73] CKPT: Find bug affecting acceptance length --- vllm/v1/attention/backends/utils.py | 12 +++++------- vllm/v1/spec_decode/draft_model.py | 27 ++++++++++++++++++++++----- vllm/v1/spec_decode/eagle.py | 26 +++++++++++++++++++++++++- vllm/v1/worker/gpu_model_runner.py | 28 ++++++++++++++++++++++++++++ 4 files changed, 80 insertions(+), 13 deletions(-) diff --git a/vllm/v1/attention/backends/utils.py b/vllm/v1/attention/backends/utils.py index 98b99239c22e..2a78eae8f06c 100644 --- a/vllm/v1/attention/backends/utils.py +++ b/vllm/v1/attention/backends/utils.py @@ -86,9 +86,6 @@ class CommonAttentionMetadata: def batch_size(self) -> int: return self.seq_lens_cpu.shape[0] - def last_token_indices(self) -> torch.Tensor: - return self.query_start_loc[1:] - 1 - def slice_query_start_locs( query_start_loc: torch.Tensor, @@ -105,8 +102,9 @@ def slice_query_start_locs( query_start_loc[request_slice.start] -def extend_all_queries_by_1(common_attn_metadata: CommonAttentionMetadata, - arange: torch.Tensor) -> CommonAttentionMetadata: +def extend_all_queries_by_1( + common_attn_metadata: CommonAttentionMetadata, arange: torch.Tensor, + last_token_indices: torch.Tensor) -> CommonAttentionMetadata: """ Creates a new CommonAttentionMetadata with all query lengths increased by 1. Also all seq lens are increased by 1. @@ -119,9 +117,9 @@ def extend_all_queries_by_1(common_attn_metadata: CommonAttentionMetadata, + arange[:len(cad.query_start_loc)] new_seq_lens = cad.seq_lens + 1 # slot mappings are extended (interleaved) by the next serial id - last_slot_mapping_ids = cad.slot_mapping[cad.last_token_indices()] + last_slot_mapping_ids = cad.slot_mapping[last_token_indices] new_slot_mapping = extend_flat_seqs(seqs=cad.slot_mapping, - end_locs=cad.last_token_indices(), + end_locs=last_token_indices, new_vals=last_slot_mapping_ids + 1) new_cad = CommonAttentionMetadata( query_start_loc=new_query_start_loc, diff --git a/vllm/v1/spec_decode/draft_model.py b/vllm/v1/spec_decode/draft_model.py index 877dd22b5aaa..58986848a430 100644 --- a/vllm/v1/spec_decode/draft_model.py +++ b/vllm/v1/spec_decode/draft_model.py @@ -8,6 +8,7 @@ from vllm.attention.layer import Attention from vllm.config import ModelConfig, VllmConfig, get_layers_from_vllm_config from vllm.forward_context import set_forward_context +from vllm.logger import init_logger from vllm.model_executor.model_loader import get_model from vllm.v1.attention.backends.utils import (CommonAttentionMetadata, extend_all_queries_by_1, @@ -97,6 +98,9 @@ def load_model(self, target_model: Any) -> None: self.attn_layer_names = list(draft_attn_layer_names) +logger = init_logger(__name__) + + def update_propose_kwargs(arange: torch.Tensor, propose_kwargs: dict): """ This function: @@ -109,10 +113,12 @@ def update_propose_kwargs(arange: torch.Tensor, propose_kwargs: dict): target_token_ids = propose_kwargs["target_token_ids"] next_token_ids = propose_kwargs["next_token_ids"] target_positions = propose_kwargs["target_positions"] - token_indices_to_sample = cad.last_token_indices() + token_indices_to_sample = propose_kwargs["last_token_indices"] + if token_indices_to_sample is None: + token_indices_to_sample = cad.query_start_loc[1:] - 1 # merge target_token_ids and next_token_ids - end_locs = cad.last_token_indices() + end_locs = token_indices_to_sample new_target_token_ids = extend_flat_seqs(seqs=target_token_ids, end_locs=end_locs, new_vals=next_token_ids) @@ -123,14 +129,25 @@ def update_propose_kwargs(arange: torch.Tensor, propose_kwargs: dict): new_vals=positions_to_append) # update common_attn_metadata - new_cad: CommonAttentionMetadata = extend_all_queries_by_1(cad, - arange=arange) + new_cad: CommonAttentionMetadata = extend_all_queries_by_1( + cad, arange=arange, last_token_indices=token_indices_to_sample) + + # new token indices to sample incease by [+1, +2, +3, ..., +batch_size] + new_token_indices_to_sample = token_indices_to_sample \ + + arange_like(token_indices_to_sample) + 1 + + logger.info("old last_token_indices: %s, new last_token_indices: %s.", + token_indices_to_sample, new_token_indices_to_sample) new_propose_kwargs = dict( target_token_ids=new_target_token_ids, target_positions=new_target_positions, next_token_ids=None, - last_token_indices=None, + last_token_indices=new_token_indices_to_sample, common_attn_metadata=new_cad, ) return propose_kwargs | new_propose_kwargs + + +def arange_like(x: torch.Tensor) -> torch.Tensor: + return torch.arange(x.shape[0], device=x.device) diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index 543b00e318ff..73e71da7a43d 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -194,7 +194,7 @@ def propose( batch_size = common_attn_metadata.batch_size() if last_token_indices is None: - last_token_indices = common_attn_metadata.last_token_indices() + last_token_indices = common_attn_metadata.query_start_loc[1:] - 1 if self.method == "eagle3": assert isinstance(self.model, Eagle3LlamaForCausalLM) @@ -280,6 +280,19 @@ def propose( last_hidden_states, hidden_states = ret_hidden_states sample_hidden_states = last_hidden_states[last_token_indices] logits = self.model.compute_logits(sample_hidden_states) + self.runner.log_tokens("draft input_ids", input_ids) + if self.runner.do_log: + logger.info("draft positions: %s", model_kwargs["positions"]) + logger.info("draft position length: %s", + model_kwargs["positions"].shape[0]) + logger.info("draft last_token_indices: %s", last_token_indices) + logger.info("draft batch_descriptor: %s", + forward_ctx_kwargs["batch_descriptor"]) + logger.info("draft attn_metadata.slot_mapping: %s", + attn_metadata.slot_mapping) + for _idx, _block in enumerate(attn_metadata.block_table): + logger.info("draft attn_metadata.block_table [%d]: %s", _idx, + _block.tolist()) # Early exit if there is only one draft token to be generated. if self.num_speculative_tokens == 1: @@ -443,6 +456,15 @@ def propose( hidden_states = ret_hidden_states else: last_hidden_states, hidden_states = ret_hidden_states + self.runner.log_tokens("draft input_ids", input_ids) + if self.runner.do_log: + logger.info("draft positions: %s", model_kwargs["positions"]) + logger.info("draft position length: %s", + model_kwargs["positions"].shape[0]) + logger.info("draft batch_descriptor: %s", + forward_ctx_kwargs["batch_descriptor"]) + # logger.info(f"draft attn_metadata={attn_metadata}") + hidden_states = hidden_states[:batch_size] logits = self.model.compute_logits(last_hidden_states[:batch_size]) draft_token_ids = logits.argmax(dim=-1) @@ -450,6 +472,8 @@ def propose( # [batch_size, num_speculative_tokens] draft_token_ids = torch.stack(draft_token_ids_list, dim=1) + for idx, tokens in enumerate(draft_token_ids): + self.runner.log_tokens(f"draft token ids {idx}", tokens) return draft_token_ids def set_input_ids_first_pass(self, target_token_ids: torch.Tensor, diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 6305d0024aa4..e4e3418008c0 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -174,11 +174,21 @@ def get_output(self) -> ModelRunnerOutput: class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): + def log_tokens(self, msg: str, tokens: list[int]): + if not self.do_log: + return + logger.info("%s: %s", msg, [self.tokenizer.decode(t) for t in tokens]) + def __init__( self, vllm_config: VllmConfig, device: torch.device, ): + self.do_log = True + if self.do_log: + from transformers import AutoTokenizer + self.tokenizer = AutoTokenizer.from_pretrained( + vllm_config.model_config.model) self.vllm_config = vllm_config self.model_config = vllm_config.model_config self.cache_config = vllm_config.cache_config @@ -2313,6 +2323,8 @@ def execute_model( scheduler_output: "SchedulerOutput", intermediate_tensors: Optional[IntermediateTensors] = None, ) -> Union[ModelRunnerOutput, AsyncModelRunnerOutput, IntermediateTensors]: + if self.do_log: + logger.info("======STEP======") with record_function_or_nullcontext("Preprocess"): with self.synchronize_input_prep(): # Update persistent batch states. @@ -2383,6 +2395,18 @@ def execute_model( inputs_embeds=inputs_embeds, **model_kwargs, ) + self.log_tokens("tgt input_ids", input_ids) + if self.do_log: + logger.info("tgt positions: %s", positions) + logger.info("tgt position length: %s", positions.shape[0]) + logger.info("tgt logits_indices: %s", logits_indices) + logger.info("tgt batch_descriptor: %s", batch_descriptor) + _attn_metadata = list(attn_metadata.values())[0] + logger.info("tgt attn_metadata.slot_mapping: %s", + _attn_metadata.slot_mapping) + for _idx, _block in enumerate(_attn_metadata.block_table): + logger.info("tgt attn_metadata.block_table [%d]: %s", _idx, + _block.tolist()) with record_function_or_nullcontext("Postprocess"): if self.use_aux_hidden_state_outputs: @@ -2446,6 +2470,8 @@ def execute_model( with record_function_or_nullcontext("Sample"): sampler_output = self._sample(logits, spec_decode_metadata) + for idx, tokens in enumerate(sampler_output.sampled_token_ids): + self.log_tokens(f"Sampler output {idx}", tokens[tokens != -1]) def propose_draft_token_ids(sampled_token_ids): assert spec_decode_common_attn_metadata is not None @@ -2632,6 +2658,8 @@ def propose_draft_token_ids( target_hidden_states = hidden_states[:num_scheduled_tokens] else: if self.speculative_config.disable_padded_drafter_batch: + raise NotImplementedError( + "This path cannot be reached using `vllm serve ...`") token_indices_to_sample = None common_attn_metadata, token_indices =\ self.drafter.prepare_inputs( From 1718892649476860421bb38be023ddb191c42b55 Mon Sep 17 00:00:00 2001 From: Tomas Ruiz Date: Thu, 2 Oct 2025 14:23:19 +0200 Subject: [PATCH 38/73] Fix AL for default drafter padding --- make-problem.sh | 12 ++ tests/v1/e2e/test_spec_decode.py | 3 - tests/v1/test_outputs.py | 20 +++ vllm/v1/attention/backends/utils.py | 17 ++- vllm/v1/outputs.py | 3 + vllm/v1/spec_decode/draft_model.py | 200 +++++++++++++++++++++------- vllm/v1/spec_decode/eagle.py | 83 +++++++----- vllm/v1/worker/gpu_model_runner.py | 11 +- 8 files changed, 260 insertions(+), 89 deletions(-) create mode 100755 make-problem.sh create mode 100644 tests/v1/test_outputs.py diff --git a/make-problem.sh b/make-problem.sh new file mode 100755 index 000000000000..6265e0d2ad71 --- /dev/null +++ b/make-problem.sh @@ -0,0 +1,12 @@ +export CUDA_LAUNCH_BLOCKING=1 +nohup python examples/offline_inference/spec_decode.py \ + --model-dir Qwen/Qwen3-1.7B \ + --draft-model Qwen/Qwen3-1.7B \ + --method draft_model \ + --num_spec_tokens 3 \ + --dataset-name hf \ + --dataset-path philschmid/mt-bench \ + --num_prompts 80 \ + --temp 0.0 \ + --gpu-memory-utilization 0.9 \ + --enforce-eager > al.log 2>&1 & \ No newline at end of file diff --git a/tests/v1/e2e/test_spec_decode.py b/tests/v1/e2e/test_spec_decode.py index 4200d8d01b64..1434f96faa86 100644 --- a/tests/v1/e2e/test_spec_decode.py +++ b/tests/v1/e2e/test_spec_decode.py @@ -352,11 +352,9 @@ def test_mtp_correctness( @pytest.mark.parametrize("args", cases) @pytest.mark.parametrize("enforce_eager", [True, False]) -@pytest.mark.parametrize("disable_padded_drafter_batch", [True, False]) def test_draft_model_correctness( args: ArgsTest, enforce_eager: bool, - disable_padded_drafter_batch: bool, monkeypatch: pytest.MonkeyPatch, ): """Compare the outputs using and not using speculative decoding. @@ -373,7 +371,6 @@ def test_draft_model_correctness( "max_model_len": args.max_model_len, "enforce_eager": enforce_eager, "tensor_parallel_size": args.draft_tensor_parallel_size, - "disable_padded_drafter_batch": disable_padded_drafter_batch, }, max_model_len=args.max_model_len, gpu_memory_utilization=args.gpu_memory_utilization, diff --git a/tests/v1/test_outputs.py b/tests/v1/test_outputs.py new file mode 100644 index 000000000000..7556d5e8f4f7 --- /dev/null +++ b/tests/v1/test_outputs.py @@ -0,0 +1,20 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import torch + +from vllm.v1.outputs import SamplerOutput + + +def test_sampler_output(): + # fmt: off + # -1 is the padding token + sampled_token_ids = torch.tensor([ + [1, 2, 3, -1], + [1, -1, -1, -1], + [3, 2, -1, -1] + ]) + # fmt: on + so = SamplerOutput(sampled_token_ids=sampled_token_ids, + logprobs_tensors=None) + expected_n_sampled_tokens = torch.tensor([3, 1, 2]) + assert so.n_sampled_tokens().eq(expected_n_sampled_tokens).all() diff --git a/vllm/v1/attention/backends/utils.py b/vllm/v1/attention/backends/utils.py index 2a78eae8f06c..0ed7495ba146 100644 --- a/vllm/v1/attention/backends/utils.py +++ b/vllm/v1/attention/backends/utils.py @@ -4,7 +4,7 @@ import enum import functools from abc import abstractmethod -from dataclasses import dataclass, fields, make_dataclass +from dataclasses import dataclass, fields, make_dataclass, replace from typing import (TYPE_CHECKING, Any, ClassVar, Generic, Literal, Optional, Protocol, TypeVar, Union, get_args) @@ -86,6 +86,12 @@ class CommonAttentionMetadata: def batch_size(self) -> int: return self.seq_lens_cpu.shape[0] + def replace(self, **kwargs) -> "CommonAttentionMetadata": + return replace(self, **kwargs) + + def query_lens(self) -> torch.Tensor: + return self.query_start_loc[1:] - self.query_start_loc[:-1] + def slice_query_start_locs( query_start_loc: torch.Tensor, @@ -104,23 +110,20 @@ def slice_query_start_locs( def extend_all_queries_by_1( common_attn_metadata: CommonAttentionMetadata, arange: torch.Tensor, - last_token_indices: torch.Tensor) -> CommonAttentionMetadata: + new_slot_mapping: torch.Tensor) -> CommonAttentionMetadata: """ Creates a new CommonAttentionMetadata with all query lengths increased by 1. Also all seq lens are increased by 1. This is useful e.g. in speculative decoding with draft models, where we extend each sequence by 1 token. + The slot mapping is computed externally, as it requires more information. """ cad = common_attn_metadata # query start loc must be increased by [+0, +1, +2, ..., +batch_size] new_query_start_loc = cad.query_start_loc \ + arange[:len(cad.query_start_loc)] new_seq_lens = cad.seq_lens + 1 - # slot mappings are extended (interleaved) by the next serial id - last_slot_mapping_ids = cad.slot_mapping[last_token_indices] - new_slot_mapping = extend_flat_seqs(seqs=cad.slot_mapping, - end_locs=last_token_indices, - new_vals=last_slot_mapping_ids + 1) + new_cad = CommonAttentionMetadata( query_start_loc=new_query_start_loc, query_start_loc_cpu=new_query_start_loc.to("cpu"), diff --git a/vllm/v1/outputs.py b/vllm/v1/outputs.py index 01f3676abd92..ebdfbe5d5984 100644 --- a/vllm/v1/outputs.py +++ b/vllm/v1/outputs.py @@ -80,6 +80,9 @@ class SamplerOutput: sampled_token_ids: torch.Tensor logprobs_tensors: Optional[LogprobsTensors] + def n_sampled_tokens(self) -> torch.Tensor: + return self.sampled_token_ids.ne(-1).sum(dim=1) + @dataclass class KVConnectorOutput: diff --git a/vllm/v1/spec_decode/draft_model.py b/vllm/v1/spec_decode/draft_model.py index 58986848a430..0ea94f10d4a8 100644 --- a/vllm/v1/spec_decode/draft_model.py +++ b/vllm/v1/spec_decode/draft_model.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from dataclasses import replace -from typing import Any +from dataclasses import dataclass, replace +from typing import Any, Optional import torch @@ -13,7 +13,10 @@ from vllm.v1.attention.backends.utils import (CommonAttentionMetadata, extend_all_queries_by_1, extend_flat_seqs) -from vllm.v1.spec_decode.eagle import SpecDecodeBaseProposer +from vllm.v1.outputs import SamplerOutput +from vllm.v1.spec_decode.eagle import (PADDING_SLOT_ID, SpecDecodeBaseProposer, + num_rejected_tokens) +from vllm.v1.spec_decode.metadata import SpecDecodeMetadata class DraftModelProposer(SpecDecodeBaseProposer): @@ -31,10 +34,17 @@ def __init__( runner=runner) self._raise_if_multimodal() self._raise_if_mrope() + self._raise_if_disabled_padded_drafter_batch() - def update_propose_kwargs(self, propose_kwargs: dict): + def update_propose_kwargs( + self, propose_kwargs: dict, sampler_output: SamplerOutput, + spec_decode_metadata: Optional[SpecDecodeMetadata]): return update_propose_kwargs(arange=self.arange, - propose_kwargs=propose_kwargs) + propose_kwargs=propose_kwargs, + sampler_output=sampler_output, + spec_decode_metadata=spec_decode_metadata, + block_size=self.block_size, + max_model_len=self.max_model_len) def _raise_if_multimodal(self): if self.supports_mm_inputs: @@ -46,6 +56,12 @@ def _raise_if_mrope(self): raise NotImplementedError("Speculative Decoding with draft models " "does not support M-RoPE yet") + def _raise_if_disabled_padded_drafter_batch(self): + if self.vllm_config.speculative_config.disable_padded_drafter_batch: + raise NotImplementedError( + "Speculative Decoding with draft models does not support " + "disabled padded drafter batch yet") + def _model_kwargs(self, num_tokens: int) -> dict[str, Any]: self._raise_if_multimodal() self._raise_if_mrope() @@ -101,53 +117,147 @@ def load_model(self, target_model: Any) -> None: logger = init_logger(__name__) -def update_propose_kwargs(arange: torch.Tensor, propose_kwargs: dict): +def update_propose_kwargs(arange: torch.Tensor, propose_kwargs: dict, + sampler_output: SamplerOutput, + spec_decode_metadata: Optional[SpecDecodeMetadata], + block_size: int, max_model_len: int) -> dict: """ - This function: - - Merges the target_token_ids and the next_token_ids into a - single flat tensor. - - Appends new positions for these next_token_ids. - - Updates the common_attn_metadata to reflect that all query lengths are +1. + - Trims unnecessary tokens from the input, like those rejected by + the sampler, or those already processed by the draft model. + - Merges the next_token_ids with the existing token ids into + a flat sequence. """ cad: CommonAttentionMetadata = propose_kwargs["common_attn_metadata"] - target_token_ids = propose_kwargs["target_token_ids"] - next_token_ids = propose_kwargs["next_token_ids"] - target_positions = propose_kwargs["target_positions"] - token_indices_to_sample = propose_kwargs["last_token_indices"] - if token_indices_to_sample is None: - token_indices_to_sample = cad.query_start_loc[1:] - 1 - - # merge target_token_ids and next_token_ids - end_locs = token_indices_to_sample - new_target_token_ids = extend_flat_seqs(seqs=target_token_ids, - end_locs=end_locs, - new_vals=next_token_ids) - # append new positions - positions_to_append = target_positions[token_indices_to_sample] + 1 - new_target_positions = extend_flat_seqs(seqs=target_positions, - end_locs=end_locs, - new_vals=positions_to_append) + inputs = DraftModelInputs(cad=cad, + token_ids=propose_kwargs["target_token_ids"], + positions=propose_kwargs["target_positions"]) + inputs = trim_accepted_and_rejected_tokens( + inputs=inputs, + sampler_output=sampler_output, + spec_decode_metadata=spec_decode_metadata) + inputs = merge_next_token_ids_into_token_ids( + inputs=inputs, + next_token_ids=propose_kwargs["next_token_ids"], + block_size=block_size, + max_model_len=max_model_len, + arange=arange) + new_propose_kwargs = dict( + target_token_ids=inputs.token_ids, + target_positions=inputs.positions, + next_token_ids=None, + last_token_indices=None, + common_attn_metadata=inputs.cad, + ) + return propose_kwargs | new_propose_kwargs - # update common_attn_metadata - new_cad: CommonAttentionMetadata = extend_all_queries_by_1( - cad, arange=arange, last_token_indices=token_indices_to_sample) - # new token indices to sample incease by [+1, +2, +3, ..., +batch_size] - new_token_indices_to_sample = token_indices_to_sample \ - + arange_like(token_indices_to_sample) + 1 +@dataclass +class DraftModelInputs: + token_ids: torch.Tensor + positions: torch.Tensor + cad: CommonAttentionMetadata - logger.info("old last_token_indices: %s, new last_token_indices: %s.", - token_indices_to_sample, new_token_indices_to_sample) - new_propose_kwargs = dict( - target_token_ids=new_target_token_ids, - target_positions=new_target_positions, - next_token_ids=None, - last_token_indices=new_token_indices_to_sample, - common_attn_metadata=new_cad, +def trim_accepted_and_rejected_tokens( + inputs: DraftModelInputs, sampler_output: SamplerOutput, + spec_decode_metadata: Optional[SpecDecodeMetadata] +) -> DraftModelInputs: + """ + Removes from the input.token_ids any tokens that have already been processed + by the draft model, as well as tokens rejected by the sampler. + Adjusts the positions accordingly, the slot mapping, + and the common_attn_metadata. + """ + cad: CommonAttentionMetadata = inputs.cad + + # Compute the new token ids and positions + n_accepted_tokens = sampler_output.n_sampled_tokens() - 1 + n_rejected_tokens = num_rejected_tokens(spec_decode_metadata, + sampler_output.n_sampled_tokens()) + from_loc = cad.query_start_loc[:-1] + n_accepted_tokens + to_loc = cad.query_start_loc[1:] - 1 - n_rejected_tokens + idxs = compute_subrange_indices(from_loc, to_loc) + new_token_ids = inputs.token_ids[idxs] + new_positions = inputs.positions[idxs] + + # The new slot mapping is a subset of the previous one, + # so no recomputation is needed. + new_slot_mapping = cad.slot_mapping[idxs] + + # Update common_attn_metadata + new_query_lens = to_loc - from_loc + 1 + new_query_start_loc = torch.zeros_like(cad.query_start_loc) + new_query_start_loc[1:] = new_query_lens.cumsum(0) + + new_cad: CommonAttentionMetadata = cad.replace( + query_start_loc=new_query_start_loc, + query_start_loc_cpu=new_query_start_loc.to("cpu"), + num_actual_tokens=new_token_ids.shape[0], + max_query_len=new_query_lens.max().item(), + slot_mapping=new_slot_mapping, ) - return propose_kwargs | new_propose_kwargs + return DraftModelInputs(token_ids=new_token_ids, + positions=new_positions, + cad=new_cad) + +def compute_subrange_indices(from_locs: torch.Tensor, to_locs: torch.Tensor): + # Compute lengths of each subrange + lengths = to_locs - from_locs + 1 + # Build an index for each subrange + # torch.arange(max_len) creates [0, 1, ..., max_len-1] + # broadcasting + masking ensures we only keep valid positions + max_len = lengths.max() + offsets = torch.arange(max_len, device=from_locs.device).unsqueeze( + 0) # shape [1, max_len] + mask = offsets < lengths.unsqueeze(1) # shape [n, max_len] + # Build all indices + all_indices = from_locs.unsqueeze(1) + offsets + all_indices = all_indices[mask] # flatten valid indices only + return all_indices -def arange_like(x: torch.Tensor) -> torch.Tensor: - return torch.arange(x.shape[0], device=x.device) + +def merge_next_token_ids_into_token_ids( + inputs: DraftModelInputs, + next_token_ids: torch.Tensor, + block_size: int, + max_model_len: int, + arange: torch.Tensor, +) -> DraftModelInputs: + """ + Merges the next token ids with the existing token ids into a flat sequence. + Does the same for the positions, computes new slot mapping, + and updates the common_attn_metadata. + """ + cad: CommonAttentionMetadata = inputs.cad + + # merge token_ids and next_token_ids + query_end_locs = cad.query_start_loc[1:] - 1 + new_token_ids = extend_flat_seqs(seqs=inputs.token_ids, + end_locs=query_end_locs, + new_vals=next_token_ids) + # append new positions + positions_to_append = inputs.positions[query_end_locs] + 1 + new_positions = extend_flat_seqs(seqs=inputs.positions, + end_locs=query_end_locs, + new_vals=positions_to_append) + + # recompute slot mapping + batch_size, n_blocks_per_req = cad.block_table_tensor.shape + req_indices = torch.arange(batch_size, device=cad.query_start_loc.device) + req_indices = torch.repeat_interleave(req_indices, cad.query_lens() + 1) + block_table_indices = (req_indices * n_blocks_per_req + + new_positions // block_size) + block_nums = cad.block_table_tensor.view(-1)[block_table_indices] + block_offsets = new_positions % block_size + new_slot_mapping = block_nums * block_size + block_offsets + # Mask out the position ids that exceed the max model length. + exceeds_max_model_len = new_positions >= max_model_len + new_slot_mapping.masked_fill_(exceeds_max_model_len, PADDING_SLOT_ID) + + # update common_attn_metadata + new_cad: CommonAttentionMetadata = extend_all_queries_by_1( + cad, arange=arange, new_slot_mapping=new_slot_mapping) + return DraftModelInputs(token_ids=new_token_ids, + positions=new_positions, + cad=new_cad) diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index 73e71da7a43d..6c6ea26af332 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -386,27 +386,14 @@ def propose( common_attn_metadata.seq_lens_cpu - 1 # Compute the slot mapping. - if self.uses_mrope: - # all dimensions of positions are the same - block_numbers = clamped_positions[0] // self.block_size - else: - block_numbers = clamped_positions // self.block_size - block_ids = common_attn_metadata.block_table_tensor.gather( - dim=1, index=block_numbers.view(-1, 1)) - block_ids = block_ids.view(-1) - if self.uses_mrope: - common_attn_metadata.slot_mapping = ( - block_ids * self.block_size + - clamped_positions[0] % self.block_size) - else: - common_attn_metadata.slot_mapping = ( - block_ids * self.block_size + - clamped_positions % self.block_size) + slot_mapping = self.compute_slot_mapping( + positions=clamped_positions, + block_table_tensor=common_attn_metadata.block_table_tensor) # Mask out the slot mappings that exceed the max model length. # Otherwise, the KV cache will be inadvertently updated with the # padding tokens. - common_attn_metadata.slot_mapping.masked_fill_( - exceeds_max_model_len, PADDING_SLOT_ID) + slot_mapping.masked_fill_(exceeds_max_model_len, PADDING_SLOT_ID) + common_attn_metadata.slot_mapping = slot_mapping # Rebuild attention metadata attn_metadata = builder.build_for_drafting( # type: ignore @@ -463,7 +450,11 @@ def propose( model_kwargs["positions"].shape[0]) logger.info("draft batch_descriptor: %s", forward_ctx_kwargs["batch_descriptor"]) - # logger.info(f"draft attn_metadata={attn_metadata}") + logger.info("draft attn_metadata.slot_mapping: %s", + attn_metadata.slot_mapping) + for _idx, _block in enumerate(attn_metadata.block_table): + logger.info("draft attn_metadata.block_table [%d]: %s", + _idx, _block.tolist()) hidden_states = hidden_states[:batch_size] logits = self.model.compute_logits(last_hidden_states[:batch_size]) @@ -476,6 +467,24 @@ def propose( self.runner.log_tokens(f"draft token ids {idx}", tokens) return draft_token_ids + def compute_slot_mapping(self, positions: torch.Tensor, + block_table_tensor: torch.Tensor) -> torch.Tensor: + if self.uses_mrope: + # all dimensions of positions are the same + block_numbers = positions[0] // self.block_size + else: + block_numbers = positions // self.block_size + block_ids = block_table_tensor.gather(dim=1, + index=block_numbers.view(-1, 1)) + block_ids = block_ids.view(-1) + if self.uses_mrope: + slot_mapping = (block_ids * self.block_size + + positions[0] % self.block_size) + else: + slot_mapping = (block_ids * self.block_size + + positions % self.block_size) + return slot_mapping + def set_input_ids_first_pass(self, target_token_ids: torch.Tensor, next_token_ids: torch.Tensor, num_tokens: int, last_token_indices: torch.Tensor) -> None: @@ -611,17 +620,12 @@ def prepare_inputs_padded(self, used as padding and filtered out later by `token_indices_to_sample`. No blocking CPU operations should be introduced in this function. """ - num_draft_tokens_gpu = torch.cat([ - spec_decode_metadata.cu_num_draft_tokens[0:1], - spec_decode_metadata.cu_num_draft_tokens[1:] - - spec_decode_metadata.cu_num_draft_tokens[:-1] - ]) - - num_rejected_tokens_gpu = torch.where( - num_draft_tokens_gpu > 0, - num_draft_tokens_gpu + 1 - valid_sampled_tokens_count, - torch.zeros_like(num_draft_tokens_gpu)) - + num_rejected_tokens_gpu = num_rejected_tokens( + spec_decode_metadata, valid_sampled_tokens_count) + if self.runner.do_log: + logger.info("valid_sampled_tokens_count: %s", + valid_sampled_tokens_count) + logger.info("num_rejected_tokens_gpu: %s", num_rejected_tokens_gpu) query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu new_query_len_per_req = (query_start_loc_cpu[1:] - @@ -1147,3 +1151,22 @@ def compute_probs_and_sample_next_token( next_token_ids, ) return next_token_ids, probs + + +def num_rejected_tokens( + spec_decode_metadata: Optional[SpecDecodeMetadata], + valid_sampled_tokens_count: torch.Tensor) -> torch.Tensor: + if spec_decode_metadata is None: + return torch.zeros_like(valid_sampled_tokens_count) + + num_draft_tokens_gpu = torch.cat([ + spec_decode_metadata.cu_num_draft_tokens[0:1], + spec_decode_metadata.cu_num_draft_tokens[1:] - + spec_decode_metadata.cu_num_draft_tokens[:-1] + ]) + + num_rejected_tokens_gpu = torch.where( + num_draft_tokens_gpu > 0, + num_draft_tokens_gpu + 1 - valid_sampled_tokens_count, + torch.zeros_like(num_draft_tokens_gpu)) + return num_rejected_tokens_gpu diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index e4e3418008c0..d68b5d8931e4 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -184,7 +184,7 @@ def __init__( vllm_config: VllmConfig, device: torch.device, ): - self.do_log = True + self.do_log = False if self.do_log: from transformers import AutoTokenizer self.tokenizer = AutoTokenizer.from_pretrained( @@ -2472,6 +2472,9 @@ def execute_model( sampler_output = self._sample(logits, spec_decode_metadata) for idx, tokens in enumerate(sampler_output.sampled_token_ids): self.log_tokens(f"Sampler output {idx}", tokens[tokens != -1]) + if self.do_log: + logger.info("n_sampled_tokens: %s", + sampler_output.n_sampled_tokens()) def propose_draft_token_ids(sampled_token_ids): assert spec_decode_common_attn_metadata is not None @@ -2487,6 +2490,7 @@ def propose_draft_token_ids(sampled_token_ids): spec_decode_common_attn_metadata, cudagraph_runtime_mode=cudagraph_runtime_mode, batch_descriptor=batch_descriptor, + sampler_output=sampler_output, ) use_padded_batch = self.speculative_config and \ @@ -2578,6 +2582,7 @@ def propose_draft_token_ids( common_attn_metadata: CommonAttentionMetadata, cudagraph_runtime_mode: CUDAGraphMode, batch_descriptor: BatchDescriptor, + sampler_output: SamplerOutput, ) -> Union[list[list[int]], torch.Tensor]: num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens if self.speculative_config.method == "ngram": @@ -2658,8 +2663,6 @@ def propose_draft_token_ids( target_hidden_states = hidden_states[:num_scheduled_tokens] else: if self.speculative_config.disable_padded_drafter_batch: - raise NotImplementedError( - "This path cannot be reached using `vllm serve ...`") token_indices_to_sample = None common_attn_metadata, token_indices =\ self.drafter.prepare_inputs( @@ -2709,7 +2712,7 @@ def propose_draft_token_ids( ) if isinstance(self.drafter, DraftModelProposer): propose_kwargs = self.drafter.update_propose_kwargs( - propose_kwargs) + propose_kwargs, sampler_output, spec_decode_metadata) draft_token_ids = self.drafter.propose(**propose_kwargs) return draft_token_ids From ac56891396840d80f2480a1ee0f18bf2411dfc86 Mon Sep 17 00:00:00 2001 From: Tomas Ruiz Date: Thu, 2 Oct 2025 16:37:18 +0200 Subject: [PATCH 39/73] Remove logging Signed-off-by: Tomas Ruiz --- vllm/v1/spec_decode/eagle.py | 31 ------------------------------ vllm/v1/worker/gpu_model_runner.py | 29 ---------------------------- 2 files changed, 60 deletions(-) diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index 6c6ea26af332..171cbbca45eb 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -280,19 +280,6 @@ def propose( last_hidden_states, hidden_states = ret_hidden_states sample_hidden_states = last_hidden_states[last_token_indices] logits = self.model.compute_logits(sample_hidden_states) - self.runner.log_tokens("draft input_ids", input_ids) - if self.runner.do_log: - logger.info("draft positions: %s", model_kwargs["positions"]) - logger.info("draft position length: %s", - model_kwargs["positions"].shape[0]) - logger.info("draft last_token_indices: %s", last_token_indices) - logger.info("draft batch_descriptor: %s", - forward_ctx_kwargs["batch_descriptor"]) - logger.info("draft attn_metadata.slot_mapping: %s", - attn_metadata.slot_mapping) - for _idx, _block in enumerate(attn_metadata.block_table): - logger.info("draft attn_metadata.block_table [%d]: %s", _idx, - _block.tolist()) # Early exit if there is only one draft token to be generated. if self.num_speculative_tokens == 1: @@ -443,18 +430,6 @@ def propose( hidden_states = ret_hidden_states else: last_hidden_states, hidden_states = ret_hidden_states - self.runner.log_tokens("draft input_ids", input_ids) - if self.runner.do_log: - logger.info("draft positions: %s", model_kwargs["positions"]) - logger.info("draft position length: %s", - model_kwargs["positions"].shape[0]) - logger.info("draft batch_descriptor: %s", - forward_ctx_kwargs["batch_descriptor"]) - logger.info("draft attn_metadata.slot_mapping: %s", - attn_metadata.slot_mapping) - for _idx, _block in enumerate(attn_metadata.block_table): - logger.info("draft attn_metadata.block_table [%d]: %s", - _idx, _block.tolist()) hidden_states = hidden_states[:batch_size] logits = self.model.compute_logits(last_hidden_states[:batch_size]) @@ -463,8 +438,6 @@ def propose( # [batch_size, num_speculative_tokens] draft_token_ids = torch.stack(draft_token_ids_list, dim=1) - for idx, tokens in enumerate(draft_token_ids): - self.runner.log_tokens(f"draft token ids {idx}", tokens) return draft_token_ids def compute_slot_mapping(self, positions: torch.Tensor, @@ -622,10 +595,6 @@ def prepare_inputs_padded(self, """ num_rejected_tokens_gpu = num_rejected_tokens( spec_decode_metadata, valid_sampled_tokens_count) - if self.runner.do_log: - logger.info("valid_sampled_tokens_count: %s", - valid_sampled_tokens_count) - logger.info("num_rejected_tokens_gpu: %s", num_rejected_tokens_gpu) query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu new_query_len_per_req = (query_start_loc_cpu[1:] - diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index d68b5d8931e4..2ec7f7f52c22 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -174,21 +174,11 @@ def get_output(self) -> ModelRunnerOutput: class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): - def log_tokens(self, msg: str, tokens: list[int]): - if not self.do_log: - return - logger.info("%s: %s", msg, [self.tokenizer.decode(t) for t in tokens]) - def __init__( self, vllm_config: VllmConfig, device: torch.device, ): - self.do_log = False - if self.do_log: - from transformers import AutoTokenizer - self.tokenizer = AutoTokenizer.from_pretrained( - vllm_config.model_config.model) self.vllm_config = vllm_config self.model_config = vllm_config.model_config self.cache_config = vllm_config.cache_config @@ -2323,8 +2313,6 @@ def execute_model( scheduler_output: "SchedulerOutput", intermediate_tensors: Optional[IntermediateTensors] = None, ) -> Union[ModelRunnerOutput, AsyncModelRunnerOutput, IntermediateTensors]: - if self.do_log: - logger.info("======STEP======") with record_function_or_nullcontext("Preprocess"): with self.synchronize_input_prep(): # Update persistent batch states. @@ -2395,18 +2383,6 @@ def execute_model( inputs_embeds=inputs_embeds, **model_kwargs, ) - self.log_tokens("tgt input_ids", input_ids) - if self.do_log: - logger.info("tgt positions: %s", positions) - logger.info("tgt position length: %s", positions.shape[0]) - logger.info("tgt logits_indices: %s", logits_indices) - logger.info("tgt batch_descriptor: %s", batch_descriptor) - _attn_metadata = list(attn_metadata.values())[0] - logger.info("tgt attn_metadata.slot_mapping: %s", - _attn_metadata.slot_mapping) - for _idx, _block in enumerate(_attn_metadata.block_table): - logger.info("tgt attn_metadata.block_table [%d]: %s", _idx, - _block.tolist()) with record_function_or_nullcontext("Postprocess"): if self.use_aux_hidden_state_outputs: @@ -2470,11 +2446,6 @@ def execute_model( with record_function_or_nullcontext("Sample"): sampler_output = self._sample(logits, spec_decode_metadata) - for idx, tokens in enumerate(sampler_output.sampled_token_ids): - self.log_tokens(f"Sampler output {idx}", tokens[tokens != -1]) - if self.do_log: - logger.info("n_sampled_tokens: %s", - sampler_output.n_sampled_tokens()) def propose_draft_token_ids(sampled_token_ids): assert spec_decode_common_attn_metadata is not None From 4b4399947a7666407eabaa7d2f5fb97322358d4f Mon Sep 17 00:00:00 2001 From: Tomas Ruiz Date: Thu, 2 Oct 2025 16:51:08 +0200 Subject: [PATCH 40/73] use non-blocking cpu move, document and test helper fns Signed-off-by: Tomas Ruiz --- tests/v1/e2e/test_spec_decode.py | 15 +++++++++++++++ vllm/v1/attention/backends/utils.py | 4 ++-- vllm/v1/spec_decode/draft_model.py | 17 ++++++++++++----- 3 files changed, 29 insertions(+), 7 deletions(-) diff --git a/tests/v1/e2e/test_spec_decode.py b/tests/v1/e2e/test_spec_decode.py index 1434f96faa86..48ce3571b72d 100644 --- a/tests/v1/e2e/test_spec_decode.py +++ b/tests/v1/e2e/test_spec_decode.py @@ -16,6 +16,7 @@ from vllm.distributed import cleanup_dist_env_and_memory from vllm.outputs import RequestOutput from vllm.platforms import current_platform +from vllm.v1.spec_decode.draft_model import compute_subrange_indices from vllm.v1.spec_decode.metrics import (compute_acceptance_len, compute_acceptance_rate) @@ -426,3 +427,17 @@ def compute_exact_matches(ref_outputs: list[RequestOutput], print(f"ref_output: {ref_output.outputs[0].text}") print(f"spec_output: {spec_output.outputs[0].text}") return matches / len(ref_outputs) + + +@pytest.mark.parametrize("device", ["cpu", "cuda"]) +def test_compute_subrange_indices(device: str): + start_locs = torch.tensor([3, 6, 12], device=device) + end_locs = torch.tensor([5, 6, 15], device=device) + # fmt: off + expected_indices = torch.tensor([3, 4, 5, + 6, + 12, 13, 14, 15], + device=device) + # fmt: on + indices = compute_subrange_indices(start_locs, end_locs) + assert torch.equal(indices, expected_indices) diff --git a/vllm/v1/attention/backends/utils.py b/vllm/v1/attention/backends/utils.py index 0ed7495ba146..c1ff2ed785f4 100644 --- a/vllm/v1/attention/backends/utils.py +++ b/vllm/v1/attention/backends/utils.py @@ -126,9 +126,9 @@ def extend_all_queries_by_1( new_cad = CommonAttentionMetadata( query_start_loc=new_query_start_loc, - query_start_loc_cpu=new_query_start_loc.to("cpu"), + query_start_loc_cpu=new_query_start_loc.to("cpu", non_blocking=True), seq_lens=new_seq_lens, - seq_lens_cpu=new_seq_lens.to("cpu"), + seq_lens_cpu=new_seq_lens.to("cpu", non_blocking=True), num_reqs=cad.num_reqs, # num requests stays unchanged num_computed_tokens_cpu=cad.num_computed_tokens_cpu + 1, # each request is extended by 1 token -> batch_size tokens are added diff --git a/vllm/v1/spec_decode/draft_model.py b/vllm/v1/spec_decode/draft_model.py index 0ea94f10d4a8..4352be22cbe2 100644 --- a/vllm/v1/spec_decode/draft_model.py +++ b/vllm/v1/spec_decode/draft_model.py @@ -191,7 +191,7 @@ def trim_accepted_and_rejected_tokens( new_cad: CommonAttentionMetadata = cad.replace( query_start_loc=new_query_start_loc, - query_start_loc_cpu=new_query_start_loc.to("cpu"), + query_start_loc_cpu=new_query_start_loc.to("cpu", non_blocking=True), num_actual_tokens=new_token_ids.shape[0], max_query_len=new_query_lens.max().item(), slot_mapping=new_slot_mapping, @@ -201,18 +201,25 @@ def trim_accepted_and_rejected_tokens( cad=new_cad) -def compute_subrange_indices(from_locs: torch.Tensor, to_locs: torch.Tensor): +def compute_subrange_indices(start_locs: torch.Tensor, end_locs: torch.Tensor): + """ + Given two tensor of the same length containing start and end locations, + returns a tensor of indices with each subrange. E.g. + start_locs = [s1, s2, s3, ...], and + end_locs = [e1, e2, e3, ...], + return [*s1:e1, *s2:e2, *s3:e3, ...] as a flat tensor + """ # Compute lengths of each subrange - lengths = to_locs - from_locs + 1 + lengths = end_locs - start_locs + 1 # Build an index for each subrange # torch.arange(max_len) creates [0, 1, ..., max_len-1] # broadcasting + masking ensures we only keep valid positions max_len = lengths.max() - offsets = torch.arange(max_len, device=from_locs.device).unsqueeze( + offsets = torch.arange(max_len, device=start_locs.device).unsqueeze( 0) # shape [1, max_len] mask = offsets < lengths.unsqueeze(1) # shape [n, max_len] # Build all indices - all_indices = from_locs.unsqueeze(1) + offsets + all_indices = start_locs.unsqueeze(1) + offsets all_indices = all_indices[mask] # flatten valid indices only return all_indices From 10eb718660410cc42097f2c417afb0a6c1d4f9e3 Mon Sep 17 00:00:00 2001 From: Tomas Ruiz Date: Thu, 2 Oct 2025 16:56:22 +0200 Subject: [PATCH 41/73] Minimize changes Signed-off-by: Tomas Ruiz --- vllm/v1/spec_decode/eagle.py | 41 ++++++++++++++++-------------------- 1 file changed, 18 insertions(+), 23 deletions(-) diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index 171cbbca45eb..77169fe64ef1 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -373,14 +373,27 @@ def propose( common_attn_metadata.seq_lens_cpu - 1 # Compute the slot mapping. - slot_mapping = self.compute_slot_mapping( - positions=clamped_positions, - block_table_tensor=common_attn_metadata.block_table_tensor) + if self.uses_mrope: + # all dimensions of positions are the same + block_numbers = clamped_positions[0] // self.block_size + else: + block_numbers = clamped_positions // self.block_size + block_ids = common_attn_metadata.block_table_tensor.gather( + dim=1, index=block_numbers.view(-1, 1)) + block_ids = block_ids.view(-1) + if self.uses_mrope: + common_attn_metadata.slot_mapping = ( + block_ids * self.block_size + + clamped_positions[0] % self.block_size) + else: + common_attn_metadata.slot_mapping = ( + block_ids * self.block_size + + clamped_positions % self.block_size) # Mask out the slot mappings that exceed the max model length. # Otherwise, the KV cache will be inadvertently updated with the # padding tokens. - slot_mapping.masked_fill_(exceeds_max_model_len, PADDING_SLOT_ID) - common_attn_metadata.slot_mapping = slot_mapping + common_attn_metadata.slot_mapping.masked_fill_( + exceeds_max_model_len, PADDING_SLOT_ID) # Rebuild attention metadata attn_metadata = builder.build_for_drafting( # type: ignore @@ -440,24 +453,6 @@ def propose( draft_token_ids = torch.stack(draft_token_ids_list, dim=1) return draft_token_ids - def compute_slot_mapping(self, positions: torch.Tensor, - block_table_tensor: torch.Tensor) -> torch.Tensor: - if self.uses_mrope: - # all dimensions of positions are the same - block_numbers = positions[0] // self.block_size - else: - block_numbers = positions // self.block_size - block_ids = block_table_tensor.gather(dim=1, - index=block_numbers.view(-1, 1)) - block_ids = block_ids.view(-1) - if self.uses_mrope: - slot_mapping = (block_ids * self.block_size + - positions[0] % self.block_size) - else: - slot_mapping = (block_ids * self.block_size + - positions % self.block_size) - return slot_mapping - def set_input_ids_first_pass(self, target_token_ids: torch.Tensor, next_token_ids: torch.Tensor, num_tokens: int, last_token_indices: torch.Tensor) -> None: From 4c7eb11ff731965d10c3b29545c6fa3710012bc3 Mon Sep 17 00:00:00 2001 From: Tomas Ruiz Date: Thu, 2 Oct 2025 17:20:52 +0200 Subject: [PATCH 42/73] Reduce changes footprint Signed-off-by: Tomas Ruiz --- .gitignore | 3 - make-problem.sh | 12 ---- tests/v1/e2e/test_spec_decode.py | 80 +++++++++++----------- vllm/v1/spec_decode/draft_model.py | 105 ++++++++++++++++------------- vllm/v1/spec_decode/eagle.py | 3 + vllm/v1/worker/gpu_model_runner.py | 8 +-- 6 files changed, 103 insertions(+), 108 deletions(-) delete mode 100755 make-problem.sh diff --git a/.gitignore b/.gitignore index d947c675565b..b1df673e83ca 100644 --- a/.gitignore +++ b/.gitignore @@ -1,6 +1,3 @@ -# Scripts for development -scripts/ - # version file generated by setuptools-scm /vllm/_version.py diff --git a/make-problem.sh b/make-problem.sh deleted file mode 100755 index 6265e0d2ad71..000000000000 --- a/make-problem.sh +++ /dev/null @@ -1,12 +0,0 @@ -export CUDA_LAUNCH_BLOCKING=1 -nohup python examples/offline_inference/spec_decode.py \ - --model-dir Qwen/Qwen3-1.7B \ - --draft-model Qwen/Qwen3-1.7B \ - --method draft_model \ - --num_spec_tokens 3 \ - --dataset-name hf \ - --dataset-path philschmid/mt-bench \ - --num_prompts 80 \ - --temp 0.0 \ - --gpu-memory-utilization 0.9 \ - --enforce-eager > al.log 2>&1 & \ No newline at end of file diff --git a/tests/v1/e2e/test_spec_decode.py b/tests/v1/e2e/test_spec_decode.py index 48ce3571b72d..4859a3cd66fb 100644 --- a/tests/v1/e2e/test_spec_decode.py +++ b/tests/v1/e2e/test_spec_decode.py @@ -248,46 +248,6 @@ def test_eagle_correctness( cleanup_dist_env_and_memory() -@dataclass -class ArgsTest: - model: str - draft_model: str - sampling_config: SamplingParams - num_speculative_tokens: int - expected_acceptance_rate: float - expected_acceptance_len: float - expected_same_output_fraction: float - # Defaults - target_tensor_parallel_size: int = 1 - draft_tensor_parallel_size: int = 1 - max_model_len: int = 1024 - gpu_memory_utilization: float = 0.5 - - -cases = [ - # Same model for draft and target, greedy sampling. - ArgsTest( - model="Qwen/Qwen3-0.6B", - draft_model="Qwen/Qwen3-0.6B", - sampling_config=greedy_sampling(), - num_speculative_tokens=3, # K - expected_acceptance_len=3 + 1, # K + 1 - expected_acceptance_rate=1.0, - expected_same_output_fraction=1.0, - ), - # Smaller draft model, stochastic sampling. - ArgsTest( - model="Qwen/Qwen3-1.7B", - draft_model="Qwen/Qwen3-0.6B", - sampling_config=stochastic_sampling(), - num_speculative_tokens=3, - expected_acceptance_len=2.85 + 1, - expected_acceptance_rate=0.9, - expected_same_output_fraction=0.9, - ), -] - - @pytest.mark.parametrize(["model_setup", "mm_enabled"], [ (("mtp", "XiaomiMiMo/MiMo-7B-Base", 1), False), (("mtp", "ZixiQi/DeepSeek-V3-4layers-MTP-FP8", 1), False), @@ -351,6 +311,46 @@ def test_mtp_correctness( cleanup_dist_env_and_memory() +@dataclass +class ArgsTest: + model: str + draft_model: str + sampling_config: SamplingParams + num_speculative_tokens: int + expected_acceptance_rate: float + expected_acceptance_len: float + expected_same_output_fraction: float + # Defaults + target_tensor_parallel_size: int = 1 + draft_tensor_parallel_size: int = 1 + max_model_len: int = 1024 + gpu_memory_utilization: float = 0.5 + + +cases = [ + # Same model for draft and target, greedy sampling. + ArgsTest( + model="Qwen/Qwen3-0.6B", + draft_model="Qwen/Qwen3-0.6B", + sampling_config=greedy_sampling(), + num_speculative_tokens=3, # K + expected_acceptance_len=3 + 1, # K + 1 + expected_acceptance_rate=1.0, + expected_same_output_fraction=1.0, + ), + # Smaller draft model, stochastic sampling. + ArgsTest( + model="Qwen/Qwen3-1.7B", + draft_model="Qwen/Qwen3-0.6B", + sampling_config=stochastic_sampling(), + num_speculative_tokens=3, + expected_acceptance_len=2.85 + 1, + expected_acceptance_rate=0.9, + expected_same_output_fraction=0.9, + ), +] + + @pytest.mark.parametrize("args", cases) @pytest.mark.parametrize("enforce_eager", [True, False]) def test_draft_model_correctness( diff --git a/vllm/v1/spec_decode/draft_model.py b/vllm/v1/spec_decode/draft_model.py index 4352be22cbe2..d28be0491891 100644 --- a/vllm/v1/spec_decode/draft_model.py +++ b/vllm/v1/spec_decode/draft_model.py @@ -8,13 +8,14 @@ from vllm.attention.layer import Attention from vllm.config import ModelConfig, VllmConfig, get_layers_from_vllm_config from vllm.forward_context import set_forward_context -from vllm.logger import init_logger from vllm.model_executor.model_loader import get_model from vllm.v1.attention.backends.utils import (CommonAttentionMetadata, extend_all_queries_by_1, extend_flat_seqs) from vllm.v1.outputs import SamplerOutput -from vllm.v1.spec_decode.eagle import (PADDING_SLOT_ID, SpecDecodeBaseProposer, +from vllm.v1.sample.metadata import SamplingMetadata +from vllm.v1.spec_decode.eagle import (PADDING_SLOT_ID, CudaGraphArgs, + SpecDecodeBaseProposer, num_rejected_tokens) from vllm.v1.spec_decode.metadata import SpecDecodeMetadata @@ -36,15 +37,60 @@ def __init__( self._raise_if_mrope() self._raise_if_disabled_padded_drafter_batch() - def update_propose_kwargs( - self, propose_kwargs: dict, sampler_output: SamplerOutput, - spec_decode_metadata: Optional[SpecDecodeMetadata]): - return update_propose_kwargs(arange=self.arange, - propose_kwargs=propose_kwargs, - sampler_output=sampler_output, - spec_decode_metadata=spec_decode_metadata, - block_size=self.block_size, - max_model_len=self.max_model_len) + def propose( + self, + # [num_tokens] + target_token_ids: torch.Tensor, + # [num_tokens] or [3, num_tokens] when M-RoPE is enabled + target_positions: torch.Tensor, + # [num_tokens, hidden_size] + target_hidden_states: torch.Tensor, + # [batch_size] + next_token_ids: torch.Tensor, + last_token_indices: Optional[torch.Tensor], + common_attn_metadata: CommonAttentionMetadata, + sampling_metadata: SamplingMetadata, + cudagraph_args: "CudaGraphArgs", + sampler_output: SamplerOutput, + spec_decode_metadata: Optional[SpecDecodeMetadata], + mm_embed_inputs: Optional[tuple[list[torch.Tensor], + torch.Tensor]] = None, + ) -> torch.Tensor: + """ + - Trims unnecessary tokens from the input, like those rejected by + the sampler, or those already processed by the draft model. + - Merges the next_token_ids with the existing token ids into + a flat sequence. + """ + inputs = DraftModelInputs(cad=common_attn_metadata, + token_ids=target_token_ids, + positions=target_positions) + inputs = trim_accepted_and_rejected_tokens( + inputs=inputs, + sampler_output=sampler_output, + spec_decode_metadata=spec_decode_metadata) + inputs = merge_next_token_ids_into_token_ids( + inputs=inputs, + next_token_ids=next_token_ids, + block_size=self.block_size, + max_model_len=self.max_model_len, + arange=self.arange) + + draft_token_ids = super().propose( + target_token_ids=inputs.token_ids, + target_positions=inputs.positions, + common_attn_metadata=inputs.cad, + cudagraph_args=cudagraph_args, + sampling_metadata=sampling_metadata, + sampler_output=sampler_output, + spec_decode_metadata=spec_decode_metadata, + # below are are not used by draft model + target_hidden_states=None, + next_token_ids=None, + last_token_indices=None, + mm_embed_inputs=None, + ) + return draft_token_ids def _raise_if_multimodal(self): if self.supports_mm_inputs: @@ -114,43 +160,6 @@ def load_model(self, target_model: Any) -> None: self.attn_layer_names = list(draft_attn_layer_names) -logger = init_logger(__name__) - - -def update_propose_kwargs(arange: torch.Tensor, propose_kwargs: dict, - sampler_output: SamplerOutput, - spec_decode_metadata: Optional[SpecDecodeMetadata], - block_size: int, max_model_len: int) -> dict: - """ - - Trims unnecessary tokens from the input, like those rejected by - the sampler, or those already processed by the draft model. - - Merges the next_token_ids with the existing token ids into - a flat sequence. - """ - cad: CommonAttentionMetadata = propose_kwargs["common_attn_metadata"] - inputs = DraftModelInputs(cad=cad, - token_ids=propose_kwargs["target_token_ids"], - positions=propose_kwargs["target_positions"]) - inputs = trim_accepted_and_rejected_tokens( - inputs=inputs, - sampler_output=sampler_output, - spec_decode_metadata=spec_decode_metadata) - inputs = merge_next_token_ids_into_token_ids( - inputs=inputs, - next_token_ids=propose_kwargs["next_token_ids"], - block_size=block_size, - max_model_len=max_model_len, - arange=arange) - new_propose_kwargs = dict( - target_token_ids=inputs.token_ids, - target_positions=inputs.positions, - next_token_ids=None, - last_token_indices=None, - common_attn_metadata=inputs.cad, - ) - return propose_kwargs | new_propose_kwargs - - @dataclass class DraftModelInputs: token_ids: torch.Tensor diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index 77169fe64ef1..892d52252396 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -29,6 +29,7 @@ from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder, CommonAttentionMetadata) from vllm.v1.kv_cache_interface import KVCacheConfig +from vllm.v1.outputs import SamplerOutput from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.spec_decode.metadata import SpecDecodeMetadata from vllm.v1.utils import CpuGpuBuffer @@ -187,6 +188,8 @@ def propose( common_attn_metadata: CommonAttentionMetadata, sampling_metadata: SamplingMetadata, cudagraph_args: "CudaGraphArgs", + sampler_output: SamplerOutput, + spec_decode_metadata: Optional[SpecDecodeMetadata], mm_embed_inputs: Optional[tuple[list[torch.Tensor], torch.Tensor]] = None, ) -> torch.Tensor: diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 2ec7f7f52c22..a44a8ccbe7df 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -2670,7 +2670,7 @@ def propose_draft_token_ids( cudagraph_runtime_mode=cudagraph_runtime_mode, batch_descriptor=batch_descriptor, ) - propose_kwargs = dict( + draft_token_ids = self.drafter.propose( target_token_ids=target_token_ids, target_positions=target_positions, target_hidden_states=target_hidden_states, @@ -2680,11 +2680,9 @@ def propose_draft_token_ids( common_attn_metadata=common_attn_metadata, mm_embed_inputs=mm_embed_inputs, cudagraph_args=cudagraph_args, + sampler_output=sampler_output, + spec_decode_metadata=spec_decode_metadata, ) - if isinstance(self.drafter, DraftModelProposer): - propose_kwargs = self.drafter.update_propose_kwargs( - propose_kwargs, sampler_output, spec_decode_metadata) - draft_token_ids = self.drafter.propose(**propose_kwargs) return draft_token_ids def update_config(self, overrides: dict[str, Any]) -> None: From d1230189af0969170334f5b210311eea09632c78 Mon Sep 17 00:00:00 2001 From: Tomas Ruiz Date: Thu, 2 Oct 2025 17:23:09 +0200 Subject: [PATCH 43/73] Reduce changes Signed-off-by: Tomas Ruiz --- vllm/v1/spec_decode/draft_model.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/vllm/v1/spec_decode/draft_model.py b/vllm/v1/spec_decode/draft_model.py index d28be0491891..a42599fd3c30 100644 --- a/vllm/v1/spec_decode/draft_model.py +++ b/vllm/v1/spec_decode/draft_model.py @@ -109,8 +109,6 @@ def _raise_if_disabled_padded_drafter_batch(self): "disabled padded drafter batch yet") def _model_kwargs(self, num_tokens: int) -> dict[str, Any]: - self._raise_if_multimodal() - self._raise_if_mrope() return { "input_ids": self.input_ids[:num_tokens], "positions": self.positions[:num_tokens], @@ -118,7 +116,6 @@ def _model_kwargs(self, num_tokens: int) -> dict[str, Any]: def dummy_run(self, num_tokens: int, forward_ctx_kwargs: dict): model_kwargs = self._model_kwargs(num_tokens) - assert isinstance(self.model, torch.nn.Module) with set_forward_context( vllm_config=self.vllm_config, num_tokens=num_tokens, From 02872ad673f989459c18c2e4b01ceffada850116 Mon Sep 17 00:00:00 2001 From: Tomas Ruiz Date: Thu, 2 Oct 2025 17:33:37 +0200 Subject: [PATCH 44/73] Minimize changes Signed-off-by: Tomas Ruiz --- vllm/v1/spec_decode/eagle.py | 22 +++++++++++++--------- 1 file changed, 13 insertions(+), 9 deletions(-) diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index 892d52252396..6a3aafdab0c8 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -65,9 +65,7 @@ def __init__( self.num_speculative_tokens = ( self.speculative_config.num_speculative_tokens) self.max_num_tokens = ( - vllm_config.scheduler_config.max_num_batched_tokens + - vllm_config.scheduler_config.max_num_seqs * - self.num_speculative_tokens) + vllm_config.scheduler_config.max_num_batched_tokens) self.token_arange_np = np.arange(self.max_num_tokens) # We need to get the hidden size from the draft model config because # the draft model's hidden size can be different from the target model's @@ -266,12 +264,7 @@ def propose( num_tokens=num_input_tokens, ) if self.pass_cudagraph_args_to_forward_ctx: - # Update num_tokens in batch descriptor, eg after cudagraph padding - old_bd: BatchDescriptor = cudagraph_args["batch_descriptor"] - if old_bd is not None: - new_bd = BatchDescriptor(num_tokens=num_input_tokens, - uniform_decode=old_bd.uniform_decode) - cudagraph_args["batch_descriptor"] = new_bd + update_batch_descriptor(cudagraph_args, num_input_tokens) forward_ctx_kwargs.update(cudagraph_args) with set_forward_context(**forward_ctx_kwargs): @@ -1137,3 +1130,14 @@ def num_rejected_tokens( num_draft_tokens_gpu + 1 - valid_sampled_tokens_count, torch.zeros_like(num_draft_tokens_gpu)) return num_rejected_tokens_gpu + + +def update_batch_descriptor(cudagraph_args: CudaGraphArgs, + new_num_tokens: int) -> None: + """The cudagraph padding can change the num_tokens, so the batch descriptor + should be updated. The cudagraph_args is modified in place.""" + old: Optional[BatchDescriptor] = cudagraph_args["batch_descriptor"] + if old is not None: + new = BatchDescriptor(num_tokens=new_num_tokens, + uniform_decode=old.uniform_decode) + cudagraph_args["batch_descriptor"] = new From 33bcc08e2d31ff5ccdcc54d2cc64f97d4d943785 Mon Sep 17 00:00:00 2001 From: Tomas Ruiz Date: Mon, 6 Oct 2025 10:32:58 +0200 Subject: [PATCH 45/73] ruff Signed-off-by: Tomas Ruiz --- .pre-commit-config.yaml | 12 - pyproject.toml | 127 +- tests/v1/attention/test_attention_backends.py | 322 +-- tests/v1/e2e/test_spec_decode.py | 187 +- tests/v1/spec_decode/test_eagle.py | 390 ++-- tests/v1/test_outputs.py | 3 +- tests/v1/worker/test_utils.py | 97 +- vllm/config/speculative.py | 333 +-- vllm/engine/arg_utils.py | 1202 +++++----- vllm/model_executor/model_loader/__init__.py | 53 +- .../model_loader/base_loader.py | 28 +- .../model_loader/gguf_loader.py | 86 +- .../model_loader/tensorizer_loader.py | 55 +- vllm/v1/attention/backends/utils.py | 389 ++-- vllm/v1/core/sched/scheduler.py | 432 ++-- vllm/v1/outputs.py | 51 +- vllm/v1/spec_decode/draft_model.py | 142 +- vllm/v1/spec_decode/eagle.py | 625 ++--- vllm/v1/spec_decode/metrics.py | 89 +- vllm/v1/worker/gpu_model_runner.py | 2040 ++++++++++------- vllm/v1/worker/utils.py | 49 +- 21 files changed, 3767 insertions(+), 2945 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 8ca414ee4269..ea63ef1f528c 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -6,28 +6,16 @@ default_stages: - manual # Run in CI exclude: 'vllm/third_party/.*' repos: -- repo: https://github.com/google/yapf - rev: v0.43.0 - hooks: - - id: yapf - args: [--in-place, --verbose] - # Keep the same list from yapfignore here to avoid yapf failing without any inputs - exclude: '(.buildkite|benchmarks|build|examples)/.*' - repo: https://github.com/astral-sh/ruff-pre-commit rev: v0.11.7 hooks: - id: ruff args: [--output-format, github, --fix] - id: ruff-format - files: ^(.buildkite|benchmarks|examples)/.* - repo: https://github.com/crate-ci/typos rev: v1.35.5 hooks: - id: typos -- repo: https://github.com/PyCQA/isort - rev: 6.0.1 - hooks: - - id: isort - repo: https://github.com/pre-commit/mirrors-clang-format rev: v20.1.3 hooks: diff --git a/pyproject.toml b/pyproject.toml index 034a21f1c12b..2b416d3206c2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -52,27 +52,106 @@ lora_filesystem_resolver = "vllm.plugins.lora_resolvers.filesystem_resolver:regi where = ["."] include = ["vllm*"] -[tool.yapfignore] -ignore_patterns = [ - ".buildkite/**", - "benchmarks/**", - "build/**", - "examples/**", -] - -[tool.ruff] -# Allow lines to be as long as 80. -line-length = 80 - [tool.ruff.lint.per-file-ignores] "vllm/third_party/**" = ["ALL"] "vllm/version.py" = ["F401"] "vllm/_version.py" = ["ALL"] -# Python 3.8 typing - skip V0 code -"vllm/attention/**/*.py" = ["UP006", "UP035"] -"vllm/engine/**/*.py" = ["UP006", "UP035"] -"vllm/executor/**/*.py" = ["UP006", "UP035"] -"vllm/worker/**/*.py" = ["UP006", "UP035"] +# TEMPORARY! These ignores will be fixed forward +## Line length violations +"csrc/cutlass_extensions/vllm_cutlass_library_extension.py" = ["E501"] +"tests/compile/piecewise/test_simple.py" = ["E501"] +"tests/compile/piecewise/test_toy_llama.py" = ["E501", "B023"] +"tests/entrypoints/conftest.py" = ["E501"] +"tests/entrypoints/openai/test_audio.py" = ["E501"] +"tests/entrypoints/openai/test_chat.py" = ["E501"] +"tests/entrypoints/openai/test_chat_template.py" = ["E501"] +"tests/entrypoints/openai/test_chat_with_tool_reasoning.py" = ["E501"] +"tests/entrypoints/openai/test_completion_with_function_calling.py" = ["E501"] +"tests/entrypoints/openai/test_video.py" = ["E501"] +"tests/entrypoints/openai/test_vision.py" = ["E501"] +"tests/entrypoints/test_chat_utils.py" = ["E501"] +"tests/kernels/moe/modular_kernel_tools/common.py" = ["E501"] +"tests/models/language/generation/test_gemma.py" = ["E501"] +"tests/models/language/generation/test_mistral.py" = ["E501"] +"tests/models/multimodal/generation/test_ultravox.py" = ["E501"] +"tests/models/multimodal/generation/test_voxtral.py" = ["E501"] +"tests/models/multimodal/generation/vlm_utils/custom_inputs.py" = ["E501"] +"tests/tool_use/test_tool_choice_required.py" = ["E501"] +"tests/v1/attention/utils.py" = ["E501"] +"tests/v1/entrypoints/openai/responses/test_image.py" = ["E501"] +"tests/v1/kv_connector/nixl_integration/test_accuracy.py" = ["E501"] +"tests/v1/kv_connector/unit/test_offloading_connector.py" = ["E501"] +"tests/v1/logits_processors/test_custom_offline.py" = ["E501"] +"vllm/attention/ops/pallas_kv_cache_update.py" = ["E501"] +"vllm/compilation/collective_fusion.py" = ["E501"] +"vllm/compilation/wrapper.py" = ["E501"] +"vllm/config/vllm.py" = ["E501"] +"vllm/distributed/device_communicators/all2all.py" = ["E501"] +"vllm/entrypoints/openai/protocol.py" = ["E501"] +"vllm/lora/layers/vocal_parallel_embedding.py" = ["E501"] +"vllm/model_executor/model_loader/bitsandbytes_loader.py" = ["E501"] +"vllm/model_executor/models/bailing_moe.py" = ["E501"] +"vllm/model_executor/models/hyperclovax_vision.py" = ["E501"] +"vllm/model_executor/models/llama4_eagle.py" = ["E501"] +"vllm/model_executor/models/longcat_flash_mtp.py" = ["E501"] +"vllm/model_executor/models/phi4mm.py" = ["E501"] +"vllm/model_executor/models/qwen3_next.py" = ["E501"] +"vllm/model_executor/layers/quantization/ptpc_fp8.py" = ["E501"] +"vllm/v1/attention/backends/mla/common.py" = ["E501"] +"vllm/v1/engine/utils.py" = ["E501"] +"vllm/v1/utils.py" = ["E501"] +"vllm/v1/worker/gpu_model_runner.py" = ["E501"] +## Simplification rules +"tests/distributed/test_expert_placement.py" = ["SIM108"] +"tests/kernels/attention/test_cutlass_mla_decode.py" = ["SIM108"] +"tests/kernels/attention/test_flashmla.py" = ["SIM108"] +"tests/kernels/attention/test_lightning_attn.py" = ["SIM108"] +"tests/kernels/moe/test_pplx_moe.py" = ["SIM108"] +"tests/kernels/quantization/test_cutlass_scaled_mm.py" = ["SIM108"] +"tests/kernels/test_onednn.py" = ["SIM108"] +"tests/kernels/utils.py" = ["SIM108"] +"tests/multimodal/test_processing.py" = ["SIM108"] +"vllm/attention/ops/triton_reshape_and_cache_flash.py" = ["SIM108"] +"vllm/distributed/parallel_state.py" = ["SIM108"] +"vllm/entrypoints/chat_utils.py" = ["SIM108"] +"vllm/entrypoints/llm.py" = ["SIM108"] +"vllm/model_executor/layers/batch_invariant.py" = ["SIM108"] +"vllm/model_executor/layers/fla/ops/chunk_o.py" = ["SIM108"] +"vllm/model_executor/layers/fused_moe/fused_moe.py" = ["SIM108"] +"vllm/model_executor/layers/fused_moe/layer.py" = ["SIM108"] +"vllm/model_executor/layers/fused_moe/modular_kernel.py" = ["SIM108"] +"vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py" = ["SIM108"] +"vllm/model_executor/layers/layernorm.py" = ["SIM108"] +"vllm/model_executor/layers/lightning_attn.py" = ["SIM108"] +"vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py" = ["SIM103"] +"vllm/model_executor/layers/quantization/compressed_tensors/utils.py" = ["SIM110"] +"vllm/model_executor/layers/quantization/quark/utils.py" = ["SIM110"] +"vllm/utils/__init__.py" = ["SIM108"] +"vllm/v1/sample/ops/bad_words.py" = ["SIM108"] +"vllm/v1/sample/rejection_sampler.py" = ["SIM108"] +"vllm/v1/worker/tpu_model_runner.py" = ["SIM108"] +"vllm/_custom_ops.py" = ["SIM108"] +"tools/profiler/print_layerwise_table.py" = ["SIM118"] +## Loop variable binding issues +"tests/kernels/mamba/test_mamba_ssm_ssd.py" = ["B023"] +## Type annotation modernization and other rules +"vllm/attention/backends/abstract.py" = ["UP035", "UP006"] +"vllm/attention/layer.py" = ["UP035", "UP006"] +"vllm/attention/layers/chunked_local_attention.py" = ["UP035", "UP006"] +"vllm/attention/ops/flashmla.py" = ["UP035", "UP006"] +"vllm/attention/ops/paged_attn.py" = ["UP035", "UP006"] +"vllm/engine/arg_utils.py" = ["UP035", "UP006"] +"vllm/engine/metrics.py" = ["UP035", "UP006"] +"vllm/engine/metrics_types.py" = ["UP035", "UP006"] +"vllm/executor/executor_base.py" = ["UP035", "UP006"] +"vllm/executor/msgspec_utils.py" = ["UP035", "UP006"] +"vllm/executor/ray_distributed_executor.py" = ["UP035", "UP006", "SIM108", "SIM112"] +"vllm/executor/ray_utils.py" = ["UP035", "UP006"] +"vllm/executor/uniproc_executor.py" = ["UP035", "UP006"] +"vllm/model_executor/layers/fused_moe/flashinfer_trtllm_moe.py" = ["UP035"] +## Type comparison issues +"vllm/multimodal/inputs.py" = ["E721"] +# End of temporary ignores [tool.ruff.lint] select = [ @@ -87,7 +166,7 @@ select = [ # flake8-simplify "SIM", # isort - # "I", + "I", # flake8-logging-format "G", ] @@ -104,21 +183,15 @@ ignore = [ "UP007", ] +[tool.ruff.format] +docstring-code-format = true + [tool.mypy] plugins = ['pydantic.mypy'] ignore_missing_imports = true check_untyped_defs = true follow_imports = "silent" -[tool.isort] -skip_glob = [ - ".buildkite/*", - "benchmarks/*", - "examples/*", -] -use_parentheses = true -skip_gitignore = true - [tool.pytest.ini_options] markers = [ "slow_test", diff --git a/tests/v1/attention/test_attention_backends.py b/tests/v1/attention/test_attention_backends.py index b3ae9941fe6b..1c094ce1788d 100644 --- a/tests/v1/attention/test_attention_backends.py +++ b/tests/v1/attention/test_attention_backends.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Tests for v1 attention backends without GPUModelRunner dependency.""" + from functools import partial from typing import Optional, Union @@ -8,22 +9,31 @@ import torch from torch.nn.attention.flex_attention import create_block_mask, flex_attention -from tests.v1.attention.utils import (BatchSpec, create_common_attn_metadata, - create_standard_kv_cache_spec, - create_vllm_config, - get_attention_backend) +from tests.v1.attention.utils import ( + BatchSpec, + create_common_attn_metadata, + create_standard_kv_cache_spec, + create_vllm_config, + get_attention_backend, +) from vllm.attention.backends.registry import _Backend from vllm.config import ModelConfig from vllm.platforms import current_platform from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, cdiv, is_torch_equal_or_newer -from vllm.v1.attention.backends.utils import (CommonAttentionMetadata, - extend_flat_seqs, - set_kv_cache_layout) +from vllm.v1.attention.backends.utils import ( + CommonAttentionMetadata, + extend_flat_seqs, + set_kv_cache_layout, +) from vllm.v1.kv_cache_interface import FullAttentionSpec BACKENDS_TO_TEST = [ - _Backend.FLASH_ATTN, _Backend.FLASHINFER, _Backend.FLEX_ATTENTION, - _Backend.TRITON_ATTN, _Backend.TREE_ATTN, "FLEX_ATTENTION_SLOW" + _Backend.FLASH_ATTN, + _Backend.FLASHINFER, + _Backend.FLEX_ATTENTION, + _Backend.TRITON_ATTN, + _Backend.TREE_ATTN, + "FLEX_ATTENTION_SLOW", ] # Remove flashinfer from the list if it's not available @@ -50,42 +60,38 @@ def _convert_dtype_to_torch(dtype): # Define common batch configurations BATCH_SPECS = { - "small_decode": - BatchSpec(seq_lens=[32, 40], query_lens=[1, 1]), - "small_prefill": - BatchSpec(seq_lens=[32, 40], query_lens=[8, 8]), - "mixed_small": - BatchSpec(seq_lens=[32, 40, 48, 56], query_lens=[1, 1, 5, 5]), - "medium_decode": - BatchSpec(seq_lens=[128, 256, 512, 1024, 128, 256, 512, 1024], - query_lens=[1, 1, 1, 1, 1, 1, 1, 1]), - "medium_prefill": - BatchSpec(seq_lens=[256, 512, 1024, 2048], query_lens=[16, 16, 16, 16]), - "mixed_medium": - BatchSpec(seq_lens=[512, 1024, 2048, 512, 1024, 2048], - query_lens=[1, 1, 1, 7, 7, 7]), - "large_decode": - BatchSpec(seq_lens=[2048] * 32, query_lens=[1] * 32), - "large_prefill": - BatchSpec(seq_lens=[4096] * 8, query_lens=[32] * 8), - "single_decode": - BatchSpec(seq_lens=[1024], query_lens=[1]), - "single_prefill": - BatchSpec(seq_lens=[1024], query_lens=[64]), + "small_decode": BatchSpec(seq_lens=[32, 40], query_lens=[1, 1]), + "small_prefill": BatchSpec(seq_lens=[32, 40], query_lens=[8, 8]), + "mixed_small": BatchSpec(seq_lens=[32, 40, 48, 56], query_lens=[1, 1, 5, 5]), + "medium_decode": BatchSpec( + seq_lens=[128, 256, 512, 1024, 128, 256, 512, 1024], + query_lens=[1, 1, 1, 1, 1, 1, 1, 1], + ), + "medium_prefill": BatchSpec( + seq_lens=[256, 512, 1024, 2048], query_lens=[16, 16, 16, 16] + ), + "mixed_medium": BatchSpec( + seq_lens=[512, 1024, 2048, 512, 1024, 2048], query_lens=[1, 1, 1, 7, 7, 7] + ), + "large_decode": BatchSpec(seq_lens=[2048] * 32, query_lens=[1] * 32), + "large_prefill": BatchSpec(seq_lens=[4096] * 8, query_lens=[32] * 8), + "single_decode": BatchSpec(seq_lens=[1024], query_lens=[1]), + "single_prefill": BatchSpec(seq_lens=[1024], query_lens=[64]), } def create_and_prepopulate_kv_cache( - k_contexts: list[torch.Tensor], - v_contexts: list[torch.Tensor], - block_size: int, - num_kv_heads: int, - head_size: int, - dtype: torch.dtype, - device: torch.device, - num_blocks: int, - common_attn_metadata: CommonAttentionMetadata, - randomize_blocks: bool = True) -> torch.Tensor: + k_contexts: list[torch.Tensor], + v_contexts: list[torch.Tensor], + block_size: int, + num_kv_heads: int, + head_size: int, + dtype: torch.dtype, + device: torch.device, + num_blocks: int, + common_attn_metadata: CommonAttentionMetadata, + randomize_blocks: bool = True, +) -> torch.Tensor: """Create and prepopulate a KV cache with context data. Args: @@ -107,20 +113,18 @@ def create_and_prepopulate_kv_cache( """ batch_size = len(k_contexts) seq_lens = common_attn_metadata.seq_lens_cpu - query_lens = common_attn_metadata.query_start_loc_cpu[ - 1:] - common_attn_metadata.query_start_loc_cpu[:-1] + query_lens = ( + common_attn_metadata.query_start_loc_cpu[1:] + - common_attn_metadata.query_start_loc_cpu[:-1] + ) context_lens = common_attn_metadata.num_computed_tokens_cpu block_table = common_attn_metadata.block_table_tensor slot_mapping = common_attn_metadata.slot_mapping # Create KV cache - kv_cache = torch.empty(2, - num_blocks, - block_size, - num_kv_heads, - head_size, - dtype=dtype, - device=device) + kv_cache = torch.empty( + 2, num_blocks, block_size, num_kv_heads, head_size, dtype=dtype, device=device + ) kv_cache_flat = kv_cache.view(2, -1, num_kv_heads, head_size) # Populate the cache with the context tokens @@ -169,8 +173,8 @@ def create_and_prepopulate_kv_cache( start = common_attn_metadata.query_start_loc_cpu[i] end = common_attn_metadata.query_start_loc_cpu[i + 1] slot_mapping[start:end] = block_table[ - i, - block_indices] * block_size + token_inter_block_offsets.to(device) + i, block_indices + ] * block_size + token_inter_block_offsets.to(device) return kv_cache @@ -223,20 +227,19 @@ def mock_get_per_layer_parameters(vllm_config, layer_names, impl_cls): # Return mock parameters for a single layer head_size = vllm_config.model_config.get_head_size() return { - layer_name: - PerLayerParameters( + layer_name: PerLayerParameters( window_left=-1, # No sliding window logits_soft_cap=0.0, # No soft cap - sm_scale=1.0 / (head_size**0.5) # Standard scale + sm_scale=1.0 / (head_size**0.5), # Standard scale ) for layer_name in layer_names } with unittest.mock.patch( - 'vllm.v1.attention.backends.flashinfer.get_per_layer_parameters', - mock_get_per_layer_parameters): - builder = builder_cls(kv_cache_spec, layer_names, vllm_config, - device) + "vllm.v1.attention.backends.flashinfer.get_per_layer_parameters", + mock_get_per_layer_parameters, + ): + builder = builder_cls(kv_cache_spec, layer_names, vllm_config, device) attn_metadata = builder.build( common_prefix_len=0, common_attn_metadata=common_attn_metadata, @@ -253,9 +256,11 @@ def mock_get_per_layer_parameters(vllm_config, layer_names, impl_cls): # Instantiate implementation num_heads = vllm_config.model_config.get_num_attention_heads( - vllm_config.parallel_config) + vllm_config.parallel_config + ) num_kv_heads = vllm_config.model_config.get_num_kv_heads( - vllm_config.parallel_config) + vllm_config.parallel_config + ) head_size = vllm_config.model_config.get_head_size() scale = 1.0 / (head_size**0.5) impl = impl_cls( @@ -275,13 +280,9 @@ def mock_get_per_layer_parameters(vllm_config, layer_names, impl_cls): # Run forward pass # NOTE: The query, key, and value are already shaped correctly # in the calling test function. - output = impl.forward(mock_layer, - query, - key, - value, - kv_cache, - attn_metadata, - output=output) + output = impl.forward( + mock_layer, query, key, value, kv_cache, attn_metadata, output=output + ) return output @@ -312,10 +313,12 @@ def _test_backend_correctness( 5. Comparing the vLLM backend's output to the ground-truth SDPA output. """ current_platform.seed_everything(42) - vllm_config = create_vllm_config(model_name=model, - max_model_len=max(batch_spec.seq_lens), - block_size=block_size, - num_gpu_blocks=8192) + vllm_config = create_vllm_config( + model_name=model, + max_model_len=max(batch_spec.seq_lens), + block_size=block_size, + num_gpu_blocks=8192, + ) device = torch.device("cuda:0") kv_cache_spec = create_standard_kv_cache_spec(vllm_config) @@ -325,9 +328,11 @@ def _test_backend_correctness( seq_lens = batch_spec.seq_lens query_lens = batch_spec.query_lens num_q_heads = vllm_config.model_config.get_num_attention_heads( - vllm_config.parallel_config) + vllm_config.parallel_config + ) num_kv_heads = vllm_config.model_config.get_num_kv_heads( - vllm_config.parallel_config) + vllm_config.parallel_config + ) head_size = vllm_config.model_config.get_head_size() sliding_window = vllm_config.model_config.get_sliding_window() dtype = _convert_dtype_to_torch(vllm_config.model_config.dtype) @@ -345,21 +350,9 @@ def _test_backend_correctness( context_len = s_len - q_len # Generate Q, K, V for the whole sequence to be used in SDPA - q = torch.randn(q_len, - num_q_heads, - head_size, - dtype=dtype, - device=device) - k_full = torch.randn(s_len, - num_kv_heads, - head_size, - dtype=dtype, - device=device) - v_full = torch.randn(s_len, - num_kv_heads, - head_size, - dtype=dtype, - device=device) + q = torch.randn(q_len, num_q_heads, head_size, dtype=dtype, device=device) + k_full = torch.randn(s_len, num_kv_heads, head_size, dtype=dtype, device=device) + v_full = torch.randn(s_len, num_kv_heads, head_size, dtype=dtype, device=device) # SDPA expects (N, H, L, D), so unsqueeze batch and permute q_sdpa_in = q.unsqueeze(0).transpose(1, 2) @@ -369,7 +362,8 @@ def _test_backend_correctness( if num_q_heads != num_kv_heads: assert num_q_heads % num_kv_heads == 0, ( f"num_q_heads ({num_q_heads}) must be divisible by " - f"num_kv_heads ({num_kv_heads})") + f"num_kv_heads ({num_kv_heads})" + ) repeats = num_q_heads // num_kv_heads k_sdpa_in = k_sdpa_in.repeat_interleave(repeats, dim=1) v_sdpa_in = v_sdpa_in.repeat_interleave(repeats, dim=1) @@ -379,18 +373,17 @@ def _test_backend_correctness( kv_len = s_len final_mask_mod = partial(mask_mod, context_len=context_len) - block_mask = create_block_mask(final_mask_mod, - B=None, - H=None, - Q_LEN=q_len, - KV_LEN=kv_len, - device=device) - sdpa_out_i = flex_attention(q_sdpa_in, - k_sdpa_in, - v_sdpa_in, - block_mask=block_mask, - scale=scale, - enable_gqa=True) + block_mask = create_block_mask( + final_mask_mod, B=None, H=None, Q_LEN=q_len, KV_LEN=kv_len, device=device + ) + sdpa_out_i = flex_attention( + q_sdpa_in, + k_sdpa_in, + v_sdpa_in, + block_mask=block_mask, + scale=scale, + enable_gqa=True, + ) all_sdpa_outputs.append(sdpa_out_i.transpose(1, 2).squeeze(0)) @@ -409,7 +402,8 @@ def _test_backend_correctness( sdpa_output = torch.cat(all_sdpa_outputs, dim=0) common_attn_metadata = create_common_attn_metadata( - batch_spec, vllm_config.cache_config.block_size, device) + batch_spec, vllm_config.cache_config.block_size, device + ) # 3. Simulate Paged KV Cache and a realistic slot_mapping kv_cache = create_and_prepopulate_kv_cache( @@ -422,7 +416,8 @@ def _test_backend_correctness( device=device, num_blocks=vllm_config.cache_config.num_gpu_blocks or 1000, common_attn_metadata=common_attn_metadata, - randomize_blocks=True) + randomize_blocks=True, + ) # 4. Run vLLM backends and compare # Note: flex_attention has known Triton kernel compatibility issues @@ -438,8 +433,9 @@ def _test_backend_correctness( kv_cache_for_backend = kv_cache.transpose(0, 1) # For FlashInfer default to HND layout and - kv_cache_for_backend = kv_cache_for_backend.transpose( - 2, 3).contiguous().transpose(2, 3) + kv_cache_for_backend = ( + kv_cache_for_backend.transpose(2, 3).contiguous().transpose(2, 3) + ) set_kv_cache_layout("HND") backend_output = run_attention_backend( @@ -459,32 +455,45 @@ def _test_backend_correctness( # Check shape and dtype consistency assert backend_output.shape == sdpa_output.shape, ( f"[{backend_name}] shape {backend_output.shape} != " - f"SDPA shape {sdpa_output.shape}") + f"SDPA shape {sdpa_output.shape}" + ) assert backend_output.dtype == sdpa_output.dtype, ( f"[{backend_name}] dtype {backend_output.dtype} != " - f"SDPA dtype {sdpa_output.dtype}") + f"SDPA dtype {sdpa_output.dtype}" + ) assert torch.isfinite(backend_output).all(), ( - f"[{backend_name}] produced non-finite values") + f"[{backend_name}] produced non-finite values" + ) # Check numerical similarity def error_msg(msg: str, backend_name: str): - return (f"[{backend_name}] output differs from SDPA baseline. " - f"{msg}") - - torch.testing.assert_close(backend_output, - sdpa_output, - rtol=rtol, - atol=atol, - msg=partial(error_msg, - backend_name=backend_name)) - - -@pytest.mark.parametrize("batch_spec_name", [ - "small_decode", "small_prefill", "mixed_small", "medium_decode", - "medium_prefill", "mixed_medium", "large_decode", "large_prefill", - "single_decode", "single_prefill" -]) + return f"[{backend_name}] output differs from SDPA baseline. {msg}" + + torch.testing.assert_close( + backend_output, + sdpa_output, + rtol=rtol, + atol=atol, + msg=partial(error_msg, backend_name=backend_name), + ) + + +@pytest.mark.parametrize( + "batch_spec_name", + [ + "small_decode", + "small_prefill", + "mixed_small", + "medium_decode", + "medium_prefill", + "mixed_medium", + "large_decode", + "large_prefill", + "single_decode", + "single_prefill", + ], +) @pytest.mark.parametrize("model", ["meta-llama/Meta-Llama-3-8B"]) def test_causal_backend_correctness(batch_spec_name: str, model: str): """Test backend's correctness with causal attention.""" @@ -500,33 +509,33 @@ def causal_mask_mod( return (q_idx + context_len) >= kv_idx batch_spec = BATCH_SPECS[batch_spec_name] - LARGE_BLOCK_BACKENDS = ([_Backend.FLEX_ATTENTION] - if is_torch_equal_or_newer("2.9.0.dev0") else []) + LARGE_BLOCK_BACKENDS = ( + [_Backend.FLEX_ATTENTION] if is_torch_equal_or_newer("2.9.0.dev0") else [] + ) SMALL_BLOCK_BACKENDS = [ x for x in BACKENDS_TO_TEST if x not in LARGE_BLOCK_BACKENDS ] - _test_backend_correctness(batch_spec, model, SMALL_BLOCK_BACKENDS, - causal_mask_mod) + _test_backend_correctness(batch_spec, model, SMALL_BLOCK_BACKENDS, causal_mask_mod) # Fast FlexAttention needs to run with block_size=128 if LARGE_BLOCK_BACKENDS: - _test_backend_correctness(batch_spec, - model, - LARGE_BLOCK_BACKENDS, - causal_mask_mod, - block_size=128) + _test_backend_correctness( + batch_spec, model, LARGE_BLOCK_BACKENDS, causal_mask_mod, block_size=128 + ) SLIDING_WINDOW_BACKENDS_TO_TEST = [ - _Backend.FLASH_ATTN, _Backend.FLEX_ATTENTION, _Backend.TRITON_ATTN, - "FLEX_ATTENTION_SLOW" + _Backend.FLASH_ATTN, + _Backend.FLEX_ATTENTION, + _Backend.TRITON_ATTN, + "FLEX_ATTENTION_SLOW", ] -@pytest.mark.parametrize("batch_spec_name", [ - "small_decode", "small_prefill", "mixed_medium", "large_decode", - "large_prefill" -]) +@pytest.mark.parametrize( + "batch_spec_name", + ["small_decode", "small_prefill", "mixed_medium", "large_decode", "large_prefill"], +) @pytest.mark.parametrize("model", ["microsoft/Phi-tiny-MoE-instruct"]) def test_sliding_window_backend_correctness(batch_spec_name: str, model: str): """Test backend's correctness with sliding window attention.""" @@ -545,28 +554,31 @@ def sliding_window_mask_mod( return causal_mask & window_mask batch_spec = BATCH_SPECS[batch_spec_name] - model_config = ModelConfig(model=model, - max_model_len=max(batch_spec.seq_lens)) + model_config = ModelConfig(model=model, max_model_len=max(batch_spec.seq_lens)) sliding_window = model_config.get_sliding_window() - sliding_window_mask_mod_fn = partial(sliding_window_mask_mod, - sliding_window=sliding_window) + sliding_window_mask_mod_fn = partial( + sliding_window_mask_mod, sliding_window=sliding_window + ) - LARGE_BLOCK_BACKENDS = ([_Backend.FLEX_ATTENTION] - if is_torch_equal_or_newer("2.9.0.dev0") else []) + LARGE_BLOCK_BACKENDS = ( + [_Backend.FLEX_ATTENTION] if is_torch_equal_or_newer("2.9.0.dev0") else [] + ) SMALL_BLOCK_BACKENDS = [ - x for x in SLIDING_WINDOW_BACKENDS_TO_TEST - if x not in LARGE_BLOCK_BACKENDS + x for x in SLIDING_WINDOW_BACKENDS_TO_TEST if x not in LARGE_BLOCK_BACKENDS ] - _test_backend_correctness(batch_spec, model, SMALL_BLOCK_BACKENDS, - sliding_window_mask_mod_fn) + _test_backend_correctness( + batch_spec, model, SMALL_BLOCK_BACKENDS, sliding_window_mask_mod_fn + ) # Fast FlexAttention needs to run with block_size=128 if LARGE_BLOCK_BACKENDS: - _test_backend_correctness(batch_spec, - model, - LARGE_BLOCK_BACKENDS, - sliding_window_mask_mod_fn, - block_size=128) + _test_backend_correctness( + batch_spec, + model, + LARGE_BLOCK_BACKENDS, + sliding_window_mask_mod_fn, + block_size=128, + ) @pytest.mark.parametrize("device", ["cpu", "cuda"]) diff --git a/tests/v1/e2e/test_spec_decode.py b/tests/v1/e2e/test_spec_decode.py index 4859a3cd66fb..775371897573 100644 --- a/tests/v1/e2e/test_spec_decode.py +++ b/tests/v1/e2e/test_spec_decode.py @@ -17,8 +17,7 @@ from vllm.outputs import RequestOutput from vllm.platforms import current_platform from vllm.v1.spec_decode.draft_model import compute_subrange_indices -from vllm.v1.spec_decode.metrics import (compute_acceptance_len, - compute_acceptance_rate) +from vllm.v1.spec_decode.metrics import compute_acceptance_len, compute_acceptance_rate MTP_SIMILARITY_RATE = 0.8 @@ -55,19 +54,17 @@ def get_test_prompts(mm_enabled: bool, quiet: bool = False): give no other output than that simple sentence without quotes. """ elif kind == "mm": - placeholders = [{ - "type": "image_url", - "image_url": { - "url": - f"{VLLM_S3_BUCKET_URL}/{VLM_IMAGES_DIR}/stop_sign.jpg" - }, - }] + placeholders = [ + { + "type": "image_url", + "image_url": { + "url": f"{VLLM_S3_BUCKET_URL}/{VLM_IMAGES_DIR}/stop_sign.jpg" + }, + } + ] prompt = [ *placeholders, - { - "type": "text", - "text": "The meaning of the image is" - }, + {"type": "text", "text": "The meaning of the image is"}, ] else: raise ValueError(f"Unknown prompt type: {kind}") @@ -99,10 +96,10 @@ def test_ngram_correctness( sampling_config: SamplingParams, model_name: str, ): - ''' + """ Compare the outputs of an original LLM and a speculative LLM should be the same when using ngram speculative decoding. - ''' + """ test_prompts = get_test_prompts(mm_enabled=False) ref_llm = LLM(model=model_name, max_model_len=1024) @@ -144,32 +141,77 @@ def test_ngram_correctness( ["model_setup", "mm_enabled"], [ (("eagle3", "Qwen/Qwen3-8B", "AngelSlim/Qwen3-8B_eagle3", 1), False), - pytest.param(("eagle3", "Qwen/Qwen2.5-VL-7B-Instruct", - "Rayzl/qwen2.5-vl-7b-eagle3-sgl", 1), - False, - marks=pytest.mark.skip(reason="Skipping due to its " \ - "head_dim not being a a multiple of 32")), - (("eagle", "meta-llama/Llama-3.1-8B-Instruct", - "yuhuili/EAGLE-LLaMA3.1-Instruct-8B", 1), False), - (("eagle3", "meta-llama/Llama-3.1-8B-Instruct", - "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B", 1), False), - pytest.param(("eagle", "meta-llama/Llama-4-Scout-17B-16E-Instruct", - "morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct", 4), - False, - marks=large_gpu_mark(min_gb=80)), # works on 4x H100 - pytest.param(("eagle", "meta-llama/Llama-4-Scout-17B-16E-Instruct", - "morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct", 4), - True, - marks=large_gpu_mark(min_gb=80)), # works on 4x H100 - (("eagle", "eagle618/deepseek-v3-random", - "eagle618/eagle-deepseek-v3-random", 1), False), + pytest.param( + ( + "eagle3", + "Qwen/Qwen2.5-VL-7B-Instruct", + "Rayzl/qwen2.5-vl-7b-eagle3-sgl", + 1, + ), + False, + marks=pytest.mark.skip( + reason="Skipping due to its head_dim not being a a multiple of 32" + ), + ), + ( + ( + "eagle", + "meta-llama/Llama-3.1-8B-Instruct", + "yuhuili/EAGLE-LLaMA3.1-Instruct-8B", + 1, + ), + False, + ), + ( + ( + "eagle3", + "meta-llama/Llama-3.1-8B-Instruct", + "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B", + 1, + ), + False, + ), + pytest.param( + ( + "eagle", + "meta-llama/Llama-4-Scout-17B-16E-Instruct", + "morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct", + 4, + ), + False, + marks=large_gpu_mark(min_gb=80), + ), # works on 4x H100 + pytest.param( + ( + "eagle", + "meta-llama/Llama-4-Scout-17B-16E-Instruct", + "morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct", + 4, + ), + True, + marks=large_gpu_mark(min_gb=80), + ), # works on 4x H100 + ( + ( + "eagle", + "eagle618/deepseek-v3-random", + "eagle618/eagle-deepseek-v3-random", + 1, + ), + False, + ), ], ids=[ - "qwen3_eagle3", "qwen2_5_vl_eagle3", "llama3_eagle", "llama3_eagle3", - "llama4_eagle", "llama4_eagle_mm", "deepseek_eagle" - ]) -@pytest.mark.parametrize("attn_backend", - get_attn_backend_list_based_on_platform()) + "qwen3_eagle3", + "qwen2_5_vl_eagle3", + "llama3_eagle", + "llama3_eagle3", + "llama4_eagle", + "llama4_eagle_mm", + "deepseek_eagle", + ], +) +@pytest.mark.parametrize("attn_backend", get_attn_backend_list_based_on_platform()) def test_eagle_correctness( monkeypatch: pytest.MonkeyPatch, sampling_config: SamplingParams, @@ -181,15 +223,16 @@ def test_eagle_correctness( # TODO: Fix this flaky test pytest.skip( "TREE_ATTN is flaky in the test disable for now until it can be " - "resolved (see https://github.com/vllm-project/vllm/issues/22922)") + "resolved (see https://github.com/vllm-project/vllm/issues/22922)" + ) # Generate test prompts inside the function instead of using fixture test_prompts = get_test_prompts(mm_enabled) - ''' + """ Compare the outputs of a original LLM and a speculative LLM should be the same when using eagle speculative decoding. model_setup: (method, model_name, eagle_model_name, tp_size) - ''' + """ with monkeypatch.context() as m: if "Llama-4-Scout" in model_setup[1] and attn_backend == "FLASH_ATTN": # Scout requires default backend selection @@ -200,18 +243,20 @@ def test_eagle_correctness( m.setenv("VLLM_MLA_DISABLE", "1") m.setenv("VLLM_ATTENTION_BACKEND", attn_backend) - if (attn_backend == "TRITON_ATTN" and not current_platform.is_rocm()): - pytest.skip("TRITON_ATTN does not support " - "multi-token eagle spec decode on current platform") + if attn_backend == "TRITON_ATTN" and not current_platform.is_rocm(): + pytest.skip( + "TRITON_ATTN does not support " + "multi-token eagle spec decode on current platform" + ) if attn_backend == "FLASH_ATTN" and current_platform.is_rocm(): m.setenv("VLLM_ROCM_USE_AITER", "1") method, model_name, spec_model_name, tp_size = model_setup - ref_llm = LLM(model=model_name, - max_model_len=2048, - tensor_parallel_size=tp_size) + ref_llm = LLM( + model=model_name, max_model_len=2048, tensor_parallel_size=tp_size + ) ref_outputs = ref_llm.chat(test_prompts, sampling_config) del ref_llm torch.cuda.empty_cache() @@ -248,11 +293,14 @@ def test_eagle_correctness( cleanup_dist_env_and_memory() -@pytest.mark.parametrize(["model_setup", "mm_enabled"], [ - (("mtp", "XiaomiMiMo/MiMo-7B-Base", 1), False), - (("mtp", "ZixiQi/DeepSeek-V3-4layers-MTP-FP8", 1), False), -], - ids=["mimo", "deepseek"]) +@pytest.mark.parametrize( + ["model_setup", "mm_enabled"], + [ + (("mtp", "XiaomiMiMo/MiMo-7B-Base", 1), False), + (("mtp", "ZixiQi/DeepSeek-V3-4layers-MTP-FP8", 1), False), + ], + ids=["mimo", "deepseek"], +) def test_mtp_correctness( monkeypatch: pytest.MonkeyPatch, sampling_config: SamplingParams, @@ -261,21 +309,23 @@ def test_mtp_correctness( ): # Generate test prompts inside the function instead of using fixture test_prompts = get_test_prompts(mm_enabled) - ''' + """ Compare the outputs of a original LLM and a speculative LLM should be the same when using MTP speculative decoding. model_setup: (method, model_name, tp_size) - ''' + """ with monkeypatch.context() as m: m.setenv("VLLM_USE_V1", "1") m.setenv("VLLM_MLA_DISABLE", "1") method, model_name, tp_size = model_setup - ref_llm = LLM(model=model_name, - max_model_len=2048, - tensor_parallel_size=tp_size, - trust_remote_code=True) + ref_llm = LLM( + model=model_name, + max_model_len=2048, + tensor_parallel_size=tp_size, + trust_remote_code=True, + ) ref_outputs = ref_llm.chat(test_prompts, sampling_config) del ref_llm torch.cuda.empty_cache() @@ -333,7 +383,7 @@ class ArgsTest: model="Qwen/Qwen3-0.6B", draft_model="Qwen/Qwen3-0.6B", sampling_config=greedy_sampling(), - num_speculative_tokens=3, # K + num_speculative_tokens=3, # K expected_acceptance_len=3 + 1, # K + 1 expected_acceptance_rate=1.0, expected_same_output_fraction=1.0, @@ -408,15 +458,18 @@ def test_draft_model_correctness( match_fraction = compute_exact_matches(ref_outputs, spec_outputs) assert match_fraction >= args.expected_same_output_fraction - print(f"spec-decode: target={args.model}, draft={args.draft_model}, " - f"temperature={args.sampling_config.temperature:.2f}, " - f"acceptance_rate={acceptance_rate:.2f}, " - f"acceptance_len={acceptance_len:.2f}, " - f"match_fraction={match_fraction:.2f}") + print( + f"spec-decode: target={args.model}, draft={args.draft_model}, " + f"temperature={args.sampling_config.temperature:.2f}, " + f"acceptance_rate={acceptance_rate:.2f}, " + f"acceptance_len={acceptance_len:.2f}, " + f"match_fraction={match_fraction:.2f}" + ) -def compute_exact_matches(ref_outputs: list[RequestOutput], - spec_outputs: list[RequestOutput]) -> float: +def compute_exact_matches( + ref_outputs: list[RequestOutput], spec_outputs: list[RequestOutput] +) -> float: """Compute the fraction of the prompts that match exactly""" assert len(ref_outputs) == len(spec_outputs) matches = 0 diff --git a/tests/v1/spec_decode/test_eagle.py b/tests/v1/spec_decode/test_eagle.py index d317f62bade4..3c748e25bd63 100644 --- a/tests/v1/spec_decode/test_eagle.py +++ b/tests/v1/spec_decode/test_eagle.py @@ -8,13 +8,22 @@ import torch from tests.utils import get_attn_backend_list_based_on_platform -from tests.v1.attention.utils import (BatchSpec, create_common_attn_metadata, - create_standard_kv_cache_spec, - get_attention_backend) +from tests.v1.attention.utils import ( + BatchSpec, + create_common_attn_metadata, + create_standard_kv_cache_spec, + get_attention_backend, +) from vllm.attention.backends.registry import _Backend -from vllm.config import (CacheConfig, DeviceConfig, ModelConfig, - ParallelConfig, SchedulerConfig, SpeculativeConfig, - VllmConfig) +from vllm.config import ( + CacheConfig, + DeviceConfig, + ModelConfig, + ParallelConfig, + SchedulerConfig, + SpeculativeConfig, + VllmConfig, +) from vllm.config.load import LoadConfig from vllm.model_executor.models.llama import LlamaForCausalLM from vllm.platforms import current_platform @@ -32,9 +41,7 @@ def _create_proposer( num_speculative_tokens: int, speculative_token_tree: Optional[list[tuple[int, ...]]] = None, ) -> EagleProposer: - model_config = ModelConfig(model=model_dir, - runner="generate", - max_model_len=100) + model_config = ModelConfig(model=model_dir, runner="generate", max_model_len=100) # Choose model directory based on method draft_model_dir = eagle_dir if method == "eagle" else eagle3_dir @@ -60,10 +67,10 @@ def _create_proposer( device_config=DeviceConfig(device=current_platform.device_type), parallel_config=ParallelConfig(), load_config=LoadConfig(), - scheduler_config=SchedulerConfig()) + scheduler_config=SchedulerConfig(), + ) - return EagleProposer(vllm_config=vllm_config, - device=current_platform.device_type) + return EagleProposer(vllm_config=vllm_config, device=current_platform.device_type) def test_prepare_next_token_ids(): @@ -82,7 +89,7 @@ def test_prepare_next_token_ids(): query_lens=[num_speculative_tokens + 1] * num_requests, ) - req_ids = [f"req_{i+1}" for i in range(num_requests)] + req_ids = [f"req_{i + 1}" for i in range(num_requests)] mock_input_batch = mock.MagicMock(spec=InputBatch) mock_input_batch.req_ids = req_ids mock_input_batch.num_reqs = num_requests @@ -101,24 +108,26 @@ def test_prepare_next_token_ids(): [0, 1, -1, -1, -1], # 1 accepted, 3 rejected, "1" sampled [0, 1, 2, 3, 4], # all accepted, "4" sampled [-1, -1, -1, -1, -1], # sampling skipped, use backup token "30" - [-1, -1, -1, -1, -1] # this request will be discarded + [-1, -1, -1, -1, -1], # this request will be discarded ] - sampled_token_ids_tensor = torch.tensor(sampled_token_ids, - dtype=torch.int32, - device=device) - sampled_token_ids_cpu = [[i for i in seq if i != -1] - for seq in sampled_token_ids] + sampled_token_ids_tensor = torch.tensor( + sampled_token_ids, dtype=torch.int32, device=device + ) + sampled_token_ids_cpu = [[i for i in seq if i != -1] for seq in sampled_token_ids] expected_next_token_ids_cpu = [1, 4, 30, 40] - expected_next_token_ids_tensor = torch.tensor(expected_next_token_ids_cpu, - dtype=torch.int32, - device=device) + expected_next_token_ids_tensor = torch.tensor( + expected_next_token_ids_cpu, dtype=torch.int32, device=device + ) proposer = _create_proposer("eagle", num_speculative_tokens) next_token_ids_from_cpu = proposer.prepare_next_token_ids_cpu( - sampled_token_ids_cpu, mock_requests, mock_input_batch, - mock_num_scheduled_tokens) + sampled_token_ids_cpu, + mock_requests, + mock_input_batch, + mock_num_scheduled_tokens, + ) assert torch.equal(next_token_ids_from_cpu, expected_next_token_ids_tensor) @@ -131,19 +140,23 @@ def test_prepare_next_token_ids(): discarded_req_indices = torch.tensor([3], dtype=torch.int64, device=device) num_discarded_reqs = 1 - expected_valid_sampled_tokens_count = torch.tensor([2, 5, 0, 0], - dtype=torch.int32, - device=device) + expected_valid_sampled_tokens_count = torch.tensor( + [2, 5, 0, 0], dtype=torch.int32, device=device + ) - next_token_ids_from_padded, valid_sampled_tokens_count = \ + next_token_ids_from_padded, valid_sampled_tokens_count = ( proposer.prepare_next_token_ids_padded( - common_attn_metadata, sampled_token_ids_tensor, mock_requests, - mock_input_batch, discarded_req_indices, num_discarded_reqs) + common_attn_metadata, + sampled_token_ids_tensor, + mock_requests, + mock_input_batch, + discarded_req_indices, + num_discarded_reqs, + ) + ) - assert torch.equal(next_token_ids_from_padded, - expected_next_token_ids_tensor) - assert torch.equal(valid_sampled_tokens_count, - expected_valid_sampled_tokens_count) + assert torch.equal(next_token_ids_from_padded, expected_next_token_ids_tensor) + assert torch.equal(valid_sampled_tokens_count, expected_valid_sampled_tokens_count) def test_prepare_inputs(): @@ -183,21 +196,27 @@ def test_prepare_inputs(): sampled_token_ids = [ [ACCEPT_TOKEN, ACCEPT_TOKEN, REJECT_TOKEN, BONUS_TOKEN], [ - ACCEPT_TOKEN, ACCEPT_TOKEN, ACCEPT_TOKEN, REJECT_TOKEN, - REJECT_TOKEN, REJECT_TOKEN, BONUS_TOKEN + ACCEPT_TOKEN, + ACCEPT_TOKEN, + ACCEPT_TOKEN, + REJECT_TOKEN, + REJECT_TOKEN, + REJECT_TOKEN, + BONUS_TOKEN, ], - [ACCEPT_TOKEN, ACCEPT_TOKEN, REJECT_TOKEN, REJECT_TOKEN, BONUS_TOKEN] + [ACCEPT_TOKEN, ACCEPT_TOKEN, REJECT_TOKEN, REJECT_TOKEN, BONUS_TOKEN], + ] + sampled_token_ids = [ + [i for i in seq if i != REJECT_TOKEN] for seq in sampled_token_ids ] - sampled_token_ids = [[i for i in seq if i != REJECT_TOKEN] - for seq in sampled_token_ids] # Expected calculations: # query_len_per_req = [4, 7, 5] # num_tokens_per_req = [3, 4, 3] (after subtracting rejected tokens) # Expected cumulative counts: [0, 3, 7, 10] - expected_cu_num_tokens = torch.tensor([0, 3, 7, 10], - dtype=torch.int32, - device=device) + expected_cu_num_tokens = torch.tensor( + [0, 3, 7, 10], dtype=torch.int32, device=device + ) # Expected token indices (mapped from original positions): # First request: indices 0, 1, 2 (keeping first 3 from positions 0-3) @@ -214,17 +233,18 @@ def test_prepare_inputs(): 7, # Second request: 4 tokens (7-3) 11, 12, - 13 # Third request: 3 tokens (5-2) + 13, # Third request: 3 tokens (5-2) ], dtype=torch.int32, - device=device) + device=device, + ) proposer = _create_proposer("eagle", 1) updated_metadata, token_indices = proposer.prepare_inputs( - common_attn_metadata, sampled_token_ids, num_draft_tokens) + common_attn_metadata, sampled_token_ids, num_draft_tokens + ) - assert torch.equal(updated_metadata.query_start_loc, - expected_cu_num_tokens) + assert torch.equal(updated_metadata.query_start_loc, expected_cu_num_tokens) assert token_indices.shape[0] == expected_cu_num_tokens[-1].item() assert torch.equal(token_indices, expected_token_indices) @@ -249,12 +269,12 @@ def test_prepare_inputs_padded(): device = torch.device(current_platform.device_type) - expected_token_indices = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 8], - dtype=torch.int32, - device=device) - expected_token_indices_to_sample = torch.tensor([1, 5, 6], - dtype=torch.int32, - device=device) + expected_token_indices = torch.tensor( + [0, 1, 2, 3, 4, 5, 6, 7, 8], dtype=torch.int32, device=device + ) + expected_token_indices_to_sample = torch.tensor( + [1, 5, 6], dtype=torch.int32, device=device + ) num_speculative_tokens = 2 batch_spec = BatchSpec( @@ -269,9 +289,9 @@ def test_prepare_inputs_padded(): ) # Needed for cu_num_draft_tokens, which is expected to be [3, 6, 9] - expected_query_start_loc = torch.tensor([0, 3, 6, 9], - dtype=torch.int32, - device=device) + expected_query_start_loc = torch.tensor( + [0, 3, 6, 9], dtype=torch.int32, device=device + ) spec_decode_metadata = SpecDecodeMetadata.make_dummy( draft_token_ids=[[0] * num_speculative_tokens] * 3, device=device, @@ -280,43 +300,48 @@ def test_prepare_inputs_padded(): # num_rejected_tokens = [1, 0, 2] # num_draft_tokens = [2, 2, 2] # valid_sampled_tokens_count = num_draft_tokens + 1 - num_rejected_tokens - valid_sampled_tokens_count = torch.tensor([2, 3, 1], - dtype=torch.int32, - device=device) + valid_sampled_tokens_count = torch.tensor( + [2, 3, 1], dtype=torch.int32, device=device + ) proposer = _create_proposer("eagle", num_speculative_tokens) - output_metadata, token_indices, token_indices_to_sample = \ + output_metadata, token_indices, token_indices_to_sample = ( proposer.prepare_inputs_padded( - common_attn_metadata, - spec_decode_metadata, - valid_sampled_tokens_count) + common_attn_metadata, spec_decode_metadata, valid_sampled_tokens_count + ) + ) assert output_metadata.max_query_len == 3 - assert torch.equal(output_metadata.query_start_loc, - expected_query_start_loc) + assert torch.equal(output_metadata.query_start_loc, expected_query_start_loc) assert torch.equal(token_indices, expected_token_indices) - assert torch.equal(token_indices_to_sample, - expected_token_indices_to_sample) + assert torch.equal(token_indices_to_sample, expected_token_indices_to_sample) @pytest.mark.parametrize("method", ["eagle", "eagle3"]) -@pytest.mark.parametrize("attn_backend", - get_attn_backend_list_based_on_platform()) +@pytest.mark.parametrize("attn_backend", get_attn_backend_list_based_on_platform()) @pytest.mark.parametrize("pp_size", [1, 2]) @pytest.mark.parametrize("use_distinct_embed_tokens", [True, False]) -@mock.patch('vllm.v1.spec_decode.eagle.get_pp_group') -@mock.patch('vllm.v1.spec_decode.eagle.get_layers_from_vllm_config') -@mock.patch('vllm.v1.spec_decode.eagle.get_model') -def test_load_model(mock_get_model, mock_get_layers, mock_get_pp_group, method, - attn_backend, pp_size, use_distinct_embed_tokens, - monkeypatch): - +@mock.patch("vllm.v1.spec_decode.eagle.get_pp_group") +@mock.patch("vllm.v1.spec_decode.eagle.get_layers_from_vllm_config") +@mock.patch("vllm.v1.spec_decode.eagle.get_model") +def test_load_model( + mock_get_model, + mock_get_layers, + mock_get_pp_group, + method, + attn_backend, + pp_size, + use_distinct_embed_tokens, + monkeypatch, +): monkeypatch.setenv("VLLM_ATTENTION_BACKEND", attn_backend) - if (attn_backend == "TRITON_ATTN" and not current_platform.is_rocm()): - pytest.skip("TRITON_ATTN does not support " - "multi-token eagle spec decode on current platform") + if attn_backend == "TRITON_ATTN" and not current_platform.is_rocm(): + pytest.skip( + "TRITON_ATTN does not support " + "multi-token eagle spec decode on current platform" + ) if attn_backend == "FLASH_ATTN" and current_platform.is_rocm(): monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1") @@ -335,20 +360,20 @@ def test_load_model(mock_get_model, mock_get_layers, mock_get_pp_group, method, # Setup mocks for attention layers target_attn_layers = { "target_attn_1": mock.MagicMock(), - "target_attn_2": mock.MagicMock() + "target_attn_2": mock.MagicMock(), } target_indx_layers: dict[str, mock.MagicMock] = {} # Draft model has one extra attention layer compared to target model - all_attn_layers = { - **target_attn_layers, "draft_extra_attn": mock.MagicMock() - } + all_attn_layers = {**target_attn_layers, "draft_extra_attn": mock.MagicMock()} all_indx_layers: dict[str, mock.MagicMock] = {} # Make mock_get_layers return different values for each call mock_get_layers.side_effect = [ - target_attn_layers, target_indx_layers, all_attn_layers, - all_indx_layers + target_attn_layers, + target_indx_layers, + all_attn_layers, + all_indx_layers, ] # Setup mock for pp group to return the appropriate value for world size @@ -367,6 +392,7 @@ class _TargetModelStub(LlamaForCausalLM): target_model.model.embed_tokens.weight.shape = (131072, 4096) from vllm.model_executor.models import SupportsMultiModal + assert not isinstance(target_model, SupportsMultiModal) if method == "eagle": @@ -388,30 +414,30 @@ class _TargetModelStub(LlamaForCausalLM): # Verify that the embed tokens are set correctly # If pp_size is > 1, the embed tokens should be distinct if pp_size > 1 or use_distinct_embed_tokens: - assert proposer.model.model.embed_tokens != \ - target_model.model.embed_tokens + assert proposer.model.model.embed_tokens != target_model.model.embed_tokens else: # When pp_size is 1 and the draft and target models have # embed_tokens of the same shape, they should be shared. - assert proposer.model.model.embed_tokens == \ - target_model.model.embed_tokens + assert proposer.model.model.embed_tokens == target_model.model.embed_tokens @pytest.mark.parametrize("method", ["eagle", "eagle3"]) -@pytest.mark.parametrize("attn_backend", - get_attn_backend_list_based_on_platform()) +@pytest.mark.parametrize("attn_backend", get_attn_backend_list_based_on_platform()) @pytest.mark.parametrize("num_speculative_tokens", [1, 3, 8]) def test_propose(method, attn_backend, num_speculative_tokens, monkeypatch): - monkeypatch.setenv("VLLM_ATTENTION_BACKEND", attn_backend) - if (attn_backend == "TRITON_ATTN" and not current_platform.is_rocm()): - pytest.skip("TRITON_ATTN does not support " - "multi-token eagle spec decode on current platform") + if attn_backend == "TRITON_ATTN" and not current_platform.is_rocm(): + pytest.skip( + "TRITON_ATTN does not support " + "multi-token eagle spec decode on current platform" + ) - if (attn_backend == "TREE_ATTN"): - pytest.skip("TREE_ATTN is tested separately in test_propose_tree" - "because it requires special input mocking.") + if attn_backend == "TREE_ATTN": + pytest.skip( + "TREE_ATTN is tested separately in test_propose_tree" + "because it requires special input mocking." + ) if attn_backend == "FLASH_ATTN" and current_platform.is_rocm(): monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1") @@ -498,31 +524,22 @@ def create_deterministic_logits(token_ids): device=device, ) - target_token_ids = torch.randint(0, - vocab_size, (total_tokens, ), - device=device) - target_positions = torch.cat([ - torch.arange(seq_len_1, device=device), - torch.arange(seq_len_2, device=device) - ]) - target_hidden_states = torch.randn(total_tokens, - hidden_size, - device=device) - next_token_ids = torch.randint(0, - vocab_size, (batch_size, ), - dtype=torch.int32, - device=device) + target_token_ids = torch.randint(0, vocab_size, (total_tokens,), device=device) + target_positions = torch.cat( + [torch.arange(seq_len_1, device=device), torch.arange(seq_len_2, device=device)] + ) + target_hidden_states = torch.randn(total_tokens, hidden_size, device=device) + next_token_ids = torch.randint( + 0, vocab_size, (batch_size,), dtype=torch.int32, device=device + ) sampling_metadata = mock.MagicMock() if attn_backend == "FLASH_ATTN": - attn_metadata_builder_cls, _ = get_attention_backend( - _Backend.FLASH_ATTN) + attn_metadata_builder_cls, _ = get_attention_backend(_Backend.FLASH_ATTN) elif attn_backend == "TRITON_ATTN": - attn_metadata_builder_cls, _ = get_attention_backend( - _Backend.TRITON_ATTN) + attn_metadata_builder_cls, _ = get_attention_backend(_Backend.TRITON_ATTN) elif attn_backend == "TREE_ATTN": - attn_metadata_builder_cls, _ = get_attention_backend( - _Backend.TREE_ATTN) + attn_metadata_builder_cls, _ = get_attention_backend(_Backend.TREE_ATTN) else: raise ValueError(f"Unsupported attention backend: {attn_backend}") @@ -536,19 +553,23 @@ def create_deterministic_logits(token_ids): # Mock runner for attention metadata building proposer.runner = mock.MagicMock() proposer.runner.attn_groups.append([mock.MagicMock()]) - proposer.runner.attn_groups[0][0].get_metadata_builder.return_value = \ - attn_metadata_builder + proposer.runner.attn_groups[0][ + 0 + ].get_metadata_builder.return_value = attn_metadata_builder proposer._get_attention_metadata_builder = mock.MagicMock( - return_value=attn_metadata_builder) + return_value=attn_metadata_builder + ) - result = proposer.propose(target_token_ids=target_token_ids, - target_positions=target_positions, - target_hidden_states=target_hidden_states, - next_token_ids=next_token_ids, - last_token_indices=None, - common_attn_metadata=common_attn_metadata, - sampling_metadata=sampling_metadata, - cudagraph_args=dict()) + result = proposer.propose( + target_token_ids=target_token_ids, + target_positions=target_positions, + target_hidden_states=target_hidden_states, + next_token_ids=next_token_ids, + last_token_indices=None, + common_attn_metadata=common_attn_metadata, + sampling_metadata=sampling_metadata, + cudagraph_args=dict(), + ) assert result.shape == (batch_size, num_speculative_tokens) @@ -557,13 +578,14 @@ def create_deterministic_logits(token_ids): # Example for num_speculative_tokens=1: # [[42], [60]] expected_tokens = torch.tensor( - [[base_token_ids[0]], [base_token_ids[1]]], device=device) + [[base_token_ids[0]], [base_token_ids[1]]], device=device + ) else: # Example for num_speculative_tokens=3: # [[42, 43, 44], [60, 61, 62]] - expected_tokens = torch.zeros((batch_size, num_speculative_tokens), - dtype=torch.int64, - device=device) + expected_tokens = torch.zeros( + (batch_size, num_speculative_tokens), dtype=torch.int64, device=device + ) for i in range(batch_size): for j in range(num_speculative_tokens): expected_tokens[i, j] = base_token_ids[i] + j @@ -575,12 +597,12 @@ def create_deterministic_logits(token_ids): @pytest.mark.parametrize( "spec_token_tree", [ - [(0, )], # A single token - [(0, ), (0, 0), (0, 0, 0)], # Chain - [(0, ), (1, ), (2, )], # Parallel - [(0, ), (1, ), (2, ), (0, 0), (0, 1), (1, 0), (1, 1), (2, 0), - (2, 1)], # Tree - ]) + [(0,)], # A single token + [(0,), (0, 0), (0, 0, 0)], # Chain + [(0,), (1,), (2,)], # Parallel + [(0,), (1,), (2,), (0, 0), (0, 1), (1, 0), (1, 1), (2, 0), (2, 1)], # Tree + ], +) def test_propose_tree(spec_token_tree): # Get GPU device. device = torch.device(current_platform.device_type) @@ -595,9 +617,9 @@ def test_propose_tree(spec_token_tree): num_speculative_tokens = len(spec_token_tree) # Create proposer first so we can use its actual hidden_size. - proposer = _create_proposer("eagle", - num_speculative_tokens, - speculative_token_tree=spec_token_tree) + proposer = _create_proposer( + "eagle", num_speculative_tokens, speculative_token_tree=spec_token_tree + ) # Get the hidden_size from the proposer to ensure consistency. hidden_size = proposer.hidden_size @@ -618,32 +640,31 @@ def create_deterministic_logits(token_ids, k: int): model_mock = mock.MagicMock() # Mock the model forward calls. - forward_returns = [(torch.zeros(total_tokens, hidden_size, device=device), - torch.zeros(total_tokens, hidden_size, device=device))] + forward_returns = [ + ( + torch.zeros(total_tokens, hidden_size, device=device), + torch.zeros(total_tokens, hidden_size, device=device), + ) + ] for cu_num_drafts in proposer.cu_drafts_per_level: - h_logits = torch.zeros(batch_size * cu_num_drafts, - hidden_size, - device=device) - h_states = torch.zeros(batch_size * cu_num_drafts, - hidden_size, - device=device) + h_logits = torch.zeros(batch_size * cu_num_drafts, hidden_size, device=device) + h_states = torch.zeros(batch_size * cu_num_drafts, hidden_size, device=device) forward_returns.append((h_logits, h_states)) model_mock.side_effect = forward_returns # Mock the compute_logits calls. - cu_num_drafts_tensor = torch.tensor([0] + proposer.cu_drafts_per_level, - dtype=torch.int32, - device=device) + cu_num_drafts_tensor = torch.tensor( + [0] + proposer.cu_drafts_per_level, dtype=torch.int32, device=device + ) logits_returns = [] for level, num_children in enumerate(proposer.child_drafts_per_level): token_ids = base_token_ids + cu_num_drafts_tensor[level] - level_num_drafts = cu_num_drafts_tensor[ - level + 1] - cu_num_drafts_tensor[level] + level_num_drafts = cu_num_drafts_tensor[level + 1] - cu_num_drafts_tensor[level] level_logits = [] for i in range(level_num_drafts // num_children): level_logits.append( - create_deterministic_logits(token_ids + i * num_children, - num_children)) + create_deterministic_logits(token_ids + i * num_children, num_children) + ) logits_returns.append(torch.stack(level_logits, dim=1)) model_mock.compute_logits.side_effect = logits_returns @@ -665,29 +686,23 @@ def create_deterministic_logits(token_ids, k: int): # Mock runner for attention metadata building. proposer.runner = mock.MagicMock() proposer.runner.attn_groups.append([mock.MagicMock()]) - proposer.runner.attn_groups[0][0].metadata_builders = [ - attn_metadata_builder - ] - proposer.runner.attn_groups[0][0].get_metadata_builder.return_value = \ - attn_metadata_builder + proposer.runner.attn_groups[0][0].metadata_builders = [attn_metadata_builder] + proposer.runner.attn_groups[0][ + 0 + ].get_metadata_builder.return_value = attn_metadata_builder proposer._get_attention_metadata_builder = mock.MagicMock( - return_value=attn_metadata_builder) + return_value=attn_metadata_builder + ) # Setup inputs for the proposer. - target_token_ids = torch.randint(0, - vocab_size, (total_tokens, ), - device=device) - target_positions = torch.cat([ - torch.arange(seq_len_1, device=device), - torch.arange(seq_len_2, device=device) - ]) - target_hidden_states = torch.randn(total_tokens, - hidden_size, - device=device) - next_token_ids = torch.randint(0, - vocab_size, (batch_size, ), - dtype=torch.int32, - device=device) + target_token_ids = torch.randint(0, vocab_size, (total_tokens,), device=device) + target_positions = torch.cat( + [torch.arange(seq_len_1, device=device), torch.arange(seq_len_2, device=device)] + ) + target_hidden_states = torch.randn(total_tokens, hidden_size, device=device) + next_token_ids = torch.randint( + 0, vocab_size, (batch_size,), dtype=torch.int32, device=device + ) batch_spec = BatchSpec( seq_lens=seq_lens, query_lens=seq_lens, @@ -700,20 +715,23 @@ def create_deterministic_logits(token_ids, k: int): sampling_metadata = mock.MagicMock() # Propose draft tokens. - result = proposer.propose(target_token_ids=target_token_ids, - target_positions=target_positions, - target_hidden_states=target_hidden_states, - next_token_ids=next_token_ids, - last_token_indices=None, - common_attn_metadata=common_attn_metadata, - sampling_metadata=sampling_metadata, - cudagraph_args=dict()) + result = proposer.propose( + target_token_ids=target_token_ids, + target_positions=target_positions, + target_hidden_states=target_hidden_states, + next_token_ids=next_token_ids, + last_token_indices=None, + common_attn_metadata=common_attn_metadata, + sampling_metadata=sampling_metadata, + cudagraph_args=dict(), + ) assert result.shape == (batch_size, num_speculative_tokens) # The tokens are expected to be consecutive integers starting # from the base token IDs. expected_tokens = base_token_ids[:, None] + torch.arange( - num_speculative_tokens, dtype=torch.int64, device=device) + num_speculative_tokens, dtype=torch.int64, device=device + ) # Verify that the draft tokens match our expectations. assert torch.equal(result, expected_tokens) diff --git a/tests/v1/test_outputs.py b/tests/v1/test_outputs.py index 7556d5e8f4f7..5ddf923eeac1 100644 --- a/tests/v1/test_outputs.py +++ b/tests/v1/test_outputs.py @@ -14,7 +14,6 @@ def test_sampler_output(): [3, 2, -1, -1] ]) # fmt: on - so = SamplerOutput(sampled_token_ids=sampled_token_ids, - logprobs_tensors=None) + so = SamplerOutput(sampled_token_ids=sampled_token_ids, logprobs_tensors=None) expected_n_sampled_tokens = torch.tensor([3, 1, 2]) assert so.n_sampled_tokens().eq(expected_n_sampled_tokens).all() diff --git a/tests/v1/worker/test_utils.py b/tests/v1/worker/test_utils.py index 4a5f91edc747..62be6dad46a2 100644 --- a/tests/v1/worker/test_utils.py +++ b/tests/v1/worker/test_utils.py @@ -10,32 +10,28 @@ def test_bind_kv_cache(): from vllm.attention import Attention ctx = { - 'layers.0.self_attn': Attention(32, 128, 0.1), - 'layers.1.self_attn': Attention(32, 128, 0.1), - 'layers.2.self_attn': Attention(32, 128, 0.1), - 'layers.3.self_attn': Attention(32, 128, 0.1), + "layers.0.self_attn": Attention(32, 128, 0.1), + "layers.1.self_attn": Attention(32, 128, 0.1), + "layers.2.self_attn": Attention(32, 128, 0.1), + "layers.3.self_attn": Attention(32, 128, 0.1), } kv_cache = { - 'layers.0.self_attn': torch.zeros((1, )), - 'layers.1.self_attn': torch.zeros((1, )), - 'layers.2.self_attn': torch.zeros((1, )), - 'layers.3.self_attn': torch.zeros((1, )), + "layers.0.self_attn": torch.zeros((1,)), + "layers.1.self_attn": torch.zeros((1,)), + "layers.2.self_attn": torch.zeros((1,)), + "layers.3.self_attn": torch.zeros((1,)), } runner_kv_caches: list[torch.Tensor] = [] bind_kv_cache(kv_cache, ctx, runner_kv_caches) - assert ctx['layers.0.self_attn'].kv_cache[0] is kv_cache[ - 'layers.0.self_attn'] - assert ctx['layers.1.self_attn'].kv_cache[0] is kv_cache[ - 'layers.1.self_attn'] - assert ctx['layers.2.self_attn'].kv_cache[0] is kv_cache[ - 'layers.2.self_attn'] - assert ctx['layers.3.self_attn'].kv_cache[0] is kv_cache[ - 'layers.3.self_attn'] + assert ctx["layers.0.self_attn"].kv_cache[0] is kv_cache["layers.0.self_attn"] + assert ctx["layers.1.self_attn"].kv_cache[0] is kv_cache["layers.1.self_attn"] + assert ctx["layers.2.self_attn"].kv_cache[0] is kv_cache["layers.2.self_attn"] + assert ctx["layers.3.self_attn"].kv_cache[0] is kv_cache["layers.3.self_attn"] - assert runner_kv_caches[0] is kv_cache['layers.0.self_attn'] - assert runner_kv_caches[1] is kv_cache['layers.1.self_attn'] - assert runner_kv_caches[2] is kv_cache['layers.2.self_attn'] - assert runner_kv_caches[3] is kv_cache['layers.3.self_attn'] + assert runner_kv_caches[0] is kv_cache["layers.0.self_attn"] + assert runner_kv_caches[1] is kv_cache["layers.1.self_attn"] + assert runner_kv_caches[2] is kv_cache["layers.2.self_attn"] + assert runner_kv_caches[3] is kv_cache["layers.3.self_attn"] def test_bind_kv_cache_non_attention(): @@ -43,53 +39,54 @@ def test_bind_kv_cache_non_attention(): # example from Jamba PP=2 ctx = { - 'model.layers.20.attn': Attention(32, 128, 0.1), - 'model.layers.28.attn': Attention(32, 128, 0.1), + "model.layers.20.attn": Attention(32, 128, 0.1), + "model.layers.28.attn": Attention(32, 128, 0.1), } kv_cache = { - 'model.layers.20.attn': torch.zeros((1, )), - 'model.layers.28.attn': torch.zeros((1, )), + "model.layers.20.attn": torch.zeros((1,)), + "model.layers.28.attn": torch.zeros((1,)), } runner_kv_caches: list[torch.Tensor] = [] bind_kv_cache(kv_cache, ctx, runner_kv_caches) - assert ctx['model.layers.20.attn'].kv_cache[0] is kv_cache[ - 'model.layers.20.attn'] - assert ctx['model.layers.28.attn'].kv_cache[0] is kv_cache[ - 'model.layers.28.attn'] + assert ctx["model.layers.20.attn"].kv_cache[0] is kv_cache["model.layers.20.attn"] + assert ctx["model.layers.28.attn"].kv_cache[0] is kv_cache["model.layers.28.attn"] - assert runner_kv_caches[0] is kv_cache['model.layers.20.attn'] - assert runner_kv_caches[1] is kv_cache['model.layers.28.attn'] + assert runner_kv_caches[0] is kv_cache["model.layers.20.attn"] + assert runner_kv_caches[1] is kv_cache["model.layers.28.attn"] def test_bind_kv_cache_draft_model(): from vllm.attention import Attention + ctx = { - 'model.layers.0.attn': Attention(32, 128, 0.1), - 'model.layers.1.attn': Attention(32, 128, 0.1), - 'draft_model.layers.0.attn': Attention(32, 128, 0.1), - 'draft_model.layers.1.attn': Attention(32, 128, 0.1), + "model.layers.0.attn": Attention(32, 128, 0.1), + "model.layers.1.attn": Attention(32, 128, 0.1), + "draft_model.layers.0.attn": Attention(32, 128, 0.1), + "draft_model.layers.1.attn": Attention(32, 128, 0.1), } kv_cache = { - 'model.layers.0.attn': torch.zeros((1, )), - 'model.layers.1.attn': torch.zeros((1, )), - 'draft_model.layers.0.attn': torch.zeros((1, )), - 'draft_model.layers.1.attn': torch.zeros((1, )), + "model.layers.0.attn": torch.zeros((1,)), + "model.layers.1.attn": torch.zeros((1,)), + "draft_model.layers.0.attn": torch.zeros((1,)), + "draft_model.layers.1.attn": torch.zeros((1,)), } runner_kv_caches: list[torch.Tensor] = [] bind_kv_cache(kv_cache, ctx, runner_kv_caches) - assert ctx['model.layers.0.attn'].kv_cache[0] is kv_cache[ - 'model.layers.0.attn'] - assert ctx['model.layers.1.attn'].kv_cache[0] is kv_cache[ - 'model.layers.1.attn'] - assert ctx['draft_model.layers.0.attn'].kv_cache[0] is kv_cache[ - 'draft_model.layers.0.attn'] - assert ctx['draft_model.layers.1.attn'].kv_cache[0] is kv_cache[ - 'draft_model.layers.1.attn'] + assert ctx["model.layers.0.attn"].kv_cache[0] is kv_cache["model.layers.0.attn"] + assert ctx["model.layers.1.attn"].kv_cache[0] is kv_cache["model.layers.1.attn"] + assert ( + ctx["draft_model.layers.0.attn"].kv_cache[0] + is kv_cache["draft_model.layers.0.attn"] + ) + assert ( + ctx["draft_model.layers.1.attn"].kv_cache[0] + is kv_cache["draft_model.layers.1.attn"] + ) # caches are ordered by layer_index, interleaving target and draft model - assert runner_kv_caches[0] is kv_cache['model.layers.0.attn'] - assert runner_kv_caches[1] is kv_cache['draft_model.layers.0.attn'] - assert runner_kv_caches[2] is kv_cache['model.layers.1.attn'] - assert runner_kv_caches[3] is kv_cache['draft_model.layers.1.attn'] \ No newline at end of file + assert runner_kv_caches[0] is kv_cache["model.layers.0.attn"] + assert runner_kv_caches[1] is kv_cache["draft_model.layers.0.attn"] + assert runner_kv_caches[2] is kv_cache["model.layers.1.attn"] + assert runner_kv_caches[3] is kv_cache["draft_model.layers.1.attn"] diff --git a/vllm/config/speculative.py b/vllm/config/speculative.py index a70ee896225e..521110431f93 100644 --- a/vllm/config/speculative.py +++ b/vllm/config/speculative.py @@ -24,23 +24,41 @@ PretrainedConfig = Any ModelConfig = Any - me_quant = LazyLoader("model_executor", globals(), - "vllm.model_executor.layers.quantization") + me_quant = LazyLoader( + "model_executor", globals(), "vllm.model_executor.layers.quantization" + ) logger = init_logger(__name__) -SpeculativeMethod = Literal["ngram", "eagle", "eagle3", "medusa", - "mlp_speculator", "draft_model", "deepseek_mtp", - "ernie_mtp", "qwen3_next_mtp", "mimo_mtp", - "longcat_flash_mtp", "mtp"] -MTP_MODEL_TYPES = ("deepseek_mtp", "mimo_mtp", "glm4_moe_mtp", "ernie_mtp", - "qwen3_next_mtp", "longcat_flash_mtp") +SpeculativeMethod = Literal[ + "ngram", + "eagle", + "eagle3", + "medusa", + "mlp_speculator", + "draft_model", + "deepseek_mtp", + "ernie_mtp", + "qwen3_next_mtp", + "mimo_mtp", + "longcat_flash_mtp", + "mtp", +] +MTP_MODEL_TYPES = ( + "deepseek_mtp", + "mimo_mtp", + "glm4_moe_mtp", + "ernie_mtp", + "qwen3_next_mtp", + "longcat_flash_mtp", +) @config @dataclass class SpeculativeConfig: """Configuration for speculative decoding.""" + enforce_eager: Optional[bool] = None """Override the default enforce_eager from model_config""" # General speculative decoding control @@ -107,8 +125,7 @@ class SpeculativeConfig: # required configuration params passed from engine target_model_config: SkipValidation[ModelConfig] = None # type: ignore """The configuration of the target model.""" - target_parallel_config: SkipValidation[ - ParallelConfig] = None # type: ignore + target_parallel_config: SkipValidation[ParallelConfig] = None # type: ignore """The parallel configuration for the target model.""" enable_chunked_prefill: SkipValidation[bool] = None # type: ignore """Whether vLLM is configured to use chunked prefill or not. Used for @@ -120,8 +137,7 @@ class SpeculativeConfig: # params generated in the post-init stage draft_model_config: SkipValidation[ModelConfig] = None # type: ignore """The configuration of the draft model initialized internal.""" - draft_parallel_config: SkipValidation[ - ParallelConfig] = None # type: ignore + draft_parallel_config: SkipValidation[ParallelConfig] = None # type: ignore """The parallel configuration for the draft model initialized internal.""" def compute_hash(self) -> str: @@ -140,8 +156,7 @@ def compute_hash(self) -> str: # Eagle3 affects the computation graph because it returns intermediate # hidden states in addition to the final hidden state. factors.append(self.method == "eagle3") - hash_str = hashlib.md5(str(factors).encode(), - usedforsecurity=False).hexdigest() + hash_str = hashlib.md5(str(factors).encode(), usedforsecurity=False).hexdigest() return hash_str @staticmethod @@ -150,58 +165,57 @@ def hf_config_override(hf_config: PretrainedConfig) -> PretrainedConfig: hf_config.model_type = "deepseek_mtp" if hf_config.model_type == "deepseek_mtp": n_predict = getattr(hf_config, "num_nextn_predict_layers", None) - hf_config.update({ - "n_predict": n_predict, - "architectures": ["DeepSeekMTPModel"] - }) + hf_config.update( + {"n_predict": n_predict, "architectures": ["DeepSeekMTPModel"]} + ) if hf_config.architectures[0] == "MiMoForCausalLM": hf_config.model_type = "mimo_mtp" n_predict = getattr(hf_config, "num_nextn_predict_layers", None) - hf_config.update({ - "num_hidden_layers": 0, - "n_predict": n_predict, - "architectures": ["MiMoMTPModel"] - }) + hf_config.update( + { + "num_hidden_layers": 0, + "n_predict": n_predict, + "architectures": ["MiMoMTPModel"], + } + ) if hf_config.architectures[0] == "Glm4MoeForCausalLM": hf_config.model_type = "glm4_moe_mtp" n_predict = getattr(hf_config, "num_nextn_predict_layers", None) - hf_config.update({ - "num_hidden_layers": 0, - "n_predict": n_predict, - "architectures": ["Glm4MoeMTPModel"] - }) + hf_config.update( + { + "num_hidden_layers": 0, + "n_predict": n_predict, + "architectures": ["Glm4MoeMTPModel"], + } + ) if hf_config.model_type == "ernie4_5_moe": hf_config.model_type = "ernie_mtp" if hf_config.model_type == "ernie_mtp": n_predict = getattr(hf_config, "num_nextn_predict_layers", None) - hf_config.update({ - "n_predict": n_predict, - "architectures": ["ErnieMTPModel"] - }) + hf_config.update( + {"n_predict": n_predict, "architectures": ["ErnieMTPModel"]} + ) if hf_config.model_type == "qwen3_next": hf_config.model_type = "qwen3_next_mtp" if hf_config.model_type == "qwen3_next_mtp": n_predict = getattr(hf_config, "num_nextn_predict_layers", None) - hf_config.update({ - "n_predict": n_predict, - "architectures": ["Qwen3NextMTP"] - }) + hf_config.update( + {"n_predict": n_predict, "architectures": ["Qwen3NextMTP"]} + ) if hf_config.model_type == "longcat_flash": hf_config.model_type = "longcat_flash_mtp" n_predict = getattr(hf_config, "num_nextn_predict_layers", 1) - hf_config.update({ - "n_predict": n_predict, - "architectures": ["LongCatFlashMTPModel"] - }) + hf_config.update( + {"n_predict": n_predict, "architectures": ["LongCatFlashMTPModel"]} + ) return hf_config def __post_init__(self): - # Note: "method" is a new parameter that helps to extend the # configuration of non-model-based proposers, and the "model" parameter # will be used to set the draft model, eagle head, or additional weight @@ -211,17 +225,17 @@ def __post_init__(self): # default. if self.method in MTP_MODEL_TYPES: - logger.warning("method `%s` is deprecated and replaced with mtp.", - self.method) + logger.warning( + "method `%s` is deprecated and replaced with mtp.", self.method + ) self.method = "mtp" if self.model is None and self.num_speculative_tokens is not None: if self.method == "mtp": - assert ( - self.target_model_config - is not None), "target_model_config must be present for mtp" - if self.target_model_config.hf_text_config.model_type \ - == "deepseek_v32": + assert self.target_model_config is not None, ( + "target_model_config must be present for mtp" + ) + if self.target_model_config.hf_text_config.model_type == "deepseek_v32": # FIXME(luccafong): cudgraph with v32 MTP is not supported, # remove this when the issue is fixed. self.enforce_eager = True @@ -235,21 +249,21 @@ def __post_init__(self): self.model = "ngram" else: raise ValueError( - "num_speculative_tokens was provided but without " - "speculative model.") + "num_speculative_tokens was provided but without speculative model." + ) # Automatically configure the method for ngram when "model" is used # instead of "method" - if self.method is None and (self.model is not None - and self.model in ("ngram", "[ngram]")): + if self.method is None and ( + self.model is not None and self.model in ("ngram", "[ngram]") + ): self.method = "ngram" if self.method in ("ngram", "[ngram]"): # Unified to "ngram" internally self.method = "ngram" # Set default values if not provided - if (self.prompt_lookup_min is None - and self.prompt_lookup_max is None): + if self.prompt_lookup_min is None and self.prompt_lookup_max is None: # TODO(woosuk): Tune these values. They are arbitrarily chosen. self.prompt_lookup_min = 5 self.prompt_lookup_max = 5 @@ -263,14 +277,17 @@ def __post_init__(self): # Validate values if self.prompt_lookup_min < 1: raise ValueError( - f"prompt_lookup_min={self.prompt_lookup_min} must be > 0") + f"prompt_lookup_min={self.prompt_lookup_min} must be > 0" + ) if self.prompt_lookup_max < 1: raise ValueError( - f"prompt_lookup_max={self.prompt_lookup_max} must be > 0") + f"prompt_lookup_max={self.prompt_lookup_max} must be > 0" + ) if self.prompt_lookup_min > self.prompt_lookup_max: raise ValueError( f"prompt_lookup_min={self.prompt_lookup_min} must " - f"be <= prompt_lookup_max={self.prompt_lookup_max}") + f"be <= prompt_lookup_max={self.prompt_lookup_max}" + ) # TODO: current we still need extract vocab_size from target model # config, in future, we may try refactor it out, and set @@ -285,25 +302,21 @@ def __post_init__(self): # TODO: Move this import to the top once `ModelConfig` # lives in `vllm.config.model`. from vllm.config import ModelConfig + self.draft_model_config = ModelConfig( model=self.model, runner="draft", tokenizer=self.target_model_config.tokenizer, tokenizer_mode=self.target_model_config.tokenizer_mode, - trust_remote_code=self.target_model_config. - trust_remote_code, - allowed_local_media_path=self.target_model_config. - allowed_local_media_path, - allowed_media_domains=self.target_model_config. - allowed_media_domains, + trust_remote_code=self.target_model_config.trust_remote_code, + allowed_local_media_path=self.target_model_config.allowed_local_media_path, + allowed_media_domains=self.target_model_config.allowed_media_domains, dtype=self.target_model_config.dtype, seed=self.target_model_config.seed, revision=self.revision, code_revision=self.code_revision, - tokenizer_revision=self.target_model_config. - tokenizer_revision, - spec_target_max_model_len=self.target_model_config. - max_model_len, + tokenizer_revision=self.target_model_config.tokenizer_revision, + spec_target_max_model_len=self.target_model_config.max_model_len, quantization=self.quantization, enforce_eager=self.target_model_config.enforce_eager, max_logprobs=self.target_model_config.max_logprobs, @@ -311,7 +324,7 @@ def __post_init__(self): ) # Automatically detect the method - if self.method in ('eagle', 'eagle3'): + if self.method in ("eagle", "eagle3"): pass # examples: # yuhuili/EAGLE-LLaMA3-Instruct-8B @@ -323,27 +336,26 @@ def __post_init__(self): self.method = "eagle3" elif self.draft_model_config.hf_config.model_type == "medusa": self.method = "medusa" - elif (self.draft_model_config.hf_config.model_type == - "mlp_speculator"): + elif self.draft_model_config.hf_config.model_type == "mlp_speculator": self.method = "mlp_speculator" - elif (self.draft_model_config.hf_config.model_type - in MTP_MODEL_TYPES): + elif self.draft_model_config.hf_config.model_type in MTP_MODEL_TYPES: self.method = "mtp" if self.num_speculative_tokens > 1: logger.warning( - "Enabling num_speculative_tokens > 1 will run" \ - "multiple times of forward on same MTP layer" \ - ",which may result in lower acceptance rate" \ - ) - elif (self.draft_model_config.hf_config.model_type - in ("longcat_flash_mtp")): + "Enabling num_speculative_tokens > 1 will run" + "multiple times of forward on same MTP layer" + ",which may result in lower acceptance rate" + ) + elif self.draft_model_config.hf_config.model_type in ( + "longcat_flash_mtp" + ): self.method = "longcat_flash_mtp" if self.num_speculative_tokens > 1: logger.warning( - "LongCat MTP models only have " \ - "one layer. Might need some code changes " \ - "to support multiple layers." - ) + "LongCat MTP models only have " + "one layer. Might need some code changes " + "to support multiple layers." + ) else: self.method = "draft_model" @@ -352,60 +364,67 @@ def __post_init__(self): if self.enable_chunked_prefill and not envs.VLLM_USE_V1: raise ValueError( "Chunked prefill and EAGLE are not compatible " - "when using V0.") + "when using V0." + ) - from vllm.transformers_utils.configs import ( - SpeculatorsConfig) - from vllm.transformers_utils.configs.eagle import ( - EAGLEConfig) + from vllm.transformers_utils.configs import SpeculatorsConfig + from vllm.transformers_utils.configs.eagle import EAGLEConfig - if isinstance(self.draft_model_config.hf_config, - (EAGLEConfig, SpeculatorsConfig)): + if isinstance( + self.draft_model_config.hf_config, + (EAGLEConfig, SpeculatorsConfig), + ): pass else: eagle_config = EAGLEConfig( self.draft_model_config.hf_config, method=self.method, - model_type="eagle") + model_type="eagle", + ) self.draft_model_config.hf_config = eagle_config - if (self.num_speculative_tokens is not None - and hasattr(self.draft_model_config.hf_config, - "num_lookahead_tokens")): - self.draft_model_config.hf_config.num_lookahead_tokens = \ - self.num_speculative_tokens + if self.num_speculative_tokens is not None and hasattr( + self.draft_model_config.hf_config, "num_lookahead_tokens" + ): + self.draft_model_config.hf_config.num_lookahead_tokens = ( + self.num_speculative_tokens + ) - n_predict = getattr(self.draft_model_config.hf_config, - "n_predict", None) + n_predict = getattr( + self.draft_model_config.hf_config, "n_predict", None + ) if n_predict is not None: if self.num_speculative_tokens is None: # Default to max value defined in draft model config. self.num_speculative_tokens = n_predict - elif self.num_speculative_tokens > n_predict and \ - self.num_speculative_tokens % n_predict != 0: + elif ( + self.num_speculative_tokens > n_predict + and self.num_speculative_tokens % n_predict != 0 + ): # Ensure divisibility for MTP module reuse. raise ValueError( f"num_speculative_tokens:{self.num_speculative_tokens}" - f" must be divisible by {n_predict=}") + f" must be divisible by {n_predict=}" + ) if self.speculative_token_tree is None: # Generate chain of tokens. - self.speculative_token_tree = str([ - (i + 1) * (0, ) - for i in range(self.num_speculative_tokens) - ]) + self.speculative_token_tree = str( + [(i + 1) * (0,) for i in range(self.num_speculative_tokens)] + ) else: # Sort the token tree breadth-first. - tree_choices = ast.literal_eval( - self.speculative_token_tree) + tree_choices = ast.literal_eval(self.speculative_token_tree) self.speculative_token_tree = str( - sorted(tree_choices, key=lambda t: (len(t), t))) + sorted(tree_choices, key=lambda t: (len(t), t)) + ) - self.draft_tensor_parallel_size = \ + self.draft_tensor_parallel_size = ( SpeculativeConfig._verify_and_get_draft_tp( self.target_parallel_config, self.draft_tensor_parallel_size, - self.draft_model_config.hf_config + self.draft_model_config.hf_config, + ) ) self.draft_model_config.max_model_len = ( @@ -413,12 +432,14 @@ def __post_init__(self): self.max_model_len, self.draft_model_config.max_model_len, self.target_model_config.max_model_len, - )) + ) + ) self.draft_parallel_config = ( SpeculativeConfig.create_draft_parallel_config( - self.target_parallel_config, - self.draft_tensor_parallel_size)) + self.target_parallel_config, self.draft_tensor_parallel_size + ) + ) @staticmethod def _maybe_override_draft_max_model_len( @@ -439,14 +460,17 @@ def _maybe_override_draft_max_model_len( """ if speculative_max_model_len is not None: - if speculative_max_model_len > draft_max_model_len: - raise ValueError(f"{speculative_max_model_len=} cannot be " - f"larger than {draft_max_model_len=}") + raise ValueError( + f"{speculative_max_model_len=} cannot be " + f"larger than {draft_max_model_len=}" + ) if speculative_max_model_len > target_max_model_len: - raise ValueError(f"{speculative_max_model_len=} cannot be " - f"larger than {target_max_model_len=}") + raise ValueError( + f"{speculative_max_model_len=} cannot be " + f"larger than {target_max_model_len=}" + ) return speculative_max_model_len @@ -457,9 +481,10 @@ def _maybe_override_draft_max_model_len( @staticmethod def _verify_and_get_draft_tp( - target_parallel_config: ParallelConfig, - speculative_draft_tensor_parallel_size: Optional[int], - draft_hf_config: PretrainedConfig) -> int: + target_parallel_config: ParallelConfig, + speculative_draft_tensor_parallel_size: Optional[int], + draft_hf_config: PretrainedConfig, + ) -> int: """ Verifies and adjusts the tensor parallel size for a draft model specified using speculative_draft_tensor_parallel_size. @@ -473,15 +498,20 @@ def _verify_and_get_draft_tp( logger.warning( "%s cannot currently be run with tp>1; " "setting speculative_draft_tensor_parallel_size=1", - draft_hf_config.model_type) + draft_hf_config.model_type, + ) else: - speculative_draft_tensor_parallel_size = \ + speculative_draft_tensor_parallel_size = ( target_parallel_config.tensor_parallel_size + ) elif speculative_draft_tensor_parallel_size not in ( - 1, target_parallel_config.tensor_parallel_size): + 1, + target_parallel_config.tensor_parallel_size, + ): raise ValueError( f"{speculative_draft_tensor_parallel_size=} cannot be " - f"other value than 1 or target model tensor_parallel_size") + f"other value than 1 or target model tensor_parallel_size" + ) return speculative_draft_tensor_parallel_size @staticmethod @@ -494,52 +524,57 @@ def create_draft_parallel_config( This is mostly a copy of the target parallel config, except the tp_size. """ draft_parallel_config = ParallelConfig( - pipeline_parallel_size=target_parallel_config. - pipeline_parallel_size, + pipeline_parallel_size=target_parallel_config.pipeline_parallel_size, tensor_parallel_size=speculative_draft_tensor_parallel_size, - distributed_executor_backend=target_parallel_config. - distributed_executor_backend, - max_parallel_loading_workers=target_parallel_config. - max_parallel_loading_workers, - disable_custom_all_reduce=target_parallel_config. - disable_custom_all_reduce, - ray_workers_use_nsight=target_parallel_config. - ray_workers_use_nsight, + distributed_executor_backend=target_parallel_config.distributed_executor_backend, + max_parallel_loading_workers=target_parallel_config.max_parallel_loading_workers, + disable_custom_all_reduce=target_parallel_config.disable_custom_all_reduce, + ray_workers_use_nsight=target_parallel_config.ray_workers_use_nsight, placement_group=target_parallel_config.placement_group, ) return draft_parallel_config - @model_validator(mode='after') + @model_validator(mode="after") def _verify_args(self) -> Self: if self.num_speculative_tokens is None: raise ValueError( "num_speculative_tokens must be provided with " "speculative model unless the draft model config contains an " - "n_predict parameter.") + "n_predict parameter." + ) if self.num_speculative_tokens <= 0: - raise ValueError("Expected num_speculative_tokens to be greater " - f"than zero ({self.num_speculative_tokens}).") + raise ValueError( + "Expected num_speculative_tokens to be greater " + f"than zero ({self.num_speculative_tokens})." + ) if self.draft_model_config: self.draft_model_config.verify_with_parallel_config( - self.draft_parallel_config) + self.draft_parallel_config + ) - if (self.disable_by_batch_size is not None - and self.disable_by_batch_size < 2): - raise ValueError("Expect the batch size threshold of disabling " - "speculative decoding is > 1, but got " - f"{self.disable_by_batch_size=}") + if self.disable_by_batch_size is not None and self.disable_by_batch_size < 2: + raise ValueError( + "Expect the batch size threshold of disabling " + "speculative decoding is > 1, but got " + f"{self.disable_by_batch_size=}" + ) eagle3_target_supported = ["llama", "qwen", "minicpm", "gpt_oss"] - if self.method == "eagle3" and self.target_model_config and not any( - supported_model in - self.target_model_config.hf_text_config.model_type - for supported_model in eagle3_target_supported): + if ( + self.method == "eagle3" + and self.target_model_config + and not any( + supported_model in self.target_model_config.hf_text_config.model_type + for supported_model in eagle3_target_supported + ) + ): raise ValueError( f"Eagle3 is only supported for {eagle3_target_supported} models. " # noqa: E501 - f"Got {self.target_model_config.hf_text_config.model_type=}") + f"Got {self.target_model_config.hf_text_config.model_type=}" + ) return self diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 4ccb64e29cf2..442531420f1b 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -10,9 +10,22 @@ import sys from dataclasses import MISSING, dataclass, fields, is_dataclass from itertools import permutations -from typing import (TYPE_CHECKING, Annotated, Any, Callable, Dict, List, - Literal, Optional, Type, TypeVar, Union, cast, get_args, - get_origin) +from typing import ( + TYPE_CHECKING, + Annotated, + Any, + Callable, + Dict, + List, + Literal, + Optional, + Type, + TypeVar, + Union, + cast, + get_args, + get_origin, +) import huggingface_hub import regex as re @@ -21,17 +34,42 @@ from typing_extensions import TypeIs, deprecated import vllm.envs as envs -from vllm.config import (BlockSize, CacheConfig, CacheDType, CompilationConfig, - ConfigType, ConvertOption, DetailedTraceModules, - Device, DeviceConfig, DistributedExecutorBackend, - EPLBConfig, HfOverrides, KVEventsConfig, - KVTransferConfig, LoadConfig, LogprobsMode, - LoRAConfig, MambaDType, MMEncoderTPMode, ModelConfig, - ModelDType, ObservabilityConfig, ParallelConfig, - PoolerConfig, PrefixCachingHashAlgo, RunnerOption, - SchedulerConfig, SchedulerPolicy, SpeculativeConfig, - StructuredOutputsConfig, TaskOption, TokenizerMode, - VllmConfig, get_attr_docs) +from vllm.config import ( + BlockSize, + CacheConfig, + CacheDType, + CompilationConfig, + ConfigType, + ConvertOption, + DetailedTraceModules, + Device, + DeviceConfig, + DistributedExecutorBackend, + EPLBConfig, + HfOverrides, + KVEventsConfig, + KVTransferConfig, + LoadConfig, + LogprobsMode, + LoRAConfig, + MambaDType, + MMEncoderTPMode, + ModelConfig, + ModelDType, + ObservabilityConfig, + ParallelConfig, + PoolerConfig, + PrefixCachingHashAlgo, + RunnerOption, + SchedulerConfig, + SchedulerPolicy, + SpeculativeConfig, + StructuredOutputsConfig, + TaskOption, + TokenizerMode, + VllmConfig, + get_attr_docs, +) from vllm.config.multimodal import MMCacheType, MultiModalConfig from vllm.config.parallel import ExpertPlacementStrategy from vllm.config.utils import get_field @@ -41,11 +79,13 @@ from vllm.ray.lazy_utils import is_ray_initialized from vllm.reasoning import ReasoningParserManager from vllm.test_utils import MODEL_WEIGHTS_S3_BUCKET, MODELS_ON_S3 -from vllm.transformers_utils.config import (get_model_path, is_interleaved, - maybe_override_with_speculators) +from vllm.transformers_utils.config import ( + get_model_path, + is_interleaved, + maybe_override_with_speculators, +) from vllm.transformers_utils.utils import check_gguf_file -from vllm.utils import (FlexibleArgumentParser, GiB_bytes, get_ip, - is_in_ray_actor) +from vllm.utils import FlexibleArgumentParser, GiB_bytes, get_ip, is_in_ray_actor from vllm.v1.sample.logits_processor import LogitsProcessor # yapf: enable @@ -70,20 +110,18 @@ def parse_type(return_type: Callable[[str], T]) -> Callable[[str], T]: - def _parse_type(val: str) -> T: try: return return_type(val) except ValueError as e: raise argparse.ArgumentTypeError( - f"Value {val} cannot be converted to {return_type}.") from e + f"Value {val} cannot be converted to {return_type}." + ) from e return _parse_type -def optional_type( - return_type: Callable[[str], T]) -> Callable[[str], Optional[T]]: - +def optional_type(return_type: Callable[[str], T]) -> Callable[[str], Optional[T]]: def _optional_type(val: str) -> Optional[T]: if val == "" or val == "None": return None @@ -124,7 +162,8 @@ def literal_to_kwargs(type_hints: set[TypeHint]) -> dict[str, Any]: if not all(isinstance(option, option_type) for option in options): raise ValueError( "All options must be of the same type. " - f"Got {options} with types {[type(c) for c in options]}") + f"Got {options} with types {[type(c) for c in options]}" + ) kwarg = "metavar" if contains_type(type_hints, str) else "choices" return {"type": option_type, kwarg: sorted(options)} @@ -191,8 +230,9 @@ def _compute_kwargs(cls: ConfigType) -> dict[str, Any]: kwargs[name] = {"default": default, "help": help} # Set other kwargs based on the type hints - json_tip = ("Should either be a valid JSON string or JSON keys passed " - "individually.") + json_tip = ( + "Should either be a valid JSON string or JSON keys passed individually." + ) if dataclass_cls is not None: def parse_dataclass(val: str, cls=dataclass_cls) -> Any: @@ -214,7 +254,8 @@ def parse_dataclass(val: str, cls=dataclass_cls) -> Any: tuple_type = types[0] assert all(t is tuple_type for t in types if t is not Ellipsis), ( "All non-Ellipsis tuple elements must be of the same " - f"type. Got {types}.") + f"type. Got {types}." + ) kwargs[name]["type"] = tuple_type kwargs[name]["nargs"] = "+" if Ellipsis in types else len(types) elif contains_type(type_hints, list): @@ -240,19 +281,20 @@ def parse_dataclass(val: str, cls=dataclass_cls) -> Any: kwargs[name]["help"] += f"\n\n{human_readable_int.__doc__}" elif contains_type(type_hints, float): kwargs[name]["type"] = float - elif (contains_type(type_hints, dict) - and (contains_type(type_hints, str) - or any(is_not_builtin(th) for th in type_hints))): + elif contains_type(type_hints, dict) and ( + contains_type(type_hints, str) + or any(is_not_builtin(th) for th in type_hints) + ): kwargs[name]["type"] = union_dict_and_str elif contains_type(type_hints, dict): kwargs[name]["type"] = parse_type(json.loads) kwargs[name]["help"] += f"\n\n{json_tip}" - elif (contains_type(type_hints, str) - or any(is_not_builtin(th) for th in type_hints)): + elif contains_type(type_hints, str) or any( + is_not_builtin(th) for th in type_hints + ): kwargs[name]["type"] = str else: - raise ValueError( - f"Unsupported type {type_hints} for argument {name}.") + raise ValueError(f"Unsupported type {type_hints} for argument {name}.") # If the type hint was a sequence of literals, use the helper function # to update the type and choices @@ -284,9 +326,9 @@ def get_kwargs(cls: ConfigType) -> dict[str, Any]: @dataclass class EngineArgs: """Arguments for vLLM engine.""" + model: str = ModelConfig.model - served_model_name: Optional[Union[ - str, List[str]]] = ModelConfig.served_model_name + served_model_name: Optional[Union[str, List[str]]] = ModelConfig.served_model_name tokenizer: Optional[str] = ModelConfig.tokenizer hf_config_path: Optional[str] = ModelConfig.hf_config_path runner: RunnerOption = ModelConfig.runner @@ -297,8 +339,7 @@ class EngineArgs: tokenizer_mode: TokenizerMode = ModelConfig.tokenizer_mode trust_remote_code: bool = ModelConfig.trust_remote_code allowed_local_media_path: str = ModelConfig.allowed_local_media_path - allowed_media_domains: Optional[ - list[str]] = ModelConfig.allowed_media_domains + allowed_media_domains: Optional[list[str]] = ModelConfig.allowed_media_domains download_dir: Optional[str] = LoadConfig.download_dir safetensors_load_strategy: str = LoadConfig.safetensors_load_strategy load_format: Union[str, LoadFormats] = LoadConfig.load_format @@ -307,19 +348,17 @@ class EngineArgs: kv_cache_dtype: CacheDType = CacheConfig.cache_dtype seed: Optional[int] = ModelConfig.seed max_model_len: Optional[int] = ModelConfig.max_model_len - cuda_graph_sizes: list[int] = get_field(SchedulerConfig, - "cuda_graph_sizes") + cuda_graph_sizes: list[int] = get_field(SchedulerConfig, "cuda_graph_sizes") # Note: Specifying a custom executor backend by passing a class # is intended for expert use only. The API may change without # notice. - distributed_executor_backend: Optional[Union[ - str, DistributedExecutorBackend, - Type[ExecutorBase]]] = ParallelConfig.distributed_executor_backend + distributed_executor_backend: Optional[ + Union[str, DistributedExecutorBackend, Type[ExecutorBase]] + ] = ParallelConfig.distributed_executor_backend # number of P/D disaggregation (or other disaggregation) workers pipeline_parallel_size: int = ParallelConfig.pipeline_parallel_size tensor_parallel_size: int = ParallelConfig.tensor_parallel_size - decode_context_parallel_size: int = \ - ParallelConfig.decode_context_parallel_size + decode_context_parallel_size: int = ParallelConfig.decode_context_parallel_size data_parallel_size: int = ParallelConfig.data_parallel_size data_parallel_rank: Optional[int] = None data_parallel_start_rank: Optional[int] = None @@ -330,38 +369,37 @@ class EngineArgs: data_parallel_backend: str = ParallelConfig.data_parallel_backend enable_expert_parallel: bool = ParallelConfig.enable_expert_parallel enable_dbo: bool = ParallelConfig.enable_dbo - dbo_decode_token_threshold: int = \ - ParallelConfig.dbo_decode_token_threshold - dbo_prefill_token_threshold: int = \ - ParallelConfig.dbo_prefill_token_threshold + dbo_decode_token_threshold: int = ParallelConfig.dbo_decode_token_threshold + dbo_prefill_token_threshold: int = ParallelConfig.dbo_prefill_token_threshold eplb_config: EPLBConfig = get_field(ParallelConfig, "eplb_config") enable_eplb: bool = ParallelConfig.enable_eplb - expert_placement_strategy: ExpertPlacementStrategy = \ + expert_placement_strategy: ExpertPlacementStrategy = ( ParallelConfig.expert_placement_strategy + ) _api_process_count: int = ParallelConfig._api_process_count _api_process_rank: int = ParallelConfig._api_process_rank num_redundant_experts: int = EPLBConfig.num_redundant_experts eplb_window_size: int = EPLBConfig.window_size eplb_step_interval: int = EPLBConfig.step_interval eplb_log_balancedness: bool = EPLBConfig.log_balancedness - max_parallel_loading_workers: Optional[ - int] = ParallelConfig.max_parallel_loading_workers + max_parallel_loading_workers: Optional[int] = ( + ParallelConfig.max_parallel_loading_workers + ) block_size: Optional[BlockSize] = CacheConfig.block_size enable_prefix_caching: Optional[bool] = CacheConfig.enable_prefix_caching - prefix_caching_hash_algo: PrefixCachingHashAlgo = \ + prefix_caching_hash_algo: PrefixCachingHashAlgo = ( CacheConfig.prefix_caching_hash_algo + ) disable_sliding_window: bool = ModelConfig.disable_sliding_window disable_cascade_attn: bool = ModelConfig.disable_cascade_attn swap_space: float = CacheConfig.swap_space cpu_offload_gb: float = CacheConfig.cpu_offload_gb gpu_memory_utilization: float = CacheConfig.gpu_memory_utilization kv_cache_memory_bytes: Optional[int] = CacheConfig.kv_cache_memory_bytes - max_num_batched_tokens: Optional[ - int] = SchedulerConfig.max_num_batched_tokens + max_num_batched_tokens: Optional[int] = SchedulerConfig.max_num_batched_tokens max_num_partial_prefills: int = SchedulerConfig.max_num_partial_prefills max_long_partial_prefills: int = SchedulerConfig.max_long_partial_prefills - long_prefill_token_threshold: int = \ - SchedulerConfig.long_prefill_token_threshold + long_prefill_token_threshold: int = SchedulerConfig.long_prefill_token_threshold max_num_seqs: Optional[int] = SchedulerConfig.max_num_seqs max_logprobs: int = ModelConfig.max_logprobs logprobs_mode: LogprobsMode = ModelConfig.logprobs_mode @@ -376,20 +414,22 @@ class EngineArgs: quantization: Optional[QuantizationMethods] = ModelConfig.quantization enforce_eager: bool = ModelConfig.enforce_eager disable_custom_all_reduce: bool = ParallelConfig.disable_custom_all_reduce - limit_mm_per_prompt: dict[str, Union[int, dict[str, int]]] = \ - get_field(MultiModalConfig, "limit_per_prompt") + limit_mm_per_prompt: dict[str, Union[int, dict[str, int]]] = get_field( + MultiModalConfig, "limit_per_prompt" + ) interleave_mm_strings: bool = MultiModalConfig.interleave_mm_strings - media_io_kwargs: dict[str, dict[str, - Any]] = get_field(MultiModalConfig, - "media_io_kwargs") - mm_processor_kwargs: Optional[Dict[str, Any]] = \ - MultiModalConfig.mm_processor_kwargs + media_io_kwargs: dict[str, dict[str, Any]] = get_field( + MultiModalConfig, "media_io_kwargs" + ) + mm_processor_kwargs: Optional[Dict[str, Any]] = MultiModalConfig.mm_processor_kwargs disable_mm_preprocessor_cache: bool = False # DEPRECATED mm_processor_cache_gb: float = MultiModalConfig.mm_processor_cache_gb - mm_processor_cache_type: Optional[MMCacheType] = \ + mm_processor_cache_type: Optional[MMCacheType] = ( MultiModalConfig.mm_processor_cache_type - mm_shm_cache_max_object_size_mb: int = \ + ) + mm_shm_cache_max_object_size_mb: int = ( MultiModalConfig.mm_shm_cache_max_object_size_mb + ) mm_encoder_tp_mode: MMEncoderTPMode = MultiModalConfig.mm_encoder_tp_mode io_processor_plugin: Optional[str] = None skip_mm_profiling: bool = MultiModalConfig.skip_mm_profiling @@ -399,31 +439,28 @@ class EngineArgs: enable_lora_bias: bool = LoRAConfig.bias_enabled max_loras: int = LoRAConfig.max_loras max_lora_rank: int = LoRAConfig.max_lora_rank - default_mm_loras: Optional[Dict[str, str]] = \ - LoRAConfig.default_mm_loras + default_mm_loras: Optional[Dict[str, str]] = LoRAConfig.default_mm_loras fully_sharded_loras: bool = LoRAConfig.fully_sharded_loras max_cpu_loras: Optional[int] = LoRAConfig.max_cpu_loras lora_dtype: Optional[Union[str, torch.dtype]] = LoRAConfig.lora_dtype lora_extra_vocab_size: int = LoRAConfig.lora_extra_vocab_size ray_workers_use_nsight: bool = ParallelConfig.ray_workers_use_nsight - num_gpu_blocks_override: Optional[ - int] = CacheConfig.num_gpu_blocks_override + num_gpu_blocks_override: Optional[int] = CacheConfig.num_gpu_blocks_override num_lookahead_slots: int = SchedulerConfig.num_lookahead_slots - model_loader_extra_config: dict = \ - get_field(LoadConfig, "model_loader_extra_config") - ignore_patterns: Optional[Union[str, - List[str]]] = LoadConfig.ignore_patterns + model_loader_extra_config: dict = get_field(LoadConfig, "model_loader_extra_config") + ignore_patterns: Optional[Union[str, List[str]]] = LoadConfig.ignore_patterns - enable_chunked_prefill: Optional[ - bool] = SchedulerConfig.enable_chunked_prefill + enable_chunked_prefill: Optional[bool] = SchedulerConfig.enable_chunked_prefill disable_chunked_mm_input: bool = SchedulerConfig.disable_chunked_mm_input disable_hybrid_kv_cache_manager: bool = ( - SchedulerConfig.disable_hybrid_kv_cache_manager) + SchedulerConfig.disable_hybrid_kv_cache_manager + ) structured_outputs_config: StructuredOutputsConfig = get_field( - VllmConfig, "structured_outputs_config") + VllmConfig, "structured_outputs_config" + ) reasoning_parser: str = StructuredOutputsConfig.reasoning_parser # Deprecated guided decoding fields guided_decoding_backend: Optional[str] = None @@ -431,25 +468,25 @@ class EngineArgs: guided_decoding_disable_any_whitespace: Optional[bool] = None guided_decoding_disable_additional_properties: Optional[bool] = None - logits_processor_pattern: Optional[ - str] = ModelConfig.logits_processor_pattern + logits_processor_pattern: Optional[str] = ModelConfig.logits_processor_pattern speculative_config: Optional[Dict[str, Any]] = None - show_hidden_metrics_for_version: Optional[str] = \ + show_hidden_metrics_for_version: Optional[str] = ( ObservabilityConfig.show_hidden_metrics_for_version - otlp_traces_endpoint: Optional[str] = \ - ObservabilityConfig.otlp_traces_endpoint - collect_detailed_traces: Optional[list[DetailedTraceModules]] = \ + ) + otlp_traces_endpoint: Optional[str] = ObservabilityConfig.otlp_traces_endpoint + collect_detailed_traces: Optional[list[DetailedTraceModules]] = ( ObservabilityConfig.collect_detailed_traces + ) scheduling_policy: SchedulerPolicy = SchedulerConfig.policy scheduler_cls: Union[str, Type[object]] = SchedulerConfig.scheduler_cls pooler_config: Optional[PoolerConfig] = ModelConfig.pooler_config - override_pooler_config: Optional[Union[dict, PoolerConfig]] = \ + override_pooler_config: Optional[Union[dict, PoolerConfig]] = ( ModelConfig.override_pooler_config - compilation_config: CompilationConfig = \ - get_field(VllmConfig, "compilation_config") + ) + compilation_config: CompilationConfig = get_field(VllmConfig, "compilation_config") worker_cls: str = ParallelConfig.worker_cls worker_extension_cls: str = ParallelConfig.worker_extension_cls @@ -458,8 +495,9 @@ class EngineArgs: generation_config: str = ModelConfig.generation_config enable_sleep_mode: bool = ModelConfig.enable_sleep_mode - override_generation_config: dict[str, Any] = \ - get_field(ModelConfig, "override_generation_config") + override_generation_config: dict[str, Any] = get_field( + ModelConfig, "override_generation_config" + ) model_impl: str = ModelConfig.model_impl override_attention_dtype: str = ModelConfig.override_attention_dtype @@ -467,8 +505,7 @@ class EngineArgs: mamba_cache_dtype: MambaDType = CacheConfig.mamba_cache_dtype mamba_ssm_cache_dtype: MambaDType = CacheConfig.mamba_ssm_cache_dtype - additional_config: dict[str, Any] = \ - get_field(VllmConfig, "additional_config") + additional_config: dict[str, Any] = get_field(VllmConfig, "additional_config") use_tqdm_on_load: bool = LoadConfig.use_tqdm_on_load pt_load_map_location: str = LoadConfig.pt_load_map_location @@ -476,34 +513,36 @@ class EngineArgs: # DEPRECATED enable_multimodal_encoder_data_parallel: bool = False - logits_processors: Optional[list[Union[ - str, type[LogitsProcessor]]]] = ModelConfig.logits_processors + logits_processors: Optional[list[Union[str, type[LogitsProcessor]]]] = ( + ModelConfig.logits_processors + ) """Custom logitproc types""" async_scheduling: bool = SchedulerConfig.async_scheduling - kv_sharing_fast_prefill: bool = \ - CacheConfig.kv_sharing_fast_prefill + kv_sharing_fast_prefill: bool = CacheConfig.kv_sharing_fast_prefill def __post_init__(self): # support `EngineArgs(compilation_config={...})` # without having to manually construct a # CompilationConfig object if isinstance(self.compilation_config, dict): - self.compilation_config = CompilationConfig( - **self.compilation_config) + self.compilation_config = CompilationConfig(**self.compilation_config) if isinstance(self.eplb_config, dict): self.eplb_config = EPLBConfig(**self.eplb_config) # Setup plugins from vllm.plugins import load_general_plugins + load_general_plugins() # when use hf offline,replace model id to local model path if huggingface_hub.constants.HF_HUB_OFFLINE: model_id = self.model self.model = get_model_path(self.model, self.revision) logger.info( - "HF_HUB_OFFLINE is True, replace model_id [%s] " \ - "to model_path [%s]",model_id, self.model) + "HF_HUB_OFFLINE is True, replace model_id [%s] to model_path [%s]", + model_id, + self.model, + ) @staticmethod def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: @@ -515,86 +554,92 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: title="ModelConfig", description=ModelConfig.__doc__, ) - if not ('serve' in sys.argv[1:] and '--help' in sys.argv[1:]): + if not ("serve" in sys.argv[1:] and "--help" in sys.argv[1:]): model_group.add_argument("--model", **model_kwargs["model"]) model_group.add_argument("--runner", **model_kwargs["runner"]) model_group.add_argument("--convert", **model_kwargs["convert"]) - model_group.add_argument("--task", - **model_kwargs["task"], - deprecated=True) + model_group.add_argument("--task", **model_kwargs["task"], deprecated=True) model_group.add_argument("--tokenizer", **model_kwargs["tokenizer"]) - model_group.add_argument("--tokenizer-mode", - **model_kwargs["tokenizer_mode"]) - model_group.add_argument("--trust-remote-code", - **model_kwargs["trust_remote_code"]) + model_group.add_argument("--tokenizer-mode", **model_kwargs["tokenizer_mode"]) + model_group.add_argument( + "--trust-remote-code", **model_kwargs["trust_remote_code"] + ) model_group.add_argument("--dtype", **model_kwargs["dtype"]) model_group.add_argument("--seed", **model_kwargs["seed"]) - model_group.add_argument("--hf-config-path", - **model_kwargs["hf_config_path"]) - model_group.add_argument("--allowed-local-media-path", - **model_kwargs["allowed_local_media_path"]) - model_group.add_argument("--allowed-media-domains", - **model_kwargs["allowed_media_domains"]) + model_group.add_argument("--hf-config-path", **model_kwargs["hf_config_path"]) + model_group.add_argument( + "--allowed-local-media-path", **model_kwargs["allowed_local_media_path"] + ) + model_group.add_argument( + "--allowed-media-domains", **model_kwargs["allowed_media_domains"] + ) model_group.add_argument("--revision", **model_kwargs["revision"]) - model_group.add_argument("--code-revision", - **model_kwargs["code_revision"]) - model_group.add_argument("--rope-scaling", - **model_kwargs["rope_scaling"]) + model_group.add_argument("--code-revision", **model_kwargs["code_revision"]) + model_group.add_argument("--rope-scaling", **model_kwargs["rope_scaling"]) model_group.add_argument("--rope-theta", **model_kwargs["rope_theta"]) - model_group.add_argument("--tokenizer-revision", - **model_kwargs["tokenizer_revision"]) - model_group.add_argument("--max-model-len", - **model_kwargs["max_model_len"]) - model_group.add_argument("--quantization", "-q", - **model_kwargs["quantization"]) - model_group.add_argument("--enforce-eager", - **model_kwargs["enforce_eager"]) - model_group.add_argument("--max-logprobs", - **model_kwargs["max_logprobs"]) - model_group.add_argument("--logprobs-mode", - **model_kwargs["logprobs_mode"]) - model_group.add_argument("--disable-sliding-window", - **model_kwargs["disable_sliding_window"]) - model_group.add_argument("--disable-cascade-attn", - **model_kwargs["disable_cascade_attn"]) - model_group.add_argument("--skip-tokenizer-init", - **model_kwargs["skip_tokenizer_init"]) - model_group.add_argument("--enable-prompt-embeds", - **model_kwargs["enable_prompt_embeds"]) - model_group.add_argument("--served-model-name", - **model_kwargs["served_model_name"]) - model_group.add_argument("--config-format", - **model_kwargs["config_format"]) + model_group.add_argument( + "--tokenizer-revision", **model_kwargs["tokenizer_revision"] + ) + model_group.add_argument("--max-model-len", **model_kwargs["max_model_len"]) + model_group.add_argument("--quantization", "-q", **model_kwargs["quantization"]) + model_group.add_argument("--enforce-eager", **model_kwargs["enforce_eager"]) + model_group.add_argument("--max-logprobs", **model_kwargs["max_logprobs"]) + model_group.add_argument("--logprobs-mode", **model_kwargs["logprobs_mode"]) + model_group.add_argument( + "--disable-sliding-window", **model_kwargs["disable_sliding_window"] + ) + model_group.add_argument( + "--disable-cascade-attn", **model_kwargs["disable_cascade_attn"] + ) + model_group.add_argument( + "--skip-tokenizer-init", **model_kwargs["skip_tokenizer_init"] + ) + model_group.add_argument( + "--enable-prompt-embeds", **model_kwargs["enable_prompt_embeds"] + ) + model_group.add_argument( + "--served-model-name", **model_kwargs["served_model_name"] + ) + model_group.add_argument("--config-format", **model_kwargs["config_format"]) # This one is a special case because it can bool # or str. TODO: Handle this in get_kwargs - model_group.add_argument("--hf-token", - type=str, - nargs="?", - const=True, - default=model_kwargs["hf_token"]["default"], - help=model_kwargs["hf_token"]["help"]) - model_group.add_argument("--hf-overrides", - **model_kwargs["hf_overrides"]) - model_group.add_argument("--pooler-config", - **model_kwargs["pooler_config"]) - model_group.add_argument("--override-pooler-config", - **model_kwargs["override_pooler_config"], - deprecated=True) - model_group.add_argument("--logits-processor-pattern", - **model_kwargs["logits_processor_pattern"]) - model_group.add_argument("--generation-config", - **model_kwargs["generation_config"]) - model_group.add_argument("--override-generation-config", - **model_kwargs["override_generation_config"]) - model_group.add_argument("--enable-sleep-mode", - **model_kwargs["enable_sleep_mode"]) + model_group.add_argument( + "--hf-token", + type=str, + nargs="?", + const=True, + default=model_kwargs["hf_token"]["default"], + help=model_kwargs["hf_token"]["help"], + ) + model_group.add_argument("--hf-overrides", **model_kwargs["hf_overrides"]) + model_group.add_argument("--pooler-config", **model_kwargs["pooler_config"]) + model_group.add_argument( + "--override-pooler-config", + **model_kwargs["override_pooler_config"], + deprecated=True, + ) + model_group.add_argument( + "--logits-processor-pattern", **model_kwargs["logits_processor_pattern"] + ) + model_group.add_argument( + "--generation-config", **model_kwargs["generation_config"] + ) + model_group.add_argument( + "--override-generation-config", **model_kwargs["override_generation_config"] + ) + model_group.add_argument( + "--enable-sleep-mode", **model_kwargs["enable_sleep_mode"] + ) model_group.add_argument("--model-impl", **model_kwargs["model_impl"]) - model_group.add_argument("--override-attention-dtype", - **model_kwargs["override_attention_dtype"]) - model_group.add_argument("--logits-processors", - **model_kwargs["logits_processors"]) - model_group.add_argument("--io-processor-plugin", - **model_kwargs["io_processor_plugin"]) + model_group.add_argument( + "--override-attention-dtype", **model_kwargs["override_attention_dtype"] + ) + model_group.add_argument( + "--logits-processors", **model_kwargs["logits_processors"] + ) + model_group.add_argument( + "--io-processor-plugin", **model_kwargs["io_processor_plugin"] + ) # Model loading arguments load_kwargs = get_kwargs(LoadConfig) @@ -603,18 +648,18 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: description=LoadConfig.__doc__, ) load_group.add_argument("--load-format", **load_kwargs["load_format"]) - load_group.add_argument("--download-dir", - **load_kwargs["download_dir"]) - load_group.add_argument("--safetensors-load-strategy", - **load_kwargs["safetensors_load_strategy"]) - load_group.add_argument("--model-loader-extra-config", - **load_kwargs["model_loader_extra_config"]) - load_group.add_argument("--ignore-patterns", - **load_kwargs["ignore_patterns"]) - load_group.add_argument("--use-tqdm-on-load", - **load_kwargs["use_tqdm_on_load"]) - load_group.add_argument('--pt-load-map-location', - **load_kwargs["pt_load_map_location"]) + load_group.add_argument("--download-dir", **load_kwargs["download_dir"]) + load_group.add_argument( + "--safetensors-load-strategy", **load_kwargs["safetensors_load_strategy"] + ) + load_group.add_argument( + "--model-loader-extra-config", **load_kwargs["model_loader_extra_config"] + ) + load_group.add_argument("--ignore-patterns", **load_kwargs["ignore_patterns"]) + load_group.add_argument("--use-tqdm-on-load", **load_kwargs["use_tqdm_on_load"]) + load_group.add_argument( + "--pt-load-map-location", **load_kwargs["pt_load_map_location"] + ) # Structured outputs arguments structured_outputs_kwargs = get_kwargs(StructuredOutputsConfig) @@ -626,7 +671,8 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: "--reasoning-parser", # This choice is a special case because it's not static choices=list(ReasoningParserManager.reasoning_parsers), - **structured_outputs_kwargs["reasoning_parser"]) + **structured_outputs_kwargs["reasoning_parser"], + ) # Deprecated guided decoding arguments for arg, type in [ ("--guided-decoding-backend", str), @@ -638,7 +684,8 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: arg, type=type, help=(f"[DEPRECATED] {arg} will be removed in v0.12.0."), - deprecated=True) + deprecated=True, + ) # Parallel arguments parallel_kwargs = get_kwargs(ParallelConfig) @@ -648,111 +695,128 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: ) parallel_group.add_argument( "--distributed-executor-backend", - **parallel_kwargs["distributed_executor_backend"]) + **parallel_kwargs["distributed_executor_backend"], + ) parallel_group.add_argument( - "--pipeline-parallel-size", "-pp", - **parallel_kwargs["pipeline_parallel_size"]) - parallel_group.add_argument("--tensor-parallel-size", "-tp", - **parallel_kwargs["tensor_parallel_size"]) + "--pipeline-parallel-size", + "-pp", + **parallel_kwargs["pipeline_parallel_size"], + ) parallel_group.add_argument( - "--decode-context-parallel-size", "-dcp", - **parallel_kwargs["decode_context_parallel_size"]) - parallel_group.add_argument("--data-parallel-size", "-dp", - **parallel_kwargs["data_parallel_size"]) + "--tensor-parallel-size", "-tp", **parallel_kwargs["tensor_parallel_size"] + ) parallel_group.add_argument( - '--data-parallel-rank', - '-dpn', + "--decode-context-parallel-size", + "-dcp", + **parallel_kwargs["decode_context_parallel_size"], + ) + parallel_group.add_argument( + "--data-parallel-size", "-dp", **parallel_kwargs["data_parallel_size"] + ) + parallel_group.add_argument( + "--data-parallel-rank", + "-dpn", type=int, - help='Data parallel rank of this instance. ' - 'When set, enables external load balancer mode.') - parallel_group.add_argument('--data-parallel-start-rank', - '-dpr', - type=int, - help='Starting data parallel rank ' - 'for secondary nodes.') - parallel_group.add_argument('--data-parallel-size-local', - '-dpl', - type=int, - help='Number of data parallel replicas ' - 'to run on this node.') - parallel_group.add_argument('--data-parallel-address', - '-dpa', - type=str, - help='Address of data parallel cluster ' - 'head-node.') - parallel_group.add_argument('--data-parallel-rpc-port', - '-dpp', - type=int, - help='Port for data parallel RPC ' - 'communication.') - parallel_group.add_argument('--data-parallel-backend', - '-dpb', - type=str, - default='mp', - help='Backend for data parallel, either ' - '"mp" or "ray".') + help="Data parallel rank of this instance. " + "When set, enables external load balancer mode.", + ) parallel_group.add_argument( - "--data-parallel-hybrid-lb", - **parallel_kwargs["data_parallel_hybrid_lb"]) + "--data-parallel-start-rank", + "-dpr", + type=int, + help="Starting data parallel rank for secondary nodes.", + ) + parallel_group.add_argument( + "--data-parallel-size-local", + "-dpl", + type=int, + help="Number of data parallel replicas to run on this node.", + ) + parallel_group.add_argument( + "--data-parallel-address", + "-dpa", + type=str, + help="Address of data parallel cluster head-node.", + ) + parallel_group.add_argument( + "--data-parallel-rpc-port", + "-dpp", + type=int, + help="Port for data parallel RPC communication.", + ) + parallel_group.add_argument( + "--data-parallel-backend", + "-dpb", + type=str, + default="mp", + help='Backend for data parallel, either "mp" or "ray".', + ) parallel_group.add_argument( - "--enable-expert-parallel", - **parallel_kwargs["enable_expert_parallel"]) - parallel_group.add_argument("--enable-dbo", - **parallel_kwargs["enable_dbo"]) + "--data-parallel-hybrid-lb", **parallel_kwargs["data_parallel_hybrid_lb"] + ) + parallel_group.add_argument( + "--enable-expert-parallel", **parallel_kwargs["enable_expert_parallel"] + ) + parallel_group.add_argument("--enable-dbo", **parallel_kwargs["enable_dbo"]) parallel_group.add_argument( "--dbo-decode-token-threshold", - **parallel_kwargs["dbo_decode_token_threshold"]) + **parallel_kwargs["dbo_decode_token_threshold"], + ) parallel_group.add_argument( "--dbo-prefill-token-threshold", - **parallel_kwargs["dbo_prefill_token_threshold"]) - parallel_group.add_argument("--enable-eplb", - **parallel_kwargs["enable_eplb"]) - parallel_group.add_argument("--eplb-config", - **parallel_kwargs["eplb_config"]) + **parallel_kwargs["dbo_prefill_token_threshold"], + ) + parallel_group.add_argument("--enable-eplb", **parallel_kwargs["enable_eplb"]) + parallel_group.add_argument("--eplb-config", **parallel_kwargs["eplb_config"]) parallel_group.add_argument( "--expert-placement-strategy", - **parallel_kwargs["expert_placement_strategy"]) + **parallel_kwargs["expert_placement_strategy"], + ) parallel_group.add_argument( "--num-redundant-experts", type=int, - help= - "[DEPRECATED] --num-redundant-experts will be removed in v0.12.0.", - deprecated=True) + help="[DEPRECATED] --num-redundant-experts will be removed in v0.12.0.", + deprecated=True, + ) parallel_group.add_argument( "--eplb-window-size", type=int, help="[DEPRECATED] --eplb-window-size will be removed in v0.12.0.", - deprecated=True) + deprecated=True, + ) parallel_group.add_argument( "--eplb-step-interval", type=int, - help= - "[DEPRECATED] --eplb-step-interval will be removed in v0.12.0.", - deprecated=True) + help="[DEPRECATED] --eplb-step-interval will be removed in v0.12.0.", + deprecated=True, + ) parallel_group.add_argument( "--eplb-log-balancedness", action=argparse.BooleanOptionalAction, - help= - "[DEPRECATED] --eplb-log-balancedness will be removed in v0.12.0.", - deprecated=True) + help="[DEPRECATED] --eplb-log-balancedness will be removed in v0.12.0.", + deprecated=True, + ) parallel_group.add_argument( "--max-parallel-loading-workers", - **parallel_kwargs["max_parallel_loading_workers"]) + **parallel_kwargs["max_parallel_loading_workers"], + ) parallel_group.add_argument( - "--ray-workers-use-nsight", - **parallel_kwargs["ray_workers_use_nsight"]) + "--ray-workers-use-nsight", **parallel_kwargs["ray_workers_use_nsight"] + ) parallel_group.add_argument( "--disable-custom-all-reduce", - **parallel_kwargs["disable_custom_all_reduce"]) - parallel_group.add_argument("--worker-cls", - **parallel_kwargs["worker_cls"]) - parallel_group.add_argument("--worker-extension-cls", - **parallel_kwargs["worker_extension_cls"]) + **parallel_kwargs["disable_custom_all_reduce"], + ) + parallel_group.add_argument("--worker-cls", **parallel_kwargs["worker_cls"]) + parallel_group.add_argument( + "--worker-extension-cls", **parallel_kwargs["worker_extension_cls"] + ) parallel_group.add_argument( "--enable-multimodal-encoder-data-parallel", action="store_true", - deprecated=True) + deprecated=True, + ) # KV cache arguments cache_kwargs = get_kwargs(CacheConfig) @@ -761,29 +825,36 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: description=CacheConfig.__doc__, ) cache_group.add_argument("--block-size", **cache_kwargs["block_size"]) - cache_group.add_argument("--gpu-memory-utilization", - **cache_kwargs["gpu_memory_utilization"]) - cache_group.add_argument("--kv-cache-memory-bytes", - **cache_kwargs["kv_cache_memory_bytes"]) + cache_group.add_argument( + "--gpu-memory-utilization", **cache_kwargs["gpu_memory_utilization"] + ) + cache_group.add_argument( + "--kv-cache-memory-bytes", **cache_kwargs["kv_cache_memory_bytes"] + ) cache_group.add_argument("--swap-space", **cache_kwargs["swap_space"]) - cache_group.add_argument("--kv-cache-dtype", - **cache_kwargs["cache_dtype"]) - cache_group.add_argument("--num-gpu-blocks-override", - **cache_kwargs["num_gpu_blocks_override"]) - cache_group.add_argument("--enable-prefix-caching", - **cache_kwargs["enable_prefix_caching"]) - cache_group.add_argument("--prefix-caching-hash-algo", - **cache_kwargs["prefix_caching_hash_algo"]) - cache_group.add_argument("--cpu-offload-gb", - **cache_kwargs["cpu_offload_gb"]) - cache_group.add_argument("--calculate-kv-scales", - **cache_kwargs["calculate_kv_scales"]) - cache_group.add_argument("--kv-sharing-fast-prefill", - **cache_kwargs["kv_sharing_fast_prefill"]) - cache_group.add_argument("--mamba-cache-dtype", - **cache_kwargs["mamba_cache_dtype"]) - cache_group.add_argument("--mamba-ssm-cache-dtype", - **cache_kwargs["mamba_ssm_cache_dtype"]) + cache_group.add_argument("--kv-cache-dtype", **cache_kwargs["cache_dtype"]) + cache_group.add_argument( + "--num-gpu-blocks-override", **cache_kwargs["num_gpu_blocks_override"] + ) + cache_group.add_argument( + "--enable-prefix-caching", **cache_kwargs["enable_prefix_caching"] + ) + cache_group.add_argument( + "--prefix-caching-hash-algo", **cache_kwargs["prefix_caching_hash_algo"] + ) + cache_group.add_argument("--cpu-offload-gb", **cache_kwargs["cpu_offload_gb"]) + cache_group.add_argument( + "--calculate-kv-scales", **cache_kwargs["calculate_kv_scales"] + ) + cache_group.add_argument( + "--kv-sharing-fast-prefill", **cache_kwargs["kv_sharing_fast_prefill"] + ) + cache_group.add_argument( + "--mamba-cache-dtype", **cache_kwargs["mamba_cache_dtype"] + ) + cache_group.add_argument( + "--mamba-ssm-cache-dtype", **cache_kwargs["mamba_ssm_cache_dtype"] + ) # Multimodal related configs multimodal_kwargs = get_kwargs(MultiModalConfig) @@ -791,35 +862,41 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: title="MultiModalConfig", description=MultiModalConfig.__doc__, ) - multimodal_group.add_argument("--limit-mm-per-prompt", - **multimodal_kwargs["limit_per_prompt"]) - multimodal_group.add_argument("--media-io-kwargs", - **multimodal_kwargs["media_io_kwargs"]) multimodal_group.add_argument( - "--mm-processor-kwargs", - **multimodal_kwargs["mm_processor_kwargs"]) + "--limit-mm-per-prompt", **multimodal_kwargs["limit_per_prompt"] + ) + multimodal_group.add_argument( + "--media-io-kwargs", **multimodal_kwargs["media_io_kwargs"] + ) + multimodal_group.add_argument( + "--mm-processor-kwargs", **multimodal_kwargs["mm_processor_kwargs"] + ) + multimodal_group.add_argument( + "--mm-processor-cache-gb", **multimodal_kwargs["mm_processor_cache_gb"] + ) multimodal_group.add_argument( - "--mm-processor-cache-gb", - **multimodal_kwargs["mm_processor_cache_gb"]) - multimodal_group.add_argument("--disable-mm-preprocessor-cache", - action="store_true", - deprecated=True) + "--disable-mm-preprocessor-cache", action="store_true", deprecated=True + ) multimodal_group.add_argument( - "--mm-processor-cache-type", - **multimodal_kwargs["mm_processor_cache_type"]) + "--mm-processor-cache-type", **multimodal_kwargs["mm_processor_cache_type"] + ) multimodal_group.add_argument( "--mm-shm-cache-max-object-size-mb", - **multimodal_kwargs["mm_shm_cache_max_object_size_mb"]) + **multimodal_kwargs["mm_shm_cache_max_object_size_mb"], + ) multimodal_group.add_argument( - "--mm-encoder-tp-mode", **multimodal_kwargs["mm_encoder_tp_mode"]) + "--mm-encoder-tp-mode", **multimodal_kwargs["mm_encoder_tp_mode"] + ) + multimodal_group.add_argument( + "--interleave-mm-strings", **multimodal_kwargs["interleave_mm_strings"] + ) multimodal_group.add_argument( - "--interleave-mm-strings", - **multimodal_kwargs["interleave_mm_strings"]) - multimodal_group.add_argument("--skip-mm-profiling", - **multimodal_kwargs["skip_mm_profiling"]) + "--skip-mm-profiling", **multimodal_kwargs["skip_mm_profiling"] + ) multimodal_group.add_argument( - "--video-pruning-rate", **multimodal_kwargs["video_pruning_rate"]) + "--video-pruning-rate", **multimodal_kwargs["video_pruning_rate"] + ) # LoRA related configs lora_kwargs = get_kwargs(LoRAConfig) @@ -830,24 +907,23 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: lora_group.add_argument( "--enable-lora", action=argparse.BooleanOptionalAction, - help="If True, enable handling of LoRA adapters.") - lora_group.add_argument("--enable-lora-bias", - **lora_kwargs["bias_enabled"]) + help="If True, enable handling of LoRA adapters.", + ) + lora_group.add_argument("--enable-lora-bias", **lora_kwargs["bias_enabled"]) lora_group.add_argument("--max-loras", **lora_kwargs["max_loras"]) - lora_group.add_argument("--max-lora-rank", - **lora_kwargs["max_lora_rank"]) - lora_group.add_argument("--lora-extra-vocab-size", - **lora_kwargs["lora_extra_vocab_size"]) + lora_group.add_argument("--max-lora-rank", **lora_kwargs["max_lora_rank"]) + lora_group.add_argument( + "--lora-extra-vocab-size", **lora_kwargs["lora_extra_vocab_size"] + ) lora_group.add_argument( "--lora-dtype", **lora_kwargs["lora_dtype"], ) - lora_group.add_argument("--max-cpu-loras", - **lora_kwargs["max_cpu_loras"]) - lora_group.add_argument("--fully-sharded-loras", - **lora_kwargs["fully_sharded_loras"]) - lora_group.add_argument("--default-mm-loras", - **lora_kwargs["default_mm_loras"]) + lora_group.add_argument("--max-cpu-loras", **lora_kwargs["max_cpu_loras"]) + lora_group.add_argument( + "--fully-sharded-loras", **lora_kwargs["fully_sharded_loras"] + ) + lora_group.add_argument("--default-mm-loras", **lora_kwargs["default_mm_loras"]) # Observability arguments observability_kwargs = get_kwargs(ObservabilityConfig) @@ -857,21 +933,22 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: ) observability_group.add_argument( "--show-hidden-metrics-for-version", - **observability_kwargs["show_hidden_metrics_for_version"]) + **observability_kwargs["show_hidden_metrics_for_version"], + ) observability_group.add_argument( - "--otlp-traces-endpoint", - **observability_kwargs["otlp_traces_endpoint"]) + "--otlp-traces-endpoint", **observability_kwargs["otlp_traces_endpoint"] + ) # TODO: generalise this special case choices = observability_kwargs["collect_detailed_traces"]["choices"] metavar = f"{{{','.join(choices)}}}" observability_kwargs["collect_detailed_traces"]["metavar"] = metavar observability_kwargs["collect_detailed_traces"]["choices"] += [ - ",".join(p) - for p in permutations(get_args(DetailedTraceModules), r=2) + ",".join(p) for p in permutations(get_args(DetailedTraceModules), r=2) ] observability_group.add_argument( "--collect-detailed-traces", - **observability_kwargs["collect_detailed_traces"]) + **observability_kwargs["collect_detailed_traces"], + ) # Scheduler arguments scheduler_kwargs = get_kwargs(SchedulerConfig) @@ -880,40 +957,49 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: description=SchedulerConfig.__doc__, ) scheduler_group.add_argument( - "--max-num-batched-tokens", - **scheduler_kwargs["max_num_batched_tokens"]) - scheduler_group.add_argument("--max-num-seqs", - **scheduler_kwargs["max_num_seqs"]) + "--max-num-batched-tokens", **scheduler_kwargs["max_num_batched_tokens"] + ) scheduler_group.add_argument( - "--max-num-partial-prefills", - **scheduler_kwargs["max_num_partial_prefills"]) + "--max-num-seqs", **scheduler_kwargs["max_num_seqs"] + ) + scheduler_group.add_argument( + "--max-num-partial-prefills", **scheduler_kwargs["max_num_partial_prefills"] + ) scheduler_group.add_argument( "--max-long-partial-prefills", - **scheduler_kwargs["max_long_partial_prefills"]) - scheduler_group.add_argument('--cuda-graph-sizes', - **scheduler_kwargs["cuda_graph_sizes"]) + **scheduler_kwargs["max_long_partial_prefills"], + ) + scheduler_group.add_argument( + "--cuda-graph-sizes", **scheduler_kwargs["cuda_graph_sizes"] + ) scheduler_group.add_argument( "--long-prefill-token-threshold", - **scheduler_kwargs["long_prefill_token_threshold"]) - scheduler_group.add_argument("--num-lookahead-slots", - **scheduler_kwargs["num_lookahead_slots"]) + **scheduler_kwargs["long_prefill_token_threshold"], + ) + scheduler_group.add_argument( + "--num-lookahead-slots", **scheduler_kwargs["num_lookahead_slots"] + ) # multi-step scheduling has been removed; corresponding arguments # are no longer supported. - scheduler_group.add_argument("--scheduling-policy", - **scheduler_kwargs["policy"]) scheduler_group.add_argument( - "--enable-chunked-prefill", - **scheduler_kwargs["enable_chunked_prefill"]) + "--scheduling-policy", **scheduler_kwargs["policy"] + ) scheduler_group.add_argument( - "--disable-chunked-mm-input", - **scheduler_kwargs["disable_chunked_mm_input"]) - scheduler_group.add_argument("--scheduler-cls", - **scheduler_kwargs["scheduler_cls"]) + "--enable-chunked-prefill", **scheduler_kwargs["enable_chunked_prefill"] + ) + scheduler_group.add_argument( + "--disable-chunked-mm-input", **scheduler_kwargs["disable_chunked_mm_input"] + ) + scheduler_group.add_argument( + "--scheduler-cls", **scheduler_kwargs["scheduler_cls"] + ) scheduler_group.add_argument( "--disable-hybrid-kv-cache-manager", - **scheduler_kwargs["disable_hybrid_kv_cache_manager"]) - scheduler_group.add_argument("--async-scheduling", - **scheduler_kwargs["async_scheduling"]) + **scheduler_kwargs["disable_hybrid_kv_cache_manager"], + ) + scheduler_group.add_argument( + "--async-scheduling", **scheduler_kwargs["async_scheduling"] + ) # vLLM arguments vllm_kwargs = get_kwargs(VllmConfig) @@ -925,23 +1011,29 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: # create_engine_config. So we set the type to a JSON string here to # delay the Pydantic validation that comes with SpeculativeConfig. vllm_kwargs["speculative_config"]["type"] = optional_type(json.loads) - vllm_group.add_argument("--speculative-config", - **vllm_kwargs["speculative_config"]) - vllm_group.add_argument("--kv-transfer-config", - **vllm_kwargs["kv_transfer_config"]) - vllm_group.add_argument('--kv-events-config', - **vllm_kwargs["kv_events_config"]) - vllm_group.add_argument("--compilation-config", "-O", - **vllm_kwargs["compilation_config"]) - vllm_group.add_argument("--additional-config", - **vllm_kwargs["additional_config"]) - vllm_group.add_argument('--structured-outputs-config', - **vllm_kwargs["structured_outputs_config"]) + vllm_group.add_argument( + "--speculative-config", **vllm_kwargs["speculative_config"] + ) + vllm_group.add_argument( + "--kv-transfer-config", **vllm_kwargs["kv_transfer_config"] + ) + vllm_group.add_argument("--kv-events-config", **vllm_kwargs["kv_events_config"]) + vllm_group.add_argument( + "--compilation-config", "-O", **vllm_kwargs["compilation_config"] + ) + vllm_group.add_argument( + "--additional-config", **vllm_kwargs["additional_config"] + ) + vllm_group.add_argument( + "--structured-outputs-config", **vllm_kwargs["structured_outputs_config"] + ) # Other arguments - parser.add_argument('--disable-log-stats', - action='store_true', - help='Disable logging statistics.') + parser.add_argument( + "--disable-log-stats", + action="store_true", + help="Disable logging statistics.", + ) return parser @@ -950,10 +1042,9 @@ def from_cli_args(cls, args: argparse.Namespace): # Get the list of attributes of this dataclass. attrs = [attr.name for attr in dataclasses.fields(cls)] # Set the attributes from the parsed arguments. - engine_args = cls(**{ - attr: getattr(args, attr) - for attr in attrs if hasattr(args, attr) - }) + engine_args = cls( + **{attr: getattr(args, attr) for attr in attrs if hasattr(args, attr)} + ) return engine_args def create_model_config(self) -> ModelConfig: @@ -962,15 +1053,20 @@ def create_model_config(self) -> ModelConfig: self.quantization = self.load_format = "gguf" # NOTE: This is to allow model loading from S3 in CI - if (not isinstance(self, AsyncEngineArgs) and envs.VLLM_CI_USE_S3 - and self.model in MODELS_ON_S3 and self.load_format == "auto"): + if ( + not isinstance(self, AsyncEngineArgs) + and envs.VLLM_CI_USE_S3 + and self.model in MODELS_ON_S3 + and self.load_format == "auto" + ): self.model = f"{MODEL_WEIGHTS_S3_BUCKET}/{self.model}" if self.disable_mm_preprocessor_cache: logger.warning( "`--disable-mm-preprocessor-cache` is deprecated " "and will be removed in v0.13. " - "Please use `--mm-processor-cache-gb 0` instead.", ) + "Please use `--mm-processor-cache-gb 0` instead.", + ) self.mm_processor_cache_gb = 0 elif envs.VLLM_MM_INPUT_CACHE_GIB != 4: @@ -987,7 +1083,8 @@ def create_model_config(self) -> ModelConfig: logger.warning( "--enable-multimodal-encoder-data-parallel` is deprecated " "and will be removed in v0.13. " - "Please use `--mm-encoder-tp-mode data` instead.") + "Please use `--mm-encoder-tp-mode data` instead." + ) self.mm_encoder_tp_mode = "data" @@ -1029,8 +1126,7 @@ def create_model_config(self) -> ModelConfig: mm_processor_kwargs=self.mm_processor_kwargs, mm_processor_cache_gb=self.mm_processor_cache_gb, mm_processor_cache_type=self.mm_processor_cache_type, - mm_shm_cache_max_object_size_mb=self. - mm_shm_cache_max_object_size_mb, + mm_shm_cache_max_object_size_mb=self.mm_shm_cache_max_object_size_mb, mm_encoder_tp_mode=self.mm_encoder_tp_mode, pooler_config=self.pooler_config, override_pooler_config=self.override_pooler_config, @@ -1046,33 +1142,34 @@ def create_model_config(self) -> ModelConfig: ) def validate_tensorizer_args(self): - from vllm.model_executor.model_loader.tensorizer import ( - TensorizerConfig) + from vllm.model_executor.model_loader.tensorizer import TensorizerConfig + for key in self.model_loader_extra_config: if key in TensorizerConfig._fields: - self.model_loader_extra_config["tensorizer_config"][ - key] = self.model_loader_extra_config[key] + self.model_loader_extra_config["tensorizer_config"][key] = ( + self.model_loader_extra_config[key] + ) def create_load_config(self) -> LoadConfig: - if self.quantization == "bitsandbytes": self.load_format = "bitsandbytes" if self.load_format == "tensorizer": if hasattr(self.model_loader_extra_config, "to_serializable"): self.model_loader_extra_config = ( - self.model_loader_extra_config.to_serializable()) + self.model_loader_extra_config.to_serializable() + ) self.model_loader_extra_config["tensorizer_config"] = {} - self.model_loader_extra_config["tensorizer_config"][ - "tensorizer_dir"] = self.model + self.model_loader_extra_config["tensorizer_config"]["tensorizer_dir"] = ( + self.model + ) self.validate_tensorizer_args() return LoadConfig( load_format=self.load_format, download_dir=self.download_dir, safetensors_load_strategy=self.safetensors_load_strategy, - device="cpu" - if is_online_quantization(self.quantization) else None, + device="cpu" if is_online_quantization(self.quantization) else None, model_loader_extra_config=self.model_loader_extra_config, ignore_patterns=self.ignore_patterns, use_tqdm_on_load=self.use_tqdm_on_load, @@ -1100,12 +1197,14 @@ def create_speculative_config( # Note(Shangming): These parameters are not obtained from the cli arg # '--speculative-config' and must be passed in when creating the engine # config. - self.speculative_config.update({ - "target_model_config": target_model_config, - "target_parallel_config": target_parallel_config, - "enable_chunked_prefill": enable_chunked_prefill, - "disable_log_stats": disable_log_stats, - }) + self.speculative_config.update( + { + "target_model_config": target_model_config, + "target_parallel_config": target_parallel_config, + "enable_chunked_prefill": enable_chunked_prefill, + "disable_log_stats": disable_log_stats, + } + ) return SpeculativeConfig(**self.speculative_config) def create_engine_config( @@ -1128,21 +1227,21 @@ def create_engine_config( """ current_platform.pre_register_and_update() - device_config = DeviceConfig( - device=cast(Device, current_platform.device_type)) + device_config = DeviceConfig(device=cast(Device, current_platform.device_type)) model_config = self.create_model_config() self.model = model_config.model self.tokenizer = model_config.tokenizer - (self.model, self.tokenizer, - self.speculative_config) = maybe_override_with_speculators( - model=self.model, - tokenizer=self.tokenizer, - revision=self.revision, - trust_remote_code=self.trust_remote_code, - vllm_speculative_config=self.speculative_config, - ) + (self.model, self.tokenizer, self.speculative_config) = ( + maybe_override_with_speculators( + model=self.model, + tokenizer=self.tokenizer, + revision=self.revision, + trust_remote_code=self.trust_remote_code, + vllm_speculative_config=self.speculative_config, + ) + ) # * If VLLM_USE_V1 is unset, we enable V1 for "supported features" # and fall back to V0 for experimental or unsupported features. @@ -1164,12 +1263,17 @@ def create_engine_config( # Set default arguments for V1 Engine. self._set_default_args(usage_context, model_config) # Disable chunked prefill for POWER (ppc64le)/ARM/s390x/RISCV CPUs in V1 - if current_platform.is_cpu() and current_platform.get_cpu_architecture( - ) in (CpuArchEnum.POWERPC, CpuArchEnum.S390X, CpuArchEnum.ARM, - CpuArchEnum.RISCV): - logger.info("Chunked prefill is not supported for ARM and POWER, " - "S390X and RISC-V CPUs; " - "disabling it for V1 backend.") + if current_platform.is_cpu() and current_platform.get_cpu_architecture() in ( + CpuArchEnum.POWERPC, + CpuArchEnum.S390X, + CpuArchEnum.ARM, + CpuArchEnum.RISCV, + ): + logger.info( + "Chunked prefill is not supported for ARM and POWER, " + "S390X and RISC-V CPUs; " + "disabling it for V1 backend." + ) self.enable_chunked_prefill = False assert self.enable_chunked_prefill is not None @@ -1185,8 +1289,7 @@ def create_engine_config( # because the world size does not change by dcp, it simply # reuses the GPUs of TP group, and split one TP group into # tp_size//dcp_size DCP groups. - assert self.tensor_parallel_size % self.decode_context_parallel_size \ - == 0, ( + assert self.tensor_parallel_size % self.decode_context_parallel_size == 0, ( f"tp_size={self.tensor_parallel_size} must be divisible by" f"dcp_size={self.decode_context_parallel_size}." ) @@ -1215,6 +1318,7 @@ def create_engine_config( # of a Ray task, therefore we check is_ray_initialized() # as opposed to is_in_ray_actor(). import ray + ray_runtime_env = ray.get_runtime_context().runtime_env logger.info("Using ray runtime env: %s", ray_runtime_env) @@ -1230,15 +1334,15 @@ def create_engine_config( placement_group = ray.util.get_current_placement_group() assert not headless or not self.data_parallel_hybrid_lb, ( - "data_parallel_hybrid_lb is not applicable in " - "headless mode") + "data_parallel_hybrid_lb is not applicable in headless mode" + ) data_parallel_external_lb = self.data_parallel_rank is not None # Local DP rank = 1, use pure-external LB. if data_parallel_external_lb: assert self.data_parallel_size_local in (1, None), ( - "data_parallel_size_local must be 1 when data_parallel_rank " - "is set") + "data_parallel_size_local must be 1 when data_parallel_rank is set" + ) data_parallel_size_local = 1 # Use full external lb if we have local_size of 1. self.data_parallel_hybrid_lb = False @@ -1261,8 +1365,8 @@ def create_engine_config( self.data_parallel_rank = self.data_parallel_start_rank or 0 else: assert not self.data_parallel_hybrid_lb, ( - "data_parallel_size_local must be set to use " - "data_parallel_hybrid_lb.") + "data_parallel_size_local must be set to use data_parallel_hybrid_lb." + ) # Local DP size defaults to global DP size if not set. data_parallel_size_local = self.data_parallel_size @@ -1273,39 +1377,46 @@ def create_engine_config( if self.data_parallel_backend == "ray": host_ip = get_ip() logger.info( - "Using host IP %s as ray-based data parallel address", - host_ip) + "Using host IP %s as ray-based data parallel address", host_ip + ) data_parallel_address = host_ip else: assert self.data_parallel_backend == "mp", ( "data_parallel_backend can only be ray or mp, got %s", - self.data_parallel_backend) + self.data_parallel_backend, + ) data_parallel_address = ParallelConfig.data_parallel_master_ip else: data_parallel_address = self.data_parallel_address # This port is only used when there are remote data parallel engines, # otherwise the local IPC transport is used. - data_parallel_rpc_port = self.data_parallel_rpc_port if ( + data_parallel_rpc_port = ( self.data_parallel_rpc_port - is not None) else ParallelConfig.data_parallel_rpc_port + if (self.data_parallel_rpc_port is not None) + else ParallelConfig.data_parallel_rpc_port + ) if self.async_scheduling: # Async scheduling does not work with the uniprocess backend. if self.distributed_executor_backend is None: self.distributed_executor_backend = "mp" - logger.info("Defaulting to mp-based distributed executor " - "backend for async scheduling.") + logger.info( + "Defaulting to mp-based distributed executor " + "backend for async scheduling." + ) if self.pipeline_parallel_size > 1: - raise ValueError("Async scheduling is not supported with " - "pipeline-parallel-size > 1.") + raise ValueError( + "Async scheduling is not supported with pipeline-parallel-size > 1." + ) # Currently, async scheduling does not support speculative decoding. # TODO(woosuk): Support it. if self.speculative_config is not None: raise ValueError( "Currently, speculative decoding is not supported with " - "async scheduling.") + "async scheduling." + ) # Forward the deprecated CLI args to the EPLB config. if self.num_redundant_experts is not None: @@ -1372,33 +1483,38 @@ def create_engine_config( disable_chunked_mm_input=self.disable_chunked_mm_input, is_multimodal_model=model_config.is_multimodal_model, is_encoder_decoder=model_config.is_encoder_decoder, - send_delta_data=(envs.VLLM_USE_RAY_SPMD_WORKER - and parallel_config.use_ray), + send_delta_data=(envs.VLLM_USE_RAY_SPMD_WORKER and parallel_config.use_ray), policy=self.scheduling_policy, scheduler_cls=self.scheduler_cls, max_num_partial_prefills=self.max_num_partial_prefills, max_long_partial_prefills=self.max_long_partial_prefills, long_prefill_token_threshold=self.long_prefill_token_threshold, - disable_hybrid_kv_cache_manager=self. - disable_hybrid_kv_cache_manager, + disable_hybrid_kv_cache_manager=self.disable_hybrid_kv_cache_manager, async_scheduling=self.async_scheduling, ) if not model_config.is_multimodal_model and self.default_mm_loras: raise ValueError( "Default modality-specific LoRA(s) were provided for a " - "non multimodal model") - - lora_config = LoRAConfig( - bias_enabled=self.enable_lora_bias, - max_lora_rank=self.max_lora_rank, - max_loras=self.max_loras, - default_mm_loras=self.default_mm_loras, - fully_sharded_loras=self.fully_sharded_loras, - lora_extra_vocab_size=self.lora_extra_vocab_size, - lora_dtype=self.lora_dtype, - max_cpu_loras=self.max_cpu_loras if self.max_cpu_loras - and self.max_cpu_loras > 0 else None) if self.enable_lora else None + "non multimodal model" + ) + + lora_config = ( + LoRAConfig( + bias_enabled=self.enable_lora_bias, + max_lora_rank=self.max_lora_rank, + max_loras=self.max_loras, + default_mm_loras=self.default_mm_loras, + fully_sharded_loras=self.fully_sharded_loras, + lora_extra_vocab_size=self.lora_extra_vocab_size, + lora_dtype=self.lora_dtype, + max_cpu_loras=self.max_cpu_loras + if self.max_cpu_loras and self.max_cpu_loras > 0 + else None, + ) + if self.enable_lora + else None + ) # bitsandbytes pre-quantized model need a specific model loader if model_config.quantization == "bitsandbytes": @@ -1408,27 +1524,27 @@ def create_engine_config( # Pass reasoning_parser into StructuredOutputsConfig if self.reasoning_parser: - self.structured_outputs_config.reasoning_parser = \ - self.reasoning_parser + self.structured_outputs_config.reasoning_parser = self.reasoning_parser # Forward the deprecated CLI args to the StructuredOutputsConfig so_config = self.structured_outputs_config if self.guided_decoding_backend is not None: - so_config.guided_decoding_backend = \ - self.guided_decoding_backend + so_config.guided_decoding_backend = self.guided_decoding_backend if self.guided_decoding_disable_fallback is not None: - so_config.guided_decoding_disable_fallback = \ - self.guided_decoding_disable_fallback + so_config.guided_decoding_disable_fallback = ( + self.guided_decoding_disable_fallback + ) if self.guided_decoding_disable_any_whitespace is not None: - so_config.guided_decoding_disable_any_whitespace = \ - self.guided_decoding_disable_any_whitespace + so_config.guided_decoding_disable_any_whitespace = ( + self.guided_decoding_disable_any_whitespace + ) if self.guided_decoding_disable_additional_properties is not None: - so_config.guided_decoding_disable_additional_properties = \ - self.guided_decoding_disable_additional_properties + so_config.guided_decoding_disable_additional_properties = ( + self.guided_decoding_disable_additional_properties + ) observability_config = ObservabilityConfig( - show_hidden_metrics_for_version=( - self.show_hidden_metrics_for_version), + show_hidden_metrics_for_version=(self.show_hidden_metrics_for_version), otlp_traces_endpoint=self.otlp_traces_endpoint, collect_detailed_traces=self.collect_detailed_traces, ) @@ -1458,25 +1574,28 @@ def _is_v1_supported_oracle(self, model_config: ModelConfig) -> bool: ############################################################# # Unsupported Feature Flags on V1. - if (self.logits_processor_pattern - != EngineArgs.logits_processor_pattern): - _raise_or_fallback(feature_name="--logits-processor-pattern", - recommend_to_remove=False) + if self.logits_processor_pattern != EngineArgs.logits_processor_pattern: + _raise_or_fallback( + feature_name="--logits-processor-pattern", recommend_to_remove=False + ) return False # No Mamba or Encoder-Decoder so far. if not model_config.is_v1_compatible: - _raise_or_fallback(feature_name=model_config.architectures, - recommend_to_remove=False) + _raise_or_fallback( + feature_name=model_config.architectures, recommend_to_remove=False + ) return False # No Concurrent Partial Prefills so far. - if (self.max_num_partial_prefills - != SchedulerConfig.max_num_partial_prefills - or self.max_long_partial_prefills - != SchedulerConfig.max_long_partial_prefills): - _raise_or_fallback(feature_name="Concurrent Partial Prefill", - recommend_to_remove=False) + if ( + self.max_num_partial_prefills != SchedulerConfig.max_num_partial_prefills + or self.max_long_partial_prefills + != SchedulerConfig.max_long_partial_prefills + ): + _raise_or_fallback( + feature_name="Concurrent Partial Prefill", recommend_to_remove=False + ) return False V1_BACKENDS = [ @@ -1496,8 +1615,10 @@ def _is_v1_supported_oracle(self, model_config: ModelConfig) -> bool: "XFORMERS", "ROCM_ATTN", ] - if (envs.is_set("VLLM_ATTENTION_BACKEND") - and envs.VLLM_ATTENTION_BACKEND not in V1_BACKENDS): + if ( + envs.is_set("VLLM_ATTENTION_BACKEND") + and envs.VLLM_ATTENTION_BACKEND not in V1_BACKENDS + ): name = f"VLLM_ATTENTION_BACKEND={envs.VLLM_ATTENTION_BACKEND}" _raise_or_fallback(feature_name=name, recommend_to_remove=True) return False @@ -1506,30 +1627,36 @@ def _is_v1_supported_oracle(self, model_config: ModelConfig) -> bool: # Experimental Features - allow users to opt in. if self.pipeline_parallel_size > 1: - supports_pp = getattr(self.distributed_executor_backend, - 'supports_pp', False) + supports_pp = getattr( + self.distributed_executor_backend, "supports_pp", False + ) if not supports_pp and self.distributed_executor_backend not in ( - ParallelConfig.distributed_executor_backend, "ray", "mp", - "external_launcher"): - name = "Pipeline Parallelism without Ray distributed " \ - "executor or multiprocessing executor or external " \ - "launcher" - _raise_or_fallback(feature_name=name, - recommend_to_remove=False) + ParallelConfig.distributed_executor_backend, + "ray", + "mp", + "external_launcher", + ): + name = ( + "Pipeline Parallelism without Ray distributed " + "executor or multiprocessing executor or external " + "launcher" + ) + _raise_or_fallback(feature_name=name, recommend_to_remove=False) return False - if (current_platform.is_cpu() - and model_config.get_sliding_window() is not None): - _raise_or_fallback(feature_name="sliding window (CPU backend)", - recommend_to_remove=False) + if current_platform.is_cpu() and model_config.get_sliding_window() is not None: + _raise_or_fallback( + feature_name="sliding window (CPU backend)", recommend_to_remove=False + ) return False ############################################################# return True - def _set_default_args(self, usage_context: UsageContext, - model_config: ModelConfig) -> None: + def _set_default_args( + self, usage_context: UsageContext, model_config: ModelConfig + ) -> None: """Set Default Arguments for V1 Engine.""" # V1 always uses chunked prefills and prefix caching @@ -1540,12 +1667,12 @@ def _set_default_args(self, usage_context: UsageContext, # TODO: When prefix caching supports prompt embeds inputs, this # check can be removed. - if (self.enable_prompt_embeds - and self.enable_prefix_caching is not False): + if self.enable_prompt_embeds and self.enable_prefix_caching is not False: logger.warning( "--enable-prompt-embeds and --enable-prefix-caching " "are not supported together in V1. Prefix caching has " - "been disabled.") + "been disabled." + ) self.enable_prefix_caching = False if self.enable_prefix_caching is None: @@ -1556,15 +1683,15 @@ def _set_default_args(self, usage_context: UsageContext, else: self.enable_prefix_caching = True else: - pooling_type = model_config.pooler_config.pooling_type is_causal = getattr(model_config.hf_config, "is_causal", True) - incremental_prefill_supported = (pooling_type is not None - and pooling_type.lower() == "last" - and is_causal) + incremental_prefill_supported = ( + pooling_type is not None + and pooling_type.lower() == "last" + and is_causal + ) - action = "Enabling" if \ - incremental_prefill_supported else "Disabling" + action = "Enabling" if incremental_prefill_supported else "Disabling" if self.enable_chunked_prefill is None: self.enable_chunked_prefill = incremental_prefill_supported @@ -1598,6 +1725,7 @@ def _set_default_args(self, usage_context: UsageContext, # throughput, see PR #17885 for more details. # So here we do an extra device name check to prevent such regression. from vllm.usage.usage_lib import UsageContext + if device_memory >= 70 * GiB_bytes and "a100" not in device_name: # For GPUs like H100 and MI300x, use larger default values. default_max_num_batched_tokens = { @@ -1623,15 +1751,15 @@ def _set_default_args(self, usage_context: UsageContext, if current_platform.is_tpu(): default_max_num_batched_tokens_tpu = { UsageContext.LLM_CLASS: { - 'V6E': 2048, - 'V5E': 1024, - 'V5P': 512, + "V6E": 2048, + "V5E": 1024, + "V5P": 512, }, UsageContext.OPENAI_API_SERVER: { - 'V6E': 1024, - 'V5E': 512, - 'V5P': 256, - } + "V6E": 1024, + "V5E": 512, + "V5P": 256, + }, } # cpu specific default values. @@ -1647,47 +1775,58 @@ def _set_default_args(self, usage_context: UsageContext, } use_context_value = usage_context.value if usage_context else None - if (self.max_num_batched_tokens is None - and usage_context in default_max_num_batched_tokens): + if ( + self.max_num_batched_tokens is None + and usage_context in default_max_num_batched_tokens + ): if current_platform.is_tpu(): chip_name = current_platform.get_device_name() - if chip_name in default_max_num_batched_tokens_tpu[ - usage_context]: - self.max_num_batched_tokens = \ - default_max_num_batched_tokens_tpu[ - usage_context][chip_name] + if chip_name in default_max_num_batched_tokens_tpu[usage_context]: + self.max_num_batched_tokens = default_max_num_batched_tokens_tpu[ + usage_context + ][chip_name] else: - self.max_num_batched_tokens = \ - default_max_num_batched_tokens[usage_context] + self.max_num_batched_tokens = default_max_num_batched_tokens[ + usage_context + ] else: if not self.enable_chunked_prefill: self.max_num_batched_tokens = model_config.max_model_len else: - self.max_num_batched_tokens = \ - default_max_num_batched_tokens[usage_context] + self.max_num_batched_tokens = default_max_num_batched_tokens[ + usage_context + ] logger.debug( "Setting max_num_batched_tokens to %d for %s usage context.", - self.max_num_batched_tokens, use_context_value) + self.max_num_batched_tokens, + use_context_value, + ) - if (self.max_num_seqs is None - and usage_context in default_max_num_seqs): - self.max_num_seqs = min(default_max_num_seqs[usage_context], - self.max_num_batched_tokens or sys.maxsize) + if self.max_num_seqs is None and usage_context in default_max_num_seqs: + self.max_num_seqs = min( + default_max_num_seqs[usage_context], + self.max_num_batched_tokens or sys.maxsize, + ) - logger.debug("Setting max_num_seqs to %d for %s usage context.", - self.max_num_seqs, use_context_value) + logger.debug( + "Setting max_num_seqs to %d for %s usage context.", + self.max_num_seqs, + use_context_value, + ) @dataclass class AsyncEngineArgs(EngineArgs): """Arguments for asynchronous vLLM engine.""" + enable_log_requests: bool = False @property @deprecated( "`disable_log_requests` is deprecated and has been replaced with " "`enable_log_requests`. This will be removed in v0.12.0. Please use " - "`enable_log_requests` instead.") + "`enable_log_requests` instead." + ) def disable_log_requests(self) -> bool: return not self.enable_log_requests @@ -1695,28 +1834,34 @@ def disable_log_requests(self) -> bool: @deprecated( "`disable_log_requests` is deprecated and has been replaced with " "`enable_log_requests`. This will be removed in v0.12.0. Please use " - "`enable_log_requests` instead.") + "`enable_log_requests` instead." + ) def disable_log_requests(self, value: bool): self.enable_log_requests = not value @staticmethod - def add_cli_args(parser: FlexibleArgumentParser, - async_args_only: bool = False) -> FlexibleArgumentParser: + def add_cli_args( + parser: FlexibleArgumentParser, async_args_only: bool = False + ) -> FlexibleArgumentParser: # Initialize plugin to update the parser, for example, The plugin may # add a new kind of quantization method to --quantization argument or # a new device to --device argument. load_general_plugins() if not async_args_only: parser = EngineArgs.add_cli_args(parser) - parser.add_argument('--enable-log-requests', - action=argparse.BooleanOptionalAction, - default=AsyncEngineArgs.enable_log_requests, - help='Enable logging requests.') - parser.add_argument('--disable-log-requests', - action=argparse.BooleanOptionalAction, - default=not AsyncEngineArgs.enable_log_requests, - help='[DEPRECATED] Disable logging requests.', - deprecated=True) + parser.add_argument( + "--enable-log-requests", + action=argparse.BooleanOptionalAction, + default=AsyncEngineArgs.enable_log_requests, + help="Enable logging requests.", + ) + parser.add_argument( + "--disable-log-requests", + action=argparse.BooleanOptionalAction, + default=not AsyncEngineArgs.enable_log_requests, + help="[DEPRECATED] Disable logging requests.", + deprecated=True, + ) current_platform.pre_register_and_update(parser) return parser @@ -1724,7 +1869,8 @@ def add_cli_args(parser: FlexibleArgumentParser, def _raise_or_fallback(feature_name: str, recommend_to_remove: bool): if envs.is_set("VLLM_USE_V1") and envs.VLLM_USE_V1: raise NotImplementedError( - f"VLLM_USE_V1=1 is not supported with {feature_name}.") + f"VLLM_USE_V1=1 is not supported with {feature_name}." + ) msg = f"{feature_name} is not supported by the V1 Engine. " msg += "Falling back to V0. " if recommend_to_remove: @@ -1743,17 +1889,17 @@ def human_readable_int(value): - '25.6k' -> 25,600 """ value = value.strip() - match = re.fullmatch(r'(\d+(?:\.\d+)?)([kKmMgGtT])', value) + match = re.fullmatch(r"(\d+(?:\.\d+)?)([kKmMgGtT])", value) if match: decimal_multiplier = { - 'k': 10**3, - 'm': 10**6, - 'g': 10**9, + "k": 10**3, + "m": 10**6, + "g": 10**9, } binary_multiplier = { - 'K': 2**10, - 'M': 2**20, - 'G': 2**30, + "K": 2**10, + "M": 2**20, + "G": 2**30, } number, suffix = match.groups() @@ -1766,9 +1912,11 @@ def human_readable_int(value): try: return int(number) * mult except ValueError as e: - raise argparse.ArgumentTypeError("Decimals are not allowed " \ - f"with binary suffixes like {suffix}. Did you mean to use " \ - f"{number}{suffix.lower()} instead?") from e + raise argparse.ArgumentTypeError( + "Decimals are not allowed " + f"with binary suffixes like {suffix}. Did you mean to use " + f"{number}{suffix.lower()} instead?" + ) from e # Regular plain number. return int(value) diff --git a/vllm/model_executor/model_loader/__init__.py b/vllm/model_executor/model_loader/__init__.py index 53ef69f6949d..fb9e46b942b7 100644 --- a/vllm/model_executor/model_loader/__init__.py +++ b/vllm/model_executor/model_loader/__init__.py @@ -9,18 +9,20 @@ from vllm.config.load import LoadConfig from vllm.logger import init_logger from vllm.model_executor.model_loader.base_loader import BaseModelLoader -from vllm.model_executor.model_loader.bitsandbytes_loader import ( - BitsAndBytesModelLoader) +from vllm.model_executor.model_loader.bitsandbytes_loader import BitsAndBytesModelLoader from vllm.model_executor.model_loader.default_loader import DefaultModelLoader from vllm.model_executor.model_loader.dummy_loader import DummyModelLoader from vllm.model_executor.model_loader.gguf_loader import GGUFModelLoader from vllm.model_executor.model_loader.runai_streamer_loader import ( - RunaiModelStreamerLoader) -from vllm.model_executor.model_loader.sharded_state_loader import ( - ShardedStateLoader) + RunaiModelStreamerLoader, +) +from vllm.model_executor.model_loader.sharded_state_loader import ShardedStateLoader from vllm.model_executor.model_loader.tensorizer_loader import TensorizerLoader from vllm.model_executor.model_loader.utils import ( - get_architecture_class_name, get_model_architecture, get_model_cls) + get_architecture_class_name, + get_model_architecture, + get_model_cls, +) logger = init_logger(__name__) @@ -69,7 +71,10 @@ def register_model_loader(load_format: str): Examples: >>> from vllm.config.load import LoadConfig - >>> from vllm.model_executor.model_loader import get_model_loader, register_model_loader + >>> from vllm.model_executor.model_loader import ( + ... get_model_loader, + ... register_model_loader, + ... ) >>> from vllm.model_executor.model_loader.base_loader import BaseModelLoader >>> >>> @register_model_loader("my_loader") @@ -89,14 +94,20 @@ def _wrapper(model_loader_cls): if load_format in _LOAD_FORMAT_TO_MODEL_LOADER: logger.warning( "Load format `%s` is already registered, and will be " - "overwritten by the new loader class `%s`.", load_format, - model_loader_cls) + "overwritten by the new loader class `%s`.", + load_format, + model_loader_cls, + ) if not issubclass(model_loader_cls, BaseModelLoader): - raise ValueError("The model loader must be a subclass of " - "`BaseModelLoader`.") + raise ValueError( + "The model loader must be a subclass of `BaseModelLoader`." + ) _LOAD_FORMAT_TO_MODEL_LOADER[load_format] = model_loader_cls - logger.info("Registered model loader `%s` with load format `%s`", - model_loader_cls, load_format) + logger.info( + "Registered model loader `%s` with load format `%s`", + model_loader_cls, + load_format, + ) return model_loader_cls return _wrapper @@ -110,16 +121,18 @@ def get_model_loader(load_config: LoadConfig) -> BaseModelLoader: return _LOAD_FORMAT_TO_MODEL_LOADER[load_format](load_config) -def get_model(*, - vllm_config: VllmConfig, - model_config: Optional[ModelConfig] = None, - prefix: str = "") -> nn.Module: +def get_model( + *, + vllm_config: VllmConfig, + model_config: Optional[ModelConfig] = None, + prefix: str = "", +) -> nn.Module: loader = get_model_loader(vllm_config.load_config) if model_config is None: model_config = vllm_config.model_config - return loader.load_model(vllm_config=vllm_config, - model_config=model_config, - prefix=prefix) + return loader.load_model( + vllm_config=vllm_config, model_config=model_config, prefix=prefix + ) __all__ = [ diff --git a/vllm/model_executor/model_loader/base_loader.py b/vllm/model_executor/model_loader/base_loader.py index 8f14383a3b82..4c7f88ee4875 100644 --- a/vllm/model_executor/model_loader/base_loader.py +++ b/vllm/model_executor/model_loader/base_loader.py @@ -9,7 +9,10 @@ from vllm.config.load import LoadConfig from vllm.logger import init_logger from vllm.model_executor.model_loader.utils import ( - initialize_model, process_weights_after_loading, set_default_torch_dtype) + initialize_model, + process_weights_after_loading, + set_default_torch_dtype, +) logger = init_logger(__name__) @@ -26,27 +29,26 @@ def download_model(self, model_config: ModelConfig) -> None: raise NotImplementedError @abstractmethod - def load_weights(self, model: nn.Module, - model_config: ModelConfig) -> None: - """Load weights into a model. This standalone API allows + def load_weights(self, model: nn.Module, model_config: ModelConfig) -> None: + """Load weights into a model. This standalone API allows inplace weights loading for an already-initialized model""" raise NotImplementedError - def load_model(self, - vllm_config: VllmConfig, - model_config: ModelConfig, - prefix: str = "") -> nn.Module: + def load_model( + self, vllm_config: VllmConfig, model_config: ModelConfig, prefix: str = "" + ) -> nn.Module: """Load a model with the given configurations.""" device_config = vllm_config.device_config load_config = vllm_config.load_config - load_device = device_config.device if load_config.device is None else \ - load_config.device + load_device = ( + device_config.device if load_config.device is None else load_config.device + ) target_device = torch.device(load_device) with set_default_torch_dtype(model_config.dtype): with target_device: - model = initialize_model(vllm_config=vllm_config, - model_config=model_config, - prefix=prefix) + model = initialize_model( + vllm_config=vllm_config, model_config=model_config, prefix=prefix + ) logger.debug("Loading weights on %s ...", load_device) # Quantization does not happen in `load_weights` but after it diff --git a/vllm/model_executor/model_loader/gguf_loader.py b/vllm/model_executor/model_loader/gguf_loader.py index f41ead9ae539..c1560941d5b2 100644 --- a/vllm/model_executor/model_loader/gguf_loader.py +++ b/vllm/model_executor/model_loader/gguf_loader.py @@ -13,10 +13,15 @@ from vllm.config.load import LoadConfig from vllm.model_executor.model_loader.base_loader import BaseModelLoader from vllm.model_executor.model_loader.utils import ( - initialize_model, process_weights_after_loading, set_default_torch_dtype) + initialize_model, + process_weights_after_loading, + set_default_torch_dtype, +) from vllm.model_executor.model_loader.weight_utils import ( - get_gguf_extra_tensor_names, get_gguf_weight_type_map, - gguf_quant_weights_iterator) + get_gguf_extra_tensor_names, + get_gguf_weight_type_map, + gguf_quant_weights_iterator, +) class GGUFModelLoader(BaseModelLoader): @@ -29,15 +34,18 @@ class GGUFModelLoader(BaseModelLoader): def __init__(self, load_config: LoadConfig): super().__init__(load_config) if load_config.model_loader_extra_config: - raise ValueError(f"Model loader extra config is not supported for " - f"load format {load_config.load_format}") + raise ValueError( + f"Model loader extra config is not supported for " + f"load format {load_config.load_format}" + ) def _prepare_weights(self, model_name_or_path: str): if os.path.isfile(model_name_or_path): return model_name_or_path # for raw HTTPS link if model_name_or_path.startswith( - ("http://", "https://")) and model_name_or_path.endswith(".gguf"): + ("http://", "https://") + ) and model_name_or_path.endswith(".gguf"): return hf_hub_download(url=model_name_or_path) # repo id/filename.gguf if "/" in model_name_or_path and model_name_or_path.endswith(".gguf"): @@ -46,7 +54,8 @@ def _prepare_weights(self, model_name_or_path: str): else: raise ValueError( f"Unrecognised GGUF reference: {model_name_or_path} " - "(expected local file, raw URL, or /.gguf)") + "(expected local file, raw URL, or /.gguf)" + ) def _get_gguf_weights_map(self, model_config: ModelConfig): """ @@ -68,25 +77,32 @@ def _get_gguf_weights_map(self, model_config: ModelConfig): # GGUF layer map assumes that we will have a merged expert weights # so we need to map them manually for idx in range(config.num_hidden_layers): - gguf_to_hf_name_map[f"blk.{idx}.exp_probs_b.bias"] = \ - f"model.layers.{idx}.mlp.gate.e_score_correction_bias" - gguf_to_hf_name_map[f"blk.{idx}.ffn_down_exps.weight"] = \ - f"model.layers.{idx}.mlp.experts.0.down_proj.weight" - gguf_to_hf_name_map[f"blk.{idx}.ffn_gate_exps.weight"] = \ - f"model.layers.{idx}.mlp.experts.0.gate_proj.weight" - gguf_to_hf_name_map[f"blk.{idx}.ffn_up_exps.weight"] = \ - f"model.layers.{idx}.mlp.experts.0.up_proj.weight" + gguf_to_hf_name_map[f"blk.{idx}.exp_probs_b.bias"] = ( + f"model.layers.{idx}.mlp.gate.e_score_correction_bias" + ) + gguf_to_hf_name_map[f"blk.{idx}.ffn_down_exps.weight"] = ( + f"model.layers.{idx}.mlp.experts.0.down_proj.weight" + ) + gguf_to_hf_name_map[f"blk.{idx}.ffn_gate_exps.weight"] = ( + f"model.layers.{idx}.mlp.experts.0.gate_proj.weight" + ) + gguf_to_hf_name_map[f"blk.{idx}.ffn_up_exps.weight"] = ( + f"model.layers.{idx}.mlp.experts.0.up_proj.weight" + ) if model_type in ("qwen2_moe", "qwen3_moe"): model_type = model_type.replace("_", "") # GGUF layer map assumes that we will have a merged expert weights # so we need to map them manually for idx in range(config.num_hidden_layers): - gguf_to_hf_name_map[f"blk.{idx}.ffn_down_exps.weight"] = \ - f"model.layers.{idx}.mlp.experts.0.down_proj.weight" - gguf_to_hf_name_map[f"blk.{idx}.ffn_gate_exps.weight"] = \ - f"model.layers.{idx}.mlp.experts.0.gate_proj.weight" - gguf_to_hf_name_map[f"blk.{idx}.ffn_up_exps.weight"] = \ - f"model.layers.{idx}.mlp.experts.0.up_proj.weight" + gguf_to_hf_name_map[f"blk.{idx}.ffn_down_exps.weight"] = ( + f"model.layers.{idx}.mlp.experts.0.down_proj.weight" + ) + gguf_to_hf_name_map[f"blk.{idx}.ffn_gate_exps.weight"] = ( + f"model.layers.{idx}.mlp.experts.0.gate_proj.weight" + ) + gguf_to_hf_name_map[f"blk.{idx}.ffn_up_exps.weight"] = ( + f"model.layers.{idx}.mlp.experts.0.up_proj.weight" + ) arch = None for key, value in gguf.MODEL_ARCH_NAMES.items(): @@ -99,7 +115,8 @@ def _get_gguf_weights_map(self, model_config: ModelConfig): name_map = gguf.get_tensor_name_map(arch, num_layers) with torch.device("meta"): dummy_model = AutoModelForCausalLM.from_config( - config, trust_remote_code=model_config.trust_remote_code) + config, trust_remote_code=model_config.trust_remote_code + ) state_dict = dummy_model.state_dict() for hf_name in state_dict: @@ -111,33 +128,31 @@ def _get_gguf_weights_map(self, model_config: ModelConfig): def _get_weights_iterator( self, model_name_or_path: str, gguf_to_hf_name_map: dict[str, str] ) -> Generator[tuple[str, torch.Tensor], None, None]: - return gguf_quant_weights_iterator(model_name_or_path, - gguf_to_hf_name_map) + return gguf_quant_weights_iterator(model_name_or_path, gguf_to_hf_name_map) def download_model(self, model_config: ModelConfig) -> None: self._prepare_weights(model_config.model) - def load_weights(self, model: nn.Module, - model_config: ModelConfig) -> None: + def load_weights(self, model: nn.Module, model_config: ModelConfig) -> None: local_model_path = self._prepare_weights(model_config.model) gguf_weights_map = self._get_gguf_weights_map(model_config) model.load_weights( - self._get_weights_iterator(local_model_path, gguf_weights_map)) + self._get_weights_iterator(local_model_path, gguf_weights_map) + ) - def load_model(self, - vllm_config: VllmConfig, - model_config: ModelConfig, - prefix: str = "") -> nn.Module: + def load_model( + self, vllm_config: VllmConfig, model_config: ModelConfig, prefix: str = "" + ) -> nn.Module: device_config = vllm_config.device_config local_model_path = self._prepare_weights(model_config.model) gguf_weights_map = self._get_gguf_weights_map(model_config) # we can only know if tie word embeddings after mapping weights if "lm_head.weight" in get_gguf_extra_tensor_names( - local_model_path, gguf_weights_map): + local_model_path, gguf_weights_map + ): model_config.hf_config.update({"tie_word_embeddings": True}) - weight_type_map = get_gguf_weight_type_map(model_config.model, - gguf_weights_map) + weight_type_map = get_gguf_weight_type_map(model_config.model, gguf_weights_map) # filter out unquantized modules to skip unquant_names = [ @@ -150,8 +165,7 @@ def load_model(self, target_device = torch.device(device_config.device) with set_default_torch_dtype(model_config.dtype): with target_device: - model = initialize_model(vllm_config=vllm_config, - prefix=prefix) + model = initialize_model(vllm_config=vllm_config, prefix=prefix) self.load_weights(model, model_config) process_weights_after_loading(model, model_config, target_device) diff --git a/vllm/model_executor/model_loader/tensorizer_loader.py b/vllm/model_executor/model_loader/tensorizer_loader.py index 1475f075b54b..29f946650671 100644 --- a/vllm/model_executor/model_loader/tensorizer_loader.py +++ b/vllm/model_executor/model_loader/tensorizer_loader.py @@ -13,11 +13,18 @@ from vllm.logger import init_logger from vllm.model_executor.model_loader.base_loader import BaseModelLoader from vllm.model_executor.model_loader.tensorizer import ( - TensorizerConfig, deserialize_tensorizer_model, init_tensorizer_model, - is_vllm_tensorized, serialize_vllm_model, tensorizer_weights_iterator) -from vllm.model_executor.model_loader.utils import (get_model_architecture, - initialize_model, - set_default_torch_dtype) + TensorizerConfig, + deserialize_tensorizer_model, + init_tensorizer_model, + is_vllm_tensorized, + serialize_vllm_model, + tensorizer_weights_iterator, +) +from vllm.model_executor.model_loader.utils import ( + get_model_architecture, + initialize_model, + set_default_torch_dtype, +) logger = init_logger(__name__) @@ -44,15 +51,18 @@ def __init__(self, load_config: LoadConfig): else: validate_config(load_config.model_loader_extra_config) self.tensorizer_config = TensorizerConfig( - **load_config.model_loader_extra_config["tensorizer_config"]) + **load_config.model_loader_extra_config["tensorizer_config"] + ) - def _verify_config(self, model_config: ModelConfig, - parallel_config: ParallelConfig): + def _verify_config( + self, model_config: ModelConfig, parallel_config: ParallelConfig + ): self.tensorizer_config.verify_with_model_config(model_config) self.tensorizer_config.verify_with_parallel_config(parallel_config) def _get_weights_iterator( - self, ) -> Generator[tuple[str, torch.Tensor], None, None]: + self, + ) -> Generator[tuple[str, torch.Tensor], None, None]: tensorizer_args = self.tensorizer_config._construct_tensorizer_args() return tensorizer_weights_iterator(tensorizer_args) @@ -72,8 +82,7 @@ def _load_model_serialized_cpu( model_config = vllm_config.model_config with set_default_torch_dtype(model_config.dtype): with torch.device(device_config.device): - model = initialize_model(vllm_config=vllm_config, - prefix=prefix) + model = initialize_model(vllm_config=vllm_config, prefix=prefix) model.load_weights(self._get_weights_iterator()) return model.eval() @@ -84,8 +93,7 @@ def download_model(self, model_config: ModelConfig) -> None: with self.tensorizer_config.open_stream(): pass - def _patch_tensorizer_config( - self, model_config: ModelConfig) -> TensorizerConfig: + def _patch_tensorizer_config(self, model_config: ModelConfig) -> TensorizerConfig: model_class = get_model_architecture(model_config)[0] tensorizer_config = copy.copy(self.tensorizer_config) tensorizer_config.model_class = model_class @@ -93,8 +101,7 @@ def _patch_tensorizer_config( tensorizer_config.dtype = model_config.dtype return tensorizer_config - def load_weights(self, model: nn.Module, - model_config: ModelConfig) -> None: + def load_weights(self, model: nn.Module, model_config: ModelConfig) -> None: """Load serialized model weights with tensorizer. Expects a vLLM-tensorized model. See the @@ -106,10 +113,9 @@ def load_weights(self, model: nn.Module, else: model.load_weights(self._get_weights_iterator()) - def load_model(self, - vllm_config: VllmConfig, - model_config: ModelConfig, - prefix: str = "") -> nn.Module: + def load_model( + self, vllm_config: VllmConfig, model_config: ModelConfig, prefix: str = "" + ) -> nn.Module: parallel_config = vllm_config.parallel_config self._verify_config(model_config, parallel_config) @@ -117,8 +123,8 @@ def load_model(self, from vllm.distributed import get_tensor_model_parallel_rank self.tensorizer_config.tensorizer_uri = ( - self.tensorizer_config.tensorizer_uri % - get_tensor_model_parallel_rank()) + self.tensorizer_config.tensorizer_uri % get_tensor_model_parallel_rank() + ) if is_vllm_tensorized(self.tensorizer_config): tensorizer_config = self._patch_tensorizer_config(model_config) @@ -126,12 +132,11 @@ def load_model(self, with set_default_torch_dtype(model_config.dtype): with torch.device(device_config.device): model = init_tensorizer_model( - tensorizer_config=tensorizer_config, - vllm_config=vllm_config) + tensorizer_config=tensorizer_config, vllm_config=vllm_config + ) self.load_weights(model, model_config) return model - return self._load_model_serialized_cpu(vllm_config=vllm_config, - prefix=prefix) + return self._load_model_serialized_cpu(vllm_config=vllm_config, prefix=prefix) @staticmethod def save_model( diff --git a/vllm/v1/attention/backends/utils.py b/vllm/v1/attention/backends/utils.py index c1ff2ed785f4..3186a6f71b0d 100644 --- a/vllm/v1/attention/backends/utils.py +++ b/vllm/v1/attention/backends/utils.py @@ -5,8 +5,18 @@ import functools from abc import abstractmethod from dataclasses import dataclass, fields, make_dataclass, replace -from typing import (TYPE_CHECKING, Any, ClassVar, Generic, Literal, Optional, - Protocol, TypeVar, Union, get_args) +from typing import ( + TYPE_CHECKING, + Any, + ClassVar, + Generic, + Literal, + Optional, + Protocol, + TypeVar, + Union, + get_args, +) import numpy as np import torch @@ -21,11 +31,11 @@ from vllm.v1.worker.gpu_input_batch import InputBatch import vllm.envs as envs -from vllm.attention.backends.abstract import (AttentionBackend, - AttentionMetadata) +from vllm.attention.backends.abstract import AttentionBackend, AttentionMetadata from vllm.attention.layer import Attention from vllm.distributed.kv_transfer.kv_connector.utils import ( - get_kv_connector_cache_layout) + get_kv_connector_cache_layout, +) from vllm.logger import init_logger from vllm.v1.kv_cache_interface import AttentionSpec from vllm.v1.worker.ubatch_utils import UBatchSlice @@ -46,7 +56,7 @@ class CommonAttentionMetadata: """ Per-batch attention metadata, shared across layers and backends. AttentionMetadataBuilder instances use it to construct per-layer metadata. - + For many of the tensors we keep both GPU and CPU versions. """ @@ -98,19 +108,23 @@ def slice_query_start_locs( request_slice: slice, ) -> torch.Tensor: """ - Creates a new query_start_loc that corresponds to the requests in + Creates a new query_start_loc that corresponds to the requests in request_slice. Note: This function creates a new tensor to hold the new query_start_locs. This will break cudagraph compatibility. """ - return query_start_loc[request_slice.start: request_slice.stop + 1] -\ - query_start_loc[request_slice.start] + return ( + query_start_loc[request_slice.start : request_slice.stop + 1] + - query_start_loc[request_slice.start] + ) def extend_all_queries_by_1( - common_attn_metadata: CommonAttentionMetadata, arange: torch.Tensor, - new_slot_mapping: torch.Tensor) -> CommonAttentionMetadata: + common_attn_metadata: CommonAttentionMetadata, + arange: torch.Tensor, + new_slot_mapping: torch.Tensor, +) -> CommonAttentionMetadata: """ Creates a new CommonAttentionMetadata with all query lengths increased by 1. Also all seq lens are increased by 1. @@ -120,8 +134,7 @@ def extend_all_queries_by_1( """ cad = common_attn_metadata # query start loc must be increased by [+0, +1, +2, ..., +batch_size] - new_query_start_loc = cad.query_start_loc \ - + arange[:len(cad.query_start_loc)] + new_query_start_loc = cad.query_start_loc + arange[: len(cad.query_start_loc)] new_seq_lens = cad.seq_lens + 1 new_cad = CommonAttentionMetadata( @@ -143,10 +156,11 @@ def extend_all_queries_by_1( return new_cad -def extend_flat_seqs(seqs: torch.Tensor, end_locs: torch.Tensor, - new_vals: torch.Tensor) -> torch.Tensor: +def extend_flat_seqs( + seqs: torch.Tensor, end_locs: torch.Tensor, new_vals: torch.Tensor +) -> torch.Tensor: """ - This function appends a single new value into multiple sequences + This function appends a single new value into multiple sequences that are stored in a flat format. E.g. [x1, x2, y1] and [x3, y2] become [x1, x2, x3, y1, y2] """ @@ -160,8 +174,7 @@ def extend_flat_seqs(seqs: torch.Tensor, end_locs: torch.Tensor, seqs_new_idxs = seqs_new_idxs.cumsum(0) - 1 # indices for new values - new_val_idxs = end_locs + 1 + torch.arange(new_vals.shape[0], - device=seqs.device) + new_val_idxs = end_locs + 1 + torch.arange(new_vals.shape[0], device=seqs.device) # assign seqs and new vals new_seqs[seqs_new_idxs] = seqs new_seqs[new_val_idxs] = new_vals @@ -170,15 +183,14 @@ def extend_flat_seqs(seqs: torch.Tensor, end_locs: torch.Tensor, def _make_metadata_with_slice( - ubatch_slice: UBatchSlice, - attn_metadata: CommonAttentionMetadata) -> CommonAttentionMetadata: + ubatch_slice: UBatchSlice, attn_metadata: CommonAttentionMetadata +) -> CommonAttentionMetadata: """ - This function creates a new CommonAttentionMetadata that corresponds to + This function creates a new CommonAttentionMetadata that corresponds to the requests included in ubatch_slice """ - assert not ubatch_slice.is_empty(), ( - f"Ubatch slice {ubatch_slice} is empty") + assert not ubatch_slice.is_empty(), f"Ubatch slice {ubatch_slice} is empty" request_slice = ubatch_slice.request_slice token_slice = ubatch_slice.token_slice @@ -189,10 +201,12 @@ def _make_metadata_with_slice( last_req = request_slice.stop - 1 last_tok = token_slice.stop - 1 - assert start_locs[first_req] <= first_tok < start_locs[first_req + 1], \ + assert start_locs[first_req] <= first_tok < start_locs[first_req + 1], ( "Token slice start outside of first request" - assert start_locs[last_req] <= last_tok < start_locs[last_req+1], \ + ) + assert start_locs[last_req] <= last_tok < start_locs[last_req + 1], ( "Token slice end outside of last request" + ) # If the "middle" request has tokens in both ubatches, we have to split it. # If ubatch_slice is the first ubatch then we will be splitting the last @@ -202,12 +216,13 @@ def _make_metadata_with_slice( splits_last_request = last_tok < start_locs[last_req + 1] - 1 query_start_loc_cpu = slice_query_start_locs(start_locs, request_slice) - query_start_loc = slice_query_start_locs(attn_metadata.query_start_loc, - request_slice) + query_start_loc = slice_query_start_locs( + attn_metadata.query_start_loc, request_slice + ) assert len(query_start_loc) >= 2, ( - f"query_start_loc must have at least 2 elements, " - f"got {len(query_start_loc)}") + f"query_start_loc must have at least 2 elements, got {len(query_start_loc)}" + ) if splits_first_request: tokens_skipped = first_tok - start_locs[first_req] @@ -229,14 +244,13 @@ def _make_metadata_with_slice( seq_lens_cpu[-1] -= tokens_skipped max_seq_len = int(seq_lens_cpu.max()) - num_computed_tokens_cpu = attn_metadata.num_computed_tokens_cpu[ - request_slice] + num_computed_tokens_cpu = attn_metadata.num_computed_tokens_cpu[request_slice] num_requests = request_slice.stop - request_slice.start num_actual_tokens = token_slice.stop - token_slice.start max_query_len = int( - torch.max(torch.abs(query_start_loc_cpu[1:] - - query_start_loc_cpu[:-1])).item()) + torch.max(torch.abs(query_start_loc_cpu[1:] - query_start_loc_cpu[:-1])).item() + ) # This is to account for the case where we are in a dummy # run and query_start_loc_cpu is full of 0s @@ -266,15 +280,14 @@ def split_attn_metadata( common_attn_metadata: CommonAttentionMetadata, ) -> list[CommonAttentionMetadata]: """ - Creates a new CommonAttentionMetadata instance that corresponds to the + Creates a new CommonAttentionMetadata instance that corresponds to the requests for each UBatchSlice in ubatch_slices. Note: This function does not modify common_attn_metadata """ results = [] for ubatch_slice in ubatch_slices: - results.append( - _make_metadata_with_slice(ubatch_slice, common_attn_metadata)) + results.append(_make_metadata_with_slice(ubatch_slice, common_attn_metadata)) return results @@ -283,7 +296,7 @@ def split_attn_metadata( class AttentionCGSupport(enum.Enum): - """ Constants for the cudagraph support of the attention backend + """Constants for the cudagraph support of the attention backend Here we do not consider the cascade attention, as currently it is never cudagraph supported.""" @@ -301,46 +314,53 @@ class AttentionCGSupport(enum.Enum): class AttentionMetadataBuilder(abc.ABC, Generic[M]): # Does this backend/builder support CUDA Graphs for attention (default: no). - cudagraph_support: ClassVar[AttentionCGSupport] = \ - AttentionCGSupport.NEVER + cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.NEVER # Does this backend/builder reorder the batch? # If not, set this to None. Otherwise set it to the query # length that will be pulled into the front of the batch. reorder_batch_threshold: Optional[int] = None @abstractmethod - def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], - vllm_config: VllmConfig, device: torch.device): + def __init__( + self, + kv_cache_spec: AttentionSpec, + layer_names: list[str], + vllm_config: VllmConfig, + device: torch.device, + ): self.kv_cache_spec = kv_cache_spec self.layer_names = layer_names self.vllm_config = vllm_config self.device = device def _init_reorder_batch_threshold( - self, - reorder_batch_threshold: int = 1, - supports_spec_as_decode: bool = False) -> None: + self, reorder_batch_threshold: int = 1, supports_spec_as_decode: bool = False + ) -> None: self.reorder_batch_threshold = reorder_batch_threshold - if self.reorder_batch_threshold is not None \ - and supports_spec_as_decode: + if self.reorder_batch_threshold is not None and supports_spec_as_decode: # If the backend supports spec-as-decode kernels, then we can set # the reorder_batch_threshold based on the number of speculative # tokens from the config. speculative_config = self.vllm_config.speculative_config - if (speculative_config is not None - and speculative_config.num_speculative_tokens is not None): - self.reorder_batch_threshold = \ + if ( + speculative_config is not None + and speculative_config.num_speculative_tokens is not None + ): + self.reorder_batch_threshold = ( 1 + speculative_config.num_speculative_tokens + ) @abstractmethod - def build(self, - common_prefix_len: int, - common_attn_metadata: CommonAttentionMetadata, - fast_build: bool = False) -> M: + def build( + self, + common_prefix_len: int, + common_attn_metadata: CommonAttentionMetadata, + fast_build: bool = False, + ) -> M: """ Central method that builds attention metadata. Some builders (MLA) require reorder_batch to be called prior to build. - + Args: common_prefix_len: The length of the common prefix of the batch. common_attn_metadata: The common attention metadata. @@ -350,8 +370,9 @@ def build(self, """ raise NotImplementedError - def reorder_batch(self, input_batch: "InputBatch", - scheduler_output: "SchedulerOutput") -> bool: + def reorder_batch( + self, input_batch: "InputBatch", scheduler_output: "SchedulerOutput" + ) -> bool: """ Update the order of requests in the batch based on the attention backend's needs. For example, some attention backends (namely MLA) may @@ -368,14 +389,16 @@ def reorder_batch(self, input_batch: "InputBatch", raise NotImplementedError def build_for_cudagraph_capture( - self, common_attn_metadata: CommonAttentionMetadata) -> M: + self, common_attn_metadata: CommonAttentionMetadata + ) -> M: """ Build attention metadata for CUDA graph capture. Uses build by default. Subclasses that override this method should call self.build or super().build_for_cudagraph_capture. """ - return self.build(common_prefix_len=0, - common_attn_metadata=common_attn_metadata) + return self.build( + common_prefix_len=0, common_attn_metadata=common_attn_metadata + ) def build_for_drafting( self, @@ -384,7 +407,7 @@ def build_for_drafting( ) -> M: """ Build attention metadata for draft model. Uses build by default. - + Args: common_attn_metadata: The common attention metadata. draft_index: The index of the current draft operation. @@ -393,9 +416,11 @@ def build_for_drafting( For tree-based attention, this index instead refers to the draft attempt for the i-th level in the tree of tokens. """ - return self.build(common_prefix_len=0, - common_attn_metadata=common_attn_metadata, - fast_build=True) + return self.build( + common_prefix_len=0, + common_attn_metadata=common_attn_metadata, + fast_build=True, + ) def use_cascade_attention( self, @@ -418,8 +443,11 @@ def get_kv_cache_layout(): if _KV_CACHE_LAYOUT_OVERRIDE is not None: cache_layout = _KV_CACHE_LAYOUT_OVERRIDE - logger.info_once("`_KV_CACHE_LAYOUT_OVERRIDE` variable detected. " \ - "Setting KV cache layout to %s.", cache_layout) + logger.info_once( + "`_KV_CACHE_LAYOUT_OVERRIDE` variable detected. " + "Setting KV cache layout to %s.", + cache_layout, + ) return cache_layout # Format specified by the user. @@ -429,8 +457,11 @@ def get_kv_cache_layout(): cache_layout = get_kv_connector_cache_layout() else: assert is_valid_kv_cache_layout(cache_layout) - logger.info_once("`VLLM_KV_CACHE_LAYOUT` environment variable " \ - "detected. Setting KV cache layout to %s.", cache_layout) + logger.info_once( + "`VLLM_KV_CACHE_LAYOUT` environment variable " + "detected. Setting KV cache layout to %s.", + cache_layout, + ) return cache_layout @@ -455,8 +486,8 @@ class PerLayerParameters: def get_per_layer_parameters( - vllm_config: VllmConfig, layer_names: list[str], - cls_: type['AttentionImpl']) -> dict[str, PerLayerParameters]: + vllm_config: VllmConfig, layer_names: list[str], cls_: type["AttentionImpl"] +) -> dict[str, PerLayerParameters]: """ Scan layers in `layer_names` and determine some hyperparameters to use during `plan`. @@ -476,17 +507,18 @@ def get_per_layer_parameters( sm_scale = impl.scale has_sinks = getattr(impl, "sinks", None) is not None - per_layer_params[key] = PerLayerParameters(window_left, - logits_soft_cap, sm_scale, - has_sinks) + per_layer_params[key] = PerLayerParameters( + window_left, logits_soft_cap, sm_scale, has_sinks + ) return per_layer_params def infer_global_hyperparameters( - per_layer_params: dict[str, PerLayerParameters]) -> PerLayerParameters: + per_layer_params: dict[str, PerLayerParameters], +) -> PerLayerParameters: """ - Currently, FlashInfer backend other than trtllm-gen + Currently, FlashInfer backend other than trtllm-gen only support models in which all layers share the same values for the following hyperparameters: - `window_left` @@ -507,13 +539,15 @@ def infer_global_hyperparameters( for params in param_sets: if params.window_left != global_params.window_left: raise ValueError( - "Window left is not the same for all layers. " \ - "One potential fix is to set disable_sliding_window=True") + "Window left is not the same for all layers. " + "One potential fix is to set disable_sliding_window=True" + ) assert params == global_params, ( "FlashInfer backend currently only supports models in which all" "layers share the same values " "for the following hyperparameters:" - "`window_left`, `logits_soft_cap`, `sm_scale`.") + "`window_left`, `logits_soft_cap`, `sm_scale`." + ) return global_params @@ -595,11 +629,10 @@ def make_local_attention_virtual_batches( # new_tokens_in_first_block = [2, 1, 4] # local_blocks = [2, 4, 2] q_tokens_in_first_block = np.minimum( - attn_chunk_size - ((seq_lens_np - q_seqlens) % attn_chunk_size), - q_seqlens).astype(np.int32) + attn_chunk_size - ((seq_lens_np - q_seqlens) % attn_chunk_size), q_seqlens + ).astype(np.int32) tokens_in_last_block = attn_chunk_size + (seq_lens_np % -attn_chunk_size) - local_blocks = 1 + cdiv(q_seqlens - q_tokens_in_first_block, - attn_chunk_size) + local_blocks = 1 + cdiv(q_seqlens - q_tokens_in_first_block, attn_chunk_size) # Once we know the number of local blocks we can compute the request spans # for each batch idx, we can figure out the number of "virtual" requests we @@ -620,14 +653,13 @@ def make_local_attention_virtual_batches( rarange = np.repeat(local_blocks, local_blocks) - arange - 1 # Then we can compute the seqlens_q_local, handling the fact that the # first and last blocks could be partial - seqlens_q_local = \ - np.repeat(q_seqlens - q_tokens_in_first_block, local_blocks) + seqlens_q_local = np.repeat(q_seqlens - q_tokens_in_first_block, local_blocks) # set the first block since this may be a partial block seqlens_q_local[arange == 0] = q_tokens_in_first_block # set the remaining blocks seqlens_q_local[arange > 0] = np.minimum( - seqlens_q_local - attn_chunk_size * (arange - 1), - attn_chunk_size)[arange > 0] + seqlens_q_local - attn_chunk_size * (arange - 1), attn_chunk_size + )[arange > 0] # convert from q_seqlens to cu_seqlens_q cu_seqlens_q_local = np.empty(virtual_batches + 1, dtype=np.int32) @@ -639,22 +671,20 @@ def make_local_attention_virtual_batches( # batch # For our example this will be: # seqlens_k_local = [4, 2, 4, 4, 4, 1, 4, 1] - seqlens_k_local = np.full(cu_num_blocks[-1], - attn_chunk_size, - dtype=np.int32) + seqlens_k_local = np.full(cu_num_blocks[-1], attn_chunk_size, dtype=np.int32) seqlens_k_local[cu_num_blocks - 1] = tokens_in_last_block num_computed_tokens_local = seqlens_k_local - seqlens_q_local - k_seqstarts_absolute = np.repeat(seq_lens_np, local_blocks) - \ - (rarange * attn_chunk_size + \ - np.repeat(tokens_in_last_block, local_blocks)) + k_seqstarts_absolute = np.repeat(seq_lens_np, local_blocks) - ( + rarange * attn_chunk_size + np.repeat(tokens_in_last_block, local_blocks) + ) # For the example the local attention blocks start at: # _b0_ _____b1_____ _b2_ # k_seqstarts_absolute = [0, 4, 4, 8, 12, 16, 4, 8] block_starts = k_seqstarts_absolute // block_size - assert attn_chunk_size % block_size == 0, \ - f"attn_chunk_size {attn_chunk_size} is not " \ - f"divisible by block_size {block_size}" + assert attn_chunk_size % block_size == 0, ( + f"attn_chunk_size {attn_chunk_size} is not divisible by block_size {block_size}" + ) pages_per_local_batch = attn_chunk_size // block_size # Create a block_table for the local attention blocks @@ -675,12 +705,14 @@ def make_local_attention_virtual_batches( # [ 22, 23 ], < local-batch 6, (batch 2, starting from k[4]) # [ 24, 25 ], < local-batch 7, (batch 2, starting from k[8]) # ] - block_indices = (block_starts[:, None] + - np.arange(pages_per_local_batch, dtype=np.int32)) - block_indices = block_indices.reshape(-1).clip(max=block_table.shape[1] - - 1) - batch_indices = np.repeat(np.arange(actual_batch_size, dtype=np.int32), - local_blocks * pages_per_local_batch) + block_indices = block_starts[:, None] + np.arange( + pages_per_local_batch, dtype=np.int32 + ) + block_indices = block_indices.reshape(-1).clip(max=block_table.shape[1] - 1) + batch_indices = np.repeat( + np.arange(actual_batch_size, dtype=np.int32), + local_blocks * pages_per_local_batch, + ) # NOTE: https://github.com/pytorch/pytorch/pull/160256 causes performance # regression when using numpy arrays (batch and block indices) to index into @@ -688,8 +720,9 @@ def make_local_attention_virtual_batches( # tensor first, which recovers perf. batch_indices_torch = torch.from_numpy(batch_indices) block_indices_torch = torch.from_numpy(block_indices) - block_table_local = block_table[batch_indices_torch, block_indices_torch]\ - .view(virtual_batches, -1) + block_table_local = block_table[batch_indices_torch, block_indices_torch].view( + virtual_batches, -1 + ) query_start_loc_cpu = torch.from_numpy(cu_seqlens_q_local) seq_lens_cpu = torch.from_numpy(seqlens_k_local) @@ -697,8 +730,7 @@ def make_local_attention_virtual_batches( return CommonAttentionMetadata( query_start_loc_cpu=query_start_loc_cpu, - query_start_loc=query_start_loc_cpu.to(device=device, - non_blocking=True), + query_start_loc=query_start_loc_cpu.to(device=device, non_blocking=True), seq_lens_cpu=seq_lens_cpu, seq_lens=seq_lens_cpu.to(device=device, non_blocking=True), num_computed_tokens_cpu=torch.from_numpy(num_computed_tokens_local), @@ -738,9 +770,7 @@ def make_kv_sharing_fast_prefill_common_attn_metadata( # Find how many decode indices belong to each request # request_ids: [0, 1, 1, 2] - request_ids = torch.bucketize(logits_indices, - query_start_loc[1:], - right=True) + request_ids = torch.bucketize(logits_indices, query_start_loc[1:], right=True) # Figure out how many tokens are in each request # num_decode_tokens: [1, 2, 1] @@ -748,9 +778,9 @@ def make_kv_sharing_fast_prefill_common_attn_metadata( # Calculate new query_start_loc with tokens in generation_indices # decode_query_start_loc: [0, 1, 3, 4] - decode_query_start_loc = torch.empty(num_reqs + 1, - device=query_start_loc.device, - dtype=query_start_loc.dtype) + decode_query_start_loc = torch.empty( + num_reqs + 1, device=query_start_loc.device, dtype=query_start_loc.dtype + ) decode_query_start_loc[0] = 0 decode_query_start_loc[1:] = torch.cumsum(num_decode_tokens, dim=0) @@ -759,8 +789,7 @@ def make_kv_sharing_fast_prefill_common_attn_metadata( common_attn_metadata = CommonAttentionMetadata( query_start_loc=decode_query_start_loc, - query_start_loc_cpu=decode_query_start_loc.to("cpu", - non_blocking=True), + query_start_loc_cpu=decode_query_start_loc.to("cpu", non_blocking=True), seq_lens=seq_lens, seq_lens_cpu=seq_lens.to("cpu", non_blocking=True), num_computed_tokens_cpu=common_attn_metadata.num_computed_tokens_cpu, @@ -776,22 +805,25 @@ def make_kv_sharing_fast_prefill_common_attn_metadata( def subclass_attention_backend( - name_prefix: str, attention_backend_cls: type[AttentionBackend], - builder_cls: type[AttentionMetadataBuilder[M]] + name_prefix: str, + attention_backend_cls: type[AttentionBackend], + builder_cls: type[AttentionMetadataBuilder[M]], ) -> type[AttentionBackend]: """ Return a new subclass where `get_builder_cls` returns `builder_cls`. """ name: str = name_prefix + attention_backend_cls.__name__ # type: ignore - return type(name, (attention_backend_cls, ), - {"get_builder_cls": lambda: builder_cls}) + return type( + name, (attention_backend_cls,), {"get_builder_cls": lambda: builder_cls} + ) def split_decodes_and_prefills( - common_attn_metadata: CommonAttentionMetadata, - decode_threshold: int = 1, - require_uniform: bool = False) -> tuple[int, int, int, int]: + common_attn_metadata: CommonAttentionMetadata, + decode_threshold: int = 1, + require_uniform: bool = False, +) -> tuple[int, int, int, int]: """ Assuming a reordered batch, finds the boundary between prefill and decode requests. @@ -815,8 +847,9 @@ def split_decodes_and_prefills( num_tokens = common_attn_metadata.num_actual_tokens query_start_loc = common_attn_metadata.query_start_loc_cpu - if max_query_len <= decode_threshold and \ - (not require_uniform or decode_threshold <= 1): + if max_query_len <= decode_threshold and ( + not require_uniform or decode_threshold <= 1 + ): return num_reqs, 0, num_tokens, 0 query_lens = query_start_loc[1:] - query_start_loc[:-1] @@ -849,7 +882,7 @@ def reorder_batch_to_split_decodes_and_prefills( """ Reorders the batch to split into prefill and decode requests; places all requests with <= decode_threshold tokens at the front of the batch. - + Returns: True if the batch was modified, False otherwise. """ @@ -904,8 +937,7 @@ def reorder_batch_to_split_decodes_and_prefills( return modified_batch -def reshape_query_for_spec_decode(query: torch.Tensor, - batch_size: int) -> torch.Tensor: +def reshape_query_for_spec_decode(query: torch.Tensor, batch_size: int) -> torch.Tensor: """ Reshapes the query tensor for the specified batch size, so that it has shape (batch_size, seq_len, num_heads, head_dim). @@ -915,13 +947,13 @@ def reshape_query_for_spec_decode(query: torch.Tensor, num_heads = query.shape[1] head_dim = query.shape[2] assert total_tokens % batch_size == 0, ( - f"{total_tokens=} is not divisible by {batch_size=}") + f"{total_tokens=} is not divisible by {batch_size=}" + ) seq_len = total_tokens // batch_size return query.view(batch_size, seq_len, num_heads, head_dim) -def reshape_attn_output_for_spec_decode( - attn_output: torch.Tensor) -> torch.Tensor: +def reshape_attn_output_for_spec_decode(attn_output: torch.Tensor) -> torch.Tensor: """ Reshapes the attention output tensor, so that the batch_size and seq_len dimensions are combined. @@ -929,16 +961,14 @@ def reshape_attn_output_for_spec_decode( if attn_output.dim() == 3: # Already in the correct shape return attn_output - assert attn_output.dim() == 4, \ - f"attn_output must be 4D, got {attn_output.dim()}D" + assert attn_output.dim() == 4, f"attn_output must be 4D, got {attn_output.dim()}D" total_tokens = attn_output.shape[0] * attn_output.shape[1] - return attn_output.view(total_tokens, attn_output.shape[2], - attn_output.shape[3]) + return attn_output.view(total_tokens, attn_output.shape[2], attn_output.shape[3]) KV_SHARING_FAST_PREFILL_METADATA_FIELDS = [ - ('logits_indices_padded', Optional[torch.Tensor], None), - ('num_logits_indices', int, 0), + ("logits_indices_padded", Optional[torch.Tensor], None), + ("num_logits_indices", int, 0), ] @@ -951,7 +981,7 @@ def subclass_attention_metadata( Return a new subclass of `metadata_cls` with additional fields """ name: str = name_prefix + metadata_cls.__name__ # type: ignore - Wrapped = make_dataclass(name, fields, bases=(metadata_cls, )) + Wrapped = make_dataclass(name, fields, bases=(metadata_cls,)) return Wrapped @@ -965,55 +995,55 @@ def create_fast_prefill_custom_backend( prefix: str, underlying_attn_backend: AttentionBackend, ) -> type[AttentionBackend]: - underlying_builder = underlying_attn_backend.get_builder_cls() class FastPrefillAttentionBuilder(underlying_builder): # type: ignore - - def build(self, - common_prefix_len: int, - common_attn_metadata: CommonAttentionMetadata, - fast_build: bool = False) -> AttentionMetadata: - new_common_attn_metadata =\ - make_kv_sharing_fast_prefill_common_attn_metadata(common_attn_metadata) - metadata = super().build(common_prefix_len, - new_common_attn_metadata, fast_build) + def build( + self, + common_prefix_len: int, + common_attn_metadata: CommonAttentionMetadata, + fast_build: bool = False, + ) -> AttentionMetadata: + new_common_attn_metadata = ( + make_kv_sharing_fast_prefill_common_attn_metadata(common_attn_metadata) + ) + metadata = super().build( + common_prefix_len, new_common_attn_metadata, fast_build + ) class KVSharingFastPrefillAttentionMetadata( - metadata.__class__, # type: ignore - KVSharingFastPrefillMetadata): - + metadata.__class__, # type: ignore + KVSharingFastPrefillMetadata, + ): def __init__(self, metadata, common_attn_metadata): # Shallow copy all fields in metadata cls for field in fields(metadata.__class__): - setattr(self, field.name, - getattr(metadata, field.name)) + setattr(self, field.name, getattr(metadata, field.name)) # Set additional fields that will be used in model code - assert (common_attn_metadata.logits_indices_padded - is not None - and common_attn_metadata.num_logits_indices - is not None) - self.logits_indices_padded = \ + assert ( + common_attn_metadata.logits_indices_padded is not None + and common_attn_metadata.num_logits_indices is not None + ) + self.logits_indices_padded = ( common_attn_metadata.logits_indices_padded - self.num_logits_indices = \ - common_attn_metadata.num_logits_indices + ) + self.num_logits_indices = common_attn_metadata.num_logits_indices - return KVSharingFastPrefillAttentionMetadata( - metadata, common_attn_metadata) + return KVSharingFastPrefillAttentionMetadata(metadata, common_attn_metadata) attn_backend = subclass_attention_backend( name_prefix=prefix, attention_backend_cls=underlying_attn_backend, - builder_cls=FastPrefillAttentionBuilder) + builder_cls=FastPrefillAttentionBuilder, + ) return attn_backend def compute_causal_conv1d_metadata(query_start_loc_p: torch.Tensor): - # Needed for causal_conv1d - seqlens = query_start_loc_p.diff().to('cpu') + seqlens = query_start_loc_p.diff().to("cpu") nums_dict = {} # type: ignore batch_ptr = None token_chunk_offset_ptr = None @@ -1021,40 +1051,39 @@ def compute_causal_conv1d_metadata(query_start_loc_p: torch.Tensor): for BLOCK_M in [8]: # cover all BLOCK_M values nums = -(-seqlens // BLOCK_M) nums_dict[BLOCK_M] = {} - nums_dict[BLOCK_M]['nums'] = nums - nums_dict[BLOCK_M]['tot'] = nums.sum().item() + nums_dict[BLOCK_M]["nums"] = nums + nums_dict[BLOCK_M]["tot"] = nums.sum().item() mlist = torch.from_numpy(np.repeat(np.arange(len(nums)), nums)) - nums_dict[BLOCK_M]['mlist'] = mlist - mlist_len = len(nums_dict[BLOCK_M]['mlist']) - nums_dict[BLOCK_M]['mlist_len'] = mlist_len + nums_dict[BLOCK_M]["mlist"] = mlist + mlist_len = len(nums_dict[BLOCK_M]["mlist"]) + nums_dict[BLOCK_M]["mlist_len"] = mlist_len MAX_NUM_PROGRAMS = max(1024, mlist_len) * 2 offsetlist = [] # type: ignore for idx, num in enumerate(nums): offsetlist.extend(range(num)) offsetlist = torch.tensor(offsetlist, dtype=torch.int32) - nums_dict[BLOCK_M]['offsetlist'] = offsetlist + nums_dict[BLOCK_M]["offsetlist"] = offsetlist if batch_ptr is None: # Update default value after class definition - batch_ptr = torch.full((MAX_NUM_PROGRAMS, ), - PAD_SLOT_ID, - dtype=torch.int32, - device=device) - token_chunk_offset_ptr = torch.full((MAX_NUM_PROGRAMS, ), - PAD_SLOT_ID, - dtype=torch.int32, - device=device) + batch_ptr = torch.full( + (MAX_NUM_PROGRAMS,), PAD_SLOT_ID, dtype=torch.int32, device=device + ) + token_chunk_offset_ptr = torch.full( + (MAX_NUM_PROGRAMS,), PAD_SLOT_ID, dtype=torch.int32, device=device + ) else: if batch_ptr.nelement() < MAX_NUM_PROGRAMS: batch_ptr.resize_(MAX_NUM_PROGRAMS).fill_(PAD_SLOT_ID) token_chunk_offset_ptr.resize_( # type: ignore - MAX_NUM_PROGRAMS).fill_(PAD_SLOT_ID) + MAX_NUM_PROGRAMS + ).fill_(PAD_SLOT_ID) batch_ptr[0:mlist_len].copy_(mlist) token_chunk_offset_ptr[ # type: ignore - 0:mlist_len].copy_(offsetlist) - nums_dict[BLOCK_M]['batch_ptr'] = batch_ptr - nums_dict[BLOCK_M]['token_chunk_offset_ptr'] = (token_chunk_offset_ptr - ) # type: ignore + 0:mlist_len + ].copy_(offsetlist) + nums_dict[BLOCK_M]["batch_ptr"] = batch_ptr + nums_dict[BLOCK_M]["token_chunk_offset_ptr"] = token_chunk_offset_ptr # type: ignore return nums_dict, batch_ptr, token_chunk_offset_ptr diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 1510b953c3e6..1ae79fb78471 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -11,25 +11,24 @@ from vllm.config import VllmConfig from vllm.distributed.kv_events import EventPublisherFactory, KVEventBatch -from vllm.distributed.kv_transfer.kv_connector.factory import ( - KVConnectorFactory) -from vllm.distributed.kv_transfer.kv_connector.v1 import (KVConnectorBase_V1, - KVConnectorRole) -from vllm.distributed.kv_transfer.kv_connector.v1.metrics import ( - KVConnectorStats) +from vllm.distributed.kv_transfer.kv_connector.factory import KVConnectorFactory +from vllm.distributed.kv_transfer.kv_connector.v1 import ( + KVConnectorBase_V1, + KVConnectorRole, +) +from vllm.distributed.kv_transfer.kv_connector.v1.metrics import KVConnectorStats from vllm.logger import init_logger from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry -from vllm.v1.core.encoder_cache_manager import (EncoderCacheManager, - compute_encoder_budget) +from vllm.v1.core.encoder_cache_manager import ( + EncoderCacheManager, + compute_encoder_budget, +) from vllm.v1.core.kv_cache_manager import KVCacheBlocks, KVCacheManager from vllm.v1.core.sched.interface import SchedulerInterface -from vllm.v1.core.sched.output import (CachedRequestData, NewRequestData, - SchedulerOutput) -from vllm.v1.core.sched.request_queue import (SchedulingPolicy, - create_request_queue) +from vllm.v1.core.sched.output import CachedRequestData, NewRequestData, SchedulerOutput +from vllm.v1.core.sched.request_queue import SchedulingPolicy, create_request_queue from vllm.v1.core.sched.utils import check_stop, remove_all -from vllm.v1.engine import (EngineCoreEventType, EngineCoreOutput, - EngineCoreOutputs) +from vllm.v1.engine import EngineCoreEventType, EngineCoreOutput, EngineCoreOutputs from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.metrics.stats import SchedulerStats from vllm.v1.outputs import DraftTokenIds, KVConnectorOutput, ModelRunnerOutput @@ -41,7 +40,6 @@ class Scheduler(SchedulerInterface): - def __init__( self, vllm_config: VllmConfig, @@ -67,16 +65,17 @@ def __init__( # by update_from_outputs(). This is currently used in the multi-engine # case to track request lifetimes efficiently. self.finished_req_ids_dict: Optional[dict[int, set[str]]] = ( - defaultdict(set) if include_finished_set else None) + defaultdict(set) if include_finished_set else None + ) # Scheduling constraints. self.max_num_running_reqs = self.scheduler_config.max_num_seqs - self.max_num_scheduled_tokens = \ - self.scheduler_config.max_num_batched_tokens + self.max_num_scheduled_tokens = self.scheduler_config.max_num_batched_tokens self.max_model_len = self.scheduler_config.max_model_len self.enable_kv_cache_events = ( self.kv_events_config is not None - and self.kv_events_config.enable_kv_cache_events) + and self.kv_events_config.enable_kv_cache_events + ) # Create KVConnector for the Scheduler. Note that each Worker # will have a corresponding KVConnector with Role=WORKER. @@ -85,12 +84,14 @@ def __init__( if self.vllm_config.kv_transfer_config is not None: assert len(self.kv_cache_config.kv_cache_groups) == 1, ( "Multiple KV cache groups are not currently supported " - "with KV connectors") + "with KV connectors" + ) assert not self.is_encoder_decoder, ( - "Encoder-decoder models are not currently supported " - "with KV connectors") + "Encoder-decoder models are not currently supported with KV connectors" + ) self.connector = KVConnectorFactory.create_connector( - config=self.vllm_config, role=KVConnectorRole.SCHEDULER) + config=self.vllm_config, role=KVConnectorRole.SCHEDULER + ) self.kv_event_publisher = EventPublisherFactory.create( self.kv_events_config, @@ -102,8 +103,7 @@ def __init__( self.block_size = self.cache_config.block_size - self.dcp_world_size = \ - vllm_config.parallel_config.decode_context_parallel_size + self.dcp_world_size = vllm_config.parallel_config.decode_context_parallel_size # Note(hc): The scheduler’s block_size must be multiplied # by dcp_world_size, since block hashes are computed on the # original full token sequence at a granularity of @@ -120,7 +120,8 @@ def __init__( self.policy = SchedulingPolicy.FCFS else: raise ValueError( - f"Unknown scheduling policy: {self.scheduler_config.policy}") + f"Unknown scheduling policy: {self.scheduler_config.policy}" + ) # Priority queues for requests. self.waiting = create_request_queue(self.policy) self.running: list[Request] = [] @@ -153,8 +154,7 @@ def __init__( # NOTE: For the models without encoder (e.g., text-only models), # the encoder cache will not be initialized because cache size is 0 # for these models. - self.encoder_cache_manager = EncoderCacheManager( - cache_size=encoder_cache_size) + self.encoder_cache_manager = EncoderCacheManager(cache_size=encoder_cache_size) speculative_config = vllm_config.speculative_config self.use_eagle = False @@ -213,30 +213,35 @@ def schedule(self) -> SchedulerOutput: while req_index < len(self.running) and token_budget > 0: request = self.running[req_index] - num_new_tokens = (request.num_tokens_with_spec + - request.num_output_placeholders - - request.num_computed_tokens) - if (0 < self.scheduler_config.long_prefill_token_threshold < - num_new_tokens): - num_new_tokens = ( - self.scheduler_config.long_prefill_token_threshold) + num_new_tokens = ( + request.num_tokens_with_spec + + request.num_output_placeholders + - request.num_computed_tokens + ) + if 0 < self.scheduler_config.long_prefill_token_threshold < num_new_tokens: + num_new_tokens = self.scheduler_config.long_prefill_token_threshold num_new_tokens = min(num_new_tokens, token_budget) # Make sure the input position does not exceed the max model len. # This is necessary when using spec decoding. num_new_tokens = min( - num_new_tokens, - self.max_model_len - request.num_computed_tokens) + num_new_tokens, self.max_model_len - request.num_computed_tokens + ) # Schedule encoder inputs. encoder_inputs_to_schedule = None new_encoder_compute_budget = encoder_compute_budget if request.has_encoder_inputs: - (encoder_inputs_to_schedule, num_new_tokens, - new_encoder_compute_budget - ) = self._try_schedule_encoder_inputs( - request, request.num_computed_tokens, num_new_tokens, - encoder_compute_budget) + ( + encoder_inputs_to_schedule, + num_new_tokens, + new_encoder_compute_budget, + ) = self._try_schedule_encoder_inputs( + request, + request.num_computed_tokens, + num_new_tokens, + encoder_compute_budget, + ) if num_new_tokens == 0: # The request cannot be scheduled because one of the following @@ -259,7 +264,8 @@ def schedule(self) -> SchedulerOutput: new_blocks = self.kv_cache_manager.allocate_slots( request, num_new_tokens, - num_lookahead_tokens=self.num_lookahead_tokens) + num_lookahead_tokens=self.num_lookahead_tokens, + ) if new_blocks is not None: # The request can be scheduled. @@ -284,8 +290,9 @@ def schedule(self) -> SchedulerOutput: preempted_req.num_computed_tokens = 0 preempted_req.num_preemptions += 1 if self.log_stats: - preempted_req.record_event(EngineCoreEventType.PREEMPTED, - scheduled_timestamp) + preempted_req.record_event( + EngineCoreEventType.PREEMPTED, scheduled_timestamp + ) self.waiting.prepend_request(preempted_req) preempted_reqs.append(preempted_req) @@ -306,19 +313,21 @@ def schedule(self) -> SchedulerOutput: # Speculative decode related. if request.spec_token_ids: - num_scheduled_spec_tokens = (num_new_tokens + - request.num_computed_tokens - - request.num_tokens) + num_scheduled_spec_tokens = ( + num_new_tokens + request.num_computed_tokens - request.num_tokens + ) if num_scheduled_spec_tokens > 0: # Trim spec_token_ids list to num_scheduled_spec_tokens. del request.spec_token_ids[num_scheduled_spec_tokens:] scheduled_spec_decode_tokens[request.request_id] = ( - request.spec_token_ids) + request.spec_token_ids + ) # Encoder-related. if encoder_inputs_to_schedule: scheduled_encoder_inputs[request.request_id] = ( - encoder_inputs_to_schedule) + encoder_inputs_to_schedule + ) # Allocate the encoder cache. for i in encoder_inputs_to_schedule: self.encoder_cache_manager.allocate(request, i) @@ -328,8 +337,10 @@ def schedule(self) -> SchedulerOutput: scheduled_loras: set[int] = set() if self.lora_config: scheduled_loras = set( - req.lora_request.lora_int_id for req in scheduled_running_reqs - if req.lora_request and req.lora_request.lora_int_id > 0) + req.lora_request.lora_int_id + for req in scheduled_running_reqs + if req.lora_request and req.lora_request.lora_int_id > 0 + ) assert len(scheduled_loras) <= self.lora_config.max_loras # Use a temporary RequestQueue to collect requests that need to be @@ -352,7 +363,8 @@ def schedule(self) -> SchedulerOutput: else: logger.debug( "%s is still in WAITING_FOR_REMOTE_KVS state.", - request.request_id) + request.request_id, + ) self.waiting.pop_request() skipped_waiting_requests.prepend_request(request) continue @@ -370,9 +382,14 @@ def schedule(self) -> SchedulerOutput: # Check that adding the request still respects the max_loras # constraint. - if (self.lora_config and request.lora_request and - (len(scheduled_loras) == self.lora_config.max_loras and - request.lora_request.lora_int_id not in scheduled_loras)): + if ( + self.lora_config + and request.lora_request + and ( + len(scheduled_loras) == self.lora_config.max_loras + and request.lora_request.lora_int_id not in scheduled_loras + ) + ): # Scheduling would exceed max_loras, skip. self.waiting.pop_request() skipped_waiting_requests.prepend_request(request) @@ -384,15 +401,17 @@ def schedule(self) -> SchedulerOutput: # Get already-cached tokens. if request.num_computed_tokens == 0: # Get locally-cached tokens. - new_computed_blocks, num_new_local_computed_tokens = \ - self.kv_cache_manager.get_computed_blocks( - request) + new_computed_blocks, num_new_local_computed_tokens = ( + self.kv_cache_manager.get_computed_blocks(request) + ) # Get externally-cached tokens if using a KVConnector. if self.connector is not None: num_external_computed_tokens, load_kv_async = ( self.connector.get_num_new_matched_tokens( - request, num_new_local_computed_tokens)) + request, num_new_local_computed_tokens + ) + ) if num_external_computed_tokens is None: # The request cannot be scheduled because @@ -403,13 +422,15 @@ def schedule(self) -> SchedulerOutput: continue # Total computed tokens (local + external). - num_computed_tokens = (num_new_local_computed_tokens + - num_external_computed_tokens) + num_computed_tokens = ( + num_new_local_computed_tokens + num_external_computed_tokens + ) # KVTransfer: WAITING reqs have num_computed_tokens > 0 # after async KV recvs are completed. else: new_computed_blocks = ( - self.kv_cache_manager.create_empty_block_list()) + self.kv_cache_manager.create_empty_block_list() + ) num_new_local_computed_tokens = 0 num_computed_tokens = request.num_computed_tokens @@ -426,15 +447,21 @@ def schedule(self) -> SchedulerOutput: # `request.num_prompt_tokens` to consider the resumed # requests, which have output tokens. num_new_tokens = request.num_tokens - num_computed_tokens - if (0 < self.scheduler_config.long_prefill_token_threshold - < num_new_tokens): + if ( + 0 + < self.scheduler_config.long_prefill_token_threshold + < num_new_tokens + ): num_new_tokens = ( - self.scheduler_config.long_prefill_token_threshold) + self.scheduler_config.long_prefill_token_threshold + ) # chunked prefill has to be enabled explicitly to allow # pooling requests to be chunked - if not self.scheduler_config.chunked_prefill_enabled and \ - num_new_tokens > token_budget: + if ( + not self.scheduler_config.chunked_prefill_enabled + and num_new_tokens > token_budget + ): self.waiting.pop_request() skipped_waiting_requests.prepend_request(request) continue @@ -444,11 +471,16 @@ def schedule(self) -> SchedulerOutput: # Schedule encoder inputs. if request.has_encoder_inputs: - (encoder_inputs_to_schedule, num_new_tokens, - new_encoder_compute_budget - ) = self._try_schedule_encoder_inputs( - request, num_computed_tokens, num_new_tokens, - encoder_compute_budget) + ( + encoder_inputs_to_schedule, + num_new_tokens, + new_encoder_compute_budget, + ) = self._try_schedule_encoder_inputs( + request, + num_computed_tokens, + num_new_tokens, + encoder_compute_budget, + ) if num_new_tokens == 0: # The request cannot be scheduled. break @@ -458,9 +490,9 @@ def schedule(self) -> SchedulerOutput: # extra block gets allocated which # creates a mismatch between the number # of local and remote blocks. - effective_lookahead_tokens = (0 if request.num_computed_tokens - == 0 else - self.num_lookahead_tokens) + effective_lookahead_tokens = ( + 0 if request.num_computed_tokens == 0 else self.num_lookahead_tokens + ) # Determine if we need to allocate cross-attention blocks. if self.is_encoder_decoder and request.has_encoder_inputs: @@ -468,8 +500,9 @@ def schedule(self) -> SchedulerOutput: # always padded to the maximum length. If we support other # encoder-decoder models, this will need to be updated if we # want to only allocate what is needed. - num_encoder_tokens =\ + num_encoder_tokens = ( self.scheduler_config.max_num_encoder_input_tokens + ) else: num_encoder_tokens = 0 @@ -511,20 +544,21 @@ def schedule(self) -> SchedulerOutput: req_index += 1 self.running.append(request) if self.log_stats: - request.record_event(EngineCoreEventType.SCHEDULED, - scheduled_timestamp) + request.record_event( + EngineCoreEventType.SCHEDULED, scheduled_timestamp + ) if request.status == RequestStatus.WAITING: scheduled_new_reqs.append(request) elif request.status == RequestStatus.PREEMPTED: scheduled_resumed_reqs.append(request) else: - raise RuntimeError( - f"Invalid request status: {request.status}") + raise RuntimeError(f"Invalid request status: {request.status}") if self.lora_config and request.lora_request: scheduled_loras.add(request.lora_request.lora_int_id) req_to_new_blocks[request.request_id] = ( - self.kv_cache_manager.get_blocks(request.request_id)) + self.kv_cache_manager.get_blocks(request.request_id) + ) num_scheduled_tokens[request.request_id] = num_new_tokens token_budget -= num_new_tokens request.status = RequestStatus.RUNNING @@ -535,7 +569,8 @@ def schedule(self) -> SchedulerOutput: # Encoder-related. if encoder_inputs_to_schedule: scheduled_encoder_inputs[request.request_id] = ( - encoder_inputs_to_schedule) + encoder_inputs_to_schedule + ) # Allocate the encoder cache. for i in encoder_inputs_to_schedule: self.encoder_cache_manager.allocate(request, i) @@ -553,23 +588,26 @@ def schedule(self) -> SchedulerOutput: # Since some requests in the RUNNING queue may not be scheduled in # this step, the total number of scheduled requests can be smaller than # len(self.running). - assert (len(scheduled_new_reqs) + len(scheduled_resumed_reqs) + - len(scheduled_running_reqs) <= len(self.running)) + assert len(scheduled_new_reqs) + len(scheduled_resumed_reqs) + len( + scheduled_running_reqs + ) <= len(self.running) # Get the longest common prefix among all requests in the running queue. # This can be potentially used for cascade attention. - num_common_prefix_blocks = [0] * len( - self.kv_cache_config.kv_cache_groups) + num_common_prefix_blocks = [0] * len(self.kv_cache_config.kv_cache_groups) if self.running: any_request = self.running[0] num_common_prefix_blocks = ( self.kv_cache_manager.get_num_common_prefix_blocks( - any_request, len(self.running))) + any_request, len(self.running) + ) + ) # Construct the scheduler output. new_reqs_data = [ NewRequestData.from_request( - req, req_to_new_blocks[req.request_id].get_block_ids()) + req, req_to_new_blocks[req.request_id].get_block_ids() + ) for req in scheduled_new_reqs ] cached_reqs_data = self._make_cached_request_data( @@ -579,11 +617,12 @@ def schedule(self) -> SchedulerOutput: scheduled_spec_decode_tokens, req_to_new_blocks, ) - scheduled_requests = (scheduled_new_reqs + scheduled_running_reqs + - scheduled_resumed_reqs) - structured_output_request_ids, grammar_bitmask = ( - self.get_grammar_bitmask(scheduled_requests, - scheduled_spec_decode_tokens)) + scheduled_requests = ( + scheduled_new_reqs + scheduled_running_reqs + scheduled_resumed_reqs + ) + structured_output_request_ids, grammar_bitmask = self.get_grammar_bitmask( + scheduled_requests, scheduled_spec_decode_tokens + ) scheduler_output = SchedulerOutput( scheduled_new_reqs=new_reqs_data, scheduled_cached_reqs=cached_reqs_data, @@ -597,8 +636,7 @@ def schedule(self) -> SchedulerOutput: # It contains the request IDs that are finished in between # the previous and the current steps. finished_req_ids=self.finished_req_ids, - free_encoder_mm_hashes=self.encoder_cache_manager. - get_freed_mm_hashes(), + free_encoder_mm_hashes=self.encoder_cache_manager.get_freed_mm_hashes(), structured_output_request_ids=structured_output_request_ids, grammar_bitmask=grammar_bitmask, ) @@ -680,16 +718,18 @@ def _make_cached_request_data( for req in itertools.chain(running_reqs, resumed_reqs): req_id = req.request_id req_ids.append(req_id) - num_tokens = (num_scheduled_tokens[req_id] - - len(spec_decode_tokens.get(req_id, ()))) + num_tokens = num_scheduled_tokens[req_id] - len( + spec_decode_tokens.get(req_id, ()) + ) if self.use_pp: # When using PP, the scheduler sends the sampled tokens back, # because there's no direct communication between the first- # stage worker and the last-stage worker. Otherwise, we don't # need to send the sampled tokens back because the model runner # will cache them. - token_ids = req.all_token_ids[req.num_computed_tokens:req. - num_computed_tokens + num_tokens] + token_ids = req.all_token_ids[ + req.num_computed_tokens : req.num_computed_tokens + num_tokens + ] new_token_ids.append(token_ids) elif use_connector: # When using a KVConnector, we add a placeholder to avoid index @@ -697,7 +737,8 @@ def _make_cached_request_data( # is updated to handle token IDs properly. new_token_ids.append([]) new_block_ids.append( - req_to_new_blocks[req_id].get_block_ids(allow_none=True)) + req_to_new_blocks[req_id].get_block_ids(allow_none=True) + ) num_computed_tokens.append(req.num_computed_tokens) num_output_tokens.append(len(req.output_token_ids)) # Because resumed_reqs is usually empty, it is more efficient to do @@ -766,7 +807,8 @@ def _try_schedule_encoder_inputs( if self.is_encoder_decoder and num_computed_tokens > 0: assert start_pos == 0, ( "Encoder input should be processed at the beginning of " - "the sequence when encoder-decoder models are used.") + "the sequence when encoder-decoder models are used." + ) # Encoder input has already been computed # The calculation here is a bit different. We don't turn encoder # output into tokens that get processed by the decoder and @@ -790,8 +832,7 @@ def _try_schedule_encoder_inputs( # current step. continue - if self.encoder_cache_manager.check_and_update_cache( - request, i): + if self.encoder_cache_manager.check_and_update_cache(request, i): # The encoder input is already computed and cached from a # previous step. continue @@ -799,16 +840,18 @@ def _try_schedule_encoder_inputs( # If no encoder input chunking is allowed, we do not want to # partially schedule a multimodal item. If the scheduled range would # only cover part of the mm input, roll back to before the mm item. - if (self.scheduler_config.disable_chunked_mm_input - and num_computed_tokens < start_pos - and (num_computed_tokens + num_new_tokens) - < (start_pos + num_encoder_tokens)): + if ( + self.scheduler_config.disable_chunked_mm_input + and num_computed_tokens < start_pos + and (num_computed_tokens + num_new_tokens) + < (start_pos + num_encoder_tokens) + ): num_new_tokens = start_pos - num_computed_tokens break if not self.encoder_cache_manager.can_allocate( - request, i, encoder_compute_budget, - num_tokens_to_schedule): + request, i, encoder_compute_budget, num_tokens_to_schedule + ): # The encoder cache is full or the encoder budget is exhausted. # NOTE(woosuk): We assume that the encoder input tokens should # be processed altogether, as the encoder usually uses @@ -881,8 +924,9 @@ def update_from_output( outputs: dict[int, list[EngineCoreOutput]] = defaultdict(list) spec_decoding_stats: Optional[SpecDecodingStats] = None - kv_connector_stats = (kv_connector_output.kv_connector_stats - if kv_connector_output else None) + kv_connector_stats = ( + kv_connector_output.kv_connector_stats if kv_connector_output else None + ) failed_kv_load_req_ids = None if kv_connector_output and kv_connector_output.invalid_block_ids: @@ -890,7 +934,8 @@ def update_from_output( # load. Identify affected requests and adjust their computed token # count to trigger recomputation of the invalid blocks. failed_kv_load_req_ids = self._handle_invalid_blocks( - kv_connector_output.invalid_block_ids) + kv_connector_output.invalid_block_ids + ) # NOTE(woosuk): As len(num_scheduled_tokens) can be up to 1K or more, # the below loop can be a performance bottleneck. We should do our best @@ -910,11 +955,13 @@ def update_from_output( continue req_index = model_runner_output.req_id_to_index[req_id] - generated_token_ids = sampled_token_ids[ - req_index] if sampled_token_ids else [] + generated_token_ids = ( + sampled_token_ids[req_index] if sampled_token_ids else [] + ) scheduled_spec_token_ids = ( - scheduler_output.scheduled_spec_decode_tokens.get(req_id)) + scheduler_output.scheduled_spec_decode_tokens.get(req_id) + ) if scheduled_spec_token_ids: num_draft_tokens = len(scheduled_spec_token_ids) num_accepted = len(generated_token_ids) - 1 @@ -928,7 +975,8 @@ def update_from_output( spec_decoding_stats = self.make_spec_decoding_stats( spec_decoding_stats, num_draft_tokens=num_draft_tokens, - num_accepted_tokens=num_accepted) + num_accepted_tokens=num_accepted, + ) stopped = False new_logprobs = None @@ -939,14 +987,14 @@ def update_from_output( # Check for stop and update request status. if new_token_ids: new_token_ids, stopped = self._update_request_with_output( - request, new_token_ids) + request, new_token_ids + ) # Stop checking for pooler models. pooler_output = None if pooler_outputs: pooler_output = pooler_outputs[req_index] - stopped = check_stop(request, self.max_model_len, - pooler_output) + stopped = check_stop(request, self.max_model_len, pooler_output) if stopped: kv_transfer_params = self._free_request(request) @@ -956,28 +1004,29 @@ def update_from_output( stopped_preempted_reqs.add(request) # Extract sample logprobs if needed. - if request.sampling_params is not None \ - and request.sampling_params.logprobs is not None and logprobs: + if ( + request.sampling_params is not None + and request.sampling_params.logprobs is not None + and logprobs + ): # NOTE: once we support N tokens per step (spec decode), # the outer lists can be of length > 1. new_logprobs = logprobs.slice(req_index, req_index + 1) - if new_token_ids and self.structured_output_manager.should_advance( - request): + if new_token_ids and self.structured_output_manager.should_advance(request): # NOTE: structured_output_request # should not be None if use_structured_output, we have # checked above, so safe to ignore type warning request.structured_output_request.grammar.accept_tokens( # type: ignore[union-attr] - req_id, new_token_ids) + req_id, new_token_ids + ) if num_nans_in_logits is not None and req_id in num_nans_in_logits: request.num_nans_in_logits = num_nans_in_logits[req_id] # Get prompt logprobs for this request. prompt_logprobs_tensors = prompt_logprobs_dict.get(req_id) - if new_token_ids or pooler_output is not None \ - or kv_transfer_params: - + if new_token_ids or pooler_output is not None or kv_transfer_params: # Add EngineCoreOutput for this Request. outputs[request.client_index].append( EngineCoreOutput( @@ -992,7 +1041,8 @@ def update_from_output( kv_transfer_params=kv_transfer_params, trace_headers=request.trace_headers, num_cached_tokens=request.num_cached_tokens, - )) + ) + ) else: # Invariant: EngineCore returns no partial prefill outputs. assert not prompt_logprobs_tensors @@ -1025,11 +1075,13 @@ def update_from_output( eco.finished_requests = finished_set else: engine_core_outputs[client_index] = EngineCoreOutputs( - finished_requests=finished_set) + finished_requests=finished_set + ) finished_req_ids.clear() - if (stats := self.make_stats(spec_decoding_stats, - kv_connector_stats)) is not None: + if ( + stats := self.make_stats(spec_decoding_stats, kv_connector_stats) + ) is not None: # Return stats to only one of the front-ends. if (eco := next(iter(engine_core_outputs.values()), None)) is None: # We must return the stats even if there are no request @@ -1060,8 +1112,9 @@ def _update_request_with_output( return new_token_ids, stopped def _free_encoder_inputs(self, request: Request) -> None: - cached_encoder_input_ids = ( - self.encoder_cache_manager.get_cached_input_ids(request)) + cached_encoder_input_ids = self.encoder_cache_manager.get_cached_input_ids( + request + ) # OPTIMIZATION: Avoid list(set) if the set is empty. if not cached_encoder_input_ids: return @@ -1076,21 +1129,19 @@ def _free_encoder_inputs(self, request: Request) -> None: # With Whisper, as soon as we've generated a single token, # we know we're done with the encoder input. Cross Attention # KVs have been calculated and cached already. - self.encoder_cache_manager.free_encoder_input( - request, input_id) + self.encoder_cache_manager.free_encoder_input(request, input_id) elif start_pos + num_tokens <= request.num_computed_tokens: # The encoder output is already processed and stored # in the decoder's KV cache. - self.encoder_cache_manager.free_encoder_input( - request, input_id) + self.encoder_cache_manager.free_encoder_input(request, input_id) def update_draft_token_ids( self, draft_token_ids: DraftTokenIds, ) -> None: for req_id, spec_token_ids in zip( - draft_token_ids.req_ids, - draft_token_ids.draft_token_ids, + draft_token_ids.req_ids, + draft_token_ids.draft_token_ids, ): request = self.requests.get(req_id) if request is None or request.is_finished(): @@ -1104,7 +1155,8 @@ def update_draft_token_ids( elif self.structured_output_manager.should_advance(request): metadata = request.structured_output_request request.spec_token_ids = metadata.grammar.validate_tokens( # type: ignore[union-attr] - spec_token_ids) + spec_token_ids + ) else: request.spec_token_ids = spec_token_ids @@ -1130,7 +1182,7 @@ def finish_requests( """ assert RequestStatus.is_finished(finished_status) if isinstance(request_ids, str): - request_ids = (request_ids, ) + request_ids = (request_ids,) else: request_ids = set(request_ids) @@ -1200,15 +1252,15 @@ def make_stats( return None prefix_cache_stats = self.kv_cache_manager.make_prefix_cache_stats() assert prefix_cache_stats is not None - return SchedulerStats(num_running_reqs=len(self.running), - num_waiting_reqs=len(self.waiting), - kv_cache_usage=self.kv_cache_manager.usage, - prefix_cache_stats=prefix_cache_stats, - spec_decoding_stats=spec_decoding_stats, - num_corrupted_reqs=sum(req.is_output_corrupted - for req in self.running), - kv_connector_stats=kv_connector_stats.data - if kv_connector_stats else None) + return SchedulerStats( + num_running_reqs=len(self.running), + num_waiting_reqs=len(self.waiting), + kv_cache_usage=self.kv_cache_manager.usage, + prefix_cache_stats=prefix_cache_stats, + spec_decoding_stats=spec_decoding_stats, + num_corrupted_reqs=sum(req.is_output_corrupted for req in self.running), + kv_connector_stats=kv_connector_stats.data if kv_connector_stats else None, + ) def make_spec_decoding_stats( self, @@ -1221,8 +1273,8 @@ def make_spec_decoding_stats( if spec_decoding_stats is None: spec_decoding_stats = SpecDecodingStats.new(self.num_spec_tokens) spec_decoding_stats.observe_draft( - num_draft_tokens=num_draft_tokens, - num_accepted_tokens=num_accepted_tokens) + num_draft_tokens=num_draft_tokens, num_accepted_tokens=num_accepted_tokens + ) return spec_decoding_stats def shutdown(self) -> None: @@ -1239,7 +1291,8 @@ def get_kv_connector(self) -> Optional[KVConnectorBase_V1]: return self.connector def _connector_finished( - self, request: Request) -> tuple[bool, Optional[dict[str, Any]]]: + self, request: Request + ) -> tuple[bool, Optional[dict[str, Any]]]: """ Invoke the KV connector request_finished() method if applicable. @@ -1249,7 +1302,7 @@ def _connector_finished( if self.connector is None: return False, None - (block_ids, ) = self.kv_cache_manager.get_block_ids(request.request_id) + (block_ids,) = self.kv_cache_manager.get_block_ids(request.request_id) return self.connector.request_finished(request, block_ids) def _update_waiting_for_remote_kv(self, request: Request) -> bool: @@ -1273,8 +1326,7 @@ def _update_waiting_for_remote_kv(self, request: Request) -> bool: # updated in _update_requests_with_invalid_blocks if request.num_computed_tokens: # Cache any valid computed tokens. - self.kv_cache_manager.cache_blocks(request, - request.num_computed_tokens) + self.kv_cache_manager.cache_blocks(request, request.num_computed_tokens) else: # No valid computed tokens, release allocated blocks. # There may be a local cache hit on retry. @@ -1283,8 +1335,7 @@ def _update_waiting_for_remote_kv(self, request: Request) -> bool: self.failed_recving_kv_req_ids.remove(request.request_id) else: # Now that the blocks are ready, actually cache them. - (block_ids, ) = self.kv_cache_manager.get_block_ids( - request.request_id) + (block_ids,) = self.kv_cache_manager.get_block_ids(request.request_id) num_computed_tokens = len(block_ids) * self.block_size # Handle the case where num request tokens less than one block. num_computed_tokens = min(num_computed_tokens, request.num_tokens) @@ -1300,8 +1351,7 @@ def _update_waiting_for_remote_kv(self, request: Request) -> bool: self.finished_recving_kv_req_ids.remove(request.request_id) return True - def _update_from_kv_xfer_finished(self, - kv_connector_output: KVConnectorOutput): + def _update_from_kv_xfer_finished(self, kv_connector_output: KVConnectorOutput): """ KV Connector: update the scheduler state based on the output. @@ -1316,21 +1366,23 @@ def _update_from_kv_xfer_finished(self, self.connector.update_connector_output(kv_connector_output) # KV Connector:: update recv and send status from last step. - for req_id in (kv_connector_output.finished_recving or ()): + for req_id in kv_connector_output.finished_recving or (): logger.debug("Finished recving KV transfer for request %s", req_id) self.finished_recving_kv_req_ids.add(req_id) - for req_id in (kv_connector_output.finished_sending or ()): + for req_id in kv_connector_output.finished_sending or (): logger.debug("Finished sending KV transfer for request %s", req_id) if req_id not in self.requests: logger.warning( "Got finished sending KV transfer for request %s," - "but the request is already freed.", req_id) + "but the request is already freed.", + req_id, + ) else: self._free_blocks(self.requests[req_id]) def _update_requests_with_invalid_blocks( - self, requests: Iterable[Request], - invalid_block_ids: set[int]) -> tuple[set[str], int]: + self, requests: Iterable[Request], invalid_block_ids: set[int] + ) -> tuple[set[str], int]: """ Identify and update requests affected by invalid KV cache blocks. @@ -1361,25 +1413,25 @@ def _update_requests_with_invalid_blocks( marked_invalid_block = False req_id = request.request_id # TODO (davidb): add support for hybrid memory allocator - (req_block_ids, ) = self.kv_cache_manager.get_block_ids(req_id) + (req_block_ids,) = self.kv_cache_manager.get_block_ids(req_id) # We iterate only over blocks that may contain externally computed # tokens if request.status == RequestStatus.WAITING_FOR_REMOTE_KVS: # Async loading. If num_computed_tokens is set it implies we # already processed some block failures for it in a prior step req_num_computed_tokens = ( - request.num_computed_tokens if req_id - in self.failed_recving_kv_req_ids else len(req_block_ids) * - self.block_size) + request.num_computed_tokens + if req_id in self.failed_recving_kv_req_ids + else len(req_block_ids) * self.block_size + ) else: # Sync loading. num_computed_tokens includes new tokens req_num_computed_tokens = request.num_cached_tokens - req_num_computed_blocks = (req_num_computed_tokens + - self.block_size - 1) // self.block_size - for idx, block_id in zip(range(req_num_computed_blocks), - req_block_ids): - + req_num_computed_blocks = ( + req_num_computed_tokens + self.block_size - 1 + ) // self.block_size + for idx, block_id in zip(range(req_num_computed_blocks), req_block_ids): if block_id not in invalid_block_ids: continue @@ -1404,8 +1456,9 @@ def _update_requests_with_invalid_blocks( marked_invalid_block = True # Truncate the computed tokens at the first failed block request.num_computed_tokens = idx * self.block_size - total_affected_tokens += (req_num_computed_tokens - - request.num_computed_tokens) + total_affected_tokens += ( + req_num_computed_tokens - request.num_computed_tokens + ) if is_affected: if not marked_invalid_block: @@ -1414,8 +1467,9 @@ def _update_requests_with_invalid_blocks( # Revert to considering only cached tokens as computed. # Currently this only applies to sync loading; Async # loading does not yet support block sharing - total_affected_tokens += (request.num_computed_tokens - - request.num_cached_tokens) + total_affected_tokens += ( + request.num_computed_tokens - request.num_cached_tokens + ) request.num_computed_tokens = request.num_cached_tokens affected_req_ids.add(request.request_id) @@ -1428,11 +1482,15 @@ def _handle_invalid_blocks(self, invalid_block_ids: set[int]) -> set[str]: # --- Handle async KV loads (WAITING_FOR_REMOTE_KVS) --- async_load_reqs = ( - req for req in self.waiting - if req.status == RequestStatus.WAITING_FOR_REMOTE_KVS) + req + for req in self.waiting + if req.status == RequestStatus.WAITING_FOR_REMOTE_KVS + ) async_affected_req_ids, num_tokens_to_reschedule = ( - self._update_requests_with_invalid_blocks(async_load_reqs, - invalid_block_ids)) + self._update_requests_with_invalid_blocks( + async_load_reqs, invalid_block_ids + ) + ) total_requests_to_reschedule += len(async_affected_req_ids) total_tokens_to_reschedule += num_tokens_to_reschedule @@ -1443,8 +1501,8 @@ def _handle_invalid_blocks(self, invalid_block_ids: set[int]) -> set[str]: # --- Handle sync KV loads (running requests) --- sync_affected_req_ids, num_tokens_to_reschedule = ( - self._update_requests_with_invalid_blocks(self.running, - invalid_block_ids)) + self._update_requests_with_invalid_blocks(self.running, invalid_block_ids) + ) total_requests_to_reschedule += len(sync_affected_req_ids) total_tokens_to_reschedule += num_tokens_to_reschedule @@ -1453,7 +1511,9 @@ def _handle_invalid_blocks(self, invalid_block_ids: set[int]) -> set[str]: logger.warning( "Recovered from KV load failure: " "%d request(s) rescheduled (%d tokens affected).", - total_requests_to_reschedule, total_tokens_to_reschedule) + total_requests_to_reschedule, + total_tokens_to_reschedule, + ) # Return the IDs of affected running requests to skip in # update_from_output. diff --git a/vllm/v1/outputs.py b/vllm/v1/outputs.py index b8c2ed5b49d0..3d9c8d147090 100644 --- a/vllm/v1/outputs.py +++ b/vllm/v1/outputs.py @@ -8,12 +8,10 @@ import torch if TYPE_CHECKING: - from vllm.distributed.kv_transfer.kv_connector.v1.metrics import ( - KVConnectorStats) + from vllm.distributed.kv_transfer.kv_connector.v1.metrics import KVConnectorStats class LogprobsLists(NamedTuple): - # [num_reqs, max_num_logprobs + 1] logprob_token_ids: list[list[int]] # [num_reqs, max_num_logprobs + 1] @@ -30,7 +28,6 @@ def slice(self, start: int, end: int): class LogprobsTensors(NamedTuple): - # [num_reqs, max_num_logprobs + 1] logprob_token_ids: torch.Tensor # [num_reqs, max_num_logprobs + 1] @@ -46,18 +43,18 @@ def tolists(self): ) @staticmethod - def empty_cpu(num_positions: int, - num_tokens_per_position: int) -> "LogprobsTensors": + def empty_cpu( + num_positions: int, num_tokens_per_position: int + ) -> "LogprobsTensors": """Create empty LogprobsTensors on CPU.""" logprob_token_ids = torch.empty( - (num_positions, num_tokens_per_position), - dtype=torch.int32, - device="cpu") + (num_positions, num_tokens_per_position), dtype=torch.int32, device="cpu" + ) logprobs = torch.empty_like(logprob_token_ids, dtype=torch.float32) - selected_token_ranks = torch.empty(num_positions, - dtype=torch.int32, - device="cpu") + selected_token_ranks = torch.empty( + num_positions, dtype=torch.int32, device="cpu" + ) return LogprobsTensors( logprob_token_ids=logprob_token_ids, logprobs=logprobs, @@ -72,7 +69,6 @@ def empty_cpu(num_positions: int, @dataclass class SamplerOutput: - # [num_reqs, max_num_generated_tokens] # Different requests can have different number of generated tokens. # All requests are padded to max_num_generated_tokens. @@ -95,15 +91,18 @@ class KVConnectorOutput: invalid_block_ids: set[int] = field(default_factory=set) def is_empty(self): - return (not self.finished_sending and not self.finished_recving - and not self.kv_connector_stats and not self.invalid_block_ids) + return ( + not self.finished_sending + and not self.finished_recving + and not self.kv_connector_stats + and not self.invalid_block_ids + ) # ModelRunnerOutput is serialized and sent to the scheduler process. # This is expensive for torch.Tensor so prefer to use list instead. @dataclass class ModelRunnerOutput: - # [num_reqs] req_ids: list[str] # req_id -> index @@ -137,11 +136,10 @@ class ModelRunnerOutput: # ModelRunnerOutput wrapper for async scheduling. class AsyncModelRunnerOutput(ABC): - @abstractmethod def get_output(self) -> ModelRunnerOutput: """Get the ModelRunnerOutput for this async output. - + This is a blocking call that waits until the results are ready, which might involve copying device tensors to the host. This method should only be called once per AsyncModelRunnerOutput. @@ -151,17 +149,18 @@ def get_output(self) -> ModelRunnerOutput: @dataclass class DraftTokenIds: - # [num_reqs] req_ids: list[str] # num_reqs x num_draft_tokens draft_token_ids: list[list[int]] -EMPTY_MODEL_RUNNER_OUTPUT = ModelRunnerOutput(req_ids=[], - req_id_to_index={}, - sampled_token_ids=[], - logprobs=None, - prompt_logprobs_dict={}, - pooler_output=[], - num_nans_in_logits=None) +EMPTY_MODEL_RUNNER_OUTPUT = ModelRunnerOutput( + req_ids=[], + req_id_to_index={}, + sampled_token_ids=[], + logprobs=None, + prompt_logprobs_dict={}, + pooler_output=[], + num_nans_in_logits=None, +) diff --git a/vllm/v1/spec_decode/draft_model.py b/vllm/v1/spec_decode/draft_model.py index a42599fd3c30..27e31eed7775 100644 --- a/vllm/v1/spec_decode/draft_model.py +++ b/vllm/v1/spec_decode/draft_model.py @@ -9,30 +9,36 @@ from vllm.config import ModelConfig, VllmConfig, get_layers_from_vllm_config from vllm.forward_context import set_forward_context from vllm.model_executor.model_loader import get_model -from vllm.v1.attention.backends.utils import (CommonAttentionMetadata, - extend_all_queries_by_1, - extend_flat_seqs) +from vllm.v1.attention.backends.utils import ( + CommonAttentionMetadata, + extend_all_queries_by_1, + extend_flat_seqs, +) from vllm.v1.outputs import SamplerOutput from vllm.v1.sample.metadata import SamplingMetadata -from vllm.v1.spec_decode.eagle import (PADDING_SLOT_ID, CudaGraphArgs, - SpecDecodeBaseProposer, - num_rejected_tokens) +from vllm.v1.spec_decode.eagle import ( + PADDING_SLOT_ID, + CudaGraphArgs, + SpecDecodeBaseProposer, + num_rejected_tokens, +) from vllm.v1.spec_decode.metadata import SpecDecodeMetadata class DraftModelProposer(SpecDecodeBaseProposer): - def __init__( self, vllm_config: VllmConfig, device: torch.device, runner=None, ): - super().__init__(vllm_config=vllm_config, - device=device, - pass_hidden_states_to_model=False, - pass_cudagraph_args_to_forward_ctx=True, - runner=runner) + super().__init__( + vllm_config=vllm_config, + device=device, + pass_hidden_states_to_model=False, + pass_cudagraph_args_to_forward_ctx=True, + runner=runner, + ) self._raise_if_multimodal() self._raise_if_mrope() self._raise_if_disabled_padded_drafter_batch() @@ -53,28 +59,31 @@ def propose( cudagraph_args: "CudaGraphArgs", sampler_output: SamplerOutput, spec_decode_metadata: Optional[SpecDecodeMetadata], - mm_embed_inputs: Optional[tuple[list[torch.Tensor], - torch.Tensor]] = None, + mm_embed_inputs: Optional[tuple[list[torch.Tensor], torch.Tensor]] = None, ) -> torch.Tensor: """ - - Trims unnecessary tokens from the input, like those rejected by + - Trims unnecessary tokens from the input, like those rejected by the sampler, or those already processed by the draft model. - - Merges the next_token_ids with the existing token ids into + - Merges the next_token_ids with the existing token ids into a flat sequence. """ - inputs = DraftModelInputs(cad=common_attn_metadata, - token_ids=target_token_ids, - positions=target_positions) + inputs = DraftModelInputs( + cad=common_attn_metadata, + token_ids=target_token_ids, + positions=target_positions, + ) inputs = trim_accepted_and_rejected_tokens( inputs=inputs, sampler_output=sampler_output, - spec_decode_metadata=spec_decode_metadata) + spec_decode_metadata=spec_decode_metadata, + ) inputs = merge_next_token_ids_into_token_ids( inputs=inputs, next_token_ids=next_token_ids, block_size=self.block_size, max_model_len=self.max_model_len, - arange=self.arange) + arange=self.arange, + ) draft_token_ids = super().propose( target_token_ids=inputs.token_ids, @@ -94,19 +103,23 @@ def propose( def _raise_if_multimodal(self): if self.supports_mm_inputs: - raise NotImplementedError("Speculative Decoding with draft models " - "does not support multimodal models yet") + raise NotImplementedError( + "Speculative Decoding with draft models " + "does not support multimodal models yet" + ) def _raise_if_mrope(self): if self.draft_model_config.uses_mrope: - raise NotImplementedError("Speculative Decoding with draft models " - "does not support M-RoPE yet") + raise NotImplementedError( + "Speculative Decoding with draft models does not support M-RoPE yet" + ) def _raise_if_disabled_padded_drafter_batch(self): if self.vllm_config.speculative_config.disable_padded_drafter_batch: raise NotImplementedError( "Speculative Decoding with draft models does not support " - "disabled padded drafter batch yet") + "disabled padded drafter batch yet" + ) def _model_kwargs(self, num_tokens: int) -> dict[str, Any]: return { @@ -117,28 +130,35 @@ def _model_kwargs(self, num_tokens: int) -> dict[str, Any]: def dummy_run(self, num_tokens: int, forward_ctx_kwargs: dict): model_kwargs = self._model_kwargs(num_tokens) with set_forward_context( - vllm_config=self.vllm_config, - num_tokens=num_tokens, - **forward_ctx_kwargs, + vllm_config=self.vllm_config, + num_tokens=num_tokens, + **forward_ctx_kwargs, ): self.model(**model_kwargs) - def set_input_ids_first_pass(self, target_token_ids: torch.Tensor, - next_token_ids: torch.Tensor, num_tokens: int, - last_token_indices: torch.Tensor) -> None: + def set_input_ids_first_pass( + self, + target_token_ids: torch.Tensor, + next_token_ids: torch.Tensor, + num_tokens: int, + last_token_indices: torch.Tensor, + ) -> None: self.input_ids[:num_tokens] = target_token_ids def load_model(self, target_model: Any) -> None: """Takes target_model to satisfy the type checker.""" draft_model_config: ModelConfig = ( - self.vllm_config.speculative_config.draft_model_config) + self.vllm_config.speculative_config.draft_model_config + ) vllm_config_draft: VllmConfig = replace( - self.vllm_config, model_config=draft_model_config) + self.vllm_config, model_config=draft_model_config + ) # This must be computed before loading the draft model # because that mutates the forward_context of the vllm_config target_attn_layer_names = set( - get_layers_from_vllm_config(self.vllm_config, Attention).keys()) + get_layers_from_vllm_config(self.vllm_config, Attention).keys() + ) from vllm.compilation.backends import set_model_tag @@ -152,8 +172,9 @@ def load_model(self, target_model: Any) -> None: # This must be computed after loading the draft model # because that mutates the forward_context of the vllm_config draft_attn_layer_names = ( - get_layers_from_vllm_config(self.vllm_config, Attention).keys() - - target_attn_layer_names) + get_layers_from_vllm_config(self.vllm_config, Attention).keys() + - target_attn_layer_names + ) self.attn_layer_names = list(draft_attn_layer_names) @@ -165,21 +186,23 @@ class DraftModelInputs: def trim_accepted_and_rejected_tokens( - inputs: DraftModelInputs, sampler_output: SamplerOutput, - spec_decode_metadata: Optional[SpecDecodeMetadata] + inputs: DraftModelInputs, + sampler_output: SamplerOutput, + spec_decode_metadata: Optional[SpecDecodeMetadata], ) -> DraftModelInputs: """ Removes from the input.token_ids any tokens that have already been processed by the draft model, as well as tokens rejected by the sampler. - Adjusts the positions accordingly, the slot mapping, + Adjusts the positions accordingly, the slot mapping, and the common_attn_metadata. """ cad: CommonAttentionMetadata = inputs.cad # Compute the new token ids and positions n_accepted_tokens = sampler_output.n_sampled_tokens() - 1 - n_rejected_tokens = num_rejected_tokens(spec_decode_metadata, - sampler_output.n_sampled_tokens()) + n_rejected_tokens = num_rejected_tokens( + spec_decode_metadata, sampler_output.n_sampled_tokens() + ) from_loc = cad.query_start_loc[:-1] + n_accepted_tokens to_loc = cad.query_start_loc[1:] - 1 - n_rejected_tokens idxs = compute_subrange_indices(from_loc, to_loc) @@ -202,9 +225,9 @@ def trim_accepted_and_rejected_tokens( max_query_len=new_query_lens.max().item(), slot_mapping=new_slot_mapping, ) - return DraftModelInputs(token_ids=new_token_ids, - positions=new_positions, - cad=new_cad) + return DraftModelInputs( + token_ids=new_token_ids, positions=new_positions, cad=new_cad + ) def compute_subrange_indices(start_locs: torch.Tensor, end_locs: torch.Tensor): @@ -222,7 +245,8 @@ def compute_subrange_indices(start_locs: torch.Tensor, end_locs: torch.Tensor): # broadcasting + masking ensures we only keep valid positions max_len = lengths.max() offsets = torch.arange(max_len, device=start_locs.device).unsqueeze( - 0) # shape [1, max_len] + 0 + ) # shape [1, max_len] mask = offsets < lengths.unsqueeze(1) # shape [n, max_len] # Build all indices all_indices = start_locs.unsqueeze(1) + offsets @@ -239,28 +263,27 @@ def merge_next_token_ids_into_token_ids( ) -> DraftModelInputs: """ Merges the next token ids with the existing token ids into a flat sequence. - Does the same for the positions, computes new slot mapping, + Does the same for the positions, computes new slot mapping, and updates the common_attn_metadata. """ cad: CommonAttentionMetadata = inputs.cad # merge token_ids and next_token_ids query_end_locs = cad.query_start_loc[1:] - 1 - new_token_ids = extend_flat_seqs(seqs=inputs.token_ids, - end_locs=query_end_locs, - new_vals=next_token_ids) + new_token_ids = extend_flat_seqs( + seqs=inputs.token_ids, end_locs=query_end_locs, new_vals=next_token_ids + ) # append new positions positions_to_append = inputs.positions[query_end_locs] + 1 - new_positions = extend_flat_seqs(seqs=inputs.positions, - end_locs=query_end_locs, - new_vals=positions_to_append) + new_positions = extend_flat_seqs( + seqs=inputs.positions, end_locs=query_end_locs, new_vals=positions_to_append + ) # recompute slot mapping batch_size, n_blocks_per_req = cad.block_table_tensor.shape req_indices = torch.arange(batch_size, device=cad.query_start_loc.device) req_indices = torch.repeat_interleave(req_indices, cad.query_lens() + 1) - block_table_indices = (req_indices * n_blocks_per_req + - new_positions // block_size) + block_table_indices = req_indices * n_blocks_per_req + new_positions // block_size block_nums = cad.block_table_tensor.view(-1)[block_table_indices] block_offsets = new_positions % block_size new_slot_mapping = block_nums * block_size + block_offsets @@ -270,7 +293,8 @@ def merge_next_token_ids_into_token_ids( # update common_attn_metadata new_cad: CommonAttentionMetadata = extend_all_queries_by_1( - cad, arange=arange, new_slot_mapping=new_slot_mapping) - return DraftModelInputs(token_ids=new_token_ids, - positions=new_positions, - cad=new_cad) + cad, arange=arange, new_slot_mapping=new_slot_mapping + ) + return DraftModelInputs( + token_ids=new_token_ids, positions=new_positions, cad=new_cad + ) diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index 5db9bb203c9d..cd0a3c2d5312 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -10,8 +10,7 @@ import torch.nn as nn from vllm.attention.layer import Attention -from vllm.config import (CompilationLevel, VllmConfig, - get_layers_from_vllm_config) +from vllm.config import CompilationLevel, VllmConfig, get_layers_from_vllm_config from vllm.config.compilation import CUDAGraphMode from vllm.distributed.parallel_state import get_pp_group from vllm.forward_context import BatchDescriptor, set_forward_context @@ -24,11 +23,15 @@ from vllm.platforms import current_platform from vllm.utils import is_pin_memory_available from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata -from vllm.v1.attention.backends.tree_attn import (TreeAttentionMetadata, - TreeAttentionMetadataBuilder) +from vllm.v1.attention.backends.tree_attn import ( + TreeAttentionMetadata, + TreeAttentionMetadataBuilder, +) from vllm.v1.attention.backends.triton_attn import TritonAttentionMetadata -from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder, - CommonAttentionMetadata) +from vllm.v1.attention.backends.utils import ( + AttentionMetadataBuilder, + CommonAttentionMetadata, +) from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.outputs import SamplerOutput from vllm.v1.sample.metadata import SamplingMetadata @@ -43,7 +46,6 @@ class SpecDecodeBaseProposer: - def __init__( self, vllm_config: VllmConfig, @@ -58,18 +60,15 @@ def __init__( self.draft_model_config = self.speculative_config.draft_model_config self.method = self.speculative_config.method self.pass_hidden_states_to_model = pass_hidden_states_to_model - self.pass_cudagraph_args_to_forward_ctx \ - = pass_cudagraph_args_to_forward_ctx + self.pass_cudagraph_args_to_forward_ctx = pass_cudagraph_args_to_forward_ctx self.runner = runner self.device = device self.dtype = vllm_config.model_config.dtype self.max_model_len = vllm_config.model_config.max_model_len self.block_size = vllm_config.cache_config.block_size - self.num_speculative_tokens = ( - self.speculative_config.num_speculative_tokens) - self.max_num_tokens = ( - vllm_config.scheduler_config.max_num_batched_tokens) + self.num_speculative_tokens = self.speculative_config.num_speculative_tokens + self.max_num_tokens = vllm_config.scheduler_config.max_num_batched_tokens self.token_arange_np = np.arange(self.max_num_tokens) # We need to get the hidden size from the draft model config because # the draft model's hidden size can be different from the target model's @@ -79,62 +78,64 @@ def __init__( # Multi-modal data support self.mm_registry = MULTIMODAL_REGISTRY self.supports_mm_inputs = self.mm_registry.supports_multimodal_inputs( - vllm_config.model_config) + vllm_config.model_config + ) self.attn_metadata_builder: Optional[AttentionMetadataBuilder] = None - self.draft_indexer_metadata_builder: Optional[ - AttentionMetadataBuilder] = None + self.draft_indexer_metadata_builder: Optional[AttentionMetadataBuilder] = None self.attn_layer_names: list[str] = [] self.indexer_layer_names: list[str] = [] - self.use_cuda_graph = (not current_platform.is_xpu() - and self.vllm_config.compilation_config.level - == CompilationLevel.PIECEWISE and - not self.vllm_config.model_config.enforce_eager - and not self.speculative_config.enforce_eager) - self.cudagraph_batch_sizes = list( - reversed(self.vllm_config.compilation_config. - cudagraph_capture_sizes)) if self.use_cuda_graph else [] + self.use_cuda_graph = ( + not current_platform.is_xpu() + and self.vllm_config.compilation_config.level == CompilationLevel.PIECEWISE + and not self.vllm_config.model_config.enforce_eager + and not self.speculative_config.enforce_eager + ) + self.cudagraph_batch_sizes = ( + list(reversed(self.vllm_config.compilation_config.cudagraph_capture_sizes)) + if self.use_cuda_graph + else [] + ) # persistent buffers for cuda graph - self.input_ids = torch.zeros(self.max_num_tokens, - dtype=torch.int32, - device=device) + self.input_ids = torch.zeros( + self.max_num_tokens, dtype=torch.int32, device=device + ) self.uses_mrope = self.vllm_config.model_config.uses_mrope if self.uses_mrope: # M-RoPE need (3, max_num_tokens) - self.mrope_positions = torch.zeros((3, self.max_num_tokens), - dtype=torch.int64, - device=device) + self.mrope_positions = torch.zeros( + (3, self.max_num_tokens), dtype=torch.int64, device=device + ) else: # RoPE need (max_num_tokens,) - self.positions = torch.zeros(self.max_num_tokens, - dtype=torch.int64, - device=device) + self.positions = torch.zeros( + self.max_num_tokens, dtype=torch.int64, device=device + ) self.hidden_states = torch.zeros( - (self.max_num_tokens, self.hidden_size), - dtype=self.dtype, - device=device) + (self.max_num_tokens, self.hidden_size), dtype=self.dtype, device=device + ) # We need +1 here because the arange is used to set query_start_loc, # which has one more element than batch_size. max_batch_size = vllm_config.scheduler_config.max_num_seqs max_num_slots_for_arange = max(max_batch_size + 1, self.max_num_tokens) - self.arange = torch.arange(max_num_slots_for_arange, - device=device, - dtype=torch.int32) + self.arange = torch.arange( + max_num_slots_for_arange, device=device, dtype=torch.int32 + ) self.inputs_embeds = torch.zeros( - (self.max_num_tokens, self.hidden_size), - dtype=self.dtype, - device=device) + (self.max_num_tokens, self.hidden_size), dtype=self.dtype, device=device + ) self.backup_next_token_ids = CpuGpuBuffer( max_batch_size, dtype=torch.int32, pin_memory=is_pin_memory_available(), device=device, - with_numpy=True) + with_numpy=True, + ) # Determine allowed attention backends once during initialization. self.allowed_attn_types: Optional[tuple] = None @@ -143,14 +144,15 @@ def __init__( # vllm.v1.attention.backends.rocm_aiter_fa is an optional backend if find_spec("vllm.v1.attention.backends.rocm_aiter_fa"): from vllm.v1.attention.backends.rocm_aiter_fa import ( - AiterFlashAttentionMetadata) + AiterFlashAttentionMetadata, + ) + rocm_types.append(AiterFlashAttentionMetadata) self.allowed_attn_types = tuple(rocm_types) # Parse the speculative token tree. spec_token_tree = self.speculative_config.speculative_token_tree - self.tree_choices: list[tuple[int, - ...]] = ast.literal_eval(spec_token_tree) + self.tree_choices: list[tuple[int, ...]] = ast.literal_eval(spec_token_tree) tree_depth = len(self.tree_choices[-1]) # Precompute per-level properties of the tree. num_drafts_per_level = [0] * tree_depth @@ -159,10 +161,12 @@ def __init__( self.cu_drafts_per_level = [num_drafts_per_level[0]] self.child_drafts_per_level = [num_drafts_per_level[0]] for level in range(1, tree_depth): - self.cu_drafts_per_level.append(self.cu_drafts_per_level[-1] + - num_drafts_per_level[level]) - self.child_drafts_per_level.append(num_drafts_per_level[level] // - num_drafts_per_level[level - 1]) + self.cu_drafts_per_level.append( + self.cu_drafts_per_level[-1] + num_drafts_per_level[level] + ) + self.child_drafts_per_level.append( + num_drafts_per_level[level] // num_drafts_per_level[level - 1] + ) # Precompute draft position offsets in flattened tree. self.tree_draft_pos_offsets = torch.arange( 1, @@ -198,8 +202,7 @@ def propose( cudagraph_args: "CudaGraphArgs", sampler_output: SamplerOutput, spec_decode_metadata: Optional[SpecDecodeMetadata], - mm_embed_inputs: Optional[tuple[list[torch.Tensor], - torch.Tensor]] = None, + mm_embed_inputs: Optional[tuple[list[torch.Tensor], torch.Tensor]] = None, ) -> torch.Tensor: num_tokens = target_token_ids.shape[0] batch_size = common_attn_metadata.batch_size() @@ -210,27 +213,32 @@ def propose( if self.method == "eagle3": assert isinstance(self.model, Eagle3LlamaForCausalLM) target_hidden_states = self.model.combine_hidden_states( - target_hidden_states) + target_hidden_states + ) assert target_hidden_states.shape[-1] == self.hidden_size - self.set_input_ids_first_pass(target_token_ids, next_token_ids, - num_tokens, last_token_indices) + self.set_input_ids_first_pass( + target_token_ids, next_token_ids, num_tokens, last_token_indices + ) assert self.runner is not None # FIXME: need to consider multiple kv_cache_groups ubatch_id = dbo_current_ubatch_id() - attn_metadata_builder = \ - self.runner.attn_groups[0][0].metadata_builders[ubatch_id] + attn_metadata_builder = self.runner.attn_groups[0][0].metadata_builders[ + ubatch_id + ] attn_metadata = attn_metadata_builder.build_for_drafting( - common_attn_metadata=common_attn_metadata, draft_index=0) + common_attn_metadata=common_attn_metadata, draft_index=0 + ) # FIXME: support hybrid kv for draft model (remove separate indexer) if self.draft_indexer_metadata_builder: draft_indexer_metadata = ( self.draft_indexer_metadata_builder.build_for_drafting( common_attn_metadata=common_attn_metadata, draft_index=0, - )) + ) + ) else: draft_indexer_metadata = None # At this moment, we assume all eagle layers belong to the same KV @@ -242,8 +250,7 @@ def propose( assert draft_indexer_metadata is not None per_layer_attn_metadata[layer_name] = draft_indexer_metadata - if self.use_cuda_graph and \ - num_tokens <= self.cudagraph_batch_sizes[-1]: + if self.use_cuda_graph and num_tokens <= self.cudagraph_batch_sizes[-1]: num_input_tokens = self.vllm_config.pad_for_cudagraph(num_tokens) else: num_input_tokens = num_tokens @@ -275,8 +282,7 @@ def propose( "inputs_embeds": inputs_embeds, } if self.pass_hidden_states_to_model: - model_kwargs[ - "hidden_states"] = self.hidden_states[:num_input_tokens] + model_kwargs["hidden_states"] = self.hidden_states[:num_input_tokens] forward_ctx_kwargs = dict( attn_metadata=per_layer_attn_metadata, @@ -327,28 +333,30 @@ def propose( draft_token_ids = logits.argmax(dim=-1) - if self.allowed_attn_types is not None and \ - not isinstance(attn_metadata, self.allowed_attn_types): + if self.allowed_attn_types is not None and not isinstance( + attn_metadata, self.allowed_attn_types + ): raise ValueError( f"Unsupported attention metadata type for speculative " "decoding with num_speculative_tokens > 1: " f"{type(attn_metadata)}. Supported types are: " - f"{self.allowed_attn_types}") + f"{self.allowed_attn_types}" + ) # Generate the remaining draft tokens. draft_token_ids_list = [draft_token_ids] - if self.use_cuda_graph and \ - batch_size <= self.cudagraph_batch_sizes[-1]: + if self.use_cuda_graph and batch_size <= self.cudagraph_batch_sizes[-1]: input_batch_size = self.vllm_config.pad_for_cudagraph(batch_size) else: input_batch_size = batch_size common_attn_metadata.num_actual_tokens = batch_size common_attn_metadata.max_query_len = 1 - common_attn_metadata.query_start_loc = self.arange[:batch_size + 1] + common_attn_metadata.query_start_loc = self.arange[: batch_size + 1] common_attn_metadata.query_start_loc_cpu = torch.from_numpy( - self.token_arange_np[:batch_size + 1]).clone() + self.token_arange_np[: batch_size + 1] + ).clone() for token_index in range(self.num_speculative_tokens - 1): # Update the inputs. # cast to int32 is crucial when eagle model is compiled. @@ -367,14 +375,15 @@ def propose( exceeds_max_model_len = positions[0] >= self.max_model_len # Mask out the position ids that exceed the max model length. # Otherwise, we may get out-of-range error in RoPE. - clamped_positions = torch.where\ - (exceeds_max_model_len.unsqueeze(0), \ - torch.zeros_like(positions), positions) + clamped_positions = torch.where( + exceeds_max_model_len.unsqueeze(0), + torch.zeros_like(positions), + positions, + ) else: positions += 1 exceeds_max_model_len = positions >= self.max_model_len - clamped_positions = torch.where(exceeds_max_model_len, 0, - positions) + clamped_positions = torch.where(exceeds_max_model_len, 0, positions) # Increment the sequence lengths. common_attn_metadata.seq_lens += 1 @@ -382,11 +391,11 @@ def propose( # For the requests that exceed the max model length, we set the # sequence length to 1 to minimize their overheads in attention. - common_attn_metadata.seq_lens.masked_fill_(exceeds_max_model_len, - 1) + common_attn_metadata.seq_lens.masked_fill_(exceeds_max_model_len, 1) - common_attn_metadata.num_computed_tokens_cpu = \ + common_attn_metadata.num_computed_tokens_cpu = ( common_attn_metadata.seq_lens_cpu - 1 + ) # Compute the slot mapping. if self.uses_mrope: @@ -395,26 +404,28 @@ def propose( else: block_numbers = clamped_positions // self.block_size block_ids = common_attn_metadata.block_table_tensor.gather( - dim=1, index=block_numbers.view(-1, 1)) + dim=1, index=block_numbers.view(-1, 1) + ) block_ids = block_ids.view(-1) if self.uses_mrope: common_attn_metadata.slot_mapping = ( - block_ids * self.block_size + - clamped_positions[0] % self.block_size) + block_ids * self.block_size + clamped_positions[0] % self.block_size + ) else: common_attn_metadata.slot_mapping = ( - block_ids * self.block_size + - clamped_positions % self.block_size) + block_ids * self.block_size + clamped_positions % self.block_size + ) # Mask out the slot mappings that exceed the max model length. # Otherwise, the KV cache will be inadvertently updated with the # padding tokens. common_attn_metadata.slot_mapping.masked_fill_( - exceeds_max_model_len, PADDING_SLOT_ID) + exceeds_max_model_len, PADDING_SLOT_ID + ) # Rebuild attention metadata attn_metadata = attn_metadata_builder.build_for_drafting( # type: ignore - common_attn_metadata=common_attn_metadata, - draft_index=token_index + 1) + common_attn_metadata=common_attn_metadata, draft_index=token_index + 1 + ) for layer_name in self.attn_layer_names: per_layer_attn_metadata[layer_name] = attn_metadata @@ -423,8 +434,9 @@ def propose( self._set_positions(batch_size, clamped_positions) self.hidden_states[:batch_size] = hidden_states if self.supports_mm_inputs: - self.inputs_embeds[:batch_size] = \ - self.model.get_input_embeddings(input_ids) + self.inputs_embeds[:batch_size] = self.model.get_input_embeddings( + input_ids + ) input_ids = None inputs_embeds = self.inputs_embeds[:input_batch_size] @@ -439,8 +451,7 @@ def propose( "inputs_embeds": inputs_embeds, } if self.pass_hidden_states_to_model: - model_kwargs[ - "hidden_states"] = self.hidden_states[:input_batch_size] + model_kwargs["hidden_states"] = self.hidden_states[:input_batch_size] forward_ctx_kwargs = dict( attn_metadata=per_layer_attn_metadata, @@ -449,7 +460,8 @@ def propose( ) if self.pass_cudagraph_args_to_forward_ctx: cudagraph_args = self.decoding_cudagraph_args( - num_tokens=input_batch_size) + num_tokens=input_batch_size + ) forward_ctx_kwargs.update(cudagraph_args) with set_forward_context(**forward_ctx_kwargs): @@ -469,12 +481,16 @@ def propose( draft_token_ids = torch.stack(draft_token_ids_list, dim=1) return draft_token_ids - def set_input_ids_first_pass(self, target_token_ids: torch.Tensor, - next_token_ids: torch.Tensor, num_tokens: int, - last_token_indices: torch.Tensor) -> None: + def set_input_ids_first_pass( + self, + target_token_ids: torch.Tensor, + next_token_ids: torch.Tensor, + num_tokens: int, + last_token_indices: torch.Tensor, + ) -> None: # Shift the input ids by one token. # E.g., [a1, b1, b2, c1, c2, c3] -> [b1, b2, c1, c2, c3, c3] - self.input_ids[:num_tokens - 1] = target_token_ids[1:] + self.input_ids[: num_tokens - 1] = target_token_ids[1:] # Replace the last token with the next token. # E.g., [b1, b2, c1, c2, c3, c3] -> [a2, b2, b3, c2, c3, c4] self.input_ids[last_token_indices] = next_token_ids @@ -483,20 +499,22 @@ def model_returns_tuple(self) -> bool: return self.method not in ("mtp", "draft_model") def decoding_cudagraph_args(self, num_tokens: int) -> "CudaGraphArgs": - batch_descriptor = BatchDescriptor(num_tokens=num_tokens, - uniform_decode=True) + batch_descriptor = BatchDescriptor(num_tokens=num_tokens, uniform_decode=True) cudagraph_runtime_mode, batch_descriptor = ( - self.runner.cudagraph_dispatcher.dispatch(batch_descriptor)) + self.runner.cudagraph_dispatcher.dispatch(batch_descriptor) + ) return CudaGraphArgs( cudagraph_runtime_mode=cudagraph_runtime_mode, batch_descriptor=batch_descriptor, ) def prepare_next_token_ids_cpu( - self, sampled_token_ids: list[list[int]], - requests: dict[str, - CachedRequestState], gpu_input_batch: InputBatch, - num_scheduled_tokens: dict[str, int]) -> torch.Tensor: + self, + sampled_token_ids: list[list[int]], + requests: dict[str, CachedRequestState], + gpu_input_batch: InputBatch, + num_scheduled_tokens: dict[str, int], + ) -> torch.Tensor: """ This function is used to prepare the inputs for speculative decoding. It calculates the next token ids for each request based on the sampled @@ -515,23 +533,23 @@ def prepare_next_token_ids_cpu( # Get the next token id from the request state. req_id = req_ids[i] req_state = requests[req_id] - seq_len = (req_state.num_computed_tokens + - num_scheduled_tokens[req_id]) + seq_len = req_state.num_computed_tokens + num_scheduled_tokens[req_id] next_token_id = req_state.get_token_id(seq_len) next_token_ids.append(next_token_id) - next_token_ids = torch.tensor(next_token_ids, - dtype=torch.int32, - device=self.input_ids.device) + next_token_ids = torch.tensor( + next_token_ids, dtype=torch.int32, device=self.input_ids.device + ) return next_token_ids - def prepare_next_token_ids_padded(self, - common_attn_metadata: CommonAttentionMetadata, - sampled_token_ids: torch.Tensor, - requests: dict[str, CachedRequestState], - gpu_input_batch: InputBatch, - discard_request_indices: torch.Tensor, - num_discarded_requests: int) -> \ - tuple[torch.Tensor, torch.Tensor]: + def prepare_next_token_ids_padded( + self, + common_attn_metadata: CommonAttentionMetadata, + sampled_token_ids: torch.Tensor, + requests: dict[str, CachedRequestState], + gpu_input_batch: InputBatch, + discard_request_indices: torch.Tensor, + num_discarded_requests: int, + ) -> tuple[torch.Tensor, torch.Tensor]: """ This function is used to prepare the inputs for speculative decoding. It calculates the next token ids and the number of valid sampled tokens @@ -545,30 +563,34 @@ def prepare_next_token_ids_padded(self, # Precompute get_token_id for when there is no valid next token num_reqs = gpu_input_batch.num_reqs - self.backup_next_token_ids.np[:num_reqs] = np.array([ - requests[gpu_input_batch.req_ids[i]].get_token_id( - common_attn_metadata.seq_lens_cpu[i].item()) - for i in range(num_reqs) - ]) + self.backup_next_token_ids.np[:num_reqs] = np.array( + [ + requests[gpu_input_batch.req_ids[i]].get_token_id( + common_attn_metadata.seq_lens_cpu[i].item() + ) + for i in range(num_reqs) + ] + ) self.backup_next_token_ids.copy_to_gpu(num_reqs) # Mask out the sampled tokens indices that should not be sampled. - discard_sampled_tokens_req_indices = \ - discard_request_indices[:num_discarded_requests] + discard_sampled_tokens_req_indices = discard_request_indices[ + :num_discarded_requests + ] valid_sampled_token_ids_gpu = sampled_token_ids.clone() valid_sampled_token_ids_gpu.index_fill_( - 0, discard_sampled_tokens_req_indices, -1) + 0, discard_sampled_tokens_req_indices, -1 + ) # Generate a mask for all valid tokens within those requests max_gen_len = sampled_token_ids.shape[-1] if max_gen_len == 1: - valid_mask = torch.ones_like(valid_sampled_token_ids_gpu, - dtype=torch.bool) + valid_mask = torch.ones_like(valid_sampled_token_ids_gpu, dtype=torch.bool) else: - valid_mask = ( - (valid_sampled_token_ids_gpu != -1) & - (valid_sampled_token_ids_gpu < gpu_input_batch.vocab_size)) + valid_mask = (valid_sampled_token_ids_gpu != -1) & ( + valid_sampled_token_ids_gpu < gpu_input_batch.vocab_size + ) # Count the number of valid tokens in each request valid_sampled_tokens_count = valid_mask.sum(dim=1) @@ -580,22 +602,25 @@ def prepare_next_token_ids_padded(self, # Get last valid token from each row # (assume undefined state where there is no valid token) selected_tokens = torch.gather( - valid_sampled_token_ids_gpu, 1, - last_valid_indices_safe.unsqueeze(1)).squeeze(1) + valid_sampled_token_ids_gpu, 1, last_valid_indices_safe.unsqueeze(1) + ).squeeze(1) # Use last token if valid, pre-computed backup if not batch_size = valid_sampled_token_ids_gpu.shape[0] next_token_ids = torch.where( - last_valid_indices != -1, selected_tokens, - self.backup_next_token_ids.gpu[:batch_size]) + last_valid_indices != -1, + selected_tokens, + self.backup_next_token_ids.gpu[:batch_size], + ) return next_token_ids, valid_sampled_tokens_count - def prepare_inputs_padded(self, - common_attn_metadata: CommonAttentionMetadata, - spec_decode_metadata: SpecDecodeMetadata, - valid_sampled_tokens_count: torch.Tensor) -> \ - tuple[CommonAttentionMetadata, torch.Tensor, torch.Tensor]: + def prepare_inputs_padded( + self, + common_attn_metadata: CommonAttentionMetadata, + spec_decode_metadata: SpecDecodeMetadata, + valid_sampled_tokens_count: torch.Tensor, + ) -> tuple[CommonAttentionMetadata, torch.Tensor, torch.Tensor]: """ This function is used to prepare the inputs for speculative decoding It updates the common_attn_metadata for speculative decoding, @@ -605,11 +630,11 @@ def prepare_inputs_padded(self, No blocking CPU operations should be introduced in this function. """ num_rejected_tokens_gpu = num_rejected_tokens( - spec_decode_metadata, valid_sampled_tokens_count) + spec_decode_metadata, valid_sampled_tokens_count + ) query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu - new_query_len_per_req = (query_start_loc_cpu[1:] - - query_start_loc_cpu[:-1]) + new_query_len_per_req = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1] total_num_tokens = query_start_loc_cpu[-1].item() token_indices = self.arange[:total_num_tokens] @@ -619,8 +644,7 @@ def prepare_inputs_padded(self, seq_lens=common_attn_metadata.seq_lens, query_start_loc_cpu=query_start_loc_cpu, seq_lens_cpu=common_attn_metadata.seq_lens_cpu, - num_computed_tokens_cpu=common_attn_metadata. - num_computed_tokens_cpu, + num_computed_tokens_cpu=common_attn_metadata.num_computed_tokens_cpu, num_reqs=common_attn_metadata.num_reqs, num_actual_tokens=total_num_tokens, max_query_len=new_query_len_per_req.max().item(), @@ -630,8 +654,9 @@ def prepare_inputs_padded(self, causal=True, ) - token_indices_to_sample = common_attn_metadata.query_start_loc[1:] - 1 \ - - num_rejected_tokens_gpu + token_indices_to_sample = ( + common_attn_metadata.query_start_loc[1:] - 1 - num_rejected_tokens_gpu + ) return spec_common_attn_metadata, token_indices, token_indices_to_sample @@ -646,10 +671,10 @@ def propose_tree( hidden_states: torch.Tensor, common_attn_metadata: CommonAttentionMetadata, ) -> list[torch.Tensor]: - tree_attn_metadata_builder = \ - self.runner.attn_groups[0][0].get_metadata_builder() - assert isinstance(tree_attn_metadata_builder, - TreeAttentionMetadataBuilder) + tree_attn_metadata_builder = self.runner.attn_groups[0][ + 0 + ].get_metadata_builder() + assert isinstance(tree_attn_metadata_builder, TreeAttentionMetadataBuilder) total_num_drafts = self.cu_drafts_per_level[0] level_num_drafts = total_num_drafts @@ -658,31 +683,31 @@ def propose_tree( if num_children == 1: draft_token_ids = logits.argmax(dim=-1).view(batch_size, -1) else: - draft_token_ids = torch.topk(logits, num_children, - dim=-1).indices.view(batch_size, -1) + draft_token_ids = torch.topk(logits, num_children, dim=-1).indices.view( + batch_size, -1 + ) draft_token_ids_list = [draft_token_ids] draft_hidden_states = hidden_states.view(batch_size, 1, -1) # Initialize empty tensors for concatenation with the level outputs. - tree_input_ids = torch.empty(0, - device=self.input_ids.device, - dtype=self.input_ids.dtype) - tree_positions = torch.empty(0, - device=self.positions.device, - dtype=self.positions.dtype) - tree_hidden_states = torch.empty(0, - device=self.hidden_states.device, - dtype=self.hidden_states.dtype) + tree_input_ids = torch.empty( + 0, device=self.input_ids.device, dtype=self.input_ids.dtype + ) + tree_positions = torch.empty( + 0, device=self.positions.device, dtype=self.positions.dtype + ) + tree_hidden_states = torch.empty( + 0, device=self.hidden_states.device, dtype=self.hidden_states.dtype + ) # Precompute the draft token positions. flattened_draft_positions = ( - positions.view(batch_size, -1) + - self.tree_draft_pos_offsets[:batch_size, :]) + positions.view(batch_size, -1) + self.tree_draft_pos_offsets[:batch_size, :] + ) tree_depth = len(self.cu_drafts_per_level) for level in range(tree_depth - 1): # Get draft positions for RoPE. draft_positions = positions + (level + 1) - exceeds_max_model_len = (positions + - total_num_drafts) >= self.max_model_len + exceeds_max_model_len = (positions + total_num_drafts) >= self.max_model_len # Mask out the position ids that exceed the max model length. # Otherwise, we may get out-of-range error in RoPE. draft_positions = torch.where( @@ -694,27 +719,28 @@ def propose_tree( if level_num_drafts > 1: # Repeat the positions for each draft at this level. draft_positions = draft_positions.repeat_interleave( - level_num_drafts, dim=1) + level_num_drafts, dim=1 + ) if num_children > 1: # Repeat draft hidden states for each child. draft_hidden_states = draft_hidden_states.repeat_interleave( - num_children, dim=1) + num_children, dim=1 + ) # Concatenate the draft tokens, positions, and hidden states. - tree_input_ids = torch.cat([tree_input_ids, draft_token_ids], - dim=1) - tree_positions = torch.cat([tree_positions, draft_positions], - dim=1) + tree_input_ids = torch.cat([tree_input_ids, draft_token_ids], dim=1) + tree_positions = torch.cat([tree_positions, draft_positions], dim=1) tree_hidden_states = torch.cat( - [tree_hidden_states, draft_hidden_states], dim=1) + [tree_hidden_states, draft_hidden_states], dim=1 + ) # Build new attention metadata for the next level of drafts. # This is necessary to support tree attention. query_len = total_num_drafts common_attn_metadata = replace( common_attn_metadata, - query_start_loc=query_len * self.arange[:batch_size + 1], + query_start_loc=query_len * self.arange[: batch_size + 1], seq_lens=common_attn_metadata.seq_lens + level_num_drafts, num_actual_tokens=batch_size * query_len, max_query_len=query_len, @@ -730,20 +756,20 @@ def propose_tree( per_layer_attn_metadata[layer_name] = attn_metadata # Consider max model length. - attn_metadata.max_seq_len = min(attn_metadata.max_seq_len, - self.max_model_len) + attn_metadata.max_seq_len = min( + attn_metadata.max_seq_len, self.max_model_len + ) # For the requests that exceed the max model length, we set the # sequence length to 1 to minimize their overheads in attention. attn_metadata.seq_lens.masked_fill_(exceeds_max_model_len, 1) # Compute the slot mapping. - query_positions = flattened_draft_positions[:, level:level + - query_len] + query_positions = flattened_draft_positions[:, level : level + query_len] block_numbers = query_positions // self.block_size - block_ids = attn_metadata.block_table.gather(dim=1, - index=block_numbers) - slot_mapping = (block_ids * self.block_size + - query_positions % self.block_size) + block_ids = attn_metadata.block_table.gather(dim=1, index=block_numbers) + slot_mapping = ( + block_ids * self.block_size + query_positions % self.block_size + ) # Mask out the slot mappings that exceed the max model length. # Otherwise, the KV cache will be inadvertently updated with the # padding tokens. @@ -755,19 +781,16 @@ def propose_tree( input_ids = tree_input_ids.view(-1) self.input_ids[:num_tokens] = input_ids self.positions[:num_tokens] = tree_positions.view(-1) - self.hidden_states[:num_tokens] = tree_hidden_states.view( - num_tokens, -1) + self.hidden_states[:num_tokens] = tree_hidden_states.view(num_tokens, -1) - if self.use_cuda_graph and \ - num_tokens <= self.cudagraph_batch_sizes[-1]: - num_input_tokens = self.vllm_config.pad_for_cudagraph( - num_tokens) + if self.use_cuda_graph and num_tokens <= self.cudagraph_batch_sizes[-1]: + num_input_tokens = self.vllm_config.pad_for_cudagraph(num_tokens) else: num_input_tokens = num_tokens # Run the model. - with set_forward_context(per_layer_attn_metadata, - self.vllm_config, - num_tokens=num_input_tokens): + with set_forward_context( + per_layer_attn_metadata, self.vllm_config, num_tokens=num_input_tokens + ): last_hidden_states, hidden_states = self.model( input_ids=self.input_ids[:num_input_tokens], positions=self.positions[:num_input_tokens], @@ -777,28 +800,29 @@ def propose_tree( # Get the output hidden states for the draft tokens. draft_hidden_states = hidden_states[:num_tokens].view( - batch_size, query_len, -1)[:, -level_num_drafts:] + batch_size, query_len, -1 + )[:, -level_num_drafts:] draft_last_hidden_states = last_hidden_states[:num_tokens].view( - batch_size, query_len, -1)[:, -level_num_drafts:] + batch_size, query_len, -1 + )[:, -level_num_drafts:] # Get the output logits for the draft tokens. logits = self.model.compute_logits( - draft_last_hidden_states.reshape(batch_size * level_num_drafts, - -1)) + draft_last_hidden_states.reshape(batch_size * level_num_drafts, -1) + ) # Sample a draft token for each child at the next tree level. num_children = self.child_drafts_per_level[level + 1] if num_children == 1: draft_token_ids = logits.argmax(dim=-1).view(batch_size, -1) else: - draft_token_ids = torch.topk(logits, num_children, - dim=-1).indices.view( - batch_size, -1) + draft_token_ids = torch.topk(logits, num_children, dim=-1).indices.view( + batch_size, -1 + ) draft_token_ids_list.append(draft_token_ids) # Update the # drafts counters for the next tree level. - level_num_drafts = self.cu_drafts_per_level[level + - 1] - total_num_drafts + level_num_drafts = self.cu_drafts_per_level[level + 1] - total_num_drafts total_num_drafts = self.cu_drafts_per_level[level + 1] return draft_token_ids_list @@ -834,17 +858,14 @@ def prepare_inputs( n + 1 - len(sampled_token_ids[i]) if n > 0 else 0 for i, n in enumerate(num_draft_tokens) ] - num_rejected_tokens = torch.tensor(num_rejected_tokens, - dtype=torch.int32) + num_rejected_tokens = torch.tensor(num_rejected_tokens, dtype=torch.int32) device = common_attn_metadata.query_start_loc.device query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu - new_seq_lens_cpu = common_attn_metadata.seq_lens_cpu \ - - num_rejected_tokens + new_seq_lens_cpu = common_attn_metadata.seq_lens_cpu - num_rejected_tokens # [0, q1, q1 + q2, q1 + q2 + q3] -> [q1, q2, q3] - new_query_len_per_req = (query_start_loc_cpu[1:] - - query_start_loc_cpu[:-1]) + new_query_len_per_req = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1] # [q1, q2, q3] -> [q1 - n1, q2 - n2, q3 - n3] new_num_tokens_per_req = new_query_len_per_req - num_rejected_tokens new_num_tokens_per_req_np = new_num_tokens_per_req.numpy() @@ -854,7 +875,8 @@ def prepare_inputs( new_query_start_loc_cpu = torch.zeros( query_start_loc_cpu.shape, dtype=torch.int32, - pin_memory=is_pin_memory_available()) + pin_memory=is_pin_memory_available(), + ) new_query_start_loc_np = new_query_start_loc_cpu.numpy() np.cumsum(new_num_tokens_per_req_np, out=new_query_start_loc_np[1:]) @@ -864,36 +886,36 @@ def prepare_inputs( # [0, 2, 6, 9] -> # [0, 0, 2, 2, 2, 2, 6, 6, 6] # _r1_ ____r2____ ___r3__ - new_query_start_locs_expanded = np.repeat(new_query_start_loc_np[:-1], - new_num_tokens_per_req_np) + new_query_start_locs_expanded = np.repeat( + new_query_start_loc_np[:-1], new_num_tokens_per_req_np + ) # [0, 1, 2, 3, 4, 5, 6, 7, 8] -> # [0, 1, 0, 1, 2, 3, 0, 1, 2] # _r1_ ____r2____ ___r3__ - token_offests = self.token_arange_np[:total_num_tokens] \ - - new_query_start_locs_expanded + token_offests = ( + self.token_arange_np[:total_num_tokens] - new_query_start_locs_expanded + ) # Expand starting positions to match token pattern # [0, q1, q1 + q2] -> # [0, 0, q1, q1, q1, q1, q1 + q2, q1 + q2, q1 + q2] # _r1_ _____r2_______ ___________r3____________ old_query_start_locs_expanded = np.repeat( - query_start_loc_cpu[:-1].numpy(), new_num_tokens_per_req_np) + query_start_loc_cpu[:-1].numpy(), new_num_tokens_per_req_np + ) # Final token indices are: # [0, 1, // req 1 # q1 + 0, q1 + 1, q1 + 2, q1 + 3, // req 2 # q1 + q2 + 0, q1 + q2 + 1, q1 + q2 + 2] // req 3 token_indices_np = token_offests + old_query_start_locs_expanded - token_indices = torch.from_numpy(token_indices_np).to( - device, non_blocking=True) + token_indices = torch.from_numpy(token_indices_np).to(device, non_blocking=True) spec_common_attn_metadata = CommonAttentionMetadata( - query_start_loc=new_query_start_loc_cpu.to(device, - non_blocking=True), + query_start_loc=new_query_start_loc_cpu.to(device, non_blocking=True), seq_lens=new_seq_lens_cpu.to(device, non_blocking=True), query_start_loc_cpu=new_query_start_loc_cpu, seq_lens_cpu=new_seq_lens_cpu, - num_computed_tokens_cpu=common_attn_metadata. - num_computed_tokens_cpu, + num_computed_tokens_cpu=common_attn_metadata.num_computed_tokens_cpu, num_reqs=common_attn_metadata.num_reqs, num_actual_tokens=total_num_tokens, max_query_len=new_query_len_per_req.max().item(), @@ -906,45 +928,52 @@ def prepare_inputs( return spec_common_attn_metadata, token_indices def get_model_name(self, model: nn.Module) -> str: - if hasattr(model, 'module'): # multi-GPU + if hasattr(model, "module"): # multi-GPU model = model.module return model.__class__.__name__ def load_model(self, target_model: nn.Module) -> None: - draft_model_config = \ - self.vllm_config.speculative_config.draft_model_config + draft_model_config = self.vllm_config.speculative_config.draft_model_config target_attn_layer_names = set( - get_layers_from_vllm_config(self.vllm_config, Attention).keys()) + get_layers_from_vllm_config(self.vllm_config, Attention).keys() + ) # FIXME: support hybrid kv for draft model target_indexer_layer_names = set( - get_layers_from_vllm_config(self.vllm_config, - DeepseekV32IndexerCache).keys()) + get_layers_from_vllm_config( + self.vllm_config, DeepseekV32IndexerCache + ).keys() + ) from vllm.compilation.backends import set_model_tag + with set_model_tag("eagle_head"): - self.model = get_model(vllm_config=self.vllm_config, - model_config=draft_model_config) + self.model = get_model( + vllm_config=self.vllm_config, model_config=draft_model_config + ) draft_attn_layer_names = ( - get_layers_from_vllm_config(self.vllm_config, Attention).keys() - - target_attn_layer_names) - indexer_layers = get_layers_from_vllm_config(self.vllm_config, - DeepseekV32IndexerCache) - draft_indexer_layer_names = (indexer_layers.keys() - - target_indexer_layer_names) + get_layers_from_vllm_config(self.vllm_config, Attention).keys() + - target_attn_layer_names + ) + indexer_layers = get_layers_from_vllm_config( + self.vllm_config, DeepseekV32IndexerCache + ) + draft_indexer_layer_names = indexer_layers.keys() - target_indexer_layer_names self.attn_layer_names = list(draft_attn_layer_names) self.indexer_layer_names = list(draft_indexer_layer_names) if self.indexer_layer_names: first_layer = self.indexer_layer_names[0] self.draft_indexer_metadata_builder = ( - indexer_layers[first_layer].get_attn_backend().get_builder_cls( - )( + indexer_layers[first_layer] + .get_attn_backend() + .get_builder_cls()( indexer_layers[first_layer].get_kv_cache_spec(), self.indexer_layer_names, self.vllm_config, self.device, - )) + ) + ) else: self.draft_indexer_metadata_builder = None @@ -952,38 +981,41 @@ def load_model(self, target_model: nn.Module) -> None: # Even if the target model is multimodal, we can also use # text-only draft models try: - dummy_input_ids = torch.tensor([[1]], - device=self.input_ids.device) - self.model.get_input_embeddings(dummy_input_ids, - multimodal_embeddings=None) + dummy_input_ids = torch.tensor([[1]], device=self.input_ids.device) + self.model.get_input_embeddings( + dummy_input_ids, multimodal_embeddings=None + ) except (NotImplementedError, AttributeError, TypeError): logger.warning( "Draft model does not support multimodal inputs, " - "falling back to text-only mode") + "falling back to text-only mode" + ) self.supports_mm_inputs = False if supports_multimodal(target_model): # handle multimodality - if (self.get_model_name(target_model) == - "Qwen2_5_VLForConditionalGeneration"): - self.model.config.image_token_index = ( - target_model.config.image_token_id) + if ( + self.get_model_name(target_model) + == "Qwen2_5_VLForConditionalGeneration" + ): + self.model.config.image_token_index = target_model.config.image_token_id else: self.model.config.image_token_index = ( - target_model.config.image_token_index) + target_model.config.image_token_index + ) target_language_model = target_model.get_language_model() else: target_language_model = target_model # share embed_tokens with the target model if needed if get_pp_group().world_size == 1: - if hasattr(target_language_model.model, 'embed_tokens'): + if hasattr(target_language_model.model, "embed_tokens"): target_embed_tokens = target_language_model.model.embed_tokens - elif hasattr(target_language_model.model, 'embedding'): + elif hasattr(target_language_model.model, "embedding"): target_embed_tokens = target_language_model.model.embedding else: raise AttributeError( - "Target model does not have 'embed_tokens' or 'embedding' " - "attribute") + "Target model does not have 'embed_tokens' or 'embedding' attribute" + ) # Check if shapes match and we found the embedding eagle_shape = self.model.model.embed_tokens.weight.shape @@ -991,47 +1023,53 @@ def load_model(self, target_model: nn.Module) -> None: if eagle_shape == target_shape: logger.info( "Assuming the EAGLE head shares the same vocab embedding" - " with the target model.") + " with the target model." + ) del self.model.model.embed_tokens self.model.model.embed_tokens = target_embed_tokens else: logger.info( "The EAGLE head's vocab embedding will be loaded separately" - " from the target model.") + " from the target model." + ) else: logger.info( "The EAGLE head's vocab embedding will be loaded separately" - " from the target model.") + " from the target model." + ) # share lm_head with the target model if needed # some model definition do not define lm_head explicitly # and reuse embed_tokens for lm_head, e.g., CohereForCausalLM if self.vllm_config.speculative_config.method != "eagle3": if hasattr(target_language_model, "lm_head"): - logger.info( - "Loading EAGLE LM head weights from the target model.") + logger.info("Loading EAGLE LM head weights from the target model.") self.model.lm_head = target_language_model.lm_head else: - if (hasattr(self.model, "lm_head") - and hasattr(target_language_model, "lm_head") - and self.model.lm_head.weight.shape - == target_language_model.lm_head.weight.shape): - logger.info("Assuming the EAGLE head shares the same lm_head" - " with the target model.") + if ( + hasattr(self.model, "lm_head") + and hasattr(target_language_model, "lm_head") + and self.model.lm_head.weight.shape + == target_language_model.lm_head.weight.shape + ): + logger.info( + "Assuming the EAGLE head shares the same lm_head" + " with the target model." + ) del self.model.lm_head self.model.lm_head = target_language_model.lm_head else: logger.info( "The EAGLE head's lm_head will be loaded separately" - " from the target model.") + " from the target model." + ) @torch.inference_mode() def dummy_run( self, num_tokens: int, ) -> None: - with set_forward_context(None, self.vllm_config, - num_tokens=num_tokens): + with set_forward_context(None, self.vllm_config, num_tokens=num_tokens): if self.supports_mm_inputs: input_ids = None inputs_embeds = self.inputs_embeds[:num_tokens] @@ -1046,8 +1084,7 @@ def dummy_run( inputs_embeds=inputs_embeds, ) - def _get_attention_metadata_builder( - self) -> list[AttentionMetadataBuilder]: + def _get_attention_metadata_builder(self) -> list[AttentionMetadataBuilder]: """Find and return the attention metadata builders for EAGLE layers. Returns: @@ -1068,11 +1105,11 @@ def _get_attention_metadata_builder( break assert builder is not None, ( - "Failed to find attention metadata builder for EAGLE layers.") + "Failed to find attention metadata builder for EAGLE layers." + ) return builder - def validate_same_kv_cache_group(self, - kv_cache_config: KVCacheConfig) -> None: + def validate_same_kv_cache_group(self, kv_cache_config: KVCacheConfig) -> None: """ Validate that all eagle layers belong to the same KVCacheGroup. Need this assumption to ensure all eagle layers can use the @@ -1083,12 +1120,17 @@ def validate_same_kv_cache_group(self, for id, kv_cache_group in enumerate(kv_cache_config.kv_cache_groups): for layer_name in kv_cache_group.layer_names: kv_cache_groups[layer_name] = id - assert len( - set([ - kv_cache_groups[layer_name] - for layer_name in self.attn_layer_names - ]) - ) == 1, "All eagle layers should belong to the same kv cache group" + assert ( + len( + set( + [ + kv_cache_groups[layer_name] + for layer_name in self.attn_layer_names + ] + ) + ) + == 1 + ), "All eagle layers should belong to the same kv cache group" class CudaGraphArgs(TypedDict): @@ -1097,18 +1139,19 @@ class CudaGraphArgs(TypedDict): class EagleProposer(SpecDecodeBaseProposer): - def __init__( self, vllm_config: VllmConfig, device: torch.device, runner=None, ): - super().__init__(vllm_config, - device, - pass_hidden_states_to_model=True, - pass_cudagraph_args_to_forward_ctx=False, - runner=runner) + super().__init__( + vllm_config, + device, + pass_hidden_states_to_model=True, + pass_cudagraph_args_to_forward_ctx=False, + runner=runner, + ) # NOTE(woosuk): Currently, the below code is not used and we always use argmax @@ -1155,30 +1198,34 @@ def compute_probs_and_sample_next_token( def num_rejected_tokens( - spec_decode_metadata: Optional[SpecDecodeMetadata], - valid_sampled_tokens_count: torch.Tensor) -> torch.Tensor: + spec_decode_metadata: Optional[SpecDecodeMetadata], + valid_sampled_tokens_count: torch.Tensor, +) -> torch.Tensor: if spec_decode_metadata is None: return torch.zeros_like(valid_sampled_tokens_count) - num_draft_tokens_gpu = torch.cat([ - spec_decode_metadata.cu_num_draft_tokens[0:1], - spec_decode_metadata.cu_num_draft_tokens[1:] - - spec_decode_metadata.cu_num_draft_tokens[:-1] - ]) + num_draft_tokens_gpu = torch.cat( + [ + spec_decode_metadata.cu_num_draft_tokens[0:1], + spec_decode_metadata.cu_num_draft_tokens[1:] + - spec_decode_metadata.cu_num_draft_tokens[:-1], + ] + ) num_rejected_tokens_gpu = torch.where( num_draft_tokens_gpu > 0, num_draft_tokens_gpu + 1 - valid_sampled_tokens_count, - torch.zeros_like(num_draft_tokens_gpu)) + torch.zeros_like(num_draft_tokens_gpu), + ) return num_rejected_tokens_gpu -def update_batch_descriptor(cudagraph_args: CudaGraphArgs, - new_num_tokens: int) -> None: +def update_batch_descriptor(cudagraph_args: CudaGraphArgs, new_num_tokens: int) -> None: """The cudagraph padding can change the num_tokens, so the batch descriptor should be updated. The cudagraph_args is modified in place.""" old: Optional[BatchDescriptor] = cudagraph_args["batch_descriptor"] if old is not None: - new = BatchDescriptor(num_tokens=new_num_tokens, - uniform_decode=old.uniform_decode) + new = BatchDescriptor( + num_tokens=new_num_tokens, uniform_decode=old.uniform_decode + ) cudagraph_args["batch_descriptor"] = new diff --git a/vllm/v1/spec_decode/metrics.py b/vllm/v1/spec_decode/metrics.py index 437a9cf9f6e6..c4cd2cfbfa49 100644 --- a/vllm/v1/spec_decode/metrics.py +++ b/vllm/v1/spec_decode/metrics.py @@ -32,8 +32,10 @@ class SpecDecodingStats: @classmethod def new(cls, num_spec_tokens: int) -> "SpecDecodingStats": - return cls(num_spec_tokens=num_spec_tokens, - num_accepted_tokens_per_pos=[0] * num_spec_tokens) + return cls( + num_spec_tokens=num_spec_tokens, + num_accepted_tokens_per_pos=[0] * num_spec_tokens, + ) def observe_draft(self, num_draft_tokens: int, num_accepted_tokens: int): self.num_drafts += 1 @@ -65,10 +67,10 @@ def reset(self): def observe(self, spec_decoding_stats: SpecDecodingStats): self.num_drafts.append(spec_decoding_stats.num_drafts) self.num_draft_tokens.append(spec_decoding_stats.num_draft_tokens) - self.num_accepted_tokens.append( - spec_decoding_stats.num_accepted_tokens) + self.num_accepted_tokens.append(spec_decoding_stats.num_accepted_tokens) self.accepted_tokens_per_pos_lists.append( - spec_decoding_stats.num_accepted_tokens_per_pos) + spec_decoding_stats.num_accepted_tokens_per_pos + ) def log(self, log_fn=logger.info): if not self.num_drafts: @@ -84,8 +86,11 @@ def log(self, log_fn=logger.info): draft_throughput = num_draft_tokens / elapsed_time accepted_throughput = num_accepted_tokens / elapsed_time - draft_acceptance_rate = (num_accepted_tokens / num_draft_tokens * - 100 if num_draft_tokens > 0 else float("nan")) + draft_acceptance_rate = ( + num_accepted_tokens / num_draft_tokens * 100 + if num_draft_tokens > 0 + else float("nan") + ) # Conventionally, mean acceptance length includes the bonus token mean_acceptance_length = 1 + (num_accepted_tokens / num_drafts) @@ -150,27 +155,36 @@ def __init__( counter_drafts = self._counter_cls( name="vllm:spec_decode_num_drafts", documentation="Number of spec decoding drafts.", - labelnames=labelnames) + labelnames=labelnames, + ) self.counter_spec_decode_num_drafts = make_per_engine( - counter_drafts, per_engine_labelvalues) + counter_drafts, per_engine_labelvalues + ) counter_draft_tokens = self._counter_cls( name="vllm:spec_decode_num_draft_tokens", documentation="Number of draft tokens.", - labelnames=labelnames) + labelnames=labelnames, + ) self.counter_spec_decode_num_draft_tokens = make_per_engine( - counter_draft_tokens, per_engine_labelvalues) + counter_draft_tokens, per_engine_labelvalues + ) counter_accepted_tokens = self._counter_cls( name="vllm:spec_decode_num_accepted_tokens", documentation="Number of accepted tokens.", - labelnames=labelnames) + labelnames=labelnames, + ) self.counter_spec_decode_num_accepted_tokens = make_per_engine( - counter_accepted_tokens, per_engine_labelvalues) + counter_accepted_tokens, per_engine_labelvalues + ) assert speculative_config is not None - num_spec_tokens = (speculative_config.num_speculative_tokens - if self.spec_decoding_enabled else 0) + num_spec_tokens = ( + speculative_config.num_speculative_tokens + if self.spec_decoding_enabled + else 0 + ) pos_labelnames = labelnames + ["position"] base_counter = self._counter_cls( name="vllm:spec_decode_num_accepted_tokens_per_pos", @@ -178,52 +192,49 @@ def __init__( labelnames=pos_labelnames, ) self.counter_spec_decode_num_accepted_tokens_per_pos: dict[ - int, list[prometheus_client.Counter]] = { - idx: [ - base_counter.labels(*lv, str(pos)) - for pos in range(num_spec_tokens) - ] - for idx, lv in per_engine_labelvalues.items() - } - - def observe(self, - spec_decoding_stats: SpecDecodingStats, - engine_idx: int = 0): + int, list[prometheus_client.Counter] + ] = { + idx: [base_counter.labels(*lv, str(pos)) for pos in range(num_spec_tokens)] + for idx, lv in per_engine_labelvalues.items() + } + + def observe(self, spec_decoding_stats: SpecDecodingStats, engine_idx: int = 0): if not self.spec_decoding_enabled: return self.counter_spec_decode_num_drafts[engine_idx].inc( - spec_decoding_stats.num_drafts) + spec_decoding_stats.num_drafts + ) self.counter_spec_decode_num_draft_tokens[engine_idx].inc( - spec_decoding_stats.num_draft_tokens) + spec_decoding_stats.num_draft_tokens + ) self.counter_spec_decode_num_accepted_tokens[engine_idx].inc( - spec_decoding_stats.num_accepted_tokens) + spec_decoding_stats.num_accepted_tokens + ) for pos, counter in enumerate( - self. - counter_spec_decode_num_accepted_tokens_per_pos[engine_idx]): + self.counter_spec_decode_num_accepted_tokens_per_pos[engine_idx] + ): counter.inc(spec_decoding_stats.num_accepted_tokens_per_pos[pos]) def compute_acceptance_rate(metrics: list[Metric]) -> float: name2metric = {metric.name: metric for metric in metrics} - n_draft_toks = name2metric[ - "vllm:spec_decode_num_draft_tokens"].value # type: ignore - n_accepted_toks = name2metric[ - "vllm:spec_decode_num_accepted_tokens"].value # type: ignore + n_draft_toks = name2metric["vllm:spec_decode_num_draft_tokens"].value # type: ignore + n_accepted_toks = name2metric["vllm:spec_decode_num_accepted_tokens"].value # type: ignore return n_accepted_toks / n_draft_toks def compute_acceptance_len(metrics: list[Metric]) -> float: name2metric = {metric.name: metric for metric in metrics} n_drafts = name2metric["vllm:spec_decode_num_drafts"].value # type: ignore - n_accepted_toks = name2metric[ - "vllm:spec_decode_num_accepted_tokens"].value # type: ignore + n_accepted_toks = name2metric["vllm:spec_decode_num_accepted_tokens"].value # type: ignore if n_drafts == 0: return 1 return 1 + (n_accepted_toks / n_drafts) -def make_per_engine(counter: prometheus_client.Counter, - per_engine_labelvalues: dict[int, list[str]]): +def make_per_engine( + counter: prometheus_client.Counter, per_engine_labelvalues: dict[int, list[str]] +): """Create a counter for each label value.""" return { idx: counter.labels(*labelvalues) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 4728829e8725..487bced54c60 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -24,70 +24,112 @@ from vllm.compilation.counter import compilation_counter from vllm.compilation.cuda_graph import CUDAGraphWrapper from vllm.compilation.monitor import set_cudagraph_capturing_enabled -from vllm.config import (CompilationLevel, CUDAGraphMode, VllmConfig, - get_layers_from_vllm_config, update_config) +from vllm.config import ( + CompilationLevel, + CUDAGraphMode, + VllmConfig, + get_layers_from_vllm_config, + update_config, +) from vllm.distributed.eplb.eplb_state import EplbState -from vllm.distributed.kv_transfer import (get_kv_transfer_group, - has_kv_transfer_group) +from vllm.distributed.kv_transfer import get_kv_transfer_group, has_kv_transfer_group from vllm.distributed.kv_transfer.kv_connector.utils import copy_kv_blocks from vllm.distributed.parallel_state import ( - get_pp_group, get_tp_group, graph_capture, is_global_first_rank, - prepare_communication_buffer_for_model) -from vllm.forward_context import (BatchDescriptor, DPMetadata, - set_forward_context) + get_pp_group, + get_tp_group, + graph_capture, + is_global_first_rank, + prepare_communication_buffer_for_model, +) +from vllm.forward_context import BatchDescriptor, DPMetadata, set_forward_context from vllm.logger import init_logger from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase from vllm.model_executor.layers.mamba.abstract import MambaBase from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding from vllm.model_executor.model_loader import TensorizerLoader, get_model_loader from vllm.model_executor.models.deepseek_v2 import DeepseekV32IndexerCache + # yapf conflicts with isort for this block # yapf: disable -from vllm.model_executor.models.interfaces import (SupportsMultiModal, - is_mixture_of_experts, - supports_eagle3, - supports_mrope, - supports_multimodal_pruning, - supports_transcription) +from vllm.model_executor.models.interfaces import ( + SupportsMultiModal, + is_mixture_of_experts, + supports_eagle3, + supports_mrope, + supports_multimodal_pruning, + supports_transcription, +) + # yapf: enable from vllm.model_executor.models.interfaces_base import ( - VllmModelForPooling, is_pooling_model, is_text_generation_model) + VllmModelForPooling, + is_pooling_model, + is_text_generation_model, +) from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.inputs import (BatchedTensorInputs, MultiModalKwargsItem, - PlaceholderRange) +from vllm.multimodal.inputs import ( + BatchedTensorInputs, + MultiModalKwargsItem, + PlaceholderRange, +) from vllm.multimodal.utils import group_mm_kwargs_by_modality from vllm.pooling_params import PoolingParams from vllm.sampling_params import SamplingType from vllm.sequence import IntermediateTensors from vllm.tasks import GenerationTask, PoolingTask, SupportedTask -from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler, - GiB_bytes, cdiv, check_use_alibi, get_dtype_size, - is_pin_memory_available, - length_from_prompt_token_ids_or_embeds, round_up, - supports_dynamo) +from vllm.utils import ( + STR_DTYPE_TO_TORCH_DTYPE, + DeviceMemoryProfiler, + GiB_bytes, + cdiv, + check_use_alibi, + get_dtype_size, + is_pin_memory_available, + length_from_prompt_token_ids_or_embeds, + round_up, + supports_dynamo, +) from vllm.utils.jsontree import json_map_leaves from vllm.v1.attention.backends.flash_attn import AttentionMetadata from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadataBuilder from vllm.v1.attention.backends.utils import ( - AttentionCGSupport, AttentionMetadataBuilder, CommonAttentionMetadata, + AttentionCGSupport, + AttentionMetadataBuilder, + CommonAttentionMetadata, create_fast_prefill_custom_backend, - reorder_batch_to_split_decodes_and_prefills, split_attn_metadata) + reorder_batch_to_split_decodes_and_prefills, + split_attn_metadata, +) from vllm.v1.cudagraph_dispatcher import CudagraphDispatcher + # yapf conflicts with isort for this block # yapf: disable -from vllm.v1.kv_cache_interface import (AttentionSpec, - ChunkedLocalAttentionSpec, - CrossAttentionSpec, - EncoderOnlyAttentionSpec, - FullAttentionSpec, KVCacheConfig, - KVCacheGroupSpec, KVCacheSpec, - MambaSpec, MLAAttentionSpec, - SlidingWindowSpec, - UniformTypeKVCacheSpecs) +from vllm.v1.kv_cache_interface import ( + AttentionSpec, + ChunkedLocalAttentionSpec, + CrossAttentionSpec, + EncoderOnlyAttentionSpec, + FullAttentionSpec, + KVCacheConfig, + KVCacheGroupSpec, + KVCacheSpec, + MambaSpec, + MLAAttentionSpec, + SlidingWindowSpec, + UniformTypeKVCacheSpecs, +) + # yapf: enable -from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, AsyncModelRunnerOutput, - DraftTokenIds, LogprobsLists, LogprobsTensors, - ModelRunnerOutput, PoolerOutput, SamplerOutput) +from vllm.v1.outputs import ( + EMPTY_MODEL_RUNNER_OUTPUT, + AsyncModelRunnerOutput, + DraftTokenIds, + LogprobsLists, + LogprobsTensors, + ModelRunnerOutput, + PoolerOutput, + SamplerOutput, +) from vllm.v1.pool.metadata import PoolingMetadata from vllm.v1.sample.logits_processor import LogitsProcessors, build_logitsprocs from vllm.v1.sample.metadata import SamplingMetadata @@ -102,18 +144,21 @@ from vllm.v1.utils import CpuGpuBuffer, record_function_or_nullcontext from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch from vllm.v1.worker.gpu_ubatch_wrapper import UBatchWrapper -from vllm.v1.worker.kv_connector_model_runner_mixin import ( - KVConnectorModelRunnerMixin) +from vllm.v1.worker.kv_connector_model_runner_mixin import KVConnectorModelRunnerMixin from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin -from vllm.v1.worker.ubatch_splitting import (check_ubatch_thresholds, - ubatch_split) +from vllm.v1.worker.ubatch_splitting import check_ubatch_thresholds, ubatch_split from vllm.v1.worker.ubatch_utils import UBatchSlice, UBatchSlices from vllm.v1.worker.utils import is_residual_scattered_for_sp -from .utils import (AttentionGroup, MultiModalBudget, - add_kv_sharing_layers_to_kv_cache_groups, bind_kv_cache, - gather_mm_placeholders, sanity_check_mm_encoder_outputs, - scatter_mm_placeholders) +from .utils import ( + AttentionGroup, + MultiModalBudget, + add_kv_sharing_layers_to_kv_cache_groups, + bind_kv_cache, + gather_mm_placeholders, + sanity_check_mm_encoder_outputs, + scatter_mm_placeholders, +) if TYPE_CHECKING: from vllm.model_executor.model_loader.tensorizer import TensorizerConfig @@ -123,13 +168,11 @@ AttnMetadataDict: TypeAlias = dict[str, AttentionMetadata] # list when ubatching is enabled -PerLayerAttnMetadata: TypeAlias = Union[list[AttnMetadataDict], - AttnMetadataDict] +PerLayerAttnMetadata: TypeAlias = Union[list[AttnMetadataDict], AttnMetadataDict] # Wrapper for ModelRunnerOutput to support overlapped execution. class AsyncGPUModelRunnerOutput(AsyncModelRunnerOutput): - def __init__( self, model_runner_output: ModelRunnerOutput, @@ -152,12 +195,13 @@ def __init__( with torch.cuda.stream(async_output_copy_stream): async_output_copy_stream.wait_stream(default_stream) self._sampled_token_ids_cpu = self._sampled_token_ids.to( - 'cpu', non_blocking=True) + "cpu", non_blocking=True + ) self._async_copy_ready_event.record() def get_output(self) -> ModelRunnerOutput: """Copy the device tensors to the host and return a ModelRunnerOutput. - + This function blocks until the copy is finished. """ self._async_copy_ready_event.synchronize() @@ -175,7 +219,6 @@ def get_output(self) -> ModelRunnerOutput: class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): - def __init__( self, vllm_config: VllmConfig, @@ -193,10 +236,10 @@ def __init__( self.observability_config = vllm_config.observability_config from vllm.model_executor.models.utils import set_cpu_offload_max_bytes - set_cpu_offload_max_bytes( - int(self.cache_config.cpu_offload_gb * 1024**3)) - from vllm.model_executor.layers.batch_invariant import ( - init_batch_invariance) + + set_cpu_offload_max_bytes(int(self.cache_config.cpu_offload_gb * 1024**3)) + from vllm.model_executor.layers.batch_invariant import init_batch_invariance + init_batch_invariance() model_config = self.model_config @@ -209,13 +252,13 @@ def __init__( if cache_config.cache_dtype == "auto": self.kv_cache_dtype = self.dtype else: - self.kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[ - cache_config.cache_dtype] + self.kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype] - self.is_pooling_model = (model_config.runner_type == 'pooling') + self.is_pooling_model = model_config.runner_type == "pooling" self.enable_prompt_embeds = model_config.enable_prompt_embeds self.is_multimodal_raw_input_only_model = ( - model_config.is_multimodal_raw_input_only_model) + model_config.is_multimodal_raw_input_only_model + ) # This will be overridden in load_model() self.is_multimodal_pruning_enabled = False self.max_model_len = model_config.max_model_len @@ -228,12 +271,12 @@ def __init__( # TODO: Support overlapping mirco-batches # https://github.com/vllm-project/vllm/issues/18019 self.broadcast_pp_output = ( - self.parallel_config.distributed_executor_backend - == "external_launcher" and len(get_pp_group().ranks) > 0) + self.parallel_config.distributed_executor_backend == "external_launcher" + and len(get_pp_group().ranks) > 0 + ) # Model-related. - self.num_query_heads = model_config.get_num_attention_heads( - parallel_config) + self.num_query_heads = model_config.get_num_attention_heads(parallel_config) self.hidden_size = model_config.get_hidden_size() self.attention_chunk_size = model_config.attention_chunk_size # Only relevant for models using ALiBi (e.g, MPT) @@ -245,13 +288,13 @@ def __init__( self.mm_registry = MULTIMODAL_REGISTRY self.uses_mrope = model_config.uses_mrope self.supports_mm_inputs = self.mm_registry.supports_multimodal_inputs( - model_config) + model_config + ) if self.model_config.is_encoder_decoder: # Maximum length of the encoder input, only for encoder-decoder # models. - self.max_encoder_len = scheduler_config.\ - max_num_encoder_input_tokens + self.max_encoder_len = scheduler_config.max_num_encoder_input_tokens else: self.max_encoder_len = 0 @@ -291,17 +334,18 @@ def __init__( runner=self, ) # type: ignore elif self.speculative_config.use_eagle(): - self.drafter = EagleProposer(self.vllm_config, self.device, - self) # type: ignore + self.drafter = EagleProposer(self.vllm_config, self.device, self) # type: ignore if self.speculative_config.method == "eagle3": self.use_aux_hidden_state_outputs = True elif self.speculative_config.method == "medusa": self.drafter = MedusaProposer( - vllm_config=self.vllm_config, - device=self.device) # type: ignore + vllm_config=self.vllm_config, device=self.device + ) # type: ignore else: - raise ValueError("Unknown speculative decoding method: " - f"{self.speculative_config.method}") + raise ValueError( + "Unknown speculative decoding method: " + f"{self.speculative_config.method}" + ) self.rejection_sampler = RejectionSampler() # Request states. @@ -329,58 +373,64 @@ def __init__( block_sizes=[self.cache_config.block_size], is_spec_decode=bool(self.vllm_config.speculative_config), logitsprocs=build_logitsprocs( - self.vllm_config, self.device, self.pin_memory, + self.vllm_config, + self.device, + self.pin_memory, self.is_pooling_model, - self.vllm_config.model_config.logits_processors), + self.vllm_config.model_config.logits_processors, + ), is_pooling_model=self.is_pooling_model, ) self.use_async_scheduling = self.scheduler_config.async_scheduling - self.async_output_copy_stream = torch.cuda.Stream() if \ - self.use_async_scheduling else None + self.async_output_copy_stream = ( + torch.cuda.Stream() if self.use_async_scheduling else None + ) # TODO(woosuk): Provide an option to tune the max cudagraph batch size. # The convention is different. # self.cudagraph_batch_sizes sorts in ascending order. # The batch sizes in the config are in descending order. - if self.compilation_config.cudagraph_capture_sizes and \ - self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE: + if ( + self.compilation_config.cudagraph_capture_sizes + and self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE + ): self.cudagraph_batch_sizes = list( - reversed(self.compilation_config.cudagraph_capture_sizes)) + reversed(self.compilation_config.cudagraph_capture_sizes) + ) # Cache the device properties. self._init_device_properties() # Persistent buffers for CUDA graphs. - self.input_ids = self._make_buffer(self.max_num_tokens, - dtype=torch.int32) - self.positions = self._make_buffer(self.max_num_tokens, - dtype=torch.int64) - self.query_start_loc = self._make_buffer(self.max_num_reqs + 1, - dtype=torch.int32) + self.input_ids = self._make_buffer(self.max_num_tokens, dtype=torch.int32) + self.positions = self._make_buffer(self.max_num_tokens, dtype=torch.int64) + self.query_start_loc = self._make_buffer( + self.max_num_reqs + 1, dtype=torch.int32 + ) self.seq_lens = self._make_buffer(self.max_num_reqs, dtype=torch.int32) # Because inputs_embeds may be bfloat16 and we don't need a numpy # version of this tensor, avoid a RuntimeError by not creating a # numpy buffer. - self.inputs_embeds = self._make_buffer(self.max_num_tokens, - self.hidden_size, - dtype=self.dtype, - numpy=False) - self.is_token_ids = self._make_buffer(self.max_num_tokens, - dtype=torch.bool) - self.discard_request_indices = self._make_buffer(self.max_num_reqs, - dtype=torch.int64) + self.inputs_embeds = self._make_buffer( + self.max_num_tokens, self.hidden_size, dtype=self.dtype, numpy=False + ) + self.is_token_ids = self._make_buffer(self.max_num_tokens, dtype=torch.bool) + self.discard_request_indices = self._make_buffer( + self.max_num_reqs, dtype=torch.int64 + ) self.num_discarded_requests = 0 - self.num_decode_draft_tokens = self._make_buffer(self.max_num_reqs, - dtype=torch.int32) - self.num_accepted_tokens = self._make_buffer(self.max_num_reqs, - dtype=torch.int64) + self.num_decode_draft_tokens = self._make_buffer( + self.max_num_reqs, dtype=torch.int32 + ) + self.num_accepted_tokens = self._make_buffer( + self.max_num_reqs, dtype=torch.int64 + ) # Only relevant for multimodal models if self.supports_mm_inputs: - self.is_mm_embed = self._make_buffer(self.max_num_tokens, - dtype=torch.bool) + self.is_mm_embed = self._make_buffer(self.max_num_tokens, dtype=torch.bool) # Only relevant for models using M-RoPE (e.g, Qwen2-VL) if self.uses_mrope: @@ -395,7 +445,8 @@ def __init__( # 1D-RoPE. # See page 5 of https://arxiv.org/abs/2409.12191 self.mrope_positions = self._make_buffer( - (3, self.max_num_tokens + 1), dtype=torch.int64) + (3, self.max_num_tokens + 1), dtype=torch.int64 + ) # CUDA event to synchronize use of reused CPU tensors between steps # when async scheduling is enabled. @@ -410,10 +461,10 @@ def __init__( # OPTIMIZATION: Cache the tensors rather than creating them every step. # Keep in int64 to avoid overflow with long context - self.arange_np = np.arange(max(self.max_num_reqs + 1, - self.max_model_len, - self.max_num_tokens), - dtype=np.int64) + self.arange_np = np.arange( + max(self.max_num_reqs + 1, self.max_model_len, self.max_num_tokens), + dtype=np.int64, + ) # Layer pairings for cross-layer KV sharing. # If an Attention layer `layer_name` is in the keys of this dict, it @@ -425,19 +476,27 @@ def __init__( self.kv_sharing_fast_prefill_logits_indices = None if self.cache_config.kv_sharing_fast_prefill: self.kv_sharing_fast_prefill_logits_indices = torch.zeros( - self.max_num_tokens, dtype=torch.int32, device=self.device) + self.max_num_tokens, dtype=torch.int32, device=self.device + ) - self.uniform_decode_query_len = 1 if not self.speculative_config else \ - 1 + self.speculative_config.num_speculative_tokens + self.uniform_decode_query_len = ( + 1 + if not self.speculative_config + else 1 + self.speculative_config.num_speculative_tokens + ) # Cudagraph dispatcher for runtime cudagraph dispatching. self.cudagraph_dispatcher = CudagraphDispatcher(self.vllm_config) - self.mm_budget = MultiModalBudget( - self.model_config, - self.scheduler_config, - self.mm_registry, - ) if self.supports_mm_inputs else None + self.mm_budget = ( + MultiModalBudget( + self.model_config, + self.scheduler_config, + self.mm_registry, + ) + if self.supports_mm_inputs + else None + ) self.reorder_batch_threshold: Optional[int] = None @@ -447,14 +506,14 @@ def __init__( self.runner_only_attn_layers: set[str] = set() # Cached outputs. - self._draft_token_ids: Optional[Union[list[list[int]], - torch.Tensor]] = None + self._draft_token_ids: Optional[Union[list[list[int]], torch.Tensor]] = None self.transfer_event = torch.cuda.Event() self.sampled_token_ids_pinned_cpu = torch.empty( (self.max_model_len, 1), dtype=torch.int64, device="cpu", - pin_memory=self.pin_memory) + pin_memory=self.pin_memory, + ) def _get_positions(self, num_tokens: Any): if isinstance(num_tokens, int): @@ -466,15 +525,16 @@ def _get_positions(self, num_tokens: Any): return self.mrope_positions.gpu[:, num_tokens] return self.positions.gpu[num_tokens] - def _make_buffer(self, - *size: Union[int, torch.SymInt], - dtype: torch.dtype, - numpy: bool = True) -> CpuGpuBuffer: - return CpuGpuBuffer(*size, - dtype=dtype, - device=self.device, - pin_memory=self.pin_memory, - with_numpy=numpy) + def _make_buffer( + self, *size: Union[int, torch.SymInt], dtype: torch.dtype, numpy: bool = True + ) -> CpuGpuBuffer: + return CpuGpuBuffer( + *size, + dtype=dtype, + device=self.device, + pin_memory=self.pin_memory, + with_numpy=numpy, + ) def _init_model_kwargs(self, num_tokens: int): model_kwargs = dict[str, Any]() @@ -487,9 +547,11 @@ def _init_model_kwargs(self, num_tokens: int): token_type_id_requests = dict[int, Any]() for i, param in enumerate(pooling_params): - if param.extra_kwargs is not None and \ - (token_types := param.extra_kwargs.get( - "compressed_token_type_ids")) is not None: + if ( + param.extra_kwargs is not None + and (token_types := param.extra_kwargs.get("compressed_token_type_ids")) + is not None + ): token_type_id_requests[i] = token_types if len(token_type_id_requests) == 0: @@ -504,7 +566,8 @@ def _init_model_kwargs(self, num_tokens: int): token_type_ids.append(ids) model_kwargs["token_type_ids"] = torch.concat(token_type_ids).to( - device=self.device) + device=self.device + ) return model_kwargs def _may_reorder_batch(self, scheduler_output: "SchedulerOutput") -> None: @@ -530,17 +593,18 @@ def _may_reorder_batch(self, scheduler_output: "SchedulerOutput") -> None: # required for DCP with q_len > 1, so we assert here. Remove this # assert once the custom mask is support is added to FA3. if self.dcp_world_size > 1: - assert self.reorder_batch_threshold == 1, \ + assert self.reorder_batch_threshold == 1, ( "DCP not support reorder_batch_threshold > 1 now." + ) reorder_batch_to_split_decodes_and_prefills( self.input_batch, scheduler_output, - decode_threshold=self.reorder_batch_threshold) + decode_threshold=self.reorder_batch_threshold, + ) # Note: used for model runner override. def _init_device_properties(self) -> None: - """Initialize attributes from torch.cuda.get_device_properties - """ + """Initialize attributes from torch.cuda.get_device_properties""" self.device_properties = torch.cuda.get_device_properties(self.device) self.num_sms = self.device_properties.multi_processor_count @@ -596,8 +660,10 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: sampling_params = new_req_data.sampling_params pooling_params = new_req_data.pooling_params - if sampling_params and \ - sampling_params.sampling_type == SamplingType.RANDOM_SEED: + if ( + sampling_params + and sampling_params.sampling_type == SamplingType.RANDOM_SEED + ): generator = torch.Generator(device=self.device) generator.manual_seed(sampling_params.seed) else: @@ -654,14 +720,14 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: new_token_ids = req_data.new_token_ids[i] # Add the sampled token(s) from the previous step (if any). # This doesn't include "unverified" tokens like spec tokens. - num_new_tokens = (num_computed_tokens + len(new_token_ids) - - req_state.num_tokens) + num_new_tokens = ( + num_computed_tokens + len(new_token_ids) - req_state.num_tokens + ) if num_new_tokens == 1: # Avoid slicing list in most common case. req_state.output_token_ids.append(new_token_ids[-1]) elif num_new_tokens > 0: - req_state.output_token_ids.extend( - new_token_ids[-num_new_tokens:]) + req_state.output_token_ids.extend(new_token_ids[-num_new_tokens:]) elif num_output_tokens < len(req_state.output_token_ids): # Some output tokens were discarded due to a sync-KV-load # failure. Align the cached state. @@ -669,21 +735,22 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: req_index = self.input_batch.req_id_to_index.get(req_id) if req_index is not None: - old_end_idx = self.input_batch.num_tokens_no_spec[ - req_index] - end_idx = self.input_batch.num_prompt_tokens[ - req_index] + num_output_tokens + old_end_idx = self.input_batch.num_tokens_no_spec[req_index] + end_idx = ( + self.input_batch.num_prompt_tokens[req_index] + + num_output_tokens + ) self.input_batch.num_tokens[req_index] = end_idx self.input_batch.num_tokens_no_spec[req_index] = end_idx - self.input_batch.is_token_ids[req_index, - end_idx:old_end_idx] = False + self.input_batch.is_token_ids[req_index, end_idx:old_end_idx] = ( + False + ) # Update the block IDs. if not resumed_from_preemption: if new_block_ids is not None: # Append the new blocks to the existing block IDs. - for block_ids, new_ids in zip(req_state.block_ids, - new_block_ids): + for block_ids, new_ids in zip(req_state.block_ids, new_block_ids): block_ids.extend(new_ids) else: assert new_block_ids is not None @@ -700,11 +767,9 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: continue # Update the persistent batch. - self.input_batch.num_computed_tokens_cpu[req_index] = ( - num_computed_tokens) + self.input_batch.num_computed_tokens_cpu[req_index] = num_computed_tokens if new_block_ids is not None: - self.input_batch.block_table.append_row( - new_block_ids, req_index) + self.input_batch.block_table.append_row(new_block_ids, req_index) # For the last rank, we don't need to update the token_ids_cpu # because the sampled tokens are already cached. @@ -713,21 +778,22 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: start_token_index = num_computed_tokens end_token_index = num_computed_tokens + len(new_token_ids) self.input_batch.token_ids_cpu[ - req_index, - start_token_index:end_token_index] = new_token_ids - self.input_batch.num_tokens_no_spec[ - req_index] = end_token_index + req_index, start_token_index:end_token_index + ] = new_token_ids + self.input_batch.num_tokens_no_spec[req_index] = end_token_index self.input_batch.num_tokens[req_index] = end_token_index # Add spec_token_ids to token_ids_cpu. - spec_token_ids = ( - scheduler_output.scheduled_spec_decode_tokens.get(req_id, ())) + spec_token_ids = scheduler_output.scheduled_spec_decode_tokens.get( + req_id, () + ) if spec_token_ids: num_spec_tokens = len(spec_token_ids) start_index = self.input_batch.num_tokens_no_spec[req_index] end_token_index = start_index + num_spec_tokens self.input_batch.token_ids_cpu[ - req_index, start_index:end_token_index] = spec_token_ids + req_index, start_index:end_token_index + ] = spec_token_ids # NOTE(woosuk): `num_tokens` here may include spec tokens. self.input_batch.num_tokens[req_index] += num_spec_tokens @@ -744,7 +810,8 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: self.input_batch.refresh_metadata() def _update_states_after_model_execute( - self, output_token_ids: torch.Tensor) -> None: + self, output_token_ids: torch.Tensor + ) -> None: """Update the cached states after model execution. This is used for MTP/EAGLE for hybrid models, as in linear attention, @@ -757,14 +824,26 @@ def _update_states_after_model_execute( return # Find the number of accepted tokens for each sequence. - num_accepted_tokens = (torch.cat( - [ - output_token_ids, - torch.full((output_token_ids.size(0), 1), - -1, - device=output_token_ids.device), - ], - dim=1) == -1).int().argmax(-1).cpu().numpy() + num_accepted_tokens = ( + ( + torch.cat( + [ + output_token_ids, + torch.full( + (output_token_ids.size(0), 1), + -1, + device=output_token_ids.device, + ), + ], + dim=1, + ) + == -1 + ) + .int() + .argmax(-1) + .cpu() + .numpy() + ) for i, num_tokens in enumerate(num_accepted_tokens): self.input_batch.num_accepted_tokens_cpu[i] = num_tokens @@ -791,7 +870,7 @@ def _init_mrope_positions(self, req_state: CachedRequestState): use_audio_in_video = True if supports_mrope(self.model): - req_state.mrope_positions, req_state.mrope_position_delta = \ + req_state.mrope_positions, req_state.mrope_position_delta = ( self.model.get_mrope_input_positions( req_state.prompt_token_ids, hf_config=self.model_config.hf_config, @@ -801,8 +880,9 @@ def _init_mrope_positions(self, req_state: CachedRequestState): audio_feature_lengths=audio_feature_lengths, use_audio_in_video=use_audio_in_video, ) + ) else: - req_state.mrope_positions, req_state.mrope_position_delta = \ + req_state.mrope_positions, req_state.mrope_position_delta = ( MRotaryEmbedding.get_input_positions_tensor( req_state.prompt_token_ids, hf_config=self.model_config.hf_config, @@ -812,6 +892,7 @@ def _init_mrope_positions(self, req_state: CachedRequestState): audio_feature_lengths=audio_feature_lengths, use_audio_in_video=use_audio_in_video, ) + ) def _extract_mm_kwargs( self, @@ -830,10 +911,10 @@ def _extract_mm_kwargs( model = cast(SupportsMultiModal, self.model) mm_kwargs_combined: BatchedTensorInputs = {} for _, _, mm_kwargs_group in group_mm_kwargs_by_modality( - mm_kwargs, - device=self.device, - pin_memory=self.pin_memory, - merge_by_field_config=model.merge_by_field_config, + mm_kwargs, + device=self.device, + pin_memory=self.pin_memory, + merge_by_field_config=model.merge_by_field_config, ): mm_kwargs_combined.update(mm_kwargs_group) @@ -869,10 +950,11 @@ def _get_cumsum_and_arange( return cu_num_tokens, arange - def _prepare_input_ids(self, total_num_scheduled_tokens: int, - cu_num_tokens: np.ndarray) -> None: + def _prepare_input_ids( + self, total_num_scheduled_tokens: int, cu_num_tokens: np.ndarray + ) -> None: """Prepare the input IDs for the current batch. - + Carefully handles the `prev_sampled_token_ids` which can be cached from the previous engine iteration, in which case those tokens on the GPU need to be copied into the corresponding slots into input_ids.""" @@ -901,7 +983,7 @@ def _prepare_input_ids(self, total_num_scheduled_tokens: int, # last token in each common request. flattened_index = cu_num_tokens[cur_index].item() - 1 flattened_indices.append(flattened_index) - indices_match &= (prev_index == flattened_index) + indices_match &= prev_index == flattened_index max_flattened_index = max(max_flattened_index, flattened_index) num_commmon_tokens = len(flattened_indices) if num_commmon_tokens < total_num_scheduled_tokens: @@ -921,28 +1003,27 @@ def _prepare_input_ids(self, total_num_scheduled_tokens: int, # The indices are both the same permutation of 0..N-1 so # we can copy directly using a single slice. self.input_ids.gpu[:num_commmon_tokens].copy_( - self.input_batch.prev_sampled_token_ids[:num_commmon_tokens, - 0], - non_blocking=True) + self.input_batch.prev_sampled_token_ids[:num_commmon_tokens, 0], + non_blocking=True, + ) if self.enable_prompt_embeds: self.is_token_ids.gpu[:num_commmon_tokens] = True return # Upload the index tensors asynchronously # so the scatter can be non-blocking. - input_ids_index_tensor = torch.tensor(flattened_indices, - dtype=torch.int64, - pin_memory=self.pin_memory).to( - self.device, - non_blocking=True) + input_ids_index_tensor = torch.tensor( + flattened_indices, dtype=torch.int64, pin_memory=self.pin_memory + ).to(self.device, non_blocking=True) prev_common_req_indices_tensor = torch.tensor( - prev_common_req_indices, - dtype=torch.int64, - pin_memory=self.pin_memory).to(self.device, non_blocking=True) + prev_common_req_indices, dtype=torch.int64, pin_memory=self.pin_memory + ).to(self.device, non_blocking=True) self.input_ids.gpu.scatter_( dim=0, index=input_ids_index_tensor, src=self.input_batch.prev_sampled_token_ids[ - prev_common_req_indices_tensor, 0]) + prev_common_req_indices_tensor, 0 + ], + ) def _get_encoder_seq_lens( self, @@ -964,10 +1045,17 @@ def _get_encoder_seq_lens( def _prepare_inputs( self, scheduler_output: "SchedulerOutput" - ) -> tuple[PerLayerAttnMetadata, torch.Tensor, - Optional[SpecDecodeMetadata], np.ndarray, - Optional[CommonAttentionMetadata], int, Optional[UBatchSlices], - Optional[torch.Tensor], bool]: + ) -> tuple[ + PerLayerAttnMetadata, + torch.Tensor, + Optional[SpecDecodeMetadata], + np.ndarray, + Optional[CommonAttentionMetadata], + int, + Optional[UBatchSlices], + Optional[torch.Tensor], + bool, + ]: """ :return: tuple[ attn_metadata: layer-to-attention_metadata mapping, @@ -993,19 +1081,19 @@ def _prepare_inputs( # Get request indices. # E.g., [2, 5, 3] -> [0, 0, 1, 1, 1, 1, 1, 2, 2, 2] - req_indices = np.repeat(self.arange_np[:num_reqs], - num_scheduled_tokens) + req_indices = np.repeat(self.arange_np[:num_reqs], num_scheduled_tokens) # cu_num_tokens: [2, 5, 3] -> [2, 7, 10] # arange: [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] - cu_num_tokens, arange = self._get_cumsum_and_arange( - num_scheduled_tokens) + cu_num_tokens, arange = self._get_cumsum_and_arange(num_scheduled_tokens) # Get positions. positions_np = self.positions.np[:total_num_scheduled_tokens] - np.add(self.input_batch.num_computed_tokens_cpu[req_indices], - arange, - out=positions_np) + np.add( + self.input_batch.num_computed_tokens_cpu[req_indices], + arange, + out=positions_np, + ) # Calculate M-RoPE positions. # Only relevant for models using M-RoPE (e.g, Qwen2-VL) @@ -1016,24 +1104,28 @@ def _prepare_inputs( # E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] # -> [0, 1, M, M + 1, M + 2, M + 3, M + 4, 2 * M, 2 * M + 1, 2 * M + 2] # where M is the max_model_len. - token_indices = (positions_np + - req_indices * self.input_batch.token_ids_cpu.shape[1]) + token_indices = ( + positions_np + req_indices * self.input_batch.token_ids_cpu.shape[1] + ) token_indices_tensor = torch.from_numpy(token_indices) # NOTE(woosuk): We use torch.index_select instead of np.take here # because torch.index_select is much faster than np.take for large # tensors. - torch.index_select(self.input_batch.token_ids_cpu_tensor.flatten(), - 0, - token_indices_tensor, - out=self.input_ids.cpu[:total_num_scheduled_tokens]) + torch.index_select( + self.input_batch.token_ids_cpu_tensor.flatten(), + 0, + token_indices_tensor, + out=self.input_ids.cpu[:total_num_scheduled_tokens], + ) if self.enable_prompt_embeds: is_token_ids = self.input_batch.is_token_ids.flatten() torch.index_select( is_token_ids, 0, token_indices_tensor, - out=self.is_token_ids.cpu[:total_num_scheduled_tokens]) + out=self.is_token_ids.cpu[:total_num_scheduled_tokens], + ) # Because we did not pre-allocate a massive prompt_embeds CPU tensor on # the InputBatch, we need to fill in the prompt embeds into the expected @@ -1067,52 +1159,49 @@ def _prepare_inputs( actual_num_sched = actual_end - start_pos if actual_num_sched > 0: - self.inputs_embeds.cpu[output_idx:output_idx + - actual_num_sched].copy_( - req_embeds[start_pos:actual_end] - ) + self.inputs_embeds.cpu[ + output_idx : output_idx + actual_num_sched + ].copy_(req_embeds[start_pos:actual_end]) output_idx += num_sched - self.input_batch.block_table.compute_slot_mapping( - req_indices, positions_np) - self.input_batch.block_table.commit_slot_mapping( - total_num_scheduled_tokens) + self.input_batch.block_table.compute_slot_mapping(req_indices, positions_np) + self.input_batch.block_table.commit_slot_mapping(total_num_scheduled_tokens) # Prepare the attention metadata. self.query_start_loc.np[0] = 0 - self.query_start_loc.np[1:num_reqs + 1] = cu_num_tokens + self.query_start_loc.np[1 : num_reqs + 1] = cu_num_tokens # Note: pad query_start_loc to be non-decreasing, as kernels # like FlashAttention requires that - self.query_start_loc.np[num_reqs + 1:].fill(cu_num_tokens[-1]) + self.query_start_loc.np[num_reqs + 1 :].fill(cu_num_tokens[-1]) self.query_start_loc.copy_to_gpu() - query_start_loc = self.query_start_loc.gpu[:num_reqs + 1] + query_start_loc = self.query_start_loc.gpu[: num_reqs + 1] num_tokens_unpadded = scheduler_output.total_num_scheduled_tokens num_tokens_padded = num_tokens_unpadded + self.get_local_padding( - num_tokens_unpadded) - uniform_decode = \ - (max_num_scheduled_tokens == self.uniform_decode_query_len) and \ - (total_num_scheduled_tokens == num_reqs * max_num_scheduled_tokens) - ubatch_slices, num_tokens_after_padding = \ - ubatch_split(num_scheduled_tokens, - num_tokens_unpadded, - num_tokens_padded, - uniform_decode=uniform_decode, - vllm_config=self.vllm_config) + num_tokens_unpadded + ) + uniform_decode = ( + max_num_scheduled_tokens == self.uniform_decode_query_len + ) and (total_num_scheduled_tokens == num_reqs * max_num_scheduled_tokens) + ubatch_slices, num_tokens_after_padding = ubatch_split( + num_scheduled_tokens, + num_tokens_unpadded, + num_tokens_padded, + uniform_decode=uniform_decode, + vllm_config=self.vllm_config, + ) self.seq_lens.np[:num_reqs] = ( - self.input_batch.num_computed_tokens_cpu[:num_reqs] + - num_scheduled_tokens) + self.input_batch.num_computed_tokens_cpu[:num_reqs] + num_scheduled_tokens + ) # Fill unused with 0 for full cuda graph mode. self.seq_lens.np[num_reqs:].fill(0) self.seq_lens.copy_to_gpu() seq_lens = self.seq_lens.gpu[:num_reqs] max_seq_len = self.seq_lens.np[:num_reqs].max().item() - num_tokens = [ - self.requests[r].num_tokens for r in self.input_batch.req_ids - ] + num_tokens = [self.requests[r].num_tokens for r in self.input_batch.req_ids] num_tokens_np = np.array(num_tokens, dtype=np.int32) # Record the index of requests that should not be sampled, @@ -1120,8 +1209,9 @@ def _prepare_inputs( discard_requests_mask = self.seq_lens.np[:num_reqs] < num_tokens_np discard_request_indices = np.nonzero(discard_requests_mask)[0] self.num_discarded_requests = len(discard_request_indices) - self.discard_request_indices.np[:self.num_discarded_requests] = ( - discard_request_indices) + self.discard_request_indices.np[: self.num_discarded_requests] = ( + discard_request_indices + ) self.discard_request_indices.copy_to_gpu(self.num_discarded_requests) @@ -1132,13 +1222,13 @@ def _prepare_inputs( # Only relevant for models using M-RoPE (e.g, Qwen2-VL) self.mrope_positions.gpu[:, :total_num_scheduled_tokens].copy_( self.mrope_positions.cpu[:, :total_num_scheduled_tokens], - non_blocking=True) + non_blocking=True, + ) else: # Common case (1D positions) self.positions.copy_to_gpu(total_num_scheduled_tokens) - use_spec_decode = len( - scheduler_output.scheduled_spec_decode_tokens) > 0 + use_spec_decode = len(scheduler_output.scheduled_spec_decode_tokens) > 0 if not use_spec_decode: # NOTE(woosuk): Due to chunked prefills, the batch may contain # partial requests. While we should not sample any token @@ -1156,27 +1246,35 @@ def _prepare_inputs( # For chunked prefills, use -1 as mask rather than 0, as guided # decoding may rollback speculative tokens. num_decode_draft_tokens = np.full(num_reqs, -1, dtype=np.int32) - for req_id, draft_token_ids in ( - scheduler_output.scheduled_spec_decode_tokens.items()): + for ( + req_id, + draft_token_ids, + ) in scheduler_output.scheduled_spec_decode_tokens.items(): req_idx = self.input_batch.req_id_to_index[req_id] num_draft_tokens[req_idx] = len(draft_token_ids) - num_decode_draft_tokens[req_idx] = (len(draft_token_ids) if ( - self.input_batch.num_computed_tokens_cpu[req_idx] - >= self.input_batch.num_prompt_tokens[req_idx]) else -1) + num_decode_draft_tokens[req_idx] = ( + len(draft_token_ids) + if ( + self.input_batch.num_computed_tokens_cpu[req_idx] + >= self.input_batch.num_prompt_tokens[req_idx] + ) + else -1 + ) spec_decode_metadata = self._calc_spec_decode_metadata( - num_draft_tokens, cu_num_tokens) + num_draft_tokens, cu_num_tokens + ) logits_indices = spec_decode_metadata.logits_indices # For DECODE only cuda graph of some attention backends (e.g., GDN). - self.num_decode_draft_tokens.np[: - num_reqs] = num_decode_draft_tokens + self.num_decode_draft_tokens.np[:num_reqs] = num_decode_draft_tokens self.num_decode_draft_tokens.np[num_reqs:].fill(-1) self.num_decode_draft_tokens.copy_to_gpu() logits_indices_padded = None if self.cache_config.kv_sharing_fast_prefill: logits_indices_padded = self._prepare_kv_sharing_fast_prefill( - logits_indices) + logits_indices + ) attn_metadata: PerLayerAttnMetadata = {} if ubatch_slices is not None: @@ -1184,26 +1282,29 @@ def _prepare_inputs( use_cascade_attn = False # Used in the below loop. - query_start_loc_cpu = self.query_start_loc.cpu[:num_reqs + 1] + query_start_loc_cpu = self.query_start_loc.cpu[: num_reqs + 1] seq_lens_cpu = self.seq_lens.cpu[:num_reqs] - num_computed_tokens_cpu = ( - self.input_batch.num_computed_tokens_cpu_tensor[:num_reqs]) + num_computed_tokens_cpu = self.input_batch.num_computed_tokens_cpu_tensor[ + :num_reqs + ] spec_decode_common_attn_metadata = None if use_spec_decode: self.num_accepted_tokens.np[:num_reqs] = ( - self.input_batch.num_accepted_tokens_cpu[:num_reqs]) + self.input_batch.num_accepted_tokens_cpu[:num_reqs] + ) self.num_accepted_tokens.np[num_reqs:].fill(1) self.num_accepted_tokens.copy_to_gpu() # Prepare the attention metadata for each KV cache group and make layers # in the same group share the same metadata. for kv_cache_group_id, kv_cache_group_spec in enumerate( - self.kv_cache_config.kv_cache_groups): + self.kv_cache_config.kv_cache_groups + ): encoder_seq_lens = self._get_encoder_seq_lens( - scheduler_output, kv_cache_group_spec.kv_cache_spec, num_reqs) + scheduler_output, kv_cache_group_spec.kv_cache_spec, num_reqs + ) - if isinstance(kv_cache_group_spec.kv_cache_spec, - EncoderOnlyAttentionSpec): + if isinstance(kv_cache_group_spec.kv_cache_spec, EncoderOnlyAttentionSpec): # Encoder-only layers do not have KV cache, so we need to # create a dummy block table and slot mapping for them. blk_table_tensor = torch.zeros( @@ -1212,7 +1313,7 @@ def _prepare_inputs( device=self.device, ) slot_mapping = torch.zeros( - (total_num_scheduled_tokens, ), + (total_num_scheduled_tokens,), dtype=torch.int64, device=self.device, ) @@ -1220,16 +1321,14 @@ def _prepare_inputs( else: blk_table = self.input_batch.block_table[kv_cache_group_id] blk_table_tensor = blk_table.get_device_tensor(num_reqs) - slot_mapping = blk_table.slot_mapping.gpu[: - total_num_scheduled_tokens] + slot_mapping = blk_table.slot_mapping.gpu[:total_num_scheduled_tokens] # Fill unused with -1. Needed for reshape_and_cache in full cuda # graph mode. - blk_table.slot_mapping.gpu[total_num_scheduled_tokens:].fill_( - -1) - num_common_prefix_blocks = ( - scheduler_output. - num_common_prefix_blocks[kv_cache_group_id]) + blk_table.slot_mapping.gpu[total_num_scheduled_tokens:].fill_(-1) + num_common_prefix_blocks = scheduler_output.num_common_prefix_blocks[ + kv_cache_group_id + ] common_attn_metadata = CommonAttentionMetadata( query_start_loc=query_start_loc, @@ -1249,11 +1348,12 @@ def _prepare_inputs( encoder_seq_lens=encoder_seq_lens, ) - if (self.speculative_config - and spec_decode_common_attn_metadata is None): + if self.speculative_config and spec_decode_common_attn_metadata is None: if isinstance(self.drafter, EagleProposer): - if (self.drafter.attn_layer_names[0] - in kv_cache_group_spec.layer_names): + if ( + self.drafter.attn_layer_names[0] + in kv_cache_group_spec.layer_names + ): spec_decode_common_attn_metadata = common_attn_metadata else: spec_decode_common_attn_metadata = common_attn_metadata @@ -1271,24 +1371,27 @@ def _prepare_inputs( ) extra_attn_metadata_args = {} - if use_spec_decode and isinstance(builder, - GDNAttentionMetadataBuilder): + if use_spec_decode and isinstance(builder, GDNAttentionMetadataBuilder): extra_attn_metadata_args = dict( - num_accepted_tokens=self.num_accepted_tokens. - gpu[:num_reqs], - num_decode_draft_tokens_cpu=self. - num_decode_draft_tokens.cpu[:num_reqs], + num_accepted_tokens=self.num_accepted_tokens.gpu[:num_reqs], + num_decode_draft_tokens_cpu=self.num_decode_draft_tokens.cpu[ + :num_reqs + ], ) if ubatch_slices is not None: common_attn_metadata_list = split_attn_metadata( - ubatch_slices, common_attn_metadata) + ubatch_slices, common_attn_metadata + ) for ubid, common_attn_metadata in enumerate( - common_attn_metadata_list): - attn_metadata_i = (attn_group.get_metadata_builder( - ubatch_id=ubid).build( - common_prefix_len=common_prefix_len, - common_attn_metadata=common_attn_metadata)) + common_attn_metadata_list + ): + attn_metadata_i = attn_group.get_metadata_builder( + ubatch_id=ubid + ).build( + common_prefix_len=common_prefix_len, + common_attn_metadata=common_attn_metadata, + ) for layer_name in kv_cache_group_spec.layer_names: assert type(attn_metadata) is list attn_metadata[ubid][layer_name] = attn_metadata_i @@ -1297,9 +1400,9 @@ def _prepare_inputs( attn_metadata_i = builder.build( common_prefix_len=common_prefix_len, common_attn_metadata=common_attn_metadata, - **extra_attn_metadata_args) - use_cascade_attn |= getattr(attn_metadata_i, "use_cascade", - False) + **extra_attn_metadata_args, + ) + use_cascade_attn |= getattr(attn_metadata_i, "use_cascade", False) for layer_name in attn_group.layer_names: attn_metadata[layer_name] = attn_metadata_i @@ -1311,10 +1414,17 @@ def _prepare_inputs( if self.lora_config: self.set_active_loras(self.input_batch, num_scheduled_tokens) - return (attn_metadata, logits_indices, spec_decode_metadata, - num_scheduled_tokens, spec_decode_common_attn_metadata, - max_num_scheduled_tokens, ubatch_slices, - num_tokens_after_padding, use_cascade_attn) + return ( + attn_metadata, + logits_indices, + spec_decode_metadata, + num_scheduled_tokens, + spec_decode_common_attn_metadata, + max_num_scheduled_tokens, + ubatch_slices, + num_tokens_after_padding, + use_cascade_attn, + ) def _compute_cascade_attn_prefix_len( self, @@ -1386,18 +1496,20 @@ def _compute_cascade_attn_prefix_len( # this case. num_reqs = len(num_scheduled_tokens) common_prefix_len = min( - common_prefix_len, - self.input_batch.num_computed_tokens_cpu[:num_reqs].min()) + common_prefix_len, self.input_batch.num_computed_tokens_cpu[:num_reqs].min() + ) # common_prefix_len should be a multiple of the block size. - common_prefix_len = (common_prefix_len // kv_cache_spec.block_size * - kv_cache_spec.block_size) - use_sliding_window = (isinstance(kv_cache_spec, SlidingWindowSpec) or - (isinstance(kv_cache_spec, FullAttentionSpec) - and kv_cache_spec.sliding_window is not None)) - use_local_attention = ( - isinstance(kv_cache_spec, ChunkedLocalAttentionSpec) - or (isinstance(kv_cache_spec, FullAttentionSpec) - and kv_cache_spec.attention_chunk_size is not None)) + common_prefix_len = ( + common_prefix_len // kv_cache_spec.block_size * kv_cache_spec.block_size + ) + use_sliding_window = isinstance(kv_cache_spec, SlidingWindowSpec) or ( + isinstance(kv_cache_spec, FullAttentionSpec) + and kv_cache_spec.sliding_window is not None + ) + use_local_attention = isinstance(kv_cache_spec, ChunkedLocalAttentionSpec) or ( + isinstance(kv_cache_spec, FullAttentionSpec) + and kv_cache_spec.attention_chunk_size is not None + ) assert isinstance(kv_cache_spec, AttentionSpec) use_cascade = attn_metadata_builder.use_cascade_attention( common_prefix_len=common_prefix_len, @@ -1417,18 +1529,15 @@ def _calc_mrope_positions(self, scheduler_output: "SchedulerOutput"): req = self.requests[req_id] assert req.mrope_positions is not None - num_computed_tokens = \ - self.input_batch.num_computed_tokens_cpu[index] - num_scheduled_tokens = \ - scheduler_output.num_scheduled_tokens[req_id] + num_computed_tokens = self.input_batch.num_computed_tokens_cpu[index] + num_scheduled_tokens = scheduler_output.num_scheduled_tokens[req_id] num_prompt_tokens = length_from_prompt_token_ids_or_embeds( - req.prompt_token_ids, req.prompt_embeds) + req.prompt_token_ids, req.prompt_embeds + ) if num_computed_tokens + num_scheduled_tokens > num_prompt_tokens: - prompt_part_len = max(0, - num_prompt_tokens - num_computed_tokens) - completion_part_len = max( - 0, num_scheduled_tokens - prompt_part_len) + prompt_part_len = max(0, num_prompt_tokens - num_computed_tokens) + completion_part_len = max(0, num_scheduled_tokens - prompt_part_len) else: prompt_part_len = num_scheduled_tokens completion_part_len = 0 @@ -1442,8 +1551,9 @@ def _calc_mrope_positions(self, scheduler_output: "SchedulerOutput"): src_start = num_computed_tokens src_end = num_computed_tokens + prompt_part_len - self.mrope_positions.cpu[:, dst_start:dst_end] = ( - req.mrope_positions[:, src_start:src_end]) + self.mrope_positions.cpu[:, dst_start:dst_end] = req.mrope_positions[ + :, src_start:src_end + ] mrope_pos_ptr += prompt_part_len if completion_part_len > 0: @@ -1483,10 +1593,12 @@ def _calc_spec_decode_metadata( # Step 1. cu_num_sampled_tokens: [4, 5, 8, 9, 11] # arange: [0, 1, 2, 3, 0, 0, 1, 2, 0, 0, 1] cu_num_sampled_tokens, arange = self._get_cumsum_and_arange( - num_sampled_tokens, cumsum_dtype=np.int32) + num_sampled_tokens, cumsum_dtype=np.int32 + ) # Step 2. [0, 0, 0, 0, 103, 104, 104, 104, 206, 207, 207] logits_indices = np.repeat( - cu_num_scheduled_tokens - num_sampled_tokens, num_sampled_tokens) + cu_num_scheduled_tokens - num_sampled_tokens, num_sampled_tokens + ) # Step 3. [0, 1, 2, 3, 103, 104, 105, 106, 206, 207, 208] logits_indices += arange @@ -1497,22 +1609,28 @@ def _calc_spec_decode_metadata( # cu_num_draft_tokens: [3, 3, 5, 5, 6] # arange: [0, 1, 2, 0, 1, 0] cu_num_draft_tokens, arange = self._get_cumsum_and_arange( - num_draft_tokens, cumsum_dtype=np.int32) + num_draft_tokens, cumsum_dtype=np.int32 + ) # [0, 0, 0, 5, 5, 9] target_logits_indices = np.repeat( - cu_num_sampled_tokens - num_sampled_tokens, num_draft_tokens) + cu_num_sampled_tokens - num_sampled_tokens, num_draft_tokens + ) # [0, 1, 2, 5, 6, 9] target_logits_indices += arange # TODO: Optimize the CPU -> GPU copy. cu_num_draft_tokens = torch.from_numpy(cu_num_draft_tokens).to( - self.device, non_blocking=True) - logits_indices = torch.from_numpy(logits_indices).to(self.device, - non_blocking=True) + self.device, non_blocking=True + ) + logits_indices = torch.from_numpy(logits_indices).to( + self.device, non_blocking=True + ) target_logits_indices = torch.from_numpy(target_logits_indices).to( - self.device, non_blocking=True) + self.device, non_blocking=True + ) bonus_logits_indices = torch.from_numpy(bonus_logits_indices).to( - self.device, non_blocking=True) + self.device, non_blocking=True + ) # Compute the draft token ids. # draft_token_indices: [ 1, 2, 3, 105, 106, 208] @@ -1536,23 +1654,26 @@ def _prepare_kv_sharing_fast_prefill( assert self.kv_sharing_fast_prefill_logits_indices is not None num_logits = logits_indices.shape[0] assert num_logits > 0 - self.kv_sharing_fast_prefill_logits_indices[:num_logits].copy_( - logits_indices) + self.kv_sharing_fast_prefill_logits_indices[:num_logits].copy_(logits_indices) # There might have leftover indices in logits_indices[num_logits:] # from previous iterations, whose values may be greater than the # batch size in the current iteration. To ensure indices are always # valid, we fill the padded indices with the last index. self.kv_sharing_fast_prefill_logits_indices[num_logits:].fill_( - logits_indices[-1].item()) - if (self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE - and num_logits <= self.cudagraph_batch_sizes[-1]): + logits_indices[-1].item() + ) + if ( + self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE + and num_logits <= self.cudagraph_batch_sizes[-1] + ): # Use piecewise CUDA graphs. # Add padding to the batch size. num_logits_padded = self.vllm_config.pad_for_cudagraph(num_logits) else: num_logits_padded = num_logits - logits_indices_padded = ( - self.kv_sharing_fast_prefill_logits_indices[:num_logits_padded]) + logits_indices_padded = self.kv_sharing_fast_prefill_logits_indices[ + :num_logits_padded + ] return logits_indices_padded def _batch_mm_kwargs_from_scheduler( @@ -1591,7 +1712,8 @@ def _batch_mm_kwargs_from_scheduler( def _execute_mm_encoder(self, scheduler_output: "SchedulerOutput"): # Batch the multi-modal inputs using the helper method. mm_kwargs, mm_hashes_pos = self._batch_mm_kwargs_from_scheduler( - scheduler_output) + scheduler_output + ) if not mm_kwargs: return @@ -1606,10 +1728,10 @@ def _execute_mm_encoder(self, scheduler_output: "SchedulerOutput"): model = cast(SupportsMultiModal, self.model) encoder_outputs = [] for modality, num_items, mm_kwargs_group in group_mm_kwargs_by_modality( - mm_kwargs, - device=self.device, - pin_memory=self.pin_memory, - merge_by_field_config=model.merge_by_field_config, + mm_kwargs, + device=self.device, + pin_memory=self.pin_memory, + merge_by_field_config=model.merge_by_field_config, ): # (ekhvedchenia): Temporary hack to limit peak memory usage when # processing multimodal data.This solves the issue with scheduler @@ -1623,11 +1745,13 @@ def _execute_mm_encoder(self, scheduler_output: "SchedulerOutput"): micro_batch_size = 1 for i in range(0, num_items, micro_batch_size): micro_batch_mm_inputs = dict( - (k, v[i:i + micro_batch_size]) - for k, v in mm_kwargs_group.items()) + (k, v[i : i + micro_batch_size]) + for k, v in mm_kwargs_group.items() + ) micro_batch_outputs = model.get_multimodal_embeddings( - **micro_batch_mm_inputs) + **micro_batch_mm_inputs + ) curr_group_outputs.extend(micro_batch_outputs) else: @@ -1638,8 +1762,7 @@ def _execute_mm_encoder(self, scheduler_output: "SchedulerOutput"): # 2. A list or tuple (length: num_items) of tensors, # each of shape (feature_size, hidden_size) in case the feature # size is dynamic depending on the input multimodal items. - curr_group_outputs = model.get_multimodal_embeddings( - **mm_kwargs_group) + curr_group_outputs = model.get_multimodal_embeddings(**mm_kwargs_group) sanity_check_mm_encoder_outputs( curr_group_outputs, @@ -1671,11 +1794,9 @@ def _gather_mm_embeddings( for req_id in self.input_batch.req_ids: mm_embeds_req: list[torch.Tensor] = [] - num_scheduled_tokens = scheduler_output.num_scheduled_tokens[ - req_id] + num_scheduled_tokens = scheduler_output.num_scheduled_tokens[req_id] req_state = self.requests[req_id] - num_computed_tokens = \ - req_state.num_computed_tokens + shift_computed_tokens + num_computed_tokens = req_state.num_computed_tokens + shift_computed_tokens for mm_feature in req_state.mm_features: pos_info = mm_feature.mm_position @@ -1703,15 +1824,15 @@ def _gather_mm_embeddings( mm_hash = mm_feature.identifier encoder_output = self.encoder_cache.get(mm_hash, None) - assert encoder_output is not None,\ - f"Encoder cache miss for {mm_hash}." + assert encoder_output is not None, f"Encoder cache miss for {mm_hash}." if (is_embed := pos_info.is_embed) is not None: is_embed = is_embed[start_idx:end_idx] req_start_pos = req_start_idx + start_pos - num_computed_tokens - is_mm_embed[req_start_pos+start_idx:req_start_pos + end_idx] \ - = True if is_embed is None else is_embed + is_mm_embed[req_start_pos + start_idx : req_start_pos + end_idx] = ( + True if is_embed is None else is_embed + ) mm_embeds_item = gather_mm_placeholders( encoder_output[start_idx:end_idx], @@ -1728,7 +1849,8 @@ def _gather_mm_embeddings( multimodal_embeddings=mm_embeds_req, mrope_positions=req_state.mrope_positions, num_computed_tokens=req_state.num_computed_tokens, - )) + ) + ) req_state.mrope_positions.copy_(new_mrope_positions) req_state.mrope_position_delta = new_delta @@ -1762,10 +1884,10 @@ def _extract_encoder_inputs( model = cast(SupportsMultiModal, self.model) encoder_features = {} for _, _, mm_kwargs_group in group_mm_kwargs_by_modality( - mm_kwargs, - device=self.device, - pin_memory=self.pin_memory, - merge_by_field_config=model.merge_by_field_config, + mm_kwargs, + device=self.device, + pin_memory=self.pin_memory, + merge_by_field_config=model.merge_by_field_config, ): # Add the grouped features to encoder_features dict # This allows the model to receive them as kwargs (e.g., @@ -1802,21 +1924,24 @@ def get_supported_pooling_tasks(self) -> list[PoolingTask]: supported_tasks = list(model.pooler.get_supported_tasks()) - if (self.scheduler_config.chunked_prefill_enabled - and "encode" in supported_tasks): + if ( + self.scheduler_config.chunked_prefill_enabled + and "encode" in supported_tasks + ): supported_tasks.remove("encode") - logger.debug_once("Chunked prefill is not supported with " - "encode task which using ALL pooling. " - "Please turn off chunked prefill by " - "`--no-enable-chunked-prefill` before using it.") + logger.debug_once( + "Chunked prefill is not supported with " + "encode task which using ALL pooling. " + "Please turn off chunked prefill by " + "`--no-enable-chunked-prefill` before using it." + ) if "score" in supported_tasks: num_labels = getattr(self.model_config.hf_config, "num_labels", 0) if num_labels != 1: supported_tasks.remove("score") - logger.debug_once( - "Score API is only enabled for num_labels == 1.") + logger.debug_once("Score API is only enabled for num_labels == 1.") return supported_tasks @@ -1831,9 +1956,11 @@ def get_supported_tasks(self) -> tuple[SupportedTask, ...]: return tuple(tasks) def sync_and_slice_intermediate_tensors( - self, num_tokens: int, intermediate_tensors: IntermediateTensors, - sync_self: bool) -> IntermediateTensors: - + self, + num_tokens: int, + intermediate_tensors: IntermediateTensors, + sync_self: bool, + ) -> IntermediateTensors: assert self.intermediate_tensors is not None tp = self.vllm_config.parallel_config.tensor_parallel_size @@ -1845,21 +1972,21 @@ def sync_and_slice_intermediate_tensors( assert intermediate_tensors is not None for k, v in intermediate_tensors.items(): is_scattered = k == "residual" and is_rs - copy_len = num_tokens // tp if is_scattered else \ - num_tokens + copy_len = num_tokens // tp if is_scattered else num_tokens self.intermediate_tensors[k][:copy_len].copy_( - v[:copy_len], non_blocking=True) - - return IntermediateTensors({ - k: - v[:num_tokens // - tp] if k == "residual" and is_rs else v[:num_tokens] - for k, v in self.intermediate_tensors.items() - }) - - def eplb_step(self, - is_dummy: bool = False, - is_profile: bool = False) -> None: + v[:copy_len], non_blocking=True + ) + + return IntermediateTensors( + { + k: v[: num_tokens // tp] + if k == "residual" and is_rs + else v[:num_tokens] + for k, v in self.intermediate_tensors.items() + } + ) + + def eplb_step(self, is_dummy: bool = False, is_profile: bool = False) -> None: """ Step for the EPLB (Expert Parallelism Load Balancing) state. """ @@ -1876,8 +2003,7 @@ def eplb_step(self, log_stats=self.parallel_config.eplb_config.log_balancedness, ) - def get_dp_padding(self, - num_tokens: int) -> tuple[int, Optional[torch.Tensor]]: + def get_dp_padding(self, num_tokens: int) -> tuple[int, Optional[torch.Tensor]]: """ Determines the total number of tokens that each rank will run. All ranks will be padded out so that they run with the same number @@ -1904,31 +2030,33 @@ def get_dp_padding(self, return 0, None num_tokens_across_dp = DPMetadata.num_tokens_across_dp( - num_tokens, dp_size, dp_rank) + num_tokens, dp_size, dp_rank + ) max_tokens_across_dp_cpu = torch.max(num_tokens_across_dp).item() - num_tokens_after_padding = torch.tensor([max_tokens_across_dp_cpu] * - dp_size, - device="cpu", - dtype=torch.int32) + num_tokens_after_padding = torch.tensor( + [max_tokens_across_dp_cpu] * dp_size, device="cpu", dtype=torch.int32 + ) return max_tokens_across_dp_cpu - num_tokens, num_tokens_after_padding def get_local_padding(self, num_tokens_unpadded: int) -> int: - num_tokens_padded = num_tokens_unpadded - if (self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE - and num_tokens_unpadded <= self.cudagraph_batch_sizes[-1]): + if ( + self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE + and num_tokens_unpadded <= self.cudagraph_batch_sizes[-1] + ): # Use piecewise CUDA graphs. # Add padding to the batch size. - num_tokens_padded = self.vllm_config.pad_for_cudagraph( - num_tokens_unpadded) + num_tokens_padded = self.vllm_config.pad_for_cudagraph(num_tokens_unpadded) else: # Eager mode. # Pad tokens to multiple of tensor_parallel_size when # enabled collective fusion for SP tp_size = self.vllm_config.parallel_config.tensor_parallel_size - if self.vllm_config.compilation_config.pass_config. \ - enable_sequence_parallelism and tp_size > 1: + if ( + self.vllm_config.compilation_config.pass_config.enable_sequence_parallelism + and tp_size > 1 + ): num_tokens_padded = round_up(num_tokens_unpadded, tp_size) num_pad_tokens = num_tokens_padded - num_tokens_unpadded @@ -1938,12 +2066,13 @@ def get_local_padding(self, num_tokens_unpadded: int) -> int: # Should be called after attention metadata creation. This just pads # the second ubatch slice out to the total number of tokens # (num_tokens + padding) - def pad_out_ubatch_slice(self, ubatch_slices: UBatchSlices, - num_total_tokens: int): - padded_second_ubatch_slice = slice(ubatch_slices[1].token_slice.start, - num_total_tokens) - ubatch_slices[1] = UBatchSlice(padded_second_ubatch_slice, - padded_second_ubatch_slice) + def pad_out_ubatch_slice(self, ubatch_slices: UBatchSlices, num_total_tokens: int): + padded_second_ubatch_slice = slice( + ubatch_slices[1].token_slice.start, num_total_tokens + ) + ubatch_slices[1] = UBatchSlice( + padded_second_ubatch_slice, padded_second_ubatch_slice + ) def _pool( self, @@ -1951,16 +2080,16 @@ def _pool( num_scheduled_tokens: int, num_scheduled_tokens_np: np.ndarray, ) -> ModelRunnerOutput: - assert self.input_batch.num_reqs ==\ - len(self.input_batch.pooling_params), \ - "Either all or none of the requests in" \ - " a batch must be pooling request" + assert self.input_batch.num_reqs == len(self.input_batch.pooling_params), ( + "Either all or none of the requests in a batch must be pooling request" + ) hidden_states = hidden_states[:num_scheduled_tokens] pooling_metadata = self.input_batch.get_pooling_metadata() - pooling_metadata.build_pooling_cursor(num_scheduled_tokens_np.tolist(), - device=hidden_states.device) - seq_lens_cpu = self.seq_lens.cpu[:self.input_batch.num_reqs] + pooling_metadata.build_pooling_cursor( + num_scheduled_tokens_np.tolist(), device=hidden_states.device + ) + seq_lens_cpu = self.seq_lens.cpu[: self.input_batch.num_reqs] model = cast(VllmModelForPooling, self.model) raw_pooler_output: PoolerOutput = model.pooler( @@ -1975,8 +2104,8 @@ def _pool( pooler_output: list[Optional[torch.Tensor]] = [] for raw_output, seq_len, prompt_len in zip( - raw_pooler_output, seq_lens_cpu, pooling_metadata.prompt_lens): - + raw_pooler_output, seq_lens_cpu, pooling_metadata.prompt_lens + ): output = raw_output if seq_len == prompt_len else None pooler_output.append(output) @@ -1990,11 +2119,13 @@ def _pool( ) def _get_num_input_tokens(self, num_scheduled_tokens: int) -> int: - if (self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE - and not envs.VLLM_DISABLE_PAD_FOR_CUDAGRAPH - and hasattr(self, "cudagraph_batch_sizes") - and self.cudagraph_batch_sizes - and num_scheduled_tokens <= self.cudagraph_batch_sizes[-1]): + if ( + self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE + and not envs.VLLM_DISABLE_PAD_FOR_CUDAGRAPH + and hasattr(self, "cudagraph_batch_sizes") + and self.cudagraph_batch_sizes + and num_scheduled_tokens <= self.cudagraph_batch_sizes[-1] + ): # Use CUDA graphs. # Add padding to the batch size. return self.vllm_config.pad_for_cudagraph(num_scheduled_tokens) @@ -2003,8 +2134,10 @@ def _get_num_input_tokens(self, num_scheduled_tokens: int) -> int: # Pad tokens to multiple of tensor_parallel_size when # enabled collective fusion for SP tp_size = self.vllm_config.parallel_config.tensor_parallel_size - if (self.compilation_config.pass_config.enable_sequence_parallelism - and tp_size > 1): + if ( + self.compilation_config.pass_config.enable_sequence_parallelism + and tp_size > 1 + ): return round_up(num_scheduled_tokens, tp_size) return num_scheduled_tokens @@ -2014,10 +2147,16 @@ def _preprocess( intermediate_tensors: Optional[IntermediateTensors] = None, ubatch_slices: Optional[UBatchSlices] = None, num_tokens_after_padding: Optional[torch.Tensor] = None, - ) -> tuple[int, int, Optional[torch.Tensor], Optional[torch.Tensor], - Optional[torch.Tensor], torch.Tensor, - Optional[IntermediateTensors], dict[str, Any]]: - + ) -> tuple[ + int, + int, + Optional[torch.Tensor], + Optional[torch.Tensor], + Optional[torch.Tensor], + torch.Tensor, + Optional[IntermediateTensors], + dict[str, Any], + ]: num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens if ubatch_slices: assert num_tokens_after_padding is not None @@ -2025,18 +2164,19 @@ def _preprocess( self.pad_out_ubatch_slice(ubatch_slices, num_input_tokens) elif ubatch_slices is None: num_input_tokens = self._get_num_input_tokens(num_scheduled_tokens) - num_pad, num_tokens_after_padding = self.get_dp_padding( - num_input_tokens) + num_pad, num_tokens_after_padding = self.get_dp_padding(num_input_tokens) num_input_tokens += num_pad # _prepare_inputs may reorder the batch, so we must gather multi # modal outputs after that to ensure the correct order - if (self.supports_mm_inputs and get_pp_group().is_first_rank - and not self.model_config.is_encoder_decoder): + if ( + self.supports_mm_inputs + and get_pp_group().is_first_rank + and not self.model_config.is_encoder_decoder + ): # Run the multimodal encoder if any. self._execute_mm_encoder(scheduler_output) - mm_embeds, is_mm_embed = self._gather_mm_embeddings( - scheduler_output) + mm_embeds, is_mm_embed = self._gather_mm_embeddings(scheduler_output) # NOTE(woosuk): To unify token ids and soft tokens (vision # embeddings), we always use embeddings (rather than token ids) @@ -2048,8 +2188,7 @@ def _preprocess( ) # TODO(woosuk): Avoid the copy. Optimize. - self.inputs_embeds.gpu[:num_scheduled_tokens].copy_( - inputs_embeds_scheduled) + self.inputs_embeds.gpu[:num_scheduled_tokens].copy_(inputs_embeds_scheduled) input_ids = None inputs_embeds = self.inputs_embeds.gpu[:num_input_tokens] @@ -2070,14 +2209,15 @@ def _preprocess( # If a batch only has token ids, then including the embedding layer # in the CUDA graph will be more performant (like in the else case # below). - token_ids_idx = self.is_token_ids.gpu[:num_scheduled_tokens] \ - .nonzero(as_tuple=False) \ + token_ids_idx = ( + self.is_token_ids.gpu[:num_scheduled_tokens] + .nonzero(as_tuple=False) .squeeze(1) + ) # Some tokens ids may need to become embeds if token_ids_idx.numel() > 0: token_ids = self.input_ids.gpu[token_ids_idx] - tokens_to_embeds = self.model.get_input_embeddings( - input_ids=token_ids) + tokens_to_embeds = self.model.get_input_embeddings(input_ids=token_ids) self.inputs_embeds.gpu[token_ids_idx] = tokens_to_embeds inputs_embeds = self.inputs_embeds.gpu[:num_input_tokens] @@ -2100,10 +2240,13 @@ def _preprocess( intermediate_tensors = None else: intermediate_tensors = self.sync_and_slice_intermediate_tensors( - num_input_tokens, intermediate_tensors, True) + num_input_tokens, intermediate_tensors, True + ) - if (self.model_config.is_encoder_decoder - and scheduler_output.scheduled_encoder_inputs): + if ( + self.model_config.is_encoder_decoder + and scheduler_output.scheduled_encoder_inputs + ): encoder_inputs = self._extract_encoder_inputs(scheduler_output) model_kwargs.update(encoder_inputs) @@ -2119,8 +2262,9 @@ def _preprocess( ) def _sample( - self, logits: Optional[torch.Tensor], - spec_decode_metadata: Optional[SpecDecodeMetadata] + self, + logits: Optional[torch.Tensor], + spec_decode_metadata: Optional[SpecDecodeMetadata], ) -> SamplerOutput: # Sample the next token and get logprobs if needed. sampling_metadata = self.input_batch.sampling_metadata @@ -2159,24 +2303,28 @@ def _sample( return sampler_output def _bookkeeping_sync( - self, scheduler_output: "SchedulerOutput", - sampler_output: SamplerOutput, logits: Optional[torch.Tensor], - hidden_states: torch.Tensor, num_scheduled_tokens: int + self, + scheduler_output: "SchedulerOutput", + sampler_output: SamplerOutput, + logits: Optional[torch.Tensor], + hidden_states: torch.Tensor, + num_scheduled_tokens: int, ) -> tuple[ - dict[str, int], - Optional[LogprobsLists], - list[list[int]], - dict[str, Optional[LogprobsTensors]], - list[str], - dict[str, int], - list[int], + dict[str, int], + Optional[LogprobsLists], + list[list[int]], + dict[str, Optional[LogprobsTensors]], + list[str], + dict[str, int], + list[int], ]: num_nans_in_logits = {} if envs.VLLM_COMPUTE_NANS_IN_LOGITS: num_nans_in_logits = self._get_nans_in_logits(logits) - discard_sampled_tokens_req_indices = \ - self.discard_request_indices.np[:self.num_discarded_requests] + discard_sampled_tokens_req_indices = self.discard_request_indices.np[ + : self.num_discarded_requests + ] for i in discard_sampled_tokens_req_indices: gen = self.input_batch.generators.get(int(i)) if gen is not None: @@ -2185,14 +2333,14 @@ def _bookkeeping_sync( # Copy some objects so they don't get modified after returning. # This is important when using async scheduling. req_ids_output_copy = self.input_batch.req_ids.copy() - req_id_to_index_output_copy = \ - self.input_batch.req_id_to_index.copy() + req_id_to_index_output_copy = self.input_batch.req_id_to_index.copy() # NOTE: GPU -> CPU Sync happens here. # Move as many CPU operations as possible before this sync point. logprobs_tensors = sampler_output.logprobs_tensors - logprobs_lists = logprobs_tensors.tolists() \ - if logprobs_tensors is not None else None + logprobs_lists = ( + logprobs_tensors.tolists() if logprobs_tensors is not None else None + ) # Compute prompt logprobs if needed. prompt_logprobs_dict = self._get_prompt_logprobs_dict( @@ -2227,10 +2375,10 @@ def _bookkeeping_sync( # Cache the sampled tokens on the GPU and avoid CPU sync. # These will be copied into input_ids in the next step # when preparing inputs. - self.input_batch.prev_sampled_token_ids = \ - sampled_token_ids - self.input_batch.prev_sampled_token_ids_invalid_indices = \ + self.input_batch.prev_sampled_token_ids = sampled_token_ids + self.input_batch.prev_sampled_token_ids_invalid_indices = ( invalid_req_indices_set + ) self.input_batch.prev_req_id_to_index = { req_id: i for i, req_id in enumerate(self.input_batch.req_ids) @@ -2245,8 +2393,7 @@ def _bookkeeping_sync( req_ids = self.input_batch.req_ids for req_idx in range(num_sampled_tokens): if self.use_async_scheduling: - sampled_ids = [-1] if \ - req_idx not in invalid_req_indices_set else None + sampled_ids = [-1] if req_idx not in invalid_req_indices_set else None else: sampled_ids = valid_sampled_token_ids[req_idx] if not sampled_ids: @@ -2257,7 +2404,8 @@ def _bookkeeping_sync( assert end_idx <= self.max_model_len + 1, ( "Sampled token IDs exceed the max model length + 1. " f"Total number of tokens: {end_idx} > max_model_len + 1: " - f"{self.max_model_len + 1}") + f"{self.max_model_len + 1}" + ) n_tokens_cache = len(sampled_ids) @@ -2270,11 +2418,12 @@ def _bookkeeping_sync( if end_idx == self.max_model_len + 1: n_tokens_cache -= 1 - self.input_batch.token_ids_cpu[req_idx, start_idx:( - start_idx + n_tokens_cache)] = sampled_ids[:n_tokens_cache] - self.input_batch.is_token_ids[req_idx, - start_idx:(start_idx + - n_tokens_cache)] = True + self.input_batch.token_ids_cpu[ + req_idx, start_idx : (start_idx + n_tokens_cache) + ] = sampled_ids[:n_tokens_cache] + self.input_batch.is_token_ids[ + req_idx, start_idx : (start_idx + n_tokens_cache) + ] = True self.input_batch.num_tokens_no_spec[req_idx] = end_idx self.input_batch.num_tokens[req_idx] = end_idx @@ -2319,7 +2468,7 @@ def _model_forward( """Helper method to call the model forward pass. This method can be overridden by subclasses for model execution. - Motivation: We can inspect only this method versus + Motivation: We can inspect only this method versus the whole execute_model, which has additional logic. Args: @@ -2356,18 +2505,27 @@ def execute_model( # Return empty ModelRunnerOutput if no work to do. return EMPTY_MODEL_RUNNER_OUTPUT return self.kv_connector_no_forward( - scheduler_output, self.vllm_config) + scheduler_output, self.vllm_config + ) if self.cache_config.kv_sharing_fast_prefill: assert not self.input_batch.num_prompt_logprobs, ( "--kv-sharing-fast-prefill produces incorrect " "logprobs for prompt tokens, tokens, please disable " - "it when the requests need prompt logprobs") + "it when the requests need prompt logprobs" + ) # Prepare the decoder inputs. - (attn_metadata, logits_indices, spec_decode_metadata, - num_scheduled_tokens_np, spec_decode_common_attn_metadata, - max_query_len, ubatch_slices, num_tokens_after_padding, - use_cascade_attn) = self._prepare_inputs(scheduler_output) + ( + attn_metadata, + logits_indices, + spec_decode_metadata, + num_scheduled_tokens_np, + spec_decode_common_attn_metadata, + max_query_len, + ubatch_slices, + num_tokens_after_padding, + use_cascade_attn, + ) = self._prepare_inputs(scheduler_output) ( num_scheduled_tokens, @@ -2378,26 +2536,33 @@ def execute_model( positions, intermediate_tensors, model_kwargs, - ) = self._preprocess(scheduler_output, intermediate_tensors, - ubatch_slices, num_tokens_after_padding) - - uniform_decode = (max_query_len - == self.uniform_decode_query_len) and ( - num_scheduled_tokens - == self.input_batch.num_reqs * max_query_len) - batch_descriptor = BatchDescriptor(num_tokens=num_input_tokens, - uniform_decode=uniform_decode) - cudagraph_runtime_mode, batch_descriptor = \ - self.cudagraph_dispatcher.dispatch(batch_descriptor, - use_cascade_attn) + ) = self._preprocess( + scheduler_output, + intermediate_tensors, + ubatch_slices, + num_tokens_after_padding, + ) + + uniform_decode = (max_query_len == self.uniform_decode_query_len) and ( + num_scheduled_tokens == self.input_batch.num_reqs * max_query_len + ) + batch_descriptor = BatchDescriptor( + num_tokens=num_input_tokens, uniform_decode=uniform_decode + ) + cudagraph_runtime_mode, batch_descriptor = ( + self.cudagraph_dispatcher.dispatch(batch_descriptor, use_cascade_attn) + ) # Set cudagraph mode to none if calc_kv_scales is true. if attn_metadata is not None: - metadata_list = (attn_metadata.values() if isinstance( - attn_metadata, dict) else [attn_metadata]) + metadata_list = ( + attn_metadata.values() + if isinstance(attn_metadata, dict) + else [attn_metadata] + ) if any( - getattr(m, 'enable_kv_scales_calculation', False) - for m in metadata_list): + getattr(m, "enable_kv_scales_calculation", False) for m in metadata_list + ): cudagraph_runtime_mode = CUDAGraphMode.NONE # This is currently to get around the assert in the DPMetadata @@ -2407,7 +2572,8 @@ def execute_model( # Run the model. # Use persistent buffers for CUDA graphs. - with (set_forward_context( + with ( + set_forward_context( attn_metadata, self.vllm_config, num_tokens=num_input_tokens, @@ -2415,9 +2581,10 @@ def execute_model( cudagraph_runtime_mode=cudagraph_runtime_mode, batch_descriptor=batch_descriptor, ubatch_slices=ubatch_slices, - ), record_function_or_nullcontext("Forward"), - self.maybe_get_kv_connector_output(scheduler_output) as - kv_connector_output): + ), + record_function_or_nullcontext("Forward"), + self.maybe_get_kv_connector_output(scheduler_output) as kv_connector_output, + ): model_output = self._model_forward( input_ids=input_ids, positions=positions, @@ -2445,8 +2612,9 @@ def execute_model( if self.is_pooling_model: # Return the pooling output. - output = self._pool(hidden_states, num_scheduled_tokens, - num_scheduled_tokens_np) + output = self._pool( + hidden_states, num_scheduled_tokens, num_scheduled_tokens_np + ) output.kv_connector_output = kv_connector_output return output @@ -2458,14 +2626,15 @@ def execute_model( if not get_pp_group().is_last_rank: all_gather_tensors = { - "residual": - not is_residual_scattered_for_sp( - self.vllm_config, num_input_tokens) + "residual": not is_residual_scattered_for_sp( + self.vllm_config, num_input_tokens + ) } get_pp_group().send_tensor_dict( hidden_states.tensors, all_gather_group=get_tp_group(), - all_gather_tensors=all_gather_tensors) + all_gather_tensors=all_gather_tensors, + ) logits = None else: sample_hidden_states = hidden_states[logits_indices] @@ -2475,16 +2644,17 @@ def execute_model( if logits is not None: model_output_broadcast_data["logits"] = logits.contiguous() - model_output_broadcast_data = get_pp_group( - ).broadcast_tensor_dict(model_output_broadcast_data, - src=len(get_pp_group().ranks) - 1) + model_output_broadcast_data = get_pp_group().broadcast_tensor_dict( + model_output_broadcast_data, src=len(get_pp_group().ranks) - 1 + ) assert model_output_broadcast_data is not None logits = model_output_broadcast_data["logits"] # Apply structured output bitmasks if present if scheduler_output.grammar_bitmask is not None: - apply_grammar_bitmask(scheduler_output, self.input_batch, - logits, self.device) + apply_grammar_bitmask( + scheduler_output, self.input_batch, logits, self.device + ) with record_function_or_nullcontext("Sample"): sampler_output = self._sample(logits, spec_decode_metadata) @@ -2506,23 +2676,30 @@ def propose_draft_token_ids(sampled_token_ids): sampler_output=sampler_output, ) - use_padded_batch = self.speculative_config and \ - (self.speculative_config.use_eagle() - or self.speculative_config.uses_draft_model()) and \ - not self.speculative_config.disable_padded_drafter_batch + use_padded_batch = ( + self.speculative_config + and ( + self.speculative_config.use_eagle() + or self.speculative_config.uses_draft_model() + ) + and not self.speculative_config.disable_padded_drafter_batch + ) effective_drafter_max_model_len = self.max_model_len if effective_drafter_max_model_len is None: effective_drafter_max_model_len = self.model_config.max_model_len - if (self.speculative_config - and self.speculative_config.draft_model_config is not None - and self.speculative_config.draft_model_config.max_model_len - is not None): + if ( + self.speculative_config + and self.speculative_config.draft_model_config is not None + and self.speculative_config.draft_model_config.max_model_len is not None + ): effective_drafter_max_model_len = ( - self.speculative_config.draft_model_config.max_model_len) + self.speculative_config.draft_model_config.max_model_len + ) input_fits_in_drafter = spec_decode_common_attn_metadata and ( - spec_decode_common_attn_metadata.max_seq_len + - self.speculative_config.num_speculative_tokens - <= effective_drafter_max_model_len) + spec_decode_common_attn_metadata.max_seq_len + + self.speculative_config.num_speculative_tokens + <= effective_drafter_max_model_len + ) if use_padded_batch and input_fits_in_drafter: # EAGLE and draft model speculative decoding can use the # GPU sampled tokens as inputs, and does not need @@ -2538,12 +2715,15 @@ def propose_draft_token_ids(sampled_token_ids): req_ids_output_copy, req_id_to_index_output_copy, invalid_req_indices, - ) = self._bookkeeping_sync(scheduler_output, sampler_output, - logits, hidden_states, - num_scheduled_tokens) + ) = self._bookkeeping_sync( + scheduler_output, + sampler_output, + logits, + hidden_states, + num_scheduled_tokens, + ) - if (self.speculative_config and not use_padded_batch - and input_fits_in_drafter): + if self.speculative_config and not use_padded_batch and input_fits_in_drafter: # ngram and other speculative decoding methods use the sampled # tokens on the CPU, so they are run after bookkeeping. propose_draft_token_ids(valid_sampled_token_ids) @@ -2602,10 +2782,12 @@ def propose_draft_token_ids( assert isinstance(sampled_token_ids, list) assert isinstance(self.drafter, NgramProposer) draft_token_ids = self.drafter.propose( - sampled_token_ids, self.input_batch.req_ids, + sampled_token_ids, + self.input_batch.req_ids, self.input_batch.num_tokens_no_spec, self.input_batch.token_ids_cpu, - self.input_batch.spec_decode_unsupported_reqs) + self.input_batch.spec_decode_unsupported_reqs, + ) elif self.speculative_config.method == "medusa": assert isinstance(sampled_token_ids, list) assert isinstance(self.drafter, MedusaProposer) @@ -2618,8 +2800,8 @@ def propose_draft_token_ids( offset = 0 assert spec_decode_metadata is not None for num_draft, tokens in zip( - spec_decode_metadata.num_draft_tokens, - sampled_token_ids): + spec_decode_metadata.num_draft_tokens, sampled_token_ids + ): indices.append(offset + len(tokens) - 1) offset += num_draft + 1 indices = torch.tensor(indices, device=self.device) @@ -2629,38 +2811,45 @@ def propose_draft_token_ids( target_hidden_states=hidden_states, sampling_metadata=sampling_metadata, ) - elif (self.speculative_config.use_eagle() - or self.speculative_config.uses_draft_model()): - assert isinstance(self.drafter, - (EagleProposer, DraftModelProposer)) + elif ( + self.speculative_config.use_eagle() + or self.speculative_config.uses_draft_model() + ): + assert isinstance(self.drafter, (EagleProposer, DraftModelProposer)) if self.speculative_config.disable_padded_drafter_batch: # When padded-batch is disabled, the sampled_token_ids should be # the cpu-side list[list[int]] of valid sampled tokens for each # request, with invalid requests having empty lists. - assert isinstance(sampled_token_ids, list), \ - "sampled_token_ids should be a python list when" \ + assert isinstance(sampled_token_ids, list), ( + "sampled_token_ids should be a python list when" "padded-batch is disabled." + ) next_token_ids = self.drafter.prepare_next_token_ids_cpu( - sampled_token_ids, self.requests, self.input_batch, - scheduler_output.num_scheduled_tokens) + sampled_token_ids, + self.requests, + self.input_batch, + scheduler_output.num_scheduled_tokens, + ) else: # When using padded-batch, the sampled_token_ids should be # the gpu tensor of sampled tokens for each request, of shape # (num_reqs, num_spec_tokens + 1) with rejected tokens having # value -1. - assert isinstance(sampled_token_ids, torch.Tensor), \ - "sampled_token_ids should be a torch.Tensor when" \ + assert isinstance(sampled_token_ids, torch.Tensor), ( + "sampled_token_ids should be a torch.Tensor when" "padded-batch is enabled." - next_token_ids, valid_sampled_tokens_count = \ + ) + next_token_ids, valid_sampled_tokens_count = ( self.drafter.prepare_next_token_ids_padded( common_attn_metadata, sampled_token_ids, self.requests, self.input_batch, self.discard_request_indices.gpu, - self.num_discarded_requests + self.num_discarded_requests, ) + ) if spec_decode_metadata is None: token_indices_to_sample = None @@ -2670,25 +2859,26 @@ def propose_draft_token_ids( if self.use_aux_hidden_state_outputs: assert aux_hidden_states is not None target_hidden_states = torch.cat( - [h[:num_scheduled_tokens] for h in aux_hidden_states], - dim=-1) + [h[:num_scheduled_tokens] for h in aux_hidden_states], dim=-1 + ) else: target_hidden_states = hidden_states[:num_scheduled_tokens] else: if self.speculative_config.disable_padded_drafter_batch: token_indices_to_sample = None - common_attn_metadata, token_indices =\ - self.drafter.prepare_inputs( - common_attn_metadata, - sampled_token_ids, - spec_decode_metadata.num_draft_tokens) + common_attn_metadata, token_indices = self.drafter.prepare_inputs( + common_attn_metadata, + sampled_token_ids, + spec_decode_metadata.num_draft_tokens, + ) else: - common_attn_metadata, token_indices, \ - token_indices_to_sample =\ + common_attn_metadata, token_indices, token_indices_to_sample = ( self.drafter.prepare_inputs_padded( common_attn_metadata, spec_decode_metadata, - valid_sampled_tokens_count) + valid_sampled_tokens_count, + ) + ) target_token_ids = self.input_ids.gpu[token_indices] target_positions = self._get_positions(token_indices) @@ -2697,7 +2887,8 @@ def propose_draft_token_ids( elif self.use_aux_hidden_state_outputs: assert aux_hidden_states is not None target_hidden_states = torch.cat( - [h[token_indices] for h in aux_hidden_states], dim=-1) + [h[token_indices] for h in aux_hidden_states], dim=-1 + ) else: target_hidden_states = hidden_states[token_indices] @@ -2730,9 +2921,10 @@ def propose_draft_token_ids( def update_config(self, overrides: dict[str, Any]) -> None: allowed_config_names = {"load_config", "model_config"} for config_name, config_overrides in overrides.items(): - assert config_name in allowed_config_names, \ - f"Config `{config_name}` not supported. " \ + assert config_name in allowed_config_names, ( + f"Config `{config_name}` not supported. " f"Allowed configs: {allowed_config_names}" + ) config = getattr(self, config_name) new_config = update_config(config, config_overrides) setattr(self, config_name, new_config) @@ -2745,26 +2937,24 @@ def load_model(self, eep_scale_up: bool = False) -> None: logger.info("Starting to load model %s...", self.model_config.model) if eep_scale_up: from vllm.distributed.parallel_state import get_ep_group - num_local_physical_experts = torch.empty(1, - dtype=torch.int32, - device="cpu") - torch.distributed.broadcast(num_local_physical_experts, - group=get_ep_group().cpu_group, - group_src=0) + + num_local_physical_experts = torch.empty(1, dtype=torch.int32, device="cpu") + torch.distributed.broadcast( + num_local_physical_experts, group=get_ep_group().cpu_group, group_src=0 + ) num_local_physical_experts = int(num_local_physical_experts.item()) new_ep_size = get_ep_group().world_size - global_expert_load, old_global_expert_indices = ( - EplbState.recv_state()) + global_expert_load, old_global_expert_indices = EplbState.recv_state() num_logical_experts = global_expert_load.shape[1] self.parallel_config.eplb_config.num_redundant_experts = ( - num_local_physical_experts * new_ep_size - num_logical_experts) - assert old_global_expert_indices.shape[ - 1] % num_local_physical_experts == 0 - old_ep_size = old_global_expert_indices.shape[ - 1] // num_local_physical_experts + num_local_physical_experts * new_ep_size - num_logical_experts + ) + assert old_global_expert_indices.shape[1] % num_local_physical_experts == 0 + old_ep_size = ( + old_global_expert_indices.shape[1] // num_local_physical_experts + ) rank_mapping = { - old_ep_rank: old_ep_rank - for old_ep_rank in range(old_ep_size) + old_ep_rank: old_ep_rank for old_ep_rank in range(old_ep_size) } else: global_expert_load = None @@ -2776,10 +2966,12 @@ def load_model(self, eep_scale_up: bool = False) -> None: model_loader = get_model_loader(self.load_config) logger.info("Loading model from scratch...") self.model = model_loader.load_model( - vllm_config=self.vllm_config, model_config=self.model_config) + vllm_config=self.vllm_config, model_config=self.model_config + ) if self.lora_config: - self.model = self.load_lora_model(self.model, self.vllm_config, - self.device) + self.model = self.load_lora_model( + self.model, self.vllm_config, self.device + ) if hasattr(self, "drafter"): logger.info("Loading drafter model...") if self.speculative_config.use_eagle(): @@ -2792,26 +2984,29 @@ def load_model(self, eep_scale_up: bool = False) -> None: if self.use_aux_hidden_state_outputs: if supports_eagle3(self.model): self.model.set_aux_hidden_state_layers( - self.model.get_eagle3_aux_hidden_state_layers()) + self.model.get_eagle3_aux_hidden_state_layers() + ) else: raise RuntimeError( "Model does not support EAGLE3 interface but " - "aux_hidden_state_outputs was requested") + "aux_hidden_state_outputs was requested" + ) time_after_load = time.perf_counter() self.model_memory_usage = m.consumed_memory - logger.info("Model loading took %.4f GiB and %.6f seconds", - self.model_memory_usage / GiB_bytes, - time_after_load - time_before_load) + logger.info( + "Model loading took %.4f GiB and %.6f seconds", + self.model_memory_usage / GiB_bytes, + time_after_load - time_before_load, + ) prepare_communication_buffer_for_model(self.model) - self.is_multimodal_pruning_enabled = (supports_multimodal_pruning( - self.model) and self.model_config.multimodal_config. - is_multimodal_pruning_enabled()) + self.is_multimodal_pruning_enabled = ( + supports_multimodal_pruning(self.model) + and self.model_config.multimodal_config.is_multimodal_pruning_enabled() + ) - if is_mixture_of_experts( - self.model) and self.parallel_config.enable_eplb: - logger.info("EPLB is enabled for model %s.", - self.model_config.model) + if is_mixture_of_experts(self.model) and self.parallel_config.enable_eplb: + logger.info("EPLB is enabled for model %s.", self.model_config.model) self.eplb_state = EplbState.build( self.model, self.device, @@ -2822,11 +3017,10 @@ def load_model(self, eep_scale_up: bool = False) -> None: ) if ( - self.vllm_config.compilation_config.level == \ - CompilationLevel.DYNAMO_AS_IS and supports_dynamo() + self.vllm_config.compilation_config.level == CompilationLevel.DYNAMO_AS_IS + and supports_dynamo() ): - backend = self.vllm_config.compilation_config.init_backend( - self.vllm_config) + backend = self.vllm_config.compilation_config.init_backend(self.vllm_config) compilation_counter.dynamo_as_is_count += 1 self.model.compile(fullgraph=True, backend=backend) return @@ -2834,26 +3028,30 @@ def load_model(self, eep_scale_up: bool = False) -> None: # CudagraphWraper and CudagraphDispatcher of vllm. # wrap the model with full cudagraph wrapper if needed. - if self.compilation_config.cudagraph_mode.has_full_cudagraphs() \ - and not self.parallel_config.enable_dbo: - self.model = CUDAGraphWrapper(self.model, - self.vllm_config, - runtime_mode=CUDAGraphMode.FULL) + if ( + self.compilation_config.cudagraph_mode.has_full_cudagraphs() + and not self.parallel_config.enable_dbo + ): + self.model = CUDAGraphWrapper( + self.model, self.vllm_config, runtime_mode=CUDAGraphMode.FULL + ) elif self.parallel_config.enable_dbo: if self.compilation_config.cudagraph_mode.has_full_cudagraphs(): - self.model = UBatchWrapper(self.model, self.vllm_config, - CUDAGraphMode.FULL, self.device) + self.model = UBatchWrapper( + self.model, self.vllm_config, CUDAGraphMode.FULL, self.device + ) else: - self.model = UBatchWrapper(self.model, self.vllm_config, - CUDAGraphMode.NONE, self.device) + self.model = UBatchWrapper( + self.model, self.vllm_config, CUDAGraphMode.NONE, self.device + ) def reload_weights(self) -> None: - assert getattr(self, "model", None) is not None, \ + assert getattr(self, "model", None) is not None, ( "Cannot reload weights before model is loaded." + ) model_loader = get_model_loader(self.load_config) logger.info("Reloading weights inplace...") - model_loader.load_weights(self.get_model(), - model_config=self.model_config) + model_loader.load_weights(self.get_model(), model_config=self.model_config) def save_tensorized_model( self, @@ -2891,7 +3089,8 @@ def _get_prompt_logprobs_dict( num_prompt_tokens = len(request.prompt_token_ids) prompt_token_ids = torch.tensor(request.prompt_token_ids).to( - self.device, non_blocking=True) + self.device, non_blocking=True + ) # Set up target LogprobsTensors object. logprobs_tensors = in_progress_dict.get(req_id) @@ -2899,7 +3098,8 @@ def _get_prompt_logprobs_dict( # Create empty logprobs CPU tensors for the entire prompt. # If chunked, we'll copy in slice by slice. logprobs_tensors = LogprobsTensors.empty_cpu( - num_prompt_tokens - 1, num_prompt_logprobs + 1) + num_prompt_tokens - 1, num_prompt_logprobs + 1 + ) in_progress_dict[req_id] = logprobs_tensors # Determine number of logits to retrieve. @@ -2929,27 +3129,29 @@ def _get_prompt_logprobs_dict( # then there is prompt logprob generated for each index. req_idx = self.input_batch.req_id_to_index[req_id] offset = self.query_start_loc.np[req_idx].item() - prompt_hidden_states = hidden_states[offset:offset + num_logits] + prompt_hidden_states = hidden_states[offset : offset + num_logits] logits = self.model.compute_logits(prompt_hidden_states) # Get the "target" tokens for each index. For prompt at index i, # the token at prompt index i+1 is the "sampled" token we want # to gather the logprob for. - tgt_token_ids = prompt_token_ids[start_tok:start_tok + num_logits] + tgt_token_ids = prompt_token_ids[start_tok : start_tok + num_logits] # Compute prompt logprobs. logprobs = self.sampler.compute_logprobs(logits) token_ids, logprobs, ranks = self.sampler.gather_logprobs( - logprobs, num_prompt_logprobs, tgt_token_ids) + logprobs, num_prompt_logprobs, tgt_token_ids + ) # Transfer GPU->CPU async. chunk_slice = slice(start_idx, start_idx + num_logits) logprobs_tensors.logprob_token_ids[chunk_slice].copy_( - token_ids, non_blocking=True) - logprobs_tensors.logprobs[chunk_slice].copy_(logprobs, - non_blocking=True) + token_ids, non_blocking=True + ) + logprobs_tensors.logprobs[chunk_slice].copy_(logprobs, non_blocking=True) logprobs_tensors.selected_token_ranks[chunk_slice].copy_( - ranks, non_blocking=True) + ranks, non_blocking=True + ) # Remove requests that have completed prefill from the batch # num_prompt_logprobs_dict. @@ -2977,8 +3179,9 @@ def _get_nans_in_logits( req_index = self.input_batch.req_id_to_index[req_id] num_nans_in_logits[req_id] = ( int(num_nans_for_index[req_index]) - if num_nans_for_index is not None - and req_index < logits.shape[0] else 0) + if num_nans_for_index is not None and req_index < logits.shape[0] + else 0 + ) return num_nans_in_logits except IndexError: return {} @@ -3004,11 +3207,11 @@ def rand_input_ids() -> torch.Tensor: self.input_ids.gpu, low=0, high=self.model_config.get_vocab_size(), - dtype=input_ids.dtype) + dtype=input_ids.dtype, + ) logger.debug_once("Randomizing dummy data for DP Rank") - input_ids.copy_(rand_input_ids()[:input_ids.size(0)], - non_blocking=True) + input_ids.copy_(rand_input_ids()[: input_ids.size(0)], non_blocking=True) yield input_ids.fill_(0) @@ -3033,13 +3236,15 @@ def _get_mm_dummy_batch( dummy_mm_items = [dummy_mm_item] * max_items_per_batch model = cast(SupportsMultiModal, self.model) - return next(mm_kwargs_group - for _, _, mm_kwargs_group in group_mm_kwargs_by_modality( - dummy_mm_items, - device=self.device, - pin_memory=self.pin_memory, - merge_by_field_config=model.merge_by_field_config, - )) + return next( + mm_kwargs_group + for _, _, mm_kwargs_group in group_mm_kwargs_by_modality( + dummy_mm_items, + device=self.device, + pin_memory=self.pin_memory, + merge_by_field_config=model.merge_by_field_config, + ) + ) @torch.inference_mode() def _dummy_run( @@ -3076,8 +3281,10 @@ def _dummy_run( (1 token) and prefill (multiple tokens) requests. remove_lora: If False, dummy LoRAs are not destroyed after the run """ - assert cudagraph_runtime_mode is None or \ - cudagraph_runtime_mode.valid_runtime_modes() + assert ( + cudagraph_runtime_mode is None + or cudagraph_runtime_mode.valid_runtime_modes() + ) # If cudagraph_mode.decode_mode() == FULL and # cudagraph_mode.separate_routine(). This means that we are using @@ -3092,8 +3299,7 @@ def _dummy_run( # When setting max_query_len = 1, we switch to and capture the optimized # routine of FA2 for pure decode, i.e., Flashdecode + an optimization # for GQA/MQA. - max_query_len = self.uniform_decode_query_len if uniform_decode else \ - num_tokens + max_query_len = self.uniform_decode_query_len if uniform_decode else num_tokens # Set num_scheduled_tokens based on num_tokens and max_num_seqs # for dummy run with LoRA so that the num_reqs collectively @@ -3109,9 +3315,7 @@ def _dummy_run( num_reqs = num_decode_tokens + 1 # Create decode requests (1 token each) followed by prefill request - num_scheduled_tokens_list = [1] * num_decode_tokens + [ - num_prefill_tokens - ] + num_scheduled_tokens_list = [1] * num_decode_tokens + [num_prefill_tokens] # Note: Overriding max_query_len to be the prefill tokens max_query_len = num_prefill_tokens elif uniform_decode: @@ -3128,8 +3332,7 @@ def _dummy_run( assert sum(num_scheduled_tokens_list) == num_tokens assert len(num_scheduled_tokens_list) == num_reqs - num_scheduled_tokens = np.array(num_scheduled_tokens_list, - dtype=np.int32) + num_scheduled_tokens = np.array(num_scheduled_tokens_list, dtype=np.int32) total_num_scheduled_tokens = int(num_scheduled_tokens.sum()) ubatch_slices = None @@ -3183,56 +3386,61 @@ def _dummy_run( self.seq_lens.np[num_reqs:] = 0 self.seq_lens.copy_to_gpu() - cum_num_tokens, _ = self._get_cumsum_and_arange( - num_scheduled_tokens) - self.query_start_loc.np[1:num_reqs + 1] = cum_num_tokens + cum_num_tokens, _ = self._get_cumsum_and_arange(num_scheduled_tokens) + self.query_start_loc.np[1 : num_reqs + 1] = cum_num_tokens self.query_start_loc.copy_to_gpu() for kv_cache_group_id, kv_cache_group_spec in enumerate( - self.kv_cache_config.kv_cache_groups): + self.kv_cache_config.kv_cache_groups + ): common_attn_metadata = CommonAttentionMetadata( - query_start_loc=self.query_start_loc.gpu[:num_reqs + 1], - query_start_loc_cpu=self.query_start_loc.cpu[:num_reqs + - 1], + query_start_loc=self.query_start_loc.gpu[: num_reqs + 1], + query_start_loc_cpu=self.query_start_loc.cpu[: num_reqs + 1], seq_lens=self.seq_lens.gpu[:num_reqs], seq_lens_cpu=self.seq_lens.cpu[:num_reqs], - num_computed_tokens_cpu=self.input_batch. - num_computed_tokens_cpu_tensor[:num_reqs], + num_computed_tokens_cpu=self.input_batch.num_computed_tokens_cpu_tensor[ + :num_reqs + ], num_reqs=num_reqs, num_actual_tokens=num_tokens, max_query_len=max_query_len, max_seq_len=self.max_model_len, - block_table_tensor=self.input_batch. - block_table[kv_cache_group_id].get_device_tensor(num_reqs), + block_table_tensor=self.input_batch.block_table[ + kv_cache_group_id + ].get_device_tensor(num_reqs), slot_mapping=self.input_batch.block_table[ - kv_cache_group_id].slot_mapping.gpu[:num_tokens], - causal=True) + kv_cache_group_id + ].slot_mapping.gpu[:num_tokens], + causal=True, + ) for attn_group in self.attn_groups[kv_cache_group_id]: if ubatch_slices is not None: common_attn_metadata_list = split_attn_metadata( - ubatch_slices, common_attn_metadata) + ubatch_slices, common_attn_metadata + ) for ubid, common_attn_metadata in enumerate( - common_attn_metadata_list): + common_attn_metadata_list + ): assert common_attn_metadata.max_query_len == 1 - attn_metadata_i = (attn_group\ - .get_metadata_builder(ubatch_id=ubid)\ - .build_for_cudagraph_capture(common_attn_metadata)) + attn_metadata_i = attn_group.get_metadata_builder( + ubatch_id=ubid + ).build_for_cudagraph_capture(common_attn_metadata) for layer_name in attn_group.layer_names: assert type(attn_metadata) is list - attn_metadata[ubid][ - layer_name] = attn_metadata_i + attn_metadata[ubid][layer_name] = attn_metadata_i else: assert type(attn_metadata) is dict - attn_metadata_i = attn_group.get_metadata_builder()\ - .build_for_cudagraph_capture(common_attn_metadata) + attn_metadata_i = attn_group.get_metadata_builder().build_for_cudagraph_capture( + common_attn_metadata + ) for layer_name in attn_group.layer_names: attn_metadata[layer_name] = attn_metadata_i - with self.maybe_dummy_run_with_lora(self.lora_config, - num_scheduled_tokens, remove_lora): + with self.maybe_dummy_run_with_lora( + self.lora_config, num_scheduled_tokens, remove_lora + ): model_kwargs = self._init_model_kwargs(num_tokens) - if (self.supports_mm_inputs - and not self.model_config.is_encoder_decoder): + if self.supports_mm_inputs and not self.model_config.is_encoder_decoder: input_ids = None inputs_embeds = self.inputs_embeds.gpu[:num_tokens] model_kwargs = { @@ -3260,23 +3468,35 @@ def _dummy_run( self.model.make_empty_intermediate_tensors( batch_size=self.max_num_tokens, dtype=self.model_config.dtype, - device=self.device)) + device=self.device, + ) + ) intermediate_tensors = self.sync_and_slice_intermediate_tensors( - num_tokens, None, False) + num_tokens, None, False + ) # filter out the valid batch descriptor - _cg_mode, batch_descriptor = self.cudagraph_dispatcher.dispatch( - BatchDescriptor(num_tokens=num_tokens_after_padding, - uniform_decode=uniform_decode)) \ - if not is_profile else (CUDAGraphMode.NONE, None) + _cg_mode, batch_descriptor = ( + self.cudagraph_dispatcher.dispatch( + BatchDescriptor( + num_tokens=num_tokens_after_padding, + uniform_decode=uniform_decode, + ) + ) + if not is_profile + else (CUDAGraphMode.NONE, None) + ) if cudagraph_runtime_mode is not None: # we allow forcing NONE when the dispatcher disagrees to support # warm ups for cudagraph capture - assert cudagraph_runtime_mode == CUDAGraphMode.NONE or \ - cudagraph_runtime_mode == _cg_mode, ( + assert ( + cudagraph_runtime_mode == CUDAGraphMode.NONE + or cudagraph_runtime_mode == _cg_mode + ), ( f"Cudagraph runtime mode mismatch at dummy_run. " - f"Expected {_cg_mode}, but got {cudagraph_runtime_mode}.") + f"Expected {_cg_mode}, but got {cudagraph_runtime_mode}." + ) else: cudagraph_runtime_mode = _cg_mode @@ -3288,14 +3508,18 @@ def _dummy_run( if num_tokens_across_dp is not None: num_tokens_across_dp[:] = num_tokens_after_padding - with self.maybe_randomize_inputs(input_ids), set_forward_context( + with ( + self.maybe_randomize_inputs(input_ids), + set_forward_context( attn_metadata, self.vllm_config, num_tokens=num_tokens_after_padding, num_tokens_across_dp=num_tokens_across_dp, cudagraph_runtime_mode=cudagraph_runtime_mode, batch_descriptor=batch_descriptor, - ubatch_slices=ubatch_slices): + ubatch_slices=ubatch_slices, + ), + ): outputs = self.model( input_ids=input_ids, positions=positions, @@ -3313,8 +3537,7 @@ def _dummy_run( assert isinstance(self.drafter, EagleProposer) self.drafter.dummy_run(num_tokens) - if (self.speculative_config - and self.speculative_config.uses_draft_model()): + if self.speculative_config and self.speculative_config.uses_draft_model(): assert isinstance(self.drafter, DraftModelProposer) forward_ctx_kwargs = { "attn_metadata": attn_metadata, @@ -3349,8 +3572,7 @@ def _dummy_sampler_run( logits = self.model.compute_logits(hidden_states) num_reqs = logits.size(0) - dummy_tensors = lambda v: torch.full( - (num_reqs, ), v, device=self.device) + dummy_tensors = lambda v: torch.full((num_reqs,), v, device=self.device) dummy_metadata = SamplingMetadata( temperature=dummy_tensors(0.5), @@ -3371,37 +3593,39 @@ def _dummy_sampler_run( logitsprocs=LogitsProcessors(), ) try: - sampler_output = self.sampler(logits=logits, - sampling_metadata=dummy_metadata) + sampler_output = self.sampler( + logits=logits, sampling_metadata=dummy_metadata + ) except RuntimeError as e: - if 'out of memory' in str(e): + if "out of memory" in str(e): raise RuntimeError( "CUDA out of memory occurred when warming up sampler with " f"{num_reqs} dummy requests. Please try lowering " "`max_num_seqs` or `gpu_memory_utilization` when " - "initializing the engine.") from e + "initializing the engine." + ) from e else: raise e if self.speculative_config: draft_token_ids = [[0] for _ in range(num_reqs)] dummy_spec_decode_metadata = SpecDecodeMetadata.make_dummy( - draft_token_ids, self.device) + draft_token_ids, self.device + ) num_tokens = sum(len(ids) for ids in draft_token_ids) # draft_probs = torch.randn( # num_tokens, logits.shape[-1], device=self.device, # dtype=logits.dtype) draft_probs = None - target_logits = torch.randn(num_tokens, - logits.shape[-1], - device=self.device, - dtype=logits.dtype) + target_logits = torch.randn( + num_tokens, logits.shape[-1], device=self.device, dtype=logits.dtype + ) # NOTE(woosuk): Here, we should use int32 because the sampler uses # int32 for bonus_token_ids. If the dtype mismatches, re-compilation # will occur at runtime. - bonus_token_ids = torch.zeros(num_reqs, - device=self.device, - dtype=torch.int32) + bonus_token_ids = torch.zeros( + num_reqs, device=self.device, dtype=torch.int32 + ) self.rejection_sampler( dummy_spec_decode_metadata, draft_probs, @@ -3431,9 +3655,9 @@ def _dummy_pooler_run_task( num_scheduled_tokens_list, device="cpu", ) - dummy_token_ids = torch.zeros((num_reqs, req_num_tokens), - dtype=torch.int32, - device=self.device) + dummy_token_ids = torch.zeros( + (num_reqs, req_num_tokens), dtype=torch.int32, device=self.device + ) model = cast(VllmModelForPooling, self.get_model()) dummy_pooling_params = PoolingParams(task=task) @@ -3447,19 +3671,22 @@ def _dummy_pooler_run_task( pooling_params=[dummy_pooling_params] * num_reqs, ) - dummy_metadata.build_pooling_cursor(num_scheduled_tokens_list, - device=hidden_states.device) + dummy_metadata.build_pooling_cursor( + num_scheduled_tokens_list, device=hidden_states.device + ) try: - return model.pooler(hidden_states=hidden_states, - pooling_metadata=dummy_metadata) + return model.pooler( + hidden_states=hidden_states, pooling_metadata=dummy_metadata + ) except RuntimeError as e: - if 'out of memory' in str(e): + if "out of memory" in str(e): raise RuntimeError( "CUDA out of memory occurred when warming up pooler " f"({task=}) with {num_reqs} dummy requests. Please try " "lowering `max_num_seqs` or `gpu_memory_utilization` when " - "initializing the engine.") from e + "initializing the engine." + ) from e else: raise e @@ -3485,7 +3712,8 @@ def profile_run(self) -> None: if self.model_config.multimodal_config.skip_mm_profiling: logger.info( "Skipping memory profiling for multimodal encoder and " - "encoder cache.") + "encoder cache." + ) else: mm_budget = self.mm_budget assert mm_budget is not None @@ -3495,8 +3723,9 @@ def profile_run(self) -> None: # modality with the max possible input tokens even when # it supports multiple. dummy_modality = mm_budget.get_modality_with_max_tokens() - max_mm_items_per_batch = mm_budget \ - .max_items_per_batch_by_modality[dummy_modality] + max_mm_items_per_batch = mm_budget.max_items_per_batch_by_modality[ + dummy_modality + ] logger.info( "Encoder cache will be initialized with a budget of " @@ -3514,9 +3743,9 @@ def profile_run(self) -> None: ) # Run multimodal encoder. - dummy_encoder_outputs = \ - self.model.get_multimodal_embeddings( - **batched_dummy_mm_inputs) + dummy_encoder_outputs = self.model.get_multimodal_embeddings( + **batched_dummy_mm_inputs + ) sanity_check_mm_encoder_outputs( dummy_encoder_outputs, @@ -3533,7 +3762,8 @@ def profile_run(self) -> None: expanded_outputs = [] for output in dummy_encoder_outputs: expanded = output.new_zeros( - (encoder_budget, encoder_output_shape[-1])) + (encoder_budget, encoder_output_shape[-1]) + ) num_tokens = output.shape[0] expanded[:num_tokens].copy_(output) expanded_outputs.append(expanded) @@ -3541,12 +3771,12 @@ def profile_run(self) -> None: dummy_encoder_outputs = expanded_outputs # Cache the dummy encoder outputs. - self.encoder_cache["tmp"] = dict( - enumerate(dummy_encoder_outputs)) + self.encoder_cache["tmp"] = dict(enumerate(dummy_encoder_outputs)) # Add `is_profile` here to pre-allocate communication buffers - hidden_states, last_hidden_states \ - = self._dummy_run(self.max_num_tokens, is_profile=True) + hidden_states, last_hidden_states = self._dummy_run( + self.max_num_tokens, is_profile=True + ) if get_pp_group().is_last_rank: if self.is_pooling_model: output = self._dummy_pooler_run(hidden_states) @@ -3563,7 +3793,8 @@ def capture_model(self) -> int: if self.compilation_config.cudagraph_mode == CUDAGraphMode.NONE: logger.warning( "Skipping CUDA graph capture. To turn on CUDA graph capture, " - "ensure `cudagraph_mode` was not manually set to `NONE`") + "ensure `cudagraph_mode` was not manually set to `NONE`" + ) return 0 else: self.initialize_cudagraph_capture() @@ -3603,24 +3834,29 @@ def freeze_gc(): self._capture_cudagraphs( compilation_cases, cudagraph_runtime_mode=cudagraph_runtime_mode, - uniform_decode=False) + uniform_decode=False, + ) # Capture full cudagraph for uniform decode batches if we # don't already have full mixed prefill-decode cudagraphs. - if cudagraph_mode.decode_mode() == CUDAGraphMode.FULL and \ - cudagraph_mode.separate_routine(): - max_num_tokens = self.scheduler_config.max_num_seqs * \ - self.uniform_decode_query_len + if ( + cudagraph_mode.decode_mode() == CUDAGraphMode.FULL + and cudagraph_mode.separate_routine() + ): + max_num_tokens = ( + self.scheduler_config.max_num_seqs * self.uniform_decode_query_len + ) decode_cudagraph_batch_sizes = [ - x for x in self.cudagraph_batch_sizes if - x <= max_num_tokens and x >= self.uniform_decode_query_len + x + for x in self.cudagraph_batch_sizes + if x <= max_num_tokens and x >= self.uniform_decode_query_len ] - compilation_cases_decode = list( - reversed(decode_cudagraph_batch_sizes)) + compilation_cases_decode = list(reversed(decode_cudagraph_batch_sizes)) self._capture_cudagraphs( compilation_cases=compilation_cases_decode, cudagraph_runtime_mode=CUDAGraphMode.FULL, - uniform_decode=True) + uniform_decode=True, + ) torch.cuda.synchronize() end_free_gpu_memory = torch.cuda.mem_get_info()[0] @@ -3636,16 +3872,23 @@ def freeze_gc(): elapsed_time = end_time - start_time cuda_graph_size = start_free_gpu_memory - end_free_gpu_memory # This usually takes 5~20 seconds. - logger.info("Graph capturing finished in %.0f secs, took %.2f GiB", - elapsed_time, cuda_graph_size / (1 << 30)) + logger.info( + "Graph capturing finished in %.0f secs, took %.2f GiB", + elapsed_time, + cuda_graph_size / (1 << 30), + ) return cuda_graph_size - def _capture_cudagraphs(self, compilation_cases: list[int], - cudagraph_runtime_mode: CUDAGraphMode, - uniform_decode: bool): - assert cudagraph_runtime_mode != CUDAGraphMode.NONE and \ - cudagraph_runtime_mode.valid_runtime_modes(), \ - f"Invalid cudagraph runtime mode: {cudagraph_runtime_mode}" + def _capture_cudagraphs( + self, + compilation_cases: list[int], + cudagraph_runtime_mode: CUDAGraphMode, + uniform_decode: bool, + ): + assert ( + cudagraph_runtime_mode != CUDAGraphMode.NONE + and cudagraph_runtime_mode.valid_runtime_modes() + ), f"Invalid cudagraph runtime mode: {cudagraph_runtime_mode}" # Only rank 0 should print progress bar during capture if is_global_first_rank(): @@ -3654,7 +3897,9 @@ def _capture_cudagraphs(self, compilation_cases: list[int], disable=not self.load_config.use_tqdm_on_load, desc="Capturing CUDA graphs ({}, {})".format( "decode" if uniform_decode else "mixed prefill-decode", - cudagraph_runtime_mode.name)) + cudagraph_runtime_mode.name, + ), + ) # We skip EPLB here since we don't want to record dummy metrics for num_tokens in compilation_cases: @@ -3662,14 +3907,16 @@ def _capture_cudagraphs(self, compilation_cases: list[int], # cudagraph, a uniform decode batch, and the number of tokens # is above the threshold. Otherwise we just capture a non-ubatched # version of the graph - allow_microbatching = self.parallel_config.enable_dbo \ - and cudagraph_runtime_mode == CUDAGraphMode.FULL \ - and uniform_decode \ + allow_microbatching = ( + self.parallel_config.enable_dbo + and cudagraph_runtime_mode == CUDAGraphMode.FULL + and uniform_decode and check_ubatch_thresholds( config=self.vllm_config.parallel_config, num_tokens=num_tokens, uniform_decode=uniform_decode, ) + ) for _ in range(self.compilation_config.cudagraph_num_of_warmups): # Use CUDAGraphRuntimeStyle.NONE (default) for warmup. @@ -3677,29 +3924,31 @@ def _capture_cudagraphs(self, compilation_cases: list[int], # if we want to warm up attention or not. This is # different from the case where `FULL` implies capture # attention while `PIECEWISE` implies no attention. - force_attention = ( - cudagraph_runtime_mode == CUDAGraphMode.FULL) - self._dummy_run(num_tokens, - cudagraph_runtime_mode=CUDAGraphMode.NONE, - force_attention=force_attention, - uniform_decode=uniform_decode, - allow_microbatching=allow_microbatching, - skip_eplb=True, - remove_lora=False) - self._dummy_run(num_tokens, - cudagraph_runtime_mode=cudagraph_runtime_mode, - uniform_decode=uniform_decode, - allow_microbatching=allow_microbatching, - skip_eplb=True, - remove_lora=False) + force_attention = cudagraph_runtime_mode == CUDAGraphMode.FULL + self._dummy_run( + num_tokens, + cudagraph_runtime_mode=CUDAGraphMode.NONE, + force_attention=force_attention, + uniform_decode=uniform_decode, + allow_microbatching=allow_microbatching, + skip_eplb=True, + remove_lora=False, + ) + self._dummy_run( + num_tokens, + cudagraph_runtime_mode=cudagraph_runtime_mode, + uniform_decode=uniform_decode, + allow_microbatching=allow_microbatching, + skip_eplb=True, + remove_lora=False, + ) self.maybe_remove_all_loras(self.lora_config) def initialize_attn_backend(self, kv_cache_config: KVCacheConfig) -> None: """ Initialize the attention backends and attention metadata builders. """ - assert len(self.attn_groups) == 0, \ - "Attention backends are already initialized" + assert len(self.attn_groups) == 0, "Attention backends are already initialized" class AttentionGroupKey(NamedTuple): attn_backend: type[AttentionBackend] @@ -3709,8 +3958,8 @@ def get_attn_backends_for_group( kv_cache_group_spec: KVCacheGroupSpec, ) -> dict[AttentionGroupKey, list[str]]: layers = get_layers_from_vllm_config( - self.vllm_config, AttentionLayerBase, - kv_cache_group_spec.layer_names) + self.vllm_config, AttentionLayerBase, kv_cache_group_spec.layer_names + ) attn_backends = {} attn_backend_layers = defaultdict(list) # Dedupe based on full class name; this is a bit safer than @@ -3730,23 +3979,19 @@ def get_attn_backends_for_group( full_cls_name = attn_backend.full_cls_name() layer_kv_cache_spec = kv_cache_group_spec.kv_cache_spec if isinstance(layer_kv_cache_spec, UniformTypeKVCacheSpecs): - layer_kv_cache_spec = layer_kv_cache_spec.kv_cache_specs[ - layer_name] + layer_kv_cache_spec = layer_kv_cache_spec.kv_cache_specs[layer_name] key = (full_cls_name, layer_kv_cache_spec) - attn_backends[key] = AttentionGroupKey(attn_backend, - layer_kv_cache_spec) + attn_backends[key] = AttentionGroupKey( + attn_backend, layer_kv_cache_spec + ) attn_backend_layers[key].append(layer_name) - return { - attn_backends[k]: v - for k, v in attn_backend_layers.items() - } + return {attn_backends[k]: v for k, v in attn_backend_layers.items()} def create_attn_groups( attn_backends_map: dict[AttentionGroupKey, list[str]], ) -> list[AttentionGroup]: attn_groups: list[AttentionGroup] = [] - for (attn_backend, - kv_cache_spec), layer_names in attn_backends_map.items(): + for (attn_backend, kv_cache_spec), layer_names in attn_backends_map.items(): attn_group = AttentionGroup.create_with_metadata_builders( attn_backend, layer_names, @@ -3754,7 +3999,8 @@ def create_attn_groups( self.vllm_config, self.device, num_metadata_builders=1 - if not self.parallel_config.enable_dbo else 2, + if not self.parallel_config.enable_dbo + else 2, ) attn_groups.append(attn_group) @@ -3769,7 +4015,7 @@ def create_attn_groups( def initialize_cudagraph_capture(self) -> None: """ - Resolve the cudagraph_mode when there are multiple attention + Resolve the cudagraph_mode when there are multiple attention backends with potential conflicting CUDA graph support. Then initialize the cudagraph_dispatcher based on the resolved cudagraph_mode. @@ -3785,81 +4031,110 @@ def initialize_cudagraph_capture(self) -> None: # Flexible resolve the cudagraph mode cudagraph_mode = self.compilation_config.cudagraph_mode # check cudagraph for mixed batch is supported - if cudagraph_mode.mixed_mode() == CUDAGraphMode.FULL \ - and min_cg_support != AttentionCGSupport.ALWAYS: - msg = (f"CUDAGraphMode.{cudagraph_mode.name} is not supported " - f"with {min_cg_builder_name} backend (support: " - f"{min_cg_support})") + if ( + cudagraph_mode.mixed_mode() == CUDAGraphMode.FULL + and min_cg_support != AttentionCGSupport.ALWAYS + ): + msg = ( + f"CUDAGraphMode.{cudagraph_mode.name} is not supported " + f"with {min_cg_builder_name} backend (support: " + f"{min_cg_support})" + ) if min_cg_support == AttentionCGSupport.NEVER: # if not supported any full cudagraphs, just raise it. - msg += "; please try cudagraph_mode=PIECEWISE, and "\ + msg += ( + "; please try cudagraph_mode=PIECEWISE, and " "make sure compilation level is piecewise" + ) raise ValueError(msg) # attempt to resolve the full cudagraph related mode if self.compilation_config.splitting_ops_contain_attention(): msg += "; setting cudagraph_mode=FULL_AND_PIECEWISE" - cudagraph_mode = self.compilation_config.cudagraph_mode = \ + cudagraph_mode = self.compilation_config.cudagraph_mode = ( CUDAGraphMode.FULL_AND_PIECEWISE + ) else: msg += "; setting cudagraph_mode=FULL_DECODE_ONLY" - cudagraph_mode = self.compilation_config.cudagraph_mode = \ + cudagraph_mode = self.compilation_config.cudagraph_mode = ( CUDAGraphMode.FULL_DECODE_ONLY + ) logger.warning(msg) # check that if we are doing decode full-cudagraphs it is supported - if (cudagraph_mode.decode_mode() == CUDAGraphMode.FULL - and min_cg_support == AttentionCGSupport.NEVER): - msg = (f"CUDAGraphMode.{cudagraph_mode.name} is not supported " - f"with {min_cg_builder_name} backend (support: " - f"{min_cg_support})") - if (self.compilation_config.level == CompilationLevel.PIECEWISE and - (self.compilation_config.splitting_ops_contain_attention() - or self.compilation_config.use_inductor_graph_partition)): - msg += "; setting cudagraph_mode=PIECEWISE because "\ + if ( + cudagraph_mode.decode_mode() == CUDAGraphMode.FULL + and min_cg_support == AttentionCGSupport.NEVER + ): + msg = ( + f"CUDAGraphMode.{cudagraph_mode.name} is not supported " + f"with {min_cg_builder_name} backend (support: " + f"{min_cg_support})" + ) + if self.compilation_config.level == CompilationLevel.PIECEWISE and ( + self.compilation_config.splitting_ops_contain_attention() + or self.compilation_config.use_inductor_graph_partition + ): + msg += ( + "; setting cudagraph_mode=PIECEWISE because " "attention is compiled piecewise" - cudagraph_mode = self.compilation_config.cudagraph_mode = \ + ) + cudagraph_mode = self.compilation_config.cudagraph_mode = ( CUDAGraphMode.PIECEWISE + ) else: - msg += "; setting cudagraph_mode=NONE because "\ + msg += ( + "; setting cudagraph_mode=NONE because " "attention is not compiled piecewise" - cudagraph_mode = self.compilation_config.cudagraph_mode = \ + ) + cudagraph_mode = self.compilation_config.cudagraph_mode = ( CUDAGraphMode.NONE + ) logger.warning(msg) # check that if we are doing spec-decode + decode full-cudagraphs it is # supported - if (cudagraph_mode.decode_mode() == CUDAGraphMode.FULL - and self.uniform_decode_query_len > 1 and min_cg_support.value - < AttentionCGSupport.UNIFORM_BATCH.value): - msg = (f"CUDAGraphMode.{cudagraph_mode.name} is not supported" - f" with spec-decode for attention backend " - f"{min_cg_builder_name} (support: {min_cg_support})") + if ( + cudagraph_mode.decode_mode() == CUDAGraphMode.FULL + and self.uniform_decode_query_len > 1 + and min_cg_support.value < AttentionCGSupport.UNIFORM_BATCH.value + ): + msg = ( + f"CUDAGraphMode.{cudagraph_mode.name} is not supported" + f" with spec-decode for attention backend " + f"{min_cg_builder_name} (support: {min_cg_support})" + ) if self.compilation_config.splitting_ops_contain_attention(): msg += "; setting cudagraph_mode=PIECEWISE" - cudagraph_mode = self.compilation_config.cudagraph_mode = \ + cudagraph_mode = self.compilation_config.cudagraph_mode = ( CUDAGraphMode.PIECEWISE + ) else: msg += "; setting cudagraph_mode=NONE" - cudagraph_mode = self.compilation_config.cudagraph_mode = \ + cudagraph_mode = self.compilation_config.cudagraph_mode = ( CUDAGraphMode.NONE + ) logger.warning(msg) # double check that we can support full cudagraph if they are requested # even after automatic downgrades - if cudagraph_mode.has_full_cudagraphs() \ - and min_cg_support == AttentionCGSupport.NEVER: - raise ValueError(f"CUDAGraphMode.{cudagraph_mode.name} is not " - f"supported with {min_cg_builder_name} backend (" - f"support:{min_cg_support}) " - "; please try cudagraph_mode=PIECEWISE, " - "and make sure compilation level is piecewise") + if ( + cudagraph_mode.has_full_cudagraphs() + and min_cg_support == AttentionCGSupport.NEVER + ): + raise ValueError( + f"CUDAGraphMode.{cudagraph_mode.name} is not " + f"supported with {min_cg_builder_name} backend (" + f"support:{min_cg_support}) " + "; please try cudagraph_mode=PIECEWISE, " + "and make sure compilation level is piecewise" + ) # Trigger cudagraph dispatching keys initialization here (after # initializing attn backends). self.cudagraph_dispatcher.initialize_cudagraph_keys( - self.compilation_config.cudagraph_mode, - self.uniform_decode_query_len) + self.compilation_config.cudagraph_mode, self.uniform_decode_query_len + ) def calculate_reorder_batch_threshold(self) -> None: """ @@ -3871,22 +4146,20 @@ def calculate_reorder_batch_threshold(self) -> None: # check that if any backends reorder batches; that the reordering # is compatible (e.g., decode threshold is the same) - reorder_batch_threshold_i = ( - attn_metadata_builder_i.reorder_batch_threshold) + reorder_batch_threshold_i = attn_metadata_builder_i.reorder_batch_threshold if reorder_batch_threshold_i is not None: if self.reorder_batch_threshold is not None: - if reorder_batch_threshold_i != \ - self.reorder_batch_threshold: + if reorder_batch_threshold_i != self.reorder_batch_threshold: raise ValueError( f"Attention backend reorders decodes with " f"threshold {reorder_batch_threshold_i} but other " f"backend uses threshold " - f"{self.reorder_batch_threshold}") + f"{self.reorder_batch_threshold}" + ) else: self.reorder_batch_threshold = reorder_batch_threshold_i - def may_reinitialize_input_batch(self, - kv_cache_config: KVCacheConfig) -> None: + def may_reinitialize_input_batch(self, kv_cache_config: KVCacheConfig) -> None: """ Re-initialize the input batch if the block sizes are different from `[self.cache_config.block_size]`. This usually happens when there @@ -3903,7 +4176,8 @@ def may_reinitialize_input_batch(self, assert self.cache_config.cpu_offload_gb == 0, ( "Cannot re-initialize the input batch when CPU weight " "offloading is enabled. See https://github.com/vllm-project/vllm/pull/18298 " # noqa: E501 - "for more details.") + "for more details." + ) self.input_batch = InputBatch( max_num_reqs=self.max_num_reqs, max_model_len=max(self.max_model_len, self.max_encoder_len), @@ -3917,11 +4191,14 @@ def may_reinitialize_input_batch(self, is_pooling_model=self.is_pooling_model, num_speculative_tokens=( self.vllm_config.speculative_config.num_speculative_tokens - if self.vllm_config.speculative_config else 0), + if self.vllm_config.speculative_config + else 0 + ), ) def _allocate_kv_cache_tensors( - self, kv_cache_config: KVCacheConfig) -> dict[str, torch.Tensor]: + self, kv_cache_config: KVCacheConfig + ) -> dict[str, torch.Tensor]: """ Initializes the KV cache buffer with the correct size. The buffer needs to be reshaped to the desired shape before being used by the models. @@ -3931,12 +4208,12 @@ def _allocate_kv_cache_tensors( Returns: dict[str, torch.Tensor]: A map between layer names to their corresponding memory buffer for KV cache. - """ + """ kv_cache_raw_tensors: dict[str, torch.Tensor] = {} for kv_cache_tensor in kv_cache_config.kv_cache_tensors: - tensor = torch.zeros(kv_cache_tensor.size, - dtype=torch.int8, - device=self.device) + tensor = torch.zeros( + kv_cache_tensor.size, dtype=torch.int8, device=self.device + ) for layer_name in kv_cache_tensor.shared_by: kv_cache_raw_tensors[layer_name] = tensor @@ -3946,8 +4223,9 @@ def _allocate_kv_cache_tensors( if layer_name in self.runner_only_attn_layers: continue layer_names.add(layer_name) - assert layer_names == set(kv_cache_raw_tensors.keys( - )), "Some layers are not correctly initialized" + assert layer_names == set(kv_cache_raw_tensors.keys()), ( + "Some layers are not correctly initialized" + ) return kv_cache_raw_tensors def _attn_group_iterator(self) -> Iterator[AttentionGroup]: @@ -3985,8 +4263,7 @@ def _reshape_kv_cache_tensors( continue raw_tensor = kv_cache_raw_tensors[layer_name] assert raw_tensor.numel() % kv_cache_spec.page_size_bytes == 0 - num_blocks = (raw_tensor.numel() // - kv_cache_spec.page_size_bytes) + num_blocks = raw_tensor.numel() // kv_cache_spec.page_size_bytes if isinstance(kv_cache_spec, AttentionSpec): has_attn = True kv_cache_shape = attn_backend.get_kv_cache_shape( @@ -3994,41 +4271,43 @@ def _reshape_kv_cache_tensors( kv_cache_spec.block_size, kv_cache_spec.num_kv_heads, kv_cache_spec.head_size, - cache_dtype_str=self.cache_config.cache_dtype) + cache_dtype_str=self.cache_config.cache_dtype, + ) dtype = kv_cache_spec.dtype try: - kv_cache_stride_order = \ - attn_backend.get_kv_cache_stride_order() - assert len(kv_cache_stride_order) == len( - kv_cache_shape) + kv_cache_stride_order = attn_backend.get_kv_cache_stride_order() + assert len(kv_cache_stride_order) == len(kv_cache_shape) except (AttributeError, NotImplementedError): - kv_cache_stride_order = tuple( - range(len(kv_cache_shape))) + kv_cache_stride_order = tuple(range(len(kv_cache_shape))) # The allocation respects the backend-defined stride order # to ensure the semantic remains consistent for each # backend. We first obtain the generic kv cache shape and # then permute it according to the stride order which could # result in a non-contiguous tensor. - kv_cache_shape = tuple(kv_cache_shape[i] - for i in kv_cache_stride_order) + kv_cache_shape = tuple( + kv_cache_shape[i] for i in kv_cache_stride_order + ) # Maintain original KV shape view. inv_order = [ kv_cache_stride_order.index(i) for i in range(len(kv_cache_stride_order)) ] - kv_caches[layer_name] = kv_cache_raw_tensors[ - layer_name].view(dtype).view(kv_cache_shape).permute( - *inv_order) + kv_caches[layer_name] = ( + kv_cache_raw_tensors[layer_name] + .view(dtype) + .view(kv_cache_shape) + .permute(*inv_order) + ) elif isinstance(kv_cache_spec, MambaSpec): has_mamba = True raw_tensor = kv_cache_raw_tensors[layer_name] state_tensors = [] storage_offset_bytes = 0 - for (shape, dtype) in zip(kv_cache_spec.shapes, - kv_cache_spec.dtypes): + for shape, dtype in zip(kv_cache_spec.shapes, kv_cache_spec.dtypes): dtype_size = get_dtype_size(dtype) num_element_per_page = ( - kv_cache_spec.page_size_bytes // dtype_size) + kv_cache_spec.page_size_bytes // dtype_size + ) target_shape = (num_blocks, *shape) stride = torch.empty(target_shape).stride() target_stride = (num_element_per_page, *stride[1:]) @@ -4052,7 +4331,8 @@ def _reshape_kv_cache_tensors( return kv_caches def _update_hybrid_attention_mamba_layout( - self, kv_caches: dict[str, torch.Tensor]) -> None: + self, kv_caches: dict[str, torch.Tensor] + ) -> None: """ Update the layout of attention layers from (2, num_blocks, ...) to (num_blocks, 2, ...). @@ -4065,19 +4345,21 @@ def _update_hybrid_attention_mamba_layout( kv_cache_spec = group.kv_cache_spec for layer_name in group.layer_names: kv_cache = kv_caches[layer_name] - if (isinstance(kv_cache_spec, AttentionSpec) - and kv_cache.shape[0] == 2): - assert kv_cache.shape[1] != 2, \ - "Fail to determine whether the layout is " \ - "(2, num_blocks, ...) or (num_blocks, 2, ...) for " \ + if isinstance(kv_cache_spec, AttentionSpec) and kv_cache.shape[0] == 2: + assert kv_cache.shape[1] != 2, ( + "Fail to determine whether the layout is " + "(2, num_blocks, ...) or (num_blocks, 2, ...) for " f"a tensor of shape {kv_cache.shape}" + ) hidden_size = kv_cache.shape[2:].numel() - kv_cache.as_strided_(size=kv_cache.shape, - stride=(hidden_size, 2 * hidden_size, - *kv_cache.stride()[2:])) + kv_cache.as_strided_( + size=kv_cache.shape, + stride=(hidden_size, 2 * hidden_size, *kv_cache.stride()[2:]), + ) def initialize_kv_cache_tensors( - self, kv_cache_config: KVCacheConfig) -> dict[str, torch.Tensor]: + self, kv_cache_config: KVCacheConfig + ) -> dict[str, torch.Tensor]: """ Initialize the memory buffer for KV cache. @@ -4090,25 +4372,29 @@ def initialize_kv_cache_tensors( # Initialize the memory buffer for KV cache kv_cache_raw_tensors = self._allocate_kv_cache_tensors(kv_cache_config) # Change the memory buffer to the desired shape - kv_caches = self._reshape_kv_cache_tensors(kv_cache_config, - kv_cache_raw_tensors) + kv_caches = self._reshape_kv_cache_tensors( + kv_cache_config, kv_cache_raw_tensors + ) # Set up cross-layer KV cache sharing - for layer_name, target_layer_name in self.shared_kv_cache_layers.items( - ): - logger.debug("%s reuses KV cache of %s", layer_name, - target_layer_name) + for layer_name, target_layer_name in self.shared_kv_cache_layers.items(): + logger.debug("%s reuses KV cache of %s", layer_name, target_layer_name) kv_caches[layer_name] = kv_caches[target_layer_name] - num_attn_module = 2 \ - if self.model_config.hf_config.model_type == "longcat_flash" else 1 - bind_kv_cache(kv_caches, - self.compilation_config.static_forward_context, - self.kv_caches, num_attn_module) + num_attn_module = ( + 2 if self.model_config.hf_config.model_type == "longcat_flash" else 1 + ) + bind_kv_cache( + kv_caches, + self.compilation_config.static_forward_context, + self.kv_caches, + num_attn_module, + ) return kv_caches def maybe_add_kv_sharing_layers_to_kv_cache_groups( - self, kv_cache_config: KVCacheConfig) -> None: + self, kv_cache_config: KVCacheConfig + ) -> None: """ Add layers that re-use KV cache to KV cache group of its target layer. Mapping of KV cache tensors happens in `initialize_kv_cache_tensors()` @@ -4127,12 +4413,10 @@ def maybe_add_kv_sharing_layers_to_kv_cache_groups( # In You Only Cache Once (https://arxiv.org/abs/2405.05254) or other # similar KV sharing setups, only the layers that generate KV caches # are involved in the prefill phase, enabling prefill to early exit. - attn_layers = get_layers_from_vllm_config(self.vllm_config, - Attention) + attn_layers = get_layers_from_vllm_config(self.vllm_config, Attention) for layer_name in reversed(attn_layers): if layer_name in self.shared_kv_cache_layers: - self.kv_sharing_fast_prefill_eligible_layers.add( - layer_name) + self.kv_sharing_fast_prefill_eligible_layers.add(layer_name) else: break @@ -4164,23 +4448,23 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: if self.dcp_world_size > 1: layer_names = self.attn_groups[0][0].layer_names - layers = get_layers_from_vllm_config(self.vllm_config, - AttentionLayerBase, - layer_names) + layers = get_layers_from_vllm_config( + self.vllm_config, AttentionLayerBase, layer_names + ) for layer in layers.values(): assert layer.impl.need_to_return_lse_for_decode, ( "DCP requires attention impls to return" " the softmax lse for decode, but the impl " f"{layer.impl.__class__.__name__} " - "does not return the softmax lse for decode.") + "does not return the softmax lse for decode." + ) def may_add_encoder_only_layers_to_kv_cache_config(self) -> None: """ Add encoder-only layers to the KV cache config. """ block_size = self.vllm_config.cache_config.block_size - encoder_only_attn_specs: dict[AttentionSpec, - list[str]] = defaultdict(list) + encoder_only_attn_specs: dict[AttentionSpec, list[str]] = defaultdict(list) attn_layers = get_layers_from_vllm_config(self.vllm_config, Attention) for layer_name, attn_module in attn_layers.items(): if attn_module.attn_type == AttentionType.ENCODER_ONLY: @@ -4188,16 +4472,18 @@ def may_add_encoder_only_layers_to_kv_cache_config(self) -> None: block_size=block_size, num_kv_heads=attn_module.num_kv_heads, head_size=attn_module.head_size, - dtype=self.kv_cache_dtype) + dtype=self.kv_cache_dtype, + ) encoder_only_attn_specs[attn_spec].append(layer_name) self.runner_only_attn_layers.add(layer_name) if len(encoder_only_attn_specs) > 0: - assert len( - encoder_only_attn_specs - ) == 1, "Only support one encoder-only attention spec now" + assert len(encoder_only_attn_specs) == 1, ( + "Only support one encoder-only attention spec now" + ) spec, layer_names = encoder_only_attn_specs.popitem() self.kv_cache_config.kv_cache_groups.append( - KVCacheGroupSpec(layer_names=layer_names, kv_cache_spec=spec)) + KVCacheGroupSpec(layer_names=layer_names, kv_cache_spec=spec) + ) def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: """ @@ -4214,8 +4500,7 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: kv_cache_spec: dict[str, KVCacheSpec] = {} attn_layers = get_layers_from_vllm_config(self.vllm_config, Attention) for layer_name, attn_module in attn_layers.items(): - if (kv_tgt_layer := - attn_module.kv_sharing_target_layer_name) is not None: + if (kv_tgt_layer := attn_module.kv_sharing_target_layer_name) is not None: # The layer doesn't need its own KV cache and will use that of # the target layer. We skip creating a KVCacheSpec for it, so # that KV cache management logic will act as this layer does @@ -4230,59 +4515,67 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: # the attention backends if attn_module.attn_type == AttentionType.DECODER: if attn_module.sliding_window is not None: - assert not use_mla, "MLA is not supported for sliding" \ - "window" + assert not use_mla, "MLA is not supported for slidingwindow" kv_cache_spec[layer_name] = SlidingWindowSpec( block_size=block_size, num_kv_heads=attn_module.num_kv_heads, head_size=attn_module.head_size, dtype=self.kv_cache_dtype, - sliding_window=attn_module.sliding_window) + sliding_window=attn_module.sliding_window, + ) elif use_mla: kv_cache_spec[layer_name] = MLAAttentionSpec( block_size=block_size, num_kv_heads=attn_module.num_kv_heads, head_size=attn_module.head_size, dtype=self.kv_cache_dtype, - cache_dtype_str=cache_dtype_str) - elif self.attention_chunk_size is not None \ - and isinstance(attn_module, ChunkedLocalAttention): + cache_dtype_str=cache_dtype_str, + ) + elif self.attention_chunk_size is not None and isinstance( + attn_module, ChunkedLocalAttention + ): kv_cache_spec[layer_name] = ChunkedLocalAttentionSpec( block_size=block_size, num_kv_heads=attn_module.num_kv_heads, head_size=attn_module.head_size, dtype=self.kv_cache_dtype, - attention_chunk_size=self.attention_chunk_size) + attention_chunk_size=self.attention_chunk_size, + ) else: kv_cache_spec[layer_name] = FullAttentionSpec( block_size=block_size, num_kv_heads=attn_module.num_kv_heads, head_size=attn_module.head_size, - dtype=self.kv_cache_dtype) + dtype=self.kv_cache_dtype, + ) elif attn_module.attn_type == AttentionType.ENCODER_DECODER: kv_cache_spec[layer_name] = CrossAttentionSpec( block_size=block_size, num_kv_heads=attn_module.num_kv_heads, head_size=attn_module.head_size, - dtype=self.kv_cache_dtype) - elif attn_module.attn_type in (AttentionType.ENCODER, - AttentionType.ENCODER_ONLY): + dtype=self.kv_cache_dtype, + ) + elif attn_module.attn_type in ( + AttentionType.ENCODER, + AttentionType.ENCODER_ONLY, + ): # encoder-only attention does not need KV cache. continue else: - raise ValueError( - f"Unknown attention type: {attn_module.attn_type}") + raise ValueError(f"Unknown attention type: {attn_module.attn_type}") mamba_layers = get_layers_from_vllm_config(self.vllm_config, MambaBase) if len(mamba_layers) > 0: - if (self.vllm_config.speculative_config is not None - and self.vllm_config.model_config.hf_config.model_type - not in ["qwen3_next"]): + if ( + self.vllm_config.speculative_config is not None + and self.vllm_config.model_config.hf_config.model_type + not in ["qwen3_next"] + ): raise NotImplementedError( - "Mamba with speculative decoding is not supported yet.") + "Mamba with speculative decoding is not supported yet." + ) mamba_block_size = self.vllm_config.cache_config.mamba_block_size - page_size_padded = ( - self.vllm_config.cache_config.mamba_page_size_padded) + page_size_padded = self.vllm_config.cache_config.mamba_page_size_padded for layer_name, mamba_module in mamba_layers.items(): kv_cache_spec[layer_name] = MambaSpec( @@ -4293,10 +4586,13 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: mamba_type=mamba_module.mamba_type, num_speculative_blocks=( self.speculative_config.num_speculative_tokens - if self.speculative_config else 0), + if self.speculative_config + else 0 + ), ) ds_indexer_layers = get_layers_from_vllm_config( - self.vllm_config, DeepseekV32IndexerCache) + self.vllm_config, DeepseekV32IndexerCache + ) for layer_name, ds_indexer_module in ds_indexer_layers.items(): kv_cache_spec[layer_name] = ds_indexer_module.get_kv_cache_spec() @@ -4311,7 +4607,7 @@ def _to_list(self, sampled_token_ids: torch.Tensor) -> list[list[int]]: # this is in the critical path of every single model # forward loop, this has caused perf issue for a disagg # setup. - pinned = self.sampled_token_ids_pinned_cpu[:sampled_token_ids.shape[0]] + pinned = self.sampled_token_ids_pinned_cpu[: sampled_token_ids.shape[0]] pinned.copy_(sampled_token_ids, non_blocking=True) self.transfer_event.record() self.transfer_event.synchronize() diff --git a/vllm/v1/worker/utils.py b/vllm/v1/worker/utils.py index d7ac33ad8f41..d2ef0232a553 100644 --- a/vllm/v1/worker/utils.py +++ b/vllm/v1/worker/utils.py @@ -35,18 +35,18 @@ def __init__( self.model_config = model_config self.scheduler_config = scheduler_config self.mm_registry = mm_registry - self.cache = cache = processor_only_cache_from_config( - model_config, mm_registry) + self.cache = cache = processor_only_cache_from_config(model_config, mm_registry) self.max_model_len = model_config.max_model_len self.max_num_reqs = scheduler_config.max_num_seqs - self.mm_limits = mm_registry.get_mm_limits_per_prompt(model_config, - cache=cache) + self.mm_limits = mm_registry.get_mm_limits_per_prompt(model_config, cache=cache) - max_tokens_by_modality = mm_registry \ - .get_max_tokens_per_item_by_nonzero_modality(model_config, - cache=cache) + max_tokens_by_modality = ( + mm_registry.get_max_tokens_per_item_by_nonzero_modality( + model_config, cache=cache + ) + ) encoder_compute_budget, encoder_cache_size = compute_mm_encoder_budget( scheduler_config, @@ -145,17 +145,14 @@ def create_with_metadata_builders( vllm_config: VllmConfig, device: torch.device, num_metadata_builders: int = 1, - ) -> 'AttentionGroup': + ) -> "AttentionGroup": metadata_builders = [ - backend.get_builder_cls()(kv_cache_spec, layer_names, vllm_config, - device) + backend.get_builder_cls()(kv_cache_spec, layer_names, vllm_config, device) for _ in range(num_metadata_builders) ] - return AttentionGroup(backend, metadata_builders, layer_names, - kv_cache_spec) + return AttentionGroup(backend, metadata_builders, layer_names, kv_cache_spec) - def get_metadata_builder(self, - ubatch_id: int = 0) -> AttentionMetadataBuilder: + def get_metadata_builder(self, ubatch_id: int = 0) -> AttentionMetadataBuilder: assert len(self.metadata_builders) > ubatch_id return self.metadata_builders[ubatch_id] @@ -172,19 +169,22 @@ def sanity_check_mm_encoder_outputs( "Expected multimodal embeddings to be a list/tuple of 2D tensors, " f"or a single 3D tensor, but got {type(mm_embeddings)} " "instead. This is most likely due to incorrect implementation " - "of the model's `get_multimodal_embeddings` method.") + "of the model's `get_multimodal_embeddings` method." + ) assert len(mm_embeddings) == expected_num_items, ( "Expected number of multimodal embeddings to match number of " f"input items: {expected_num_items}, but got {len(mm_embeddings)=} " "instead. This is most likely due to incorrect implementation " - "of the model's `get_multimodal_embeddings` method.") + "of the model's `get_multimodal_embeddings` method." + ) assert all(e.ndim == 2 for e in mm_embeddings), ( "Expected multimodal embeddings to be a sequence of 2D tensors, " f"but got tensors with shapes {[e.shape for e in mm_embeddings]} " "instead. This is most likely due to incorrect implementation " - "of the model's `get_multimodal_embeddings` method.") + "of the model's `get_multimodal_embeddings` method." + ) def scatter_mm_placeholders( @@ -290,13 +290,12 @@ def bind_kv_cache( # Convert kv_caches dict to a list of tensors in the order of layer_index. index2name = defaultdict(list) for layer_name in kv_caches: - index2name[extract_layer_index(layer_name, - num_attn_module)].append(layer_name) + index2name[extract_layer_index(layer_name, num_attn_module)].append(layer_name) for layer_index in sorted(index2name.keys()): layer_names = index2name[layer_index] non_draft_layers = [ - name for name in layer_names if not name.startswith('draft_model.') + name for name in layer_names if not name.startswith("draft_model.") ] if len(non_draft_layers) > 1: # One typical case is encoder-decoder model, e.g., bart. @@ -323,16 +322,16 @@ def bind_kv_cache( forward_context[layer_name].kv_cache = [kv_cache] -def is_residual_scattered_for_sp(vllm_config: VllmConfig, - num_input_tokens: int) -> bool: +def is_residual_scattered_for_sp( + vllm_config: VllmConfig, num_input_tokens: int +) -> bool: """Check if the residual tensor is scattered for sequence parallelism. The residual tensor is scattered across tensor parallel ranks when sequence parallelism and tensor parallelism is enabled, and the number of input tokens is one of the compilation sizes. """ - if not vllm_config.compilation_config.pass_config.\ - enable_sequence_parallelism: + if not vllm_config.compilation_config.pass_config.enable_sequence_parallelism: return False tp = vllm_config.parallel_config.tensor_parallel_size @@ -345,4 +344,4 @@ def is_residual_scattered_for_sp(vllm_config: VllmConfig, assert num_input_tokens % tp == 0 # Currently, SP is only enabled for static size fx graphs. - return (num_input_tokens in vllm_config.compilation_config.compile_sizes) + return num_input_tokens in vllm_config.compilation_config.compile_sizes From eac09d2e1ae008340d8e4bfd1d10940960955f9e Mon Sep 17 00:00:00 2001 From: Tomas Ruiz Date: Mon, 6 Oct 2025 16:36:34 +0200 Subject: [PATCH 46/73] Get AL high again Signed-off-by: Tomas Ruiz --- examples/offline_inference/spec_decode.py | 1 + tests/v1/e2e/test_spec_decode.py | 16 +--- tests/v1/test_outputs.py | 19 ----- vllm/v1/attention/backends/utils.py | 5 +- vllm/v1/outputs.py | 3 - vllm/v1/spec_decode/draft_model.py | 90 ++--------------------- vllm/v1/spec_decode/eagle.py | 41 ++++------- vllm/v1/worker/gpu_model_runner.py | 4 - 8 files changed, 21 insertions(+), 158 deletions(-) delete mode 100644 tests/v1/test_outputs.py diff --git a/examples/offline_inference/spec_decode.py b/examples/offline_inference/spec_decode.py index 5a788c8b9b55..e8aed7e81564 100644 --- a/examples/offline_inference/spec_decode.py +++ b/examples/offline_inference/spec_decode.py @@ -128,6 +128,7 @@ def main(args): "method": args.method, "model": args.draft_model, "num_speculative_tokens": args.num_spec_tokens, + "disable_padded_drafter_batch": True, "enforce_eager": args.enforce_eager, "max_model_len": args.max_model_len, } diff --git a/tests/v1/e2e/test_spec_decode.py b/tests/v1/e2e/test_spec_decode.py index 775371897573..3ecbcfb7418d 100644 --- a/tests/v1/e2e/test_spec_decode.py +++ b/tests/v1/e2e/test_spec_decode.py @@ -16,7 +16,6 @@ from vllm.distributed import cleanup_dist_env_and_memory from vllm.outputs import RequestOutput from vllm.platforms import current_platform -from vllm.v1.spec_decode.draft_model import compute_subrange_indices from vllm.v1.spec_decode.metrics import compute_acceptance_len, compute_acceptance_rate MTP_SIMILARITY_RATE = 0.8 @@ -422,6 +421,7 @@ def test_draft_model_correctness( "max_model_len": args.max_model_len, "enforce_eager": enforce_eager, "tensor_parallel_size": args.draft_tensor_parallel_size, + "disable_padded_drafter_batch": True, }, max_model_len=args.max_model_len, gpu_memory_utilization=args.gpu_memory_utilization, @@ -480,17 +480,3 @@ def compute_exact_matches( print(f"ref_output: {ref_output.outputs[0].text}") print(f"spec_output: {spec_output.outputs[0].text}") return matches / len(ref_outputs) - - -@pytest.mark.parametrize("device", ["cpu", "cuda"]) -def test_compute_subrange_indices(device: str): - start_locs = torch.tensor([3, 6, 12], device=device) - end_locs = torch.tensor([5, 6, 15], device=device) - # fmt: off - expected_indices = torch.tensor([3, 4, 5, - 6, - 12, 13, 14, 15], - device=device) - # fmt: on - indices = compute_subrange_indices(start_locs, end_locs) - assert torch.equal(indices, expected_indices) diff --git a/tests/v1/test_outputs.py b/tests/v1/test_outputs.py deleted file mode 100644 index 5ddf923eeac1..000000000000 --- a/tests/v1/test_outputs.py +++ /dev/null @@ -1,19 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import torch - -from vllm.v1.outputs import SamplerOutput - - -def test_sampler_output(): - # fmt: off - # -1 is the padding token - sampled_token_ids = torch.tensor([ - [1, 2, 3, -1], - [1, -1, -1, -1], - [3, 2, -1, -1] - ]) - # fmt: on - so = SamplerOutput(sampled_token_ids=sampled_token_ids, logprobs_tensors=None) - expected_n_sampled_tokens = torch.tensor([3, 1, 2]) - assert so.n_sampled_tokens().eq(expected_n_sampled_tokens).all() diff --git a/vllm/v1/attention/backends/utils.py b/vllm/v1/attention/backends/utils.py index 3186a6f71b0d..54520946aaf2 100644 --- a/vllm/v1/attention/backends/utils.py +++ b/vllm/v1/attention/backends/utils.py @@ -4,7 +4,7 @@ import enum import functools from abc import abstractmethod -from dataclasses import dataclass, fields, make_dataclass, replace +from dataclasses import dataclass, fields, make_dataclass from typing import ( TYPE_CHECKING, Any, @@ -96,9 +96,6 @@ class CommonAttentionMetadata: def batch_size(self) -> int: return self.seq_lens_cpu.shape[0] - def replace(self, **kwargs) -> "CommonAttentionMetadata": - return replace(self, **kwargs) - def query_lens(self) -> torch.Tensor: return self.query_start_loc[1:] - self.query_start_loc[:-1] diff --git a/vllm/v1/outputs.py b/vllm/v1/outputs.py index 3d9c8d147090..d647b207575c 100644 --- a/vllm/v1/outputs.py +++ b/vllm/v1/outputs.py @@ -76,9 +76,6 @@ class SamplerOutput: sampled_token_ids: torch.Tensor logprobs_tensors: Optional[LogprobsTensors] - def n_sampled_tokens(self) -> torch.Tensor: - return self.sampled_token_ids.ne(-1).sum(dim=1) - @dataclass class KVConnectorOutput: diff --git a/vllm/v1/spec_decode/draft_model.py b/vllm/v1/spec_decode/draft_model.py index 27e31eed7775..ce8a0da737ce 100644 --- a/vllm/v1/spec_decode/draft_model.py +++ b/vllm/v1/spec_decode/draft_model.py @@ -14,15 +14,12 @@ extend_all_queries_by_1, extend_flat_seqs, ) -from vllm.v1.outputs import SamplerOutput from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.spec_decode.eagle import ( PADDING_SLOT_ID, CudaGraphArgs, SpecDecodeBaseProposer, - num_rejected_tokens, ) -from vllm.v1.spec_decode.metadata import SpecDecodeMetadata class DraftModelProposer(SpecDecodeBaseProposer): @@ -41,7 +38,7 @@ def __init__( ) self._raise_if_multimodal() self._raise_if_mrope() - self._raise_if_disabled_padded_drafter_batch() + self._raise_if_padded_drafter_batch() def propose( self, @@ -57,8 +54,6 @@ def propose( common_attn_metadata: CommonAttentionMetadata, sampling_metadata: SamplingMetadata, cudagraph_args: "CudaGraphArgs", - sampler_output: SamplerOutput, - spec_decode_metadata: Optional[SpecDecodeMetadata], mm_embed_inputs: Optional[tuple[list[torch.Tensor], torch.Tensor]] = None, ) -> torch.Tensor: """ @@ -72,11 +67,6 @@ def propose( token_ids=target_token_ids, positions=target_positions, ) - inputs = trim_accepted_and_rejected_tokens( - inputs=inputs, - sampler_output=sampler_output, - spec_decode_metadata=spec_decode_metadata, - ) inputs = merge_next_token_ids_into_token_ids( inputs=inputs, next_token_ids=next_token_ids, @@ -91,8 +81,6 @@ def propose( common_attn_metadata=inputs.cad, cudagraph_args=cudagraph_args, sampling_metadata=sampling_metadata, - sampler_output=sampler_output, - spec_decode_metadata=spec_decode_metadata, # below are are not used by draft model target_hidden_states=None, next_token_ids=None, @@ -114,11 +102,12 @@ def _raise_if_mrope(self): "Speculative Decoding with draft models does not support M-RoPE yet" ) - def _raise_if_disabled_padded_drafter_batch(self): - if self.vllm_config.speculative_config.disable_padded_drafter_batch: + def _raise_if_padded_drafter_batch(self): + if not self.vllm_config.speculative_config.disable_padded_drafter_batch: raise NotImplementedError( "Speculative Decoding with draft models does not support " - "disabled padded drafter batch yet" + "padded drafter batch yet. Please pass --disable-padded-drafter-batch " + "in the speculative config." ) def _model_kwargs(self, num_tokens: int) -> dict[str, Any]: @@ -185,75 +174,6 @@ class DraftModelInputs: cad: CommonAttentionMetadata -def trim_accepted_and_rejected_tokens( - inputs: DraftModelInputs, - sampler_output: SamplerOutput, - spec_decode_metadata: Optional[SpecDecodeMetadata], -) -> DraftModelInputs: - """ - Removes from the input.token_ids any tokens that have already been processed - by the draft model, as well as tokens rejected by the sampler. - Adjusts the positions accordingly, the slot mapping, - and the common_attn_metadata. - """ - cad: CommonAttentionMetadata = inputs.cad - - # Compute the new token ids and positions - n_accepted_tokens = sampler_output.n_sampled_tokens() - 1 - n_rejected_tokens = num_rejected_tokens( - spec_decode_metadata, sampler_output.n_sampled_tokens() - ) - from_loc = cad.query_start_loc[:-1] + n_accepted_tokens - to_loc = cad.query_start_loc[1:] - 1 - n_rejected_tokens - idxs = compute_subrange_indices(from_loc, to_loc) - new_token_ids = inputs.token_ids[idxs] - new_positions = inputs.positions[idxs] - - # The new slot mapping is a subset of the previous one, - # so no recomputation is needed. - new_slot_mapping = cad.slot_mapping[idxs] - - # Update common_attn_metadata - new_query_lens = to_loc - from_loc + 1 - new_query_start_loc = torch.zeros_like(cad.query_start_loc) - new_query_start_loc[1:] = new_query_lens.cumsum(0) - - new_cad: CommonAttentionMetadata = cad.replace( - query_start_loc=new_query_start_loc, - query_start_loc_cpu=new_query_start_loc.to("cpu", non_blocking=True), - num_actual_tokens=new_token_ids.shape[0], - max_query_len=new_query_lens.max().item(), - slot_mapping=new_slot_mapping, - ) - return DraftModelInputs( - token_ids=new_token_ids, positions=new_positions, cad=new_cad - ) - - -def compute_subrange_indices(start_locs: torch.Tensor, end_locs: torch.Tensor): - """ - Given two tensor of the same length containing start and end locations, - returns a tensor of indices with each subrange. E.g. - start_locs = [s1, s2, s3, ...], and - end_locs = [e1, e2, e3, ...], - return [*s1:e1, *s2:e2, *s3:e3, ...] as a flat tensor - """ - # Compute lengths of each subrange - lengths = end_locs - start_locs + 1 - # Build an index for each subrange - # torch.arange(max_len) creates [0, 1, ..., max_len-1] - # broadcasting + masking ensures we only keep valid positions - max_len = lengths.max() - offsets = torch.arange(max_len, device=start_locs.device).unsqueeze( - 0 - ) # shape [1, max_len] - mask = offsets < lengths.unsqueeze(1) # shape [n, max_len] - # Build all indices - all_indices = start_locs.unsqueeze(1) + offsets - all_indices = all_indices[mask] # flatten valid indices only - return all_indices - - def merge_next_token_ids_into_token_ids( inputs: DraftModelInputs, next_token_ids: torch.Tensor, diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index cd0a3c2d5312..b80a367e96be 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -33,7 +33,6 @@ CommonAttentionMetadata, ) from vllm.v1.kv_cache_interface import KVCacheConfig -from vllm.v1.outputs import SamplerOutput from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.spec_decode.metadata import SpecDecodeMetadata from vllm.v1.utils import CpuGpuBuffer @@ -200,8 +199,6 @@ def propose( common_attn_metadata: CommonAttentionMetadata, sampling_metadata: SamplingMetadata, cudagraph_args: "CudaGraphArgs", - sampler_output: SamplerOutput, - spec_decode_metadata: Optional[SpecDecodeMetadata], mm_embed_inputs: Optional[tuple[list[torch.Tensor], torch.Tensor]] = None, ) -> torch.Tensor: num_tokens = target_token_ids.shape[0] @@ -629,9 +626,20 @@ def prepare_inputs_padded( used as padding and filtered out later by `token_indices_to_sample`. No blocking CPU operations should be introduced in this function. """ - num_rejected_tokens_gpu = num_rejected_tokens( - spec_decode_metadata, valid_sampled_tokens_count + num_draft_tokens_gpu = torch.cat( + [ + spec_decode_metadata.cu_num_draft_tokens[0:1], + spec_decode_metadata.cu_num_draft_tokens[1:] + - spec_decode_metadata.cu_num_draft_tokens[:-1], + ] ) + + num_rejected_tokens_gpu = torch.where( + num_draft_tokens_gpu > 0, + num_draft_tokens_gpu + 1 - valid_sampled_tokens_count, + torch.zeros_like(num_draft_tokens_gpu), + ) + query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu new_query_len_per_req = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1] @@ -1197,29 +1205,6 @@ def compute_probs_and_sample_next_token( return next_token_ids, probs -def num_rejected_tokens( - spec_decode_metadata: Optional[SpecDecodeMetadata], - valid_sampled_tokens_count: torch.Tensor, -) -> torch.Tensor: - if spec_decode_metadata is None: - return torch.zeros_like(valid_sampled_tokens_count) - - num_draft_tokens_gpu = torch.cat( - [ - spec_decode_metadata.cu_num_draft_tokens[0:1], - spec_decode_metadata.cu_num_draft_tokens[1:] - - spec_decode_metadata.cu_num_draft_tokens[:-1], - ] - ) - - num_rejected_tokens_gpu = torch.where( - num_draft_tokens_gpu > 0, - num_draft_tokens_gpu + 1 - valid_sampled_tokens_count, - torch.zeros_like(num_draft_tokens_gpu), - ) - return num_rejected_tokens_gpu - - def update_batch_descriptor(cudagraph_args: CudaGraphArgs, new_num_tokens: int) -> None: """The cudagraph padding can change the num_tokens, so the batch descriptor should be updated. The cudagraph_args is modified in place.""" diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 487bced54c60..d7cb8303cc0d 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -2673,7 +2673,6 @@ def propose_draft_token_ids(sampled_token_ids): spec_decode_common_attn_metadata, cudagraph_runtime_mode=cudagraph_runtime_mode, batch_descriptor=batch_descriptor, - sampler_output=sampler_output, ) use_padded_batch = ( @@ -2775,7 +2774,6 @@ def propose_draft_token_ids( common_attn_metadata: CommonAttentionMetadata, cudagraph_runtime_mode: CUDAGraphMode, batch_descriptor: BatchDescriptor, - sampler_output: SamplerOutput, ) -> Union[list[list[int]], torch.Tensor]: num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens if self.speculative_config.method == "ngram": @@ -2913,8 +2911,6 @@ def propose_draft_token_ids( common_attn_metadata=common_attn_metadata, mm_embed_inputs=mm_embed_inputs, cudagraph_args=cudagraph_args, - sampler_output=sampler_output, - spec_decode_metadata=spec_decode_metadata, ) return draft_token_ids From ccac6cb7c8989d3f6af3fb9f8bdfc21e7872b7d3 Mon Sep 17 00:00:00 2001 From: Tomas Ruiz Date: Mon, 6 Oct 2025 16:45:24 +0200 Subject: [PATCH 47/73] Minimze changes Signed-off-by: Tomas Ruiz --- vllm/v1/spec_decode/draft_model.py | 8 +++----- vllm/v1/spec_decode/eagle.py | 3 +-- 2 files changed, 4 insertions(+), 7 deletions(-) diff --git a/vllm/v1/spec_decode/draft_model.py b/vllm/v1/spec_decode/draft_model.py index ce8a0da737ce..71026ab699cd 100644 --- a/vllm/v1/spec_decode/draft_model.py +++ b/vllm/v1/spec_decode/draft_model.py @@ -57,10 +57,8 @@ def propose( mm_embed_inputs: Optional[tuple[list[torch.Tensor], torch.Tensor]] = None, ) -> torch.Tensor: """ - - Trims unnecessary tokens from the input, like those rejected by - the sampler, or those already processed by the draft model. - - Merges the next_token_ids with the existing token ids into - a flat sequence. + This function processes the inputs first before calling the .propose() + method of the parent class. """ inputs = DraftModelInputs( cad=common_attn_metadata, @@ -184,7 +182,7 @@ def merge_next_token_ids_into_token_ids( """ Merges the next token ids with the existing token ids into a flat sequence. Does the same for the positions, computes new slot mapping, - and updates the common_attn_metadata. + and updates the common_attn_metadata. The inputs are not modified in-place. """ cad: CommonAttentionMetadata = inputs.cad diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index b80a367e96be..12c5522dff08 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -310,8 +310,7 @@ def propose( else: positions = target_positions[last_token_indices] - # NOTE(Tomas): What is the intention of this ifelse? - if self.method == "mtp": + if self.method in ("deepseek_mtp", "ernie_mtp", "longcat_flash_mtp"): hidden_states = self.hidden_states[last_token_indices] else: hidden_states = hidden_states[last_token_indices] From c094f5f361d8c91bfa3c6d87336b79fea0433b18 Mon Sep 17 00:00:00 2001 From: Tomas Ruiz Date: Tue, 7 Oct 2025 12:17:01 +0200 Subject: [PATCH 48/73] Add flag for disable_padded_drafter_batch --- examples/offline_inference/spec_decode.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/examples/offline_inference/spec_decode.py b/examples/offline_inference/spec_decode.py index e8aed7e81564..8af9a8148552 100644 --- a/examples/offline_inference/spec_decode.py +++ b/examples/offline_inference/spec_decode.py @@ -74,6 +74,7 @@ def parse_args(): parser.add_argument("--custom-mm-prompts", action="store_true") parser.add_argument("--gpu-memory-utilization", type=float, default=0.8) parser.add_argument("--request-id-prefix", type=str, default="") + parser.add_argument("--disable-padded-drafter-batch", action="store_true") return parser.parse_args() @@ -114,6 +115,7 @@ def main(args): "method": args.method, "model": eagle_dir, "num_speculative_tokens": args.num_spec_tokens, + "disable_padded_drafter_batch": disable_padded_drafter_batch, } elif args.method == "ngram": speculative_config = { From a6f8484f5ad5ea653ffc990af17b755570801d4d Mon Sep 17 00:00:00 2001 From: Tomas Ruiz Date: Tue, 7 Oct 2025 12:52:01 +0200 Subject: [PATCH 49/73] Correct typo Signed-off-by: Tomas Ruiz --- examples/offline_inference/spec_decode.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/offline_inference/spec_decode.py b/examples/offline_inference/spec_decode.py index 8af9a8148552..dd13ebcc0576 100644 --- a/examples/offline_inference/spec_decode.py +++ b/examples/offline_inference/spec_decode.py @@ -115,7 +115,7 @@ def main(args): "method": args.method, "model": eagle_dir, "num_speculative_tokens": args.num_spec_tokens, - "disable_padded_drafter_batch": disable_padded_drafter_batch, + "disable_padded_drafter_batch": args.disable_padded_drafter_batch, } elif args.method == "ngram": speculative_config = { From 4e77a809ee6821a8bbae330f49f2b82e53ce9f2f Mon Sep 17 00:00:00 2001 From: Tomas Ruiz Date: Tue, 7 Oct 2025 21:44:37 +0200 Subject: [PATCH 50/73] Ensure draft model uses CUDA graph Signed-off-by: Tomas Ruiz --- vllm/v1/spec_decode/eagle.py | 10 ++++------ vllm/v1/worker/gpu_model_runner.py | 8 ++++++++ 2 files changed, 12 insertions(+), 6 deletions(-) diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index d555be7efc3f..cad47425461a 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -287,7 +287,7 @@ def propose( num_tokens=num_input_tokens, ) if self.pass_cudagraph_args_to_forward_ctx: - update_batch_descriptor(cudagraph_args, num_input_tokens) + cudagraph_args = self.cudagraph_args(num_tokens=num_input_tokens) forward_ctx_kwargs.update(cudagraph_args) with set_forward_context(**forward_ctx_kwargs): @@ -455,9 +455,7 @@ def propose( num_tokens=input_batch_size, ) if self.pass_cudagraph_args_to_forward_ctx: - cudagraph_args = self.decoding_cudagraph_args( - num_tokens=input_batch_size - ) + cudagraph_args = self.cudagraph_args(num_tokens=input_batch_size) forward_ctx_kwargs.update(cudagraph_args) with set_forward_context(**forward_ctx_kwargs): @@ -494,8 +492,8 @@ def set_input_ids_first_pass( def model_returns_tuple(self) -> bool: return self.method not in ("mtp", "draft_model") - def decoding_cudagraph_args(self, num_tokens: int) -> "CudaGraphArgs": - batch_descriptor = BatchDescriptor(num_tokens=num_tokens, uniform_decode=True) + def cudagraph_args(self, num_tokens: int) -> "CudaGraphArgs": + batch_descriptor = BatchDescriptor(num_tokens=num_tokens, uniform_decode=False) cudagraph_runtime_mode, batch_descriptor = ( self.runner.cudagraph_dispatcher.dispatch(batch_descriptor) ) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 24bf70648ed3..2ae88e5a2bdb 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -2962,6 +2962,14 @@ def load_model(self, eep_scale_up: bool = False) -> None: self.model = CUDAGraphWrapper( self.model, self.vllm_config, runtime_mode=CUDAGraphMode.FULL ) + if hasattr(self, "drafter") and isinstance( + self.drafter, DraftModelProposer + ): + self.drafter.model = CUDAGraphWrapper( + self.drafter.model, + self.drafter.vllm_config, + runtime_mode=CUDAGraphMode.FULL, + ) elif self.parallel_config.enable_dbo: if self.compilation_config.cudagraph_mode.has_full_cudagraphs(): self.model = UBatchWrapper( From a1e899c440f7083c5db6026372f093dd183aef4a Mon Sep 17 00:00:00 2001 From: Tomas Ruiz Date: Wed, 8 Oct 2025 10:21:22 +0200 Subject: [PATCH 51/73] Remove unnecessary cudagraph inputs Signed-off-by: Tomas Ruiz --- vllm/v1/spec_decode/draft_model.py | 8 +------- vllm/v1/spec_decode/eagle.py | 1 - vllm/v1/worker/gpu_model_runner.py | 7 +------ 3 files changed, 2 insertions(+), 14 deletions(-) diff --git a/vllm/v1/spec_decode/draft_model.py b/vllm/v1/spec_decode/draft_model.py index 71026ab699cd..8e9bb742a627 100644 --- a/vllm/v1/spec_decode/draft_model.py +++ b/vllm/v1/spec_decode/draft_model.py @@ -15,11 +15,7 @@ extend_flat_seqs, ) from vllm.v1.sample.metadata import SamplingMetadata -from vllm.v1.spec_decode.eagle import ( - PADDING_SLOT_ID, - CudaGraphArgs, - SpecDecodeBaseProposer, -) +from vllm.v1.spec_decode.eagle import PADDING_SLOT_ID, SpecDecodeBaseProposer class DraftModelProposer(SpecDecodeBaseProposer): @@ -53,7 +49,6 @@ def propose( last_token_indices: Optional[torch.Tensor], common_attn_metadata: CommonAttentionMetadata, sampling_metadata: SamplingMetadata, - cudagraph_args: "CudaGraphArgs", mm_embed_inputs: Optional[tuple[list[torch.Tensor], torch.Tensor]] = None, ) -> torch.Tensor: """ @@ -77,7 +72,6 @@ def propose( target_token_ids=inputs.token_ids, target_positions=inputs.positions, common_attn_metadata=inputs.cad, - cudagraph_args=cudagraph_args, sampling_metadata=sampling_metadata, # below are are not used by draft model target_hidden_states=None, diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index cad47425461a..82fcc2cc1591 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -198,7 +198,6 @@ def propose( last_token_indices: Optional[torch.Tensor], common_attn_metadata: CommonAttentionMetadata, sampling_metadata: SamplingMetadata, - cudagraph_args: "CudaGraphArgs", mm_embed_inputs: Optional[tuple[list[torch.Tensor], torch.Tensor]] = None, ) -> torch.Tensor: num_tokens = target_token_ids.shape[0] diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 2ae88e5a2bdb..7db055489291 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -126,7 +126,7 @@ from vllm.v1.sample.rejection_sampler import RejectionSampler from vllm.v1.sample.sampler import Sampler from vllm.v1.spec_decode.draft_model import DraftModelProposer -from vllm.v1.spec_decode.eagle import CudaGraphArgs, EagleProposer +from vllm.v1.spec_decode.eagle import EagleProposer from vllm.v1.spec_decode.medusa import MedusaProposer from vllm.v1.spec_decode.metadata import SpecDecodeMetadata from vllm.v1.spec_decode.ngram_proposer import NgramProposer @@ -2819,10 +2819,6 @@ def propose_draft_token_ids( ) else: mm_embed_inputs = None - cudagraph_args: CudaGraphArgs = dict( - cudagraph_runtime_mode=cudagraph_runtime_mode, - batch_descriptor=batch_descriptor, - ) draft_token_ids = self.drafter.propose( target_token_ids=target_token_ids, target_positions=target_positions, @@ -2832,7 +2828,6 @@ def propose_draft_token_ids( sampling_metadata=sampling_metadata, common_attn_metadata=common_attn_metadata, mm_embed_inputs=mm_embed_inputs, - cudagraph_args=cudagraph_args, ) return draft_token_ids From 50dcbc497a8af7f60262b082b593a714876a00df Mon Sep 17 00:00:00 2001 From: Tomas Ruiz Date: Wed, 8 Oct 2025 10:27:18 +0200 Subject: [PATCH 52/73] Minimize changes Signed-off-by: Tomas Ruiz --- vllm/v1/worker/gpu_model_runner.py | 14 +++----------- 1 file changed, 3 insertions(+), 11 deletions(-) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 7db055489291..111fc6c21209 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -2593,8 +2593,6 @@ def propose_draft_token_ids(sampled_token_ids): aux_hidden_states, spec_decode_metadata, spec_decode_common_attn_metadata, - cudagraph_runtime_mode=cudagraph_runtime_mode, - batch_descriptor=batch_descriptor, ) use_padded_batch = ( @@ -2694,8 +2692,6 @@ def propose_draft_token_ids( aux_hidden_states: Optional[list[torch.Tensor]], spec_decode_metadata: Optional[SpecDecodeMetadata], common_attn_metadata: CommonAttentionMetadata, - cudagraph_runtime_mode: CUDAGraphMode, - batch_descriptor: BatchDescriptor, ) -> Union[list[list[int]], torch.Tensor]: num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens if self.speculative_config.method == "ngram": @@ -2819,6 +2815,7 @@ def propose_draft_token_ids( ) else: mm_embed_inputs = None + draft_token_ids = self.drafter.propose( target_token_ids=target_token_ids, target_positions=target_positions, @@ -2829,6 +2826,7 @@ def propose_draft_token_ids( common_attn_metadata=common_attn_metadata, mm_embed_inputs=mm_embed_inputs, ) + return draft_token_ids def update_config(self, overrides: dict[str, Any]) -> None: @@ -2887,13 +2885,7 @@ def load_model(self, eep_scale_up: bool = False) -> None: ) if hasattr(self, "drafter"): logger.info("Loading drafter model...") - if self.speculative_config.use_eagle(): - assert isinstance(self.drafter, EagleProposer) - self.drafter.load_model(self.model) - elif self.speculative_config.uses_draft_model(): - assert isinstance(self.drafter, DraftModelProposer) - # Passed something to satisfy the type checker - self.drafter.load_model(None) + self.drafter.load_model(self.model) if self.use_aux_hidden_state_outputs: if not supports_eagle3(self.model): raise RuntimeError( From c01e43baa1f576c90a272c0444f48fbe7346f434 Mon Sep 17 00:00:00 2001 From: Tomas Ruiz Date: Wed, 8 Oct 2025 10:49:39 +0200 Subject: [PATCH 53/73] Minimize changes Signed-off-by: Tomas Ruiz --- tests/v1/e2e/test_spec_decode.py | 2 +- vllm/v1/worker/utils.py | 6 +----- 2 files changed, 2 insertions(+), 6 deletions(-) diff --git a/tests/v1/e2e/test_spec_decode.py b/tests/v1/e2e/test_spec_decode.py index 3ecbcfb7418d..484e2b0233d0 100644 --- a/tests/v1/e2e/test_spec_decode.py +++ b/tests/v1/e2e/test_spec_decode.py @@ -393,7 +393,7 @@ class ArgsTest: draft_model="Qwen/Qwen3-0.6B", sampling_config=stochastic_sampling(), num_speculative_tokens=3, - expected_acceptance_len=2.85 + 1, + expected_acceptance_len=2.8 + 1, expected_acceptance_rate=0.9, expected_same_output_fraction=0.9, ), diff --git a/vllm/v1/worker/utils.py b/vllm/v1/worker/utils.py index d2ef0232a553..20299231e865 100644 --- a/vllm/v1/worker/utils.py +++ b/vllm/v1/worker/utils.py @@ -294,10 +294,7 @@ def bind_kv_cache( for layer_index in sorted(index2name.keys()): layer_names = index2name[layer_index] - non_draft_layers = [ - name for name in layer_names if not name.startswith("draft_model.") - ] - if len(non_draft_layers) > 1: + if len(layer_names) > 1: # One typical case is encoder-decoder model, e.g., bart. # The cross attention and self attention in the same decoder layer # has different layer_name but the same layer_index. @@ -312,7 +309,6 @@ def bind_kv_cache( pass else: raise NotImplementedError - for layer_name in layer_names: runner_kv_caches.append(kv_caches[layer_name]) From cf99760b46f121b1c299c1a9bcc9fc3076a1fb83 Mon Sep 17 00:00:00 2001 From: Tomas Ruiz Date: Wed, 8 Oct 2025 11:53:07 +0200 Subject: [PATCH 54/73] Remove unused fn Signed-off-by: Tomas Ruiz --- vllm/v1/spec_decode/eagle.py | 11 ----------- 1 file changed, 11 deletions(-) diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index 82fcc2cc1591..4bc96f0f410f 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -1195,14 +1195,3 @@ def compute_probs_and_sample_next_token( next_token_ids, ) return next_token_ids, probs - - -def update_batch_descriptor(cudagraph_args: CudaGraphArgs, new_num_tokens: int) -> None: - """The cudagraph padding can change the num_tokens, so the batch descriptor - should be updated. The cudagraph_args is modified in place.""" - old: Optional[BatchDescriptor] = cudagraph_args["batch_descriptor"] - if old is not None: - new = BatchDescriptor( - num_tokens=new_num_tokens, uniform_decode=old.uniform_decode - ) - cudagraph_args["batch_descriptor"] = new From c73929d340ea3997681f8660576426aa21543c6f Mon Sep 17 00:00:00 2001 From: Tomas Ruiz Date: Wed, 8 Oct 2025 11:55:40 +0200 Subject: [PATCH 55/73] Minimize changes Signed-off-by: Tomas Ruiz --- examples/offline_inference/spec_decode.py | 1 - tests/v1/spec_decode/test_eagle.py | 2 -- 2 files changed, 3 deletions(-) diff --git a/examples/offline_inference/spec_decode.py b/examples/offline_inference/spec_decode.py index dd13ebcc0576..88a675e649ef 100644 --- a/examples/offline_inference/spec_decode.py +++ b/examples/offline_inference/spec_decode.py @@ -73,7 +73,6 @@ def parse_args(): parser.add_argument("--draft-model", type=str, default=None) parser.add_argument("--custom-mm-prompts", action="store_true") parser.add_argument("--gpu-memory-utilization", type=float, default=0.8) - parser.add_argument("--request-id-prefix", type=str, default="") parser.add_argument("--disable-padded-drafter-batch", action="store_true") return parser.parse_args() diff --git a/tests/v1/spec_decode/test_eagle.py b/tests/v1/spec_decode/test_eagle.py index 3c748e25bd63..4c490f2188aa 100644 --- a/tests/v1/spec_decode/test_eagle.py +++ b/tests/v1/spec_decode/test_eagle.py @@ -568,7 +568,6 @@ def create_deterministic_logits(token_ids): last_token_indices=None, common_attn_metadata=common_attn_metadata, sampling_metadata=sampling_metadata, - cudagraph_args=dict(), ) assert result.shape == (batch_size, num_speculative_tokens) @@ -723,7 +722,6 @@ def create_deterministic_logits(token_ids, k: int): last_token_indices=None, common_attn_metadata=common_attn_metadata, sampling_metadata=sampling_metadata, - cudagraph_args=dict(), ) assert result.shape == (batch_size, num_speculative_tokens) From 66d4f2b2a978e54429b234fe438c2a3f72599527 Mon Sep 17 00:00:00 2001 From: Tomas Ruiz Date: Thu, 9 Oct 2025 10:56:51 +0200 Subject: [PATCH 56/73] Avoid OOB error on large batches Signed-off-by: Tomas Ruiz --- vllm/v1/spec_decode/eagle.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index 4bc96f0f410f..748caabe28a1 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -67,7 +67,11 @@ def __init__( self.max_model_len = vllm_config.model_config.max_model_len self.block_size = vllm_config.cache_config.block_size self.num_speculative_tokens = self.speculative_config.num_speculative_tokens - self.max_num_tokens = vllm_config.scheduler_config.max_num_batched_tokens + # The drafter can get longer sequences than the target model. + max_batch_size = vllm_config.scheduler_config.max_num_seqs + self.max_num_tokens = ( + vllm_config.scheduler_config.max_num_batched_tokens + max_batch_size + ) self.token_arange_np = np.arange(self.max_num_tokens) # We need to get the hidden size from the draft model config because # the draft model's hidden size can be different from the target model's @@ -118,7 +122,6 @@ def __init__( # We need +1 here because the arange is used to set query_start_loc, # which has one more element than batch_size. - max_batch_size = vllm_config.scheduler_config.max_num_seqs max_num_slots_for_arange = max(max_batch_size + 1, self.max_num_tokens) self.arange = torch.arange( max_num_slots_for_arange, device=device, dtype=torch.int32 From de86231b77dceb98df9fca1649c2cae3ecea0c47 Mon Sep 17 00:00:00 2001 From: Tomas Ruiz Date: Fri, 10 Oct 2025 13:06:26 +0200 Subject: [PATCH 57/73] Simplify away passing the CUDA graph args Signed-off-by: Tomas Ruiz --- vllm/v1/spec_decode/draft_model.py | 1 - vllm/v1/spec_decode/eagle.py | 22 ++-------------------- 2 files changed, 2 insertions(+), 21 deletions(-) diff --git a/vllm/v1/spec_decode/draft_model.py b/vllm/v1/spec_decode/draft_model.py index 8e9bb742a627..1bfab8be4ef5 100644 --- a/vllm/v1/spec_decode/draft_model.py +++ b/vllm/v1/spec_decode/draft_model.py @@ -29,7 +29,6 @@ def __init__( vllm_config=vllm_config, device=device, pass_hidden_states_to_model=False, - pass_cudagraph_args_to_forward_ctx=True, runner=runner, ) self._raise_if_multimodal() diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index f5c63e58018a..4ee133762719 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -3,7 +3,7 @@ import ast from dataclasses import replace from importlib.util import find_spec -from typing import Optional, TypedDict +from typing import Optional import numpy as np import torch @@ -16,7 +16,7 @@ get_layers_from_vllm_config, ) from vllm.distributed.parallel_state import get_pp_group -from vllm.forward_context import BatchDescriptor, set_forward_context +from vllm.forward_context import set_forward_context from vllm.logger import init_logger from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase from vllm.model_executor.model_loader import get_model @@ -54,7 +54,6 @@ def __init__( vllm_config: VllmConfig, device: torch.device, pass_hidden_states_to_model: bool, - pass_cudagraph_args_to_forward_ctx: bool, runner=None, ): self.vllm_config = vllm_config @@ -63,7 +62,6 @@ def __init__( self.draft_model_config = self.speculative_config.draft_model_config self.method = self.speculative_config.method self.pass_hidden_states_to_model = pass_hidden_states_to_model - self.pass_cudagraph_args_to_forward_ctx = pass_cudagraph_args_to_forward_ctx self.runner = runner self.device = device @@ -508,16 +506,6 @@ def set_input_ids_first_pass( def model_returns_tuple(self) -> bool: return self.method not in ("mtp", "draft_model") - def cudagraph_args(self, num_tokens: int) -> "CudaGraphArgs": - batch_descriptor = BatchDescriptor(num_tokens=num_tokens, uniform_decode=False) - cudagraph_runtime_mode, batch_descriptor = ( - self.runner.cudagraph_dispatcher.dispatch(batch_descriptor) - ) - return CudaGraphArgs( - cudagraph_runtime_mode=cudagraph_runtime_mode, - batch_descriptor=batch_descriptor, - ) - def prepare_next_token_ids_cpu( self, sampled_token_ids: list[list[int]], @@ -1168,11 +1156,6 @@ def validate_same_kv_cache_group(self, kv_cache_config: KVCacheConfig) -> None: ), "All eagle layers should belong to the same kv cache group" -class CudaGraphArgs(TypedDict): - cudagraph_runtime_mode: CUDAGraphMode - batch_descriptor: BatchDescriptor - - class EagleProposer(SpecDecodeBaseProposer): def __init__( self, @@ -1184,7 +1167,6 @@ def __init__( vllm_config, device, pass_hidden_states_to_model=True, - pass_cudagraph_args_to_forward_ctx=False, runner=runner, ) From f8321d20a717013c5c41e1c22e1590937ceac911 Mon Sep 17 00:00:00 2001 From: Tomas Ruiz Date: Fri, 10 Oct 2025 13:20:09 +0200 Subject: [PATCH 58/73] add option --max-num-seqs to spec_decode.py (useful for small GPUs) Signed-off-by: Tomas Ruiz --- examples/offline_inference/spec_decode.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/examples/offline_inference/spec_decode.py b/examples/offline_inference/spec_decode.py index 88a675e649ef..9511b4bf3425 100644 --- a/examples/offline_inference/spec_decode.py +++ b/examples/offline_inference/spec_decode.py @@ -74,6 +74,7 @@ def parse_args(): parser.add_argument("--custom-mm-prompts", action="store_true") parser.add_argument("--gpu-memory-utilization", type=float, default=0.8) parser.add_argument("--disable-padded-drafter-batch", action="store_true") + parser.add_argument("--max-num-seqs", type=int, default=None) return parser.parse_args() @@ -153,6 +154,7 @@ def main(args): max_model_len=args.max_model_len, limit_mm_per_prompt={"image": 5}, disable_chunked_mm_input=True, + max_num_seqs=args.max_num_seqs, ) sampling_params = SamplingParams(temperature=args.temp, max_tokens=args.output_len) From e9560ef38442bfae166f7fbf7c2b8fa7d9172047 Mon Sep 17 00:00:00 2001 From: Tomas Ruiz Date: Fri, 10 Oct 2025 14:42:10 +0200 Subject: [PATCH 59/73] Prevent different tokenizer vocab sizes Signed-off-by: Tomas Ruiz --- vllm/config/speculative.py | 19 ++++++++++++++++++- vllm/v1/spec_decode/draft_model.py | 4 ++++ 2 files changed, 22 insertions(+), 1 deletion(-) diff --git a/vllm/config/speculative.py b/vllm/config/speculative.py index 521110431f93..6704b90ec4c0 100644 --- a/vllm/config/speculative.py +++ b/vllm/config/speculative.py @@ -575,9 +575,26 @@ def _verify_args(self) -> Self: f"Eagle3 is only supported for {eagle3_target_supported} models. " # noqa: E501 f"Got {self.target_model_config.hf_text_config.model_type=}" ) - + self.verify_equal_vocab_size_if_draft_model() return self + def verify_equal_vocab_size_if_draft_model(self): + if ( + self.method == "draft_model" + and self.target_model_config is not None + and self.draft_model_config is not None + ): + target_vocab_size = self.target_model_config.get_vocab_size() + draft_vocab_size = self.draft_model_config.get_vocab_size() + if target_vocab_size != draft_vocab_size: + raise ValueError( + f"Target and draft model should have the same vocabulary size. " + f"Target model vocab_size={target_vocab_size}. " + f"Draft model vocab_size={draft_vocab_size}. " + f"Using models with different tokenizers can cause out-of-bounds " + f"errors during speculative decoding." + ) + @property def num_lookahead_slots(self) -> int: """The number of additional slots the scheduler should allocate per diff --git a/vllm/v1/spec_decode/draft_model.py b/vllm/v1/spec_decode/draft_model.py index 1bfab8be4ef5..0fc40c649138 100644 --- a/vllm/v1/spec_decode/draft_model.py +++ b/vllm/v1/spec_decode/draft_model.py @@ -34,6 +34,7 @@ def __init__( self._raise_if_multimodal() self._raise_if_mrope() self._raise_if_padded_drafter_batch() + self._raise_if_vocab_size_mismatch() def propose( self, @@ -101,6 +102,9 @@ def _raise_if_padded_drafter_batch(self): "in the speculative config." ) + def _raise_if_vocab_size_mismatch(self): + self.vllm_config.speculative_config.verify_equal_vocab_size_if_draft_model() + def _model_kwargs(self, num_tokens: int) -> dict[str, Any]: return { "input_ids": self.input_ids[:num_tokens], From 694faf8f62d70eef316ee114c7ed4d1746cac066 Mon Sep 17 00:00:00 2001 From: Tomas Ruiz Date: Fri, 10 Oct 2025 16:05:45 +0200 Subject: [PATCH 60/73] Limit cudagraph capture time in test Signed-off-by: Tomas Ruiz --- tests/v1/e2e/test_spec_decode.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/v1/e2e/test_spec_decode.py b/tests/v1/e2e/test_spec_decode.py index 5ea12e77b08c..4b2f9d203b8e 100644 --- a/tests/v1/e2e/test_spec_decode.py +++ b/tests/v1/e2e/test_spec_decode.py @@ -421,6 +421,7 @@ def test_draft_model_correctness( "enforce_eager": enforce_eager, "tensor_parallel_size": args.draft_tensor_parallel_size, "disable_padded_drafter_batch": True, + "max_num_seqs": 100, # limit cudagraph capture runtime }, max_model_len=args.max_model_len, gpu_memory_utilization=args.gpu_memory_utilization, From fa6294fa1f70db07b6c047e2c2fb657aa814e780 Mon Sep 17 00:00:00 2001 From: Tomas Ruiz Date: Fri, 10 Oct 2025 16:35:06 +0200 Subject: [PATCH 61/73] Minimize changes related to CUDA graph Signed-off-by: Tomas Ruiz --- vllm/v1/spec_decode/draft_model.py | 16 ---------------- vllm/v1/spec_decode/eagle.py | 7 +++++-- vllm/v1/worker/gpu_model_runner.py | 24 +++++------------------- 3 files changed, 10 insertions(+), 37 deletions(-) diff --git a/vllm/v1/spec_decode/draft_model.py b/vllm/v1/spec_decode/draft_model.py index 0fc40c649138..6707aa7c911b 100644 --- a/vllm/v1/spec_decode/draft_model.py +++ b/vllm/v1/spec_decode/draft_model.py @@ -7,7 +7,6 @@ from vllm.attention.layer import Attention from vllm.config import ModelConfig, VllmConfig, get_layers_from_vllm_config -from vllm.forward_context import set_forward_context from vllm.model_executor.model_loader import get_model from vllm.v1.attention.backends.utils import ( CommonAttentionMetadata, @@ -105,21 +104,6 @@ def _raise_if_padded_drafter_batch(self): def _raise_if_vocab_size_mismatch(self): self.vllm_config.speculative_config.verify_equal_vocab_size_if_draft_model() - def _model_kwargs(self, num_tokens: int) -> dict[str, Any]: - return { - "input_ids": self.input_ids[:num_tokens], - "positions": self.positions[:num_tokens], - } - - def dummy_run(self, num_tokens: int, forward_ctx_kwargs: dict): - model_kwargs = self._model_kwargs(num_tokens) - with set_forward_context( - vllm_config=self.vllm_config, - num_tokens=num_tokens, - **forward_ctx_kwargs, - ): - self.model(**model_kwargs) - def set_input_ids_first_pass( self, target_token_ids: torch.Tensor, diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index 4ee133762719..503d2b1d0755 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -1100,12 +1100,15 @@ def dummy_run( input_ids = self.input_ids[:num_tokens] inputs_embeds = None - self.model( + model_kwargs = dict( input_ids=input_ids, positions=self._get_positions(num_tokens), - hidden_states=self.hidden_states[:num_tokens], inputs_embeds=inputs_embeds, ) + if self.pass_hidden_states_to_model: + model_kwargs["hidden_states"] = self.hidden_states[:num_tokens] + + self.model(**model_kwargs) def _get_attention_metadata_builder(self) -> list[AttentionMetadataBuilder]: """Find and return the attention metadata builders for EAGLE layers. diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 24f0e62fc5f2..98ba4a731c81 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -2944,14 +2944,6 @@ def load_model(self, eep_scale_up: bool = False) -> None: self.model = CUDAGraphWrapper( self.model, self.vllm_config, runtime_mode=CUDAGraphMode.FULL ) - if hasattr(self, "drafter") and isinstance( - self.drafter, DraftModelProposer - ): - self.drafter.model = CUDAGraphWrapper( - self.drafter.model, - self.drafter.vllm_config, - runtime_mode=CUDAGraphMode.FULL, - ) elif self.parallel_config.enable_dbo: if self.compilation_config.cudagraph_mode.has_full_cudagraphs(): self.model = UBatchWrapper( @@ -3463,20 +3455,14 @@ def _dummy_run( else: hidden_states = outputs - if self.speculative_config and self.speculative_config.use_eagle(): - assert isinstance(self.drafter, EagleProposer) + if self.speculative_config and ( + self.speculative_config.use_eagle() + or self.speculative_config.uses_draft_model() + ): + assert isinstance(self.drafter, (EagleProposer, DraftModelProposer)) use_cudagraphs = cudagraph_runtime_mode == CUDAGraphMode.PIECEWISE self.drafter.dummy_run(num_tokens, use_cudagraphs=use_cudagraphs) - if self.speculative_config and self.speculative_config.uses_draft_model(): - assert isinstance(self.drafter, DraftModelProposer) - forward_ctx_kwargs = { - "attn_metadata": attn_metadata, - "cudagraph_runtime_mode": cudagraph_runtime_mode, - "batch_descriptor": batch_descriptor, - } - self.drafter.dummy_run(num_tokens, forward_ctx_kwargs) - # This is necessary to avoid blocking DP. # For dummy runs, we typically skip EPLB since we don't have any real # requests to process. From f49a5ea721625bf6f9503b04b79c422dfeaaa300 Mon Sep 17 00:00:00 2001 From: Tomas Ruiz Date: Mon, 13 Oct 2025 12:18:27 +0200 Subject: [PATCH 62/73] Replace Optional[T] with T | None Signed-off-by: Tomas Ruiz --- vllm/v1/spec_decode/draft_model.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/vllm/v1/spec_decode/draft_model.py b/vllm/v1/spec_decode/draft_model.py index 6707aa7c911b..47f4826fdeb3 100644 --- a/vllm/v1/spec_decode/draft_model.py +++ b/vllm/v1/spec_decode/draft_model.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from dataclasses import dataclass, replace -from typing import Any, Optional +from typing import Any import torch @@ -45,10 +45,10 @@ def propose( target_hidden_states: torch.Tensor, # [batch_size] next_token_ids: torch.Tensor, - last_token_indices: Optional[torch.Tensor], + last_token_indices: torch.Tensor | None, common_attn_metadata: CommonAttentionMetadata, sampling_metadata: SamplingMetadata, - mm_embed_inputs: Optional[tuple[list[torch.Tensor], torch.Tensor]] = None, + mm_embed_inputs: tuple[list[torch.Tensor], torch.Tensor] | None = None, ) -> torch.Tensor: """ This function processes the inputs first before calling the .propose() From 37f013ecf11426cda676551017e4c34016b58d52 Mon Sep 17 00:00:00 2001 From: Tomas Ruiz Date: Mon, 13 Oct 2025 16:12:28 +0200 Subject: [PATCH 63/73] Add tests for quantized target / draft model Signed-off-by: Tomas Ruiz --- tests/v1/e2e/test_spec_decode.py | 35 +++++++++++++++++++++++++----- vllm/config/vllm.py | 8 +++++++ vllm/v1/spec_decode/draft_model.py | 9 +++++--- 3 files changed, 43 insertions(+), 9 deletions(-) diff --git a/tests/v1/e2e/test_spec_decode.py b/tests/v1/e2e/test_spec_decode.py index 79b240e57d13..b1fe79286706 100644 --- a/tests/v1/e2e/test_spec_decode.py +++ b/tests/v1/e2e/test_spec_decode.py @@ -399,14 +399,37 @@ class ArgsTest: @pytest.mark.parametrize("args", cases) @pytest.mark.parametrize("enforce_eager", [True, False]) -def test_draft_model_correctness( - args: ArgsTest, - enforce_eager: bool, - monkeypatch: pytest.MonkeyPatch, -): +def test_draft_model_correctness(args: ArgsTest, enforce_eager: bool): + assert_draft_model_correctness(args, enforce_eager) + + +@pytest.mark.parametrize( + "models", + [ + # target_model, draft_model + ("Qwen/Qwen3-1.7B-FP8", "Qwen/Qwen3-0.6B"), # target quantized + ("Qwen/Qwen3-1.7B", "Qwen/Qwen3-0.6B-FP8"), # draft quantized + ], + ids=["target_quantized", "draft_quantized"], +) +@pytest.mark.parametrize("enforce_eager", [True, False]) +def test_draft_model_quantization(models: tuple[str, str], enforce_eager: bool): + tgt_model, draft_model = models + sd_case = ArgsTest( + model=tgt_model, + draft_model=draft_model, + sampling_config=greedy_sampling(), + num_speculative_tokens=3, + expected_acceptance_len=2.95 + 1, + expected_acceptance_rate=0.95, + expected_same_output_fraction=0.95, + ) + assert_draft_model_correctness(sd_case, enforce_eager) + + +def assert_draft_model_correctness(args: ArgsTest, enforce_eager: bool): """Compare the outputs using and not using speculative decoding. In the greedy decoding case, the outputs must match EXACTLY.""" - monkeypatch.setenv("VLLM_USE_V1", "1") test_prompts = get_test_prompts(mm_enabled=False, quiet=True) spec_llm = LLM( diff --git a/vllm/config/vllm.py b/vllm/config/vllm.py index b15d122c9161..4a7d0280a6dc 100644 --- a/vllm/config/vllm.py +++ b/vllm/config/vllm.py @@ -757,6 +757,14 @@ def compile_debug_dump_path(self) -> Path | None: path = self.compilation_config.debug_dump_path / append_path return path + def replace(self, **kwargs): + """ + Replace attributes of the config, and 'recompute' the config. + dataclass.replace() calls __init__() and __post_init__(), source: + https://docs.python.org/3/library/dataclasses.html#dataclasses.replace + """ + return replace(self, **kwargs) + def __str__(self): return ( f"model={self.model_config.model!r}, " diff --git a/vllm/v1/spec_decode/draft_model.py b/vllm/v1/spec_decode/draft_model.py index 47f4826fdeb3..dc99fae3254e 100644 --- a/vllm/v1/spec_decode/draft_model.py +++ b/vllm/v1/spec_decode/draft_model.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from dataclasses import dataclass, replace +from dataclasses import dataclass from typing import Any import torch @@ -118,8 +118,11 @@ def load_model(self, target_model: Any) -> None: draft_model_config: ModelConfig = ( self.vllm_config.speculative_config.draft_model_config ) - vllm_config_draft: VllmConfig = replace( - self.vllm_config, model_config=draft_model_config + # Recompute quant_config, which is configured for the target model + # But the draft model might not be quantized. + vllm_config_draft: VllmConfig = self.vllm_config.replace( + quant_config=None, + model_config=draft_model_config, ) # This must be computed before loading the draft model From 58f8496078f889c71c7c42ac3a70952a247683ca Mon Sep 17 00:00:00 2001 From: Tomas Ruiz Date: Mon, 13 Oct 2025 15:40:53 +0000 Subject: [PATCH 64/73] Add test for draft model + tensor parallelism Signed-off-by: Tomas Ruiz --- tests/v1/e2e/test_spec_decode.py | 41 +++++++++++++++++++++--------- vllm/v1/spec_decode/draft_model.py | 5 ++++ 2 files changed, 34 insertions(+), 12 deletions(-) diff --git a/tests/v1/e2e/test_spec_decode.py b/tests/v1/e2e/test_spec_decode.py index b1fe79286706..cdf582111a30 100644 --- a/tests/v1/e2e/test_spec_decode.py +++ b/tests/v1/e2e/test_spec_decode.py @@ -359,7 +359,7 @@ def test_mtp_correctness( @dataclass class ArgsTest: - model: str + target_model: str draft_model: str sampling_config: SamplingParams num_speculative_tokens: int @@ -376,7 +376,7 @@ class ArgsTest: cases = [ # Same model for draft and target, greedy sampling. ArgsTest( - model="Qwen/Qwen3-0.6B", + target_model="Qwen/Qwen3-0.6B", draft_model="Qwen/Qwen3-0.6B", sampling_config=greedy_sampling(), num_speculative_tokens=3, # K @@ -386,7 +386,7 @@ class ArgsTest: ), # Smaller draft model, stochastic sampling. ArgsTest( - model="Qwen/Qwen3-1.7B", + target_model="Qwen/Qwen3-1.7B", draft_model="Qwen/Qwen3-0.6B", sampling_config=stochastic_sampling(), num_speculative_tokens=3, @@ -416,24 +416,31 @@ def test_draft_model_correctness(args: ArgsTest, enforce_eager: bool): def test_draft_model_quantization(models: tuple[str, str], enforce_eager: bool): tgt_model, draft_model = models sd_case = ArgsTest( - model=tgt_model, + target_model=tgt_model, draft_model=draft_model, - sampling_config=greedy_sampling(), - num_speculative_tokens=3, - expected_acceptance_len=2.95 + 1, - expected_acceptance_rate=0.95, - expected_same_output_fraction=0.95, + **some_high_acceptance_metrics(), ) assert_draft_model_correctness(sd_case, enforce_eager) +def test_draft_model_tensor_parallelism(): + sd_case = ArgsTest( + target_model="Qwen/Qwen3-1.7B", + target_tensor_parallel_size=2, + draft_model="Qwen/Qwen3-0.6B", + draft_tensor_parallel_size=1, + **some_high_acceptance_metrics(), + ) + assert_draft_model_correctness(sd_case, enforce_eager=True) + + def assert_draft_model_correctness(args: ArgsTest, enforce_eager: bool): """Compare the outputs using and not using speculative decoding. In the greedy decoding case, the outputs must match EXACTLY.""" test_prompts = get_test_prompts(mm_enabled=False, quiet=True) spec_llm = LLM( - model=args.model, + model=args.target_model, speculative_config={ "model": args.draft_model, "method": "draft_model", @@ -462,7 +469,7 @@ def assert_draft_model_correctness(args: ArgsTest, enforce_eager: bool): assert acceptance_len >= args.expected_acceptance_len ref_llm = LLM( - model=args.model, + model=args.target_model, max_model_len=args.max_model_len, gpu_memory_utilization=args.gpu_memory_utilization, tensor_parallel_size=args.target_tensor_parallel_size, @@ -480,7 +487,7 @@ def assert_draft_model_correctness(args: ArgsTest, enforce_eager: bool): assert match_fraction >= args.expected_same_output_fraction print( - f"spec-decode: target={args.model}, draft={args.draft_model}, " + f"spec-decode: target={args.target_model}, draft={args.draft_model}, " f"temperature={args.sampling_config.temperature:.2f}, " f"acceptance_rate={acceptance_rate:.2f}, " f"acceptance_len={acceptance_len:.2f}, " @@ -501,3 +508,13 @@ def compute_exact_matches( print(f"ref_output: {ref_output.outputs[0].text}") print(f"spec_output: {spec_output.outputs[0].text}") return matches / len(ref_outputs) + + +def some_high_acceptance_metrics() -> dict: + return { + "sampling_config": greedy_sampling(), + "num_speculative_tokens": 3, + "expected_acceptance_len": 2.95 + 1, + "expected_acceptance_rate": 0.95, + "expected_same_output_fraction": 0.95, + } diff --git a/vllm/v1/spec_decode/draft_model.py b/vllm/v1/spec_decode/draft_model.py index dc99fae3254e..6e3855fa3c25 100644 --- a/vllm/v1/spec_decode/draft_model.py +++ b/vllm/v1/spec_decode/draft_model.py @@ -7,6 +7,7 @@ from vllm.attention.layer import Attention from vllm.config import ModelConfig, VllmConfig, get_layers_from_vllm_config +from vllm.logger import init_logger from vllm.model_executor.model_loader import get_model from vllm.v1.attention.backends.utils import ( CommonAttentionMetadata, @@ -16,6 +17,8 @@ from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.spec_decode.eagle import PADDING_SLOT_ID, SpecDecodeBaseProposer +logger = init_logger(__name__) + class DraftModelProposer(SpecDecodeBaseProposer): def __init__( @@ -118,6 +121,8 @@ def load_model(self, target_model: Any) -> None: draft_model_config: ModelConfig = ( self.vllm_config.speculative_config.draft_model_config ) + logger.info("Starting to load model %s...", draft_model_config.model) + # Recompute quant_config, which is configured for the target model # But the draft model might not be quantized. vllm_config_draft: VllmConfig = self.vllm_config.replace( From 4bd9a460cafbfb8d55084b88b114e400ff26538c Mon Sep 17 00:00:00 2001 From: Tomas Ruiz Date: Mon, 13 Oct 2025 16:15:43 +0000 Subject: [PATCH 65/73] Log why endpoint is not ready Signed-off-by: Tomas Ruiz --- vllm/benchmarks/lib/ready_checker.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/vllm/benchmarks/lib/ready_checker.py b/vllm/benchmarks/lib/ready_checker.py index 5649faf05597..0cfd053f5353 100644 --- a/vllm/benchmarks/lib/ready_checker.py +++ b/vllm/benchmarks/lib/ready_checker.py @@ -8,8 +8,12 @@ import aiohttp from tqdm.asyncio import tqdm +from vllm.logger import init_logger + from .endpoint_request_func import RequestFunc, RequestFuncInput, RequestFuncOutput +logger = init_logger(__name__) + async def wait_for_endpoint( request_func: RequestFunc, @@ -61,6 +65,8 @@ async def wait_for_endpoint( if output.success: pbar.close() return output + else: + logger.warning("Endpoint is not ready. Error='%s'", output.error) except aiohttp.ClientConnectorError: pass From ff92d85272617801843167871dbca12e552b919b Mon Sep 17 00:00:00 2001 From: Tomas Ruiz Date: Mon, 13 Oct 2025 17:40:40 +0000 Subject: [PATCH 66/73] Test tensor parallelism more thoroughly Signed-off-by: Tomas Ruiz --- tests/v1/e2e/test_spec_decode.py | 47 +++++++++++++++++++++++++++++- vllm/config/speculative.py | 8 +++++ vllm/v1/spec_decode/draft_model.py | 45 +++++++++++++++++----------- 3 files changed, 82 insertions(+), 18 deletions(-) diff --git a/tests/v1/e2e/test_spec_decode.py b/tests/v1/e2e/test_spec_decode.py index cdf582111a30..49aab857b2a1 100644 --- a/tests/v1/e2e/test_spec_decode.py +++ b/tests/v1/e2e/test_spec_decode.py @@ -11,9 +11,12 @@ from vllm import LLM, SamplingParams from vllm.assets.base import VLLM_S3_BUCKET_URL from vllm.assets.image import VLM_IMAGES_DIR +from vllm.config.vllm import VllmConfig from vllm.distributed import cleanup_dist_env_and_memory +from vllm.engine.arg_utils import EngineArgs from vllm.outputs import RequestOutput from vllm.platforms import current_platform +from vllm.v1.spec_decode.draft_model import create_vllm_config_for_draft_model from vllm.v1.spec_decode.metrics import compute_acceptance_len, compute_acceptance_rate MTP_SIMILARITY_RATE = 0.8 @@ -434,6 +437,48 @@ def test_draft_model_tensor_parallelism(): assert_draft_model_correctness(sd_case, enforce_eager=True) +def test_draft_model_engine_args_tensor_parallelism(): + engine_args = EngineArgs( + model="Qwen/Qwen3-1.7B-FP8", # <<< tgt quantized + tensor_parallel_size=2, + speculative_config={ + "model": "Qwen/Qwen3-0.6B", # <<< draft not quantized + "method": "draft_model", + "num_speculative_tokens": 3, + "draft_tensor_parallel_size": 1, # <<< valid arg name + }, + ) + tgt_vllm_config: VllmConfig = engine_args.create_engine_config() + assert tgt_vllm_config.parallel_config.tensor_parallel_size == 2 + assert tgt_vllm_config.quant_config.get_name() == "fp8" + assert ( + tgt_vllm_config.speculative_config.draft_parallel_config.tensor_parallel_size + == 1 + ) + + draft_vllm_config: VllmConfig = create_vllm_config_for_draft_model(tgt_vllm_config) + assert draft_vllm_config.parallel_config.tensor_parallel_size == 1 + assert draft_vllm_config.quant_config is None + + +def test_draft_model_engine_args_rejects_invalid_tp_argname(): + """The user should pass "draft_tensor_parallel_size", rather than + "tensor_parallel_size". This is to catch bad syntax early.""" + + engine_args = EngineArgs( + model="Qwen/Qwen3-1.7B", + tensor_parallel_size=2, + speculative_config={ + "model": "Qwen/Qwen3-0.6B", + "method": "draft_model", + "num_speculative_tokens": 3, + "tensor_parallel_size": 1, # invalid arg name + }, + ) + with pytest.raises(ValueError): + engine_args.create_engine_config() + + def assert_draft_model_correctness(args: ArgsTest, enforce_eager: bool): """Compare the outputs using and not using speculative decoding. In the greedy decoding case, the outputs must match EXACTLY.""" @@ -447,7 +492,7 @@ def assert_draft_model_correctness(args: ArgsTest, enforce_eager: bool): "num_speculative_tokens": args.num_speculative_tokens, "max_model_len": args.max_model_len, "enforce_eager": enforce_eager, - "tensor_parallel_size": args.draft_tensor_parallel_size, + "draft_tensor_parallel_size": args.draft_tensor_parallel_size, "disable_padded_drafter_batch": True, "max_num_seqs": 100, # limit cudagraph capture runtime }, diff --git a/vllm/config/speculative.py b/vllm/config/speculative.py index 7d5933d57dac..f5be33ef2251 100644 --- a/vllm/config/speculative.py +++ b/vllm/config/speculative.py @@ -79,6 +79,8 @@ class SpeculativeConfig: draft_tensor_parallel_size: int | None = None """The degree of the tensor parallelism for the draft model. Can only be 1 or the same as the target model's tensor parallel size.""" + tensor_parallel_size: int | None = None + """This is only used to capture and reject if passed""" disable_logprobs: bool = True """If set to True, token log probabilities are not returned during speculative decoding. If set to False, token log probabilities are returned @@ -537,6 +539,12 @@ def create_draft_parallel_config( @model_validator(mode="after") def _verify_args(self) -> Self: + if self.tensor_parallel_size is not None: + raise ValueError( + "'tensor_parallel_size' is not a valid argument in the " + "speculative_config. Please pass 'draft_tensor_parallel_size' instead." + ) + if self.num_speculative_tokens is None: raise ValueError( "num_speculative_tokens must be provided with " diff --git a/vllm/v1/spec_decode/draft_model.py b/vllm/v1/spec_decode/draft_model.py index 6e3855fa3c25..38d1f91d60a9 100644 --- a/vllm/v1/spec_decode/draft_model.py +++ b/vllm/v1/spec_decode/draft_model.py @@ -6,7 +6,7 @@ import torch from vllm.attention.layer import Attention -from vllm.config import ModelConfig, VllmConfig, get_layers_from_vllm_config +from vllm.config import VllmConfig, get_layers_from_vllm_config from vllm.logger import init_logger from vllm.model_executor.model_loader import get_model from vllm.v1.attention.backends.utils import ( @@ -118,17 +118,6 @@ def set_input_ids_first_pass( def load_model(self, target_model: Any) -> None: """Takes target_model to satisfy the type checker.""" - draft_model_config: ModelConfig = ( - self.vllm_config.speculative_config.draft_model_config - ) - logger.info("Starting to load model %s...", draft_model_config.model) - - # Recompute quant_config, which is configured for the target model - # But the draft model might not be quantized. - vllm_config_draft: VllmConfig = self.vllm_config.replace( - quant_config=None, - model_config=draft_model_config, - ) # This must be computed before loading the draft model # because that mutates the forward_context of the vllm_config @@ -138,12 +127,16 @@ def load_model(self, target_model: Any) -> None: from vllm.compilation.backends import set_model_tag + draft_vllm_config: VllmConfig = create_vllm_config_for_draft_model( + target_model_vllm_config=self.vllm_config + ) + logger.info( + "Starting to load model %s with tensor_parallel_size %d...", + draft_vllm_config.model_config.model, + draft_vllm_config.parallel_config.tensor_parallel_size, + ) with set_model_tag("draft_model"): - self.model = get_model( - vllm_config=vllm_config_draft, - model_config=draft_model_config, - prefix="draft_model", - ) + self.model = get_model(vllm_config=draft_vllm_config, prefix="draft_model") # This must be computed after loading the draft model # because that mutates the forward_context of the vllm_config @@ -154,6 +147,24 @@ def load_model(self, target_model: Any) -> None: self.attn_layer_names = list(draft_attn_layer_names) +def create_vllm_config_for_draft_model( + target_model_vllm_config: VllmConfig, +) -> VllmConfig: + """The vllm_config is configured for the target model, e.g. + its quant_config and parallel_config. But the draft model might + not be quantized the same way, and might have different tensor_parallel_size. + We need to create a new vllm_config for the draft model. + This is vllm_config is useful when loading the draft model. + """ + old = target_model_vllm_config + new: VllmConfig = old.replace( + quant_config=None, # quant_config is recomputed in __init__() + model_config=old.speculative_config.draft_model_config, + parallel_config=old.speculative_config.draft_parallel_config, + ) + return new + + @dataclass class DraftModelInputs: token_ids: torch.Tensor From c135ae15ae719251ffa4583a15f0de462deb282c Mon Sep 17 00:00:00 2001 From: Tomas Ruiz Date: Tue, 14 Oct 2025 08:21:12 +0000 Subject: [PATCH 67/73] Reject draft TP > 1 Signed-off-by: Tomas Ruiz --- vllm/config/speculative.py | 3 +++ vllm/v1/spec_decode/draft_model.py | 3 ++- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/vllm/config/speculative.py b/vllm/config/speculative.py index f5be33ef2251..082e5ee95da6 100644 --- a/vllm/config/speculative.py +++ b/vllm/config/speculative.py @@ -545,6 +545,9 @@ def _verify_args(self) -> Self: "speculative_config. Please pass 'draft_tensor_parallel_size' instead." ) + if self.draft_tensor_parallel_size > 1: + raise ValueError("Only draft_tensor_parallel_size=1 is implemented so far.") + if self.num_speculative_tokens is None: raise ValueError( "num_speculative_tokens must be provided with " diff --git a/vllm/v1/spec_decode/draft_model.py b/vllm/v1/spec_decode/draft_model.py index 38d1f91d60a9..0961c4a11372 100644 --- a/vllm/v1/spec_decode/draft_model.py +++ b/vllm/v1/spec_decode/draft_model.py @@ -131,9 +131,10 @@ def load_model(self, target_model: Any) -> None: target_model_vllm_config=self.vllm_config ) logger.info( - "Starting to load model %s with tensor_parallel_size %d...", + "Starting to load draft model %s. TP=%d, rank=%d", draft_vllm_config.model_config.model, draft_vllm_config.parallel_config.tensor_parallel_size, + draft_vllm_config.parallel_config.rank, ) with set_model_tag("draft_model"): self.model = get_model(vllm_config=draft_vllm_config, prefix="draft_model") From 7c011c0264a89dade5ea9878a8c534231f6ebd20 Mon Sep 17 00:00:00 2001 From: Tomas Ruiz Date: Tue, 14 Oct 2025 10:24:51 +0000 Subject: [PATCH 68/73] Enforce same TP for draft & target Signed-off-by: Tomas Ruiz --- tests/v1/e2e/test_spec_decode.py | 12 ++++-------- vllm/config/speculative.py | 3 --- vllm/v1/spec_decode/draft_model.py | 15 ++++++++++++++- 3 files changed, 18 insertions(+), 12 deletions(-) diff --git a/tests/v1/e2e/test_spec_decode.py b/tests/v1/e2e/test_spec_decode.py index 49aab857b2a1..c3e20f66ba54 100644 --- a/tests/v1/e2e/test_spec_decode.py +++ b/tests/v1/e2e/test_spec_decode.py @@ -431,16 +431,16 @@ def test_draft_model_tensor_parallelism(): target_model="Qwen/Qwen3-1.7B", target_tensor_parallel_size=2, draft_model="Qwen/Qwen3-0.6B", - draft_tensor_parallel_size=1, + draft_tensor_parallel_size=2, **some_high_acceptance_metrics(), ) - assert_draft_model_correctness(sd_case, enforce_eager=True) + assert_draft_model_correctness(sd_case, enforce_eager=False) def test_draft_model_engine_args_tensor_parallelism(): engine_args = EngineArgs( model="Qwen/Qwen3-1.7B-FP8", # <<< tgt quantized - tensor_parallel_size=2, + tensor_parallel_size=4, speculative_config={ "model": "Qwen/Qwen3-0.6B", # <<< draft not quantized "method": "draft_model", @@ -449,12 +449,8 @@ def test_draft_model_engine_args_tensor_parallelism(): }, ) tgt_vllm_config: VllmConfig = engine_args.create_engine_config() - assert tgt_vllm_config.parallel_config.tensor_parallel_size == 2 + assert tgt_vllm_config.parallel_config.tensor_parallel_size == 4 assert tgt_vllm_config.quant_config.get_name() == "fp8" - assert ( - tgt_vllm_config.speculative_config.draft_parallel_config.tensor_parallel_size - == 1 - ) draft_vllm_config: VllmConfig = create_vllm_config_for_draft_model(tgt_vllm_config) assert draft_vllm_config.parallel_config.tensor_parallel_size == 1 diff --git a/vllm/config/speculative.py b/vllm/config/speculative.py index 082e5ee95da6..f5be33ef2251 100644 --- a/vllm/config/speculative.py +++ b/vllm/config/speculative.py @@ -545,9 +545,6 @@ def _verify_args(self) -> Self: "speculative_config. Please pass 'draft_tensor_parallel_size' instead." ) - if self.draft_tensor_parallel_size > 1: - raise ValueError("Only draft_tensor_parallel_size=1 is implemented so far.") - if self.num_speculative_tokens is None: raise ValueError( "num_speculative_tokens must be provided with " diff --git a/vllm/v1/spec_decode/draft_model.py b/vllm/v1/spec_decode/draft_model.py index 0961c4a11372..7a06fed8a7a0 100644 --- a/vllm/v1/spec_decode/draft_model.py +++ b/vllm/v1/spec_decode/draft_model.py @@ -7,6 +7,7 @@ from vllm.attention.layer import Attention from vllm.config import VllmConfig, get_layers_from_vllm_config +from vllm.config.speculative import SpeculativeConfig from vllm.logger import init_logger from vllm.model_executor.model_loader import get_model from vllm.v1.attention.backends.utils import ( @@ -37,6 +38,7 @@ def __init__( self._raise_if_mrope() self._raise_if_padded_drafter_batch() self._raise_if_vocab_size_mismatch() + self._raise_if_draft_tp_mismatch() def propose( self, @@ -101,12 +103,23 @@ def _raise_if_padded_drafter_batch(self): raise NotImplementedError( "Speculative Decoding with draft models does not support " "padded drafter batch yet. Please pass --disable-padded-drafter-batch " - "in the speculative config." + "in the speculative_config." ) def _raise_if_vocab_size_mismatch(self): self.vllm_config.speculative_config.verify_equal_vocab_size_if_draft_model() + def _raise_if_draft_tp_mismatch(self): + spec_cfg: SpeculativeConfig = self.vllm_config.speculative_config + tgt_tp = spec_cfg.target_parallel_config.tensor_parallel_size + draft_tp = spec_cfg.draft_parallel_config.tensor_parallel_size + if draft_tp != tgt_tp: + raise ValueError( + f"Currently, 'draft_tensor_parallel_size' and 'tensor_parallel_size' " + f"must be the same. Got {draft_tp} and {tgt_tp}. " + "Please pass 'draft_tensor_parallel_size' in the speculative_config." + ) + def set_input_ids_first_pass( self, target_token_ids: torch.Tensor, From 02d9d86e409681588612db8183467fe81e715597 Mon Sep 17 00:00:00 2001 From: Tomas Ruiz Date: Tue, 14 Oct 2025 10:41:48 +0000 Subject: [PATCH 69/73] Explicitly set rank for draft TP Signed-off-by: Tomas Ruiz --- vllm/config/vllm.py | 18 ++++++++++++++++++ vllm/v1/worker/worker_base.py | 2 +- 2 files changed, 19 insertions(+), 1 deletion(-) diff --git a/vllm/config/vllm.py b/vllm/config/vllm.py index 4a7d0280a6dc..08ff0b635e8e 100644 --- a/vllm/config/vllm.py +++ b/vllm/config/vllm.py @@ -765,6 +765,24 @@ def replace(self, **kwargs): """ return replace(self, **kwargs) + def set_rank(self, rank: int): + self.parallel_config.rank = rank + self._if_necessary_set_rank_in_draft_parallel_config(rank) + + def _if_necessary_set_rank_in_draft_parallel_config(self, rank: int): + """If the speculative config is tensor parallel, + and it has the same tensor parallel size as the target model, + then set the rank in the speculative parallel config, too.""" + sd_cfg: SpeculativeConfig | None = self.speculative_config + set_rank = ( + sd_cfg is not None + and sd_cfg.draft_parallel_config is not None + and sd_cfg.draft_parallel_config.tensor_parallel_size + == self.parallel_config.tensor_parallel_size + ) + if set_rank: + sd_cfg.draft_parallel_config.rank = rank + def __str__(self): return ( f"model={self.model_config.model!r}, " diff --git a/vllm/v1/worker/worker_base.py b/vllm/v1/worker/worker_base.py index 85436b443f7c..a0207e6fd4fc 100644 --- a/vllm/v1/worker/worker_base.py +++ b/vllm/v1/worker/worker_base.py @@ -77,7 +77,7 @@ def __init__( self.current_platform = current_platform - self.parallel_config.rank = rank + self.vllm_config.set_rank(rank) self.local_rank = local_rank self.rank = rank self.distributed_init_method = distributed_init_method From 14946cd5d0840410e866482d17275894fa8fa1a8 Mon Sep 17 00:00:00 2001 From: Tomas Ruiz Date: Tue, 14 Oct 2025 10:55:32 +0000 Subject: [PATCH 70/73] Document why we enforce equal TP Signed-off-by: Tomas Ruiz --- vllm/v1/spec_decode/draft_model.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/vllm/v1/spec_decode/draft_model.py b/vllm/v1/spec_decode/draft_model.py index 7a06fed8a7a0..86c2c71239d7 100644 --- a/vllm/v1/spec_decode/draft_model.py +++ b/vllm/v1/spec_decode/draft_model.py @@ -110,6 +110,12 @@ def _raise_if_vocab_size_mismatch(self): self.vllm_config.speculative_config.verify_equal_vocab_size_if_draft_model() def _raise_if_draft_tp_mismatch(self): + # Note(Tomas Ruiz) If we run the target model with TP > 1 and + # the draft model with TP = 1, then the different TP ranks collide. + # Specifically when all ranks compile the draft model on rank 0 + # (because TP=1), then the torch compile cache is overwritten and corrupted. + # We need a mechanism like this: https://github.com/vllm-project/vllm/pull/5414 + # To prevent this error, we assert that both TP sizes must be the same. spec_cfg: SpeculativeConfig = self.vllm_config.speculative_config tgt_tp = spec_cfg.target_parallel_config.tensor_parallel_size draft_tp = spec_cfg.draft_parallel_config.tensor_parallel_size From e1dbab15b487b1a219f79d06e71bf1cf75fedcd2 Mon Sep 17 00:00:00 2001 From: Tomas Ruiz Date: Tue, 14 Oct 2025 11:15:38 +0000 Subject: [PATCH 71/73] Simplify changes. Improve docs Signed-off-by: Tomas Ruiz --- tests/v1/e2e/test_spec_decode.py | 12 ++++++++---- vllm/config/parallel.py | 4 ++++ vllm/config/speculative.py | 4 +++- vllm/config/vllm.py | 18 ------------------ vllm/v1/spec_decode/draft_model.py | 13 ++++++++----- vllm/v1/worker/worker_base.py | 2 +- 6 files changed, 24 insertions(+), 29 deletions(-) diff --git a/tests/v1/e2e/test_spec_decode.py b/tests/v1/e2e/test_spec_decode.py index c3e20f66ba54..5f0c69ef626f 100644 --- a/tests/v1/e2e/test_spec_decode.py +++ b/tests/v1/e2e/test_spec_decode.py @@ -427,6 +427,7 @@ def test_draft_model_quantization(models: tuple[str, str], enforce_eager: bool): def test_draft_model_tensor_parallelism(): + """Ensure spec decode works when running with TP > 1.""" sd_case = ArgsTest( target_model="Qwen/Qwen3-1.7B", target_tensor_parallel_size=2, @@ -438,6 +439,9 @@ def test_draft_model_tensor_parallelism(): def test_draft_model_engine_args_tensor_parallelism(): + """Ensure the vllm_config for the draft model is created correctly, + and independently of the target model (quantization, TP, etc.)""" + engine_args = EngineArgs( model="Qwen/Qwen3-1.7B-FP8", # <<< tgt quantized tensor_parallel_size=4, @@ -458,17 +462,17 @@ def test_draft_model_engine_args_tensor_parallelism(): def test_draft_model_engine_args_rejects_invalid_tp_argname(): - """The user should pass "draft_tensor_parallel_size", rather than - "tensor_parallel_size". This is to catch bad syntax early.""" + """The user should pass "draft_tensor_parallel_size" rather than + "tensor_parallel_size". We enforce this with validation.""" engine_args = EngineArgs( model="Qwen/Qwen3-1.7B", - tensor_parallel_size=2, + tensor_parallel_size=1, speculative_config={ "model": "Qwen/Qwen3-0.6B", "method": "draft_model", "num_speculative_tokens": 3, - "tensor_parallel_size": 1, # invalid arg name + "tensor_parallel_size": 1, # <<< invalid arg name }, ) with pytest.raises(ValueError): diff --git a/vllm/config/parallel.py b/vllm/config/parallel.py index 084e458f8830..ffb1e176bd5f 100644 --- a/vllm/config/parallel.py +++ b/vllm/config/parallel.py @@ -3,6 +3,7 @@ import hashlib import os +from dataclasses import replace from typing import TYPE_CHECKING, Any, Literal import torch @@ -564,3 +565,6 @@ def _verify_args(self) -> Self: ) return self + + def replace(self, **kwargs) -> Self: + return replace(self, **kwargs) diff --git a/vllm/config/speculative.py b/vllm/config/speculative.py index f5be33ef2251..37a420738702 100644 --- a/vllm/config/speculative.py +++ b/vllm/config/speculative.py @@ -80,7 +80,9 @@ class SpeculativeConfig: """The degree of the tensor parallelism for the draft model. Can only be 1 or the same as the target model's tensor parallel size.""" tensor_parallel_size: int | None = None - """This is only used to capture and reject if passed""" + """Users should pass "draft_tensor_parallel_size". This parameters is only + to reject it if passed.""" + disable_logprobs: bool = True """If set to True, token log probabilities are not returned during speculative decoding. If set to False, token log probabilities are returned diff --git a/vllm/config/vllm.py b/vllm/config/vllm.py index 08ff0b635e8e..4a7d0280a6dc 100644 --- a/vllm/config/vllm.py +++ b/vllm/config/vllm.py @@ -765,24 +765,6 @@ def replace(self, **kwargs): """ return replace(self, **kwargs) - def set_rank(self, rank: int): - self.parallel_config.rank = rank - self._if_necessary_set_rank_in_draft_parallel_config(rank) - - def _if_necessary_set_rank_in_draft_parallel_config(self, rank: int): - """If the speculative config is tensor parallel, - and it has the same tensor parallel size as the target model, - then set the rank in the speculative parallel config, too.""" - sd_cfg: SpeculativeConfig | None = self.speculative_config - set_rank = ( - sd_cfg is not None - and sd_cfg.draft_parallel_config is not None - and sd_cfg.draft_parallel_config.tensor_parallel_size - == self.parallel_config.tensor_parallel_size - ) - if set_rank: - sd_cfg.draft_parallel_config.rank = rank - def __str__(self): return ( f"model={self.model_config.model!r}, " diff --git a/vllm/v1/spec_decode/draft_model.py b/vllm/v1/spec_decode/draft_model.py index 86c2c71239d7..90de7331a128 100644 --- a/vllm/v1/spec_decode/draft_model.py +++ b/vllm/v1/spec_decode/draft_model.py @@ -171,16 +171,19 @@ def create_vllm_config_for_draft_model( target_model_vllm_config: VllmConfig, ) -> VllmConfig: """The vllm_config is configured for the target model, e.g. - its quant_config and parallel_config. But the draft model might - not be quantized the same way, and might have different tensor_parallel_size. - We need to create a new vllm_config for the draft model. - This is vllm_config is useful when loading the draft model. + its quant_config and parallel_config. But the draft model is potentially + quantized differently, and has potentially different tensor_parallel_size. + This function creates a new vllm_config configured for the draft model. + The vllm_config is useful when loading the draft model with get_model(). """ old = target_model_vllm_config + new_parallel_config = old.speculative_config.draft_parallel_config.replace( + rank=old.parallel_config.rank + ) new: VllmConfig = old.replace( quant_config=None, # quant_config is recomputed in __init__() model_config=old.speculative_config.draft_model_config, - parallel_config=old.speculative_config.draft_parallel_config, + parallel_config=new_parallel_config, ) return new diff --git a/vllm/v1/worker/worker_base.py b/vllm/v1/worker/worker_base.py index a0207e6fd4fc..85436b443f7c 100644 --- a/vllm/v1/worker/worker_base.py +++ b/vllm/v1/worker/worker_base.py @@ -77,7 +77,7 @@ def __init__( self.current_platform = current_platform - self.vllm_config.set_rank(rank) + self.parallel_config.rank = rank self.local_rank = local_rank self.rank = rank self.distributed_init_method = distributed_init_method From 4641ec6e6f4a66c5a0f2a500d517e2596f27a21f Mon Sep 17 00:00:00 2001 From: Tomas Ruiz Date: Thu, 16 Oct 2025 13:59:41 +0200 Subject: [PATCH 72/73] Simplify tests Signed-off-by: Tomas Ruiz --- tests/v1/e2e/test_spec_decode.py | 43 +++----------------------------- 1 file changed, 3 insertions(+), 40 deletions(-) diff --git a/tests/v1/e2e/test_spec_decode.py b/tests/v1/e2e/test_spec_decode.py index 5f0c69ef626f..65b966cec02f 100644 --- a/tests/v1/e2e/test_spec_decode.py +++ b/tests/v1/e2e/test_spec_decode.py @@ -14,7 +14,6 @@ from vllm.config.vllm import VllmConfig from vllm.distributed import cleanup_dist_env_and_memory from vllm.engine.arg_utils import EngineArgs -from vllm.outputs import RequestOutput from vllm.platforms import current_platform from vllm.v1.spec_decode.draft_model import create_vllm_config_for_draft_model from vllm.v1.spec_decode.metrics import compute_acceptance_len, compute_acceptance_rate @@ -368,7 +367,6 @@ class ArgsTest: num_speculative_tokens: int expected_acceptance_rate: float expected_acceptance_len: float - expected_same_output_fraction: float # Defaults target_tensor_parallel_size: int = 1 draft_tensor_parallel_size: int = 1 @@ -385,7 +383,6 @@ class ArgsTest: num_speculative_tokens=3, # K expected_acceptance_len=3 + 1, # K + 1 expected_acceptance_rate=1.0, - expected_same_output_fraction=1.0, ), # Smaller draft model, stochastic sampling. ArgsTest( @@ -395,7 +392,6 @@ class ArgsTest: num_speculative_tokens=3, expected_acceptance_len=2.8 + 1, expected_acceptance_rate=0.9, - expected_same_output_fraction=0.9, ), ] @@ -502,8 +498,10 @@ def assert_draft_model_correctness(args: ArgsTest, enforce_eager: bool): enforce_eager=enforce_eager, disable_log_stats=False, # enables get_metrics() ) - spec_outputs = spec_llm.chat(test_prompts, args.sampling_config) + # we don't check the outputs, only check the metrics + spec_llm.chat(test_prompts, args.sampling_config) metrics = spec_llm.get_metrics() + acceptance_rate: float = compute_acceptance_rate(metrics) acceptance_len: float = compute_acceptance_len(metrics) del spec_llm # CLEANUP @@ -513,53 +511,18 @@ def assert_draft_model_correctness(args: ArgsTest, enforce_eager: bool): assert acceptance_rate >= args.expected_acceptance_rate assert acceptance_len >= args.expected_acceptance_len - ref_llm = LLM( - model=args.target_model, - max_model_len=args.max_model_len, - gpu_memory_utilization=args.gpu_memory_utilization, - tensor_parallel_size=args.target_tensor_parallel_size, - enforce_eager=enforce_eager, - ) - ref_outputs = ref_llm.chat(test_prompts, args.sampling_config) - del ref_llm # CLEANUP - torch.cuda.empty_cache() - cleanup_dist_env_and_memory() - - assert len(ref_outputs) > 0 - assert len(ref_outputs) == len(spec_outputs) - - match_fraction = compute_exact_matches(ref_outputs, spec_outputs) - assert match_fraction >= args.expected_same_output_fraction - print( f"spec-decode: target={args.target_model}, draft={args.draft_model}, " f"temperature={args.sampling_config.temperature:.2f}, " f"acceptance_rate={acceptance_rate:.2f}, " f"acceptance_len={acceptance_len:.2f}, " - f"match_fraction={match_fraction:.2f}" ) -def compute_exact_matches( - ref_outputs: list[RequestOutput], spec_outputs: list[RequestOutput] -) -> float: - """Compute the fraction of the prompts that match exactly""" - assert len(ref_outputs) == len(spec_outputs) - matches = 0 - for ref_output, spec_output in zip(ref_outputs, spec_outputs): - if ref_output.outputs[0].text == spec_output.outputs[0].text: - matches += 1 - else: - print(f"ref_output: {ref_output.outputs[0].text}") - print(f"spec_output: {spec_output.outputs[0].text}") - return matches / len(ref_outputs) - - def some_high_acceptance_metrics() -> dict: return { "sampling_config": greedy_sampling(), "num_speculative_tokens": 3, "expected_acceptance_len": 2.95 + 1, "expected_acceptance_rate": 0.95, - "expected_same_output_fraction": 0.95, } From ea3bb0a0496f6eb5ae91f3475dd913672c94cc7d Mon Sep 17 00:00:00 2001 From: Tomas Ruiz Date: Thu, 16 Oct 2025 21:01:40 +0200 Subject: [PATCH 73/73] Reject draft models with multiple kv-cache groups Signed-off-by: Tomas Ruiz --- vllm/v1/spec_decode/eagle.py | 6 +++--- vllm/v1/worker/gpu_model_runner.py | 7 +++++-- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index d84cdfb13dd6..940f1edf0f9a 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -1135,8 +1135,8 @@ def _get_attention_metadata_builder(self) -> AttentionMetadataBuilder: def validate_same_kv_cache_group(self, kv_cache_config: KVCacheConfig) -> None: """ - Validate that all eagle layers belong to the same KVCacheGroup. - Need this assumption to ensure all eagle layers can use the + Validate that all drafting layers belong to the same KVCacheGroup. + Need this assumption to ensure all drafting layers can use the same AttentionMetadata. May extend to multiple AttentionMetadata in the future. """ @@ -1154,7 +1154,7 @@ def validate_same_kv_cache_group(self, kv_cache_config: KVCacheConfig) -> None: ) ) == 1 - ), "All eagle layers should belong to the same kv cache group" + ), "All drafting layers should belong to the same kv cache group" class EagleProposer(SpecDecodeBaseProposer): diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 81893ab1d6fa..5b01833b5123 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -4539,8 +4539,11 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: self.may_reinitialize_input_batch(kv_cache_config) kv_caches = self.initialize_kv_cache_tensors(kv_cache_config) - if self.speculative_config and self.speculative_config.use_eagle(): - assert isinstance(self.drafter, EagleProposer) + if self.speculative_config and ( + self.speculative_config.use_eagle() + or self.speculative_config.uses_draft_model() + ): + assert isinstance(self.drafter, EagleProposer | DraftModelProposer) # validate all draft model layers belong to the same kv cache # group self.drafter.validate_same_kv_cache_group(kv_cache_config)