Skip to content

Conversation

@bobqianic
Copy link
Collaborator

@bobqianic bobqianic commented Sep 3, 2023

PR (v1): The KV cache update logic in the Master is quite simple and heavy-handed. First, a temporary cache is established, and all the KV cache is stored in this temporary cache. After the beam search is completed, the KV cache is updated based on the results of the beam search, pulling from the temporary cache to replace the old KV cache. If we assume the beam_size is X, then with the above method we need to read and write the KV cache 2X times, which is very inefficient. In reality, we often don't need to update the decoder's KV cache (zero-copy), and even if we do need to update, we only need to update a portion and can directly copy without going through the temporary cache (one-copy). Lastly, we can also detect if there are read-write conflicts to reduce the number of times we write to the temporary cache and then copy (two-copy).

PR (v2): The calculation of whisper_sample_token_topk in Master is highly inefficient. First, it uses logits_id.clear() to clear all elements in the vector, and then it uses a loop to extract the logit value from logits. Along with the current index, a new std::pair is created and then pushed to the end of logits_id using push_back. Because the length of logits is determined by n_vocab, and n_vocab in whisper is 51,865, each time we calculate whisper_sample_token_topk using the method in Master, we need to create (via logits_id.push_back) and destroy (via logits_id.clear) 51,865 std::pair objects, which is highly inefficient. In fact, we can use logits_id.resize() to limit the size of logits_id. Since the new size is less than or equal to the old capacity, there will be no memory allocation, and we can directly assign values to the old pairs using the index, thus avoiding creation and destruction.

PR(v3): Adopted ggerganov's suggestion to make ·kv_bufs· static, avoiding memory allocation and deallocation, thereby significantly improving performance.

PR(v4): Added the pointer swapping function, which slightly improved performance.

PR(v5): Updated the logic for determining two-copy, reducing unnecessary memory copying, slightly improving performance.

This PR optimizes both the KV cache update logic and the calculation of whisper_sample_token_topk, significantly reducing memory read-write operations, thereby saving a lot of time.

The average sample time has decreased from 1.56 ms/run to 0.69 ms/run, reducing the computational time by approximately ~55%. @ggerganov

Master: i7-12700H diffusion2023-07-03.wav ggml-model-whisper-base.bin

-bs 1 -bo 1 -bs 2 -bo 2 -bs 3 -bo 3 -bs 4 -bo 4 -bs 5 -bo 5
0.57 ms/run 1.56 ms/run 1.55 ms/run 1.58 ms/run 1.55 ms/run
Older Versions

This PR (v1): i7-12700H diffusion2023-07-03.wav ggml-model-whisper-base.bin

-bs 1 -bo 1 -bs 2 -bo 2 -bs 3 -bo 3 -bs 4 -bo 4 -bs 5 -bo 5
0.57 ms/run 0.96 ms/run 1.01 ms/run 1.07 ms/run 1.14 ms/run

This PR (v2): i7-12700H diffusion2023-07-03.wav ggml-model-whisper-base.bin

-bs 1 -bo 1 -bs 2 -bo 2 -bs 3 -bo 3 -bs 4 -bo 4 -bs 5 -bo 5
0.57 ms/run 0.76 ms/run 0.81 ms/run 0.87 ms/run 0.94 ms/run

This PR (v3): i7-12700H diffusion2023-07-03.wav ggml-model-whisper-base.bin

-bs 1 -bo 1 -bs 2 -bo 2 -bs 3 -bo 3 -bs 4 -bo 4 -bs 5 -bo 5
0.57 ms/run 0.69 ms/run 0.72 ms/run 0.74 ms/run 0.77 ms/run

This PR (v4): i7-12700H diffusion2023-07-03.wav ggml-model-whisper-base.bin

-bs 1 -bo 1 -bs 2 -bo 2 -bs 3 -bo 3 -bs 4 -bo 4 -bs 5 -bo 5
0.57 ms/run 0.67 ms/run 0.70 ms/run 0.72 ms/run 0.75 ms/run

This PR (v5): i7-12700H diffusion2023-07-03.wav ggml-model-whisper-base.bin

-bs 1 -bo 1 -bs 2 -bo 2 -bs 3 -bo 3 -bs 4 -bo 4 -bs 5 -bo 5
0.57 ms/run 0.64 ms/run 0.68 ms/run 0.70 ms/run 0.73 ms/run

Refine the KV cache update logic for more intelligent and efficient updating.
Copy link
Member

@ggerganov ggerganov left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Clever implementation!

Some minor comments and then I will take a more deeper look + testing to make sure this works correctly. But I understand the idea and agree this would be faster

@ggerganov
Copy link
Member

First, it uses logits_id.clear() to clear all elements in the vector, and then it uses a loop to extract the logit value from logits.

