Skip to content

Commit 2766aa8

Browse files
committed
one
1 parent 177eca9 commit 2766aa8

File tree

2 files changed

+26
-4
lines changed

2 files changed

+26
-4
lines changed

crates/whisper-local/src/model.rs

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -189,11 +189,26 @@ impl Whisper {
189189

190190
let lang_str = {
191191
self.state.pcm_to_mel(audio, 1)?;
192-
let (lang_id, _lang_probs) = self.state.lang_detect(0, 1)?;
193-
whisper_rs::get_lang_str(lang_id)
192+
let (_lang_id, lang_probs) = self.state.lang_detect(0, 1)?;
193+
194+
let mut best_lang = None;
195+
let mut best_prob = f32::NEG_INFINITY;
196+
197+
for lang in &self.languages {
198+
let lang_id = lang.whisper_index();
199+
if lang_id < lang_probs.len() {
200+
let prob = lang_probs[lang_id];
201+
if prob > best_prob {
202+
best_prob = prob;
203+
best_lang = Some(lang.as_ref().to_string());
204+
}
205+
}
206+
}
207+
208+
best_lang
194209
};
195210

196-
Ok(lang_str.map(|s| s.to_owned()))
211+
Ok(lang_str)
197212
}
198213

199214
fn filter_segments(segments: Vec<Segment>) -> Vec<Segment> {

crates/whisper/src/lib.rs

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
// https://github.com/openai/whisper/blob/ba3f3cd/whisper/tokenizer.py#L10-L128
2-
#[derive(Debug, strum::EnumString, strum::Display, strum::AsRefStr)]
2+
#[repr(u8)]
3+
#[derive(Debug, Copy, Clone, strum::EnumString, strum::Display, strum::AsRefStr)]
34
pub enum Language {
45
#[strum(serialize = "en")]
56
En,
@@ -202,3 +203,9 @@ pub enum Language {
202203
#[strum(serialize = "yue")]
203204
Yue,
204205
}
206+
207+
impl Language {
208+
pub fn whisper_index(self) -> usize {
209+
self as usize
210+
}
211+
}

0 commit comments

Comments
 (0)