Skip to content
Open
Show file tree
Hide file tree
Changes from 19 commits
Commits
Show all changes
95 commits
Select commit Hold shift + click to select a range
f27261a
Speculative Decoding with Draft Model
tomasruizt Aug 18, 2025
3b06a7c
Unod change to 'vllm bench throughput'
tomasruizt Sep 8, 2025
e41b0a3
Don't return too early
tomasruizt Sep 8, 2025
10366b9
Undo change to bind_kv_cache()
tomasruizt Sep 8, 2025
92af339
Undo changes to pyproject.toml
tomasruizt Sep 8, 2025
5b8b1c6
Merge branch 'main' into feature/spec-decode-draft-model
tomasruizt Sep 8, 2025
f2f9876
Simplify test array
tomasruizt Sep 8, 2025
824ba10
Ensure EAGLE loads correctly
tomasruizt Sep 9, 2025
5e248c1
Pass input_embeds when model is multimodal
tomasruizt Sep 9, 2025
1669ea7
Raise NotImplementedError on Mrope or Multimodal models
tomasruizt Sep 9, 2025
6040697
Merge branch 'main' into feature/spec-decode-draft-model
tomasruizt Sep 9, 2025
4b77a83
Merge branch 'main' into feature/spec-decode-draft-model
tomasruizt Sep 17, 2025
5a6cc82
Merge branch 'main' into feature/spec-decode-draft-model
tomasruizt Sep 17, 2025
54e107d
Speculative decoding with draft model separate from EAGLE
tomasruizt Sep 17, 2025
134b841
Merge branch 'main' into feature/spec-decode-draft-model
tomasruizt Sep 18, 2025
36fb940
Pass last_token_indices
tomasruizt Sep 18, 2025
b018560
Undo unnecessary changes
tomasruizt Sep 18, 2025
17e9fe5
Merge branch 'main' into feature/spec-decode-draft-model
tomasruizt Sep 18, 2025
daee8ec
Move more methods to base class
tomasruizt Sep 19, 2025
b45f7af
Merge branch 'main' into feature/spec-decode-draft-model
tomasruizt Sep 22, 2025
07d1b97
Fix call to model.compute_logits()
tomasruizt Sep 22, 2025
86d8040
Move .propose() to superclass
tomasruizt Sep 22, 2025
a696797
Merge branch 'main' into feature/spec-decode-draft-model
tomasruizt Sep 22, 2025
1afbe14
Merge branch 'main' into feature/spec-decode-draft-model
tomasruizt Sep 24, 2025
d37d780
Minimize git diffs in EAGLE
tomasruizt Sep 24, 2025
5967e09
Fix missing input
tomasruizt Sep 24, 2025
7b03a45
fix next_token_ids issue
benchislett Sep 25, 2025
35fa5a9
Merge pull request #3 from CentML/spec-decode-draft-model
tomasruizt Sep 25, 2025
ef5da86
Merge branch 'main' into feature/spec-decode-draft-model
tomasruizt Sep 25, 2025
c7d2fd5
Test also acceptance-len
tomasruizt Sep 25, 2025
ac90311
Pass missing argument in test_eagle.py
tomasruizt Sep 26, 2025
857415b
Merge branch 'main' into feature/spec-decode-draft-model
tomasruizt Sep 26, 2025
b477e10
CKPT: Remove extra forward
tomasruizt Sep 26, 2025
309d827
Prevent illegal access to hidden_states
tomasruizt Sep 26, 2025
2e97fab
Remove forward. single prompt works. Batch fails
tomasruizt Sep 28, 2025
794c3cf
Merge branch 'main' into feature/spec-decode-draft-model
tomasruizt Sep 28, 2025
89b9c1d
Remove unnecessary if-else statement
tomasruizt Sep 28, 2025
c767118
Merge branch 'feature/spec-decode-draft-model' into featury/remove-ex…
tomasruizt Sep 30, 2025
e74c71e
Minimize changes
tomasruizt Sep 30, 2025
994e9cc
Commit unit test success
tomasruizt Sep 30, 2025
26ab913
Remove unnecessary variables
tomasruizt Sep 30, 2025
01dd981
Minimize changes
tomasruizt Sep 30, 2025
09a0bb3
Remove token logging
tomasruizt Sep 30, 2025
42faf1c
Relocate utility method
tomasruizt Sep 30, 2025
044e45c
Simplify extend_flat_seqs()
tomasruizt Sep 30, 2025
7a1949d
Document test
tomasruizt Sep 30, 2025
316a6b8
Document funcs
tomasruizt Sep 30, 2025
0e75db7
Merge pull request #5 from tomasruizt/featury/remove-extra-forward
tomasruizt Sep 30, 2025
af06030
Update BatchDescriptor with correct num_tokens
tomasruizt Oct 1, 2025
a791d2e
Make sure AL benchmark can run
tomasruizt Oct 1, 2025
1de5ef4
Extend drafter max_num_tokens
tomasruizt Oct 1, 2025
4371d47
CKPT: Find bug affecting acceptance length
tomasruizt Oct 1, 2025
1718892
Fix AL for default drafter padding
tomasruizt Oct 2, 2025
ac56891
Remove logging
tomasruizt Oct 2, 2025
4b43999
use non-blocking cpu move, document and test helper fns
tomasruizt Oct 2, 2025
10eb718
Minimize changes
tomasruizt Oct 2, 2025
4c7eb11
Reduce changes footprint
tomasruizt Oct 2, 2025
d123018
Reduce changes
tomasruizt Oct 2, 2025
02872ad
Minimize changes
tomasruizt Oct 2, 2025
50ae07f
Merge commit '17edd8a' into feature/spec-decode-draft-model
tomasruizt Oct 6, 2025
33bcc08
ruff
tomasruizt Oct 6, 2025
fa99c05
Merge commit 'd6953be' into feature/spec-decode-draft-model
tomasruizt Oct 6, 2025
eac09d2
Get AL high again
tomasruizt Oct 6, 2025
ccac6cb
Minimze changes
tomasruizt Oct 6, 2025
2ba8c5a
Merge branch 'main' into feature/spec-decode-draft-model
tomasruizt Oct 7, 2025
c094f5f
Add flag for disable_padded_drafter_batch
tomasruizt Oct 7, 2025
a6f8484
Correct typo
tomasruizt Oct 7, 2025
4e77a80
Ensure draft model uses CUDA graph
tomasruizt Oct 7, 2025
a1e899c
Remove unnecessary cudagraph inputs
tomasruizt Oct 8, 2025
50dcbc4
Minimize changes
tomasruizt Oct 8, 2025
c01e43b
Minimize changes
tomasruizt Oct 8, 2025
cf99760
Remove unused fn
tomasruizt Oct 8, 2025
c73929d
Minimize changes
tomasruizt Oct 8, 2025
66d4f2b
Avoid OOB error on large batches
tomasruizt Oct 9, 2025
c27b6a7
Merge branch 'main' into feature/spec-decode-draft-model
tomasruizt Oct 10, 2025
de86231
Simplify away passing the CUDA graph args
tomasruizt Oct 10, 2025
f8321d2
add option --max-num-seqs to spec_decode.py (useful for small GPUs)
tomasruizt Oct 10, 2025
e9560ef
Prevent different tokenizer vocab sizes
tomasruizt Oct 10, 2025
694faf8
Limit cudagraph capture time in test
tomasruizt Oct 10, 2025
fa6294f
Minimize changes related to CUDA graph
tomasruizt Oct 10, 2025
c9ff19a
Merge branch 'main' into feature/spec-decode-draft-model
tomasruizt Oct 13, 2025
f49a5ea
Replace Optional[T] with T | None
tomasruizt Oct 13, 2025
37f013e
Add tests for quantized target / draft model
tomasruizt Oct 13, 2025
58f8496
Add test for draft model + tensor parallelism
tomasruizt Oct 13, 2025
4bd9a46
Log why endpoint is not ready
tomasruizt Oct 13, 2025
ff92d85
Test tensor parallelism more thoroughly
tomasruizt Oct 13, 2025
c135ae1
Reject draft TP > 1
tomasruizt Oct 14, 2025
7c011c0
Enforce same TP for draft & target
tomasruizt Oct 14, 2025
02d9d86
Explicitly set rank for draft TP
tomasruizt Oct 14, 2025
14946cd
Document why we enforce equal TP
tomasruizt Oct 14, 2025
e1dbab1
Simplify changes. Improve docs
tomasruizt Oct 14, 2025
f346cfa
Merge pull request #6 from tomasruizt/feature/correct-tensor-parallel…
tomasruizt Oct 14, 2025
4641ec6
Simplify tests
tomasruizt Oct 16, 2025
ea3bb0a
Reject draft models with multiple kv-cache groups
tomasruizt Oct 16, 2025
6ca55ab
Merge branch 'main' into feature/spec-decode-draft-model
tomasruizt Oct 29, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
# Scripts for development
scripts/