This is not slower than the new approach, because all that clear() does is change the internal elements counter to 0. The memory of the vector remains allocated. But anyway, the new implementation is also fine

@bobqianic
Copy link
Collaborator Author

First, it uses logits_id.clear() to clear all elements in the vector, and then it uses a loop to extract the logit value from logits.

This is not slower than the new approach, because all that clear() does is change the internal elements counter to 0. The memory of the vector remains allocated. But anyway, the new implementation is also fine

You're right, that's my bad. I misunderstood what the clear() function does.

@bobqianic bobqianic requested a review from ggerganov September 5, 2023 18:03
@bobqianic
Copy link
Collaborator Author

In certain scenarios, it might be more efficient to just swap the pointers instead of actually moving the memory around. For instance, let's say we only have two decoders engaged in beam search. If we need to switch the KV caches between Decoder A and Decoder B, a simple pointer swap would do the trick. Just like std:swap

@bobqianic
Copy link
Collaborator Author

bobqianic commented Sep 7, 2023

I've now visualized the logic for updating the KV cache across different versions, as shown in the figure. I've made the changes in the new version based on the suggestions you provided earlier. If there aren't any bugs, it's ready to be merged. Thanks. @ggerganov

Master:

image

PR(v1):

image

PR(v4):

image

PR(v5):

image

@bobqianic
Copy link
Collaborator Author

bobqianic commented Sep 7, 2023

I tested this PR, and the transcription results from Master and this PR are the same. I then wrote a demo and tested 1 million possible combinations. There were no issues, but further testing may still be needed. I am now starting to implement batch decoding. As for parallel sampling, I plan to wait until the grammar-based sampling is merged before proceeding.

Code
#include <iostream>
#include <cmath>
#include <vector>
#include <random>
#include <functional>
#include <numeric>
#include <unordered_set>

struct kv_cache {
    std::string data;
};

template<class T>
bool is_same_vec(T & vec1, T & vec2) {
    if (vec1.size() != vec2.size()) {
        return false;
    }
    for (int i = 0; i < vec1.size(); i++) {
        if (vec1[i] != vec2[i]) {
            return false;
        }
    }
    return true;
}

template<class T>
bool is_same_struct(T & vec1, T & vec2) {
    if (vec1.size() != vec2.size()) {
        return false;
    }
    for (int i = 0; i < vec1.size(); i++) {
        if (vec1[i].data != vec2[i].data) {
            return false;
        }
    }
    return true;
}

template<class T>
void print_vec(T & vec) {
    std::cout << "{";
    for (auto &i : vec) {
        std::cout << i << " ";
    }
    std::cout << "}" << std::endl;
}

template<class T>
void print_vec(std::vector<std::pair<T, T>> & vec) {
    std::cout << "{";
    for (auto & i : vec) {
        std::cout << "(" << i.first << "," << i.second << ")" << " ";
    }
    std::cout << "}" << std::endl;
}

template<class T>
void print_struct(T & vec) {
    std::cout << "{";
    for (auto & i : vec) {
        std::cout << i.data << " ";
    }
    std::cout << "}" << std::endl;
}

// replace std::pair by using customized pair struct (reason: std::pair is very slow)
template<typename A, typename B>
struct whisper_pair {
    A first;
    B second;

    // Define a constructor that takes two arguments.
    whisper_pair(const A& a, const B& b) : first(a), second(b) {}
    // Define a constructor that takes no argument.
    whisper_pair() : first(A()), second(B()) {}
};

