-
Notifications
You must be signed in to change notification settings - Fork 406
Whisper custom vocab #1417
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Whisper custom vocab #1417
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,88 @@ | ||
| use std::hint::black_box; | ||
| use std::time::Duration; | ||
|
|
||
| use criterion::{criterion_group, criterion_main, Criterion}; | ||
| use hypr_whisper::Language; | ||
| use whisper_local::Whisper; | ||
|
|
||
| fn benchmark_whisper_transcription(c: &mut Criterion) { | ||
| let audio: Vec<f32> = hypr_data::english_1::AUDIO | ||
| .chunks_exact(2) | ||
| .map(|chunk| i16::from_le_bytes([chunk[0], chunk[1]]) as f32 / 32768.0) | ||
| .collect(); | ||
|
|
||
| let model_path = concat!(env!("CARGO_MANIFEST_DIR"), "/model.bin"); | ||
|
|
||
| let mut whisper_without_vocab = Whisper::builder() | ||
| .model_path(model_path) | ||
| .languages(vec![Language::En]) | ||
| .build() | ||
| .unwrap(); | ||
|
|
||
| let mut whisper_with_vocab = Whisper::builder() | ||
| .model_path(model_path) | ||
| .languages(vec![Language::En]) | ||
| .vocabulary( | ||
| vec![ | ||
| "profound", | ||
| "acquire", | ||
| "complementary", | ||
| "deeply", | ||
| "repositories", | ||
| "brilliant", | ||
| "pockets", | ||
| "thread", | ||
| "stumbling", | ||
| "stumble", | ||
| "communities", | ||
| "invested", | ||
| "undergrad", | ||
| "Googleable", | ||
| "exploring", | ||
| "neuroscientist", | ||
| "psychology", | ||
| "engineering", | ||
| "researcher", | ||
| "thinker", | ||
| "skill", | ||
| "invest", | ||
| "solved", | ||
| "entire", | ||
| "especially", | ||
| "actually", | ||
| "often", | ||
| "already", | ||
| "important", | ||
| "definitely", | ||
| "much", | ||
| ] | ||
| .into_iter() | ||
| .map(|s| s.into()) | ||
| .collect(), | ||
| ) | ||
| .build() | ||
| .unwrap(); | ||
|
|
||
| let mut group = c.benchmark_group("whisper_comparison"); | ||
| group.measurement_time(Duration::from_secs(100)); | ||
| group.sample_size(10); | ||
|
|
||
| group.bench_function("without_vocab", |b| { | ||
| b.iter(|| { | ||
| let segments = whisper_without_vocab.transcribe(black_box(&audio)).unwrap(); | ||
| black_box(segments) | ||
| }) | ||
| }); | ||
|
|
||
| group.bench_function("with_vocab", |b| { | ||
| b.iter(|| { | ||
| let segments = whisper_with_vocab.transcribe(black_box(&audio)).unwrap(); | ||
| black_box(segments) | ||
| }) | ||
| }); | ||
|
|
||
yujonglee marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| group.finish(); | ||
| } | ||
|
|
||
| criterion_group!(benches, benchmark_whisper_transcription); | ||
| criterion_main!(benches); | ||
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,93 @@ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| use trie_rs::map::{Trie, TrieBuilder}; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| use whisper_rs::{WhisperContext, WhisperTokenId}; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| #[derive(Clone)] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| pub struct BiasTrie { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| trie: Trie<WhisperTokenId, f32>, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| impl BiasTrie { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| pub fn new(ctx: &WhisperContext, custom_vocab: &[&str]) -> Result<Self, crate::Error> { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| let mut builder = TrieBuilder::new(); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| for word in custom_vocab { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| let variants = Self::generate_tokenization_variants(ctx, word)?; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| for tokens in variants { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| for i in 1..=tokens.len() { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| let progress = i as f32 / tokens.len() as f32; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| let prefix_bias = 10.0 + 90.0 * progress.powi(2); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| let prefix = &tokens[..i]; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| builder.push(prefix, prefix_bias); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+13
to
+26
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🛠️ Refactor suggestion Deduplicate overlapping prefixes across variants; skip empty tokenizations. Multiple variants can yield identical prefixes; pushing each inflates/overwrites bias unpredictably. Aggregate by prefix and keep the max bias. - for word in custom_vocab {
- let variants = Self::generate_tokenization_variants(ctx, word)?;
- for tokens in variants {
- for i in 1..=tokens.len() {
- let progress = i as f32 / tokens.len() as f32;
-
- let prefix_bias = 10.0 + 90.0 * progress.powi(2);
-
- let prefix = &tokens[..i];
- builder.push(prefix, prefix_bias);
- }
- }
- }
+ let mut acc: HashMap<Vec<WhisperTokenId>, f32> = HashMap::new();
+ for word in custom_vocab {
+ let variants = Self::generate_tokenization_variants(ctx, word)?;
+ for tokens in variants {
+ if tokens.is_empty() { continue; }
+ for i in 1..=tokens.len() {
+ let progress = i as f32 / tokens.len() as f32;
+ let bias = 10.0 + 90.0 * progress.powi(2);
+ let key = tokens[..i].to_vec();
+ acc.entry(key).and_modify(|v| *v = v.max(bias)).or_insert(bias);
+ }
+ }
+ }
+ for (k, v) in acc {
+ builder.push(&k, v);
+ }📝 Committable suggestion
Suggested change
🤖 Prompt for AI Agents |
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| let trie = builder.build(); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| Ok(BiasTrie { trie }) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| fn generate_tokenization_variants( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ctx: &WhisperContext, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| word: &str, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) -> Result<Vec<Vec<WhisperTokenId>>, crate::Error> { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| let mut variants = Vec::new(); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| variants.push(ctx.tokenize(word, 99)?); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| variants.push(ctx.tokenize(&format!(" {}", word), 99)?); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| let lower = word.to_lowercase(); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if lower != word { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| variants.push(ctx.tokenize(&lower, 99)?); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| variants.push(ctx.tokenize(&format!(" {}", lower), 99)?); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| let upper = word.to_uppercase(); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if upper != word { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| variants.push(ctx.tokenize(&upper, 99)?); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| variants.push(ctx.tokenize(&format!("'{}", word), 99)?); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| variants.push(ctx.tokenize(&format!("\"{}", word), 99)?); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| Ok(variants) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| pub unsafe fn apply_bias_to_logits( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| &self, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| tokens: *const whisper_rs::whisper_rs_sys::whisper_token_data, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| n_tokens: std::os::raw::c_int, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| logits: *mut f32, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if tokens.is_null() || n_tokens <= 0 { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| let current_tokens: Vec<WhisperTokenId> = | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| std::slice::from_raw_parts(tokens, n_tokens as usize) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| .iter() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| .map(|t| t.id) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| .collect(); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| for suffix_len in 1..=std::cmp::min(10, current_tokens.len()) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| let suffix = ¤t_tokens[current_tokens.len() - suffix_len..]; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| for (full_sequence, bias_value_ref) in self.trie.predictive_search(suffix) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| let bias_value = *bias_value_ref; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| let full_sequence: Vec<WhisperTokenId> = full_sequence; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if full_sequence.len() > suffix.len() { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| let next_token = full_sequence[suffix.len()]; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| let current_logit = *logits.offset(next_token as isize); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| let boost = bias_value.ln() * 2.0; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| let new_logit = current_logit + boost; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| *logits.offset(next_token as isize) = new_logit; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+58
to
+92
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 💡 Verification agent 🧩 Analysis chainUpdate call sites to pass n_vocab and validate usage. Use this script to find and fix all call sites: 🏁 Script executed: #!/bin/bash
# Find apply_bias_to_logits call sites and show context
rg -n -C3 --type=rust '\.apply_bias_to_logits\s*\('Length of output: 518 Update all call sites to pass .apply_bias_to_logits(tokens, n_tokens, logits)to .apply_bias_to_logits(tokens, n_tokens, n_vocab, logits)and ensure the method signature in 🤖 Prompt for AI Agents |
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
Uh oh!
There was an error while loading. Please reload this page.