Skip to content

Commit 9e0726e

Browse files
morgendaveRoger Wang
andauthored
[Meta] Official Eagle mm support, first enablement on llama4 (#20788)
Signed-off-by: morgendave <[email protected]> Co-authored-by: Roger Wang <[email protected]>
1 parent 53c21e4 commit 9e0726e

File tree

8 files changed

+206
-37
lines changed

8 files changed

+206
-37
lines changed

examples/offline_inference/spec_decode.py

Lines changed: 59 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,38 @@
1313
from argparse import ArgumentParser as FlexibleArgumentParser
1414

1515

16+
QUESTION = "What is the content of each image?"
17+
IMAGE_URLS = [
18+
"https://upload.wikimedia.org/wikipedia/commons/d/da/2015_Kaczka_krzy%C5%BCowka_w_wodzie_%28samiec%29.jpg",
19+
"https://upload.wikimedia.org/wikipedia/commons/7/77/002_The_lion_king_Snyggve_in_the_Serengeti_National_Park_Photo_by_Giles_Laurent.jpg",
20+
"https://upload.wikimedia.org/wikipedia/commons/2/26/Ultramarine_Flycatcher_%28Ficedula_superciliaris%29_Naggar%2C_Himachal_Pradesh%2C_2013_%28cropped%29.JPG",
21+
"https://upload.wikimedia.org/wikipedia/commons/thumb/e/e5/Anim1754_-_Flickr_-_NOAA_Photo_Library_%281%29.jpg/2560px-Anim1754_-_Flickr_-_NOAA_Photo_Library_%281%29.jpg",
22+
"https://upload.wikimedia.org/wikipedia/commons/d/d4/Starfish%2C_Caswell_Bay_-_geograph.org.uk_-_409413.jpg",
23+
"https://upload.wikimedia.org/wikipedia/commons/6/69/Grapevinesnail_01.jpg",
24+
"https://upload.wikimedia.org/wikipedia/commons/thumb/0/0b/Texas_invasive_Musk_Thistle_1.jpg/1920px-Texas_invasive_Musk_Thistle_1.jpg",
25+
"https://upload.wikimedia.org/wikipedia/commons/thumb/7/7a/Huskiesatrest.jpg/2880px-Huskiesatrest.jpg",
26+
"https://upload.wikimedia.org/wikipedia/commons/thumb/6/68/Orange_tabby_cat_sitting_on_fallen_leaves-Hisashi-01A.jpg/1920px-Orange_tabby_cat_sitting_on_fallen_leaves-Hisashi-01A.jpg",
27+
"https://upload.wikimedia.org/wikipedia/commons/3/30/George_the_amazing_guinea_pig.jpg",
28+
"https://upload.wikimedia.org/wikipedia/commons/thumb/1/1f/Oryctolagus_cuniculus_Rcdo.jpg/1920px-Oryctolagus_cuniculus_Rcdo.jpg",
29+
"https://upload.wikimedia.org/wikipedia/commons/9/98/Horse-and-pony.jpg",
30+
]
31+
32+
33+
def get_custom_mm_prompts(num_prompts):
34+
prompts = []
35+
for url in IMAGE_URLS:
36+
prompts.append(
37+
[
38+
{"type": "image_url", "image_url": {"url": url}},
39+
{"type": "text", "text": QUESTION},
40+
]
41+
)
42+
if num_prompts > len(IMAGE_URLS):
43+
prompts = prompts * (num_prompts // len(IMAGE_URLS) + 1)
44+
45+
return [[{"role": "user", "content": prompt}] for prompt in prompts[:num_prompts]]
46+
47+
1648
def parse_args():
1749
parser = FlexibleArgumentParser()
1850
add_dataset_parser(parser)
@@ -35,6 +67,7 @@ def parse_args():
3567
parser.add_argument("--output-len", type=int, default=256)
3668
parser.add_argument("--model-dir", type=str, default=None)
3769
parser.add_argument("--eagle-dir", type=str, default=None)
70+
parser.add_argument("--custom-mm-prompts", action="store_true")
3871
return parser.parse_args()
3972

4073

@@ -44,14 +77,26 @@ def main():
4477

4578
model_dir = args.model_dir
4679
if args.model_dir is None:
80+
if args.custom_mm_prompts:
81+
raise ValueError(
82+
"custom_mm_prompts requires mm based models"
83+
"default llama3.1-8b-instruct is not mm based"
84+
"please specify model_dir to give a mm based model"
85+
)
4786
model_dir = "meta-llama/Llama-3.1-8B-Instruct"
4887
tokenizer = AutoTokenizer.from_pretrained(model_dir)
49-
50-
prompts = get_samples(args, tokenizer)
51-
# add_special_tokens is False to avoid adding bos twice when using chat templates
52-
prompt_ids = [
53-
tokenizer.encode(prompt.prompt, add_special_tokens=False) for prompt in prompts
54-
]
88+
args.custom_skip_chat_template = True
89+
90+
if not args.custom_mm_prompts:
91+
prompts = get_samples(args, tokenizer)
92+
# add_special_tokens is False to avoid adding bos twice
93+
# when using chat templates
94+
prompt_ids = [
95+
tokenizer.encode(prompt.prompt, add_special_tokens=False)
96+
for prompt in prompts
97+
]
98+
else:
99+
prompts = get_custom_mm_prompts(args.num_prompts)
55100

56101
if args.method == "eagle" or args.method == "eagle3":
57102
eagle_dir = args.eagle_dir
@@ -85,10 +130,17 @@ def main():
85130
speculative_config=speculative_config,
86131
disable_log_stats=False,
87132
max_model_len=16384,
133+
limit_mm_per_prompt={"image": 5},
134+
disable_chunked_mm_input=True,
88135
)
89136

90137
sampling_params = SamplingParams(temperature=args.temp, max_tokens=args.output_len)
91-
outputs = llm.generate(prompt_token_ids=prompt_ids, sampling_params=sampling_params)
138+
if not args.custom_mm_prompts:
139+
outputs = llm.generate(
140+
prompt_token_ids=prompt_ids, sampling_params=sampling_params
141+
)
142+
else:
143+
outputs = llm.chat(prompts, sampling_params=sampling_params)
92144

93145
# print the generated text
94146
if args.print_output:

tests/v1/e2e/test_spec_decode.py

Lines changed: 45 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -3,29 +3,34 @@
33
from __future__ import annotations
44

55
import random
6-
from typing import Any
6+
from typing import Any, Union
77

88
import pytest
99
import torch
1010

1111
from vllm import LLM, SamplingParams
12+
from vllm.assets.base import VLLM_S3_BUCKET_URL
13+
from vllm.assets.image import VLM_IMAGES_DIR
1214
from vllm.distributed import cleanup_dist_env_and_memory
1315

1416

15-
@pytest.fixture
16-
def test_prompts():
17+
def get_test_prompts(mm_enabled: bool):
1718
prompt_types = ["repeat", "sentence"]
19+
if mm_enabled:
20+
prompt_types.append("mm")
1821
num_prompts = 100
1922
prompts = []
2023

2124
random.seed(0)
2225
random_prompt_type_choices = random.choices(prompt_types, k=num_prompts)
26+
print(f"Prompt types: {random_prompt_type_choices}")
2327

2428
# Generate a mixed batch of prompts, some of which can be easily
2529
# predicted by n-gram matching and some which likely cannot.
2630
for kind in random_prompt_type_choices:
2731
word_choices = ["test", "temp", "hello", "where"]
2832
word = random.choice(word_choices)
33+
prompt: Union[str, list[dict[str, Any]]] = ""
2934
if kind == "repeat":
3035
prompt = f"""
3136
please repeat the word '{word}' 10 times.
@@ -38,6 +43,21 @@ def test_prompts():
3843
uses the word {word} at least once.
3944
give no other output than that simple sentence without quotes.
4045
"""
46+
elif kind == "mm":
47+
placeholders = [{
48+
"type": "image_url",
49+
"image_url": {
50+
"url":
51+
f"{VLLM_S3_BUCKET_URL}/{VLM_IMAGES_DIR}/stop_sign.jpg"
52+
},
53+
}]
54+
prompt = [
55+
*placeholders,
56+
{
57+
"type": "text",
58+
"text": "The meaning of the image is"
59+
},
60+
]
4161
else:
4262
raise ValueError(f"Unknown prompt type: {kind}")
4363
prompts.append([{"role": "user", "content": prompt}])
@@ -57,7 +77,6 @@ def model_name():
5777

5878
def test_ngram_correctness(
5979
monkeypatch: pytest.MonkeyPatch,
60-
test_prompts: list[list[dict[str, Any]]],
6180
sampling_config: SamplingParams,
6281
model_name: str,
6382
):
@@ -67,6 +86,7 @@ def test_ngram_correctness(
6786
'''
6887
with monkeypatch.context() as m:
6988
m.setenv("VLLM_USE_V1", "1")
89+
test_prompts = get_test_prompts(mm_enabled=False)
7090

7191
ref_llm = LLM(model=model_name, max_model_len=1024)
7292
ref_outputs = ref_llm.chat(test_prompts, sampling_config)
@@ -103,23 +123,32 @@ def test_ngram_correctness(
103123
cleanup_dist_env_and_memory()
104124

105125

106-
@pytest.mark.parametrize("model_setup", [
107-
("eagle", "meta-llama/Llama-3.1-8B-Instruct",
108-
"yuhuili/EAGLE-LLaMA3.1-Instruct-8B", 1),
109-
("eagle3", "meta-llama/Llama-3.1-8B-Instruct",
110-
"yuhuili/EAGLE3-LLaMA3.1-Instruct-8B", 1),
111-
pytest.param(
112-
("eagle", "meta-llama/Llama-4-Scout-17B-16E-Instruct",
113-
"morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct", 4),
114-
marks=pytest.mark.skip(reason="Skipping due to CI OOM issues")),
115-
],
116-
ids=["llama3_eagle", "llama3_eagle3", "llama4_eagle"])
126+
@pytest.mark.parametrize(
127+
["model_setup", "mm_enabled"], [
128+
(("eagle", "meta-llama/Llama-3.1-8B-Instruct",
129+
"yuhuili/EAGLE-LLaMA3.1-Instruct-8B", 1), False),
130+
(("eagle3", "meta-llama/Llama-3.1-8B-Instruct",
131+
"yuhuili/EAGLE3-LLaMA3.1-Instruct-8B", 1), False),
132+
pytest.param(
133+
("eagle", "meta-llama/Llama-4-Scout-17B-16E-Instruct",
134+
"morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct", 4),
135+
False,
136+
marks=pytest.mark.skip(reason="Skipping due to CI OOM issues")),
137+
pytest.param(
138+
("eagle", "meta-llama/Llama-4-Scout-17B-16E-Instruct",
139+
"morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct", 4),
140+
True,
141+
marks=pytest.mark.skip(reason="Skipping due to CI OOM issues")),
142+
],
143+
ids=["llama3_eagle", "llama3_eagle3", "llama4_eagle", "llama4_eagle_mm"])
117144
def test_eagle_correctness(
118145
monkeypatch: pytest.MonkeyPatch,
119-
test_prompts: list[list[dict[str, Any]]],
120146
sampling_config: SamplingParams,
121147
model_setup: tuple[str, str, str, int],
148+
mm_enabled: bool,
122149
):
150+
# Generate test prompts inside the function instead of using fixture
151+
test_prompts = get_test_prompts(mm_enabled)
123152
'''
124153
Compare the outputs of a original LLM and a speculative LLM
125154
should be the same when using eagle speculative decoding.

vllm/model_executor/models/llama4.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -256,6 +256,7 @@ def __init__(
256256
super().__init__()
257257

258258
self.layer_idx = extract_layer_index(prefix)
259+
self.global_layer = config.no_rope_layers[self.layer_idx] == 0
259260
self.hidden_size = config.hidden_size
260261
rope_theta = config.rope_theta
261262
rope_scaling = config.rope_scaling

vllm/model_executor/models/llama4_eagle.py

Lines changed: 31 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,9 @@
3737
from vllm.model_executor.models.llama4 import (Llama4DecoderLayer,
3838
Llama4ForCausalLM)
3939
from vllm.model_executor.models.utils import extract_layer_index
40+
from vllm.multimodal.inputs import NestedTensors
4041

41-
from .utils import AutoWeightsLoader, maybe_prefix
42+
from .utils import AutoWeightsLoader, maybe_prefix, merge_multimodal_embeddings
4243

4344
logger = init_logger(__name__)
4445

@@ -78,15 +79,23 @@ def __init__(
7879
self.norm = RMSNorm(self.config.hidden_size,
7980
eps=self.config.rms_norm_eps)
8081

82+
def get_input_embeddings(
83+
self,
84+
input_ids: torch.Tensor,
85+
) -> torch.Tensor:
86+
return self.embed_tokens(input_ids)
87+
8188
def forward(
8289
self,
8390
input_ids: Optional[torch.Tensor],
8491
positions: torch.Tensor,
8592
hidden_states: torch.Tensor,
93+
inputs_embeds: Optional[torch.Tensor] = None,
8694
) -> tuple[torch.Tensor, torch.Tensor]:
87-
input_embeds = self.embed_tokens(input_ids)
95+
if inputs_embeds is None:
96+
inputs_embeds = self.get_input_embeddings(input_ids)
8897
hidden_states = self.fc(
89-
torch.cat((input_embeds, hidden_states), dim=-1))
98+
torch.cat((inputs_embeds, hidden_states), dim=-1))
9099
residual = None
91100
for layer in self.layers:
92101
hidden_states, residual = layer(
@@ -190,8 +199,9 @@ def forward(
190199
input_ids: torch.Tensor,
191200
positions: torch.Tensor,
192201
hidden_states: torch.Tensor,
202+
inputs_embeds: Optional[torch.Tensor] = None,
193203
) -> tuple[torch.Tensor, torch.Tensor]:
194-
return self.model(input_ids, positions, hidden_states)
204+
return self.model(input_ids, positions, hidden_states, inputs_embeds)
195205

196206
def load_weights(self, weights: Iterable[tuple[str,
197207
torch.Tensor]]) -> None:
@@ -212,3 +222,20 @@ def load_weights(self, weights: Iterable[tuple[str,
212222
model_weights[name] = loaded_weight
213223

214224
loader.load_weights(model_weights.items())
225+
226+
def get_input_embeddings(
227+
self,
228+
input_ids: torch.Tensor,
229+
multimodal_embeddings: Optional[NestedTensors] = None,
230+
) -> torch.Tensor:
231+
inputs_embeds = self.model.get_input_embeddings(input_ids)
232+
233+
if multimodal_embeddings is not None:
234+
inputs_embeds = merge_multimodal_embeddings(
235+
input_ids,
236+
inputs_embeds,
237+
multimodal_embeddings,
238+
self.config.image_token_index,
239+
)
240+
241+
return inputs_embeds

vllm/model_executor/models/llama_eagle.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33

44
from collections.abc import Iterable
5+
from typing import Optional
56

67
import torch
78
import torch.nn as nn
@@ -148,7 +149,12 @@ def forward(
148149
input_ids: torch.Tensor,
149150
positions: torch.Tensor,
150151
hidden_states: torch.Tensor,
152+
inputs_embeds: Optional[torch.Tensor] = None,
151153
) -> tuple[torch.Tensor, torch.Tensor]:
154+
if inputs_embeds is not None:
155+
raise NotImplementedError(
156+
f"{type(self).__name__} does not support multimodal inputs yet."
157+
)
152158
return self.model(input_ids, positions, hidden_states)
153159

154160
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):

vllm/model_executor/models/llama_eagle3.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -202,7 +202,12 @@ def forward(
202202
input_ids: torch.Tensor,
203203
positions: torch.Tensor,
204204
hidden_states: torch.Tensor,
205+
inputs_embeds: Optional[torch.Tensor] = None,
205206
) -> tuple[torch.Tensor, torch.Tensor]:
207+
if inputs_embeds is not None:
208+
raise NotImplementedError(
209+
f"{type(self).__name__} does not support multimodal inputs yet."
210+
)
206211
return self.model(input_ids, positions, hidden_states)
207212

208213
def compute_logits(

0 commit comments

Comments
 (0)