33import pytest
44
55import vllm
6+ from tests .utils import fork_new_process_for_each_test
67from vllm .assets .image import ImageAsset
78from vllm .lora .request import LoRARequest
8-
9- from ..utils import multi_gpu_test
9+ from vllm .platforms import current_platform
1010
1111MODEL_PATH = "openbmb/MiniCPM-Llama3-V-2_5"
1212
1717
1818IMAGE_ASSETS = [
1919 ImageAsset ("stop_sign" ),
20- ImageAsset ("cherry_blossom" ),
2120]
2221
2322# After fine-tuning with LoRA, all generated content should start begin `A`.
2423EXPECTED_OUTPUT = [
2524 "A red and white stop sign with a Chinese archway in the background featuring red lanterns and gold accents." , # noqa: E501
26- "A pink cherry blossom tree with a blue sky in the background." ,
2725]
2826
2927
@@ -50,48 +48,75 @@ def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> List[str]:
5048 # Print the outputs.
5149 generated_texts : List [str ] = []
5250 for output in outputs :
53- prompt = output .prompt
5451 generated_text = output .outputs [0 ].text .strip ()
5552 generated_texts .append (generated_text )
56- print (f"Prompt: { prompt !r } , Generated text: { generated_text !r} " )
53+ print (f"Generated text: { generated_text !r} " )
5754 return generated_texts
5855
5956
60- @multi_gpu_test (num_gpus = 2 )
61- @pytest .mark .parametrize ("fully_sharded" , [True , False ])
62- def test_minicpmv_tp2 (minicpmv_lora_files , fully_sharded ):
57+ @pytest .mark .xfail (
58+ current_platform .is_rocm (),
59+ reason = "MiniCPM-V dependency xformers incompatible with ROCm" )
60+ @fork_new_process_for_each_test
61+ def test_minicpmv_lora (minicpmv_lora_files ):
62+ llm = vllm .LLM (
63+ MODEL_PATH ,
64+ max_num_seqs = 2 ,
65+ enable_lora = True ,
66+ max_loras = 2 ,
67+ max_lora_rank = 8 ,
68+ enforce_eager = True ,
69+ trust_remote_code = True ,
70+ enable_chunked_prefill = True ,
71+ )
72+ output1 = do_sample (llm , minicpmv_lora_files , lora_id = 1 )
73+ for i in range (len (EXPECTED_OUTPUT )):
74+ assert EXPECTED_OUTPUT [i ].startswith (output1 [i ])
75+ output2 = do_sample (llm , minicpmv_lora_files , lora_id = 2 )
76+ for i in range (len (EXPECTED_OUTPUT )):
77+ assert EXPECTED_OUTPUT [i ].startswith (output2 [i ])
78+
79+
80+ @pytest .mark .xfail (
81+ current_platform .is_rocm (),
82+ reason = "MiniCPM-V dependency xformers incompatible with ROCm" )
83+ @fork_new_process_for_each_test
84+ def test_minicpmv_tp4_wo_fully_sharded_loras (minicpmv_lora_files ):
6385 llm = vllm .LLM (
6486 MODEL_PATH ,
6587 enable_lora = True ,
6688 max_num_seqs = 2 ,
6789 max_loras = 4 ,
6890 max_lora_rank = 64 ,
69- tensor_parallel_size = 2 ,
91+ tensor_parallel_size = 4 ,
7092 trust_remote_code = True ,
71- fully_sharded_loras = fully_sharded ,
93+ enforce_eager = True ,
7294 enable_chunked_prefill = True ,
7395 )
74-
7596 output_tp = do_sample (llm , minicpmv_lora_files , lora_id = 1 )
76-
7797 for i in range (len (EXPECTED_OUTPUT )):
7898 assert EXPECTED_OUTPUT [i ].startswith (output_tp [i ])
7999
80100
81- @multi_gpu_test (num_gpus = 4 )
82- @pytest .mark .parametrize ("fully_sharded" , [True , False ])
83- def test_minicpmv_tp4 (minicpmv_lora_files , fully_sharded ):
101+ @pytest .mark .xfail (
102+ current_platform .is_rocm (),
103+ reason = "MiniCPM-V dependency xformers incompatible with ROCm" )
104+ @fork_new_process_for_each_test
105+ def test_minicpmv_tp4_fully_sharded_loras (minicpmv_lora_files ):
84106 llm = vllm .LLM (
85107 MODEL_PATH ,
86108 enable_lora = True ,
87109 max_num_seqs = 2 ,
88- max_loras = 4 ,
89- max_lora_rank = 64 ,
110+ max_loras = 2 ,
111+ max_lora_rank = 8 ,
90112 tensor_parallel_size = 4 ,
91113 trust_remote_code = True ,
92- fully_sharded_loras = fully_sharded ,
114+ fully_sharded_loras = True ,
93115 enable_chunked_prefill = True ,
94116 )
95117 output_tp = do_sample (llm , minicpmv_lora_files , lora_id = 1 )
96118 for i in range (len (EXPECTED_OUTPUT )):
97119 assert EXPECTED_OUTPUT [i ].startswith (output_tp [i ])
120+ output_tp = do_sample (llm , minicpmv_lora_files , lora_id = 2 )
121+ for i in range (len (EXPECTED_OUTPUT )):
122+ assert EXPECTED_OUTPUT [i ].startswith (output_tp [i ])
0 commit comments