# version file generated by setuptools-scm
/vllm/_version.py

Expand Down
17 changes: 15 additions & 2 deletions examples/offline_inference/spec_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,11 @@ def parse_args():
parser.add_argument("--output-len", type=int, default=256)
parser.add_argument("--model-dir", type=str, default=None)
parser.add_argument("--eagle-dir", type=str, default=None)
parser.add_argument("--draft-model", type=str, default=None)
parser.add_argument("--custom-mm-prompts", action="store_true")
parser.add_argument("--gpu-memory-utilization", type=float, default=0.8)
parser.add_argument("--request-id-prefix", type=str, default="")
parser.add_argument("--max-model-len", type=int, default=16384)
return parser.parse_args()


Expand Down Expand Up @@ -117,6 +121,15 @@ def main():
"prompt_lookup_max": args.prompt_lookup_max,
"prompt_lookup_min": args.prompt_lookup_min,
}
elif args.method == "draft_model":
assert args.draft_model is not None and args.draft_model != ""
speculative_config = {
"method": args.method,
"model": args.draft_model,
"num_speculative_tokens": args.num_spec_tokens,
"enforce_eager": args.enforce_eager,
"max_model_len": args.max_model_len,
}
elif args.method.endswith("mtp"):
speculative_config = {
"method": args.method,
Expand All @@ -131,10 +144,10 @@ def main():
tensor_parallel_size=args.tp,
enable_chunked_prefill=args.enable_chunked_prefill,
enforce_eager=args.enforce_eager,
gpu_memory_utilization=0.8,
gpu_memory_utilization=args.gpu_memory_utilization,
speculative_config=speculative_config,
disable_log_stats=False,
max_model_len=16384,
max_model_len=args.max_model_len,
limit_mm_per_prompt={"image": 5},
disable_chunked_mm_input=True,
)
Expand Down
126 changes: 124 additions & 2 deletions tests/v1/e2e/test_spec_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from __future__ import annotations

