Skip to content

Conversation

@wine99
Copy link
Contributor

@wine99 wine99 commented Mar 27, 2025

The KV cache handling logic differs between dynamic and static shapes.

In the case of dynamic shapes, the KV cache buffer only holds valid data. So it only needs a ConcatOP
For static shapes, the valid data is stored at the end of the buffer, with the beginning of the buffer being set to 0. So the ConcatOP will make the buffer greater than buffer size, it need slice the real size data.

The following scripts work for both CPU GPU (dynamic) and NPU (static)

import onnxruntime as rt
import os
import numpy as np
import time

import onnxruntime.tools.add_openvino_win_libs as utils
utils.add_openvino_libs_to_path()
from transformers import PreTrainedTokenizerFast

LOOP_TIME = 2
NUM_INFERENCE = 16 # how many 2nd token

def get_average_time(time_list):
    return (sum(time_list) - max(time_list) - min(time_list)) / (len(time_list) - 2)

GTA = False
test_phi3 = True
test_lama3 = False
is_npu = False

if test_phi3:
    gta_modelPath = os.path.join('C:\\', 'Users', 'gta', 'Downloads', 'Phi-3-mini-4k-instruct-onnx', 'model.onnx')
    if is_npu:
        gta_modelPath = os.path.join('C:\\', 'Users', 'gta', 'Downloads', 'Phi-3-mini-4k-instruct-onnx-rows-newalgo-int4', 'model.onnx')
    gta_tokenizerPath = os.path.join('C:\\', 'Users', 'gta', 'Downloads', 'Phi-3-mini-4k-instruct-onnx', 'tokenizer.json')
    server_modelPath = os.path.join('D:\\', 'models', 'llm', 'Phi-3-mini-4k-instruct-onnx', 'model.onnx')
    server_tokenizerPath = os.path.join('D:\\', 'models', 'llm', 'Phi-3-mini-4k-instruct-onnx', 'tokenizer.json')

if test_lama3:
    gta_modelPath = os.path.join('C:\\', 'Users', 'gta', 'Downloads', 'llama3.1-8B-instruct-onnx', 'model.onnx')
    gta_tokenizerPath = os.path.join('C:\\', 'Users', 'gta', 'Downloads', 'llama3.1-8B-instruct-onnx', 'tokenizer.json')
    server_modelPath = os.path.join('D:\\', 'models', 'llm', 'llama3.1-8B-instruct-onnx', 'model.onnx')
    server_tokenizerPath = os.path.join('D:\\', 'models', 'llm', 'llama3.1-8B-instruct-onnx', 'tokenizer.json')

if GTA:
    modelPath = gta_modelPath
    tokenizerPath = gta_tokenizerPath
else:
    modelPath = server_modelPath
    tokenizerPath = server_tokenizerPath

so = rt.SessionOptions()
# so.log_severity_level = 3
# so.enable_profiling = False

sess = rt.InferenceSession(modelPath, so, providers=['CPUExecutionProvider'])
#sess = rt.InferenceSession(modelPath, so, providers=['OpenVINOExecutionProvider'], provider_options=[{'device_type' : "CPU", 'cache_dir': "cpucache"}])
# sess = rt.InferenceSession(modelPath, so, providers=['OpenVINOExecutionProvider'], provider_options=[{'device_type' : "NPU", 'load_config':'{ "NPU": { "NPUW_CACHE_DIR": "ncache", "NPUW_DEVICES": "NPU", "NPU_USE_NPUW": "YES", "NPUW_DUMP_IO":"NO", "NPUW_FOLD": "YES", "NPUW_DUMP_SUBS": "NO", "NPUW_DQ": "YES", "NPUW_HOST_GATHER": "NO", "NPU_COMPILATION_MODE_PARAMS": "compute-layers-with-higher-precision=Sqrt,ReduceMean,Power,RMS" } }'}])
tokenizer = PreTrainedTokenizerFast(tokenizer_file=tokenizerPath)

