Skip to content

Commit c8625ad

Browse files
MasterJH5574yongwww
authored andcommitted
[KVCache] Fix kernel dispatch based on attention kinds (apache#18122)
* [KVCache] Fix kernel dispatch based on attention kinds This PR fixes a few kernel dispatch issues due to the recent introduction of `mha_sliding` as a new attention kind. Tested on Qwen3 1.7B with MLC-LLM. * Fix lint --------- Co-authored-by: Yong Wu <[email protected]>
1 parent 0c38164 commit c8625ad

File tree

1 file changed

+24
-28
lines changed

1 file changed

+24
-28
lines changed

python/tvm/relax/frontend/nn/llm/kv_cache.py

Lines changed: 24 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -374,20 +374,15 @@ def __init__( # pylint: disable=too-many-locals
374374
if rope_mode == RopeMode.INLINE:
375375
assert rotary_dim == qk_head_dim, "FlashInfer RoPE does not support partial rotary dim."
376376

377+
attn_kind_single = attn_kind[0] if isinstance(attn_kind, List) else attn_kind
378+
if attn_kind_single == "mha_sliding":
379+
attn_kind_single = "mha"
377380
flashinfer_prefill_mods = rx.backend.cuda.flashinfer.gen_flashinfer_prefill_module(
378381
dtype_q=dtype,
379382
dtype_kv=dtype,
380383
dtype_o=dtype,
381-
qk_head_dim=(
382-
qk_head_dim
383-
if (attn_kind == "mha" or isinstance(attn_kind, List))
384-
else mla_original_qk_head_dim
385-
),
386-
v_head_dim=(
387-
v_head_dim
388-
if (attn_kind == "mha" or isinstance(attn_kind, List))
389-
else mla_original_v_head_dim
390-
),
384+
qk_head_dim=(qk_head_dim if attn_kind_single == "mha" else mla_original_qk_head_dim),
385+
v_head_dim=(v_head_dim if attn_kind_single == "mha" else mla_original_v_head_dim),
391386
target=target,
392387
enable_inline_rope=rope_mode == RopeMode.INLINE,
393388
)
@@ -400,7 +395,7 @@ def __init__( # pylint: disable=too-many-locals
400395
v_head_dim=v_head_dim,
401396
target=target,
402397
)
403-
if (attn_kind == "mha" or isinstance(attn_kind, List))
398+
if attn_kind_single == "mha"
404399
else []
405400
)
406401
flashinfer_mla_mods = (
@@ -412,7 +407,7 @@ def __init__( # pylint: disable=too-many-locals
412407
head_dim_kpe=qk_head_dim - v_head_dim,
413408
target=target,
414409
)
415-
if attn_kind == "mla"
410+
if attn_kind_single == "mla"
416411
else []
417412
)
418413
self.extern_mods = flashinfer_prefill_mods + flashinfer_decode_mods + flashinfer_mla_mods
@@ -429,21 +424,21 @@ def __init__( # pylint: disable=too-many-locals
429424
rx.Tuple([rx.StringImm("tir"), bb.add_func(tree_attn_with_paged_kv_cache(num_key_value_heads, num_attention_heads, qk_head_dim, dtype, rope_scaling, target), "tir_attention_prefill_with_tree_mask_with_paged_kv_cache")]),
430425
rx.Tuple([rx.StringImm("tir"), bb.add_func(tree_attn(num_key_value_heads, num_attention_heads, qk_head_dim, dtype, rope_scaling, target), "tir_attention_prefill_with_tree_mask")]),
431426
]
432-
if (attn_kind == "mha" or isinstance(attn_kind, List))
427+
if attn_kind_single == "mha"
433428
else [rx.Tuple([]) for _ in range(6)]
434429
)
435-
mla_function = rx.Tuple([rx.StringImm("flashinfer"), rx.ExternFunc("batch_mla_paged_attention_run"), rx.ExternFunc("batch_mla_paged_attention_plan")] if attn_kind == "mla" else [])
430+
mla_function = rx.Tuple([rx.StringImm("flashinfer"), rx.ExternFunc("batch_mla_paged_attention_run"), rx.ExternFunc("batch_mla_paged_attention_plan")] if attn_kind_single == "mla" else [])
436431
attn_merge_functions = [
437432
bb.add_func(_merge_state_inplace(num_attention_heads, v_head_dim, dtype, target, "tir_attention_merge_state"), "tir_attention_merge_state"),
438433
]
439-
if attn_kind == "mla":
434+
if attn_kind_single == "mla":
440435
attn_merge_functions.append(bb.add_func(_merge_state_inplace(num_attention_heads, mla_original_v_head_dim, dtype, target, "tir_attention_merge_state_mla"), "tir_attention_merge_state_mla"))
441436

