-
Notifications
You must be signed in to change notification settings - Fork 293
Add logits support to whisper backbone #2134
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Conversation
/gemini review |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request adds a logits
method to the WhisperBackbone
to support tasks like text generation. The change includes a new method in the backbone and a corresponding unit test using a preset model.
My review focuses on improving the robustness of the new logits
method by suggesting a more direct way to access the token embedding weights, which will make the code less brittle to future changes.
def logits(self, *args, **kwargs): | ||
result = self(*args, **kwargs) | ||
token_embedding = None | ||
for embedding_type in self.decoder_embeddings.weights: | ||
if "token_embedding" in embedding_type.path: | ||
token_embedding = embedding_type | ||
return keras.ops.matmul( | ||
result["decoder_sequence_output"], keras.ops.transpose(token_embedding) | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The current approach to retrieve the token embedding matrix by iterating through all decoder weights and matching a substring in the weight's path is fragile. It would be more robust to directly access the weights from the self.token_embedding
layer, which is already an attribute of the class. This avoids a dependency on layer and weight naming conventions.
def logits(self, *args, **kwargs):
result = self(*args, **kwargs)
# Directly access the embedding matrix from the token embedding layer.
# This is more robust than searching for the weight by name in the path.
token_embedding = self.token_embedding.weights[0]
return keras.ops.matmul(
result["decoder_sequence_output"], keras.ops.transpose(token_embedding)
)
Currently the WhisperBacbone does not support logits output as we can see here.
This would be a building block for other tasks under Whisper Casual LLM.
Test output has been validated with whisper's output: Colab
Building block for issue: #2074