import random
from dataclasses import dataclass
from typing import Any, Union

import pytest
Expand All @@ -13,10 +14,12 @@
from vllm.assets.base import VLLM_S3_BUCKET_URL
from vllm.assets.image import VLM_IMAGES_DIR
from vllm.distributed import cleanup_dist_env_and_memory
from vllm.outputs import RequestOutput
from vllm.platforms import current_platform
from vllm.v1.spec_decode.metrics import compute_acceptance_rate


def get_test_prompts(mm_enabled: bool):
def get_test_prompts(mm_enabled: bool, quiet: bool = False):
prompt_types = ["repeat", "sentence"]
if mm_enabled:
prompt_types.append("mm")
Expand All @@ -25,7 +28,9 @@ def get_test_prompts(mm_enabled: bool):

random.seed(0)
random_prompt_type_choices = random.choices(prompt_types, k=num_prompts)
print(f"Prompt types: {random_prompt_type_choices}")

if not quiet:
print(f"Prompt types: {random_prompt_type_choices}")

# Generate a mixed batch of prompts, some of which can be easily
# predicted by n-gram matching and some which likely cannot.
Expand Down Expand Up @@ -69,9 +74,17 @@ def get_test_prompts(mm_enabled: bool):

@pytest.fixture
def sampling_config():
return greedy_sampling()


