Skip to content

Conversation

@pcuenca
Copy link
Member

@pcuenca pcuenca commented Sep 13, 2022

This is an alternative to #485 that incorporates #493.

We'll probably close #485, but I'm having a strange bug where the configuration is not correctly applied. The json file is correctly read in from_pretrained, but the signature of the __init__ method does not contain all the parameters by the time it's inspected in extract_init_dict.

Thanks to @mishig25 the aforementioned bug was resolved.

pcuenca and others added 8 commits September 12, 2022 18:23
Mimic the structure of the PyTorch files.
The model classes themselves need work, depending on what we do about
configuration and initialization.
For some reason the configuration is not correctly applied; the
signature of the `__init__` method does not contain all the parameters
by the time it's inspected in `extract_init_dict`.
@pcuenca pcuenca marked this pull request as draft September 13, 2022 17:30
@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Sep 13, 2022

The documentation is not available anymore as the PR was closed or merged.

@pcuenca pcuenca marked this pull request as ready for review September 14, 2022 17:09
Copy link
Contributor

@patrickvonplaten patrickvonplaten left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks very nice! Think it's just about deleting some dead code and maybe fix the dropouts everywhere :-)

@pcuenca
Copy link
Member Author

pcuenca commented Sep 15, 2022

Changed since last review:

  • Trickle down deterministic
  • Rename q, k, v
  • Docstrings (for Flax UNet only)
  • FlaxUNet2DConditionOutput
  • init_weights (not yet in use)

Copy link
Contributor

@patil-suraj patil-suraj left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good in general! Left some comments, more specifically

  • We should add the dropout layers, I didn't add it in the original repo as it was just for inference.
  • Make sure that the module and weight names match 1:1 with PyTorch. This is required as we need to provide interoperability with PT and flax models.

Comment on lines +30 to +35
# Weights were exported with old names {to_q, to_k, to_v, to_out}
self.query = nn.Dense(inner_dim, use_bias=False, dtype=self.dtype, name="to_q")
self.key = nn.Dense(inner_dim, use_bias=False, dtype=self.dtype, name="to_k")
self.value = nn.Dense(inner_dim, use_bias=False, dtype=self.dtype, name="to_v")

self.proj_attn = nn.Dense(self.query_dim, dtype=self.dtype, name="to_out")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(nit),

since we are using setup here could just use self.to_q = nn.Dense(....) instead of passing name. This will also make it easy to compare flax and pt code when reading.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, the original name was self.to_q, I changed it here to make it like the renamed PyTorch version but kept the same weight names.

import jax.numpy as jnp


class FlaxAttentionBlock(nn.Module):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should use the same names as the PyTorch modules

Suggested change
class FlaxAttentionBlock(nn.Module):
class FlaxCrossAttention(nn.Module):

query_dim: int
heads: int = 8
dim_head: int = 64
dropout: float = 0.0
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dropout is not used, we should add the dropout layer here.

dim: int
n_heads: int
d_head: int
dropout: float = 0.0
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's add the dropout layer

Comment on lines +85 to +87
self.self_attn = FlaxAttentionBlock(self.dim, self.n_heads, self.d_head, self.dropout, dtype=self.dtype)
# cross attention
self.cross_attn = FlaxAttentionBlock(self.dim, self.n_heads, self.d_head, self.dropout, dtype=self.dtype)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The names should match with pt version for autoconversion.

Suggested change
self.self_attn = FlaxAttentionBlock(self.dim, self.n_heads, self.d_head, self.dropout, dtype=self.dtype)
# cross attention
self.cross_attn = FlaxAttentionBlock(self.dim, self.n_heads, self.d_head, self.dropout, dtype=self.dtype)
self.attn1 = FlaxAttentionBlock(self.dim, self.n_heads, self.d_head, self.dropout, dtype=self.dtype)
# cross attention
self.attn2 = FlaxAttentionBlock(self.dim, self.n_heads, self.d_head, self.dropout, dtype=self.dtype)

Comment on lines +216 to +218
# 1. time
t_emb = self.time_proj(timesteps)
t_emb = self.time_embedding(t_emb)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This expects that timestpes is an array, might not work if it's a scaler. We should check this and hanlde scaler to array conversion.

Comment on lines +57 to +58
if self.add_downsample:
self.downsample = FlaxDownsample2D(self.out_channels, dtype=self.dtype)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be a list, same as in PT

Comment on lines +98 to +99
if self.add_downsample:
self.downsample = FlaxDownsample2D(self.out_channels, dtype=self.dtype)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same comment as above.

Comment on lines +153 to +154
if self.add_upsample:
self.upsample = FlaxUpsample2D(self.out_channels, dtype=self.dtype)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same comment as above.

Comment on lines +198 to +199
if self.add_upsample:
self.upsample = FlaxUpsample2D(self.out_channels, dtype=self.dtype)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same comment as above.

@patrickvonplaten
Copy link
Contributor

Thanks a lot for the review here @patil-suraj - you're 100% right here. To move fast, I'd say we merge this PR though and solve the conversion/weigth naming in a new PR (opening an issue for this) as well as the dropout layers.

As discussed offline, feel free to merge @pcuenca and we'll adapt in a future PR according to @patil-suraj's comments here :-)

@patrickvonplaten
Copy link
Contributor

Opened to issues here for future PRs :-)

@pcuenca pcuenca merged commit d8b0e4f into main Sep 15, 2022
@pcuenca pcuenca deleted the flax-unet-flaxmodelmixin branch September 15, 2022 16:07
PhaneeshB pushed a commit to nod-ai/diffusers that referenced this pull request Mar 1, 2023
yoonseokjin pushed a commit to yoonseokjin/diffusers that referenced this pull request Dec 25, 2023
* First UNet Flax modeling blocks.

Mimic the structure of the PyTorch files.
The model classes themselves need work, depending on what we do about
configuration and initialization.

* Remove FlaxUNet2DConfig class.

* ignore_for_config non-config args.

* Implement `FlaxModelMixin`

* Use new mixins for Flax UNet.

For some reason the configuration is not correctly applied; the
signature of the `__init__` method does not contain all the parameters
by the time it's inspected in `extract_init_dict`.

* Import `FlaxUNet2DConditionModel` if flax is available.

* Rm unused method `framework`

* Update src/diffusers/modeling_flax_utils.py

Co-authored-by: Suraj Patil <[email protected]>

* Indicate types in flax.struct.dataclass as pointed out by @mishig25

Co-authored-by: Mishig Davaadorj <[email protected]>

* Fix typo in transformer block.

* make style

* some more changes

* make style

* Add comment

* Update src/diffusers/modeling_flax_utils.py

Co-authored-by: Patrick von Platen <[email protected]>

* Rm unneeded comment

* Update docstrings

* correct ignore kwargs

* make style

* Update docstring examples

* Make style

* Style: remove empty line.

* Apply style (after upgrading black from pinned version)

* Remove some commented code and unused imports.

* Add init_weights (not yet in use until huggingface#513).

* Trickle down deterministic to blocks.

* Rename q, k, v according to the latest PyTorch version.

Note that weights were exported with the old names, so we need to be
careful.

* Flax UNet docstrings, default props as in PyTorch.

* Fix minor typos in PyTorch docstrings.

* Use FlaxUNet2DConditionOutput as output from UNet.

* make style

Co-authored-by: Mishig Davaadorj <[email protected]>
Co-authored-by: Mishig Davaadorj <[email protected]>
Co-authored-by: Suraj Patil <[email protected]>
Co-authored-by: Patrick von Platen <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants