diff --git a/src/common/transformations/include/transformations/op_conversions/group_query_attention_decomposition.hpp b/src/common/transformations/include/transformations/op_conversions/group_query_attention_decomposition.hpp new file mode 100644 index 00000000000000..ad39d57f32e291 --- /dev/null +++ b/src/common/transformations/include/transformations/op_conversions/group_query_attention_decomposition.hpp @@ -0,0 +1,38 @@ +// Copyright (C) 2018-2025 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include "openvino/op/group_query_attention.hpp" +#include "openvino/op/shape_of.hpp" +#include "openvino/pass/matcher_pass.hpp" +#include "transformations_visibility.hpp" + +namespace ov { +namespace pass { + +class TRANSFORMATIONS_API GroupQueryAttentionDecomposition; + +} // namespace pass +} // namespace ov + +class ov::pass::GroupQueryAttentionDecomposition : public ov::pass::MatcherPass { +public: + OPENVINO_MATCHER_PASS_RTTI("GroupQueryAttentionDecomposition"); + GroupQueryAttentionDecomposition(); + +private: + ov::OutputVector decompose(std::shared_ptr node); + std::shared_ptr get_dimensions(const std::shared_ptr& shape, + const std::vector& dims); + std::shared_ptr get_dimensions(const std::shared_ptr& node, const std::vector& dims); + ov::OutputVector make_split(const ov::Output& value, int64_t num_splits, int64_t axis); + std::shared_ptr rotaryEmbedding(ov::Output input, + ov::Output past_seqlen, + std::shared_ptr seqlen_k, + std::shared_ptr cos_cache, + std::shared_ptr sin_cache, + std::shared_ptr dim_head_size, + bool interleaved); +}; diff --git a/src/common/transformations/src/transformations/common_optimizations/common_optimizations.cpp b/src/common/transformations/src/transformations/common_optimizations/common_optimizations.cpp index 87813bae65538d..66651b3907f344 100644 --- a/src/common/transformations/src/transformations/common_optimizations/common_optimizations.cpp +++ b/src/common/transformations/src/transformations/common_optimizations/common_optimizations.cpp @@ -108,6 +108,7 @@ #include "transformations/op_conversions/eye_decomposition.hpp" #include "transformations/op_conversions/gelu7_downgrade.hpp" #include "transformations/op_conversions/group_normalization_decomposition.hpp" +#include "transformations/op_conversions/group_query_attention_decomposition.hpp" #include "transformations/op_conversions/hsigmoid_decomposition.hpp" #include "transformations/op_conversions/hswish_decomposition.hpp" #include "transformations/op_conversions/log_softmax_decomposition.hpp" @@ -156,6 +157,7 @@ bool ov::pass::CommonOptimizations::run_on_model(const std::shared_ptr(); + ADD_MATCHER(decomp, GroupQueryAttentionDecomposition) ADD_MATCHER(decomp, ScaledDotProductAttentionDecomposition) ADD_MATCHER(decomp, Gelu7Downgrade) ADD_MATCHER(decomp, BidirectionalSequenceDecomposition) diff --git a/src/common/transformations/src/transformations/op_conversions/group_query_attention_decomposition.cpp b/src/common/transformations/src/transformations/op_conversions/group_query_attention_decomposition.cpp new file mode 100644 index 00000000000000..22c9c551908db5 --- /dev/null +++ b/src/common/transformations/src/transformations/op_conversions/group_query_attention_decomposition.cpp @@ -0,0 +1,257 @@ +// Copyright (C) 2018-2025 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "transformations/op_conversions/group_query_attention_decomposition.hpp" + +#include + +#include "itt.hpp" +#include "openvino/core/rt_info.hpp" +#include "openvino/op/add.hpp" +#include "openvino/op/concat.hpp" +#include "openvino/op/constant.hpp" +#include "openvino/op/convert.hpp" +#include "openvino/op/gather.hpp" +#include "openvino/op/greater.hpp" +#include "openvino/op/multiply.hpp" +#include "openvino/op/range.hpp" +#include "openvino/op/reshape.hpp" +#include "openvino/op/scaled_dot_product_attention.hpp" +#include "openvino/op/select.hpp" +#include "openvino/op/shape_of.hpp" +#include "openvino/op/slice.hpp" +#include "openvino/op/split.hpp" +#include "openvino/op/subtract.hpp" +#include "openvino/op/transpose.hpp" +#include "openvino/op/unsqueeze.hpp" +#include "openvino/pass/pattern/op/wrap_type.hpp" + +ov::pass::GroupQueryAttentionDecomposition::GroupQueryAttentionDecomposition() { + MATCHER_SCOPE(GroupQeuryAttentionDecomposition); + auto pattern_node = ov::pass::pattern::wrap_type(); + + matcher_pass_callback callback = [OV_CAPTURE_CPY_AND_THIS](ov::pass::pattern::Matcher& m) { + auto& pattern_to_output = m.get_pattern_value_map(); + auto node = ov::as_type_ptr( + pattern_to_output.at(pattern_node).get_node_shared_ptr()); + + if (node == nullptr || transformation_callback(node)) { + return false; + } + + auto new_output_node = decompose(node); + ov::replace_node(node, new_output_node); + return true; + }; + + auto m = std::make_shared(pattern_node, matcher_name); + register_matcher(m, callback); +} + +ov::OutputVector ov::pass::GroupQueryAttentionDecomposition::decompose( + std::shared_ptr node) { + using namespace ov::op; + + const auto num_heads = node->get_num_heads(); + const auto kv_num_heads = node->get_kv_num_heads(); + const auto scale = node->get_scale(); + const auto do_rotary = node->get_do_rotary(); + const auto rotary_interleaved = node->get_rotary_interleaved(); + // TODO: add softcap support + + auto Q = node->input_value(0); + auto K = node->input_value(1); + auto V = node->input_value(2); + auto past_key = node->input_value(3); + auto past_value = node->input_value(4); + auto seqlens_k = node->input_value(5); + auto cos_cache = node->input_value(6); + auto sin_cache = node->input_value(7); + + // The length of all tokens (past + current) is `seqlens_k` + 1 + // current = Q.shape[2], past = `seqlens_k` + 1 - current + + const auto T = Q.get_element_type(); + const auto q_shape = register_new_node(Q); + const auto current_sequence_length = get_dimensions(q_shape, {2}); + const auto head_size_node = get_dimensions(q_shape, {3}); + + auto zero = register_new_node(v0::Constant::create(ov::element::i64, ov::Shape{1}, {0})); + auto one = register_new_node(v0::Constant::create(ov::element::i64, ov::Shape{1}, {1})); + auto one_without_shape = register_new_node(v0::Constant::create(ov::element::i64, ov::Shape{}, {1})); + auto two = register_new_node(v0::Constant::create(ov::element::i64, ov::Shape{1}, {2})); + auto seqlens_elemi64 = register_new_node(seqlens_k, ov::element::i64); + auto real_seqlens = register_new_node(seqlens_elemi64, one); + + // Only consider batch is 1 + auto seqlens_1d = register_new_node(real_seqlens, one, false); + auto past_sequence_length = register_new_node(seqlens_1d, current_sequence_length); + if (do_rotary) { + Q = rotaryEmbedding(Q, + past_sequence_length, + seqlens_1d, + cos_cache.get_node_shared_ptr(), + sin_cache.get_node_shared_ptr(), + head_size_node, + rotary_interleaved); + K = rotaryEmbedding(K, + past_sequence_length, + seqlens_1d, + cos_cache.get_node_shared_ptr(), + sin_cache.get_node_shared_ptr(), + head_size_node, + rotary_interleaved); + } + + auto construct_kv_cache = [&](const ov::Output& past, const ov::Output& current) { + auto past_datas = register_new_node(past, zero, past_sequence_length, one, two); + auto curr_datas = register_new_node(current, zero, current_sequence_length, one, two); + return register_new_node(ov::NodeVector{past_datas, curr_datas}, 2); + }; + K = construct_kv_cache(past_key, K); + V = construct_kv_cache(past_value, V); + auto present_k = K; + auto present_v = V; + + const size_t kv_num_heads_factor = num_heads / kv_num_heads; + if (kv_num_heads_factor > 1) { + const auto kv_shape = register_new_node(K); + const auto kv_shape_prev_2 = get_dimensions(kv_shape, {0, 1}); + const auto kv_shape_last_2 = get_dimensions(kv_shape, {2, 3}); + auto new_kv_shape = register_new_node(ov::NodeVector{kv_shape_prev_2, one, kv_shape_last_2}, 0); + K = register_new_node(K, new_kv_shape, false); + V = register_new_node(V, new_kv_shape, false); + K = register_new_node(ov::OutputVector(kv_num_heads_factor, K), 2); + V = register_new_node(ov::OutputVector(kv_num_heads_factor, V), 2); + const auto q_shape = register_new_node(Q); + const auto q_shape_prev_2 = get_dimensions(q_shape, {0, 1}); + auto extended_kv_shape = register_new_node(ov::NodeVector{q_shape_prev_2, kv_shape_last_2}, 0); + K = register_new_node(K, extended_kv_shape, false); + V = register_new_node(V, extended_kv_shape, false); + } + + // need to apply low-triangle mask to attention score. + // two steps, construct the total_sequence x total_sequence triangle, then slice the current length + auto seqlens_1d_scalar = register_new_node(seqlens_1d, one_without_shape, false); + std::shared_ptr mask_per_line_node = + register_new_node(register_new_node(v0::Constant::create(ov::element::i64, ov::Shape{}, {0})), + seqlens_1d_scalar, + one_without_shape, + ov::element::i64); + auto hori_range = register_new_node(mask_per_line_node, zero); + auto vert_range = register_new_node(mask_per_line_node, one); + auto triu = register_new_node(hori_range, vert_range); + auto typed_zero = register_new_node(v0::Constant::create(T, ov::Shape{}, {0})); + // cf. make_attention_mask@src\plugins\intel_gpu\tests\common\subgraphs_builders.hpp + std::shared_ptr minus_inf = nullptr; + if (T == ov::element::f32) + minus_inf = register_new_node(v0::Constant::create(T, ov::Shape{}, {-std::numeric_limits::infinity()})); + else if (T == ov::element::f16) + minus_inf = + register_new_node(v0::Constant::create(T, ov::Shape{}, {std::numeric_limits::lowest()})); + auto atten_mask = register_new_node(triu, minus_inf, typed_zero); + auto atten_mask_sliced = register_new_node(atten_mask, past_sequence_length, seqlens_1d, one, zero); + + std::shared_ptr qga_output; + if (scale != 0.0f) { + auto scale_node = register_new_node(v0::Constant::create(T, Shape{}, {scale})); + qga_output = register_new_node(Q, K, V, atten_mask_sliced, scale_node, false); + } else { + qga_output = register_new_node(Q, K, V, atten_mask_sliced, false); + } + + // transpose the result from (batch_size, num_heads, sequence_length, head_size) + // to (batch_size, sequence_length, num_heads * head_size) + auto perm = register_new_node(v0::Constant::create(ov::element::i64, ov::Shape{4}, {0, 2, 1, 3})); + auto qga_output_transposed = register_new_node(qga_output, perm); + auto dim_merge_shape = register_new_node(v0::Constant::create(ov::element::i32, ov::Shape{3}, {0, 0, -1})); + auto output = register_new_node(qga_output_transposed, dim_merge_shape, true)->output(0); + + return {output, present_k, present_v}; +} + +// make split functions is a copy-past from ONNX FE. TODO: move it to one place +ov::OutputVector ov::pass::GroupQueryAttentionDecomposition::make_split(const ov::Output& value, + int64_t num_splits, + int64_t axis) { + using namespace ov::op; + const auto axis_node = register_new_node(v0::Constant::create(ov::element::i64, ov::Shape{}, {axis})); + const auto split = register_new_node(value, axis_node, num_splits); + + return split->outputs(); +} + +std::shared_ptr ov::pass::GroupQueryAttentionDecomposition::get_dimensions( + const std::shared_ptr& shape, + const std::vector& dims) { + using namespace ov::op; + const auto zero = v0::Constant::create(ov::element::i32, ov::Shape{}, {0}); + const auto dims_const = v0::Constant::create(ov::element::i32, ov::Shape{dims.size()}, dims); + return register_new_node(shape, dims_const, zero); +} + +std::shared_ptr ov::pass::GroupQueryAttentionDecomposition::get_dimensions( + const std::shared_ptr& node, + const std::vector& dims) { + return get_dimensions(register_new_node(node), dims); +} + +std::shared_ptr ov::pass::GroupQueryAttentionDecomposition::rotaryEmbedding( + ov::Output input, + ov::Output past_seqlen, + std::shared_ptr seqlen_k, + std::shared_ptr cos_cache, + std::shared_ptr sin_cache, + std::shared_ptr dim_head_size, + bool interleaved) { + using namespace ov::op; + auto zero = v0::Constant::create(ov::element::i64, ov::Shape{1}, {0}); + auto one = v0::Constant::create(ov::element::i64, ov::Shape{1}, {1}); + + auto slice_cache_dim_shape = seqlen_k; + + auto cos = register_new_node(cos_cache, past_seqlen, slice_cache_dim_shape, one, zero); + auto sin = register_new_node(sin_cache, past_seqlen, slice_cache_dim_shape, one, zero); + + if (interleaved) { + auto two = v0::Constant::create(ov::element::i64, ov::Shape{1}, {2}); + + auto cache_shape = register_new_node(cos_cache); + auto cache_last_dim = get_dimensions(cos_cache, {-1}); + + auto input_shape = register_new_node(input); + + auto dim_bns = get_dimensions(input_shape, {0, 1, 2}); + std::shared_ptr half_last_dim = cache_last_dim; + + auto negtive_one = v0::Constant::create(ov::element::i64, ov::Shape{1}, {-1}); + auto split_input_shape = register_new_node(ov::NodeVector{dim_bns, half_last_dim, two}, 0); + auto reshaped_input = register_new_node(input, split_input_shape, false); + + auto in_split = make_split(reshaped_input, 2, -1); + split_input_shape = register_new_node(ov::NodeVector{dim_bns, half_last_dim}, 0); + auto in_split_0 = register_new_node(in_split[0], split_input_shape, false); + auto in_split_1 = register_new_node(in_split[1], split_input_shape, false); + + auto res_0 = register_new_node(register_new_node(in_split_0, cos), + register_new_node(in_split_1, sin)); + auto res_1 = register_new_node(register_new_node(in_split_0, sin), + register_new_node(in_split_1, cos)); + + split_input_shape = register_new_node(ov::NodeVector{dim_bns, half_last_dim, one}, 0); + auto res_0_5d = register_new_node(res_0, split_input_shape, false); + auto res_1_5d = register_new_node(res_1, split_input_shape, false); + + auto concat_ret = register_new_node(ov::NodeVector{res_0_5d, res_1_5d}, -1); + return register_new_node(concat_ret, input_shape, false); + } else { + auto in_split = make_split(input, 2, -1); + auto res_0 = register_new_node(register_new_node(in_split[0], cos), + register_new_node(in_split[1], sin)); + auto res_1 = register_new_node(register_new_node(in_split[0], sin), + register_new_node(in_split[1], cos)); + + return register_new_node(ov::NodeVector{res_0, res_1}, -1); + } +} diff --git a/src/core/dev_api/openvino/op/group_query_attention.hpp b/src/core/dev_api/openvino/op/group_query_attention.hpp new file mode 100644 index 00000000000000..8e450197be4590 --- /dev/null +++ b/src/core/dev_api/openvino/op/group_query_attention.hpp @@ -0,0 +1,50 @@ +// Copyright (C) 2018-2025 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// +#pragma once + +#include "openvino/op/op.hpp" + +namespace ov::op::internal { + +// This is an experimental operation that is implemented in the plugins. +class OPENVINO_API GroupQueryAttention : public Op { +public: + OPENVINO_OP("GroupQueryAttention"); + + GroupQueryAttention() = default; + GroupQueryAttention(const ov::OutputVector& args, + int64_t num_heads, + int64_t kv_num_heads, + float scale, + bool do_rotary, + bool rotary_interleaved); + void validate_and_infer_types() override; + bool visit_attributes(AttributeVisitor& visitor) override; + std::shared_ptr clone_with_new_inputs(const ov::OutputVector& new_args) const override; + + int64_t get_num_heads() const { + return m_num_heads; + } + int64_t get_kv_num_heads() const { + return m_kv_num_heads; + } + float get_scale() const { + return m_scale; + } + bool get_do_rotary() const { + return m_do_rotary; + } + bool get_rotary_interleaved() const { + return m_rotary_interleaved; + } + +private: + int64_t m_num_heads = 0; + int64_t m_kv_num_heads = 0; + float m_scale = 0; + bool m_do_rotary = false; + bool m_rotary_interleaved = false; +}; + +} // namespace ov::op::internal diff --git a/src/core/src/op/group_query_attention.cpp b/src/core/src/op/group_query_attention.cpp new file mode 100644 index 00000000000000..8680a4c4b44087 --- /dev/null +++ b/src/core/src/op/group_query_attention.cpp @@ -0,0 +1,71 @@ +// Copyright (C) 2018-2025 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "openvino/op/group_query_attention.hpp" + +#include "itt.hpp" + +namespace ov::op::internal { + +GroupQueryAttention::GroupQueryAttention(const OutputVector& args, + int64_t num_heads, + int64_t kv_num_heads, + float scale, + bool do_rotary, + bool rotary_interleaved) + : Op(args), + m_num_heads(num_heads), + m_kv_num_heads(kv_num_heads), + m_scale(scale), + m_do_rotary(do_rotary), + m_rotary_interleaved(rotary_interleaved) { + constructor_validate_and_infer_types(); +} + +void GroupQueryAttention::validate_and_infer_types() { + OV_OP_SCOPE(GroupQueryAttention_validate_and_infer_types); + // GQA expectes the following inputs: query, key, value, past_key, past_value, seqlens_k, cos_cache, sin_cache + // All qkv's should have the shape [batch, num_heads, seq_len, head_size] ([B, N, S, H]) + // It has three outputs: output of shape [B, S, N * H], and present_key/value of shape [B, N, S, H] + // seqlens_k is number of 1's in the attention_mask minus 1 + + const auto& q_shape = get_input_partial_shape(0); + const auto& batch_size = q_shape[0]; + const auto& sequence_len = q_shape[2]; + const auto& head_size = q_shape[3]; + const auto& past_sequence_len = get_input_partial_shape(3)[2]; + const auto& output_kv_len = past_sequence_len + sequence_len; + + const auto& element_type = get_input_element_type(0); + NODE_VALIDATION_CHECK(this, + element_type == element::f32 || element_type == element::f16, + "GroupQueryAttention only suuports f32 and f16"); + + set_output_type(0, element_type, PartialShape{batch_size, sequence_len, head_size * m_num_heads}); + set_output_type(1, element_type, PartialShape{batch_size, m_kv_num_heads, output_kv_len, head_size}); + set_output_type(2, element_type, PartialShape{batch_size, m_kv_num_heads, output_kv_len, head_size}); +} + +bool GroupQueryAttention::visit_attributes(AttributeVisitor& visitor) { + OV_OP_SCOPE(GroupQueryAttention_visit_attributes); + visitor.on_attribute("do_rotary", m_do_rotary); + visitor.on_attribute("kv_num_heads", m_kv_num_heads); + visitor.on_attribute("num_heads", m_num_heads); + visitor.on_attribute("rotary_interleaved", m_rotary_interleaved); + visitor.on_attribute("scale", m_scale); + return true; +} + +std::shared_ptr GroupQueryAttention::clone_with_new_inputs(const ov::OutputVector& new_args) const { + OV_OP_SCOPE(GroupQueryAttention_clone_with_new_inputs); + check_new_args_count(this, new_args); + return std::make_shared(new_args, + m_num_heads, + m_kv_num_heads, + m_scale, + m_do_rotary, + m_rotary_interleaved); +} + +} // namespace ov::op::internal diff --git a/src/frontends/onnx/frontend/src/op/com.microsoft/group_query_attention.cpp b/src/frontends/onnx/frontend/src/op/com.microsoft/group_query_attention.cpp new file mode 100644 index 00000000000000..e70369ec7c675d --- /dev/null +++ b/src/frontends/onnx/frontend/src/op/com.microsoft/group_query_attention.cpp @@ -0,0 +1,121 @@ +// Copyright (C) 2018-2025 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "openvino/op/group_query_attention.hpp" + +#include + +#include "core/null_node.hpp" +#include "core/operator_set.hpp" +#include "openvino/op/concat.hpp" +#include "openvino/op/divide.hpp" +#include "openvino/op/gather.hpp" +#include "openvino/op/reshape.hpp" +#include "openvino/op/shape_of.hpp" +#include "openvino/op/transpose.hpp" +#include "utils/common.hpp" +#include "utils/split.hpp" + +using namespace ov::op; + +namespace ov::frontend::onnx::com_microsoft { + +namespace detail { +namespace { +std::shared_ptr get_dimensions(const std::shared_ptr& shape, const std::vector& dims); +} // namespace +} // namespace detail + +namespace opset_1 { +ov::OutputVector group_query_attention(const ov::frontend::onnx::Node& node) { + // At least given "query" and "seqlens_k" + common::default_op_checks(node, 2); + + const auto onnx_op_inputs = node.get_ov_inputs(); + const auto num_heads = node.get_attribute_value("num_heads"); + const auto kv_num_heads = node.get_attribute_value("kv_num_heads"); + const auto scale = node.get_attribute_value("scale", 0.0f); + const auto do_rotary = node.get_attribute_value("do_rotary", 0); + const auto rotary_interleaved = node.get_attribute_value("rotary_interleaved", 0); + + // In ONNX, the format of input QKV is [B, S, N*H] and of past_kv is [B, N, S, H] + // In OV, we always use [B, N, S, H] + auto perm = v0::Constant::create(ov::element::i64, ov::Shape{4}, {0, 2, 1, 3}); + + auto Q = onnx_op_inputs[0]; + auto K = onnx_op_inputs[1]; + auto V = onnx_op_inputs[2]; + const auto q_shape_node = std::make_shared(Q); + const auto batch_size_node = detail::get_dimensions(q_shape_node, {0}); + const auto current_seqlen_size_node = detail::get_dimensions(q_shape_node, {1}); + const auto hidden_size_node = detail::get_dimensions(q_shape_node, {2}); + + OutputVector ov_op_inputs; + if (ov::op::util::is_null(K)) { + auto total_num_heads_node = + v0::Constant::create(ov::element::i64, ov::Shape{1}, {num_heads + kv_num_heads + kv_num_heads}); + auto head_size_node = std::make_shared(hidden_size_node, total_num_heads_node); + auto packed_qkv_shape = std::make_shared( + ov::NodeVector{batch_size_node, current_seqlen_size_node, total_num_heads_node, head_size_node}, + 0); + + auto inputs_qkv = std::make_shared(Q, packed_qkv_shape, false)->output(0); + inputs_qkv = std::make_shared(inputs_qkv, perm); + auto split = ov::op::util::make_split(inputs_qkv, {num_heads, kv_num_heads, kv_num_heads}, 1); + + std::copy(split.begin(), split.end(), std::back_inserter(ov_op_inputs)); + } else { + auto num_heads_node = v0::Constant::create(ov::element::i64, ov::Shape{1}, {num_heads}); + auto head_size_node = std::make_shared(hidden_size_node, num_heads_node); + auto q_shape = std::make_shared( + ov::NodeVector{batch_size_node, current_seqlen_size_node, num_heads_node, head_size_node}, + 0); + + Q = std::make_shared(Q, q_shape, false)->output(0); + Q = std::make_shared(Q, perm); + ov_op_inputs.push_back(Q); + + auto kv_num_heads_node = v0::Constant::create(ov::element::i64, ov::Shape{1}, {kv_num_heads}); + auto kv_shape = std::make_shared( + ov::NodeVector{batch_size_node, current_seqlen_size_node, kv_num_heads_node, head_size_node}, + 0); + + K = std::make_shared(K, kv_shape, false)->output(0); + V = std::make_shared(V, kv_shape, false)->output(0); + K = std::make_shared(K, perm); + V = std::make_shared(V, perm); + ov_op_inputs.push_back(K); + ov_op_inputs.push_back(V); + } + + for (int i = 3; i < 9; ++i) { + // skip total_sequence_length + if (i == 6) + continue; + ov_op_inputs.push_back(onnx_op_inputs[i]); + } + return std::make_shared(ov_op_inputs, + num_heads, + kv_num_heads, + scale, + do_rotary, + rotary_interleaved) + ->outputs(); +} + +ONNX_OP("GroupQueryAttention", OPSET_SINCE(1), com_microsoft::opset_1::group_query_attention, MICROSOFT_DOMAIN); + +} // namespace opset_1 + +namespace detail { +namespace { +std::shared_ptr get_dimensions(const std::shared_ptr& shape, const std::vector& dims) { + static const auto zero = v0::Constant::create(ov::element::i32, ov::Shape{}, {0}); + const auto dims_const = v0::Constant::create(ov::element::i32, ov::Shape{dims.size()}, dims); + return std::make_shared(shape, dims_const, zero); +} +} // namespace +} // namespace detail + +} // namespace ov::frontend::onnx::com_microsoft diff --git a/src/frontends/onnx/tests/models/com.microsoft/gqa_rotary.prototxt b/src/frontends/onnx/tests/models/com.microsoft/gqa_rotary.prototxt new file mode 100644 index 00000000000000..179e7935d7dcd6 --- /dev/null +++ b/src/frontends/onnx/tests/models/com.microsoft/gqa_rotary.prototxt @@ -0,0 +1,181 @@ +ir_version: 10 +graph { + node { + input: "query" + input: "" + input: "" + input: "past_key" + input: "past_value" + input: "seqlens_k" + input: "" + input: "cos_cache" + input: "sin_cache" + output: "output" + output: "present_key" + output: "present_value" + name: "GroupQueryAttention_0" + op_type: "GroupQueryAttention" + attribute { + name: "do_rotary" + i: 1 + type: INT + } + attribute { + name: "kv_num_heads" + i: 1 + type: INT + } + attribute { + name: "local_window_size" + i: -1 + type: INT + } + attribute { + name: "num_heads" + i: 2 + type: INT + } + attribute { + name: "rotary_interleaved" + i: 0 + type: INT + } + attribute { + name: "smooth_softmax" + i: 0 + type: INT + } + attribute { + name: "softcap" + f: 0 + type: FLOAT + } + domain: "com.microsoft" + } + name: "GroupQueryAttention_Graph" + input { + name: "query" + type { + tensor_type { + elem_type: 1 + shape { + dim { dim_param: "batch_size" } + dim { dim_param: "sequence_length" } + dim { dim_value: 64 } + } + } + } + } + input { + name: "past_key" + type { + tensor_type { + elem_type: 1 + shape { + dim { dim_param: "batch_size" } + dim { dim_value: 1 } + dim { dim_param: "past_sequence_length" } + dim { dim_value: 16 } + } + } + } + } + input { + name: "past_value" + type { + tensor_type { + elem_type: 1 + shape { + dim { dim_param: "batch_size" } + dim { dim_value: 1 } + dim { dim_param: "past_sequence_length" } + dim { dim_value: 16 } + } + } + } + } + input { + name: "seqlens_k" + type { + tensor_type { + elem_type: 6 + shape { + dim { dim_param: "batch_size" } + dim { dim_value: 1 } + } + } + } + } + input { + name: "cos_cache" + type { + tensor_type { + elem_type: 1 + shape { + dim { dim_param: "max_sequence_length" } + dim { dim_value: 8 } + } + } + } + } + input { + name: "sin_cache" + type { + tensor_type { + elem_type: 1 + shape { + dim { dim_param: "max_sequence_length" } + dim { dim_value: 8 } + } + } + } + } + output { + name: "output" + type { + tensor_type { + elem_type: 1 + shape { + dim { dim_param: "batch_size" } + dim { dim_param: "sequence_length" } + dim { dim_value: 32 } + } + } + } + } + output { + name: "present_key" + type { + tensor_type { + elem_type: 1 + shape { + dim { dim_param: "batch_size" } + dim { dim_value: 1 } + dim { dim_param: "total_sequence_length" } + dim { dim_value: 16 } + } + } + } + } + output { + name: "present_value" + type { + tensor_type { + elem_type: 1 + shape { + dim { dim_param: "batch_size" } + dim { dim_value: 1 } + dim { dim_param: "total_sequence_length" } + dim { dim_value: 16 } + } + } + } + } +} +opset_import { + version: 11 +} +opset_import { + domain: "com.microsoft" + version: 1 +} diff --git a/src/frontends/onnx/tests/models/com.microsoft/gqa_rotary_interleaved.prototxt b/src/frontends/onnx/tests/models/com.microsoft/gqa_rotary_interleaved.prototxt new file mode 100644 index 00000000000000..9186bafc770c65 --- /dev/null +++ b/src/frontends/onnx/tests/models/com.microsoft/gqa_rotary_interleaved.prototxt @@ -0,0 +1,181 @@ +ir_version: 10 +graph { + node { + input: "query" + input: "" + input: "" + input: "past_key" + input: "past_value" + input: "seqlens_k" + input: "" + input: "cos_cache" + input: "sin_cache" + output: "output" + output: "present_key" + output: "present_value" + name: "GroupQueryAttention_0" + op_type: "GroupQueryAttention" + attribute { + name: "do_rotary" + i: 1 + type: INT + } + attribute { + name: "kv_num_heads" + i: 1 + type: INT + } + attribute { + name: "local_window_size" + i: -1 + type: INT + } + attribute { + name: "num_heads" + i: 2 + type: INT + } + attribute { + name: "rotary_interleaved" + i: 1 + type: INT + } + attribute { + name: "smooth_softmax" + i: 0 + type: INT + } + attribute { + name: "softcap" + f: 0 + type: FLOAT + } + domain: "com.microsoft" + } + name: "GroupQueryAttention_Graph" + input { + name: "query" + type { + tensor_type { + elem_type: 1 + shape { + dim { dim_param: "batch_size" } + dim { dim_param: "sequence_length" } + dim { dim_value: 64 } + } + } + } + } + input { + name: "past_key" + type { + tensor_type { + elem_type: 1 + shape { + dim { dim_param: "batch_size" } + dim { dim_value: 1 } + dim { dim_param: "past_sequence_length" } + dim { dim_value: 16 } + } + } + } + } + input { + name: "past_value" + type { + tensor_type { + elem_type: 1 + shape { + dim { dim_param: "batch_size" } + dim { dim_value: 1 } + dim { dim_param: "past_sequence_length" } + dim { dim_value: 16 } + } + } + } + } + input { + name: "seqlens_k" + type { + tensor_type { + elem_type: 6 + shape { + dim { dim_param: "batch_size" } + dim { dim_value: 1 } + } + } + } + } + input { + name: "cos_cache" + type { + tensor_type { + elem_type: 1 + shape { + dim { dim_param: "max_sequence_length" } + dim { dim_value: 8 } + } + } + } + } + input { + name: "sin_cache" + type { + tensor_type { + elem_type: 1 + shape { + dim { dim_param: "max_sequence_length" } + dim { dim_value: 8 } + } + } + } + } + output { + name: "output" + type { + tensor_type { + elem_type: 1 + shape { + dim { dim_param: "batch_size" } + dim { dim_param: "sequence_length" } + dim { dim_value: 32 } + } + } + } + } + output { + name: "present_key" + type { + tensor_type { + elem_type: 1 + shape { + dim { dim_param: "batch_size" } + dim { dim_value: 1 } + dim { dim_param: "total_sequence_length" } + dim { dim_value: 16 } + } + } + } + } + output { + name: "present_value" + type { + tensor_type { + elem_type: 1 + shape { + dim { dim_param: "batch_size" } + dim { dim_value: 1 } + dim { dim_param: "total_sequence_length" } + dim { dim_value: 16 } + } + } + } + } +} +opset_import { + version: 11 +} +opset_import { + domain: "com.microsoft" + version: 1 +} diff --git a/src/frontends/onnx/tests/onnx_import_com_microsoft.in.cpp b/src/frontends/onnx/tests/onnx_import_com_microsoft.in.cpp index 170476aae05dd3..764e4d8ca12956 100644 --- a/src/frontends/onnx/tests/onnx_import_com_microsoft.in.cpp +++ b/src/frontends/onnx/tests/onnx_import_com_microsoft.in.cpp @@ -1740,3 +1740,407 @@ OPENVINO_TEST(${BACKEND_NAME}, onnx_com_microsoft_bias_add) { test_case.run(); } + +OPENVINO_TEST(${BACKEND_NAME}, onnx_model_gqa_past_0_input_1_rotary) { + const auto model = convert_model("com.microsoft/gqa_rotary.onnx"); + + std::vector query = { + -1.1258f, -1.1524f, -0.2506f, -0.4339f, 0.8487f, 0.6920f, -0.3160f, -2.1152f, 0.3223f, -1.2633f, 0.3500f, + 0.3081f, 0.1198f, 1.2377f, 1.1168f, -0.2473f, -1.3527f, -1.6959f, 0.5667f, 0.7935f, 0.5988f, -1.5551f, + -0.3414f, 1.8530f, 0.7502f, -0.5855f, -0.1734f, 0.1835f, 1.3894f, 1.5863f, 0.9463f, -0.8437f, 1.6459f, + -1.3602f, 0.3446f, 0.5199f, -2.6133f, -1.6965f, -0.2282f, 0.2800f, 0.2469f, 0.0769f, 0.3380f, 0.4544f, + 0.4569f, -0.8654f, 0.7813f, -0.9268f, -0.2188f, -2.4351f, -0.0729f, -0.0340f, 0.9625f, 0.3492f, -0.9215f, + -0.0562f, -0.6227f, -0.4637f, 1.9218f, -0.4025f, 0.1239f, 1.1648f, 0.9234f, 1.3873f, + }; + std::vector past_key = {}; + std::vector past_value = {}; + std::vector seqlens_k = {0}; + std::vector cos_cache = { + 0.8437f, + -0.7849f, + -0.7829f, + 0.4581f, + -0.9870f, + 0.6273f, + -0.9483f, + -0.9962f, + }; + std::vector sin_cache = { + 0.5368f, + 0.6196f, + -0.6222f, + 0.8889f, + 0.1605f, + -0.7788f, + 0.3174f, + -0.0872f, + }; + + std::vector expected_output = {-0.2188f, -2.4351f, -0.0729f, -0.034f, 0.9625f, 0.3492f, -0.9215f, -0.0562f, + -0.6227f, -0.4637f, 1.9218f, -0.4025f, 0.1239f, 1.1648f, 0.9234f, 1.3873f, + -0.2188f, -2.4351f, -0.0729f, -0.034f, 0.9625f, 0.3492f, -0.9215f, -0.0562f, + -0.6227f, -0.4637f, 1.9218f, -0.4025f, 0.1239f, 1.1648f, 0.9234f, 1.3873f}; + + std::vector expected_present_key = {1.2561098f, + 1.0199738f, + -0.05948371f, + -0.16574995f, + 2.5059946f, + -1.738188f, + -0.03158256f, + -0.35975295f, + 1.0918287f, + -0.90313876f, + -0.4790303f, + 0.67029977f, + -0.87039495f, + 0.7783688f, + -0.81333745f, + 0.89886224f}; + + std::vector expected_present_value = {-0.2188f, + -2.4351f, + -0.0729f, + -0.034f, + 0.9625f, + 0.3492f, + -0.9215f, + -0.0562f, + -0.6227f, + -0.4637f, + 1.9218f, + -0.4025f, + 0.1239f, + 1.1648f, + 0.9234f, + 1.3873f}; + + auto test_case = ov::test::TestCase(model, s_device); + test_case.add_input(Shape{1, 1, 64}, query); + test_case.add_input(Shape{1, 1, 0, 16}, past_key); + test_case.add_input(Shape{1, 1, 0, 16}, past_value); + test_case.add_input(Shape{1, 1}, seqlens_k); + test_case.add_input(Shape{1, 8}, cos_cache); + test_case.add_input(Shape{1, 8}, sin_cache); + test_case.add_expected_output(Shape{1, 1, 32}, expected_output); + test_case.add_expected_output(Shape{1, 1, 1, 16}, expected_present_key); + test_case.add_expected_output(Shape{1, 1, 1, 16}, expected_present_value); + test_case.run_with_tolerance_as_fp(); +} + +OPENVINO_TEST(${BACKEND_NAME}, onnx_model_gqa_past_0_input_1_rotary_interleaved) { + const auto model = convert_model("com.microsoft/gqa_rotary_interleaved.onnx"); + + std::vector query = { + -1.1258f, -1.1524f, -0.2506f, -0.4339f, 0.8487f, 0.6920f, -0.3160f, -2.1152f, 0.3223f, -1.2633f, 0.3500f, + 0.3081f, 0.1198f, 1.2377f, 1.1168f, -0.2473f, -1.3527f, -1.6959f, 0.5667f, 0.7935f, 0.5988f, -1.5551f, + -0.3414f, 1.8530f, 0.7502f, -0.5855f, -0.1734f, 0.1835f, 1.3894f, 1.5863f, 0.9463f, -0.8437f, 1.6459f, + -1.3602f, 0.3446f, 0.5199f, -2.6133f, -1.6965f, -0.2282f, 0.2800f, 0.2469f, 0.0769f, 0.3380f, 0.4544f, + 0.4569f, -0.8654f, 0.7813f, -0.9268f, -0.2188f, -2.4351f, -0.0729f, -0.0340f, 0.9625f, 0.3492f, -0.9215f, + -0.0562f, -0.6227f, -0.4637f, 1.9218f, -0.4025f, 0.1239f, 1.1648f, 0.9234f, 1.3873f, + }; + std::vector past_key = {}; + std::vector past_value = {}; + std::vector seqlens_k = {0}; + std::vector cos_cache = { + 0.8437f, + -0.7849f, + -0.7829f, + 0.4581f, + -0.9870f, + 0.6273f, + -0.9483f, + -0.9962f, + }; + std::vector sin_cache = { + 0.5368f, + 0.6196f, + -0.6222f, + 0.8889f, + 0.1605f, + -0.7788f, + 0.3174f, + -0.0872f, + }; + + std::vector expected_output = {-0.2188f, -2.4351f, -0.0729f, -0.034f, 0.9625f, 0.3492f, -0.9215f, -0.0562f, + -0.6227f, -0.4637f, 1.9218f, -0.4025f, 0.1239f, 1.1648f, 0.9234f, 1.3873f, + -0.2188f, -2.4351f, -0.0729f, -0.034f, 0.9625f, 0.3492f, -0.9215f, -0.0562f, + -0.6227f, -0.4637f, 1.9218f, -0.4025f, 0.1239f, 1.1648f, 0.9234f, 1.3873f}; + + std::vector expected_present_key = {2.118801f, + -0.2640816f, + -0.5926066f, + -0.19455537f, + 0.9903903f, + 2.954185f, + -0.35343042f, + -0.07457897f, + -0.25603274f, + -0.03627284f, + 0.56591415f, + 0.02181074f, + -0.1586003f, + 0.96567893f, + -0.8591481f, + 0.85514885f}; + + std::vector expected_present_value = {-0.2188f, + -2.4351f, + -0.0729f, + -0.034f, + 0.9625f, + 0.3492f, + -0.9215f, + -0.0562f, + -0.6227f, + -0.4637f, + 1.9218f, + -0.4025f, + 0.1239f, + 1.1648f, + 0.9234f, + 1.3873f}; + + auto test_case = ov::test::TestCase(model, s_device); + test_case.add_input(Shape{1, 1, 64}, query); + test_case.add_input(Shape{1, 1, 0, 16}, past_key); + test_case.add_input(Shape{1, 1, 0, 16}, past_value); + test_case.add_input(Shape{1, 1}, seqlens_k); + test_case.add_input(Shape{1, 8}, cos_cache); + test_case.add_input(Shape{1, 8}, sin_cache); + test_case.add_expected_output(Shape{1, 1, 32}, expected_output); + test_case.add_expected_output(Shape{1, 1, 1, 16}, expected_present_key); + test_case.add_expected_output(Shape{1, 1, 1, 16}, expected_present_value); + test_case.run_with_tolerance_as_fp(); +} + +OPENVINO_TEST(${BACKEND_NAME}, onnx_model_gqa_past_1_input_1_rotary) { + const auto model = convert_model("com.microsoft/gqa_rotary.onnx"); + + std::vector query = { + -1.1258f, -1.1524f, -0.2506f, -0.4339f, 0.8487f, 0.6920f, -0.3160f, -2.1152f, 0.3223f, -1.2633f, 0.3500f, + 0.3081f, 0.1198f, 1.2377f, 1.1168f, -0.2473f, -1.3527f, -1.6959f, 0.5667f, 0.7935f, 0.5988f, -1.5551f, + -0.3414f, 1.8530f, 0.7502f, -0.5855f, -0.1734f, 0.1835f, 1.3894f, 1.5863f, 0.9463f, -0.8437f, 1.6459f, + -1.3602f, 0.3446f, 0.5199f, -2.6133f, -1.6965f, -0.2282f, 0.2800f, 0.2469f, 0.0769f, 0.3380f, 0.4544f, + 0.4569f, -0.8654f, 0.7813f, -0.9268f, -0.2188f, -2.4351f, -0.0729f, -0.0340f, 0.9625f, 0.3492f, -0.9215f, + -0.0562f, -0.6227f, -0.4637f, 1.9218f, -0.4025f, 0.1239f, 1.1648f, 0.9234f, 1.3873f, + }; + std::vector past_key = { + -0.6136f, + 0.0316f, + -0.4927f, + 0.2484f, + 0.4397f, + 0.1124f, + 0.6408f, + 0.4412f, + -0.1023f, + 0.7924f, + -0.2897f, + 0.0525f, + 0.5229f, + 2.3022f, + -1.4689f, + -1.5867f, + }; + std::vector past_value = { + -0.5692f, + 0.9200f, + 1.1108f, + 1.2899f, + -1.4782f, + 2.5672f, + -0.4731f, + 0.3356f, + -1.6293f, + -0.5497f, + -0.4798f, + -0.4997f, + -1.0670f, + 1.1149f, + -0.1407f, + 0.8058f, + }; + std::vector seqlens_k = {1}; + std::vector cos_cache = { + 0.8437f, + -0.7849f, + -0.7829f, + 0.4581f, + -0.9870f, + 0.6273f, + -0.9483f, + -0.9962f, + -0.9635f, + -0.8046f, + 0.4139f, + 0.9863f, + 0.4117f, + 0.9874f, + -0.9743f, + 0.9494f, + }; + std::vector sin_cache = { + 0.5368f, + 0.6196f, + -0.6222f, + 0.8889f, + 0.1605f, + -0.7788f, + 0.3174f, + -0.0872f, + 0.2677f, + -0.5938f, + -0.9103f, + -0.1650f, + -0.9113f, + -0.1583f, + 0.2253f, + 0.3140f, + }; + + std::vector expected_output = { + -0.53934956f, 0.6341806f, 1.0099611f, 1.1771176f, -1.270278f, 2.3782496f, -0.511299f, 0.30222273f, + -1.5435482f, -0.5423737f, -0.27520883f, -0.4914196f, -0.96554786f, 1.1191509f, -0.05004983f, 0.85533774f, + -0.49356747f, 0.19581467f, 0.8553029f, 1.0041412f, -0.9513843f, 2.088453f, -0.5698854f, 0.25103146f, + -1.4120293f, -0.5311372f, 0.03857604f, -0.47871974f, -0.8099488f, 1.1256707f, 0.08898184f, 0.93131447f}; + + std::vector expected_present_key = { + -0.6136f, 0.0316f, -0.4927f, 0.2484f, 0.4397f, 0.1124f, 0.6408f, 0.4412f, + -0.1023f, 0.7924f, -0.2897f, 0.0525f, 0.5229f, 2.3022f, -1.4689f, -1.5867f, + -1.6519198f, 1.1400802f, 0.45031136f, 0.5877534f, -0.65952265f, -1.8121169f, 0.04630837f, 0.5568472f, + 0.20271924f, 0.7458131f, -0.17379119f, 0.3623912f, 2.5696063f, -0.58594f, -0.8126341f, -0.7919839f}; + + std::vector expected_present_value = { + -0.5692f, 0.9200f, 1.1108f, 1.2899f, -1.4782f, 2.5672f, -0.4731f, 0.3356f, -1.6293f, -0.5497f, -0.4798f, + -0.4997f, -1.0670f, 1.1149f, -0.1407f, 0.8058f, -0.2188f, -2.4351f, -0.0729f, -0.034f, 0.9625f, 0.3492f, + -0.9215f, -0.0562f, -0.6227f, -0.4637f, 1.9218f, -0.4025f, 0.1239f, 1.1648f, 0.9234f, 1.3873f}; + + auto test_case = ov::test::TestCase(model, s_device); + test_case.add_input(Shape{1, 1, 64}, query); + test_case.add_input(Shape{1, 1, 1, 16}, past_key); + test_case.add_input(Shape{1, 1, 1, 16}, past_value); + test_case.add_input(Shape{1, 1}, seqlens_k); + test_case.add_input(Shape{2, 8}, cos_cache); + test_case.add_input(Shape{2, 8}, sin_cache); + test_case.add_expected_output(Shape{1, 1, 32}, expected_output); + test_case.add_expected_output(Shape{1, 1, 2, 16}, expected_present_key); + test_case.add_expected_output(Shape{1, 1, 2, 16}, expected_present_value); + test_case.run_with_tolerance_as_fp(); +} + +OPENVINO_TEST(${BACKEND_NAME}, onnx_model_gqa_past_1_input_1_rotary_interleaved) { + const auto model = convert_model("com.microsoft/gqa_rotary_interleaved.onnx"); + + std::vector query = { + -1.1258f, -1.1524f, -0.2506f, -0.4339f, 0.8487f, 0.6920f, -0.3160f, -2.1152f, 0.3223f, -1.2633f, 0.3500f, + 0.3081f, 0.1198f, 1.2377f, 1.1168f, -0.2473f, -1.3527f, -1.6959f, 0.5667f, 0.7935f, 0.5988f, -1.5551f, + -0.3414f, 1.8530f, 0.7502f, -0.5855f, -0.1734f, 0.1835f, 1.3894f, 1.5863f, 0.9463f, -0.8437f, 1.6459f, + -1.3602f, 0.3446f, 0.5199f, -2.6133f, -1.6965f, -0.2282f, 0.2800f, 0.2469f, 0.0769f, 0.3380f, 0.4544f, + 0.4569f, -0.8654f, 0.7813f, -0.9268f, -0.2188f, -2.4351f, -0.0729f, -0.0340f, 0.9625f, 0.3492f, -0.9215f, + -0.0562f, -0.6227f, -0.4637f, 1.9218f, -0.4025f, 0.1239f, 1.1648f, 0.9234f, 1.3873f, + }; + std::vector past_key = { + -0.6136f, + 0.0316f, + -0.4927f, + 0.2484f, + 0.4397f, + 0.1124f, + 0.6408f, + 0.4412f, + -0.1023f, + 0.7924f, + -0.2897f, + 0.0525f, + 0.5229f, + 2.3022f, + -1.4689f, + -1.5867f, + }; + std::vector past_value = { + -0.5692f, + 0.9200f, + 1.1108f, + 1.2899f, + -1.4782f, + 2.5672f, + -0.4731f, + 0.3356f, + -1.6293f, + -0.5497f, + -0.4798f, + -0.4997f, + -1.0670f, + 1.1149f, + -0.1407f, + 0.8058f, + }; + std::vector seqlens_k = {1}; + std::vector cos_cache = { + 0.8437f, + -0.7849f, + -0.7829f, + 0.4581f, + -0.9870f, + 0.6273f, + -0.9483f, + -0.9962f, + -0.9635f, + -0.8046f, + 0.4139f, + 0.9863f, + 0.4117f, + 0.9874f, + -0.9743f, + 0.9494f, + }; + std::vector sin_cache = { + 0.5368f, + 0.6196f, + -0.6222f, + 0.8889f, + 0.1605f, + -0.7788f, + 0.3174f, + -0.0872f, + 0.2677f, + -0.5938f, + -0.9103f, + -0.1650f, + -0.9113f, + -0.1583f, + 0.2253f, + 0.3140f, + }; + + std::vector expected_output = { + -0.33396345f, -1.332403f, 0.31613833f, 0.40111685f, 0.16033238f, 1.0781744f, -0.7741276f, 0.07257013f, + -0.9535321f, -0.491965f, 1.1324831f, -0.43444604f, -0.2675047f, 1.1483997f, 0.57366973f, 1.1961825f, + -0.24709277f, -2.164195f, 0.02267693f, 0.07289726f, 0.7654276f, 0.5282906f, -0.8852943f, -0.02456442f, + -0.7039771f, -0.47064403f, 1.7278847f, -0.41034833f, 0.02774171f, 1.1607709f, 0.83748007f, 1.3403473f}; + + std::vector expected_present_key = { + -0.6136f, 0.0316f, -0.4927f, 0.2484f, 0.4397f, 0.1124f, 0.6408f, 0.4412f, + -0.1023f, 0.7924f, -0.2897f, 0.0525f, 0.5229f, 2.3022f, -1.4689f, -1.5867f, + -1.2216992f, 1.7511603f, 0.03145146f, -0.62293506f, -2.625969f, 1.6767058f, -0.17887366f, 0.313817f, + 0.1717277f, -0.19334024f, 0.4056727f, 0.39516917f, -0.25018305f, 0.9460988f, 1.0327814f, -0.6345757f}; + + std::vector expected_present_value = { + -0.5692f, 0.9200f, 1.1108f, 1.2899f, -1.4782f, 2.5672f, -0.4731f, 0.3356f, -1.6293f, -0.5497f, -0.4798f, + -0.4997f, -1.0670f, 1.1149f, -0.1407f, 0.8058f, -0.2188f, -2.4351f, -0.0729f, -0.034f, 0.9625f, 0.3492f, + -0.9215f, -0.0562f, -0.6227f, -0.4637f, 1.9218f, -0.4025f, 0.1239f, 1.1648f, 0.9234f, 1.3873f}; + + auto test_case = ov::test::TestCase(model, s_device); + test_case.add_input(Shape{1, 1, 64}, query); + test_case.add_input(Shape{1, 1, 1, 16}, past_key); + test_case.add_input(Shape{1, 1, 1, 16}, past_value); + test_case.add_input(Shape{1, 1}, seqlens_k); + test_case.add_input(Shape{2, 8}, cos_cache); + test_case.add_input(Shape{2, 8}, sin_cache); + test_case.add_expected_output(Shape{1, 1, 32}, expected_output); + test_case.add_expected_output(Shape{1, 1, 2, 16}, expected_present_key); + test_case.add_expected_output(Shape{1, 1, 2, 16}, expected_present_value); + test_case.run_with_tolerance_as_fp(); +}