55import tempfile
66
77import depyf
8- import pytest
98
10- from vllm .config import CompilationLevel
119
12-
13- @pytest .mark .skip (reason = "Not working; needs investigation." )
1410def test_tpu_compilation ():
1511 temp_dir = tempfile .mkdtemp ()
1612 with depyf .prepare_debug (temp_dir ):
@@ -22,27 +18,24 @@ def test_tpu_compilation():
2218 "The greatest glory in living lies not in never falling," ,
2319 ]
2420 answers = [
25- " or, through inaction, allow a human being to come to harm. " ,
26- " what is essential is invisible to the eye. " ,
27- " but in rising every time we fall. " ,
21+ " or, through inaction" ,
22+ " what is essential " ,
23+ " but in rising " ,
2824 ]
29- N = 1
25+
3026 # Currently, top-p sampling is disabled. `top_p` should be 1.0.
27+ N = 1
3128 sampling_params = SamplingParams (temperature = 0.7 ,
3229 top_p = 1.0 ,
3330 n = N ,
3431 max_tokens = 16 )
3532
36- # Set `enforce_eager=True` to avoid ahead-of-time compilation.
37- # In real workloads, `enforace_eager` should be `False`.
38-
39- # disable custom dispatcher, let Dynamo takes over
40- # all the control
4133 llm = LLM (model = "Qwen/Qwen2.5-1.5B-Instruct" ,
42- max_model_len = 512 ,
43- max_num_seqs = 64 ,
44- enforce_eager = True ,
45- compilation_config = {"level" : CompilationLevel .DYNAMO_AS_IS })
34+ max_num_batched_tokens = 256 ,
35+ max_model_len = 256 ,
36+ max_num_seqs = 32 ,
37+ enforce_eager = False )
38+
4639 outputs = llm .generate (prompts , sampling_params )
4740 for output , answer in zip (outputs , answers ):
4841 prompt = output .prompt
@@ -56,16 +49,11 @@ def test_tpu_compilation():
5649 for i , compiled_code in enumerate (compiled_codes ):
5750 print ("{} file: {}" .format (i + 1 , compiled_code ))
5851
59- # We should only trigger Dynamo compilation 4 times:
60- # 1. forward pass (symbolic)
61- # 2. compute_logits (symbolic)
62- # 3. forward pass (shape 16)
63- # 4. forward pass (shape 32)
64- # and later calls should not trigger Dynamo compilation again.
65- # NOTE: It might still trigger XLA compilation.
66-
52+ # We should only trigger Dynamo compilation 2 times:
53+ # 1. Forward pass without kv_caches
54+ # 2. Forward pass with kv_caches
6755 # Check we have 4 compiled codes
68- assert len (compiled_codes ) == 4
56+ assert len (compiled_codes ) == 2
6957
7058 kv_cache_prefix = "kv_cache"
7159 attn_prefix = "ragged_paged_attention"
@@ -77,24 +65,13 @@ def test_tpu_compilation():
7765 for i , compiled_fn in enumerate (compiled_fns ):
7866 print ("{} file: {}" .format (i + 1 , compiled_fn ))
7967
80- # The first compilation is symbolic, so it should not have any kv_caches
68+ # The first compilation should not have any kv_caches
8169 with open (compiled_fns [0 ]) as f :
8270 content = f .read ()
8371 assert kv_cache_prefix not in content
8472
85- # The second compilation is symbolic, so it should not have any kv_caches
86- with open (compiled_fns [1 ]) as f :
87- content = f .read ()
88- assert kv_cache_prefix not in content
89-
90- # The third compilation is shape 16, so it should have kv_caches and the
91- # ragged_paged_attention
92- with open (compiled_fns [2 ]) as f :
93- content = f .read ()
94- assert (kv_cache_prefix in content and attn_prefix in content )
95-
96- # The forth compilation is shape 32, so it should have kv_caches and the
73+ # The second compilation should have kv_caches and the
9774 # ragged_paged_attention
98- with open (compiled_fns [3 ]) as f :
75+ with open (compiled_fns [1 ]) as f :
9976 content = f .read ()
10077 assert (kv_cache_prefix in content and attn_prefix in content )
0 commit comments