@@ -118,6 +118,21 @@ static void byteswap_tensor(ggml_tensor * tensor) {
118118#define WHISPER_USE_SCRATCH
119119#define WHISPER_MAX_SCRATCH_BUFFERS 16
120120
121+ //
122+ // ggml helpers
123+ //
124+
125+ static void ggml_graph_compute_helper (std::vector<uint8_t > & buf, ggml_cgraph * graph, int n_threads) {
126+ struct ggml_cplan plan = ggml_graph_plan (graph, n_threads);
127+
128+ if (plan.work_size > 0 ) {
129+ buf.resize (plan.work_size );
130+ plan.work_data = buf.data ();
131+ }
132+
133+ ggml_graph_compute (graph, &plan);
134+ }
135+
121136// available whisper models
122137enum e_model {
123138 MODEL_UNKNOWN,
@@ -666,6 +681,7 @@ struct whisper_state {
666681
667682 // memory buffers used by encode / decode contexts
668683 std::vector<uint8_t > buf_compute;
684+ std::vector<uint8_t > buf_work;
669685 std::vector<uint8_t > buf_scratch[WHISPER_MAX_SCRATCH_BUFFERS];
670686
671687 int buf_last = 0 ;
@@ -1830,8 +1846,8 @@ static bool whisper_encode_internal(
18301846 {
18311847 struct ggml_cgraph gf = {};
18321848
1833- ggml_build_forward_expand (&gf, cur);
1834- ggml_graph_compute_with_ctx (ctx0 , &gf, n_threads);
1849+ ggml_build_forward_expand (&gf, cur);
1850+ ggml_graph_compute_helper (wstate. buf_work , &gf, n_threads);
18351851
18361852 // ggml_graph_print(&gf);
18371853 }
@@ -1916,7 +1932,7 @@ static bool whisper_encode_internal(
19161932 ggml_build_forward_expand (&gf, ggml_cpy (ctx0, Vcross, v));
19171933 }
19181934
1919- ggml_graph_compute_with_ctx (ctx0 , &gf, n_threads);
1935+ ggml_graph_compute_helper (wstate. buf_work , &gf, n_threads);
19201936 // ggml_graph_print(&gf);
19211937 }
19221938
@@ -2329,8 +2345,8 @@ static bool whisper_decode_internal(
23292345
23302346 // run the computation
23312347 {
2332- ggml_build_forward_expand (&gf, logits);
2333- ggml_graph_compute_with_ctx (ctx0 , &gf, n_threads);
2348+ ggml_build_forward_expand (&gf, logits);
2349+ ggml_graph_compute_helper (wstate. buf_work , &gf, n_threads);
23342350 }
23352351
23362352 // extract logits for all N tokens
@@ -5225,7 +5241,8 @@ WHISPER_API const char * whisper_bench_ggml_mul_mat_str(int n_threads) {
52255241 // b: N*N*sizeof(float)
52265242 // c: N*N*sizeof(float)
52275243 // when F16 is used, there is an extra work buffer of size N*N*sizeof(float)
5228- std::vector<char > buf (4llu*N_max*N_max*sizeof (float ) + 4 *512 );
5244+ std::vector<uint8_t > buf (3llu*N_max*N_max*sizeof (float ) + 3 *ggml_tensor_overhead ());
5245+ std::vector<uint8_t > work (1llu*N_max*N_max*sizeof (float ) + 1 *ggml_tensor_overhead ());
52295246
52305247 // put a bunch of random data in the buffer
52315248 for (size_t i = 0 ; i < buf.size (); i++) buf[i] = i;
@@ -5280,12 +5297,12 @@ WHISPER_API const char * whisper_bench_ggml_mul_mat_str(int n_threads) {
52805297 double tsum = 0.0 ;
52815298
52825299 // heat-up
5283- ggml_graph_compute_with_ctx (ctx0 , &gf, n_threads);
5300+ ggml_graph_compute_helper (work , &gf, n_threads);
52845301
52855302 for (int i = 0 ; i < n_max; ++i) {
52865303 const int64_t t0 = ggml_time_us ();
52875304
5288- ggml_graph_compute_with_ctx (ctx0 , &gf, n_threads);
5305+ ggml_graph_compute_helper (work , &gf, n_threads);
52895306
52905307 const int64_t t1 = ggml_time_us ();
52915308
0 commit comments