static bool whisper_kv_swap_fast(std::vector<int> & view, std::vector<kv_cache> & src, int size) {
    // (decoder->buffer->decoder or decoder->buffer + decoder->decoder)
    std::unordered_set<int> two_copy; // decoder indices require two copies to safely modify KV caches
    two_copy.reserve(size);

    // (buffer->decoder or decoder->decoder)
    std::unordered_set<int> one_copy; // decoder indices require one copy to safely modify KV caches
    one_copy.reserve(size);

    // (decoder<->decoder)
    std::unordered_set<int> p_swap_set; // decoder indices able to swap KV-cache pointers
    p_swap_set.reserve(size);
    std::vector<whisper_pair<int, int>> p_swap_vec;
    p_swap_vec.reserve(size);

    for (int i = 0; i < size; i++) {
        // zero-copy
        if (i == view[i] || view[i] < 0) {continue;}

        bool is_one_copy = true;
        for (int j = i + 1; j < size; j++) {
            if (i == view[j]) {
                two_copy.insert(i);
                is_one_copy = false;
                // detect symmetric diagram
                if (j == view[i]) {
                    p_swap_set.insert(i);
                    p_swap_set.insert(j);
                    p_swap_vec.emplace_back(i, j);
                }
                break;
            }
        }
        if (is_one_copy) {
            one_copy.insert(i);
        }
    }

    std::vector<kv_cache> kv_bufs;
    kv_bufs.resize(size);

    for (auto & i : two_copy) {
        // make a copy of KV
        kv_bufs[i].data = src[i].data;
    }

    for (auto & i : two_copy) {
        if (p_swap_set.find(i) != p_swap_set.end()) {continue;}

        if (two_copy.find(view[i]) != two_copy.end()) {
            // modify KV of decoder using data from kv_bufs
            src[i].data = kv_bufs[view[i]].data;
        } else {
            // modify KV of decoder using data from correspond decoder KV
            src[i].data = src[view[i]].data;
        }
    }

    for (auto & i : one_copy) {
        if (p_swap_set.find(i) != p_swap_set.end()) {continue;}
        if (two_copy.find(view[i]) != two_copy.end()) {
            // modify KV of decoder using data from kv_bufs
            src[i].data = kv_bufs[view[i]].data;
        } else {
            // modify KV of decoder using data from correspond decoder KV
            src[i].data = src[view[i]].data;
        }
    }

    for (auto & i : p_swap_vec) {
        std::swap(src[i.first].data, src[i.second].data);
    }

    // print_vec(two_copy_vec);
    // print_vec(one_copy_vec);
    // print_vec(p_swap_vec);

    return true;
}

int main() {
    bool diff = false;
    for (int r = 0; r < 1000000; r++) {
        std::vector<kv_cache> decoder_1(5);
        decoder_1[0].data = "A";
        decoder_1[1].data = "B";
        decoder_1[2].data = "C";
        decoder_1[3].data = "D";
        decoder_1[4].data = "E";

        std::vector<kv_cache> decoder_2(5);
        decoder_2[0].data = "A";
        decoder_2[1].data = "B";
        decoder_2[2].data = "C";
        decoder_2[3].data = "D";
        decoder_2[4].data = "E";

        std::random_device rd; // initialize random seed
        std::mt19937 gen(rd()); // Mersenne Twister 19937 generator
        std::uniform_int_distribution<> dis(0, 4);

        std::vector<int> decoder_index(5);
        for (auto &i : decoder_index) {
            i = dis(gen);
        }

        // print_vec(decoder_index);
        whisper_kv_swap_fast(decoder_index, decoder_1, decoder_1.size());

        // baseline method
        std::vector<kv_cache> buffer;
        buffer.resize(5);
        for (int i = 0; i < decoder_2.size(); i++) {
            buffer[i] = decoder_2[i];
        }
        for (int i = 0; i < decoder_2.size(); i++) {
            decoder_2[i] = buffer[decoder_index[i]];
        }

        if (!is_same_struct(decoder_1, decoder_2)) {
            std::cout << "Error! Iteration: " << r << std::endl;
            print_struct(decoder_1);
            print_struct(decoder_2);
            diff = true;
            break;
        }
    }
    if (!diff) {
        std::cout << "No difference" << std::endl;
    }
}

@ggerganov
Copy link
Member

The test program that you made is very useful. You should try to add it as a unit test in tests

It might be a good idea to create a wiki page with your idea and plots from this PR. Or at the very least, you can just link to this PR. But the goal is this information to remain available and easy to find in the future.

You can also add a wiki entry for your previous PR about matching OpenAI input.
I've sent you a collaborator invite

@bobqianic
Copy link
Collaborator Author

I apologize for the late update. I was very busy yesterday and didn't have time to handle it. I've just updated the code.

