Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,7 @@ class ov::pass::GroupQueryAttentionDecomposition : public ov::pass::MatcherPass
std::shared_ptr<ov::Node> get_dimensions(const std::shared_ptr<ov::Node>& node, const std::vector<int>& dims);
ov::OutputVector make_split(const ov::Output<ov::Node>& value, int64_t num_splits, int64_t axis);
std::shared_ptr<ov::Node> rotaryEmbedding(ov::Output<ov::Node> input,
ov::Output<ov::Node> past_seqlen,
std::shared_ptr<ov::Node> seqlen_k,
std::shared_ptr<ov::Node> cos_cache,
std::shared_ptr<ov::Node> sin_cache,
std::shared_ptr<ov::Node> dim_head_size,
ov::Output<ov::Node> cos,
ov::Output<ov::Node> sin,
bool interleaved);
};
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,13 @@
#include "openvino/core/graph_util.hpp"
#include "openvino/core/rt_info.hpp"
#include "openvino/op/add.hpp"
#include "openvino/op/broadcast.hpp"
#include "openvino/op/concat.hpp"
#include "openvino/op/constant.hpp"
#include "openvino/op/convert.hpp"
#include "openvino/op/gather.hpp"
#include "openvino/op/greater.hpp"
#include "openvino/op/greater_eq.hpp"
#include "openvino/op/multiply.hpp"
#include "openvino/op/range.hpp"
#include "openvino/op/reshape.hpp"
Expand All @@ -23,6 +25,7 @@
#include "openvino/op/shape_of.hpp"
#include "openvino/op/slice.hpp"
#include "openvino/op/split.hpp"
#include "openvino/op/squeeze.hpp"
#include "openvino/op/subtract.hpp"
#include "openvino/op/transpose.hpp"
#include "openvino/op/unsqueeze.hpp"
Expand Down Expand Up @@ -70,51 +73,65 @@ ov::OutputVector ov::pass::GroupQueryAttentionDecomposition::decompose(
auto cos_cache = node->input_value(6);
auto sin_cache = node->input_value(7);

// The length of all tokens (past + current) is `seqlens_k` + 1
// The length of all tokens (past + current) is `seqlens_k` + 1.
// current = Q.shape[2], past = `seqlens_k` + 1 - current

const auto T = Q.get_element_type();
const auto q_shape = register_new_node<v3::ShapeOf>(Q);
const auto current_sequence_length = get_dimensions(q_shape, {2});
const auto current_seqlen = get_dimensions(q_shape, {2});
const auto head_size_node = get_dimensions(q_shape, {3});

auto zero = register_new_node(v0::Constant::create(ov::element::i64, ov::Shape{1}, {0}));
auto one = register_new_node(v0::Constant::create(ov::element::i64, ov::Shape{1}, {1}));
auto one_without_shape = register_new_node(v0::Constant::create(ov::element::i64, ov::Shape{}, {1}));
auto two = register_new_node(v0::Constant::create(ov::element::i64, ov::Shape{1}, {2}));
auto seqlens_elemi64 = register_new_node<v0::Convert>(seqlens_k, ov::element::i64);
auto real_seqlens = register_new_node<v1::Add>(seqlens_elemi64, one);
const auto zero = register_new_node(v0::Constant::create(ov::element::i64, ov::Shape{1}, {0}));
const auto zero_without_shape = register_new_node(v0::Constant::create(ov::element::i64, ov::Shape{}, {0}));
const auto one = register_new_node(v0::Constant::create(ov::element::i64, ov::Shape{1}, {1}));
const auto one_without_shape = register_new_node(v0::Constant::create(ov::element::i64, ov::Shape{}, {1}));
const auto two = register_new_node(v0::Constant::create(ov::element::i64, ov::Shape{1}, {2}));
const auto seqlens_elemi64 = register_new_node<v0::Convert>(seqlens_k, ov::element::i64);
const auto real_seqlens = register_new_node<v1::Add>(seqlens_elemi64, one);

// Only consider batch is 1
auto seqlens_1d = register_new_node<v1::Reshape>(real_seqlens, one, false);
auto past_sequence_length = register_new_node<v1::Subtract>(seqlens_1d, current_sequence_length);
const auto seqlens_1d = register_new_node<v1::Reshape>(real_seqlens, one, false);
const auto past_seqlen = register_new_node<v1::Subtract>(seqlens_1d, current_seqlen);
const auto curr_seqlen_scalar = register_new_node<v0::Squeeze>(current_seqlen);

if (do_rotary) {
Q = rotaryEmbedding(Q,
past_sequence_length,
seqlens_1d,
cos_cache.get_node_shared_ptr(),
sin_cache.get_node_shared_ptr(),
head_size_node,
rotary_interleaved);
K = rotaryEmbedding(K,
past_sequence_length,
seqlens_1d,
cos_cache.get_node_shared_ptr(),
sin_cache.get_node_shared_ptr(),
head_size_node,
rotary_interleaved);
ov::Output<ov::Node> position_ids =
register_new_node<v4::Range>(zero_without_shape, curr_seqlen_scalar, one_without_shape, ov::element::i64);
position_ids = register_new_node<v1::Add>(position_ids, past_seqlen);

const auto cos = register_new_node<v8::Gather>(cos_cache, position_ids, zero);
const auto sin = register_new_node<v8::Gather>(sin_cache, position_ids, zero);
Q = rotaryEmbedding(Q, cos, sin, rotary_interleaved);
K = rotaryEmbedding(K, cos, sin, rotary_interleaved);
}
const auto is_static_input = K.get_partial_shape().is_static() && past_key.get_partial_shape().is_static();

auto construct_kv_cache = [&](const ov::Output<ov::Node>& past, const ov::Output<ov::Node>& current) {
auto past_datas = register_new_node<v8::Slice>(past, zero, past_sequence_length, one, two);
auto curr_datas = register_new_node<v8::Slice>(current, zero, current_sequence_length, one, two);
return register_new_node<v0::Concat>(ov::NodeVector{past_datas, curr_datas}, 2);
return register_new_node<v0::Concat>(ov::OutputVector{past, current}, 2);
};
if (is_static_input) {
// Cache memory layout for static shapes:
// - Keys: [0, ..., 0, past_key[0], ..., past_key[N-1], K[0], ..., K[M-1]]
// - Values: [0, ..., 0, past_value[0], ..., past_value[N-1], V[0], ..., V[M-1]]
// Here, padding 0 are lay on front of the buffer.
// M = current_seqlen, which is always 1 for the KV cache model.
const auto current_kv_len_const = register_new_node(
v0::Constant::create(ov::element::i64, ov::Shape{1}, {K.get_partial_shape()[2].get_length()}));
const auto past_kv_len_const = register_new_node(
v0::Constant::create(ov::element::i64, ov::Shape{1}, {past_key.get_partial_shape()[2].get_length()}));
Copy link
Contributor

@CuriousPanCake CuriousPanCake May 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just for the safety reasons, before calling get_length() on past_key's dimensions, I'd check the shape to be static.

past_key = register_new_node<v8::Slice>(past_key, current_kv_len_const, past_kv_len_const, one, two);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we have document which designs the cache layout for static shape ?

From first glimpse, we may think the cache grows afterwards, which is

index:
0->past_len->cur_len
data layout:
[past cache]|[current cache]

However, the code here assumes that past data is placed after current data, I think the memory growth direction is different from ordinary thinking. It's better that we could have a document or an agreement about this

index:
0->cur_len->past_len
data layout:
[current cache]|[past cache]

Copy link
Contributor

@sgbihu sgbihu Apr 27, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We only describe the logic in the PR description. And your understanding is not correct, the latest cache always at the end of the buffer. This part wants to pop the 0 at begin of the buffer. Then L120 is the concat logic.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Even if the concat part is the real concat of past_kv + cur_kv, the layout of past_kv cache is still confusing, why 0s are padding before past_kv, will 0s will be padded after past_kv in some other implementation ? The problem here is that we apply a strong assumption about the layout of past_kv, but there is no document about this assumption.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated comments.

past_value = register_new_node<v8::Slice>(past_value, current_kv_len_const, past_kv_len_const, one, two);
}
K = construct_kv_cache(past_key, K);
V = construct_kv_cache(past_value, V);
auto present_k = K;
auto present_v = V;

ov::Output<ov::Node> present_k = K;
ov::Output<ov::Node> present_v = V;

const auto concat_kv_len = get_dimensions(K.get_node_shared_ptr(), {2});
const auto concat_kv_len_scalar = register_new_node<v0::Squeeze>(concat_kv_len);

// Broadcast KV if grouped query attention
const size_t kv_num_heads_factor = num_heads / kv_num_heads;
if (kv_num_heads_factor > 1) {
const auto kv_shape = register_new_node<v3::ShapeOf>(K);
Expand All @@ -132,34 +149,44 @@ ov::OutputVector ov::pass::GroupQueryAttentionDecomposition::decompose(
V = register_new_node<v1::Reshape>(V, extended_kv_shape, false);
}

// need to apply low-triangle mask to attention score.
// two steps, construct the total_sequence x total_sequence triangle, then slice the current length
auto seqlens_1d_scalar = register_new_node<v1::Reshape>(seqlens_1d, one_without_shape, false);
std::shared_ptr<ov::Node> mask_per_line_node =
register_new_node<v4::Range>(register_new_node(v0::Constant::create(ov::element::i64, ov::Shape{}, {0})),
seqlens_1d_scalar,
one_without_shape,
ov::element::i64);
auto hori_range = register_new_node<v0::Unsqueeze>(mask_per_line_node, zero);
auto vert_range = register_new_node<v0::Unsqueeze>(mask_per_line_node, one);
auto triu = register_new_node<v1::Greater>(hori_range, vert_range);
auto typed_zero = register_new_node(v0::Constant::create(T, ov::Shape{}, {0}));
// Make attention mask
std::shared_ptr<ov::Node> mask;

std::shared_ptr<ov::Node> hori_range =
register_new_node<v4::Range>(zero_without_shape, concat_kv_len_scalar, one_without_shape, ov::element::i64);
hori_range = register_new_node<v0::Unsqueeze>(hori_range, zero);

std::shared_ptr<ov::Node> vert_range =
register_new_node<v4::Range>(zero_without_shape, curr_seqlen_scalar, one_without_shape, ov::element::i64);
vert_range = register_new_node<v0::Unsqueeze>(vert_range, one);
const auto past_k_node_len = get_dimensions(past_key.get_node_shared_ptr(), {2});
vert_range = register_new_node<v1::Add>(vert_range, past_k_node_len);

const auto triu = register_new_node<v1::Greater>(hori_range, vert_range);
const auto typed_zero = register_new_node(v0::Constant::create(T, ov::Shape{}, {0}));
// cf. make_attention_mask@src\plugins\intel_gpu\tests\common\subgraphs_builders.hpp
std::shared_ptr<ov::Node> minus_inf = nullptr;
if (T == ov::element::f32)
minus_inf = register_new_node(v0::Constant::create(T, ov::Shape{}, {-std::numeric_limits<float>::infinity()}));
else if (T == ov::element::f16)
minus_inf =
register_new_node(v0::Constant::create(T, ov::Shape{}, {std::numeric_limits<ov::float16>::lowest()}));
auto atten_mask = register_new_node<v1::Select>(triu, minus_inf, typed_zero);
auto atten_mask_sliced = register_new_node<v8::Slice>(atten_mask, past_sequence_length, seqlens_1d, one, zero);
mask = register_new_node<v1::Select>(triu, minus_inf, typed_zero);

if (is_static_input) {
const auto padding_len = register_new_node<v1::Subtract>(concat_kv_len, seqlens_1d);
const auto padding_mask_vert_shape = register_new_node<v0::Concat>(ov::NodeVector{current_seqlen, one}, 0);
const auto padding_mask_vert = register_new_node<v3::Broadcast>(padding_len, padding_mask_vert_shape);
const auto padding_mask = register_new_node<v1::GreaterEqual>(hori_range, padding_mask_vert);
mask = register_new_node<v1::Select>(padding_mask, mask, minus_inf);
}

std::shared_ptr<ov::Node> qga_output;
if (scale != 0.0f) {
auto scale_node = register_new_node(v0::Constant::create(T, Shape{}, {scale}));
qga_output = register_new_node<v13::ScaledDotProductAttention>(Q, K, V, atten_mask_sliced, scale_node, false);
qga_output = register_new_node<v13::ScaledDotProductAttention>(Q, K, V, mask, scale_node, false);
} else {
qga_output = register_new_node<v13::ScaledDotProductAttention>(Q, K, V, atten_mask_sliced, false);
qga_output = register_new_node<v13::ScaledDotProductAttention>(Q, K, V, mask, false);
}

// transpose the result from (batch_size, num_heads, sequence_length, head_size)
Expand Down Expand Up @@ -198,40 +225,26 @@ std::shared_ptr<ov::Node> ov::pass::GroupQueryAttentionDecomposition::get_dimens
return get_dimensions(register_new_node<ov::op::v3::ShapeOf>(node), dims);
}

std::shared_ptr<ov::Node> ov::pass::GroupQueryAttentionDecomposition::rotaryEmbedding(
ov::Output<ov::Node> input,
ov::Output<ov::Node> past_seqlen,
std::shared_ptr<ov::Node> seqlen_k,
std::shared_ptr<ov::Node> cos_cache,
std::shared_ptr<ov::Node> sin_cache,
std::shared_ptr<ov::Node> dim_head_size,
bool interleaved) {
std::shared_ptr<ov::Node> ov::pass::GroupQueryAttentionDecomposition::rotaryEmbedding(ov::Output<ov::Node> input,
ov::Output<ov::Node> cos,
ov::Output<ov::Node> sin,
bool interleaved) {
using namespace ov::op;
auto zero = v0::Constant::create(ov::element::i64, ov::Shape{1}, {0});
auto one = v0::Constant::create(ov::element::i64, ov::Shape{1}, {1});

auto slice_cache_dim_shape = seqlen_k;

auto cos = register_new_node<v8::Slice>(cos_cache, past_seqlen, slice_cache_dim_shape, one, zero);
auto sin = register_new_node<v8::Slice>(sin_cache, past_seqlen, slice_cache_dim_shape, one, zero);

if (interleaved) {
auto two = v0::Constant::create(ov::element::i64, ov::Shape{1}, {2});

auto cache_shape = register_new_node<v3::ShapeOf>(cos_cache);
auto cache_last_dim = get_dimensions(cos_cache, {-1});

auto cos_last_dim = get_dimensions(cos.get_node_shared_ptr(), {-1});
auto input_shape = register_new_node<v3::ShapeOf>(input);

auto dim_bns = get_dimensions(input_shape, {0, 1, 2});
std::shared_ptr<ov::Node> half_last_dim = cache_last_dim;

auto negtive_one = v0::Constant::create(ov::element::i64, ov::Shape{1}, {-1});
auto split_input_shape = register_new_node<v0::Concat>(ov::NodeVector{dim_bns, half_last_dim, two}, 0);
auto split_input_shape = register_new_node<v0::Concat>(ov::NodeVector{dim_bns, cos_last_dim, two}, 0);
auto reshaped_input = register_new_node<v1::Reshape>(input, split_input_shape, false);

auto in_split = make_split(reshaped_input, 2, -1);
split_input_shape = register_new_node<v0::Concat>(ov::NodeVector{dim_bns, half_last_dim}, 0);
split_input_shape = register_new_node<v0::Concat>(ov::NodeVector{dim_bns, cos_last_dim}, 0);
auto in_split_0 = register_new_node<v1::Reshape>(in_split[0], split_input_shape, false);
auto in_split_1 = register_new_node<v1::Reshape>(in_split[1], split_input_shape, false);

Expand All @@ -240,7 +253,7 @@ std::shared_ptr<ov::Node> ov::pass::GroupQueryAttentionDecomposition::rotaryEmbe
auto res_1 = register_new_node<v1::Add>(register_new_node<v1::Multiply>(in_split_0, sin),
register_new_node<v1::Multiply>(in_split_1, cos));

split_input_shape = register_new_node<v0::Concat>(ov::NodeVector{dim_bns, half_last_dim, one}, 0);
split_input_shape = register_new_node<v0::Concat>(ov::NodeVector{dim_bns, cos_last_dim, one}, 0);
auto res_0_5d = register_new_node<v1::Reshape>(res_0, split_input_shape, false);
auto res_1_5d = register_new_node<v1::Reshape>(res_1, split_input_shape, false);

Expand Down
19 changes: 14 additions & 5 deletions src/core/src/op/group_query_attention.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,17 +25,26 @@ GroupQueryAttention::GroupQueryAttention(const OutputVector& args,

void GroupQueryAttention::validate_and_infer_types() {
OV_OP_SCOPE(GroupQueryAttention_validate_and_infer_types);
// GQA expectes the following inputs: query, key, value, past_key, past_value, seqlens_k, cos_cache, sin_cache
// All qkv's should have the shape [batch, num_heads, seq_len, head_size] ([B, N, S, H])
// It has three outputs: output of shape [B, S, N * H], and present_key/value of shape [B, N, S, H]
// seqlens_k is number of 1's in the attention_mask minus 1
// GroupQueryAttention expects the following inputs:
// query, key, value, past_key, past_value, seqlens_k, cos_cache, sin_cache.
// All qkv tensors should have the shape [batch, num_heads, seq_len, head_size] ([B, N, S, H]).
// The operation produces three outputs:
// 1. Output tensor of shape [B, S, N * H].
// 2. Present_key tensor of shape [B, N, S, H].
// 3. Present_value tensor of shape [B, N, S, H].
// Note: seqlens_k represents the number of 1's in the attention_mask minus 1.

const auto& q_shape = get_input_partial_shape(0);
const auto& batch_size = q_shape[0];
const auto& sequence_len = q_shape[2];
const auto& head_size = q_shape[3];
const auto& past_sequence_len = get_input_partial_shape(3)[2];
const auto& output_kv_len = past_sequence_len + sequence_len;

auto output_kv_len = past_sequence_len;
if (past_sequence_len.is_dynamic() || sequence_len.is_dynamic()) {
// For dynamic shapes, concatenate the past and current sequence lengths.
output_kv_len += sequence_len;
}

const auto& element_type = get_input_element_type(0);
NODE_VALIDATION_CHECK(this,
Expand Down
Loading