@@ -690,18 +690,18 @@ def sparse_attn_indexer(
690690 padded_q_fp8_decode_tokens = q_fp8 [:num_decode_tokens ].reshape (decode_lens .shape [0 ], - 1 , * q_fp8 .shape [1 :])
691691 # TODO: move and optimize below logic with triton kernels
692692 batch_size = padded_q_fp8_decode_tokens .shape [0 ]
693+ next_n = padded_q_fp8_decode_tokens .shape [1 ]
693694 assert batch_size == decode_metadata .seq_lens .shape [0 ]
695+ num_padded_tokens = batch_size * next_n
694696 logits = fp8_paged_mqa_logits (
695697 padded_q_fp8_decode_tokens ,
696698 kv_cache ,
697- weights [:num_decode_tokens ],
699+ weights [:num_padded_tokens ],
698700 decode_metadata .seq_lens ,
699701 decode_metadata .block_table ,
700702 decode_metadata .schedule_metadata ,
701703 max_model_len = max_model_len ,
702704 )
703- # [B, N, L]
704- next_n = padded_q_fp8_decode_tokens .shape [1 ]
705705 # padded query len
706706 current_device = padded_q_fp8_decode_tokens .device
707707 padded_num_tokens = batch_size * next_n
@@ -721,7 +721,7 @@ def sparse_attn_indexer(
721721 topk_indices [topk_indices > index_end_pos ] = - 1
722722 if decode_metadata .requires_padding :
723723 # if padded, we need to unpack the topk indices removing padded tokens
724- topk_indices = unpack_seq_triton (topk_indices .reshape (batch_size , - 1 , logits .shape [- 1 ]), decode_lens )
724+ topk_indices = unpack_seq_triton (topk_indices .reshape (batch_size , - 1 , topk_indices .shape [- 1 ]), decode_lens )
725725 topk_indices_buffer [:num_decode_tokens , :topk_indices .
726726 shape [- 1 ]] = topk_indices .to (
727727 dtype = torch .int32 )
0 commit comments