outputs = sess.get_outputs()
output_names = list(map(lambda output: output.name, outputs))


def get_phi3_param():
    num_layers = 32
    batch_size = 1
    num_heads = 32
    sequence_length = 512
    hidden_size = 96
    return num_layers, batch_size, num_heads, sequence_length, hidden_size

def get_llama3_param():
    num_layers = 32
    batch_size = 1
    num_heads = 8
    sequence_length = 512
    hidden_size = 128
    return num_layers, batch_size, num_heads, sequence_length, hidden_size

if test_phi3:
    num_layers, batch_size, num_heads, sequence_length, hidden_size = get_phi3_param()

if test_lama3:
    num_layers, batch_size, num_heads, sequence_length, hidden_size = get_llama3_param()

def create_numpy_inputs(inputToken):
    tokenLen = len(inputToken)
    npinput_ids = np.array([inputToken], dtype=np.int64)
    if is_npu:
        npattention_mask = np.array([[1] * tokenLen + [0] * (sequence_length - tokenLen)], dtype=np.int64)
    else:
        npattention_mask = np.array([[1] * tokenLen], dtype=np.int64)
    return npinput_ids, npattention_mask

def init_npinput(inputToken):
    flattened_past_key_values = {}
    for index in range(num_layers):
        if is_npu:
            key_state = np.zeros((batch_size, num_heads, sequence_length - len(inputToken), hidden_size), dtype=np.float32)
            value_state = np.zeros((batch_size, num_heads, sequence_length - len(inputToken), hidden_size), dtype=np.float32)
        else:
            key_state = np.zeros((batch_size, num_heads, 0, hidden_size), dtype=np.float32)
            value_state = np.zeros((batch_size, num_heads, 0, hidden_size), dtype=np.float32)
        flattened_past_key_values[f'past_key_values.{index}.key'] = key_state
        flattened_past_key_values[f'past_key_values.{index}.value'] = value_state
    flattened_past_key_values['input_ids'], flattened_past_key_values['attention_mask'] = create_numpy_inputs(inputToken)
    return flattened_past_key_values


for loop_idx in range(LOOP_TIME):
    first_token_time_list = []
    secod_token_time_list = []
    print(f"start loop {loop_idx}")

    if test_phi3:
        input = """<|user|> 
The Sun is yellow because <|end|>
<|assistant|>
"""
    if test_lama3:
        input = """<|begin_of_text|><|user|>
The Sun is yellow because <|end|>
<|assistant|>
"""

    inputToken = tokenizer.encode(input)
    history_tokens = inputToken
    flattened_past_key_values = init_npinput(inputToken)
    lastTokenLen = len(inputToken)

    before = time.time()
    results = sess.run(output_names, flattened_past_key_values)
    after = time.time()
    first_token_time_list.append(int((after - before) * 1000))

    last_generated_token = np.argmax(results[0][-1, -1, :], axis=-1)
    print(last_generated_token ,end=' ')
    history_tokens.append(last_generated_token)
    for i in range(NUM_INFERENCE):
        # update kvcahe
        for index in range(len(output_names)):
            if not output_names[index].startswith('present'):
                continue
            outputname = output_names[index]
            inputname = outputname.replace('present', 'past_key_values')
            flattened_past_key_values[inputname] = results[index]

        # update input token
        flattened_past_key_values[f'input_ids'] = np.array([[last_generated_token]], dtype=np.int64)
        if is_npu:
            flattened_past_key_values[f'attention_mask'] = np.array([[1] * len(history_tokens) + [0] * (sequence_length - len(history_tokens))], dtype=np.int64)
        else:
            flattened_past_key_values[f'attention_mask'] = np.array([[1] * len(history_tokens)], dtype=np.int64)

        before = time.time()
        results = sess.run(output_names, flattened_past_key_values)
        after = time.time()
        secod_token_time_list.append(int((after - before) * 1000))

        last_generated_token = np.argmax(results[0][-1, -1, :], axis=-1)
        print(last_generated_token ,end=' ')
        history_tokens.append(last_generated_token)

    print(f"1st token times: {first_token_time_list}, avg {first_token_time_list[0]} ms")
    print(f"2nd token times: {secod_token_time_list}, avg {int(get_average_time(secod_token_time_list))} ms")

