Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 24 additions & 28 deletions python/tvm/relax/frontend/nn/llm/kv_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,20 +374,15 @@ def __init__( # pylint: disable=too-many-locals
if rope_mode == RopeMode.INLINE:
assert rotary_dim == qk_head_dim, "FlashInfer RoPE does not support partial rotary dim."

attn_kind_single = attn_kind[0] if isinstance(attn_kind, List) else attn_kind
if attn_kind_single == "mha_sliding":
attn_kind_single = "mha"
flashinfer_prefill_mods = rx.backend.cuda.flashinfer.gen_flashinfer_prefill_module(
dtype_q=dtype,
dtype_kv=dtype,
dtype_o=dtype,
qk_head_dim=(
qk_head_dim
if (attn_kind == "mha" or isinstance(attn_kind, List))
else mla_original_qk_head_dim
),
v_head_dim=(
v_head_dim
if (attn_kind == "mha" or isinstance(attn_kind, List))
else mla_original_v_head_dim
),
qk_head_dim=(qk_head_dim if attn_kind_single == "mha" else mla_original_qk_head_dim),
v_head_dim=(v_head_dim if attn_kind_single == "mha" else mla_original_v_head_dim),
target=target,
enable_inline_rope=rope_mode == RopeMode.INLINE,
)
Expand All @@ -400,7 +395,7 @@ def __init__( # pylint: disable=too-many-locals
v_head_dim=v_head_dim,
target=target,
)
if (attn_kind == "mha" or isinstance(attn_kind, List))
if attn_kind_single == "mha"
else []
)
flashinfer_mla_mods = (
Expand All @@ -412,7 +407,7 @@ def __init__( # pylint: disable=too-many-locals
head_dim_kpe=qk_head_dim - v_head_dim,
target=target,
)
if attn_kind == "mla"
if attn_kind_single == "mla"
else []
)
self.extern_mods = flashinfer_prefill_mods + flashinfer_decode_mods + flashinfer_mla_mods
Expand All @@ -429,21 +424,21 @@ def __init__( # pylint: disable=too-many-locals
rx.Tuple([rx.StringImm("tir"), bb.add_func(tree_attn_with_paged_kv_cache(num_key_value_heads, num_attention_heads, qk_head_dim, dtype, rope_scaling, target), "tir_attention_prefill_with_tree_mask_with_paged_kv_cache")]),
rx.Tuple([rx.StringImm("tir"), bb.add_func(tree_attn(num_key_value_heads, num_attention_heads, qk_head_dim, dtype, rope_scaling, target), "tir_attention_prefill_with_tree_mask")]),
]
if (attn_kind == "mha" or isinstance(attn_kind, List))
if attn_kind_single == "mha"
else [rx.Tuple([]) for _ in range(6)]
)
mla_function = rx.Tuple([rx.StringImm("flashinfer"), rx.ExternFunc("batch_mla_paged_attention_run"), rx.ExternFunc("batch_mla_paged_attention_plan")] if attn_kind == "mla" else [])
mla_function = rx.Tuple([rx.StringImm("flashinfer"), rx.ExternFunc("batch_mla_paged_attention_run"), rx.ExternFunc("batch_mla_paged_attention_plan")] if attn_kind_single == "mla" else [])
attn_merge_functions = [
bb.add_func(_merge_state_inplace(num_attention_heads, v_head_dim, dtype, target, "tir_attention_merge_state"), "tir_attention_merge_state"),
]
if attn_kind == "mla":
if attn_kind_single == "mla":
attn_merge_functions.append(bb.add_func(_merge_state_inplace(num_attention_heads, mla_original_v_head_dim, dtype, target, "tir_attention_merge_state_mla"), "tir_attention_merge_state_mla"))


if isinstance(attn_kind, List):
attn_kind = [int(getattr(AttnKind, layer_kind.upper())) for layer_kind in attn_kind]
else:
attn_kind = [int(getattr(AttnKind, attn_kind.upper())) for _ in range(num_hidden_layers)]

