-
Notifications
You must be signed in to change notification settings - Fork 6.5k
UNet Flax with FlaxModelMixin #502
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Mimic the structure of the PyTorch files. The model classes themselves need work, depending on what we do about configuration and initialization.
For some reason the configuration is not correctly applied; the signature of the `__init__` method does not contain all the parameters by the time it's inspected in `extract_init_dict`.
|
The documentation is not available anymore as the PR was closed or merged. |
Co-authored-by: Suraj Patil <[email protected]>
Co-authored-by: Mishig Davaadorj <[email protected]>
Co-authored-by: Patrick von Platen <[email protected]>
…fusers into flax_model_mixin
…fusers into flax_model_mixin
patrickvonplaten
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks very nice! Think it's just about deleting some dead code and maybe fix the dropouts everywhere :-)
Note that weights were exported with the old names, so we need to be careful.
|
Changed since last review:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good in general! Left some comments, more specifically
- We should add the
dropoutlayers, I didn't add it in the original repo as it was just for inference. - Make sure that the module and weight names match 1:1 with PyTorch. This is required as we need to provide interoperability with PT and flax models.
| # Weights were exported with old names {to_q, to_k, to_v, to_out} | ||
| self.query = nn.Dense(inner_dim, use_bias=False, dtype=self.dtype, name="to_q") | ||
| self.key = nn.Dense(inner_dim, use_bias=False, dtype=self.dtype, name="to_k") | ||
| self.value = nn.Dense(inner_dim, use_bias=False, dtype=self.dtype, name="to_v") | ||
|
|
||
| self.proj_attn = nn.Dense(self.query_dim, dtype=self.dtype, name="to_out") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(nit),
since we are using setup here could just use self.to_q = nn.Dense(....) instead of passing name. This will also make it easy to compare flax and pt code when reading.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, the original name was self.to_q, I changed it here to make it like the renamed PyTorch version but kept the same weight names.
| import jax.numpy as jnp | ||
|
|
||
|
|
||
| class FlaxAttentionBlock(nn.Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should use the same names as the PyTorch modules
| class FlaxAttentionBlock(nn.Module): | |
| class FlaxCrossAttention(nn.Module): |
| query_dim: int | ||
| heads: int = 8 | ||
| dim_head: int = 64 | ||
| dropout: float = 0.0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
dropout is not used, we should add the dropout layer here.
| dim: int | ||
| n_heads: int | ||
| d_head: int | ||
| dropout: float = 0.0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's add the dropout layer
| self.self_attn = FlaxAttentionBlock(self.dim, self.n_heads, self.d_head, self.dropout, dtype=self.dtype) | ||
| # cross attention | ||
| self.cross_attn = FlaxAttentionBlock(self.dim, self.n_heads, self.d_head, self.dropout, dtype=self.dtype) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The names should match with pt version for autoconversion.
| self.self_attn = FlaxAttentionBlock(self.dim, self.n_heads, self.d_head, self.dropout, dtype=self.dtype) | |
| # cross attention | |
| self.cross_attn = FlaxAttentionBlock(self.dim, self.n_heads, self.d_head, self.dropout, dtype=self.dtype) | |
| self.attn1 = FlaxAttentionBlock(self.dim, self.n_heads, self.d_head, self.dropout, dtype=self.dtype) | |
| # cross attention | |
| self.attn2 = FlaxAttentionBlock(self.dim, self.n_heads, self.d_head, self.dropout, dtype=self.dtype) |
| # 1. time | ||
| t_emb = self.time_proj(timesteps) | ||
| t_emb = self.time_embedding(t_emb) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This expects that timestpes is an array, might not work if it's a scaler. We should check this and hanlde scaler to array conversion.
| if self.add_downsample: | ||
| self.downsample = FlaxDownsample2D(self.out_channels, dtype=self.dtype) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should be a list, same as in PT
| if self.add_downsample: | ||
| self.downsample = FlaxDownsample2D(self.out_channels, dtype=self.dtype) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same comment as above.
| if self.add_upsample: | ||
| self.upsample = FlaxUpsample2D(self.out_channels, dtype=self.dtype) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same comment as above.
| if self.add_upsample: | ||
| self.upsample = FlaxUpsample2D(self.out_channels, dtype=self.dtype) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same comment as above.
|
Thanks a lot for the review here @patil-suraj - you're 100% right here. To move fast, I'd say we merge this PR though and solve the conversion/weigth naming in a new PR (opening an issue for this) as well as the dropout layers. As discussed offline, feel free to merge @pcuenca and we'll adapt in a future PR according to @patil-suraj's comments here :-) |
|
Opened to issues here for future PRs :-) |
* First UNet Flax modeling blocks. Mimic the structure of the PyTorch files. The model classes themselves need work, depending on what we do about configuration and initialization. * Remove FlaxUNet2DConfig class. * ignore_for_config non-config args. * Implement `FlaxModelMixin` * Use new mixins for Flax UNet. For some reason the configuration is not correctly applied; the signature of the `__init__` method does not contain all the parameters by the time it's inspected in `extract_init_dict`. * Import `FlaxUNet2DConditionModel` if flax is available. * Rm unused method `framework` * Update src/diffusers/modeling_flax_utils.py Co-authored-by: Suraj Patil <[email protected]> * Indicate types in flax.struct.dataclass as pointed out by @mishig25 Co-authored-by: Mishig Davaadorj <[email protected]> * Fix typo in transformer block. * make style * some more changes * make style * Add comment * Update src/diffusers/modeling_flax_utils.py Co-authored-by: Patrick von Platen <[email protected]> * Rm unneeded comment * Update docstrings * correct ignore kwargs * make style * Update docstring examples * Make style * Style: remove empty line. * Apply style (after upgrading black from pinned version) * Remove some commented code and unused imports. * Add init_weights (not yet in use until huggingface#513). * Trickle down deterministic to blocks. * Rename q, k, v according to the latest PyTorch version. Note that weights were exported with the old names, so we need to be careful. * Flax UNet docstrings, default props as in PyTorch. * Fix minor typos in PyTorch docstrings. * Use FlaxUNet2DConditionOutput as output from UNet. * make style Co-authored-by: Mishig Davaadorj <[email protected]> Co-authored-by: Mishig Davaadorj <[email protected]> Co-authored-by: Suraj Patil <[email protected]> Co-authored-by: Patrick von Platen <[email protected]>
This is an alternative to #485 that incorporates #493.
We'll probably close #485, but I'm having a strange bug where the configuration is not correctly applied. The json file is correctly read infrom_pretrained, but the signature of the__init__method does not contain all the parameters by the time it's inspected inextract_init_dict.Thanks to @mishig25 the aforementioned bug was resolved.