Skip to content

Commit fe66b34

Browse files
authored
[Model] Mamba2 Prefill Performance Tweaks: Fixing Flurry of Unnecessary Memory Copies (#14778)
Signed-off-by: Chih-Chieh-Yang <[email protected]>
1 parent 270a5da commit fe66b34

File tree

1 file changed

+22
-8
lines changed

1 file changed

+22
-8
lines changed

vllm/model_executor/layers/mamba/mamba_mixer2.py

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -466,10 +466,17 @@ def forward_cuda(
466466
if has_prefill:
467467

468468
initial_states = None
469-
if has_initial_states is not None and any(has_initial_states):
470-
for idx in mamba_cache_params.state_indices_tensor[
471-
~has_initial_states]:
472-
mamba_cache_params.ssm_state[idx].zero_()
469+
470+
if has_initial_states is not None and torch.any(
471+
has_initial_states):
472+
473+
# vectorized ssm_state zero init
474+
batched_zero_init_func = torch.vmap(
475+
lambda idx: mamba_cache_params.ssm_state[idx].zero_())
476+
batched_zero_init_func(
477+
mamba_cache_params.
478+
state_indices_tensor[~has_initial_states].unsqueeze(
479+
dim=-1), )
473480
initial_states = mamba_cache_params.ssm_state[
474481
mamba_cache_params.state_indices_tensor]
475482

@@ -493,10 +500,17 @@ def forward_cuda(
493500
dt_limit=(0.0, float("inf")),
494501
)
495502

496-
# update ssm states
497-
# - varlen state is a (batch, nheads, headdim, dstate) tensor
498-
for i, idx in enumerate(mamba_cache_params.state_indices_tensor):
499-
mamba_cache_params.ssm_state[idx].copy_(varlen_state[i])
503+
# vectorized ssm state update using vmap
504+
# the 1d state_indices_tensor needs to be unsqueezed to avoid vmap
505+
# limitation which doesn't allow use of `item()`
506+
# Note: the lambda capture can happen where ssm_state is initialized
507+
# instead of here
508+
batched_copy = torch.vmap(
509+
lambda idx, source_state: mamba_cache_params.ssm_state[
510+
idx].copy_(source_state))
511+
batched_copy(
512+
mamba_cache_params.state_indices_tensor.unsqueeze(dim=-1),
513+
varlen_state)
500514

501515
# - reshape
502516
hidden_states = scan_output.view(seq_len, -1)

0 commit comments

Comments
 (0)