print(tokenizer.decode(history_tokens))
print(f"loop {LOOP_TIME} times finished")

@wine99 wine99 requested review from a team as code owners March 27, 2025 04:02
@wine99 wine99 requested review from itikhono and removed request for a team March 27, 2025 04:02
@wine99 wine99 marked this pull request as draft March 27, 2025 04:02
@github-actions github-actions bot added category: Core OpenVINO Core (aka ngraph) category: transformations OpenVINO Runtime library - Transformations labels Mar 27, 2025
@sys-openvino-ci sys-openvino-ci added the ExternalPR External contributor label Mar 27, 2025
@mlukasze mlukasze requested a review from mitruska March 27, 2025 05:48
@wine99 wine99 marked this pull request as ready for review April 8, 2025 06:13

ov::Output<ov::Node> present_k;
ov::Output<ov::Node> present_v;
if (is_static_input) {

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we use "ShapeOf -> Gather" to cut K.get_partial_shape()[2].get_length() and cover dynamic case?
for static case "ShapeOf -> Gather" will be const folded, for dynamic case "ShapeOf -> Gather" remains in the graph

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Such case seems to be already handled by get_dimensions helper and stored as concat_kv_len,
Looks like the idea of this change, is to introduce different behavior for static and dynamic shapes intentionally, and treat static shape as indicator of "maximum sequence length", while dynamic as the actual size that is changing between inference calls.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The NPU doesn't support dynamic shape, which causes a runtime issue when doing type inference. At the beginning of the implementation, we used gatherOp, but the SliceOp can get better performance.

Copy link
Contributor

@mitruska mitruska left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The main concern regarding proposed changes is that the shape inference for static and dynamic shape is not unified, and the "static" shape at the operator level is used as a flag to comply with plugin specific requirements (CPU/GPU vs NPU). The shared GQA op should be rather plugin-independent by design. As I understand, in the proposed changes, static shape is assumed to be an "NPU" case where kv sequence len dimension means "maximum" sequence length, while dynamic is assumed to be the actual sequence length to be supported by CPU/GPU.
It may lead to unexpected behaviour.

Have you considered any alternative solutions, like transformation level flag/fallback for proper decomposition for the target plugin?


ov::Output<ov::Node> present_k;
ov::Output<ov::Node> present_v;
if (is_static_input) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Such case seems to be already handled by get_dimensions helper and stored as concat_kv_len,
Looks like the idea of this change, is to introduce different behavior for static and dynamic shapes intentionally, and treat static shape as indicator of "maximum sequence length", while dynamic as the actual size that is changing between inference calls.

@sgbihu
Copy link
Contributor

sgbihu commented Apr 14, 2025

Have you considered any alternative solutions, like transformation level flag/fallback for proper decomposition for the target plugin?

@mitruska We have considered the logic that does not depend on device type. It just depends on the shape. So if a static shape for CPU, this logic also works.

And I have changed some logic to make the shape infer more reasonable. The return shape should be same as input. Could you help review it and give me more inputs?

Copy link
Contributor

@praasz praasz left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, for core part

Copy link
Contributor

@mitruska mitruska left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can see the goal of the proposed changes, but I'm not fully convinced that this is a long term solution. As GQA is still a part of dev API, it can be modified, but at some point we may need to make it public and compatible with frontend frameworks, for example ONNX Attention.
If those changes are needed right now, then I don't want to block it, but I recommend to ensure that other plugins and transformation team approve such approach as possible to maintain for common GQA op.
cc: @a-sidorova @itikhono @jane-intel

@zhangYiIntel
Copy link
Contributor

For the correctness, I think the PR is good and we should add test like

src/common/transformations/tests/op_conversions/scaled_dot_product_decomposition_test.cpp

But one question about this PR is that the usage of kv-cache is much different from the current stateful model method. In stateful model, the kv-cache is maintained internally inside the stateful model by CPU/GPU which avoids memory-copy between devices and multi-query is already support in stateful model. Now with this PR the kv-cache management is done by application. In terms of reducing memory I/O, I think it's a sub-optimal way compared to stateful model method.

@sgbihu
Copy link
Contributor

sgbihu commented Apr 27, 2025

For the correctness, I think the PR is good and we should add test like

src/common/transformations/tests/op_conversions/scaled_dot_product_decomposition_test.cpp

For the test, we already added the GQA's test in PR28163 and this is a follow up for NPU.

But one question about this PR is that the usage of kv-cache is much different from the current stateful model method. In stateful model, the kv-cache is maintained internally inside the stateful model by CPU/GPU which avoids memory-copy between devices and multi-query is already support in stateful model. Now with this PR the kv-cache management is done by application. In terms of reducing memory I/O, I think it's a sub-optimal way compared to stateful model method.

The script wants to simulate the GenAI feature and only for test purpose. The GQA doesn't care about the assign/readvalue OP for stateless model, it just an operator. So only implement the GQA related part and reserver things will be covered by ORT-GenAI.

Copy link
Contributor

@zhangYiIntel zhangYiIntel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, but I still have a question about KV-cache management.

v0::Constant::create(ov::element::i64, ov::Shape{1}, {K.get_partial_shape()[2].get_length()}));
const auto past_kv_len_const = register_new_node(
v0::Constant::create(ov::element::i64, ov::Shape{1}, {past_key.get_partial_shape()[2].get_length()}));
past_key = register_new_node<v8::Slice>(past_key, current_kv_len_const, past_kv_len_const, one, two);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we have document which designs the cache layout for static shape ?

