@@ -371,8 +371,7 @@ def __init__( # pylint: disable=too-many-locals
371371 enable_disaggregation : bool
372372 Whether to enable disaggregation in the KV cache.
373373 """
374- if rope_mode == RopeMode .INLINE :
375- assert rotary_dim == qk_head_dim , "FlashInfer RoPE does not support partial rotary dim."
374+ assert rope_mode != RopeMode .INLINE , "FlashInfer RoPE does not support inline mode."
376375
377376 attn_kind_single = attn_kind [0 ] if isinstance (attn_kind , List ) else attn_kind
378377 if attn_kind_single == "mha_sliding" :
@@ -383,8 +382,8 @@ def __init__( # pylint: disable=too-many-locals
383382 dtype_o = dtype ,
384383 qk_head_dim = (qk_head_dim if attn_kind_single == "mha" else mla_original_qk_head_dim ),
385384 v_head_dim = (v_head_dim if attn_kind_single == "mha" else mla_original_v_head_dim ),
386- target = target ,
387- enable_inline_rope = rope_mode == RopeMode . INLINE ,
385+ enable_inline_rope = False ,
386+ return_static_libs = True ,
388387 )
389388 flashinfer_decode_mods = (
390389 rx .backend .cuda .flashinfer .gen_flashinfer_decode_module (
@@ -393,7 +392,8 @@ def __init__( # pylint: disable=too-many-locals
393392 dtype_o = dtype ,
394393 qk_head_dim = qk_head_dim ,
395394 v_head_dim = v_head_dim ,
396- target = target ,
395+ enable_inline_rope = False ,
396+ return_static_libs = True ,
397397 )
398398 if attn_kind_single == "mha"
399399 else []
@@ -405,7 +405,7 @@ def __init__( # pylint: disable=too-many-locals
405405 dtype_o = dtype ,
406406 head_dim_ckv = v_head_dim ,
407407 head_dim_kpe = qk_head_dim - v_head_dim ,
408- target = target ,
408+ return_static_libs = True ,
409409 )
410410 if attn_kind_single == "mla"
411411 else []
@@ -417,8 +417,8 @@ def __init__( # pylint: disable=too-many-locals
417417 bb = rx .BlockBuilder .current ()
418418 mha_functions = (
419419 [
420- rx .Tuple ([rx .StringImm ("flashinfer" ), rx .ExternFunc ("batch_prefill_with_paged_kv_cache_run " ), rx .ExternFunc ("batch_prefill_with_kv_cache_plan " )]),
421- rx .Tuple ([rx .StringImm ("flashinfer" ), rx .ExternFunc ("batch_decode_with_paged_kv_cache_run " ), rx .ExternFunc ("batch_decode_with_paged_kv_cache_plan " )]),
420+ rx .Tuple ([rx .StringImm ("flashinfer" ), rx .ExternFunc ("batch_prefill_paged_run " ), rx .ExternFunc ("batch_prefill_plan " )]),
421+ rx .Tuple ([rx .StringImm ("flashinfer" ), rx .ExternFunc ("batch_decode_run " ), rx .ExternFunc ("batch_decode_plan " )]),
422422 rx .Tuple ([rx .StringImm ("tir" ), bb .add_func (_attention_prefill (num_key_value_heads , num_attention_heads , qk_head_dim , dtype , True , rope_scaling , target ), "tir_attention_prefill_sliding_window" )]),
423423 rx .Tuple ([rx .StringImm ("tir" ), bb .add_func (_attention_decode (num_key_value_heads , num_attention_heads , qk_head_dim , dtype , True , rope_scaling , target ), "tir_attention_decode_sliding_window" )]),
424424 rx .Tuple ([rx .StringImm ("tir" ), bb .add_func (tree_attn_with_paged_kv_cache (num_key_value_heads , num_attention_heads , qk_head_dim , dtype , rope_scaling , target ), "tir_attention_prefill_with_tree_mask_with_paged_kv_cache" )]),
@@ -427,7 +427,8 @@ def __init__( # pylint: disable=too-many-locals
427427 if attn_kind_single == "mha"
428428 else [rx .Tuple ([]) for _ in range (6 )]
429429 )
430- mla_function = rx .Tuple ([rx .StringImm ("flashinfer" ), rx .ExternFunc ("batch_mla_paged_attention_run" ), rx .ExternFunc ("batch_mla_paged_attention_plan" )] if attn_kind_single == "mla" else [])
430+ ragged_prefill_function = rx .Tuple ([rx .StringImm ("flashinfer" ), rx .ExternFunc ("batch_prefill_ragged_run" ), rx .ExternFunc ("batch_prefill_plan" )]) if attn_kind_single == "mha" else rx .Tuple ([rx .StringImm ("flashinfer" ), rx .ExternFunc ("batch_prefill_ragged_run" ), rx .ExternFunc ("batch_prefill_plan" ), rx .PrimValue (mla_original_qk_head_dim ), rx .PrimValue (mla_original_v_head_dim )])
431+ mla_function = rx .Tuple ([rx .StringImm ("flashinfer" ), rx .ExternFunc ("batch_mla_run" ), rx .ExternFunc ("batch_mla_plan" )] if attn_kind_single == "mla" else [])
431432 attn_merge_functions = [
432433 bb .add_func (_merge_state_inplace (num_attention_heads , v_head_dim , dtype , target , "tir_attention_merge_state" ), "tir_attention_merge_state" ),
433434 ]
@@ -463,7 +464,7 @@ def __init__( # pylint: disable=too-many-locals
463464 rx .op .zeros ((), dtype ),
464465 bb .add_func (_kv_cache_transpose_append (num_key_value_heads , qk_head_dim , dtype ), "kv_cache_transpose_append" ),
465466 bb .add_func (_kv_cache_transpose_append_mla (qk_head_dim , dtype ), "kv_cache_transpose_append_mla" ),
466- rx . Tuple ([ rx . StringImm ( "flashinfer" ), rx . ExternFunc ( "batch_prefill_with_ragged_kv_cache_run" ), rx . ExternFunc ( "batch_prefill_with_kv_cache_plan" )]) ,
467+ ragged_prefill_function ,
467468 * mha_functions ,
468469 mla_function ,
469470 rx .Tuple (attn_merge_functions ),
0 commit comments