11# SPDX-License-Identifier: Apache-2.0
22
3- import random
43from typing import Optional
54
65import pytest
@@ -171,19 +170,31 @@ def ref_context_attention(
171170 return output
172171
173172
173+ @pytest .mark .parametrize (
174+ "block_size, large_tile_size" ,
175+ [
176+ (32 , 2048 ), # 64 blocks
177+ (32 , 4096 ), # 128 blocks
178+ (32 , 8192 ), # 256 blocks
179+ (64 , 8192 ), # 128 blocks
180+ ],
181+ )
174182@pytest .mark .parametrize (
175183 "num_heads,num_queries_per_kv,head_size,mixed_precision" ,
176184 [
177185 (4 , 2 , 8 , False ),
178186 (4 , 2 , 8 , True ),
179187 (32 , 8 , 64 , True ),
188+ (16 , 2 , 128 , True ),
180189 ],
181190)
182191@torch .inference_mode ()
183192def test_contexted_kv_attention (
184193 num_heads : int ,
185194 num_queries_per_kv : int ,
186195 head_size : int ,
196+ block_size : int ,
197+ large_tile_size ,
187198 mixed_precision : bool ,
188199) -> None :
189200 import os
@@ -192,40 +203,46 @@ def test_contexted_kv_attention(
192203
193204 from vllm .attention .ops .nki_flash_attn import flash_attn_varlen_nkifunc
194205
206+ assert large_tile_size % block_size == 0
207+
195208 device = xm .xla_device ()
196209
197- os .environ ["NEURON_CC_FLAGS" ] = (
198- " --model-type=transformer -O1 "
199- " --internal-hlo2tensorizer-options='--verify-hlo' " )
210+ compiler_flags = [
211+ "--model-type=transformer -O1" ,
212+ "--internal-hlo2tensorizer-options='--verify-hlo'" ,
213+ "--retry_failed_compilation" ,
214+ ]
215+ compiler_flags_str = " " .join (compiler_flags )
216+ os .environ ["NEURON_CC_FLAGS" ] = compiler_flags_str
200217
201- random .seed (0 )
202218 torch .manual_seed (0 )
203219 torch .set_printoptions (sci_mode = False )
204220
205- min_ctx_len = 2
206- max_ctx_len = 64
207- min_query_len = 2
208- max_query_len = 64
209- prefill_batch_size = 2
210- decode_batch_size = 6
221+ min_ctx_len = 32
222+ max_ctx_len = 1024
223+ min_query_len = 16
224+ max_query_len = 512
225+ prefill_batch_size = 4
226+ decode_batch_size = 12
211227 batch_size = prefill_batch_size + decode_batch_size
212- block_size = 32
213228 max_model_len = (max_query_len + max_ctx_len ) * 4
214229
215230 max_block_per_request = max_model_len // block_size
216231 dtype = torch .float32
217232 cache_size = (batch_size * max_block_per_request ) + 2
218- ctx_lens = [
219- random .randint (min_ctx_len , max_ctx_len )
220- for _ in range (prefill_batch_size )
221- ] + [
222- random .randint (min_ctx_len , max_ctx_len )
223- for _ in range (decode_batch_size )
224- ]
225- query_lens = [
226- random .randint (min_query_len , max_query_len )
227- for _ in range (prefill_batch_size )
228- ] + [1 for _ in range (decode_batch_size )]
233+ prefill_ctx_lens = torch .randint (min_ctx_len ,
234+ max_ctx_len + 1 , (prefill_batch_size , ),
235+ dtype = torch .long ).tolist ()
236+ decode_ctx_lens = torch .randint (min_ctx_len ,
237+ max_ctx_len + 1 , (decode_batch_size , ),
238+ dtype = torch .long ).tolist ()
239+ ctx_lens = prefill_ctx_lens + decode_ctx_lens
240+ query_lens = torch .randint (
241+ min_query_len ,
242+ max_query_len + 1 ,
243+ (prefill_batch_size , ),
244+ dtype = torch .long ,
245+ ).tolist () + [1 for _ in range (decode_batch_size )]
229246 seq_lens = [a + b for a , b in zip (query_lens , ctx_lens )]
230247 num_kv_heads = num_heads // num_queries_per_kv
231248
@@ -254,7 +271,6 @@ def test_contexted_kv_attention(
254271 values = values [torch .randperm (cache_size )]
255272 block_table = values [:batch_size * max_block_per_request ].view (
256273 batch_size , max_block_per_request )
257- torch .tensor (seq_lens , dtype = torch .long )
258274 b_ctx_len = torch .tensor (ctx_lens , dtype = torch .long )
259275 b_start_loc = torch .cumsum (torch .tensor ([0 ] + query_lens [:- 1 ],
260276 dtype = torch .long ),
@@ -311,9 +327,7 @@ def test_contexted_kv_attention(
311327 # build neuron program
312328 return_debug_tensors = False
313329 B_P_SIZE = 128
314- LARGE_TILE_SZ = 2048
315- max_num_queries = (
316- (sum (query_lens ) + block_size - 1 ) // block_size ) * block_size
330+ LARGE_TILE_SZ = large_tile_size
317331
318332 def get_active_block_tables (block_tables , query_lens , seq_lens , block_size ,
319333 num_blocks ):
@@ -332,26 +346,28 @@ def get_active_block_tables(block_tables, query_lens, seq_lens, block_size,
332346 0 ,
333347 )
334348
335- def shift_bit_length (x ):
336- return 1 << (x - 1 ).bit_length ()
349+ def ceil_div (a , b ):
350+ return (a + b - 1 ) // b
351+
352+ def pad_to_multiple (a , b ):
353+ return ceil_div (a , b ) * b
354+
355+ def pad_to_next_power_of_2 (a ):
356+ assert a > 0
357+ return 2 ** int (a - 1 ).bit_length ()
337358
338359 # calculate input shapes
339- max_num_queries_shifted = shift_bit_length (max_num_queries )
340- max_num_queries_factor = B_P_SIZE // max_num_queries_shifted
341- max_num_queries_padded = max_num_queries_shifted * max_num_queries_factor
342- assert (max_num_queries_padded == B_P_SIZE
343- ), "invalid {max_num_queries_padded=}"
360+ max_num_queries = pad_to_multiple (sum (query_lens ), block_size )
361+ max_num_queries = pad_to_next_power_of_2 (max_num_queries )
344362 head_size_padded = B_P_SIZE
363+ assert head_size_padded >= head_size
345364 context_lens = torch .tensor (seq_lens ) - torch .tensor (query_lens )
346- num_active_blocks_shifted = shift_bit_length (
347- ((context_lens + block_size - 1 ) // block_size ).sum ().item ())
348- num_active_blocks_factor = (LARGE_TILE_SZ // block_size //
349- num_active_blocks_shifted )
350- num_active_blocks = num_active_blocks_shifted * num_active_blocks_factor
351- assert (num_active_blocks *
352- block_size ) == LARGE_TILE_SZ , "invalid {num_active_blocks=}"
365+ num_active_blocks = ceil_div (context_lens , block_size ).sum ().item ()
366+ num_active_blocks = pad_to_multiple (num_active_blocks ,
367+ LARGE_TILE_SZ // block_size )
353368 context_kv_len = num_active_blocks * block_size
354- assert context_kv_len == LARGE_TILE_SZ , f"invalid { context_kv_len = } "
369+ assert (context_kv_len %
370+ LARGE_TILE_SZ == 0 ), f"invalid context_kv_len={ context_kv_len } "
355371
356372 # pad QKV tensors
357373 pad_dims = (
@@ -360,7 +376,7 @@ def shift_bit_length(x):
360376 0 ,
361377 0 ,
362378 0 ,
363- max_num_queries_padded - query .shape [0 ],
379+ max_num_queries - query .shape [0 ],
364380 )
365381 query = F .pad (query , pad_dims , "constant" , 0 )
366382 k = F .pad (k , pad_dims , "constant" , 0 )
@@ -397,7 +413,7 @@ def shift_bit_length(x):
397413 0 ,
398414 context_kv_len - prior_mask .shape [1 ],
399415 0 ,
400- B_P_SIZE - prior_mask .shape [0 ],
416+ max_num_queries - prior_mask .shape [0 ],
401417 ),
402418 "constant" ,
403419 0 ,
@@ -406,9 +422,9 @@ def shift_bit_length(x):
406422 active_mask ,
407423 (
408424 0 ,
409- B_P_SIZE - active_mask .shape [1 ],
425+ max_num_queries - active_mask .shape [1 ],
410426 0 ,
411- B_P_SIZE - active_mask .shape [0 ],
427+ max_num_queries - active_mask .shape [0 ],
412428 ),
413429 "constant" ,
414430 0 ,
@@ -430,6 +446,8 @@ def shift_bit_length(x):
430446 n_kv_head = num_kv_heads ,
431447 head_size = head_size ,
432448 mixed_precision = mixed_precision ,
449+ LARGE_TILE_SZ = LARGE_TILE_SZ ,
450+ return_debug_tensors = return_debug_tensors ,
433451 )
434452
435453 if return_debug_tensors :
@@ -439,17 +457,15 @@ def shift_bit_length(x):
439457 output_nki = flash_attn_varlen_nkifunc (* input_args , ** input_kwargs )
440458 debug_tensors = []
441459
442- output_nki = torch .tensor (output_nki ).cpu ()
443460 debug_tensors = [torch .tensor (dt ).cpu () for dt in debug_tensors ]
444461
445462 num_actual_tokens = sum (query_lens )
446- print (f"{ num_actual_tokens = } " )
447463 # - o: shape (bs, n_heads, seq_q, d) -> (bs, seq_q, n_heads, d)
448- output_nki = output_nki .permute (
449- 0 , 2 , 1 , 3 )[:, :, :, : head_size ]. cpu () [0 , :num_actual_tokens , :, :]
464+ output_nki = output_nki .cpu (). permute (0 , 2 , 1 , 3 )[:, :, :, : head_size ]
465+ output_nki = output_nki [0 , :num_actual_tokens , :, :]
450466 output_ref_padded = F .pad (
451467 output_ref ,
452- (0 , 0 , 0 , 0 , 0 , 0 , 0 , max_num_queries_padded - output_ref .shape [0 ]),
468+ (0 , 0 , 0 , 0 , 0 , 0 , 0 , max_num_queries - output_ref .shape [0 ]),
453469 "constant" ,
454470 0 ,
455471 )
0 commit comments