|
116 | 116 | help='Data type for kv cache storage. If "auto", will use model '
|
117 | 117 | "data type. fp8 type now supports e5m2.",
|
118 | 118 | )
|
119 |
| - |
| 119 | +parser.add_argument( |
| 120 | + "--input-mode", |
| 121 | + default="0", |
| 122 | + choices=["0", "1", "2", "3"], |
| 123 | + type=str, |
| 124 | + help="Input mode for multimodal models. 0: language; 1: vision; 2: speech; 3: vision_speech", |
| 125 | +) |
120 | 126 | args = parser.parse_args()
|
121 | 127 | print(args)
|
122 | 128 |
|
|
185 | 191 | config.lm_head_generation = True
|
186 | 192 | if model_type == "maira2" and not hasattr(config.text_config, "lm_head_generation"):
|
187 | 193 | config.text_config.lm_head_generation = True
|
| 194 | +if re.search("phi4mm", config.architectures[0], re.IGNORECASE): |
| 195 | + model_type = "phi4mm" |
| 196 | + model_class = MODEL_CLASSES[model_type] |
| 197 | + prompt = args.prompt |
| 198 | + _COMPATIBLE_IMAGE_SPECIAL_TOKEN_PATTERN = r"<\|image_\d+\|>" |
| 199 | + _COMPATIBLE_AUDIO_SPECIAL_TOKEN_PATTERN = r"<\|audio_\d+\|>" |
| 200 | + image_in_prompt = len(re.findall(_COMPATIBLE_IMAGE_SPECIAL_TOKEN_PATTERN, prompt)) |
| 201 | + audio_in_prompt = len(re.findall(_COMPATIBLE_AUDIO_SPECIAL_TOKEN_PATTERN, prompt)) |
| 202 | + is_vision = image_in_prompt > 0 |
| 203 | + is_speech = audio_in_prompt > 0 |
| 204 | + audio_batch_size = args.batch_size |
| 205 | + if is_vision: |
| 206 | + assert ( |
| 207 | + image_in_prompt == args.batch_size |
| 208 | + ), "Prompt is invalid. For multiple images, the user needs to \ |
| 209 | + insert multiple image placeholders in the prompt as below: \ |
| 210 | + <|user|><|image_1|><|image_2|><|image_3|>Summarize the content of the images.<|end|><|assistant|>" |
| 211 | + if is_speech: |
| 212 | + if not is_vision: |
| 213 | + assert ( |
| 214 | + audio_in_prompt == args.batch_size |
| 215 | + ), "Prompt is invalid. For multiple audios, the user needs to \ |
| 216 | + insert multiple audio placeholders in the prompt as below: \ |
| 217 | + <|user|><|audio_1|><|audio_2|><|audio_3|>Transcribe the audio clip into text.<|end|><|assistant|>" |
| 218 | + else: |
| 219 | + audio_batch_size = audio_in_prompt |
| 220 | + if not is_vision and not is_speech: |
| 221 | + config.input_mode = 0 |
| 222 | + elif is_vision and not is_speech: |
| 223 | + config.input_mode = 1 |
| 224 | + elif not is_vision and is_speech: |
| 225 | + config.input_mode = 2 |
| 226 | + else: |
| 227 | + config.input_mode = 3 |
188 | 228 |
|
| 229 | + assert config.input_mode == int( |
| 230 | + args.input_mode |
| 231 | + ), "Input mode in prompt is not consistent with the input mode in the command line." |
189 | 232 | if model_type != "llava":
|
| 233 | + config._attn_implementation = "eager" |
190 | 234 | model = model_class[0].from_pretrained(
|
191 | 235 | args.model_id,
|
192 | 236 | torch_dtype=amp_dtype,
|
193 | 237 | config=config,
|
194 | 238 | low_cpu_mem_usage=True if model_type != "maira2" else False,
|
195 | 239 | trust_remote_code=True,
|
| 240 | + attn_implementation="eager", |
196 | 241 | )
|
197 | 242 | tokenizer = model_class[1].from_pretrained(args.model_id, trust_remote_code=True)
|
198 | 243 | else:
|
@@ -240,7 +285,9 @@ def load_image(image_file):
|
240 | 285 | image = Image.open(image_file).convert("RGB")
|
241 | 286 | return image
|
242 | 287 |
|
243 |
| -elif re.search("mllama", model.config.architectures[0], re.IGNORECASE): |
| 288 | +elif re.search("mllama", model.config.architectures[0], re.IGNORECASE) or re.search( |
| 289 | + "phi4mm", model.config.architectures[0], re.IGNORECASE |
| 290 | +): |
244 | 291 | from PIL import Image
|
245 | 292 |
|
246 | 293 | def load_image(image_file):
|
@@ -280,10 +327,20 @@ def download_and_open(url: str) -> Image.Image:
|
280 | 327 | "jamba", model.config.architectures[0], re.IGNORECASE
|
281 | 328 | ):
|
282 | 329 | model.config.batch_size = int(args.batch_size) * num_beams
|
| 330 | +if re.search("phi4mm", model.config.architectures[0], re.IGNORECASE): |
| 331 | + model.config.batch_size = int(args.batch_size) * num_beams |
| 332 | + model.config.audio_batch_size = audio_batch_size * num_beams |
283 | 333 | if re.search("whisper", model.config.architectures[0], re.IGNORECASE):
|
284 | 334 | import librosa
|
285 | 335 |
|
286 | 336 | sample = librosa.load(args.audio, sr=16000)
|
| 337 | +if re.search("phi4mm", model.config.architectures[0], re.IGNORECASE): |
| 338 | + if config.input_mode in [2, 3]: |
| 339 | + import soundfile |
| 340 | + |
| 341 | + sample = soundfile.read(args.audio) |
| 342 | + else: |
| 343 | + sample = None |
287 | 344 |
|
288 | 345 |
|
289 | 346 | def trace_handler(prof):
|
@@ -347,6 +404,8 @@ def trace_handler(prof):
|
347 | 404 | if hasattr(tokenizer, "process_reporting_input")
|
348 | 405 | else tokenizer.format_and_preprocess_reporting_input
|
349 | 406 | )
|
| 407 | + elif model_type == "phi4mm": |
| 408 | + prompt = args.prompt |
350 | 409 | else:
|
351 | 410 | # input prompt
|
352 | 411 | current_path = pathlib.Path(__file__).parent.resolve()
|
@@ -431,14 +490,26 @@ def trace_handler(prof):
|
431 | 490 | )
|
432 | 491 | input_ids = processed_inputs["input_ids"]
|
433 | 492 | output = model.generate(**processed_inputs, **generate_kwargs)
|
| 493 | + elif model_type == "phi4mm": |
| 494 | + raw_image = load_image(args.image_url) if is_vision else None |
| 495 | + raw_image = [raw_image] * args.batch_size |
| 496 | + samples = [sample] * audio_batch_size |
| 497 | + inputs = tokenizer( |
| 498 | + text=prompt[0], |
| 499 | + images=raw_image if is_vision else None, |
| 500 | + audios=samples if is_speech else None, |
| 501 | + return_tensors="pt", |
| 502 | + ) |
| 503 | + input_ids = inputs["input_ids"] |
| 504 | + output = model.generate(**inputs, **generate_kwargs) |
434 | 505 | else:
|
435 | 506 | input_ids = tokenizer(prompt, return_tensors="pt").input_ids
|
436 | 507 | output = model.generate(input_ids, **generate_kwargs)
|
437 | 508 | gen_ids = output[0] if args.token_latency else output
|
438 | 509 | gen_text = tokenizer.batch_decode(
|
439 | 510 | (
|
440 | 511 | gen_ids[:, input_ids.shape[1] :]
|
441 |
| - if model_type in ["llava", "maira2"] |
| 512 | + if model_type in ["llava", "maira2", "phi4mm"] |
442 | 513 | else gen_ids
|
443 | 514 | ),
|
444 | 515 | skip_special_tokens=True,
|
@@ -514,6 +585,17 @@ def trace_handler(prof):
|
514 | 585 | get_grounding=False,
|
515 | 586 | )
|
516 | 587 | output = model.generate(**processed_inputs, **generate_kwargs)
|
| 588 | + elif model_type == "phi4mm": |
| 589 | + raw_image = load_image(args.image_url) if is_vision else None |
| 590 | + raw_image = [raw_image] * args.batch_size |
| 591 | + samples = [sample] * audio_batch_size |
| 592 | + inputs = tokenizer( |
| 593 | + text=prompt[0], |
| 594 | + images=raw_image if is_vision else None, |
| 595 | + audios=samples if is_speech else None, |
| 596 | + return_tensors="pt", |
| 597 | + ) |
| 598 | + output = model.generate(**inputs, **generate_kwargs) |
517 | 599 | else:
|
518 | 600 | input_ids = tokenizer(prompt, return_tensors="pt").input_ids
|
519 | 601 | output = model.generate(input_ids, **generate_kwargs)
|
|
0 commit comments