Skip to content
Open
Show file tree
Hide file tree
Changes from 19 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
0860de4
First working version
simondanielsson Oct 14, 2025
7b41ac4
Merge remote-tracking branch 'upstream/main' into feature/gdn-apc
simondanielsson Oct 14, 2025
538c9a0
Update type hints in gdn_attn
simondanielsson Oct 14, 2025
3fffae0
[DCP] Support Decode Context Parallel (DCP) for GQA with FlashAttenti…
FENP Oct 14, 2025
76ac0fa
Enable cudagraphs support [skip ci]
simondanielsson Oct 14, 2025
1d3afe0
Merge remote-tracking branch 'upstream/main' into feature/gdn-apc
simondanielsson Oct 14, 2025
795ed51
Fix long() -> long [skip ci]
simondanielsson Oct 14, 2025
044990c
Add defensive programming asserts
simondanielsson Oct 14, 2025
68ca70f
Allocate metadata buffer by chunk count rather than block count, and …
simondanielsson Oct 16, 2025
fe8f0b7
Return hidden state when return_intermediate_states is passed, ignori…
simondanielsson Oct 16, 2025
ac226e8
Inline _reshape_intermediate_states in the fla chunk kernel wrapper
simondanielsson Oct 16, 2025
f975260
Add more explanatory comments in FLA's chunk.py
simondanielsson Oct 16, 2025
e74f67d
Improve logging
simondanielsson Oct 16, 2025
f177a1f
Add GDN model to APC tests
simondanielsson Oct 16, 2025
552ba6f
Add helpful comments in hard-to-understand areas
simondanielsson Oct 16, 2025
30b1ea0
Merge remote-tracking branch 'upstream/main' into feature/gdn-apc
simondanielsson Oct 16, 2025
2ab062d
Improve way to set chunk_size=64 for GDN
simondanielsson Oct 16, 2025
4837a11
Revert KV cache memory limit in test
simondanielsson Oct 16, 2025
3a88844
Merge remote-tracking branch 'upstream/main' into feature/gdn-apc
simondanielsson Oct 16, 2025
b58362a
Add dynamic counting of decode chunks, rather than static value
simondanielsson Oct 16, 2025
ccda04e
Add plot
simondanielsson Oct 17, 2025
03aa33c
Remove plot
simondanielsson Oct 17, 2025
9896ba4
Merge remote-tracking branch 'upstream/main' into feature/gdn-apc
simondanielsson Nov 1, 2025
46406f1
Remove extra trailing comma
simondanielsson Nov 1, 2025
dbb4fe3
Move hardcoded chunk size to GDN attn metadata builder
simondanielsson Nov 1, 2025
efd451b
Remove extra newline
simondanielsson Nov 1, 2025
bfa6ffc
Merge remote-tracking branch 'upstream/main' into feature/gdn-apc
simondanielsson Nov 1, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 8 additions & 7 deletions tests/models/language/generation/test_hybrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@
# "yujiepan/mamba2-codestral-v0.1-tiny-random",
]

GDN_MODELS = ["tiny-random/qwen3-next-moe"]

HYBRID_MODELS = [
"ai21labs/Jamba-tiny-dev",
"pfnet/plamo-2-1b",
Expand All @@ -35,8 +37,7 @@
"ibm-granite/granite-4.0-tiny-preview",
"tiiuae/Falcon-H1-0.5B-Base",
"LiquidAI/LFM2-1.2B",
"tiny-random/qwen3-next-moe",
]
] + GDN_MODELS

