-
Notifications
You must be signed in to change notification settings - Fork 2.8k
[Op][Transformations] Adjustment of internal GQA op shape infer and decomposition to Enable NPU #29766
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
|
||
| ov::Output<ov::Node> present_k; | ||
| ov::Output<ov::Node> present_v; | ||
| if (is_static_input) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we use "ShapeOf -> Gather" to cut K.get_partial_shape()[2].get_length() and cover dynamic case?
for static case "ShapeOf -> Gather" will be const folded, for dynamic case "ShapeOf -> Gather" remains in the graph
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Such case seems to be already handled by get_dimensions helper and stored as concat_kv_len,
Looks like the idea of this change, is to introduce different behavior for static and dynamic shapes intentionally, and treat static shape as indicator of "maximum sequence length", while dynamic as the actual size that is changing between inference calls.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The NPU doesn't support dynamic shape, which causes a runtime issue when doing type inference. At the beginning of the implementation, we used gatherOp, but the SliceOp can get better performance.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The main concern regarding proposed changes is that the shape inference for static and dynamic shape is not unified, and the "static" shape at the operator level is used as a flag to comply with plugin specific requirements (CPU/GPU vs NPU). The shared GQA op should be rather plugin-independent by design. As I understand, in the proposed changes, static shape is assumed to be an "NPU" case where kv sequence len dimension means "maximum" sequence length, while dynamic is assumed to be the actual sequence length to be supported by CPU/GPU.
It may lead to unexpected behaviour.
Have you considered any alternative solutions, like transformation level flag/fallback for proper decomposition for the target plugin?
|
|
||
| ov::Output<ov::Node> present_k; | ||
| ov::Output<ov::Node> present_v; | ||
| if (is_static_input) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Such case seems to be already handled by get_dimensions helper and stored as concat_kv_len,
Looks like the idea of this change, is to introduce different behavior for static and dynamic shapes intentionally, and treat static shape as indicator of "maximum sequence length", while dynamic as the actual size that is changing between inference calls.
@mitruska We have considered the logic that does not depend on device type. It just depends on the shape. So if a static shape for CPU, this logic also works. And I have changed some logic to make the shape infer more reasonable. The return shape should be same as input. Could you help review it and give me more inputs? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK, for core part
...ansformations/include/transformations/op_conversions/group_query_attention_decomposition.hpp
Outdated
Show resolved
Hide resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I can see the goal of the proposed changes, but I'm not fully convinced that this is a long term solution. As GQA is still a part of dev API, it can be modified, but at some point we may need to make it public and compatible with frontend frameworks, for example ONNX Attention.
If those changes are needed right now, then I don't want to block it, but I recommend to ensure that other plugins and transformation team approve such approach as possible to maintain for common GQA op.
cc: @a-sidorova @itikhono @jane-intel
|
For the correctness, I think the PR is good and we should add test like But one question about this PR is that the usage of kv-cache is much different from the current stateful model method. In stateful model, the kv-cache is maintained internally inside the stateful model by CPU/GPU which avoids memory-copy between devices and multi-query is already support in stateful model. Now with this PR the kv-cache management is done by application. In terms of reducing memory I/O, I think it's a sub-optimal way compared to stateful model method. |
For the test, we already added the GQA's test in PR28163 and this is a follow up for NPU.
The script wants to simulate the GenAI feature and only for test purpose. The GQA doesn't care about the assign/readvalue OP for stateless model, it just an operator. So only implement the GQA related part and reserver things will be covered by ORT-GenAI. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, but I still have a question about KV-cache management.
| v0::Constant::create(ov::element::i64, ov::Shape{1}, {K.get_partial_shape()[2].get_length()})); | ||
| const auto past_kv_len_const = register_new_node( | ||
| v0::Constant::create(ov::element::i64, ov::Shape{1}, {past_key.get_partial_shape()[2].get_length()})); | ||
| past_key = register_new_node<v8::Slice>(past_key, current_kv_len_const, past_kv_len_const, one, two); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we have document which designs the cache layout for static shape ?
From first glimpse, we may think the cache grows afterwards, which is
index:
0->past_len->cur_len
data layout:
[past cache]|[current cache]
However, the code here assumes that past data is placed after current data, I think the memory growth direction is different from ordinary thinking. It's better that we could have a document or an agreement about this
index:
0->cur_len->past_len
data layout:
[current cache]|[past cache]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We only describe the logic in the PR description. And your understanding is not correct, the latest cache always at the end of the buffer. This part wants to pop the 0 at begin of the buffer. Then L120 is the concat logic.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Even if the concat part is the real concat of past_kv + cur_kv, the layout of past_kv cache is still confusing, why 0s are padding before past_kv, will 0s will be padded after past_kv in some other implementation ? The problem here is that we apply a strong assumption about the layout of past_kv, but there is no document about this assumption.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updated comments.
|
build_jenkins |
|
build_jenkins |
|
build_jenkins |
Co-authored-by: Pawel Raasz <[email protected]>
|
build_jenkins |
8e77c28
|
Thank you @wine99 we got it! |
The KV cache handling logic differs between dynamic and static shapes.
In the case of dynamic shapes, the KV cache buffer only holds valid data. So it only needs a ConcatOP
For static shapes, the valid data is stored at the end of the buffer, with the beginning of the buffer being set to 0. So the ConcatOP will make the buffer greater than buffer size, it need slice the real size data.
The following scripts work for both CPU GPU (dynamic) and NPU (static)