args = [
rx.ShapeExpr(
[
Expand All @@ -459,9 +454,7 @@ def __init__( # pylint: disable=too-many-locals
rx.PrimValue(num_key_value_heads),
rx.PrimValue(qk_head_dim),
rx.PrimValue(v_head_dim),
rx.ShapeExpr(
[int(getattr(AttnKind, attn_kind.upper())) for _ in range(num_hidden_layers)]
),
rx.ShapeExpr(attn_kind),
rx.PrimValue(enable_disaggregation),
rx.PrimValue(rope_mode),
rx.PrimValue(rope_scale),
Expand All @@ -475,7 +468,7 @@ def __init__( # pylint: disable=too-many-locals
mla_function,
rx.Tuple(attn_merge_functions),
bb.add_func(llama_rope_with_position_map(rope_theta, rope_scale, qk_head_dim, num_attention_heads, num_key_value_heads, dtype, rope_scaling, rotary_dim), "tir_split_rotary"),
bb.add_func(_copy_single_page(num_key_value_heads, page_size, qk_head_dim, dtype, target) if attn_kind == "mha" else _copy_single_page_mla(page_size, qk_head_dim, dtype, target), "kv_cache_copy_single_page"),
bb.add_func(_copy_single_page(num_key_value_heads, page_size, qk_head_dim, dtype, target) if attn_kind_single == "mha" else _copy_single_page_mla(page_size, qk_head_dim, dtype, target), "kv_cache_copy_single_page"),
bb.add_func(_kv_cache_debug_get_kv(num_hidden_layers, num_key_value_heads, qk_head_dim, dtype), "kv_cache_debug_get_kv"),
bb.add_func(_compact_kv_copy(num_key_value_heads, qk_head_dim, dtype, target), "kv_cache_compact_kv_copy"),
# fmt: on
Expand Down Expand Up @@ -567,6 +560,9 @@ def __init__( # pylint: disable=too-many-locals
target : Target
The target to build the model to.
"""
attn_kind_single = attn_kind[0] if isinstance(attn_kind, List) else attn_kind
if attn_kind_single == "mha_sliding":
attn_kind_single = "mha"
if isinstance(attn_kind, List):
attn_kind = [int(getattr(AttnKind, layer_kind.upper())) for layer_kind in attn_kind]
else:
Expand Down Expand Up @@ -605,7 +601,7 @@ def __init__( # pylint: disable=too-many-locals
]

if str(target.kind) == "llvm":
if attn_kind == "mla":
if attn_kind_single == "mla":
raise ValueError("MLA is not supported in TIR kernels for now.")
# pylint: disable=line-too-long
# fmt: off
Expand All @@ -631,9 +627,9 @@ def __init__( # pylint: disable=too-many-locals
else:
# pylint: disable=line-too-long
# fmt: off
ragged_qk_head_dim = qk_head_dim if (attn_kind == "mha" or isinstance(attn_kind, List)) else mla_original_qk_head_dim
ragged_v_head_dim = v_head_dim if (attn_kind == "mha" or isinstance(attn_kind, List)) else mla_original_v_head_dim
args.append(rx.Tuple([rx.StringImm("tir"), bb.add_func(_attention_prefill_ragged(num_key_value_heads if (attn_kind == "mha" or isinstance(attn_kind, List)) else num_attention_heads, num_attention_heads, ragged_qk_head_dim, ragged_v_head_dim, dtype, rope_scaling, target), "tir_attention_prefill_ragged")]))
ragged_qk_head_dim = qk_head_dim if attn_kind_single == "mha" else mla_original_qk_head_dim
ragged_v_head_dim = v_head_dim if attn_kind_single == "mha" else mla_original_v_head_dim
args.append(rx.Tuple([rx.StringImm("tir"), bb.add_func(_attention_prefill_ragged(num_key_value_heads if attn_kind_single == "mha" else num_attention_heads, num_attention_heads, ragged_qk_head_dim, ragged_v_head_dim, dtype, rope_scaling, target), "tir_attention_prefill_ragged")]))
mha_functions = (
[
rx.Tuple([rx.StringImm("tir"), bb.add_func(_attention_prefill(num_key_value_heads, num_attention_heads, qk_head_dim, dtype, False, rope_scaling, target), "tir_attention_prefill")]),
Expand All @@ -643,22 +639,22 @@ def __init__( # pylint: disable=too-many-locals
rx.Tuple([rx.StringImm("tir"), bb.add_func(tree_attn_with_paged_kv_cache(num_key_value_heads, num_attention_heads, qk_head_dim, dtype, rope_scaling, target), "tir_attention_prefill_with_tree_mask_with_paged_kv_cache")]),
rx.Tuple([rx.StringImm("tir"), bb.add_func(tree_attn(num_key_value_heads, num_attention_heads, qk_head_dim, dtype, rope_scaling, target), "tir_attention_prefill_with_tree_mask")]),
]
if (attn_kind == "mha" or isinstance(attn_kind, List))
if attn_kind_single == "mha"
else [rx.Tuple([]) for _ in range(6)]
)
mla_function = rx.Tuple([rx.StringImm("tir"), bb.add_func(_attention_prefill_mla(num_attention_heads, v_head_dim, qk_head_dim - v_head_dim, dtype, False, target), "tir_attention_prefill_mla")] if attn_kind == "mla" else [])
mla_function = rx.Tuple([rx.StringImm("tir"), bb.add_func(_attention_prefill_mla(num_attention_heads, v_head_dim, qk_head_dim - v_head_dim, dtype, False, target), "tir_attention_prefill_mla")] if attn_kind_single == "mla" else [])
attn_merge_functions = [
bb.add_func(_merge_state_inplace(num_attention_heads, v_head_dim, dtype, target, "tir_attention_merge_state"), "tir_attention_merge_state"),
]
if attn_kind == "mla":
if attn_kind_single == "mla":
attn_merge_functions.append(bb.add_func(_merge_state_inplace(num_attention_heads, mla_original_v_head_dim, dtype, target, "tir_attention_merge_state_mla"), "tir_attention_merge_state_mla"))
args.extend(mha_functions)
args.append(mla_function)
args.extend(
[
rx.Tuple(attn_merge_functions),
bb.add_func(llama_rope_with_position_map(rope_theta, rope_scale, qk_head_dim, num_attention_heads, num_key_value_heads, dtype, rope_scaling, rotary_dim), "tir_split_rotary"),
bb.add_func(_copy_single_page(num_key_value_heads, page_size, qk_head_dim, dtype, target) if (attn_kind == "mha" or isinstance(attn_kind, List)) else _copy_single_page_mla(page_size, qk_head_dim, dtype, target), "kv_cache_copy_single_page"),
bb.add_func(_copy_single_page(num_key_value_heads, page_size, qk_head_dim, dtype, target) if attn_kind_single == "mha" else _copy_single_page_mla(page_size, qk_head_dim, dtype, target), "kv_cache_copy_single_page"),
bb.add_func(_kv_cache_debug_get_kv(num_hidden_layers, num_key_value_heads, qk_head_dim, dtype), "kv_cache_debug_get_kv"),
bb.add_func(_compact_kv_copy(num_key_value_heads, qk_head_dim, dtype, target), "kv_cache_compact_kv_copy"),
]
Expand Down
Loading