@@ -2445,40 +2445,50 @@ static void fft(const std::vector<float> & in, std::vector<float> & out) {
24452445 }
24462446}
24472447
2448- static void log_mel_spectrogram_worker_thread (int ith, const std::vector<float > &hann, const float *samples,
2449- int n_samples, int fft_size, int fft_step, int n_threads,
2450- const whisper_filters &filters, bool speed_up, whisper_mel &mel) {
2451- std::vector<float > fft_in (fft_size, 0.0 );
2452- std::vector<float > fft_out (2 * fft_size);
2453- int n_fft = 1 + (speed_up ? fft_size / 4 : fft_size / 2 );
2454-
2455- for (int i = ith; i < mel.n_len ; i += n_threads) {
2456- const int offset = i * fft_step;
2457-
2458- // apply Hanning window
2459- for (int j = 0 ; j < fft_size; j++) {
2460- if (offset + j < n_samples) {
2461- fft_in[j] = hann[j] * samples[offset + j];
2462- } else {
2463- fft_in[j] = 0.0 ;
2464- }
2465- }
2448+ static bool hann_window (int length, bool periodic, std::vector<float > & output) {
2449+ if (output.size () < length) {
2450+ output.resize (length);
2451+ }
2452+ int offset = -1 ;
2453+ if (periodic) {
2454+ offset = 0 ;
2455+ }
2456+ for (int i = 0 ; i < length; i++) {
2457+ output[i] = 0.5 *(1.0 - cosf ((2.0 *M_PI*i)/(length + offset)));
2458+ }
24662459
2467- // FFT -> mag^2
2468- fft (fft_in, fft_out);
2460+ return true ;
2461+ }
24692462
2470- for (int j = 0 ; j < fft_size; j++) {
2471- fft_out[j] = (fft_out[2 * j + 0 ] * fft_out[2 * j + 0 ] + fft_out[2 * j + 1 ] * fft_out[2 * j + 1 ]);
2463+ static void log_mel_spectrogram_worker_thread (int ith, const std::vector<float > & hann, const std::vector<float > & samples,
2464+ int n_samples, int frame_size, int frame_step, int n_threads,
2465+ const whisper_filters & filters, whisper_mel & mel) {
2466+ std::vector<float > fft_in (frame_size, 0.0 );
2467+ std::vector<float > fft_out (2 * frame_step);
2468+ // make sure n_fft == 1 + (WHISPER_N_FFT / 2), bin_0 to bin_nyquist
2469+ int n_fft = 1 + (frame_size / 2 );
2470+ int i = ith;
2471+
2472+ // calculate FFT only when fft_in are not all zero
2473+ for (; i < std::min (n_samples / frame_step + 1 , mel.n_len ); i += n_threads) {
2474+ const int offset = i * frame_step;
2475+
2476+ // apply Hanning window (~10% faster)
2477+ for (int j = 0 ; j < std::min (frame_size, n_samples - offset); j++) {
2478+ fft_in[j] = hann[j] * samples[offset + j];
24722479 }
2473- for (int j = 1 ; j < fft_size / 2 ; j++) {
2474- fft_out[j] += fft_out[fft_size - j];
2480+ // fill the rest with zeros
2481+ if (n_samples - offset < frame_size) {
2482+ std::fill (fft_in.begin () + (n_samples - offset), fft_in.end (), 0.0 );
24752483 }
24762484
2477- if (speed_up) {
2478- // scale down in the frequency domain results in a speed up in the time domain
2479- for (int j = 0 ; j < n_fft; j++) {
2480- fft_out[j] = 0.5 * (fft_out[2 * j] + fft_out[2 * j + 1 ]);
2481- }
2485+ // FFT
2486+ fft (fft_in, fft_out);
2487+
2488+ // Calculate modulus^2 of complex numbers
2489+ // Use pow(fft_out[2 * j + 0], 2) + pow(fft_out[2 * j + 1], 2) causes inference quality problem? Interesting.
2490+ for (int j = 0 ; j < frame_size; j++) {
2491+ fft_out[j] = (fft_out[2 * j + 0 ] * fft_out[2 * j + 0 ] + fft_out[2 * j + 1 ] * fft_out[2 * j + 1 ]);
24822492 }
24832493
24842494 // mel spectrogram
@@ -2489,10 +2499,10 @@ static void log_mel_spectrogram_worker_thread(int ith, const std::vector<float>
24892499 int k = 0 ;
24902500 for (k = 0 ; k < n_fft - 3 ; k += 4 ) {
24912501 sum +=
2492- fft_out[k + 0 ] * filters.data [j* n_fft + k + 0 ] +
2493- fft_out[k + 1 ] * filters.data [j* n_fft + k + 1 ] +
2494- fft_out[k + 2 ] * filters.data [j* n_fft + k + 2 ] +
2495- fft_out[k + 3 ] * filters.data [j* n_fft + k + 3 ];
2502+ fft_out[k + 0 ] * filters.data [j * n_fft + k + 0 ] +
2503+ fft_out[k + 1 ] * filters.data [j * n_fft + k + 1 ] +
2504+ fft_out[k + 2 ] * filters.data [j * n_fft + k + 2 ] +
2505+ fft_out[k + 3 ] * filters.data [j * n_fft + k + 3 ];
24962506 }
24972507
24982508 // handle n_fft remainder
@@ -2505,68 +2515,73 @@ static void log_mel_spectrogram_worker_thread(int ith, const std::vector<float>
25052515 mel.data [j * mel.n_len + i] = sum;
25062516 }
25072517 }
2518+
2519+ // Otherwise fft_out are all zero
2520+ double sum = log10 (1e-10 );
2521+ for (; i < mel.n_len ; i += n_threads) {
2522+ for (int j = 0 ; j < mel.n_mel ; j++) {
2523+ mel.data [j * mel.n_len + i] = sum;
2524+ }
2525+ }
25082526}
25092527
2510- // ref: https://github.com/openai/whisper/blob/main/whisper/audio.py#L92-L124
2528+ // ref: https://github.com/openai/whisper/blob/main/whisper/audio.py#L110-L157
25112529static bool log_mel_spectrogram (
2512- whisper_state & wstate,
2513- const float * samples,
2530+ whisper_state & wstate,
2531+ const float * samples,
25142532 const int n_samples,
25152533 const int /* sample_rate*/ ,
2516- const int fft_size ,
2517- const int fft_step ,
2534+ const int frame_size ,
2535+ const int frame_step ,
25182536 const int n_mel,
25192537 const int n_threads,
2520- const whisper_filters & filters,
2521- const bool speed_up ,
2522- whisper_mel & mel) {
2538+ const whisper_filters & filters,
2539+ const bool debug ,
2540+ whisper_mel & mel) {
25232541 const int64_t t_start_us = ggml_time_us ();
25242542
2525- // Hanning window
2543+ // Hanning window (Use cosf to eliminate difference)
2544+ // ref: https://pytorch.org/docs/stable/generated/torch.hann_window.html
2545+ // ref: https://github.com/openai/whisper/blob/main/whisper/audio.py#L147
25262546 std::vector<float > hann;
2527- hann.resize (fft_size);
2528- for (int i = 0 ; i < fft_size; i++) {
2529- hann[i] = 0.5 *(1.0 - cos ((2.0 *M_PI*i)/(fft_size)));
2530- }
2531-
2532- mel.n_mel = n_mel;
2533- mel.n_len = n_samples/fft_step;
2534- mel.n_len_org = mel.n_len ;
2547+ hann_window (frame_size, true , hann);
25352548
2536- std::vector<float > samples_padded;
25372549
2538- // pad audio with at least one extra chunk of zeros
2539- {
2540- const int pad = ( 100 *WHISPER_CHUNK_SIZE)/ 2 ;
2550+ // Calculate the length of padding
2551+ int64_t stage_1_pad = WHISPER_SAMPLE_RATE * 30 ;
2552+ int64_t stage_2_pad = frame_size / 2 ;
25412553
2542- if (mel. n_len % pad != 0 ) {
2543- mel. n_len = (mel. n_len /pad + 1 )*pad ;
2544- }
2545- mel. n_len += pad ;
2554+ // Initialize a vector and copy data from C array to it.
2555+ std::vector< float > samples_padded ;
2556+ samples_padded. resize (n_samples + stage_1_pad + stage_2_pad * 2 );
2557+ std::copy (samples, samples + n_samples, samples_padded. begin () + stage_2_pad) ;
25462558
2547- samples_padded.resize (mel.n_len *fft_step);
2548- memcpy (samples_padded.data (), samples, n_samples*sizeof (float ));
2549- memset (samples_padded.data () + n_samples, 0 , (mel.n_len *fft_step - n_samples)*sizeof (float ));
2559+ // pad 30 seconds of zeros at the end of audio (480,000 samples) + reflective pad 200 samples at the end of audio
2560+ std::fill (samples_padded.begin () + n_samples + stage_2_pad, samples_padded.begin () + n_samples + stage_1_pad + 2 * stage_2_pad, 0 );
25502561
2551- samples = samples_padded. data ();
2552- }
2562+ // reflective pad 200 samples at the beginning of audio
2563+ std::reverse_copy (samples + 1 , samples + 1 + stage_2_pad, samples_padded. begin ());
25532564
2554- mel.data .resize (mel.n_mel *mel.n_len );
2565+ mel.n_mel = n_mel;
2566+ // https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/SpectralOps.cpp#L936
2567+ // Calculate number of frames + remove the last frame
2568+ mel.n_len = (samples_padded.size () - frame_size) / frame_step;
2569+ // Calculate semi-padded sample length to ensure compatibility
2570+ mel.n_len_org = 1 + (n_samples + stage_2_pad - frame_size) / frame_step;
2571+ mel.data .resize (mel.n_mel * mel.n_len );
25552572
2556- // printf("%s: n_samples = %d, n_len = %d\n", __func__, n_samples, mel.n_len);
2557- // printf("%s: recording length: %f s\n", __func__, (float) n_samples/sample_rate);
25582573
25592574 {
25602575 std::vector<std::thread> workers (n_threads - 1 );
25612576 for (int iw = 0 ; iw < n_threads - 1 ; ++iw) {
25622577 workers[iw] = std::thread (
2563- log_mel_spectrogram_worker_thread, iw + 1 , std::cref (hann), samples ,
2564- n_samples, fft_size, fft_step , n_threads,
2565- std::cref (filters), speed_up, std::ref (mel));
2578+ log_mel_spectrogram_worker_thread, iw + 1 , std::cref (hann), samples_padded ,
2579+ n_samples + stage_2_pad, frame_size, frame_step , n_threads,
2580+ std::cref (filters), std::ref (mel));
25662581 }
25672582
25682583 // main thread
2569- log_mel_spectrogram_worker_thread (0 , hann, samples , n_samples, fft_size, fft_step , n_threads, filters, speed_up , mel);
2584+ log_mel_spectrogram_worker_thread (0 , hann, samples_padded , n_samples + stage_2_pad, frame_size, frame_step , n_threads, filters, mel);
25702585
25712586 for (int iw = 0 ; iw < n_threads - 1 ; ++iw) {
25722587 workers[iw].join ();
@@ -2580,7 +2595,6 @@ static bool log_mel_spectrogram(
25802595 mmax = mel.data [i];
25812596 }
25822597 }
2583- // printf("%s: max = %f\n", __func__, mmax);
25842598
25852599 mmax -= 8.0 ;
25862600
@@ -2594,7 +2608,16 @@ static bool log_mel_spectrogram(
25942608
25952609 wstate.t_mel_us += ggml_time_us () - t_start_us;
25962610
2597- // printf("mel.n_len() = %d, divided by 1500: %f, n_samples / fft_step: %d\n", mel.n_len, mel.n_len / 1500.0, n_samples / fft_step);
2611+ // Dump log_mel_spectrogram
2612+ if (debug) {
2613+ std::ofstream outFile (" log_mel_spectrogram.json" );
2614+ outFile << " [" ;
2615+ for (uint64_t i = 0 ; i < mel.data .size () - 1 ; i++) {
2616+ outFile << mel.data [i] << " , " ;
2617+ }
2618+ outFile << mel.data [mel.data .size () - 1 ] << " ]" ;
2619+ outFile.close ();
2620+ }
25982621
25992622 return true ;
26002623}
@@ -3026,21 +3049,30 @@ int whisper_pcm_to_mel(struct whisper_context * ctx, const float * samples, int
30263049 return whisper_pcm_to_mel_with_state (ctx, ctx->state , samples, n_samples, n_threads);
30273050}
30283051
3029- // same as whisper_pcm_to_mel, but applies a Phase Vocoder to speed up the audio x2
3052+ // same as whisper_pcm_to_mel, but applies a Phase Vocoder to speed up the audio x2 (PV without phase lock is not good)
30303053int whisper_pcm_to_mel_phase_vocoder_with_state (struct whisper_context * ctx, struct whisper_state * state, const float * samples, int n_samples, int n_threads) {
3031- if (!log_mel_spectrogram (*state, samples, n_samples, WHISPER_SAMPLE_RATE, 2 * WHISPER_N_FFT, 2 * WHISPER_HOP_LENGTH, WHISPER_N_MEL, n_threads, ctx->model .filters , true , state->mel )) {
3054+ if (!log_mel_spectrogram (*state, samples, n_samples, WHISPER_SAMPLE_RATE, 2 * WHISPER_N_FFT, 2 * WHISPER_HOP_LENGTH, WHISPER_N_MEL, n_threads, ctx->model .filters , false , state->mel )) {
30323055 log (" %s: failed to compute mel spectrogram\n " , __func__);
30333056 return -1 ;
30343057 }
30353058
30363059 return 0 ;
30373060}
30383061
3039- // same as whisper_pcm_to_mel, but applies a Phase Vocoder to speed up the audio x2
3062+ // same as whisper_pcm_to_mel, but applies a Phase Vocoder to speed up the audio x2 (PV without phase lock is not good)
30403063int whisper_pcm_to_mel_phase_vocoder (struct whisper_context * ctx, const float * samples, int n_samples, int n_threads) {
30413064 return whisper_pcm_to_mel_phase_vocoder_with_state (ctx, ctx->state , samples, n_samples, n_threads);
30423065}
30433066
3067+ // same as whisper_pcm_to_mel, but applies WSOLA to speed up the audio x2
3068+ // TODO
3069+
3070+ // same as whisper_pcm_to_mel, but applies HPTSM to speed up the audio x2
3071+ // TODO
3072+
3073+ // same as whisper_pcm_to_mel, but applies PV (with phase lock) to speed up the audio x2
3074+ // TODO
3075+
30443076int whisper_set_mel_with_state (
30453077 struct whisper_context * /* ctx*/ ,
30463078 struct whisper_state * state,
@@ -3492,6 +3524,7 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
34923524 /* .max_tokens =*/ 0 ,
34933525
34943526 /* .speed_up =*/ false ,
3527+ /* .debug_mode =*/ false ,
34953528 /* .audio_ctx =*/ 0 ,
34963529
34973530 /* .tdrz_enable =*/ false ,
@@ -3653,7 +3686,7 @@ static void whisper_process_logits(
36533686 WHISPER_ASSERT (n_logits == ctx.vocab .n_vocab );
36543687
36553688 // extract the logits for the last token
3656- // we will be mutating and therefore we don't want to use the ctx.logits buffer directly
3689+ // we will be mutating, and therefore we don't want to use the ctx.logits buffer directly
36573690 auto & probs = decoder.probs ;
36583691 auto & logits = decoder.logits ;
36593692 auto & logprobs = decoder.logprobs ;
@@ -4056,10 +4089,9 @@ int whisper_full_with_state(
40564089
40574090 // compute log mel spectrogram
40584091 if (params.speed_up ) {
4059- if (whisper_pcm_to_mel_phase_vocoder_with_state (ctx, state, samples, n_samples, params.n_threads ) != 0 ) {
4060- log (" %s: failed to compute log mel spectrogram\n " , __func__);
4061- return -1 ;
4062- }
4092+ // TODO: Replace PV with more advanced algorithm
4093+ log (" %s: failed to compute log mel spectrogram\n " , __func__);
4094+ return -1 ;
40634095 } else {
40644096 if (whisper_pcm_to_mel_with_state (ctx, state, samples, n_samples, params.n_threads ) != 0 ) {
40654097 log (" %s: failed to compute log mel spectrogram\n " , __func__);
@@ -4095,8 +4127,8 @@ int whisper_full_with_state(
40954127 const int seek_start = params.offset_ms /10 ;
40964128 const int seek_end = params.duration_ms == 0 ? whisper_n_len_from_state (state) : seek_start + params.duration_ms /10 ;
40974129
4098- // if length of spectrogram is less than 1s (100 samples ), then return
4099- // basically don't process anything that is less than 1s
4130+ // if length of spectrogram is less than 1.0s (100 frames ), then return
4131+ // basically don't process anything that is less than 1.0s
41004132 // see issue #39: https://github.com/ggerganov/whisper.cpp/issues/39
41014133 if (seek_end < seek_start + (params.speed_up ? 50 : 100 )) {
41024134 return 0 ;
0 commit comments