Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 22 additions & 22 deletions bindings/go/pkg/whisper/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -118,12 +118,12 @@ func (context *context) SetMaxTokensPerSegment(n uint) {

// ResetTimings resets the mode timings. Should be called before processing
func (context *context) ResetTimings() {
context.model.ctx.Whisper_reset_timings()
context.model.state.Whisper_reset_timings()
}

// PrintTimings prints the model timings to stdout.
func (context *context) PrintTimings() {
context.model.ctx.Whisper_print_timings()
context.model.ctx.Whisper_print_timings(context.model.state)
}

// SystemInfo returns the system information
Expand All @@ -139,7 +139,7 @@ func (context *context) SystemInfo() string {
// Make sure to call whisper_pcm_to_mel() or whisper_set_mel() first.
// Returns the probabilities of all languages.
func (context *context) WhisperLangAutoDetect(offset_ms int, n_threads int) ([]float32, error) {
langProbs, err := context.model.ctx.Whisper_lang_auto_detect(offset_ms, n_threads)
langProbs, err := context.model.ctx.Whisper_lang_auto_detect(context.model.state, offset_ms, n_threads)
if err != nil {
return nil, err
}
Expand All @@ -159,23 +159,23 @@ func (context *context) Process(data []float32, cb SegmentCallback) error {
// We don't do parallel processing at the moment
processors := 0
if processors > 1 {
if err := context.model.ctx.Whisper_full_parallel(context.params, data, processors, nil, func(new int) {
if err := context.model.ctx.Whisper_full_parallel(context.model.state, context.params, data, processors, nil, func(new int) {
if cb != nil {
num_segments := context.model.ctx.Whisper_full_n_segments()
num_segments := context.model.state.Whisper_full_n_segments()
s0 := num_segments - new
for i := s0; i < num_segments; i++ {
cb(toSegment(context.model.ctx, i))
cb(toSegment(context.model, i))
}
}
}); err != nil {
return err
}
} else if err := context.model.ctx.Whisper_full(context.params, data, nil, func(new int) {
} else if err := context.model.ctx.Whisper_full_with_state( context.model.state, context.params, data, nil, func(new int) {
if cb != nil {
num_segments := context.model.ctx.Whisper_full_n_segments()
num_segments := context.model.state.Whisper_full_n_segments()
s0 := num_segments - new
for i := s0; i < num_segments; i++ {
cb(toSegment(context.model.ctx, i))
cb(toSegment(context.model, i))
}
}
}); err != nil {
Expand All @@ -188,15 +188,15 @@ func (context *context) Process(data []float32, cb SegmentCallback) error {

// Return the next segment of tokens
func (context *context) NextSegment() (Segment, error) {
if context.model.ctx == nil {
if context.model.state == nil {
return Segment{}, ErrInternalAppError
}
if context.n >= context.model.ctx.Whisper_full_n_segments() {
if context.n >= context.model.state.Whisper_full_n_segments() {
return Segment{}, io.EOF
}

// Populate result
result := toSegment(context.model.ctx, context.n)
result := toSegment(context.model, context.n)

// Increment the cursor
context.n++
Expand Down Expand Up @@ -267,23 +267,23 @@ func (context *context) IsLANG(t Token, lang string) bool {
///////////////////////////////////////////////////////////////////////////////
// PRIVATE METHODS

func toSegment(ctx *whisper.Context, n int) Segment {
func toSegment(model *model, n int) Segment {
return Segment{
Num: n,
Text: strings.TrimSpace(ctx.Whisper_full_get_segment_text(n)),
Start: time.Duration(ctx.Whisper_full_get_segment_t0(n)) * time.Millisecond * 10,
End: time.Duration(ctx.Whisper_full_get_segment_t1(n)) * time.Millisecond * 10,
Tokens: toTokens(ctx, n),
Text: strings.TrimSpace(model.state.Whisper_full_get_segment_text(n)),
Start: time.Duration(model.state.Whisper_full_get_segment_t0(n)) * time.Millisecond * 10,
End: time.Duration(model.state.Whisper_full_get_segment_t1(n)) * time.Millisecond * 10,
Tokens: toTokens(model, n),
}
}

func toTokens(ctx *whisper.Context, n int) []Token {
result := make([]Token, ctx.Whisper_full_n_tokens(n))
func toTokens(model *model, n int) []Token {
result := make([]Token, model.state.Whisper_full_n_tokens(n))
for i := 0; i < len(result); i++ {
result[i] = Token{
Id: int(ctx.Whisper_full_get_token_id(n, i)),
Text: strings.TrimSpace(ctx.Whisper_full_get_token_text(n, i)),
P: ctx.Whisper_full_get_token_p(n, i),
Id: int(model.state.Whisper_full_get_token_id(n, i)),
Text: strings.TrimSpace(model.ctx.Whisper_full_get_token_text(model.state, n, i)),
P: model.state.Whisper_full_get_token_p(n, i),
}
}
return result
Expand Down
11 changes: 11 additions & 0 deletions bindings/go/pkg/whisper/model.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import (
type model struct {
path string
ctx *whisper.Context
state *whisper.State
}

// Make sure model adheres to the interface
Expand All @@ -30,6 +31,11 @@ func New(path string) (Model, error) {
} else if ctx := whisper.Whisper_init(path); ctx == nil {
return nil, ErrUnableToLoadModel
} else {
state := ctx.Whisper_init_state();
if state == nil {
return nil, ErrUnableToLoadModel
}
model.state = state
model.ctx = ctx
model.path = path
}
Expand All @@ -43,8 +49,13 @@ func (model *model) Close() error {
model.ctx.Whisper_free()
}

if model.state != nil {
model.state.Whisper_free_state()
}

// Release resources
model.ctx = nil
model.state = nil

// Return success
return nil
Expand Down
108 changes: 68 additions & 40 deletions bindings/go/whisper.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,17 +20,17 @@ extern bool callEncoderBegin(void* user_data);
// Text segment callback
// Called on every newly generated text segment
// Use the whisper_full_...() functions to obtain the text segments
static void whisper_new_segment_cb(struct whisper_context* ctx, int n_new, void* user_data) {
if(user_data != NULL && ctx != NULL) {
static void whisper_new_segment_cb(struct whisper_context* ctx, struct whisper_state* state, int n_new, void* user_data) {
if(user_data != NULL && ctx != NULL && state != NULL) {
callNewSegment(user_data, n_new);
}
}

// Encoder begin callback
// If not NULL, called before the encoder starts
// If it returns false, the computation is aborted
static bool whisper_encoder_begin_cb(struct whisper_context* ctx, void* user_data) {
if(user_data != NULL && ctx != NULL) {
static bool whisper_encoder_begin_cb(struct whisper_context* ctx, struct whisper_state* state, void* user_data) {
if(user_data != NULL && ctx != NULL && state != NULL) {
return callEncoderBegin(user_data);
}
return false;
Expand All @@ -53,6 +53,7 @@ import "C"

type (
Context C.struct_whisper_context
State C.struct_whisper_state
Token C.whisper_token
TokenData C.struct_whisper_token_data
SamplingStrategy C.enum_whisper_sampling_strategy
Expand Down Expand Up @@ -98,15 +99,28 @@ func Whisper_init(path string) *Context {
}
}

func (ctx *Context) Whisper_init_state() *State {
state := C.whisper_init_state((*C.struct_whisper_context)(ctx))
if state != nil {
return (*State)(state)
} else {
return nil
}
}

// Frees all memory allocated by the model.
func (ctx *Context) Whisper_free() {
C.whisper_free((*C.struct_whisper_context)(ctx))
}

func (state *State) Whisper_free_state() {
C.whisper_free_state((*C.struct_whisper_state)(state))
}

// Convert RAW PCM audio to log mel spectrogram.
// The resulting spectrogram is stored inside the provided whisper context.
func (ctx *Context) Whisper_pcm_to_mel(data []float32, threads int) error {
if C.whisper_pcm_to_mel((*C.struct_whisper_context)(ctx), (*C.float)(&data[0]), C.int(len(data)), C.int(threads)) == 0 {
func (ctx *Context) Whisper_pcm_to_mel(state *State, data []float32, threads int) error {
if C.whisper_pcm_to_mel((*C.struct_whisper_context)(ctx),(*C.struct_whisper_state)(state), (*C.float)(&data[0]), C.int(len(data)), C.int(threads)) == 0 {
return nil
} else {
return ErrConversionFailed
Expand All @@ -116,8 +130,8 @@ func (ctx *Context) Whisper_pcm_to_mel(data []float32, threads int) error {
// This can be used to set a custom log mel spectrogram inside the provided whisper context.
// Use this instead of whisper_pcm_to_mel() if you want to provide your own log mel spectrogram.
// n_mel must be 80
func (ctx *Context) Whisper_set_mel(data []float32, n_mel int) error {
if C.whisper_set_mel((*C.struct_whisper_context)(ctx), (*C.float)(&data[0]), C.int(len(data)), C.int(n_mel)) == 0 {
func (state *State) Whisper_set_mel(data []float32, n_mel int) error {
if C.whisper_set_mel((*C.struct_whisper_state)(state), (*C.float)(&data[0]), C.int(len(data)), C.int(n_mel)) == 0 {
return nil
} else {
return ErrConversionFailed
Expand All @@ -127,8 +141,8 @@ func (ctx *Context) Whisper_set_mel(data []float32, n_mel int) error {
// Run the Whisper encoder on the log mel spectrogram stored inside the provided whisper context.
// Make sure to call whisper_pcm_to_mel() or whisper_set_mel() first.
// offset can be used to specify the offset of the first frame in the spectrogram.
func (ctx *Context) Whisper_encode(offset, threads int) error {
if C.whisper_encode((*C.struct_whisper_context)(ctx), C.int(offset), C.int(threads)) == 0 {
func (ctx *Context) Whisper_encode(state *State, offset, threads int) error {
if C.whisper_encode((*C.struct_whisper_context)(ctx), (*C.struct_whisper_state)(state), C.int(offset), C.int(threads)) == 0 {
return nil
} else {
return ErrConversionFailed
Expand All @@ -139,8 +153,8 @@ func (ctx *Context) Whisper_encode(offset, threads int) error {
// Make sure to call whisper_encode() first.
// tokens + n_tokens is the provided context for the decoder.
// n_past is the number of tokens to use from previous decoder calls.
func (ctx *Context) Whisper_decode(tokens []Token, past, threads int) error {
if C.whisper_decode((*C.struct_whisper_context)(ctx), (*C.whisper_token)(&tokens[0]), C.int(len(tokens)), C.int(past), C.int(threads)) == 0 {
func (ctx *Context) Whisper_decode(state *State, tokens []Token, past, threads int) error {
if C.whisper_decode((*C.struct_whisper_context)(ctx), (*C.struct_whisper_state)(state), (*C.whisper_token)(&tokens[0]), C.int(len(tokens)), C.int(past), C.int(threads)) == 0 {
return nil
} else {
return ErrConversionFailed
Expand Down Expand Up @@ -183,17 +197,17 @@ func Whisper_lang_str(id int) string {
// Make sure to call whisper_pcm_to_mel() or whisper_set_mel() first.
// Returns the probabilities of all languages.
// ref: https://github.com/openai/whisper/blob/main/whisper/decoding.py#L18-L69
func (ctx *Context) Whisper_lang_auto_detect(offset_ms, n_threads int) ([]float32, error) {
func (ctx *Context) Whisper_lang_auto_detect(state *State, offset_ms, n_threads int) ([]float32, error) {
probs := make([]float32, Whisper_lang_max_id()+1)
if n := int(C.whisper_lang_auto_detect((*C.struct_whisper_context)(ctx), C.int(offset_ms), C.int(n_threads), (*C.float)(&probs[0]))); n < 0 {
if n := int(C.whisper_lang_auto_detect((*C.struct_whisper_context)(ctx), (*C.struct_whisper_state)(state), C.int(offset_ms), C.int(n_threads), (*C.float)(&probs[0]))); n < 0 {
return nil, ErrAutoDetectFailed
} else {
return probs, nil
}
}

func (ctx *Context) Whisper_n_len() int {
return int(C.whisper_n_len((*C.struct_whisper_context)(ctx)))
func (state *State) Whisper_n_len() int {
return int(C.whisper_n_len((*C.struct_whisper_state)(state)))
}

func (ctx *Context) Whisper_n_vocab() int {
Expand Down Expand Up @@ -268,13 +282,13 @@ func Whisper_token_transcribe() Token {
}

// Performance information
func (ctx *Context) Whisper_print_timings() {
C.whisper_print_timings((*C.struct_whisper_context)(ctx))
func (ctx *Context) Whisper_print_timings(state *State) {
C.whisper_print_timings((*C.struct_whisper_context)(ctx), (*C.struct_whisper_state)(state))
}

// Performance information
func (ctx *Context) Whisper_reset_timings() {
C.whisper_reset_timings((*C.struct_whisper_context)(ctx))
func (state *State) Whisper_reset_timings() {
C.whisper_reset_timings((*C.struct_whisper_state)(state))
}

// Print system information
Expand Down Expand Up @@ -302,16 +316,30 @@ func (ctx *Context) Whisper_full(params Params, samples []float32, encoderBeginC
}
}

// Run the entire model: PCM -> log mel spectrogram -> encoder -> decoder -> text with the given state
// Uses the specified decoding strategy to obtain the text.
func (ctx *Context) Whisper_full_with_state(state *State, params Params, samples []float32, encoderBeginCallback func() bool, newSegmentCallback func(int)) error {
registerEncoderBeginCallback(ctx, encoderBeginCallback)
registerNewSegmentCallback(ctx, newSegmentCallback)
defer registerEncoderBeginCallback(ctx, nil)
defer registerNewSegmentCallback(ctx, nil)
if C.whisper_full_with_state((*C.struct_whisper_context)(ctx), (*C.struct_whisper_state)(state), (C.struct_whisper_full_params)(params), (*C.float)(&samples[0]), C.int(len(samples))) == 0 {
return nil
} else {
return ErrConversionFailed
}
}

// Split the input audio in chunks and process each chunk separately using whisper_full()
// It seems this approach can offer some speedup in some cases.
// However, the transcription accuracy can be worse at the beginning and end of each chunk.
func (ctx *Context) Whisper_full_parallel(params Params, samples []float32, processors int, encoderBeginCallback func() bool, newSegmentCallback func(int)) error {
func (ctx *Context) Whisper_full_parallel(state *State, params Params, samples []float32, processors int, encoderBeginCallback func() bool, newSegmentCallback func(int)) error {
registerEncoderBeginCallback(ctx, encoderBeginCallback)
registerNewSegmentCallback(ctx, newSegmentCallback)
defer registerEncoderBeginCallback(ctx, nil)
defer registerNewSegmentCallback(ctx, nil)

if C.whisper_full_parallel((*C.struct_whisper_context)(ctx), (C.struct_whisper_full_params)(params), (*C.float)(&samples[0]), C.int(len(samples)), C.int(processors)) == 0 {
if C.whisper_full_parallel((*C.struct_whisper_context)(ctx), (*C.struct_whisper_state)(state), (C.struct_whisper_full_params)(params), (*C.float)(&samples[0]), C.int(len(samples)), C.int(processors)) == 0 {
return nil
} else {
return ErrConversionFailed
Expand All @@ -320,49 +348,49 @@ func (ctx *Context) Whisper_full_parallel(params Params, samples []float32, proc

// Number of generated text segments.
// A segment can be a few words, a sentence, or even a paragraph.
func (ctx *Context) Whisper_full_n_segments() int {
return int(C.whisper_full_n_segments((*C.struct_whisper_context)(ctx)))
func (state *State) Whisper_full_n_segments() int {
return int(C.whisper_full_n_segments((*C.struct_whisper_state)(state)))
}

// Get the start and end time of the specified segment.
func (ctx *Context) Whisper_full_get_segment_t0(segment int) int64 {
return int64(C.whisper_full_get_segment_t0((*C.struct_whisper_context)(ctx), C.int(segment)))
func (state *State) Whisper_full_get_segment_t0(segment int) int64 {
return int64(C.whisper_full_get_segment_t0((*C.struct_whisper_state)(state), C.int(segment)))
}

// Get the start and end time of the specified segment.
func (ctx *Context) Whisper_full_get_segment_t1(segment int) int64 {
return int64(C.whisper_full_get_segment_t1((*C.struct_whisper_context)(ctx), C.int(segment)))
func (state *State) Whisper_full_get_segment_t1(segment int) int64 {
return int64(C.whisper_full_get_segment_t1((*C.struct_whisper_state)(state), C.int(segment)))
}

// Get the text of the specified segment.
func (ctx *Context) Whisper_full_get_segment_text(segment int) string {
return C.GoString(C.whisper_full_get_segment_text((*C.struct_whisper_context)(ctx), C.int(segment)))
func (state *State) Whisper_full_get_segment_text(segment int) string {
return C.GoString(C.whisper_full_get_segment_text((*C.struct_whisper_state)(state), C.int(segment)))
}

// Get number of tokens in the specified segment.
func (ctx *Context) Whisper_full_n_tokens(segment int) int {
return int(C.whisper_full_n_tokens((*C.struct_whisper_context)(ctx), C.int(segment)))
func (state *State) Whisper_full_n_tokens(segment int) int {
return int(C.whisper_full_n_tokens((*C.struct_whisper_state)(state), C.int(segment)))
}

// Get the token text of the specified token index in the specified segment.
func (ctx *Context) Whisper_full_get_token_text(segment int, token int) string {
return C.GoString(C.whisper_full_get_token_text((*C.struct_whisper_context)(ctx), C.int(segment), C.int(token)))
func (ctx *Context) Whisper_full_get_token_text(state *State, segment int, token int) string {
return C.GoString(C.whisper_full_get_token_text((*C.struct_whisper_context)(ctx), (*C.struct_whisper_state)(state), C.int(segment), C.int(token)))
}

// Get the token of the specified token index in the specified segment.
func (ctx *Context) Whisper_full_get_token_id(segment int, token int) Token {
return Token(C.whisper_full_get_token_id((*C.struct_whisper_context)(ctx), C.int(segment), C.int(token)))
func (state *State) Whisper_full_get_token_id(segment int, token int) Token {
return Token(C.whisper_full_get_token_id((*C.struct_whisper_state)(state), C.int(segment), C.int(token)))
}

// Get token data for the specified token in the specified segment.
// This contains probabilities, timestamps, etc.
func (ctx *Context) whisper_full_get_token_data(segment int, token int) TokenData {
return TokenData(C.whisper_full_get_token_data((*C.struct_whisper_context)(ctx), C.int(segment), C.int(token)))
func (state *State) whisper_full_get_token_data(segment int, token int) TokenData {
return TokenData(C.whisper_full_get_token_data((*C.struct_whisper_state)(state), C.int(segment), C.int(token)))
}

// Get the probability of the specified token in the specified segment.
func (ctx *Context) Whisper_full_get_token_p(segment int, token int) float32 {
return float32(C.whisper_full_get_token_p((*C.struct_whisper_context)(ctx), C.int(segment), C.int(token)))
func (state *State) Whisper_full_get_token_p(segment int, token int) float32 {
return float32(C.whisper_full_get_token_p((*C.struct_whisper_state)(state), C.int(segment), C.int(token)))
}

///////////////////////////////////////////////////////////////////////////////
Expand Down
Loading