From first glimpse, we may think the cache grows afterwards, which is

index:
0->past_len->cur_len
data layout:
[past cache]|[current cache]

However, the code here assumes that past data is placed after current data, I think the memory growth direction is different from ordinary thinking. It's better that we could have a document or an agreement about this

index:
0->cur_len->past_len
data layout:
[current cache]|[past cache]

Copy link
Contributor

@sgbihu sgbihu Apr 27, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We only describe the logic in the PR description. And your understanding is not correct, the latest cache always at the end of the buffer. This part wants to pop the 0 at begin of the buffer. Then L120 is the concat logic.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Even if the concat part is the real concat of past_kv + cur_kv, the layout of past_kv cache is still confusing, why 0s are padding before past_kv, will 0s will be padded after past_kv in some other implementation ? The problem here is that we apply a strong assumption about the layout of past_kv, but there is no document about this assumption.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated comments.

@sgbihu
Copy link
Contributor

sgbihu commented May 21, 2025

build_jenkins

@sgbihu
Copy link
Contributor

sgbihu commented May 21, 2025

build_jenkins

@mlukasze mlukasze removed the request for review from e-nugmanova May 21, 2025 09:01
@sgbihu
Copy link
Contributor

sgbihu commented May 21, 2025

build_jenkins

@sgbihu
Copy link
Contributor

sgbihu commented May 21, 2025

build_jenkins

@sgbihu sgbihu enabled auto-merge May 21, 2025 23:38
@sgbihu sgbihu added this pull request to the merge queue May 22, 2025
Merged via the queue into openvinotoolkit:master with commit 8e77c28 May 22, 2025
213 of 217 checks passed
@sgbihu sgbihu deleted the gqa_npu branch May 22, 2025 05:45
@mlukasze mlukasze added this to the 2025.2 milestone May 22, 2025
@mlukasze
Copy link
Contributor

Thank you @wine99 we got it!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

category: Core OpenVINO Core (aka ngraph) category: transformations OpenVINO Runtime library - Transformations ExternalPR External contributor

Projects

None yet

Development

Successfully merging this pull request may close these issues.

10 participants