@@ -374,20 +374,15 @@ def __init__( # pylint: disable=too-many-locals
374374 if rope_mode == RopeMode .INLINE :
375375 assert rotary_dim == qk_head_dim , "FlashInfer RoPE does not support partial rotary dim."
376376
377+ attn_kind_single = attn_kind [0 ] if isinstance (attn_kind , List ) else attn_kind
378+ if attn_kind_single == "mha_sliding" :
379+ attn_kind_single = "mha"
377380 flashinfer_prefill_mods = rx .backend .cuda .flashinfer .gen_flashinfer_prefill_module (
378381 dtype_q = dtype ,
379382 dtype_kv = dtype ,
380383 dtype_o = dtype ,
381- qk_head_dim = (
382- qk_head_dim
383- if (attn_kind == "mha" or isinstance (attn_kind , List ))
384- else mla_original_qk_head_dim
385- ),
386- v_head_dim = (
387- v_head_dim
388- if (attn_kind == "mha" or isinstance (attn_kind , List ))
389- else mla_original_v_head_dim
390- ),
384+ qk_head_dim = (qk_head_dim if attn_kind_single == "mha" else mla_original_qk_head_dim ),
385+ v_head_dim = (v_head_dim if attn_kind_single == "mha" else mla_original_v_head_dim ),
391386 target = target ,
392387 enable_inline_rope = rope_mode == RopeMode .INLINE ,
393388 )
@@ -400,7 +395,7 @@ def __init__( # pylint: disable=too-many-locals
400395 v_head_dim = v_head_dim ,
401396 target = target ,
402397 )
403- if ( attn_kind == "mha" or isinstance ( attn_kind , List ))
398+ if attn_kind_single == "mha"
404399 else []
405400 )
406401 flashinfer_mla_mods = (
@@ -412,7 +407,7 @@ def __init__( # pylint: disable=too-many-locals
412407 head_dim_kpe = qk_head_dim - v_head_dim ,
413408 target = target ,
414409 )
415- if attn_kind == "mla"
410+ if attn_kind_single == "mla"
416411 else []
417412 )
418413 self .extern_mods = flashinfer_prefill_mods + flashinfer_decode_mods + flashinfer_mla_mods
@@ -429,21 +424,21 @@ def __init__( # pylint: disable=too-many-locals
429424 rx .Tuple ([rx .StringImm ("tir" ), bb .add_func (tree_attn_with_paged_kv_cache (num_key_value_heads , num_attention_heads , qk_head_dim , dtype , rope_scaling , target ), "tir_attention_prefill_with_tree_mask_with_paged_kv_cache" )]),
430425 rx .Tuple ([rx .StringImm ("tir" ), bb .add_func (tree_attn (num_key_value_heads , num_attention_heads , qk_head_dim , dtype , rope_scaling , target ), "tir_attention_prefill_with_tree_mask" )]),
431426 ]
432- if ( attn_kind == "mha" or isinstance ( attn_kind , List ))
427+ if attn_kind_single == "mha"
433428 else [rx .Tuple ([]) for _ in range (6 )]
434429 )
435- mla_function = rx .Tuple ([rx .StringImm ("flashinfer" ), rx .ExternFunc ("batch_mla_paged_attention_run" ), rx .ExternFunc ("batch_mla_paged_attention_plan" )] if attn_kind == "mla" else [])
430+ mla_function = rx .Tuple ([rx .StringImm ("flashinfer" ), rx .ExternFunc ("batch_mla_paged_attention_run" ), rx .ExternFunc ("batch_mla_paged_attention_plan" )] if attn_kind_single == "mla" else [])
436431 attn_merge_functions = [
437432 bb .add_func (_merge_state_inplace (num_attention_heads , v_head_dim , dtype , target , "tir_attention_merge_state" ), "tir_attention_merge_state" ),
438433 ]
439- if attn_kind == "mla" :
434+ if attn_kind_single == "mla" :
440435 attn_merge_functions .append (bb .add_func (_merge_state_inplace (num_attention_heads , mla_original_v_head_dim , dtype , target , "tir_attention_merge_state_mla" ), "tir_attention_merge_state_mla" ))
441436
442-
443437 if isinstance (attn_kind , List ):
444438 attn_kind = [int (getattr (AttnKind , layer_kind .upper ())) for layer_kind in attn_kind ]
445439 else :
446440 attn_kind = [int (getattr (AttnKind , attn_kind .upper ())) for _ in range (num_hidden_layers )]
441+
447442 args = [
448443 rx .ShapeExpr (
449444 [
@@ -459,9 +454,7 @@ def __init__( # pylint: disable=too-many-locals
459454 rx .PrimValue (num_key_value_heads ),
460455 rx .PrimValue (qk_head_dim ),
461456 rx .PrimValue (v_head_dim ),
462- rx .ShapeExpr (
463- [int (getattr (AttnKind , attn_kind .upper ())) for _ in range (num_hidden_layers )]
464- ),
457+ rx .ShapeExpr (attn_kind ),
465458 rx .PrimValue (enable_disaggregation ),
466459 rx .PrimValue (rope_mode ),
467460 rx .PrimValue (rope_scale ),
@@ -475,7 +468,7 @@ def __init__( # pylint: disable=too-many-locals
475468 mla_function ,
476469 rx .Tuple (attn_merge_functions ),
477470 bb .add_func (llama_rope_with_position_map (rope_theta , rope_scale , qk_head_dim , num_attention_heads , num_key_value_heads , dtype , rope_scaling , rotary_dim ), "tir_split_rotary" ),
478- bb .add_func (_copy_single_page (num_key_value_heads , page_size , qk_head_dim , dtype , target ) if attn_kind == "mha" else _copy_single_page_mla (page_size , qk_head_dim , dtype , target ), "kv_cache_copy_single_page" ),
471+ bb .add_func (_copy_single_page (num_key_value_heads , page_size , qk_head_dim , dtype , target ) if attn_kind_single == "mha" else _copy_single_page_mla (page_size , qk_head_dim , dtype , target ), "kv_cache_copy_single_page" ),
479472 bb .add_func (_kv_cache_debug_get_kv (num_hidden_layers , num_key_value_heads , qk_head_dim , dtype ), "kv_cache_debug_get_kv" ),
480473 bb .add_func (_compact_kv_copy (num_key_value_heads , qk_head_dim , dtype , target ), "kv_cache_compact_kv_copy" ),
481474 # fmt: on
@@ -567,6 +560,9 @@ def __init__( # pylint: disable=too-many-locals
567560 target : Target
568561 The target to build the model to.
569562 """
563+ attn_kind_single = attn_kind [0 ] if isinstance (attn_kind , List ) else attn_kind
564+ if attn_kind_single == "mha_sliding" :
565+ attn_kind_single = "mha"
570566 if isinstance (attn_kind , List ):
571567 attn_kind = [int (getattr (AttnKind , layer_kind .upper ())) for layer_kind in attn_kind ]
572568 else :
@@ -605,7 +601,7 @@ def __init__( # pylint: disable=too-many-locals
605601 ]
606602
607603 if str (target .kind ) == "llvm" :
608- if attn_kind == "mla" :
604+ if attn_kind_single == "mla" :
609605 raise ValueError ("MLA is not supported in TIR kernels for now." )
610606 # pylint: disable=line-too-long
611607 # fmt: off
@@ -631,9 +627,9 @@ def __init__( # pylint: disable=too-many-locals
631627 else :
632628 # pylint: disable=line-too-long
633629 # fmt: off
634- ragged_qk_head_dim = qk_head_dim if ( attn_kind == "mha" or isinstance ( attn_kind , List )) else mla_original_qk_head_dim
635- ragged_v_head_dim = v_head_dim if ( attn_kind == "mha" or isinstance ( attn_kind , List )) else mla_original_v_head_dim
636- args .append (rx .Tuple ([rx .StringImm ("tir" ), bb .add_func (_attention_prefill_ragged (num_key_value_heads if ( attn_kind == "mha" or isinstance ( attn_kind , List )) else num_attention_heads , num_attention_heads , ragged_qk_head_dim , ragged_v_head_dim , dtype , rope_scaling , target ), "tir_attention_prefill_ragged" )]))
630+ ragged_qk_head_dim = qk_head_dim if attn_kind_single == "mha" else mla_original_qk_head_dim
631+ ragged_v_head_dim = v_head_dim if attn_kind_single == "mha" else mla_original_v_head_dim
632+ args .append (rx .Tuple ([rx .StringImm ("tir" ), bb .add_func (_attention_prefill_ragged (num_key_value_heads if attn_kind_single == "mha" else num_attention_heads , num_attention_heads , ragged_qk_head_dim , ragged_v_head_dim , dtype , rope_scaling , target ), "tir_attention_prefill_ragged" )]))
637633 mha_functions = (
638634 [
639635 rx .Tuple ([rx .StringImm ("tir" ), bb .add_func (_attention_prefill (num_key_value_heads , num_attention_heads , qk_head_dim , dtype , False , rope_scaling , target ), "tir_attention_prefill" )]),
@@ -643,22 +639,22 @@ def __init__( # pylint: disable=too-many-locals
643639 rx .Tuple ([rx .StringImm ("tir" ), bb .add_func (tree_attn_with_paged_kv_cache (num_key_value_heads , num_attention_heads , qk_head_dim , dtype , rope_scaling , target ), "tir_attention_prefill_with_tree_mask_with_paged_kv_cache" )]),
644640 rx .Tuple ([rx .StringImm ("tir" ), bb .add_func (tree_attn (num_key_value_heads , num_attention_heads , qk_head_dim , dtype , rope_scaling , target ), "tir_attention_prefill_with_tree_mask" )]),
645641 ]
646- if ( attn_kind == "mha" or isinstance ( attn_kind , List ))
642+ if attn_kind_single == "mha"
647643 else [rx .Tuple ([]) for _ in range (6 )]
648644 )
649- mla_function = rx .Tuple ([rx .StringImm ("tir" ), bb .add_func (_attention_prefill_mla (num_attention_heads , v_head_dim , qk_head_dim - v_head_dim , dtype , False , target ), "tir_attention_prefill_mla" )] if attn_kind == "mla" else [])
645+ mla_function = rx .Tuple ([rx .StringImm ("tir" ), bb .add_func (_attention_prefill_mla (num_attention_heads , v_head_dim , qk_head_dim - v_head_dim , dtype , False , target ), "tir_attention_prefill_mla" )] if attn_kind_single == "mla" else [])
650646 attn_merge_functions = [
651647 bb .add_func (_merge_state_inplace (num_attention_heads , v_head_dim , dtype , target , "tir_attention_merge_state" ), "tir_attention_merge_state" ),
652648 ]
653- if attn_kind == "mla" :
649+ if attn_kind_single == "mla" :
654650 attn_merge_functions .append (bb .add_func (_merge_state_inplace (num_attention_heads , mla_original_v_head_dim , dtype , target , "tir_attention_merge_state_mla" ), "tir_attention_merge_state_mla" ))
655651 args .extend (mha_functions )
656652 args .append (mla_function )
657653 args .extend (
658654 [
659655 rx .Tuple (attn_merge_functions ),
660656 bb .add_func (llama_rope_with_position_map (rope_theta , rope_scale , qk_head_dim , num_attention_heads , num_key_value_heads , dtype , rope_scaling , rotary_dim ), "tir_split_rotary" ),
661- bb .add_func (_copy_single_page (num_key_value_heads , page_size , qk_head_dim , dtype , target ) if ( attn_kind == "mha" or isinstance ( attn_kind , List )) else _copy_single_page_mla (page_size , qk_head_dim , dtype , target ), "kv_cache_copy_single_page" ),
657+ bb .add_func (_copy_single_page (num_key_value_heads , page_size , qk_head_dim , dtype , target ) if attn_kind_single == "mha" else _copy_single_page_mla (page_size , qk_head_dim , dtype , target ), "kv_cache_copy_single_page" ),
662658 bb .add_func (_kv_cache_debug_get_kv (num_hidden_layers , num_key_value_heads , qk_head_dim , dtype ), "kv_cache_debug_get_kv" ),
663659 bb .add_func (_compact_kv_copy (num_key_value_heads , qk_head_dim , dtype , target ), "kv_cache_compact_kv_copy" ),
664660 ]
0 commit comments