Skip to content

Commit b610b5f

Browse files
committed
trigger stateful path for Phisilica model
Co-author: Beheshti, Nazanin
1 parent b8a1e82 commit b610b5f

File tree

2 files changed

+25
-15
lines changed

2 files changed

+25
-15
lines changed

onnxruntime/core/providers/openvino/ov_interface.cc

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -360,10 +360,10 @@ void OVInferRequest::Infer() {
360360

361361
StatefulOVInferRequest::StatefulOVInferRequest(ov::InferRequest infer_request, std::string device)
362362
: OVInferRequest(std::move(infer_request)), target_device(device) {
363-
bool gpu_or_npu = ((device.find("NPU") != std::string::npos) || (device.find("GPU") != std::string::npos));
364-
if (gpu_or_npu) {
365-
prefill_use_full_chat_history = true;
366-
}
363+
// bool gpu_or_npu = ((device.find("NPU") != std::string::npos) || (device.find("GPU") != std::string::npos));
364+
// if (gpu_or_npu) {
365+
// prefill_use_full_chat_history = true;
366+
// }
367367
}
368368

369369
void StatefulOVInferRequest::FillTensor(const std::string& tensor_name, const ov::element::Type& type,

onnxruntime/core/providers/openvino/ov_stateful_patch_utils.cc

Lines changed: 21 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,14 @@ void FuseCacheReorder(std::shared_ptr<ov::Model> ov_model,
7272
main_input_name = "input_ids";
7373
}
7474

75+
if (ModelHasInputOutputNames(ov_model, "input_hidden_states")) {
76+
main_input_name = "input_hidden_states";
77+
}
78+
79+
if (ModelHasInputOutputNames(ov_model, "/model/embed_tokens/Gather_output_0")) {
80+
main_input_name = "/model/embed_tokens/Gather_output_0";
81+
}
82+
7583
auto input_batch = ov_model->input(main_input_name).get_partial_shape()[0];
7684

7785
auto beam_idx = std::make_shared<ov::opset13::Parameter>(ov::element::i32, ov::PartialShape({std::move(input_batch)}));
@@ -121,20 +129,22 @@ void MakeStateful(std::shared_ptr<ov::Model>& ov_model,
121129
void PatchStatefulDecoder(std::shared_ptr<ov::Model> model) {
122130
std::vector<std::string> key_value_input_names;
123131
std::vector<std::string> not_kv_inputs;
124-
for (const ov::Output<ov::Node>& input : model->inputs()) {
125-
auto& names = input.get_names();
126-
127-
bool found = false;
128-
for (auto& name : names) {
129-
if (name.find("key_values") != std::string::npos) {
130-
key_value_input_names.push_back(name);
132+
const auto& params = model->get_parameters();
133+
bool found = false;
134+
for (auto i = 0; i < params.size(); i++) {
135+
auto param_name = params.at(i)->output(0).get_any_name();
136+
if (param_name.find("key_values") != std::string::npos) {
137+
key_value_input_names.push_back(param_name);
138+
found = true;
139+
} else if (param_name.find("key") != std::string::npos) {
140+
key_value_input_names.push_back(param_name);
141+
found = true;
142+
} else if (param_name.find("value") != std::string::npos) {
143+
key_value_input_names.push_back(param_name);
131144
found = true;
132-
break;
133-
}
134145
}
135-
136146
if (!found) {
137-
not_kv_inputs.push_back(input.get_any_name());
147+
not_kv_inputs.push_back(param_name);
138148
}
139149
}
140150

0 commit comments

Comments
 (0)