def greedy_sampling() -> SamplingParams:
return SamplingParams(temperature=0, max_tokens=10, ignore_eos=False)


def stochastic_sampling() -> SamplingParams:
return SamplingParams(temperature=1.0, max_tokens=10, ignore_eos=False)


@pytest.fixture
def model_name():
return "meta-llama/Llama-3.1-8B-Instruct"
Expand Down Expand Up @@ -223,3 +236,112 @@ def test_eagle_correctness(
del spec_llm
torch.cuda.empty_cache()
cleanup_dist_env_and_memory()


@dataclass
class ArgsTest:
model: str
draft_model: str
sampling_config: SamplingParams
expected_acceptance_rate: float
expected_same_output_fraction: float
# Defaults
target_tensor_parallel_size: int = 1
draft_tensor_parallel_size: int = 1
max_model_len: int = 1024
gpu_memory_utilization: float = 0.5


cases = [
ArgsTest(
model="Qwen/Qwen3-0.6B",
draft_model="Qwen/Qwen3-0.6B",
sampling_config=greedy_sampling(),
expected_acceptance_rate=1.0,
expected_same_output_fraction=1.0,
),
ArgsTest(
model="Qwen/Qwen3-1.7B",
draft_model="Qwen/Qwen3-0.6B",
sampling_config=stochastic_sampling(),
expected_acceptance_rate=0.9,
expected_same_output_fraction=0.9,
),
]


@pytest.mark.parametrize("args", cases)
@pytest.mark.parametrize("enforce_eager", [True, False])
@pytest.mark.parametrize("disable_padded_drafter_batch", [True, False])
def test_draft_model_correctness(
args: ArgsTest,
enforce_eager: bool,
disable_padded_drafter_batch: bool,
monkeypatch: pytest.MonkeyPatch,
):
"""Compare the outputs using and not using speculative decoding.
In the greedy decoding case, the outputs must match EXACTLY."""
monkeypatch.setenv("VLLM_USE_V1", "1")
test_prompts = get_test_prompts(mm_enabled=False, quiet=True)

spec_llm = LLM(
model=args.model,
speculative_config={
"model": args.draft_model,
"method": "draft_model",
"num_speculative_tokens": 3,
"max_model_len": args.max_model_len,
"enforce_eager": enforce_eager,
"tensor_parallel_size": args.draft_tensor_parallel_size,
"disable_padded_drafter_batch": disable_padded_drafter_batch,
},
max_model_len=args.max_model_len,
gpu_memory_utilization=args.gpu_memory_utilization,
tensor_parallel_size=args.target_tensor_parallel_size,
enforce_eager=enforce_eager,
disable_log_stats=False, # enables get_metrics()
)
spec_outputs = spec_llm.chat(test_prompts, args.sampling_config)
acceptance_rate = compute_acceptance_rate(spec_llm.get_metrics())
del spec_llm # CLEANUP
torch.cuda.empty_cache()
cleanup_dist_env_and_memory()

assert acceptance_rate >= args.expected_acceptance_rate

ref_llm = LLM(
model=args.model,
max_model_len=args.max_model_len,
gpu_memory_utilization=args.gpu_memory_utilization,
tensor_parallel_size=args.target_tensor_parallel_size,
enforce_eager=enforce_eager,
)
ref_outputs = ref_llm.chat(test_prompts, args.sampling_config)
del ref_llm # CLEANUP
torch.cuda.empty_cache()
cleanup_dist_env_and_memory()

assert len(ref_outputs) > 0
assert len(ref_outputs) == len(spec_outputs)

match_fraction = compute_exact_matches(ref_outputs, spec_outputs)
assert match_fraction >= args.expected_same_output_fraction

print(f"spec-decode: target={args.model}, draft={args.draft_model}, "
f"temperature={args.sampling_config.temperature:.2f}, "
f"acceptance_rate={acceptance_rate:.2f}, "
f"match_fraction={match_fraction:.2f}")


def compute_exact_matches(ref_outputs: list[RequestOutput],
spec_outputs: list[RequestOutput]) -> float:
"""Compute the fraction of the prompts that match exactly"""
assert len(ref_outputs) == len(spec_outputs)
matches = 0
for ref_output, spec_output in zip(ref_outputs, spec_outputs):
if ref_output.outputs[0].text == spec_output.outputs[0].text:
matches += 1
else:
print(f"ref_output: {ref_output.outputs[0].text}")
print(f"spec_output: {spec_output.outputs[0].text}")
return matches / len(ref_outputs)
32 changes: 32 additions & 0 deletions tests/v1/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,38 @@ def test_bind_kv_cache():
assert runner_kv_caches[3] is kv_cache['layers.3.self_attn']


def test_bind_kv_cache_draft_model():
from vllm.attention import Attention
ctx = {
'model.layers.0.attn': Attention(32, 128, 0.1),
'model.layers.1.attn': Attention(32, 128, 0.1),
'draft_model.layers.0.attn': Attention(32, 128, 0.1),
'draft_model.layers.1.attn': Attention(32, 128, 0.1),
}
kv_cache = {
'model.layers.0.attn': torch.zeros((1, )),
'model.layers.1.attn': torch.zeros((1, )),
'draft_model.layers.0.attn': torch.zeros((1, )),
'draft_model.layers.1.attn': torch.zeros((1, )),
}
runner_kv_caches: list[torch.Tensor] = []
bind_kv_cache(kv_cache, ctx, runner_kv_caches)
assert ctx['model.layers.0.attn'].kv_cache[0] is kv_cache[
'model.layers.0.attn']
assert ctx['model.layers.1.attn'].kv_cache[0] is kv_cache[
'model.layers.1.attn']
assert ctx['draft_model.layers.0.attn'].kv_cache[0] is kv_cache[
'draft_model.layers.0.attn']
assert ctx['draft_model.layers.1.attn'].kv_cache[0] is kv_cache[
'draft_model.layers.1.attn']

# caches are ordered by layer_index, interleaving target and draft model
assert runner_kv_caches[0] is kv_cache['model.layers.0.attn']
assert runner_kv_caches[1] is kv_cache['draft_model.layers.0.attn']
assert runner_kv_caches[2] is kv_cache['model.layers.1.attn']
assert runner_kv_caches[3] is kv_cache['draft_model.layers.1.attn']


def test_bind_kv_cache_non_attention():
from vllm.attention import Attention

Expand Down
8 changes: 3 additions & 5 deletions vllm/config/speculative.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,11 +336,6 @@ def __post_init__(self):
)
else:
self.method = "draft_model"
raise NotImplementedError(
"Speculative decoding with draft model is not "
"supported yet. Please consider using other "
"speculative decoding methods such as ngram, medusa, "
"eagle, or deepseek_mtp.")