442-
443437
if isinstance(attn_kind, List):
444438
attn_kind = [int(getattr(AttnKind, layer_kind.upper())) for layer_kind in attn_kind]
445439
else:
446440
attn_kind = [int(getattr(AttnKind, attn_kind.upper())) for _ in range(num_hidden_layers)]
441+
447442
args = [
448443
rx.ShapeExpr(
449444
[
@@ -459,9 +454,7 @@ def __init__( # pylint: disable=too-many-locals
459454
rx.PrimValue(num_key_value_heads),
460455
rx.PrimValue(qk_head_dim),
461456
rx.PrimValue(v_head_dim),
462-
rx.ShapeExpr(
463-
[int(getattr(AttnKind, attn_kind.upper())) for _ in range(num_hidden_layers)]
464-
),
457+
rx.ShapeExpr(attn_kind),
465458
rx.PrimValue(enable_disaggregation),
466459
rx.PrimValue(rope_mode),
467460
rx.PrimValue(rope_scale),
@@ -475,7 +468,7 @@ def __init__( # pylint: disable=too-many-locals
475468
mla_function,
476469
rx.Tuple(attn_merge_functions),
477470
bb.add_func(llama_rope_with_position_map(rope_theta, rope_scale, qk_head_dim, num_attention_heads, num_key_value_heads, dtype, rope_scaling, rotary_dim), "tir_split_rotary"),
478-
bb.add_func(_copy_single_page(num_key_value_heads, page_size, qk_head_dim, dtype, target) if attn_kind == "mha" else _copy_single_page_mla(page_size, qk_head_dim, dtype, target), "kv_cache_copy_single_page"),
471+
bb.add_func(_copy_single_page(num_key_value_heads, page_size, qk_head_dim, dtype, target) if attn_kind_single == "mha" else _copy_single_page_mla(page_size, qk_head_dim, dtype, target), "kv_cache_copy_single_page"),
479472
bb.add_func(_kv_cache_debug_get_kv(num_hidden_layers, num_key_value_heads, qk_head_dim, dtype), "kv_cache_debug_get_kv"),
480473
bb.add_func(_compact_kv_copy(num_key_value_heads, qk_head_dim, dtype, target), "kv_cache_compact_kv_copy"),
481474
# fmt: on
@@ -567,6 +560,9 @@ def __init__( # pylint: disable=too-many-locals
567560
target : Target
568561
The target to build the model to.
569562
"""
563+
attn_kind_single = attn_kind[0] if isinstance(attn_kind, List) else attn_kind
564+
if attn_kind_single == "mha_sliding":
565+
attn_kind_single = "mha"
570566
if isinstance(attn_kind, List):
571567
attn_kind = [int(getattr(AttnKind, layer_kind.upper())) for layer_kind in attn_kind]
572568
else:
@@ -605,7 +601,7 @@ def __init__( # pylint: disable=too-many-locals
605601
]
606602

607603
if str(target.kind) == "llvm":
608-
if attn_kind == "mla":
604+
if attn_kind_single == "mla":
609605
raise ValueError("MLA is not supported in TIR kernels for now.")
610606
# pylint: disable=line-too-long
611607
# fmt: off
@@ -631,9 +627,9 @@ def __init__( # pylint: disable=too-many-locals
631627
else:
632628
# pylint: disable=line-too-long
633629
# fmt: off
634-
ragged_qk_head_dim = qk_head_dim if (attn_kind == "mha" or isinstance(attn_kind, List)) else mla_original_qk_head_dim
635-
ragged_v_head_dim = v_head_dim if (attn_kind == "mha" or isinstance(attn_kind, List)) else mla_original_v_head_dim
636-
args.append(rx.Tuple([rx.StringImm("tir"), bb.add_func(_attention_prefill_ragged(num_key_value_heads if (attn_kind == "mha" or isinstance(attn_kind, List)) else num_attention_heads, num_attention_heads, ragged_qk_head_dim, ragged_v_head_dim, dtype, rope_scaling, target), "tir_attention_prefill_ragged")]))
630+
ragged_qk_head_dim = qk_head_dim if attn_kind_single == "mha" else mla_original_qk_head_dim
631+
ragged_v_head_dim = v_head_dim if attn_kind_single == "mha" else mla_original_v_head_dim
632+
args.append(rx.Tuple([rx.StringImm("tir"), bb.add_func(_attention_prefill_ragged(num_key_value_heads if attn_kind_single == "mha" else num_attention_heads, num_attention_heads, ragged_qk_head_dim, ragged_v_head_dim, dtype, rope_scaling, target), "tir_attention_prefill_ragged")]))
637633
mha_functions = (
638634
[
639635
rx.Tuple([rx.StringImm("tir"), bb.add_func(_attention_prefill(num_key_value_heads, num_attention_heads, qk_head_dim, dtype, False, rope_scaling, target), "tir_attention_prefill")]),
@@ -643,22 +639,22 @@ def __init__( # pylint: disable=too-many-locals
643639
rx.Tuple([rx.StringImm("tir"), bb.add_func(tree_attn_with_paged_kv_cache(num_key_value_heads, num_attention_heads, qk_head_dim, dtype, rope_scaling, target), "tir_attention_prefill_with_tree_mask_with_paged_kv_cache")]),
644640
rx.Tuple([rx.StringImm("tir"), bb.add_func(tree_attn(num_key_value_heads, num_attention_heads, qk_head_dim, dtype, rope_scaling, target), "tir_attention_prefill_with_tree_mask")]),
645641
]
646-
if (attn_kind == "mha" or isinstance(attn_kind, List))
642+
if attn_kind_single == "mha"
647643
else [rx.Tuple([]) for _ in range(6)]
648644
)
649-
mla_function = rx.Tuple([rx.StringImm("tir"), bb.add_func(_attention_prefill_mla(num_attention_heads, v_head_dim, qk_head_dim - v_head_dim, dtype, False, target), "tir_attention_prefill_mla")] if attn_kind == "mla" else [])
645+
mla_function = rx.Tuple([rx.StringImm("tir"), bb.add_func(_attention_prefill_mla(num_attention_heads, v_head_dim, qk_head_dim - v_head_dim, dtype, False, target), "tir_attention_prefill_mla")] if attn_kind_single == "mla" else [])
650646
attn_merge_functions = [
651647
bb.add_func(_merge_state_inplace(num_attention_heads, v_head_dim, dtype, target, "tir_attention_merge_state"), "tir_attention_merge_state"),
652648
]
653-
if attn_kind == "mla":
649+
if attn_kind_single == "mla":
654650
attn_merge_functions.append(bb.add_func(_merge_state_inplace(num_attention_heads, mla_original_v_head_dim, dtype, target, "tir_attention_merge_state_mla"), "tir_attention_merge_state_mla"))
655651
args.extend(mha_functions)
656652
args.append(mla_function)
657653
args.extend(
658654
[
659655
rx.Tuple(attn_merge_functions),
660656
bb.add_func(llama_rope_with_position_map(rope_theta, rope_scale, qk_head_dim, num_attention_heads, num_key_value_heads, dtype, rope_scaling, rotary_dim), "tir_split_rotary"),
661-
bb.add_func(_copy_single_page(num_key_value_heads, page_size, qk_head_dim, dtype, target) if (attn_kind == "mha" or isinstance(attn_kind, List)) else _copy_single_page_mla(page_size, qk_head_dim, dtype, target), "kv_cache_copy_single_page"),
657+
bb.add_func(_copy_single_page(num_key_value_heads, page_size, qk_head_dim, dtype, target) if attn_kind_single == "mha" else _copy_single_page_mla(page_size, qk_head_dim, dtype, target), "kv_cache_copy_single_page"),
662658
bb.add_func(_kv_cache_debug_get_kv(num_hidden_layers, num_key_value_heads, qk_head_dim, dtype), "kv_cache_debug_get_kv"),
663659
bb.add_func(_compact_kv_copy(num_key_value_heads, qk_head_dim, dtype, target), "kv_cache_compact_kv_copy"),
664660
]

0 commit comments

Comments
 (0)