Skip to content

Commit 37f013e

Browse files
committed
Add tests for quantized target / draft model
Signed-off-by: Tomas Ruiz <[email protected]>
1 parent f49a5ea commit 37f013e

File tree

3 files changed

+43
-9
lines changed

3 files changed

+43
-9
lines changed

tests/v1/e2e/test_spec_decode.py

Lines changed: 29 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -399,14 +399,37 @@ class ArgsTest:
399399

400400
@pytest.mark.parametrize("args", cases)
401401
@pytest.mark.parametrize("enforce_eager", [True, False])
402-
def test_draft_model_correctness(
403-
args: ArgsTest,
404-
enforce_eager: bool,
405-
monkeypatch: pytest.MonkeyPatch,
406-
):
402+
def test_draft_model_correctness(args: ArgsTest, enforce_eager: bool):
403+
assert_draft_model_correctness(args, enforce_eager)
404+
405+
406+
@pytest.mark.parametrize(
407+
"models",
408+
[
409+
# target_model, draft_model
410+
("Qwen/Qwen3-1.7B-FP8", "Qwen/Qwen3-0.6B"), # target quantized
411+
("Qwen/Qwen3-1.7B", "Qwen/Qwen3-0.6B-FP8"), # draft quantized
412+
],
413+
ids=["target_quantized", "draft_quantized"],
414+
)
415+
@pytest.mark.parametrize("enforce_eager", [True, False])
416+
def test_draft_model_quantization(models: tuple[str, str], enforce_eager: bool):
417+
tgt_model, draft_model = models
418+
sd_case = ArgsTest(
419+
model=tgt_model,
420+
draft_model=draft_model,
421+
sampling_config=greedy_sampling(),
422+
num_speculative_tokens=3,
423+
expected_acceptance_len=2.95 + 1,
424+
expected_acceptance_rate=0.95,
425+
expected_same_output_fraction=0.95,
426+
)
427+
assert_draft_model_correctness(sd_case, enforce_eager)
428+
429+
430+
def assert_draft_model_correctness(args: ArgsTest, enforce_eager: bool):
407431
"""Compare the outputs using and not using speculative decoding.
408432
In the greedy decoding case, the outputs must match EXACTLY."""
409-
monkeypatch.setenv("VLLM_USE_V1", "1")
410433
test_prompts = get_test_prompts(mm_enabled=False, quiet=True)
411434

412435
spec_llm = LLM(

vllm/config/vllm.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -757,6 +757,14 @@ def compile_debug_dump_path(self) -> Path | None:
757757
path = self.compilation_config.debug_dump_path / append_path
758758
return path
759759

760+
def replace(self, **kwargs):
761+
"""
762+
Replace attributes of the config, and 'recompute' the config.
763+
dataclass.replace() calls __init__() and __post_init__(), source:
764+
https://docs.python.org/3/library/dataclasses.html#dataclasses.replace
765+
"""
766+
return replace(self, **kwargs)
767+
760768
def __str__(self):
761769
return (
762770
f"model={self.model_config.model!r}, "

vllm/v1/spec_decode/draft_model.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3-
from dataclasses import dataclass, replace
3+
from dataclasses import dataclass
44
from typing import Any
55

66
import torch
@@ -118,8 +118,11 @@ def load_model(self, target_model: Any) -> None:
118118
draft_model_config: ModelConfig = (
119119
self.vllm_config.speculative_config.draft_model_config
120120
)
121-
vllm_config_draft: VllmConfig = replace(
122-
self.vllm_config, model_config=draft_model_config
121+
# Recompute quant_config, which is configured for the target model
122+
# But the draft model might not be quantized.
123+
vllm_config_draft: VllmConfig = self.vllm_config.replace(
124+
quant_config=None,
125+
model_config=draft_model_config,
123126
)
124127

125128
# This must be computed before loading the draft model

0 commit comments

Comments
 (0)