44by the model.
55"""
66from argparse import Namespace
7- from typing import List
7+ from typing import List , NamedTuple , Optional
88
9+ from PIL .Image import Image
910from transformers import AutoProcessor , AutoTokenizer
1011
1112from vllm import LLM , SamplingParams
1920]
2021
2122
22- def load_qwenvl_chat (question : str , image_urls : List [str ]):
23+ class ModelRequestData (NamedTuple ):
24+ llm : LLM
25+ prompt : str
26+ stop_token_ids : Optional [List [str ]]
27+ image_data : List [Image ]
28+ chat_template : Optional [str ]
29+
30+
31+ def load_qwenvl_chat (question : str , image_urls : List [str ]) -> ModelRequestData :
2332 model_name = "Qwen/Qwen-VL-Chat"
2433 llm = LLM (
2534 model = model_name ,
@@ -48,10 +57,16 @@ def load_qwenvl_chat(question: str, image_urls: List[str]):
4857
4958 stop_tokens = ["<|endoftext|>" , "<|im_start|>" , "<|im_end|>" ]
5059 stop_token_ids = [tokenizer .convert_tokens_to_ids (i ) for i in stop_tokens ]
51- return llm , prompt , stop_token_ids , None , chat_template
60+ return ModelRequestData (
61+ llm = llm ,
62+ prompt = prompt ,
63+ stop_token_ids = stop_token_ids ,
64+ image_data = [fetch_image (url ) for url in image_urls ],
65+ chat_template = chat_template ,
66+ )
5267
5368
54- def load_phi3v (question : str , image_urls : List [str ]):
69+ def load_phi3v (question : str , image_urls : List [str ]) -> ModelRequestData :
5570 llm = LLM (
5671 model = "microsoft/Phi-3.5-vision-instruct" ,
5772 trust_remote_code = True ,
@@ -62,10 +77,17 @@ def load_phi3v(question: str, image_urls: List[str]):
6277 for i , _ in enumerate (image_urls , start = 1 ))
6378 prompt = f"<|user|>\n { placeholders } \n { question } <|end|>\n <|assistant|>\n "
6479 stop_token_ids = None
65- return llm , prompt , stop_token_ids , None , None
80+
81+ return ModelRequestData (
82+ llm = llm ,
83+ prompt = prompt ,
84+ stop_token_ids = stop_token_ids ,
85+ image_data = [fetch_image (url ) for url in image_urls ],
86+ chat_template = None ,
87+ )
6688
6789
68- def load_internvl (question : str , image_urls : List [str ]):
90+ def load_internvl (question : str , image_urls : List [str ]) -> ModelRequestData :
6991 model_name = "OpenGVLab/InternVL2-2B"
7092
7193 llm = LLM (
@@ -93,10 +115,16 @@ def load_internvl(question: str, image_urls: List[str]):
93115 stop_tokens = ["<|endoftext|>" , "<|im_start|>" , "<|im_end|>" , "<|end|>" ]
94116 stop_token_ids = [tokenizer .convert_tokens_to_ids (i ) for i in stop_tokens ]
95117
96- return llm , prompt , stop_token_ids , None , None
118+ return ModelRequestData (
119+ llm = llm ,
120+ prompt = prompt ,
121+ stop_token_ids = stop_token_ids ,
122+ image_data = [fetch_image (url ) for url in image_urls ],
123+ chat_template = None ,
124+ )
97125
98126
99- def load_qwen2_vl (question , image_urls : List [str ]):
127+ def load_qwen2_vl (question , image_urls : List [str ]) -> ModelRequestData :
100128 try :
101129 from qwen_vl_utils import process_vision_info
102130 except ModuleNotFoundError :
@@ -143,7 +171,13 @@ def load_qwen2_vl(question, image_urls: List[str]):
143171 else :
144172 image_data , _ = process_vision_info (messages )
145173
146- return llm , prompt , stop_token_ids , image_data , None
174+ return ModelRequestData (
175+ llm = llm ,
176+ prompt = prompt ,
177+ stop_token_ids = stop_token_ids ,
178+ image_data = image_data ,
179+ chat_template = None ,
180+ )
147181
148182
149183model_example_map = {
@@ -155,20 +189,17 @@ def load_qwen2_vl(question, image_urls: List[str]):
155189
156190
157191def run_generate (model , question : str , image_urls : List [str ]):
158- llm , prompt , stop_token_ids , image_data , _ = model_example_map [model ](
159- question , image_urls )
160- if image_data is None :
161- image_data = [fetch_image (url ) for url in image_urls ]
192+ req_data = model_example_map [model ](question , image_urls )
162193
163194 sampling_params = SamplingParams (temperature = 0.0 ,
164195 max_tokens = 128 ,
165- stop_token_ids = stop_token_ids )
196+ stop_token_ids = req_data . stop_token_ids )
166197
167- outputs = llm .generate (
198+ outputs = req_data . llm .generate (
168199 {
169- "prompt" : prompt ,
200+ "prompt" : req_data . prompt ,
170201 "multi_modal_data" : {
171- "image" : image_data
202+ "image" : req_data . image_data
172203 },
173204 },
174205 sampling_params = sampling_params )
@@ -179,13 +210,12 @@ def run_generate(model, question: str, image_urls: List[str]):
179210
180211
181212def run_chat (model : str , question : str , image_urls : List [str ]):
182- llm , _ , stop_token_ids , _ , chat_template = model_example_map [model ](
183- question , image_urls )
213+ req_data = model_example_map [model ](question , image_urls )
184214
185215 sampling_params = SamplingParams (temperature = 0.0 ,
186216 max_tokens = 128 ,
187- stop_token_ids = stop_token_ids )
188- outputs = llm .chat (
217+ stop_token_ids = req_data . stop_token_ids )
218+ outputs = req_data . llm .chat (
189219 [{
190220 "role" :
191221 "user" ,
@@ -203,7 +233,7 @@ def run_chat(model: str, question: str, image_urls: List[str]):
203233 ],
204234 }],
205235 sampling_params = sampling_params ,
206- chat_template = chat_template ,
236+ chat_template = req_data . chat_template ,
207237 )
208238
209239 for o in outputs :
0 commit comments