# Replace hf_config for EAGLE draft_model
if self.method in ("eagle", "eagle3"):
Expand Down Expand Up @@ -552,6 +547,9 @@ def use_eagle(self) -> bool:
return self.method in ("eagle", "eagle3", "deepseek_mtp", "ernie_mtp",
"qwen3_next_mtp")

def uses_draft_model(self) -> bool:
return self.method == "draft_model"

def __repr__(self) -> str:
method = self.method
model = None if method == "ngram" else self.draft_model_config.model
Expand Down
8 changes: 0 additions & 8 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1534,14 +1534,6 @@ def _is_v1_supported_oracle(self, model_config: ModelConfig) -> bool:
recommend_to_remove=False)
return False

# V1 supports N-gram, Medusa, and Eagle speculative decoding.
if (self.speculative_config is not None
and self.speculative_config.get("method") == "draft_model"):
raise NotImplementedError(
"Speculative decoding with draft model is not supported yet. "
"Please consider using other speculative decoding methods "
"such as ngram, medusa, eagle, or deepseek_mtp.")

V1_BACKENDS = [
"FLASH_ATTN_VLLM_V1",
"FLASH_ATTN",
Expand Down
6 changes: 4 additions & 2 deletions vllm/model_executor/model_loader/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,12 +112,14 @@ def get_model_loader(load_config: LoadConfig) -> BaseModelLoader:

def get_model(*,
vllm_config: VllmConfig,
model_config: Optional[ModelConfig] = None) -> nn.Module:
model_config: Optional[ModelConfig] = None,
prefix: str = "") -> nn.Module:
loader = get_model_loader(vllm_config.load_config)
if model_config is None:
model_config = vllm_config.model_config
return loader.load_model(vllm_config=vllm_config,
model_config=model_config)
model_config=model_config,
prefix=prefix)


