Skip to content

Add logits support to whisper backbone #2134

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: master
Choose a base branch
from

Conversation

shanky-kapoor
Copy link

@shanky-kapoor shanky-kapoor commented Mar 10, 2025

Currently the WhisperBacbone does not support logits output as we can see here.

This would be a building block for other tasks under Whisper Casual LLM.

Test output has been validated with whisper's output: Colab

Building block for issue: #2074

@divyashreepathihalli
Copy link
Collaborator

/gemini review

Copy link

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request adds a logits method to the WhisperBackbone to support tasks like text generation. The change includes a new method in the backbone and a corresponding unit test using a preset model.

My review focuses on improving the robustness of the new logits method by suggesting a more direct way to access the token embedding weights, which will make the code less brittle to future changes.

Comment on lines +290 to +298
def logits(self, *args, **kwargs):
result = self(*args, **kwargs)
token_embedding = None
for embedding_type in self.decoder_embeddings.weights:
if "token_embedding" in embedding_type.path:
token_embedding = embedding_type
return keras.ops.matmul(
result["decoder_sequence_output"], keras.ops.transpose(token_embedding)
)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The current approach to retrieve the token embedding matrix by iterating through all decoder weights and matching a substring in the weight's path is fragile. It would be more robust to directly access the weights from the self.token_embedding layer, which is already an attribute of the class. This avoids a dependency on layer and weight naming conventions.

    def logits(self, *args, **kwargs):
        result = self(*args, **kwargs)
        # Directly access the embedding matrix from the token embedding layer.
        # This is more robust than searching for the weight by name in the path.
        token_embedding = self.token_embedding.weights[0]
        return keras.ops.matmul(
            result["decoder_sequence_output"], keras.ops.transpose(token_embedding)
        )

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants