-
Notifications
You must be signed in to change notification settings - Fork 5k
Faster beam_search sampling
#1243
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Refine the KV cache update logic for more intelligent and efficient updating.
ggerganov
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Clever implementation!
Some minor comments and then I will take a more deeper look + testing to make sure this works correctly. But I understand the idea and agree this would be faster
This is not slower than the new approach, because all that |
You're right, that's my bad. I misunderstood what the |
|
In certain scenarios, it might be more efficient to just swap the pointers instead of actually moving the memory around. For instance, let's say we only have two decoders engaged in beam search. If we need to switch the KV caches between Decoder A and Decoder B, a simple pointer swap would do the trick. Just like |
|
I've now visualized the logic for updating the KV cache across different versions, as shown in the figure. I've made the changes in the new version based on the suggestions you provided earlier. If there aren't any bugs, it's ready to be merged. Thanks. @ggerganov Master:
PR(v1):
PR(v4):
PR(v5):
|
|
I tested this PR, and the transcription results from Master and this PR are the same. I then wrote a demo and tested 1 million possible combinations. There were no issues, but further testing may still be needed. I am now starting to implement batch decoding. As for parallel sampling, I plan to wait until the grammar-based sampling is merged before proceeding. Code#include <iostream>
#include <cmath>
#include <vector>
#include <random>
#include <functional>
#include <numeric>
#include <unordered_set>
struct kv_cache {
std::string data;
};
template<class T>
bool is_same_vec(T & vec1, T & vec2) {
if (vec1.size() != vec2.size()) {
return false;
}
for (int i = 0; i < vec1.size(); i++) {
if (vec1[i] != vec2[i]) {
return false;
}
}
return true;
}
template<class T>
bool is_same_struct(T & vec1, T & vec2) {
if (vec1.size() != vec2.size()) {
return false;
}
for (int i = 0; i < vec1.size(); i++) {
if (vec1[i].data != vec2[i].data) {
return false;
}
}
return true;
}
template<class T>
void print_vec(T & vec) {
std::cout << "{";
for (auto &i : vec) {
std::cout << i << " ";
}
std::cout << "}" << std::endl;
}
template<class T>
void print_vec(std::vector<std::pair<T, T>> & vec) {
std::cout << "{";
for (auto & i : vec) {
std::cout << "(" << i.first << "," << i.second << ")" << " ";
}
std::cout << "}" << std::endl;
}
template<class T>
void print_struct(T & vec) {
std::cout << "{";
for (auto & i : vec) {
std::cout << i.data << " ";
}
std::cout << "}" << std::endl;
}
// replace std::pair by using customized pair struct (reason: std::pair is very slow)
template<typename A, typename B>
struct whisper_pair {
A first;
B second;
// Define a constructor that takes two arguments.
whisper_pair(const A& a, const B& b) : first(a), second(b) {}
// Define a constructor that takes no argument.
whisper_pair() : first(A()), second(B()) {}
};
static bool whisper_kv_swap_fast(std::vector<int> & view, std::vector<kv_cache> & src, int size) {
// (decoder->buffer->decoder or decoder->buffer + decoder->decoder)
std::unordered_set<int> two_copy; // decoder indices require two copies to safely modify KV caches
two_copy.reserve(size);
// (buffer->decoder or decoder->decoder)
std::unordered_set<int> one_copy; // decoder indices require one copy to safely modify KV caches
one_copy.reserve(size);
// (decoder<->decoder)
std::unordered_set<int> p_swap_set; // decoder indices able to swap KV-cache pointers
p_swap_set.reserve(size);
std::vector<whisper_pair<int, int>> p_swap_vec;
p_swap_vec.reserve(size);
for (int i = 0; i < size; i++) {
// zero-copy
if (i == view[i] || view[i] < 0) {continue;}
bool is_one_copy = true;
for (int j = i + 1; j < size; j++) {
if (i == view[j]) {
two_copy.insert(i);
is_one_copy = false;
// detect symmetric diagram
if (j == view[i]) {
p_swap_set.insert(i);
p_swap_set.insert(j);
p_swap_vec.emplace_back(i, j);
}
break;
}
}
if (is_one_copy) {
one_copy.insert(i);
}
}
std::vector<kv_cache> kv_bufs;
kv_bufs.resize(size);
for (auto & i : two_copy) {
// make a copy of KV
kv_bufs[i].data = src[i].data;
}
for (auto & i : two_copy) {
if (p_swap_set.find(i) != p_swap_set.end()) {continue;}
if (two_copy.find(view[i]) != two_copy.end()) {
// modify KV of decoder using data from kv_bufs
src[i].data = kv_bufs[view[i]].data;
} else {
// modify KV of decoder using data from correspond decoder KV
src[i].data = src[view[i]].data;
}
}
for (auto & i : one_copy) {
if (p_swap_set.find(i) != p_swap_set.end()) {continue;}
if (two_copy.find(view[i]) != two_copy.end()) {
// modify KV of decoder using data from kv_bufs
src[i].data = kv_bufs[view[i]].data;
} else {
// modify KV of decoder using data from correspond decoder KV
src[i].data = src[view[i]].data;
}
}
for (auto & i : p_swap_vec) {
std::swap(src[i.first].data, src[i.second].data);
}
// print_vec(two_copy_vec);
// print_vec(one_copy_vec);
// print_vec(p_swap_vec);
return true;
}
int main() {
bool diff = false;
for (int r = 0; r < 1000000; r++) {
std::vector<kv_cache> decoder_1(5);
decoder_1[0].data = "A";
decoder_1[1].data = "B";
decoder_1[2].data = "C";
decoder_1[3].data = "D";
decoder_1[4].data = "E";
std::vector<kv_cache> decoder_2(5);
decoder_2[0].data = "A";
decoder_2[1].data = "B";
decoder_2[2].data = "C";
decoder_2[3].data = "D";
decoder_2[4].data = "E";
std::random_device rd; // initialize random seed
std::mt19937 gen(rd()); // Mersenne Twister 19937 generator
std::uniform_int_distribution<> dis(0, 4);
std::vector<int> decoder_index(5);
for (auto &i : decoder_index) {
i = dis(gen);
}
// print_vec(decoder_index);
whisper_kv_swap_fast(decoder_index, decoder_1, decoder_1.size());
// baseline method
std::vector<kv_cache> buffer;
buffer.resize(5);
for (int i = 0; i < decoder_2.size(); i++) {
buffer[i] = decoder_2[i];
}
for (int i = 0; i < decoder_2.size(); i++) {
decoder_2[i] = buffer[decoder_index[i]];
}
if (!is_same_struct(decoder_1, decoder_2)) {
std::cout << "Error! Iteration: " << r << std::endl;
print_struct(decoder_1);
print_struct(decoder_2);
diff = true;
break;
}
}
if (!diff) {
std::cout << "No difference" << std::endl;
}
} |
|
The test program that you made is very useful. You should try to add it as a unit test in It might be a good idea to create a wiki page with your idea and plots from this PR. Or at the very least, you can just link to this PR. But the goal is this information to remain available and easy to find in the future. You can also add a wiki entry for your previous PR about matching OpenAI input. |
|
I apologize for the late update. I was very busy yesterday and didn't have time to handle it. I've just updated the code. |
* master: (96 commits) whisper : fix bench regression + fix performance when using CPU BLAS (ggml-org#1275) whisper : faster beam_search sampling via reduced KV cache copies (ggml-org#1243) java : fixed signing of java artifact using gradle (ggml-org#1267) ci : try to fix gradle action (ggml-org#1265) gitignore : update sync : ggml (HBM + Metal + style) (ggml-org#1264) ci : upgrade gradle to 2.4.2 (ggml-org#1263) sync : ggml (CUDA faster rope) cmake : noramlize case (ggml-org#1129) build : do not use _GNU_SOURCE gratuitously (ggml-org#1129) examples : fix build + compile warnings (close ggml-org#1256) models : add quantum models to download-ggml-model.sh (ggml-org#1235) whisper.android : bump gradle plugin and dependencies + a lint pass (ggml-org#1255) sign jar for Maven Central repo whisper.android : address ARM's big.LITTLE arch by checking cpu info (ggml-org#1254) make : fix detection of AVX2 on macOS (ggml-org#1250) ggml : posixify pagesize (ggml-org#1251) configured publishing.repositories ggml : sync latest llama.cpp (view_src + alloc improvements) (ggml-org#1247) make : improve cpuinfo handling on x86 hosts (ggml-org#1238) ...
…ml-org#1243) * Faster `beam_search` sampling Refine the KV cache update logic for more intelligent and efficient updating. * Faster `whisper_sample_token_topk` * Update whisper.cpp * Update whisper.cpp * Update whisper.cpp * Reduce `memory allocation` * Add `pointer swapping` * Fixed some bugs * Update whisper.cpp * Apply suggestions from code review * Updated the logic for determining `two-copy` * Updated the logic for determining `two-copy` v2 * whisper : add debug logs + coding style --------- Co-authored-by: Georgi Gerganov <[email protected]>
…ml-org#1243) * Faster `beam_search` sampling Refine the KV cache update logic for more intelligent and efficient updating. * Faster `whisper_sample_token_topk` * Update whisper.cpp * Update whisper.cpp * Update whisper.cpp * Reduce `memory allocation` * Add `pointer swapping` * Fixed some bugs * Update whisper.cpp * Apply suggestions from code review * Updated the logic for determining `two-copy` * Updated the logic for determining `two-copy` v2 * whisper : add debug logs + coding style --------- Co-authored-by: Georgi Gerganov <[email protected]>
…ml-org#1243) * Faster `beam_search` sampling Refine the KV cache update logic for more intelligent and efficient updating. * Faster `whisper_sample_token_topk` * Update whisper.cpp * Update whisper.cpp * Update whisper.cpp * Reduce `memory allocation` * Add `pointer swapping` * Fixed some bugs * Update whisper.cpp * Apply suggestions from code review * Updated the logic for determining `two-copy` * Updated the logic for determining `two-copy` v2 * whisper : add debug logs + coding style --------- Co-authored-by: Georgi Gerganov <[email protected]>
…ml-org#1243) * Faster `beam_search` sampling Refine the KV cache update logic for more intelligent and efficient updating. * Faster `whisper_sample_token_topk` * Update whisper.cpp * Update whisper.cpp * Update whisper.cpp * Reduce `memory allocation` * Add `pointer swapping` * Fixed some bugs * Update whisper.cpp * Apply suggestions from code review * Updated the logic for determining `two-copy` * Updated the logic for determining `two-copy` v2 * whisper : add debug logs + coding style --------- Co-authored-by: Georgi Gerganov <[email protected]>
…ml-org#1243) * Faster `beam_search` sampling Refine the KV cache update logic for more intelligent and efficient updating. * Faster `whisper_sample_token_topk` * Update whisper.cpp * Update whisper.cpp * Update whisper.cpp * Reduce `memory allocation` * Add `pointer swapping` * Fixed some bugs * Update whisper.cpp * Apply suggestions from code review * Updated the logic for determining `two-copy` * Updated the logic for determining `two-copy` v2 * whisper : add debug logs + coding style --------- Co-authored-by: Georgi Gerganov <[email protected]>
…ml-org#1243) * Faster `beam_search` sampling Refine the KV cache update logic for more intelligent and efficient updating. * Faster `whisper_sample_token_topk` * Update whisper.cpp * Update whisper.cpp * Update whisper.cpp * Reduce `memory allocation` * Add `pointer swapping` * Fixed some bugs * Update whisper.cpp * Apply suggestions from code review * Updated the logic for determining `two-copy` * Updated the logic for determining `two-copy` v2 * whisper : add debug logs + coding style --------- Co-authored-by: Georgi Gerganov <[email protected]>




PR (v1): The KV cache update logic in the Master is quite simple and heavy-handed. First, a temporary cache is established, and all the KV cache is stored in this temporary cache. After the beam search is completed, the KV cache is updated based on the results of the beam search, pulling from the temporary cache to replace the old KV cache. If we assume the
beam_sizeisX, then with the above method we need to read and write the KV cache2Xtimes, which is very inefficient. In reality, we often don't need to update the decoder's KV cache (zero-copy), and even if we do need to update, we only need to update a portion and can directly copy without going through the temporary cache (one-copy). Lastly, we can also detect if there are read-write conflicts to reduce the number of times we write to the temporary cache and then copy (two-copy).PR (v2): The calculation of
whisper_sample_token_topkin Master is highly inefficient. First, it useslogits_id.clear()to clear all elements in the vector, and then it uses a loop to extract the logit value from logits. Along with the current index, a newstd::pairis created and then pushed to the end oflogits_idusingpush_back. Because the length of logits is determined byn_vocab, andn_vocabin whisper is51,865, each time we calculatewhisper_sample_token_topkusing the method in Master, we need to create (vialogits_id.push_back)and destroy (vialogits_id.clear)51,865std::pairobjects, which is highly inefficient. In fact, we can uselogits_id.resize()to limit the size oflogits_id. Since the new size is less than or equal to the old capacity, there will be no memory allocation, and we can directly assign values to the old pairs using the index, thus avoiding creation and destruction.PR(v3): Adopted ggerganov's suggestion to make ·kv_bufs· static, avoiding memory allocation and deallocation, thereby significantly improving performance.
PR(v4): Added the pointer swapping function, which slightly improved performance.
PR(v5): Updated the logic for determining
two-copy, reducing unnecessary memory copying, slightly improving performance.This PR optimizes both the KV cache update logic and the calculation of
whisper_sample_token_topk, significantly reducing memory read-write operations, thereby saving a lot of time.The average sample time has decreased from
1.56 ms/runto0.69 ms/run, reducing the computational time by approximately~55%. @ggerganovMaster:
i7-12700Hdiffusion2023-07-03.wavggml-model-whisper-base.bin-bs 1-bo 1-bs 2-bo 2-bs 3-bo 3-bs 4-bo 4-bs 5-bo 5Older Versions
This PR (v1):
i7-12700Hdiffusion2023-07-03.wavggml-model-whisper-base.bin-bs 1-bo 1-bs 2-bo 2-bs 3-bo 3-bs 4-bo 4-bs 5-bo 5This PR (v2):
i7-12700Hdiffusion2023-07-03.wavggml-model-whisper-base.bin-bs 1-bo 1-bs 2-bo 2-bs 3-bo 3-bs 4-bo 4-bs 5-bo 5This PR (v3):
i7-12700Hdiffusion2023-07-03.wavggml-model-whisper-base.bin-bs 1-bo 1-bs 2-bo 2-bs 3-bo 3-bs 4-bo 4-bs 5-bo 5This PR (v4):
i7-12700Hdiffusion2023-07-03.wavggml-model-whisper-base.bin-bs 1-bo 1-bs 2-bo 2-bs 3-bo 3-bs 4-bo 4-bs 5-bo 5This PR (v5):
i7-12700Hdiffusion2023-07-03.wavggml-model-whisper-base.bin-bs 1-bo 1-bs 2-bo 2-bs 3-bo 3-bs 4-bo 4-bs 5-bo 5