diff --git a/tests/v1/e2e/test_spec_decode.py b/tests/v1/e2e/test_spec_decode.py index 7dbdf0ca0710..45b48e585893 100644 --- a/tests/v1/e2e/test_spec_decode.py +++ b/tests/v1/e2e/test_spec_decode.py @@ -121,6 +121,86 @@ def test_ngram_correctness( cleanup_dist_env_and_memory() +@pytest.mark.parametrize( + "model_path", + [ + "RedHatAI/Llama-3.1-8B-Instruct-speculator.eagle3", + "RedHatAI/Qwen3-8B-speculator.eagle3", + ], + ids=["llama3_eagle3_speculator", "qwen3_eagle3_speculator"], +) +def test_speculators_model_integration( + monkeypatch: pytest.MonkeyPatch, + sampling_config: SamplingParams, + model_path: str, +): + """ + Test that speculators models work with the simplified integration. + + This verifies the `vllm serve ` use case where + speculative config is automatically detected from the model config + without requiring explicit --speculative-config argument. + + Tests: + 1. Speculator model is correctly detected + 2. Verifier model is extracted from speculator config + 3. Speculative decoding is automatically enabled + 4. Text generation works correctly + 5. Output matches reference (non-speculative) generation + """ + monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1") + + # Generate test prompts + test_prompts = get_test_prompts(mm_enabled=False) + + # First run: Direct speculator model (simplified integration) + spec_llm = LLM(model=model_path, max_model_len=1024) + spec_outputs = spec_llm.chat(test_prompts, sampling_config) + + # Verify speculative config was auto-detected + assert spec_llm.llm_engine.vllm_config.speculative_config is not None, ( + f"Speculative config should be auto-detected for {model_path}" + ) + + spec_config = spec_llm.llm_engine.vllm_config.speculative_config + assert spec_config.num_speculative_tokens > 0, ( + f"Expected positive speculative tokens, " + f"got {spec_config.num_speculative_tokens}" + ) + + # Verify draft model is set to the speculator model + assert spec_config.model == model_path, ( + f"Draft model should be {model_path}, got {spec_config.model}" + ) + + # Extract verifier model for reference run + verifier_model = spec_llm.llm_engine.vllm_config.model_config.model + + del spec_llm + torch.cuda.empty_cache() + cleanup_dist_env_and_memory() + + # Second run: Reference without speculative decoding + ref_llm = LLM(model=verifier_model, max_model_len=1024) + ref_outputs = ref_llm.chat(test_prompts, sampling_config) + del ref_llm + torch.cuda.empty_cache() + cleanup_dist_env_and_memory() + + # Compare outputs + matches = sum( + 1 + for ref, spec in zip(ref_outputs, spec_outputs) + if ref.outputs[0].text == spec.outputs[0].text + ) + + # Heuristic: expect at least 66% of prompts to match exactly + assert matches >= int(0.66 * len(ref_outputs)), ( + f"Only {matches}/{len(ref_outputs)} outputs matched. " + f"Expected at least {int(0.66 * len(ref_outputs))} matches." + ) + + @pytest.mark.parametrize( ["model_setup", "mm_enabled"], [ diff --git a/tests/speculative_decoding/speculators/test_eagle3.py b/tests/v1/spec_decode/test_speculators_eagle3.py similarity index 94% rename from tests/speculative_decoding/speculators/test_eagle3.py rename to tests/v1/spec_decode/test_speculators_eagle3.py index 19ba32d8dee4..5ce6e1593b5c 100644 --- a/tests/speculative_decoding/speculators/test_eagle3.py +++ b/tests/v1/spec_decode/test_speculators_eagle3.py @@ -22,10 +22,6 @@ "nm-testing/Speculator-Qwen3-8B-Eagle3-converted-071-quantized-w4a16", id="qwen3-eagle3-speculator-w4a16-verifier", ), - pytest.param( - "nm-testing/random-weights-llama3.1.8b-2layer-eagle3", - id="llama3-eagl3-multiple-layers", - ), ], ) def test_eagle3_speculators_model( diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index e8f8e3f8c2b5..f13ce935ec4b 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -81,7 +81,7 @@ is_interleaved, maybe_override_with_speculators, ) -from vllm.transformers_utils.utils import check_gguf_file +from vllm.transformers_utils.utils import check_gguf_file, is_s3 from vllm.utils.argparse_utils import FlexibleArgumentParser from vllm.utils.mem_constants import GiB_bytes from vllm.utils.network_utils import get_ip @@ -1305,20 +1305,26 @@ def create_engine_config( device_config = DeviceConfig(device=cast(Device, current_platform.device_type)) + # Check if the model is a speculator and override model/tokenizer/config + # BEFORE creating ModelConfig, so the config is created with the target model + # Skip speculator detection for S3 models since HuggingFace cannot load + # configs directly from S3 URLs. S3 models can still use speculators with + # explicit --speculative-config. + if not is_s3(self.model): + (self.model, self.tokenizer, self.speculative_config) = ( + maybe_override_with_speculators( + model=self.model, + tokenizer=self.tokenizer, + revision=self.revision, + trust_remote_code=self.trust_remote_code, + vllm_speculative_config=self.speculative_config, + ) + ) + model_config = self.create_model_config() self.model = model_config.model self.tokenizer = model_config.tokenizer - (self.model, self.tokenizer, self.speculative_config) = ( - maybe_override_with_speculators( - model=self.model, - tokenizer=self.tokenizer, - revision=self.revision, - trust_remote_code=self.trust_remote_code, - vllm_speculative_config=self.speculative_config, - ) - ) - # * If VLLM_USE_V1 is unset, we enable V1 for "supported features" # and fall back to V0 for experimental or unsupported features. # * If VLLM_USE_V1=1, we enable V1 for supported + experimental