@@ -466,10 +466,17 @@ def forward_cuda(
466466 if has_prefill :
467467
468468 initial_states = None
469- if has_initial_states is not None and any (has_initial_states ):
470- for idx in mamba_cache_params .state_indices_tensor [
471- ~ has_initial_states ]:
472- mamba_cache_params .ssm_state [idx ].zero_ ()
469+
470+ if has_initial_states is not None and torch .any (
471+ has_initial_states ):
472+
473+ # vectorized ssm_state zero init
474+ batched_zero_init_func = torch .vmap (
475+ lambda idx : mamba_cache_params .ssm_state [idx ].zero_ ())
476+ batched_zero_init_func (
477+ mamba_cache_params .
478+ state_indices_tensor [~ has_initial_states ].unsqueeze (
479+ dim = - 1 ), )
473480 initial_states = mamba_cache_params .ssm_state [
474481 mamba_cache_params .state_indices_tensor ]
475482
@@ -493,10 +500,17 @@ def forward_cuda(
493500 dt_limit = (0.0 , float ("inf" )),
494501 )
495502
496- # update ssm states
497- # - varlen state is a (batch, nheads, headdim, dstate) tensor
498- for i , idx in enumerate (mamba_cache_params .state_indices_tensor ):
499- mamba_cache_params .ssm_state [idx ].copy_ (varlen_state [i ])
503+ # vectorized ssm state update using vmap
504+ # the 1d state_indices_tensor needs to be unsqueezed to avoid vmap
505+ # limitation which doesn't allow use of `item()`
506+ # Note: the lambda capture can happen where ssm_state is initialized
507+ # instead of here
508+ batched_copy = torch .vmap (
509+ lambda idx , source_state : mamba_cache_params .ssm_state [
510+ idx ].copy_ (source_state ))
511+ batched_copy (
512+ mamba_cache_params .state_indices_tensor .unsqueeze (dim = - 1 ),
513+ varlen_state )
500514
501515 # - reshape
502516 hidden_states = scan_output .view (seq_len , - 1 )
0 commit comments