Skip to content

Commit ff2bc1a

Browse files
committed
fix memory leak
1 parent 607a66d commit ff2bc1a

File tree

1 file changed

+28
-13
lines changed

1 file changed

+28
-13
lines changed

crates/whisper-local/src/model.rs

Lines changed: 28 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ impl Whisper {
109109
let token_beg = self.token_beg;
110110
let language = self.get_language(audio)?;
111111

112-
let params = {
112+
let mut params = {
113113
let mut p = FullParams::new(SamplingStrategy::Greedy { best_of: 1 });
114114

115115
let parts = [self.dynamic_prompt.trim()];
@@ -124,10 +124,6 @@ impl Whisper {
124124

125125
p.set_initial_prompt(&initial_prompt);
126126

127-
unsafe {
128-
Self::set_logit_filter(&mut p, &token_beg, &self.bias_trie);
129-
}
130-
131127
p.set_no_timestamps(true);
132128
p.set_token_timestamps(false);
133129
p.set_split_on_word(true);
@@ -146,6 +142,8 @@ impl Whisper {
146142
p
147143
};
148144

145+
let _guard = unsafe { Self::set_logit_filter(&mut params, &token_beg, &self.bias_trie) };
146+
149147
self.state.full(params, &audio[..])?;
150148
let num_segments = self.state.full_n_segments();
151149

@@ -256,12 +254,7 @@ impl Whisper {
256254
params: &mut FullParams,
257255
token_beg: &WhisperTokenId,
258256
bias_trie: &BiasTrie,
259-
) {
260-
struct Context {
261-
token_beg: WhisperTokenId,
262-
bias_trie: BiasTrie,
263-
}
264-
257+
) -> LogitFilterGuard {
265258
let context = Box::new(Context {
266259
token_beg: *token_beg,
267260
bias_trie: bias_trie.clone(),
@@ -288,9 +281,31 @@ impl Whisper {
288281
.apply_bias_to_logits(tokens, n_tokens, logits);
289282
}
290283

284+
let context_ptr = Box::into_raw(context) as *mut std::ffi::c_void;
285+
291286
params.set_filter_logits_callback(Some(logits_filter_callback));
292-
params
293-
.set_filter_logits_callback_user_data(Box::into_raw(context) as *mut std::ffi::c_void);
287+
params.set_filter_logits_callback_user_data(context_ptr);
288+
289+
LogitFilterGuard { context_ptr }
290+
}
291+
}
292+
293+
struct Context {
294+
token_beg: WhisperTokenId,
295+
bias_trie: BiasTrie,
296+
}
297+
298+
struct LogitFilterGuard {
299+
context_ptr: *mut std::ffi::c_void,
300+
}
301+
302+
impl Drop for LogitFilterGuard {
303+
fn drop(&mut self) {
304+
if !self.context_ptr.is_null() {
305+
unsafe {
306+
let _ = Box::from_raw(self.context_ptr as *mut Context);
307+
}
308+
}
294309
}
295310
}
296311

0 commit comments

Comments
 (0)