@@ -1644,56 +1644,62 @@ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unif
16441644
16451645ggml_tensor * llm_graph_context::build_rs (
16461646 ggml_tensor * s,
1647- ggml_tensor * state_copy,
1647+ ggml_tensor * state_copy_main,
1648+ ggml_tensor * state_copy_extra,
16481649 int32_t state_size,
16491650 int32_t n_seqs,
1650- uint32_t n_kv ,
1651- uint32_t kv_head ,
1652- uint32_t kv_size ,
1651+ uint32_t n_rs ,
1652+ uint32_t rs_head ,
1653+ uint32_t rs_size ,
16531654 int32_t rs_zero,
16541655 const llm_graph_get_rows_fn & get_state_rows) const {
16551656
1656- ggml_tensor * states = ggml_reshape_2d (ctx0, s, state_size, kv_size );
1657+ ggml_tensor * states = ggml_reshape_2d (ctx0, s, state_size, rs_size );
16571658
16581659 // Clear a single state which will then be copied to the other cleared states.
16591660 // Note that this is a no-op when the view is zero-sized.
16601661 ggml_tensor * state_zero = ggml_view_1d (ctx0, states, state_size*(rs_zero >= 0 ), rs_zero*states->nb [1 ]*(rs_zero >= 0 ));
16611662 ggml_build_forward_expand (gf, ggml_scale_inplace (ctx0, state_zero, 0 ));
16621663
16631664 // copy states
1664- // NOTE: assuming the copy destinations are ALL contained between kv_head and kv_head + n_kv
1665- // {state_size, kv_size } -> {state_size, n_seqs}
1666- ggml_tensor * output_states = get_state_rows (ctx0, states, ggml_view_1d (ctx0, state_copy, n_seqs, 0 ) );
1665+ // NOTE: assuming the copy destinations are ALL contained between rs_head and rs_head + n_rs
1666+ // {state_size, rs_size } -> {state_size, n_seqs}
1667+ ggml_tensor * output_states = get_state_rows (ctx0, states, state_copy_main );
16671668 ggml_build_forward_expand (gf, output_states);
16681669
1669- // copy extra states which won't be changed further (between n_seqs and n_kv )
1670- ggml_tensor * states_extra = ggml_get_rows (ctx0, states, ggml_view_1d (ctx0, state_copy, n_kv - n_seqs, n_seqs*state_copy-> nb [ 0 ]) );
1670+ // copy extra states which won't be changed further (between n_seqs and n_rs )
1671+ ggml_tensor * states_extra = ggml_get_rows (ctx0, states, state_copy_extra );
16711672 ggml_build_forward_expand (gf,
16721673 ggml_cpy (ctx0,
16731674 states_extra,
1674- ggml_view_1d (ctx0, s, state_size*(n_kv - n_seqs), (kv_head + n_seqs)*state_size*ggml_element_size (s))));
1675+ ggml_view_1d (ctx0, s, state_size*(n_rs - n_seqs), (rs_head + n_seqs)*state_size*ggml_element_size (s))));
16751676
16761677 return output_states;
16771678}
16781679
16791680static std::unique_ptr<llm_graph_input_rs> build_rs_inp_impl (
16801681 ggml_context * ctx0,
1682+ const llama_ubatch & ubatch,
16811683 const llama_memory_recurrent_context * mctx_cur) {
16821684
16831685 auto inp = std::make_unique<llm_graph_input_rs>(mctx_cur);
16841686
1685- const auto n_rs = mctx_cur->get_n_rs ();
1687+ const int64_t n_rs = mctx_cur->get_n_rs ();
1688+ const int64_t n_seqs = ubatch.n_seqs ;
16861689
16871690 inp->s_copy = ggml_new_tensor_1d (ctx0, GGML_TYPE_I32, n_rs);
16881691 ggml_set_input (inp->s_copy );
16891692
1693+ inp->s_copy_main = ggml_view_1d (ctx0, inp->s_copy , n_seqs, 0 );
1694+ inp->s_copy_extra = ggml_view_1d (ctx0, inp->s_copy , n_rs - n_seqs, n_seqs * inp->s_copy ->nb [0 ]);
1695+
16901696 return inp;
16911697}
16921698
16931699llm_graph_input_rs * llm_graph_context::build_rs_inp () const {
16941700 const auto * mctx_cur = static_cast <const llama_memory_recurrent_context *>(mctx);
16951701
1696- auto inp = build_rs_inp_impl (ctx0, mctx_cur);
1702+ auto inp = build_rs_inp_impl (ctx0, ubatch, mctx_cur);
16971703
16981704 return (llm_graph_input_rs *) res->add_input (std::move (inp));
16991705}
@@ -1706,7 +1712,9 @@ ggml_tensor * llm_graph_context::build_rs(
17061712 const llm_graph_get_rows_fn & get_state_rows) const {
17071713 const auto * kv_state = inp->mctx ;
17081714
1709- return build_rs (s, inp->s_copy , state_size, n_seqs, kv_state->get_n_rs (), kv_state->get_head (), kv_state->get_size (), kv_state->get_rs_z (), get_state_rows);
1715+ return build_rs (s, inp->s_copy_main , inp->s_copy_extra , state_size, n_seqs,
1716+ kv_state->get_n_rs (), kv_state->get_head (), kv_state->get_size (), kv_state->get_rs_z (),
1717+ get_state_rows);
17101718}
17111719
17121720ggml_tensor * llm_graph_context::build_rwkv_token_shift_load (
@@ -1753,7 +1761,7 @@ ggml_tensor * llm_graph_context::build_rwkv_token_shift_store(
17531761llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid () const {
17541762 const auto * mctx_cur = static_cast <const llama_memory_hybrid_context *>(mctx);
17551763
1756- auto inp_rs = build_rs_inp_impl (ctx0, mctx_cur->get_recr ());
1764+ auto inp_rs = build_rs_inp_impl (ctx0, ubatch, mctx_cur->get_recr ());
17571765 auto inp_attn = build_attn_inp_kv_unified_impl (ctx0, ubatch, hparams, cparams, mctx_cur->get_attn ());
17581766
17591767 auto inp = std::make_unique<llm_graph_input_mem_hybrid>(std::move (inp_attn), std::move (inp_rs), mctx_cur);
0 commit comments