Skip to content
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
67e245c
First UNet Flax modeling blocks.
pcuenca Sep 12, 2022
c3fdbf9
Remove FlaxUNet2DConfig class.
pcuenca Sep 12, 2022
1067e34
ignore_for_config non-config args.
pcuenca Sep 12, 2022
95073e1
Implement `FlaxModelMixin`
mishig25 Sep 13, 2022
b9f6eb4
Merge remote-tracking branch 'origin/flax_model_mixin' into flax-unet…
pcuenca Sep 13, 2022
9891e5c
Use new mixins for Flax UNet.
pcuenca Sep 13, 2022
2d90544
Merge remote-tracking branch 'origin/main' into flax-unet-flaxmodelmixin
pcuenca Sep 13, 2022
25c615a
Import `FlaxUNet2DConditionModel` if flax is available.
pcuenca Sep 13, 2022
91559f3
Rm unused method `framework`
mishig25 Sep 14, 2022
f7a0ab2
Update src/diffusers/modeling_flax_utils.py
Sep 14, 2022
d41f2bf
Indicate types in flax.struct.dataclass as pointed out by @mishig25
pcuenca Sep 14, 2022
e0ec7bf
Fix typo in transformer block.
pcuenca Sep 14, 2022
5e7aeea
make style
pcuenca Sep 14, 2022
70ce383
Merge remote-tracking branch 'origin/main' into flax-unet-flaxmodelmixin
pcuenca Sep 14, 2022
5d81bf8
some more changes
patrickvonplaten Sep 14, 2022
1430ab8
make style
patrickvonplaten Sep 14, 2022
6a2a4c1
Add comment
mishig25 Sep 14, 2022
8d20417
Merge remote-tracking branch 'origin/flax_model_mixin' into flax-unet…
pcuenca Sep 14, 2022
2bf0267
Update src/diffusers/modeling_flax_utils.py
Sep 14, 2022
25ab3ca
Rm unneeded comment
mishig25 Sep 14, 2022
1e8466e
Update docstrings
mishig25 Sep 14, 2022
6842d29
correct ignore kwargs
patrickvonplaten Sep 14, 2022
4f6b01b
Merge branch 'flax_model_mixin' of https://github.com/huggingface/dif…
patrickvonplaten Sep 14, 2022
0f26c05
make style
patrickvonplaten Sep 14, 2022
d98e8c7
Update docstring examples
mishig25 Sep 14, 2022
5a7b784
Merge branch 'flax_model_mixin' of https://github.com/huggingface/dif…
mishig25 Sep 14, 2022
5d08577
Make style
mishig25 Sep 14, 2022
31caae9
Merge remote-tracking branch 'origin/flax_model_mixin' into flax-unet…
pcuenca Sep 14, 2022
0611b17
Merge remote-tracking branch 'origin/main' into flax-unet-flaxmodelmixin
pcuenca Sep 14, 2022
39bbd13
Style: remove empty line.
pcuenca Sep 14, 2022
ea99f35
Apply style (after upgrading black from pinned version)
pcuenca Sep 14, 2022
2d896f6
Remove some commented code and unused imports.
pcuenca Sep 15, 2022
da6ddfd
Add init_weights (not yet in use until #513).
pcuenca Sep 15, 2022
e7347c0
Trickle down deterministic to blocks.
pcuenca Sep 15, 2022
cfca52f
Rename q, k, v according to the latest PyTorch version.
pcuenca Sep 15, 2022
a48500a
Flax UNet docstrings, default props as in PyTorch.
pcuenca Sep 15, 2022
b33ef5e
Fix minor typos in PyTorch docstrings.
pcuenca Sep 15, 2022
b8798ba
Use FlaxUNet2DConditionOutput as output from UNet.
pcuenca Sep 15, 2022
da97b21
make style
pcuenca Sep 15, 2022
802e710
Merge remote-tracking branch 'origin/main' into flax-unet-flaxmodelmixin
pcuenca Sep 15, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/diffusers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,5 +64,6 @@

if is_flax_available():
from .schedulers import FlaxPNDMScheduler
from .models.unet_2d_condition_flax import FlaxUNet2DConditionModel
else:
from .utils.dummy_flax_objects import * # noqa F403
40 changes: 40 additions & 0 deletions src/diffusers/configuration_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,3 +401,43 @@ def inner_init(self, *args, **kwargs):
getattr(self, "register_to_config")(**new_kwargs)

return inner_init


def flax_register_to_config(cls):
original_init = cls.__init__

@functools.wraps(original_init)
def init(self, *args, **kwargs):
# Ignore private kwargs in the init.
init_kwargs = {k: v for k, v in kwargs.items() if not k.startswith("_")}
# original_init(self, *args, **init_kwargs)
if not isinstance(self, ConfigMixin):
raise RuntimeError(
f"`@register_for_config` was applied to {self.__class__.__name__} init method, but this class does "
"not inherit from `ConfigMixin`."
)

ignore = getattr(self, "ignore_for_config", [])
# Get positional arguments aligned with kwargs
new_kwargs = {}
signature = inspect.signature(init)
parameters = {
name: p.default for i, (name, p) in enumerate(signature.parameters.items()) if i > 0 and name not in ignore
}
for arg, name in zip(args, parameters.keys()):
new_kwargs[name] = arg

# Then add all kwargs
new_kwargs.update(
{
k: init_kwargs.get(k, default)
for k, default in parameters.items()
if k not in ignore and k not in new_kwargs
}
)
getattr(self, "register_to_config")(**new_kwargs)

original_init(self, *args, **init_kwargs)

cls.__init__ = init
return cls
500 changes: 500 additions & 0 deletions src/diffusers/modeling_flax_utils.py

Large diffs are not rendered by default.

181 changes: 181 additions & 0 deletions src/diffusers/models/attention_flax.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,181 @@
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import flax.linen as nn
import jax.numpy as jnp


class FlaxAttentionBlock(nn.Module):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should use the same names as the PyTorch modules

Suggested change
class FlaxAttentionBlock(nn.Module):
class FlaxCrossAttention(nn.Module):

query_dim: int
heads: int = 8
dim_head: int = 64
dropout: float = 0.0
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dropout is not used, we should add the dropout layer here.

dtype: jnp.dtype = jnp.float32

def setup(self):
inner_dim = self.dim_head * self.heads
self.scale = self.dim_head**-0.5

self.to_q = nn.Dense(inner_dim, use_bias=False, dtype=self.dtype)
self.to_k = nn.Dense(inner_dim, use_bias=False, dtype=self.dtype)
self.to_v = nn.Dense(inner_dim, use_bias=False, dtype=self.dtype)

self.to_out = nn.Dense(self.query_dim, dtype=self.dtype)

def reshape_heads_to_batch_dim(self, tensor):
batch_size, seq_len, dim = tensor.shape
head_size = self.heads
tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
tensor = jnp.transpose(tensor, (0, 2, 1, 3))
tensor = tensor.reshape(batch_size * head_size, seq_len, dim // head_size)
return tensor

def reshape_batch_dim_to_heads(self, tensor):
batch_size, seq_len, dim = tensor.shape
head_size = self.heads
tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
tensor = jnp.transpose(tensor, (0, 2, 1, 3))
tensor = tensor.reshape(batch_size // head_size, seq_len, dim * head_size)
return tensor

def __call__(self, hidden_states, context=None, deterministic=True):
context = hidden_states if context is None else context

q = self.to_q(hidden_states)
k = self.to_k(context)
v = self.to_v(context)

q = self.reshape_heads_to_batch_dim(q)
k = self.reshape_heads_to_batch_dim(k)
v = self.reshape_heads_to_batch_dim(v)

# compute attentions
attn_weights = jnp.einsum("b i d, b j d->b i j", q, k)
attn_weights = attn_weights * self.scale
attn_weights = nn.softmax(attn_weights, axis=2)

## attend to values
hidden_states = jnp.einsum("b i j, b j d -> b i d", attn_weights, v)
hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
hidden_states = self.to_out(hidden_states)
return hidden_states


class FlaxBasicTransformerBlock(nn.Module):
dim: int
n_heads: int
d_head: int
dropout: float = 0.0
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's add the dropout layer

dtype: jnp.dtype = jnp.float32

def setup(self):
# self attention
self.self_attn = FlaxAttentionBlock(self.dim, self.n_heads, self.d_head, self.dropout, dtype=self.dtype)
# cross attention
self.cross_attn = FlaxAttentionBlock(self.dim, self.n_heads, self.d_head, self.dropout, dtype=self.dtype)
Comment on lines +85 to +87
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The names should match with pt version for autoconversion.

Suggested change
self.self_attn = FlaxAttentionBlock(self.dim, self.n_heads, self.d_head, self.dropout, dtype=self.dtype)
# cross attention
self.cross_attn = FlaxAttentionBlock(self.dim, self.n_heads, self.d_head, self.dropout, dtype=self.dtype)
self.attn1 = FlaxAttentionBlock(self.dim, self.n_heads, self.d_head, self.dropout, dtype=self.dtype)
# cross attention
self.attn2 = FlaxAttentionBlock(self.dim, self.n_heads, self.d_head, self.dropout, dtype=self.dtype)

self.ff = FlaxGluFeedForward(dim=self.dim, dropout=self.dropout, dtype=self.dtype)
self.norm1 = nn.LayerNorm(epsilon=1e-5, dtype=self.dtype)
self.norm2 = nn.LayerNorm(epsilon=1e-5, dtype=self.dtype)
self.norm3 = nn.LayerNorm(epsilon=1e-5, dtype=self.dtype)

def __call__(self, hidden_states, context, deterministic=True):
# self attention
residual = hidden_states
hidden_states = self.self_attn(self.norm1(hidden_states))
hidden_states = hidden_states + residual

# cross attention
residual = hidden_states
hidden_states = self.cross_attn(self.norm2(hidden_states), context)
hidden_states = hidden_states + residual

# feed forward
residual = hidden_states
hidden_states = self.ff(self.norm3(hidden_states))
hidden_states = hidden_states + residual

return hidden_states


class FlaxSpatialTransformer(nn.Module):
in_channels: int
n_heads: int
d_head: int
depth: int = 1
dropout: float = 0.0
dtype: jnp.dtype = jnp.float32

def setup(self):
self.norm = nn.GroupNorm(num_groups=32, epsilon=1e-5)

inner_dim = self.n_heads * self.d_head
self.proj_in = nn.Conv(
inner_dim,
kernel_size=(1, 1),
strides=(1, 1),
padding="VALID",
dtype=self.dtype,
)

self.transformer_blocks = [
TransformerBlock(inner_dim, self.n_heads, self.d_head, dropout=self.dropout, dtype=self.dtype)
for _ in range(self.depth)
]

self.proj_out = nn.Conv(
inner_dim,
kernel_size=(1, 1),
strides=(1, 1),
padding="VALID",
dtype=self.dtype,
)

def __call__(self, hidden_states, context, deterministic=True):
batch, height, width, channels = hidden_states.shape
# import ipdb; ipdb.set_trace()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# import ipdb; ipdb.set_trace()

residual = hidden_states
hidden_states = self.norm(hidden_states)
hidden_states = self.proj_in(hidden_states)

# hidden_states = jnp.transpose(hidden_states, (0, 2, 3, 1))
hidden_states = hidden_states.reshape(batch, height * width, channels)

for transformer_block in self.transformer_blocks:
hidden_states = transformer_block(hidden_states, context)

hidden_states = hidden_states.reshape(batch, height, width, channels)
# hidden_states = jnp.transpose(hidden_states, (0, 3, 1, 2))

hidden_states = self.proj_out(hidden_states)
hidden_states = hidden_states + residual

return hidden_states


class FlaxGluFeedForward(nn.Module):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We will have to split this in two modules FeedForward and GEGLU like in PyTorch.

dim: int
dropout: float = 0.0
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's add the dropout layer.

dtype: jnp.dtype = jnp.float32

def setup(self):
inner_dim = self.dim * 4
self.dense1 = nn.Dense(inner_dim * 2, dtype=self.dtype)
self.dense2 = nn.Dense(self.dim, dtype=self.dtype)

def __call__(self, hidden_states, deterministic=True):
hidden_states = self.dense1(hidden_states)
hidden_linear, hidden_gelu = jnp.split(hidden_states, 2, axis=2)
hidden_states = hidden_linear * nn.gelu(hidden_gelu)
hidden_states = self.dense2(hidden_states)
return hidden_states
56 changes: 56 additions & 0 deletions src/diffusers/models/embeddings_flax.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math

import flax.linen as nn
import jax.numpy as jnp


# This is like models.embeddings.get_timestep_embedding (PyTorch) but
# less general (only handles the case we currently need).
Comment on lines +20 to +21
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, we could update this once we start converting other models.

def get_sinusoidal_embeddings(timesteps, embedding_dim):
"""
This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings.

:param timesteps: a 1-D tensor of N indices, one per batch element.
These may be fractional.
:param embedding_dim: the dimension of the output. :param max_period: controls the minimum frequency of the
embeddings. :return: an [N x dim] tensor of positional embeddings.
"""
half_dim = embedding_dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = jnp.exp(jnp.arange(half_dim) * -emb)
emb = timesteps[:, None] * emb[None, :]
emb = jnp.concatenate([jnp.cos(emb), jnp.sin(emb)], -1)
return emb


class FlaxTimestepEmbedding(nn.Module):
time_embed_dim: int = 32
dtype: jnp.dtype = jnp.float32

@nn.compact
def __call__(self, temb):
temb = nn.Dense(self.time_embed_dim, dtype=self.dtype, name="linear_1")(temb)
temb = nn.silu(temb)
temb = nn.Dense(self.time_embed_dim, dtype=self.dtype, name="linear_2")(temb)
return temb


class FlaxTimesteps(nn.Module):
dim: int = 32

@nn.compact
def __call__(self, timesteps):
return get_sinusoidal_embeddings(timesteps, self.dim)
Loading