__all__ = [
Expand Down
9 changes: 6 additions & 3 deletions vllm/model_executor/model_loader/base_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,10 @@ def load_weights(self, model: nn.Module,
inplace weights loading for an already-initialized model"""
raise NotImplementedError

def load_model(self, vllm_config: VllmConfig,
model_config: ModelConfig) -> nn.Module:
def load_model(self,
vllm_config: VllmConfig,
model_config: ModelConfig,
prefix: str = "") -> nn.Module:
"""Load a model with the given configurations."""
device_config = vllm_config.device_config
load_config = vllm_config.load_config
Expand All @@ -43,7 +45,8 @@ def load_model(self, vllm_config: VllmConfig,
with set_default_torch_dtype(model_config.dtype):
with target_device:
model = initialize_model(vllm_config=vllm_config,
model_config=model_config)
model_config=model_config,
prefix=prefix)

logger.debug("Loading weights on %s ...", load_device)
# Quantization does not happen in `load_weights` but after it
Expand Down
9 changes: 6 additions & 3 deletions vllm/model_executor/model_loader/gguf_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,8 +124,10 @@ def load_weights(self, model: nn.Module,
model.load_weights(
self._get_weights_iterator(local_model_path, gguf_weights_map))

def load_model(self, vllm_config: VllmConfig,
model_config: ModelConfig) -> nn.Module:
def load_model(self,
vllm_config: VllmConfig,
model_config: ModelConfig,
prefix: str = "") -> nn.Module:
device_config = vllm_config.device_config
local_model_path = self._prepare_weights(model_config.model)
gguf_weights_map = self._get_gguf_weights_map(model_config)
Expand All @@ -148,7 +150,8 @@ def load_model(self, vllm_config: VllmConfig,
target_device = torch.device(device_config.device)
with set_default_torch_dtype(model_config.dtype):
with target_device:
model = initialize_model(vllm_config=vllm_config)
model = initialize_model(vllm_config=vllm_config,
prefix=prefix)
self.load_weights(model, model_config)

process_weights_after_loading(model, model_config, target_device)
Expand Down
13 changes: 9 additions & 4 deletions vllm/model_executor/model_loader/tensorizer_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ def _get_weights_iterator(
def _load_model_serialized_cpu(
self,
vllm_config: VllmConfig,
prefix: str = "",
) -> nn.Module:
"""Load a serialized model with tensorizer to the CPU.

Expand All @@ -71,7 +72,8 @@ def _load_model_serialized_cpu(
model_config = vllm_config.model_config
with set_default_torch_dtype(model_config.dtype):
with torch.device(device_config.device):
model = initialize_model(vllm_config=vllm_config)
model = initialize_model(vllm_config=vllm_config,
prefix=prefix)

model.load_weights(self._get_weights_iterator())
return model.eval()
Expand Down Expand Up @@ -104,8 +106,10 @@ def load_weights(self, model: nn.Module,
else:
model.load_weights(self._get_weights_iterator())

def load_model(self, vllm_config: VllmConfig,
model_config: ModelConfig) -> nn.Module:
def load_model(self,
vllm_config: VllmConfig,
model_config: ModelConfig,
prefix: str = "") -> nn.Module:
parallel_config = vllm_config.parallel_config
self._verify_config(model_config, parallel_config)

Expand All @@ -126,7 +130,8 @@ def load_model(self, vllm_config: VllmConfig,
vllm_config=vllm_config)
self.load_weights(model, model_config)
return model
return self._load_model_serialized_cpu(vllm_config=vllm_config)
return self._load_model_serialized_cpu(vllm_config=vllm_config,
prefix=prefix)

@staticmethod
def save_model(
Expand Down
Loading