@@ -273,7 +273,8 @@ def test_prepare_prompt(batch_size):
273273 "unsupported for encoder/ "
274274 "decoder models" )
275275@pytest .mark .parametrize ("batch_size" , BATCH_SIZES )
276- def test_prepare_decode (batch_size ):
276+ @pytest .mark .parametrize ("multiple_seqs_per_seq_group" , [True , False ])
277+ def test_prepare_decode (batch_size , multiple_seqs_per_seq_group ):
277278 '''
278279 Test the ability of the encoder/decoder model runner subclass to
279280 produce decode-phase model inputs & attention metadata.
@@ -288,6 +289,7 @@ def test_prepare_decode(batch_size):
288289 Arguments:
289290
290291 * batch_size
292+ * multiple_seqs_per_seq_group
291293 * backend_name: The attention backend under test
292294 * enforce_eager: Enforce eager mode if True (i.e. no CUDAGraph)
293295 '''
@@ -305,29 +307,40 @@ def test_prepare_decode(batch_size):
305307 seq_lens : List [int ] = []
306308 encoder_seq_lens : List [int ] = []
307309 seq_group_metadata_list : List [SequenceGroupMetadata ] = []
308- block_tables = {0 : [1 ]}
310+ block_tables = {
311+ 0 : [1 ],
312+ 1 : [3 ]
313+ } if multiple_seqs_per_seq_group else {
314+ 0 : [1 ]
315+ }
309316 cross_block_table = [2 ]
310317 for i in range (batch_size ):
311318 # make sure all tokens fit into one block
312319 seq_len = i % (model_runner .block_size - 1 ) + 1
313- seq_lens .append (seq_len )
314320 seq_data = SequenceData (
315321 array (VLLM_TOKEN_ID_ARRAY_TYPE , (range (seq_len ))))
316322 encoder_seq_len = (i + 1 ) % (model_runner .block_size - 1 ) + 1
317- encoder_seq_lens .append (encoder_seq_len )
318323 encoder_seq_data = SequenceData (
319324 array (VLLM_TOKEN_ID_ARRAY_TYPE , (range (encoder_seq_len ))))
325+
320326 seq_group_metadata = SequenceGroupMetadata (
321327 request_id = f"test_{ i } " ,
322328 is_prompt = False ,
323- seq_data = {0 : seq_data },
329+ seq_data = {
330+ 0 : seq_data ,
331+ 1 : seq_data
332+ } if multiple_seqs_per_seq_group else {0 : seq_data },
324333 sampling_params = SamplingParams (temperature = 0 ),
325334 block_tables = block_tables ,
326335 encoder_seq_data = encoder_seq_data ,
327336 cross_block_table = cross_block_table ,
328337 )
329338 assert seq_group_metadata .token_chunk_size == 1
330339 seq_group_metadata_list .append (seq_group_metadata )
340+ seq_lens .extend (
341+ [seq_len for _ in range (len (seq_group_metadata .seq_data ))])
342+ encoder_seq_lens .extend (
343+ [encoder_seq_len for _ in range (len (seq_group_metadata .seq_data ))])
331344
332345 # Build
333346 # * Decoder model inputs
@@ -398,19 +411,24 @@ def test_prepare_decode(batch_size):
398411
399412 # Verify block tables are correct for prompts
400413 # - Decoder self-attention
401- expected = torch .tensor (
402- [block_tables [0 ] for _ in range (len (seq_group_metadata_list ))],
403- dtype = torch .int32 ,
404- device = model_runner .device )
414+ flattened_block_tables = [
415+ block_table for block_table in block_tables .values ()
416+ ]
417+ expected = torch .tensor (flattened_block_tables *
418+ len (seq_group_metadata_list ),
419+ dtype = torch .int32 ,
420+ device = model_runner .device )
405421 assert torch .equal (
406422 attn_metadata .block_tables ,
407423 expected ,
408424 )
409425 # - Encoder/decoder cross-attention
410- expected = torch .tensor (
411- [cross_block_table for _ in range (len (seq_group_metadata_list ))],
412- dtype = torch .int32 ,
413- device = model_runner .device )
426+ expected = torch .tensor ([
427+ cross_block_table for seq_group_metadata in seq_group_metadata_list
428+ for _ in range (len (seq_group_metadata .seq_data ))
429+ ],
430+ dtype = torch .int32 ,
431+ device = model_runner .device )
414432 assert torch .equal (
415433 attn_metadata .cross_block_tables ,
416434 expected ,
@@ -474,7 +492,8 @@ def test_prepare_decode(batch_size):
474492
475493
476494@pytest .mark .parametrize ("batch_size" , list (range (1 , 257 )))
477- def test_prepare_decode_cuda_graph (batch_size ):
495+ @pytest .mark .parametrize ("multiple_seqs_per_seq_group" , [True , False ])
496+ def test_prepare_decode_cuda_graph (batch_size , multiple_seqs_per_seq_group ):
478497 """
479498 Tests that for encoder-decoder models with CUDA Graph capture and replay
480499 enabled, the tensors used during the decode phase are correctly padded
@@ -489,32 +508,45 @@ def test_prepare_decode_cuda_graph(batch_size):
489508 enable_chunked_prefill = False ,
490509 enforce_eager = False ,
491510 )
492-
511+ block_tables = {
512+ 0 : [1 ],
513+ 1 : [3 ]
514+ } if multiple_seqs_per_seq_group else {
515+ 0 : [1 ]
516+ }
493517 seq_lens : List [int ] = []
494518 encoder_seq_lens : List [int ] = []
495519 seq_group_metadata_list : List [SequenceGroupMetadata ] = []
496- block_tables = { 0 : [ 1 ]}
520+
497521 cross_block_table = [2 ]
522+ expanded_batch_size = 0
498523 for i in range (batch_size ):
499524 # make sure all tokens fit into one block
500525 seq_len = i % (model_runner .block_size - 1 ) + 1
501- seq_lens .append (seq_len )
502526 seq_data = SequenceData (
503527 array (VLLM_TOKEN_ID_ARRAY_TYPE , (range (seq_len ))))
504528 encoder_seq_len = (i + 1 ) % (model_runner .block_size - 1 ) + 1
505- encoder_seq_lens .append (encoder_seq_len )
506529 encoder_seq_data = SequenceData (
507530 array (VLLM_TOKEN_ID_ARRAY_TYPE , (range (encoder_seq_len ))))
508531 seq_group_metadata = SequenceGroupMetadata (
509532 request_id = f"test_{ i } " ,
510533 is_prompt = False ,
511- seq_data = {0 : seq_data },
534+ seq_data = {
535+ 0 : seq_data ,
536+ 1 : seq_data
537+ } if multiple_seqs_per_seq_group else {0 : seq_data },
512538 sampling_params = SamplingParams (temperature = 0 ),
513539 block_tables = block_tables ,
514540 encoder_seq_data = encoder_seq_data ,
515541 cross_block_table = cross_block_table ,
516542 )
517543 assert seq_group_metadata .token_chunk_size == 1
544+ seq_lens .extend (
545+ [seq_len for _ in range (len (seq_group_metadata .seq_data ))])
546+ encoder_seq_lens .extend (
547+ [encoder_seq_len for _ in range (len (seq_group_metadata .seq_data ))])
548+ expanded_batch_size = expanded_batch_size + len (
549+ seq_group_metadata .seq_data )
518550 seq_group_metadata_list .append (seq_group_metadata )
519551
520552 model_input = model_runner .prepare_model_input (seq_group_metadata_list )
@@ -530,8 +562,8 @@ def test_prepare_decode_cuda_graph(batch_size):
530562 # With CUDA Graph capture and replay enabled, the decoder and encoder
531563 # input sequences will be padded. Create the expected padded tensors
532564 # accordingly.
533- graph_batch_size = _get_graph_batch_size (batch_size )
534- cuda_graph_pad_size = graph_batch_size - batch_size
565+ graph_batch_size = _get_graph_batch_size (expanded_batch_size )
566+ cuda_graph_pad_size = graph_batch_size - expanded_batch_size
535567 padded_seq_lens = seq_lens + list (itertools .repeat (1 , cuda_graph_pad_size ))
536568 padded_encoder_seq_lens = encoder_seq_lens + list (
537569 itertools .repeat (1 , cuda_graph_pad_size ))
@@ -560,10 +592,13 @@ def test_prepare_decode_cuda_graph(batch_size):
560592
561593 # Verify block tables are correct for prompts
562594 # - Decoder self-attention. Pad the block tables as expected.
563- expected = [block_tables [0 ] for _ in range (batch_size )]
564- expected .extend ([[] for _ in range (cuda_graph_pad_size )])
595+ flattened_block_tables = [
596+ block_table for _ in range (len (seq_group_metadata_list ))
597+ for block_table in block_tables .values ()
598+ ]
599+ flattened_block_tables .extend ([[] for _ in range (cuda_graph_pad_size )])
565600 expected = make_tensor_with_pad (
566- expected ,
601+ flattened_block_tables ,
567602 max_len = 64 ,
568603 pad = 0 ,
569604 dtype = torch .int32 ,
@@ -575,7 +610,10 @@ def test_prepare_decode_cuda_graph(batch_size):
575610 )
576611 # - Encoder/decoder cross-attention. Pad the cross-attention block tables
577612 # as expected.
578- expected = [cross_block_table for _ in range (len (seq_group_metadata_list ))]
613+ expected = [
614+ cross_block_table for seq_group_metadata in seq_group_metadata_list
615+ for _ in range (len (seq_group_metadata .seq_data ))
616+ ]
579617 expected .extend ([[] for _ in range (cuda_graph_pad_size )])
580618 expected = make_tensor_with_pad (
581619 expected ,
0 commit comments