FULL_CUDA_GRAPH_MODELS = [
"ai21labs/Jamba-tiny-dev",
Expand Down Expand Up @@ -380,7 +381,7 @@ def _get_vLLM_output(
return outs, vllm_model


@pytest.mark.parametrize("model", [HYBRID_MODELS[3]])
@pytest.mark.parametrize("model", [HYBRID_MODELS[3]] + GDN_MODELS)
@pytest.mark.parametrize("max_tokens", [64])
@pytest.mark.parametrize("n_repetitions", [2])
# If num_logprobs is set to -1, then the stringent version
Expand Down Expand Up @@ -446,7 +447,7 @@ def test_apc_single_prompt(
)


@pytest.mark.parametrize("model", [HYBRID_MODELS[3]])
@pytest.mark.parametrize("model", [HYBRID_MODELS[3]] + GDN_MODELS)
@pytest.mark.parametrize("max_tokens", [64])
@pytest.mark.parametrize("n_repetitions", [2])
# If num_logprobs is set to -1, then the stringent version
Expand Down Expand Up @@ -528,7 +529,7 @@ def test_apc_single_prompt_block_align_alignment(
)


@pytest.mark.parametrize("model", [HYBRID_MODELS[3]])
@pytest.mark.parametrize("model", [HYBRID_MODELS[3]] + GDN_MODELS)
@pytest.mark.parametrize("max_tokens", [64])
@pytest.mark.parametrize("n_repetitions", [2])
# If num_logprobs is set to -1, then the stringent version
Expand Down Expand Up @@ -595,7 +596,7 @@ def test_apc_multiple_prompts_all_cached_outputs(
)


@pytest.mark.parametrize("model", [HYBRID_MODELS[3]])
@pytest.mark.parametrize("model", [HYBRID_MODELS[3]] + GDN_MODELS)
@pytest.mark.parametrize("max_tokens", [64])
@pytest.mark.parametrize("n_repetitions", [2])
# If num_logprobs is set to -1, then the stringent version
Expand Down Expand Up @@ -679,7 +680,7 @@ def test_apc_multiple_prompts_block_align_alignment(
)


@pytest.mark.parametrize("model", [HYBRID_MODELS[3]])
@pytest.mark.parametrize("model", [HYBRID_MODELS[3]] + GDN_MODELS)
@pytest.mark.parametrize("max_tokens", [64])
@pytest.mark.parametrize("n_repetitions", [2])
# If num_logprobs is set to -1, then the stringent version
Expand Down
5 changes: 5 additions & 0 deletions vllm/config/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1505,6 +1505,11 @@ def get_mamba_chunk_size(self) -> int | None:
if chunk_size is None:
# used by e.g. Mamba2, NemotronH, Zamba
chunk_size = getattr(self.hf_text_config, "chunk_size", None)
if chunk_size is None and self.hf_text_config.model_type == "qwen3_next":
# Fallback for Qwen3-Next. 64 is a hardcoded value in the GDN kernel.
# https://github.com/fla-org/flash-linear-attention/blob/2e7336262c11f8bc6cd6a94b1eb5ee353ae8b4cd/fla/ops/common/chunk_delta_h.py#L439
return 64

return chunk_size

def get_multimodal_config(self) -> MultiModalConfig:
Expand Down
40 changes: 34 additions & 6 deletions vllm/model_executor/layers/fla/ops/chunk.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ def chunk_gated_delta_rule_fwd(
initial_state: torch.Tensor,
output_final_state: bool,
cu_seqlens: torch.LongTensor | None = None,
return_intermediate_states: bool = False,
):
g = chunk_local_cumsum(g, chunk_size=64, cu_seqlens=cu_seqlens)
# obtain WY representation. u is actually the new v.
Expand Down Expand Up @@ -66,7 +67,15 @@ def chunk_gated_delta_rule_fwd(
cu_seqlens=cu_seqlens,
)
if SUPPRESS_LEVEL < 3:
return g, o, A, final_state, None, None, None
return (
g,
o,
A,
final_state,
None,
h if return_intermediate_states else None,
None,
)
elif SUPPRESS_LEVEL >= 3:
return g, o, A, final_state, w, h, v_new

Expand All @@ -87,6 +96,7 @@ def forward(
output_final_state: bool,
cu_seqlens: torch.LongTensor | None = None,
use_qk_l2norm_in_kernel: bool = False,
return_intermediate_states: bool = False,
):
if use_qk_l2norm_in_kernel:
q = l2norm_fwd(q)
Expand All @@ -102,10 +112,22 @@ def forward(
initial_state=initial_state,
output_final_state=output_final_state,
cu_seqlens=cu_seqlens,
return_intermediate_states=return_intermediate_states,
)
ctx.scale = scale
ctx.use_qk_l2norm_in_kernel = use_qk_l2norm_in_kernel
return o.to(q.dtype), final_state
intermediate_states = None
if return_intermediate_states:
assert h is not None
# Convert intermediate states into "chunk-major" form
# Equal-length batches keep their batch dimension; flatten it together
# with the chunk axis so callers receive a contiguous chunk stream.
# Variable-length inputs collapse the batch dimension during preprocessing,
# so the kernel already emits a linearised chunk stream in ``states[:, i]``.
# Flattening mirrors the metadata builder's chunk enumeration order.
# Last three axes of h are [H, K, V], producing [num_chunks_total, H, K, V]
intermediate_states = h.reshape(-1, *h.shape[-3:])
return o.to(q.dtype), final_state, intermediate_states


@torch.compiler.disable
Expand All @@ -121,6 +143,7 @@ def chunk_gated_delta_rule(
cu_seqlens: torch.LongTensor | None = None,
head_first: bool = False,
use_qk_l2norm_in_kernel: bool = False,
return_intermediate_states: bool = False,
):
r"""
Args:
Expand Down Expand Up @@ -155,6 +178,10 @@ def chunk_gated_delta_rule(
Outputs of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`.
final_state (torch.Tensor):
Final state of shape `[N, H, K, V]` if `output_final_state=True` else `None`.
intermediate_states (Optional[torch.Tensor]):
When ``return_intermediate_states`` is ``True`` a tensor containing
the per-chunk state snapshots shaped ``[num_chunks_total, H, K, V]``.
Otherwise ``None``.

Examples::
>>> import torch
Expand All @@ -169,7 +196,7 @@ def chunk_gated_delta_rule(
>>> beta = torch.rand(B, T, H, dtype=torch.bfloat16, device='cuda').sigmoid()
>>> g = F.logsigmoid(torch.rand(B, T, H, dtype=torch.bfloat16, device='cuda'))
>>> h0 = torch.randn(B, H, K, V, dtype=torch.bfloat16, device='cuda')
>>> o, ht = chunk_gated_delta_rule(
>>> o, ht, _ = chunk_gated_delta_rule(
q, k, v, g, beta,
initial_state=h0,
output_final_state=True
Expand All @@ -178,7 +205,7 @@ def chunk_gated_delta_rule(
>>> q, k, v, beta, g = map(lambda x: rearrange(x, 'b t ... -> 1 (b t) ...'), (q, k, v, beta, g))
# for a batch with 4 sequences, `cu_seqlens` with 5 start/end positions are expected
>>> cu_seqlens = q.new_tensor([0, 2048, 4096, 6144, 8192], dtype=torch.long)
>>> o_var, ht_var = chunk_gated_delta_rule(
>>> o_var, ht_var, _ = chunk_gated_delta_rule(
q, k, v, g, beta,
initial_state=h0,
output_final_state=True,
Expand Down Expand Up @@ -223,7 +250,7 @@ def chunk_gated_delta_rule(
)
if scale is None:
scale = k.shape[-1] ** -0.5
o, final_state = ChunkGatedDeltaRuleFunction.apply(
o, final_state, intermediate_states = ChunkGatedDeltaRuleFunction.apply(
q,
k,
v,
Expand All @@ -234,7 +261,8 @@ def chunk_gated_delta_rule(
output_final_state,
cu_seqlens,
use_qk_l2norm_in_kernel,
return_intermediate_states,
)
if head_first:
o = rearrange(o, "b t h ... -> b h t ...")
return o, final_state
return o, final_state, intermediate_states
13 changes: 10 additions & 3 deletions vllm/model_executor/models/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,12 +305,19 @@ def verify_and_update_config(cls, vllm_config: "VllmConfig") -> None:
"NemotronHForCausalLM",
"Zamba2ForCausalLM",
]
GDN_MODELS = [
"Qwen3NextForCausalLM",
]
if cache_config.enable_prefix_caching:
if model_config.architecture in MAMBA2_MODELS:
if model_config.architecture in MAMBA2_MODELS + GDN_MODELS:
layer_type = (
"Mamba2" if model_config.architecture in MAMBA2_MODELS else "GDN"
)
logger.info(
"Warning: Prefix caching is currently enabled. "
"Its support for Mamba2 layers is experimental. "
"Please report any issues you may observe."
"Its support for %s layers is experimental. "
"Please report any issues you may observe.",
layer_type,
)
else:
logger.info(
Expand Down
Loading