Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion vllm/model_executor/layers/fla/ops/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def tensor_cache(fn: Callable[..., torch.Tensor]) -> Callable[..., torch.Tensor]
"""

cache_entries: tuple[tuple | None, dict | None, Any] = []
cache_size = 4
cache_size = 8

@functools.wraps(fn)
def wrapper(*args: Any, **kwargs: Any) -> Any:
Expand Down
21 changes: 12 additions & 9 deletions vllm/model_executor/models/qwen3_next.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,7 +423,7 @@ def rearrange_mixed_qkv(self, mixed_qkv):
(query, key),
)
value = rearrange(value, "l (h d) -> 1 l h d", d=self.head_v_dim)
return query, key, value
return query.contiguous(), key.contiguous(), value.contiguous()

def forward(
self,
Expand Down Expand Up @@ -456,6 +456,8 @@ def _forward(
non_spec_query_start_loc = attn_metadata.non_spec_query_start_loc
spec_sequence_masks = attn_metadata.spec_sequence_masks
spec_token_masks = attn_metadata.spec_token_masks
spec_token_indx = attn_metadata.spec_token_indx
non_spec_token_indx = attn_metadata.non_spec_token_indx
spec_state_indices_tensor = attn_metadata.spec_state_indices_tensor # noqa: E501
non_spec_state_indices_tensor = attn_metadata.non_spec_state_indices_tensor # noqa: E501
self_kv_cache = self.kv_cache[forward_context.virtual_engine]
Expand Down Expand Up @@ -487,8 +489,8 @@ def _forward(
mixed_qkv_spec = mixed_qkv
mixed_qkv_non_spec = None
else:
mixed_qkv_spec = mixed_qkv[spec_token_masks]
mixed_qkv_non_spec = mixed_qkv[~spec_token_masks]
mixed_qkv_spec = mixed_qkv.index_select(0, spec_token_indx)
mixed_qkv_non_spec = mixed_qkv.index_select(0, non_spec_token_indx)
else:
mixed_qkv_spec = None
mixed_qkv_non_spec = mixed_qkv
Expand Down Expand Up @@ -558,10 +560,10 @@ def _forward(
g_non_spec = None
beta_non_spec = None
else:
g_spec = g[:, spec_token_masks]
beta_spec = beta[:, spec_token_masks]
g_non_spec = g[:, ~spec_token_masks]
beta_non_spec = beta[:, ~spec_token_masks]
g_spec = g.index_select(1, spec_token_indx)
beta_spec = beta.index_select(1, spec_token_indx)
g_non_spec = g.index_select(1, non_spec_token_indx)
beta_non_spec = beta.index_select(1, non_spec_token_indx)
else:
g_spec = None
beta_spec = None
Expand Down Expand Up @@ -638,8 +640,9 @@ def _forward(
dtype=core_attn_out_non_spec.dtype,
device=core_attn_out_non_spec.device,
)
core_attn_out[:, spec_token_masks] = core_attn_out_spec
core_attn_out[:, ~spec_token_masks] = core_attn_out_non_spec
core_attn_out.index_copy_(1, spec_token_indx, core_attn_out_spec)
core_attn_out.index_copy_(1, non_spec_token_indx, core_attn_out_non_spec)

elif spec_sequence_masks is not None:
core_attn_out = core_attn_out_spec
else:
Expand Down
49 changes: 46 additions & 3 deletions vllm/v1/attention/backends/gdn_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,9 @@ class GDNAttentionMetadata:
spec_token_masks: torch.Tensor | None = (
None # shape: [num_prefill_tokens + num_decode_tokens,]
)
spec_token_indx: torch.Tensor | None = None
non_spec_token_indx: torch.Tensor | None = None

num_accepted_tokens: torch.Tensor | None = None # shape: [batch,]

# The following attributes are for triton implementation of causal_conv1d
Expand Down Expand Up @@ -110,6 +113,16 @@ def __init__(
dtype=torch.bool,
device=device,
)
self.spec_token_indx = torch.empty(
(self.decode_cudagraph_max_bs * (self.num_spec + 1),),
dtype=torch.int32,
device=device,
)
self.non_spec_token_indx = torch.empty(
(self.decode_cudagraph_max_bs * (self.num_spec + 1),),
dtype=torch.int32,
device=device,
)
self.spec_query_start_loc = torch.empty(
(self.decode_cudagraph_max_bs + 1,),
dtype=torch.int32,
Expand Down Expand Up @@ -167,6 +180,8 @@ def build( # type: ignore[override]
)
num_spec_decode_tokens = 0
spec_token_masks = None
spec_token_indx = None
non_spec_token_indx = None
spec_state_indices_tensor = None
non_spec_state_indices_tensor = m.block_table_tensor[:, 0]
spec_query_start_loc = None
Expand All @@ -180,6 +195,9 @@ def build( # type: ignore[override]
num_prefills = non_spec_query_lens.size(0) - num_decodes
num_decode_tokens = num_decodes
num_prefill_tokens = non_spec_query_lens.sum().item() - num_decode_tokens
num_spec_decode_tokens = (
query_lens.sum().item() - num_prefill_tokens - num_decode_tokens
)

if num_prefills == 0 and num_decodes == 0:
spec_token_masks = torch.ones(
Expand All @@ -192,6 +210,14 @@ def build( # type: ignore[override]
dtype=torch.bool,
device=query_start_loc.device,
)
spec_token_indx = torch.arange(
spec_token_masks.size(0),
dtype=torch.int32,
device=query_start_loc.device,
)
non_spec_token_indx = torch.empty(
0, dtype=torch.int32, device=query_start_loc.device
)
spec_state_indices_tensor = m.block_table_tensor[:, : self.num_spec + 1]
non_spec_state_indices_tensor = None
spec_query_start_loc = query_start_loc
Expand All @@ -200,6 +226,11 @@ def build( # type: ignore[override]
spec_token_masks = torch.repeat_interleave(
spec_sequence_masks, query_lens
)
index = torch.argsort(spec_token_masks)
num_non_spec_tokens = num_prefill_tokens + num_decode_tokens
non_spec_token_indx = index[:num_non_spec_tokens]
spec_token_indx = index[num_non_spec_tokens:]

spec_state_indices_tensor = m.block_table_tensor[
spec_sequence_masks, : self.num_spec + 1
]
Expand All @@ -226,9 +257,6 @@ def build( # type: ignore[override]
out=non_spec_query_start_loc[1:],
)

num_spec_decode_tokens = (
query_lens.sum().item() - num_prefill_tokens - num_decode_tokens
)
assert num_accepted_tokens is not None
num_accepted_tokens = num_accepted_tokens[spec_sequence_masks]

Expand Down Expand Up @@ -281,6 +309,19 @@ def build( # type: ignore[override]
spec_token_masks = self.spec_token_masks[:num_actual_tokens]
spec_token_masks[spec_token_masks.size(0) :].fill_(False)

assert non_spec_token_indx is not None and spec_token_indx is not None
self.non_spec_token_indx[: non_spec_token_indx.size(0)].copy_(
non_spec_token_indx, non_blocking=True
)
non_spec_token_indx = self.non_spec_token_indx[
: non_spec_token_indx.size(0)
]

self.spec_token_indx[: spec_token_indx.size(0)].copy_(
spec_token_indx, non_blocking=True
)
spec_token_indx = self.spec_token_indx[: spec_token_indx.size(0)]

self.spec_query_start_loc[: num_spec_decodes + 1].copy_(
spec_query_start_loc, non_blocking=True
)
Expand Down Expand Up @@ -333,6 +374,8 @@ def build( # type: ignore[override]
non_spec_state_indices_tensor=non_spec_state_indices_tensor,
spec_sequence_masks=spec_sequence_masks,
spec_token_masks=spec_token_masks,
spec_token_indx=spec_token_indx,
non_spec_token_indx=non_spec_token_indx,
num_accepted_tokens=num_accepted_tokens,
nums_dict=nums_dict,
batch_ptr=batch_ptr,
Expand Down