@ggerganov ggerganov merged commit 9b14418 into ggml-org:master Sep 10, 2023
bdonkey added a commit to bdonkey/whisper.cpp that referenced this pull request Sep 13, 2023
* master: (96 commits)
  whisper : fix bench regression + fix performance when using CPU BLAS (ggml-org#1275)
  whisper : faster beam_search sampling via reduced KV cache copies (ggml-org#1243)
  java : fixed signing of java artifact using gradle (ggml-org#1267)
  ci : try to fix gradle action (ggml-org#1265)
  gitignore : update
  sync : ggml (HBM + Metal + style) (ggml-org#1264)
  ci : upgrade gradle to 2.4.2 (ggml-org#1263)
  sync : ggml (CUDA faster rope)
  cmake : noramlize case (ggml-org#1129)
  build : do not use _GNU_SOURCE gratuitously (ggml-org#1129)
  examples : fix build + compile warnings (close ggml-org#1256)
  models : add quantum models to download-ggml-model.sh (ggml-org#1235)
  whisper.android : bump gradle plugin and dependencies + a lint pass (ggml-org#1255)
  sign jar for Maven Central repo
  whisper.android : address ARM's big.LITTLE arch by checking cpu info (ggml-org#1254)
  make : fix detection of AVX2 on macOS (ggml-org#1250)
  ggml : posixify pagesize (ggml-org#1251)
  configured publishing.repositories
  ggml : sync latest llama.cpp (view_src + alloc improvements) (ggml-org#1247)
  make : improve cpuinfo handling on x86 hosts (ggml-org#1238)
  ...
didzis pushed a commit to didzis/whisper.cpp that referenced this pull request Sep 30, 2023
…ml-org#1243)

* Faster `beam_search` sampling

Refine the KV cache update logic for more intelligent and efficient updating.

* Faster `whisper_sample_token_topk`

* Update whisper.cpp

* Update whisper.cpp

* Update whisper.cpp

* Reduce `memory allocation`

* Add `pointer swapping`

* Fixed some bugs

* Update whisper.cpp

* Apply suggestions from code review

* Updated the logic for determining `two-copy`

* Updated the logic for determining `two-copy` v2

* whisper : add debug logs + coding style

---------

Co-authored-by: Georgi Gerganov <[email protected]>
jacobwu-b pushed a commit to jacobwu-b/Transcriptify-by-whisper.cpp that referenced this pull request Oct 24, 2023
…ml-org#1243)

* Faster `beam_search` sampling

Refine the KV cache update logic for more intelligent and efficient updating.

* Faster `whisper_sample_token_topk`

* Update whisper.cpp

* Update whisper.cpp

* Update whisper.cpp

* Reduce `memory allocation`

* Add `pointer swapping`

* Fixed some bugs

* Update whisper.cpp

* Apply suggestions from code review

* Updated the logic for determining `two-copy`

* Updated the logic for determining `two-copy` v2

* whisper : add debug logs + coding style

---------

Co-authored-by: Georgi Gerganov <[email protected]>
jacobwu-b pushed a commit to jacobwu-b/Transcriptify-by-whisper.cpp that referenced this pull request Oct 24, 2023
…ml-org#1243)

* Faster `beam_search` sampling

Refine the KV cache update logic for more intelligent and efficient updating.

* Faster `whisper_sample_token_topk`

* Update whisper.cpp

* Update whisper.cpp

* Update whisper.cpp

* Reduce `memory allocation`

* Add `pointer swapping`

* Fixed some bugs

* Update whisper.cpp

* Apply suggestions from code review

* Updated the logic for determining `two-copy`

* Updated the logic for determining `two-copy` v2

* whisper : add debug logs + coding style

---------

Co-authored-by: Georgi Gerganov <[email protected]>
vonstring pushed a commit to vonstring/whisper.cpp that referenced this pull request Nov 7, 2023
…ml-org#1243)

* Faster `beam_search` sampling

Refine the KV cache update logic for more intelligent and efficient updating.

* Faster `whisper_sample_token_topk`

* Update whisper.cpp

* Update whisper.cpp

* Update whisper.cpp

* Reduce `memory allocation`

* Add `pointer swapping`

* Fixed some bugs

* Update whisper.cpp

* Apply suggestions from code review

* Updated the logic for determining `two-copy`

* Updated the logic for determining `two-copy` v2

* whisper : add debug logs + coding style

---------

Co-authored-by: Georgi Gerganov <[email protected]>
landtanin pushed a commit to landtanin/whisper.cpp that referenced this pull request Dec 16, 2023
…ml-org#1243)

* Faster `beam_search` sampling

Refine the KV cache update logic for more intelligent and efficient updating.

* Faster `whisper_sample_token_topk`

* Update whisper.cpp

* Update whisper.cpp

* Update whisper.cpp

* Reduce `memory allocation`

* Add `pointer swapping`

* Fixed some bugs

* Update whisper.cpp

* Apply suggestions from code review

* Updated the logic for determining `two-copy`

* Updated the logic for determining `two-copy` v2

* whisper : add debug logs + coding style

---------

Co-authored-by: Georgi Gerganov <[email protected]>
iThalay pushed a commit to iThalay/whisper.cpp that referenced this pull request Sep 23, 2024
…ml-org#1243)

* Faster `beam_search` sampling

Refine the KV cache update logic for more intelligent and efficient updating.

* Faster `whisper_sample_token_topk`

* Update whisper.cpp

* Update whisper.cpp

* Update whisper.cpp

* Reduce `memory allocation`

* Add `pointer swapping`

* Fixed some bugs

* Update whisper.cpp

* Apply suggestions from code review

* Updated the logic for determining `two-copy`

* Updated the logic for determining `two-copy` v2

* whisper : add debug logs + coding style

---------

Co-authored-by: Georgi Gerganov <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants