@@ -72,6 +72,14 @@ void FuseCacheReorder(std::shared_ptr<ov::Model> ov_model,
7272 main_input_name = " input_ids" ;
7373 }
7474
75+ if (ModelHasInputOutputNames (ov_model, " input_hidden_states" )) {
76+ main_input_name = " input_hidden_states" ;
77+ }
78+
79+ if (ModelHasInputOutputNames (ov_model, " /model/embed_tokens/Gather_output_0" )) {
80+ main_input_name = " /model/embed_tokens/Gather_output_0" ;
81+ }
82+
7583 auto input_batch = ov_model->input (main_input_name).get_partial_shape ()[0 ];
7684
7785 auto beam_idx = std::make_shared<ov::opset13::Parameter>(ov::element::i32 , ov::PartialShape ({std::move (input_batch)}));
@@ -121,20 +129,22 @@ void MakeStateful(std::shared_ptr<ov::Model>& ov_model,
121129void PatchStatefulDecoder (std::shared_ptr<ov::Model> model) {
122130 std::vector<std::string> key_value_input_names;
123131 std::vector<std::string> not_kv_inputs;
124- for (const ov::Output<ov::Node>& input : model->inputs ()) {
125- auto & names = input.get_names ();
126-
127- bool found = false ;
128- for (auto & name : names) {
129- if (name.find (" key_values" ) != std::string::npos) {
130- key_value_input_names.push_back (name);
132+ const auto & params = model->get_parameters ();
133+ bool found = false ;
134+ for (auto i = 0 ; i < params.size (); i++) {
135+ auto param_name = params.at (i)->output (0 ).get_any_name ();
136+ if (param_name.find (" key_values" ) != std::string::npos) {
137+ key_value_input_names.push_back (param_name);
138+ found = true ;
139+ } else if (param_name.find (" key" ) != std::string::npos) {
140+ key_value_input_names.push_back (param_name);
141+ found = true ;
142+ } else if (param_name.find (" value" ) != std::string::npos) {
143+ key_value_input_names.push_back (param_name);
131144 found = true ;
132- break ;
133- }
134145 }
135-
136146 if (!found) {
137- not_kv_inputs.push_back (input. get_any_name () );
147+ not_kv_inputs.push_back (param_name );
138148 }
139149 }
140150
0 commit comments