33from __future__ import annotations
44
55import random
6- from typing import Any
6+ from typing import Any , Union
77
88import pytest
99import torch
1010
1111from vllm import LLM , SamplingParams
12+ from vllm .assets .base import VLLM_S3_BUCKET_URL
13+ from vllm .assets .image import VLM_IMAGES_DIR
1214from vllm .distributed import cleanup_dist_env_and_memory
1315
1416
15- @pytest .fixture
16- def test_prompts ():
17+ def get_test_prompts (mm_enabled : bool ):
1718 prompt_types = ["repeat" , "sentence" ]
19+ if mm_enabled :
20+ prompt_types .append ("mm" )
1821 num_prompts = 100
1922 prompts = []
2023
2124 random .seed (0 )
2225 random_prompt_type_choices = random .choices (prompt_types , k = num_prompts )
26+ print (f"Prompt types: { random_prompt_type_choices } " )
2327
2428 # Generate a mixed batch of prompts, some of which can be easily
2529 # predicted by n-gram matching and some which likely cannot.
2630 for kind in random_prompt_type_choices :
2731 word_choices = ["test" , "temp" , "hello" , "where" ]
2832 word = random .choice (word_choices )
33+ prompt : Union [str , list [dict [str , Any ]]] = ""
2934 if kind == "repeat" :
3035 prompt = f"""
3136 please repeat the word '{ word } ' 10 times.
@@ -38,6 +43,21 @@ def test_prompts():
3843 uses the word { word } at least once.
3944 give no other output than that simple sentence without quotes.
4045 """
46+ elif kind == "mm" :
47+ placeholders = [{
48+ "type" : "image_url" ,
49+ "image_url" : {
50+ "url" :
51+ f"{ VLLM_S3_BUCKET_URL } /{ VLM_IMAGES_DIR } /stop_sign.jpg"
52+ },
53+ }]
54+ prompt = [
55+ * placeholders ,
56+ {
57+ "type" : "text" ,
58+ "text" : "The meaning of the image is"
59+ },
60+ ]
4161 else :
4262 raise ValueError (f"Unknown prompt type: { kind } " )
4363 prompts .append ([{"role" : "user" , "content" : prompt }])
@@ -57,7 +77,6 @@ def model_name():
5777
5878def test_ngram_correctness (
5979 monkeypatch : pytest .MonkeyPatch ,
60- test_prompts : list [list [dict [str , Any ]]],
6180 sampling_config : SamplingParams ,
6281 model_name : str ,
6382):
@@ -67,6 +86,7 @@ def test_ngram_correctness(
6786 '''
6887 with monkeypatch .context () as m :
6988 m .setenv ("VLLM_USE_V1" , "1" )
89+ test_prompts = get_test_prompts (mm_enabled = False )
7090
7191 ref_llm = LLM (model = model_name , max_model_len = 1024 )
7292 ref_outputs = ref_llm .chat (test_prompts , sampling_config )
@@ -103,23 +123,32 @@ def test_ngram_correctness(
103123 cleanup_dist_env_and_memory ()
104124
105125
106- @pytest .mark .parametrize ("model_setup" , [
107- ("eagle" , "meta-llama/Llama-3.1-8B-Instruct" ,
108- "yuhuili/EAGLE-LLaMA3.1-Instruct-8B" , 1 ),
109- ("eagle3" , "meta-llama/Llama-3.1-8B-Instruct" ,
110- "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B" , 1 ),
111- pytest .param (
112- ("eagle" , "meta-llama/Llama-4-Scout-17B-16E-Instruct" ,
113- "morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct" , 4 ),
114- marks = pytest .mark .skip (reason = "Skipping due to CI OOM issues" )),
115- ],
116- ids = ["llama3_eagle" , "llama3_eagle3" , "llama4_eagle" ])
126+ @pytest .mark .parametrize (
127+ ["model_setup" , "mm_enabled" ], [
128+ (("eagle" , "meta-llama/Llama-3.1-8B-Instruct" ,
129+ "yuhuili/EAGLE-LLaMA3.1-Instruct-8B" , 1 ), False ),
130+ (("eagle3" , "meta-llama/Llama-3.1-8B-Instruct" ,
131+ "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B" , 1 ), False ),
132+ pytest .param (
133+ ("eagle" , "meta-llama/Llama-4-Scout-17B-16E-Instruct" ,
134+ "morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct" , 4 ),
135+ False ,
136+ marks = pytest .mark .skip (reason = "Skipping due to CI OOM issues" )),
137+ pytest .param (
138+ ("eagle" , "meta-llama/Llama-4-Scout-17B-16E-Instruct" ,
139+ "morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct" , 4 ),
140+ True ,
141+ marks = pytest .mark .skip (reason = "Skipping due to CI OOM issues" )),
142+ ],
143+ ids = ["llama3_eagle" , "llama3_eagle3" , "llama4_eagle" , "llama4_eagle_mm" ])
117144def test_eagle_correctness (
118145 monkeypatch : pytest .MonkeyPatch ,
119- test_prompts : list [list [dict [str , Any ]]],
120146 sampling_config : SamplingParams ,
121147 model_setup : tuple [str , str , str , int ],
148+ mm_enabled : bool ,
122149):
150+ # Generate test prompts inside the function instead of using fixture
151+ test_prompts = get_test_prompts (mm_enabled )
123152 '''
124153 Compare the outputs of a original LLM and a speculative LLM
125154 should be the same when using eagle speculative decoding.
0 commit comments