-
Notifications
You must be signed in to change notification settings - Fork 6.5k
[UNet2DConditionModel, UNet2DModel] pass norm_num_groups to all the blocks #442
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
The documentation is not available anymore as the PR was closed or merged. |
|
Thank you everyone! |
patrickvonplaten
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very cool @patil-suraj !
|
|
||
| self.assertTrue(torch.allclose(output_slice, expected_output_slice, rtol=1e-2)) | ||
|
|
||
| def test_forward_with_norm_groups(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems to adjust norm groups inside the blocks cf https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/unet_blocks.py#L780
So not sure if it can work with any value because of the division with 4
patrickvonplaten
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry @patil-suraj - let's really try to not change existing test (except for they are wrong). It would be really nice to just add a new test here
Agree, I have already added a new common test |
…locks (huggingface#442) * pass norm_num_groups to unet blocs and attention * fix UNet2DConditionModel * add norm_num_groups arg in vae * add tests * remove comment * Apply suggestions from code review
This PR removed the hardcoded value of 32 for
norm_num_groupsand makes sure that the arg is passed to all up, down, res, and attention blocks.Fixes #410