Skip to content

Conversation

@akash5474
Copy link
Contributor

@akash5474 akash5474 commented Oct 9, 2022

Summary

Addresses #621

Trickle down norm_num_groups and use it to replace hardcoded values in GroupNorm.

Details

This PR makes the following changes:

  • Replaces hardcoded value of num_groups param passed to GroupNorm
  • Adds a param to specify number of norm groups to the following classes:
    • FlaxResnetBlock2D
    • FlaxAttentionBlock
    • FlaxDownEncoderBlock2D
    • FlaxUpDecoderBlock2D
    • FlaxUNetMidBlock2D
      • Set a default value if param is None
  • Creates FlaxModelTesterMixin class with test_output and test_forward_with_norm_groups tests
  • Creates FlaxAutoencoderKLTests class and sets up tests for model
  • Creates require_flax testing util decorator
  • Fixes typos:
    • Rename class FlaxUpEncoderBlock2D => FlaxUpDecoderBlock2D
    • Fix docstring for param add_downsample => add_upsample in FlaxUpDecoderBlock2D
    • Fix norm_num_groups docstring default value in FlaxEncoder

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Oct 9, 2022

The documentation is not available anymore as the PR was closed or merged.

@patrickvonplaten patrickvonplaten requested review from patil-suraj and pcuenca and removed request for pcuenca October 10, 2022 13:27
Copy link
Contributor

@patrickvonplaten patrickvonplaten left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great to me! Thanks @akash5474

@akash5474
Copy link
Contributor Author

akash5474 commented Oct 10, 2022

Awesome! Thanks for the opportunity.

I also ported a couple more tests from ModelTesterMixin (test_from_pretrained_save_pretrained and test_model_from_config) to the FlaxModelTesterMixin but I wasn't sure if I should include them in this PR or keep it more focused.

I'm happy to push those tests to this PR/branch or create a new one, do you have a preference?

@patrickvonplaten
Copy link
Contributor

Thanks a lot for adding the tests!

@patrickvonplaten patrickvonplaten merged commit a124204 into huggingface:main Oct 11, 2022
prathikr pushed a commit to prathikr/diffusers that referenced this pull request Oct 26, 2022
* pass norm_num_groups param and add tests

* set resnet_groups for FlaxUNetMidBlock2D

* fixed docstrings

* fixed typo

* using is_flax_available util and created require_flax decorator
yoonseokjin pushed a commit to yoonseokjin/diffusers that referenced this pull request Dec 25, 2023
* pass norm_num_groups param and add tests

* set resnet_groups for FlaxUNetMidBlock2D

* fixed docstrings

* fixed typo

* using is_flax_available util and created require_flax decorator
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants