-
Notifications
You must be signed in to change notification settings - Fork 5k
Significantly improve whisper.cpp inference quality #1148
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 11 commits
4767223
38eaeff
2dd6884
4ebe450
6f445d1
7f690dd
527d7c6
f3e7774
95be6dc
bd1dbd1
2c49c9b
5df242c
e40ec27
715bf61
3fe41d5
36b0df7
444b59a
308f490
252f807
0a5f435
65fd0e1
386ef32
241df86
22d348c
590a12e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -2396,12 +2396,10 @@ static void fft(const std::vector<float> & in, std::vector<float> & out) { | |
| even.reserve(N/2); | ||
| odd.reserve(N/2); | ||
|
|
||
| for (int i = 0; i < N; i++) { | ||
| if (i % 2 == 0) { | ||
| even.push_back(in[i]); | ||
| } else { | ||
| odd.push_back(in[i]); | ||
| } | ||
| // | ||
| for (int i = 0; i < N; i+=2) { | ||
| even.push_back(in[i]); | ||
| odd.push_back(in[i + 1]); | ||
| } | ||
|
|
||
| std::vector<float> even_fft; | ||
|
|
@@ -2424,6 +2422,7 @@ static void fft(const std::vector<float> & in, std::vector<float> & out) { | |
|
|
||
| out[2*(k + N/2) + 0] = even_fft[2*k + 0] - re*re_odd + im*im_odd; | ||
| out[2*(k + N/2) + 1] = even_fft[2*k + 1] - re*im_odd - im*re_odd; | ||
|
|
||
| } | ||
| } | ||
|
|
||
|
|
@@ -2432,34 +2431,43 @@ static void log_mel_spectrogram_worker_thread(int ith, const std::vector<float> | |
| const whisper_filters &filters, bool speed_up, whisper_mel &mel) { | ||
| std::vector<float> fft_in(fft_size, 0.0); | ||
| std::vector<float> fft_out(2 * fft_size); | ||
| // Is using 32-bit float to calculate log_mel appropriate? | ||
| // 32-bit float has about 7 digits of precision, but minimum value of log_mel is 1e-10. | ||
| int n_fft = 1 + (speed_up ? fft_size / 4 : fft_size / 2); | ||
| int i = ith; | ||
|
|
||
| for (int i = ith; i < mel.n_len; i += n_threads) { | ||
| // Calculate FFT only when fft_in are not all zero | ||
| for (; i < std::min((n_samples / fft_step) + 1, mel.n_len); i += n_threads) { | ||
| const int offset = i * fft_step; | ||
|
|
||
| // apply Hanning window | ||
| for (int j = 0; j < fft_size; j++) { | ||
| if (offset + j < n_samples) { | ||
| fft_in[j] = hann[j] * samples[offset + j]; | ||
| } else { | ||
| fft_in[j] = 0.0; | ||
| } | ||
| // apply Hanning window (~10% faster) | ||
| for (int j = 0; j < std::min(fft_size, n_samples - offset); j++) { | ||
| fft_in[j] = hann[j] * samples[offset + j]; | ||
| } | ||
| // fill the rest with zeros | ||
| if (n_samples - offset < fft_size) { | ||
| std::fill(fft_in.begin() + (n_samples - offset), fft_in.end(), 0.0); | ||
| } | ||
|
|
||
| // FFT -> mag^2 | ||
| // FFT | ||
| fft(fft_in, fft_out); | ||
|
|
||
| // Calculate modulus^2 of complex numbers | ||
| // Use pow(fft_out[2 * j + 0], 2) + pow(fft_out[2 * j + 1], 2) causes inference quality problem? Interesting. | ||
| for (int j = 0; j < fft_size; j++) { | ||
| fft_out[j] = (fft_out[2 * j + 0] * fft_out[2 * j + 0] + fft_out[2 * j + 1] * fft_out[2 * j + 1]); | ||
| } | ||
| for (int j = 1; j < fft_size / 2; j++) { | ||
| fft_out[j] += fft_out[fft_size - j]; | ||
|
|
||
| // The frequency spectrum produced by real input data is symmetrical around the Nyquist frequency. | ||
| // This is where the actual issue lies | ||
| for (int j = 0; j < fft_size / 2; j++) { | ||
| fft_out[j] = (fft_out[fft_size - j - 1] + fft_out[j + 1]) / 2; | ||
| } | ||
|
||
|
|
||
| if (speed_up) { | ||
| // scale down in the frequency domain results in a speed up in the time domain | ||
| for (int j = 0; j < n_fft; j++) { | ||
| fft_out[j] = 0.5 * (fft_out[2 * j] + fft_out[2 * j + 1]); | ||
| // scale down in the frequency domain results in a speed-up in the time domain | ||
| for (int j = 0; j < n_fft - 1; j++) { | ||
| fft_out[j] = (fft_out[2 * j] + fft_out[2 * j + 1]) / 2; | ||
| } | ||
| } | ||
|
|
||
|
|
@@ -2471,10 +2479,10 @@ static void log_mel_spectrogram_worker_thread(int ith, const std::vector<float> | |
| int k = 0; | ||
| for (k = 0; k < n_fft - 3; k += 4) { | ||
| sum += | ||
| fft_out[k + 0] * filters.data[j*n_fft + k + 0] + | ||
| fft_out[k + 1] * filters.data[j*n_fft + k + 1] + | ||
| fft_out[k + 2] * filters.data[j*n_fft + k + 2] + | ||
| fft_out[k + 3] * filters.data[j*n_fft + k + 3]; | ||
| fft_out[k + 0] * filters.data[j * n_fft + k + 0] + | ||
| fft_out[k + 1] * filters.data[j * n_fft + k + 1] + | ||
| fft_out[k + 2] * filters.data[j * n_fft + k + 2] + | ||
| fft_out[k + 3] * filters.data[j * n_fft + k + 3]; | ||
| } | ||
|
|
||
| // handle n_fft remainder | ||
|
|
@@ -2487,6 +2495,14 @@ static void log_mel_spectrogram_worker_thread(int ith, const std::vector<float> | |
| mel.data[j * mel.n_len + i] = sum; | ||
| } | ||
| } | ||
|
|
||
| // Otherwise fft_out are all zero | ||
| double sum = log10(1e-10); | ||
| for (; i < mel.n_len; i += n_threads) { | ||
| for (int j = 0; j < mel.n_mel; j++) { | ||
| mel.data[j * mel.n_len + i] = sum; | ||
| } | ||
| } | ||
| } | ||
|
|
||
| // ref: https://github.com/openai/whisper/blob/main/whisper/audio.py#L92-L124 | ||
|
|
@@ -2508,7 +2524,9 @@ static bool log_mel_spectrogram( | |
| std::vector<float> hann; | ||
| hann.resize(fft_size); | ||
| for (int i = 0; i < fft_size; i++) { | ||
| hann[i] = 0.5*(1.0 - cos((2.0*M_PI*i)/(fft_size))); | ||
| // ref: https://pytorch.org/docs/stable/generated/torch.hann_window.html | ||
| // ref: https://github.com/openai/whisper/blob/main/whisper/audio.py#L147 | ||
| hann[i] = 0.5*(1.0 - cos((2.0*M_PI*i)/(fft_size - 1))); | ||
| } | ||
|
|
||
| mel.n_mel = n_mel; | ||
|
|
@@ -3634,7 +3652,7 @@ static void whisper_process_logits( | |
| WHISPER_ASSERT(n_logits == ctx.vocab.n_vocab); | ||
|
|
||
| // extract the logits for the last token | ||
| // we will be mutating and therefore we don't want to use the ctx.logits buffer directly | ||
| // we will be mutating, and therefore we don't want to use the ctx.logits buffer directly | ||
| auto & probs = decoder.probs; | ||
| auto & logits = decoder.logits; | ||
| auto & logprobs = decoder.logprobs; | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Isn't doing a zero fill here problematic? Let's say you've got an
fft_sizeof256, andn_samples - offsethappens to come out to128. That means you're creating a cliff edge betweenfft_in[127]andfft_in[128]. The Hanning window is at its maximum value there, so you get no smoothing at all, it's going to be all artefacty.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good question! So, I am writing a more detailed code analysis note, comparing the differences in the methods of generating log mel spectrograms between OpenAI's whisper and whisper.cpp. I have already completed part of it, and I will publish this finished portion shortly.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've got a sneaking suspicion it might not be easy to get this right without doing some rework on the audio pipeline. Ideally you'd want to be saving unprocessed samples for the next round so you're always working with a full buffer.