Skip to content

Conversation

@mdda
Copy link
Contributor

@mdda mdda commented Feb 28, 2025

Adds Gemma2-2b (including GQA to Attention and fixes to Block)

Fixes #4567

It includes:

  • New TransformerConfig for gemma2-2b model
  • Renames existing configs to make them uniform
    • NB: This cannot affect existing code, since this module was previously unusable
  • Added additional params reading key adjustment (key in gemma2-2b needs remapping from kaggle download)
  • Adds GQA to Attention module
  • Reorders the operations in Block module so that logits output from overall Transformer are not gibberish
    • logits confirmed to (approximately) match those from GDE gemma (flax linen) model
  • No new documentation provided
    • This change would make the example in the nnx documentation work (did not work before)
  • No additional tests provided

@cgarciae
Copy link
Collaborator

cgarciae commented Mar 4, 2025

Hey Martin! Thanks for doing this.
Some folks are internally also improving the model. Let me merge their changes first and make sure there are no conflicts with those changes.

@mdda
Copy link
Contributor Author

mdda commented Mar 5, 2025

Ahhh - Brings back memories of PRs for TensorFlow. Good times! /s

I'll attempt fix the white-space check failures, when you let me know that it isn't a waste of my time.

@classmethod
def gemma_27b(cls):
num_layers = 46
def gemma2_2b(cls):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why don't we add a new method for gemma_27b configuration instead of removing the gemma2_2b one?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They are both in the file - it's just the git that didn't pick up the diff properly.

But also, it makes sense to rename the different 'generations' of gemma, gemma2, gemma3, etc as separate classes, rather than relying on implicit knowledge about which size came from where.

Moreover, the gemma2_27 didn't have the right normalisation in the attention - will need a separate fix.

I also see that the Google-generated PR adopted the same fix as mine to the Block module. Good for you!

@cgarciae
Copy link
Collaborator

Thanks @mdda for doing this, it took a while because there where other changes to the model in the background that were pending. I think this is great. Can you please add a test to check that the GQA configuration works and matches the base version?

Also now need to solve the conflicts, sorry about this, this code is being used by a couple of users internally.

@mdda
Copy link
Contributor Author

mdda commented Mar 15, 2025

Surprised that the code was being used internally prior to my PR, since the Block module was entirely borked.

@cgarciae
Copy link
Collaborator

@mdda please take a look at CI, you probably need to run:

pip install pre-commit
pre-commit run --all-files

@vfdev-5
Copy link
Collaborator

vfdev-5 commented Jul 21, 2025

Hey @mdda , I wonder whether you would be able to finalize this PR? If you are busy, we can take over it keeping your commits (such that you will be credited for the work you have done). Please, let me know. Thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Some surprises when adding Gemma2-2b to flax/examples/gemma

3 participants