Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 8 additions & 22 deletions vllm/model_executor/layers/mamba/mamba_mixer2.py
Original file line number Diff line number Diff line change
Expand Up @@ -466,17 +466,10 @@ def forward_cuda(
if has_prefill:

initial_states = None

if has_initial_states is not None and torch.any(
has_initial_states):

# vectorized ssm_state zero init
batched_zero_init_func = torch.vmap(
lambda idx: mamba_cache_params.ssm_state[idx].zero_())
batched_zero_init_func(
mamba_cache_params.
state_indices_tensor[~has_initial_states].unsqueeze(
dim=-1), )
if has_initial_states is not None and any(has_initial_states):
for idx in mamba_cache_params.state_indices_tensor[
~has_initial_states]:
mamba_cache_params.ssm_state[idx].zero_()
initial_states = mamba_cache_params.ssm_state[
mamba_cache_params.state_indices_tensor]

Expand All @@ -500,17 +493,10 @@ def forward_cuda(
dt_limit=(0.0, float("inf")),
)

# vectorized ssm state update using vmap
# the 1d state_indices_tensor needs to be unsqueezed to avoid vmap
# limitation which doesn't allow use of `item()`
# Note: the lambda capture can happen where ssm_state is initialized
# instead of here
batched_copy = torch.vmap(
lambda idx, source_state: mamba_cache_params.ssm_state[
idx].copy_(source_state))
batched_copy(
mamba_cache_params.state_indices_tensor.unsqueeze(dim=-1),
varlen_state)
# update ssm states
# - varlen state is a (batch, nheads, headdim, dstate) tensor
for i, idx in enumerate(mamba_cache_params.state_indices_tensor):
mamba_cache_params.ssm_state[idx].copy_(varlen_state[i])

# - reshape
hidden_states = scan_output.view(seq_len, -1)
Expand Down