Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 80 additions & 0 deletions tests/v1/e2e/test_spec_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,86 @@ def test_ngram_correctness(
cleanup_dist_env_and_memory()


@pytest.mark.parametrize(
"model_path",
[
"RedHatAI/Llama-3.1-8B-Instruct-speculator.eagle3",
"RedHatAI/Qwen3-8B-speculator.eagle3",
],
ids=["llama3_eagle3_speculator", "qwen3_eagle3_speculator"],
)
def test_speculators_model_integration(
monkeypatch: pytest.MonkeyPatch,
sampling_config: SamplingParams,
model_path: str,
):
"""
Test that speculators models work with the simplified integration.

This verifies the `vllm serve <speculator-model>` use case where
speculative config is automatically detected from the model config
without requiring explicit --speculative-config argument.

Tests:
1. Speculator model is correctly detected
2. Verifier model is extracted from speculator config
3. Speculative decoding is automatically enabled
4. Text generation works correctly
5. Output matches reference (non-speculative) generation
"""
monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")

# Generate test prompts
test_prompts = get_test_prompts(mm_enabled=False)

# First run: Direct speculator model (simplified integration)
spec_llm = LLM(model=model_path, max_model_len=1024)
spec_outputs = spec_llm.chat(test_prompts, sampling_config)

# Verify speculative config was auto-detected
assert spec_llm.llm_engine.vllm_config.speculative_config is not None, (
f"Speculative config should be auto-detected for {model_path}"
)

spec_config = spec_llm.llm_engine.vllm_config.speculative_config
assert spec_config.num_speculative_tokens > 0, (
f"Expected positive speculative tokens, "
f"got {spec_config.num_speculative_tokens}"
)

# Verify draft model is set to the speculator model
assert spec_config.model == model_path, (
f"Draft model should be {model_path}, got {spec_config.model}"
)

# Extract verifier model for reference run
verifier_model = spec_llm.llm_engine.vllm_config.model_config.model

del spec_llm
torch.cuda.empty_cache()
cleanup_dist_env_and_memory()

# Second run: Reference without speculative decoding
ref_llm = LLM(model=verifier_model, max_model_len=1024)
ref_outputs = ref_llm.chat(test_prompts, sampling_config)
del ref_llm
torch.cuda.empty_cache()
cleanup_dist_env_and_memory()

# Compare outputs
matches = sum(
1
for ref, spec in zip(ref_outputs, spec_outputs)
if ref.outputs[0].text == spec.outputs[0].text
)

# Heuristic: expect at least 66% of prompts to match exactly
assert matches >= int(0.66 * len(ref_outputs)), (
f"Only {matches}/{len(ref_outputs)} outputs matched. "
f"Expected at least {int(0.66 * len(ref_outputs))} matches."
)


@pytest.mark.parametrize(
["model_setup", "mm_enabled"],
[
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,6 @@
"nm-testing/Speculator-Qwen3-8B-Eagle3-converted-071-quantized-w4a16",
id="qwen3-eagle3-speculator-w4a16-verifier",
),
pytest.param(
"nm-testing/random-weights-llama3.1.8b-2layer-eagle3",
id="llama3-eagl3-multiple-layers",
),
],
)
def test_eagle3_speculators_model(
Expand Down
28 changes: 17 additions & 11 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@
is_interleaved,
maybe_override_with_speculators,
)
from vllm.transformers_utils.utils import check_gguf_file
from vllm.transformers_utils.utils import check_gguf_file, is_s3
from vllm.utils.argparse_utils import FlexibleArgumentParser
from vllm.utils.mem_constants import GiB_bytes
from vllm.utils.network_utils import get_ip
Expand Down Expand Up @@ -1305,20 +1305,26 @@ def create_engine_config(

device_config = DeviceConfig(device=cast(Device, current_platform.device_type))

# Check if the model is a speculator and override model/tokenizer/config
# BEFORE creating ModelConfig, so the config is created with the target model
# Skip speculator detection for S3 models since HuggingFace cannot load
# configs directly from S3 URLs. S3 models can still use speculators with
# explicit --speculative-config.
if not is_s3(self.model):
(self.model, self.tokenizer, self.speculative_config) = (
maybe_override_with_speculators(
model=self.model,
tokenizer=self.tokenizer,
revision=self.revision,
trust_remote_code=self.trust_remote_code,
vllm_speculative_config=self.speculative_config,
)
)

model_config = self.create_model_config()
self.model = model_config.model
self.tokenizer = model_config.tokenizer

(self.model, self.tokenizer, self.speculative_config) = (
maybe_override_with_speculators(
model=self.model,
tokenizer=self.tokenizer,
revision=self.revision,
trust_remote_code=self.trust_remote_code,
vllm_speculative_config=self.speculative_config,
)
)

# * If VLLM_USE_V1 is unset, we enable V1 for "supported features"
# and fall back to V0 for experimental or unsupported features.
# * If VLLM_USE_V1=1, we enable V1 for supported + experimental
Expand Down