@@ -270,6 +270,7 @@ class Batch:
270270 prefix_lens : torch .Tensor = None
271271 position_ids_offsets : torch .Tensor = None
272272 out_cache_loc : torch .Tensor = None
273+ extend_num_tokens : int = None
273274
274275 # For processing logprobs
275276 return_logprob : bool = False
@@ -280,10 +281,6 @@ class Batch:
280281 image_sizes : List [List [int ]] = None
281282 image_offsets : List [int ] = None
282283
283- # Other arguments for control
284- output_ids : torch .Tensor = None
285- extend_num_tokens : int = None
286-
287284 # Batched sampling params
288285 temperatures : torch .Tensor = None
289286 top_ps : torch .Tensor = None
@@ -820,6 +817,7 @@ def init_flashinfer_args(
820817 prefix_lens ,
821818 flashinfer_decode_wrapper ,
822819):
820+ """Init auxiliary variables for FlashInfer attention backend."""
823821 num_qo_heads = model_runner .model_config .num_attention_heads // model_runner .tp_size
824822 num_kv_heads = model_runner .model_config .get_num_kv_heads (model_runner .tp_size )
825823 head_dim = model_runner .model_config .head_dim
@@ -885,6 +883,7 @@ def init_flashinfer_args(
885883
886884
887885def init_triton_args (forward_mode , seq_lens , prefix_lens ):
886+ """Init auxiliary variables for triton attention backend."""
888887 batch_size = len (seq_lens )
889888 max_seq_len = int (torch .max (seq_lens ))
890889 start_loc = torch .zeros ((batch_size ,), dtype = torch .int32 , device = "cuda" )
0 commit comments