Skip to content

Commit 5e6417e

Browse files
kashifpatrickvonplatenpcuenca
authored
[Docs] Models (#416)
* docs for attention * types for embeddings * unet2d docstrings * UNet2DConditionModel docstrings * fix typos * style and vq-vae docstrings * docstrings for VAE * Update src/diffusers/models/unet_2d.py Co-authored-by: Patrick von Platen <[email protected]> * make style * added inherits from sentence * docstring to forward * make style * Apply suggestions from code review Co-authored-by: Pedro Cuenca <[email protected]> * finish model docs * up Co-authored-by: Patrick von Platen <[email protected]> Co-authored-by: Pedro Cuenca <[email protected]>
1 parent 234e90c commit 5e6417e

File tree

7 files changed

+284
-88
lines changed

7 files changed

+284
-88
lines changed

docs/source/api/models.mdx

Lines changed: 26 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,32 @@ Diffusers contains pretrained models for popular algorithms and modules for crea
1616
The primary function of these models is to denoise an input sample, by modeling the distribution $p_\theta(\mathbf{x}_{t-1}|\mathbf{x}_t)$.
1717
The models are built on the base class ['ModelMixin'] that is a `torch.nn.module` with basic functionality for saving and loading models both locally and from the HuggingFace hub.
1818

19-
## API
19+
## ModelMixin
20+
[[autodoc]] ModelMixin
2021

21-
Models should provide the `def forward` function and initialization of the model.
22-
All saving, loading, and utilities should be in the base ['ModelMixin'] class.
22+
## UNet2DOutput
23+
[[autodoc]] models.unet_2d.UNet2DOutput
2324

24-
## Examples
25+
## UNet2DModel
26+
[[autodoc]] UNet2DModel
2527

26-
- The ['UNetModel'] was proposed in [TODO](https://arxiv.org/) and has been used in paper1, paper2, paper3.
27-
- Extensions of the ['UNetModel'] include the ['UNetGlideModel'] that uses attention and timestep embeddings for the [GLIDE](https://arxiv.org/abs/2112.10741) paper, the ['UNetGradTTS'] model from this [paper](https://arxiv.org/abs/2105.06337) for text-to-speech, ['UNetLDMModel'] for latent-diffusion models in this [paper](https://arxiv.org/abs/2112.10752), and the ['TemporalUNet'] used for time-series prediciton in this reinforcement learning [paper](https://arxiv.org/abs/2205.09991).
28-
- TODO: mention VAE / SDE score estimation
28+
## UNet2DConditionOutput
29+
[[autodoc]] models.unet_2d_condition.UNet2DConditionOutput
30+
31+
## UNet2DConditionModel
32+
[[autodoc]] UNet2DConditionModel
33+
34+
## DecoderOutput
35+
[[autodoc]] models.vae.DecoderOutput
36+
37+
## VQEncoderOutput
38+
[[autodoc]] models.vae.VQEncoderOutput
39+
40+
## VQModel
41+
[[autodoc]] VQModel
42+
43+
## AutoencoderKLOutput
44+
[[autodoc]] models.vae.AutoencoderKLOutput
45+
46+
## AutoencoderKL
47+
[[autodoc]] AutoencoderKL

src/diffusers/modeling_utils.py

Lines changed: 13 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -117,27 +117,12 @@ class ModelMixin(torch.nn.Module):
117117
Base class for all models.
118118
119119
[`ModelMixin`] takes care of storing the configuration of the models and handles methods for loading, downloading
120-
and saving models as well as a few methods common to all models to:
120+
and saving models.
121121
122-
- resize the input embeddings,
123-
- prune heads in the self-attention heads.
122+
Class attributes:
124123
125-
Class attributes (overridden by derived classes):
126-
127-
- **config_class** ([`ConfigMixin`]) -- A subclass of [`ConfigMixin`] to use as configuration class for this
128-
model architecture.
129-
- **load_tf_weights** (`Callable`) -- A python *method* for loading a TensorFlow checkpoint in a PyTorch model,
130-
taking as arguments:
131-
132-
- **model** ([`ModelMixin`]) -- An instance of the model on which to load the TensorFlow checkpoint.
133-
- **config** ([`PreTrainedConfigMixin`]) -- An instance of the configuration associated to the model.
134-
- **path** (`str`) -- A path to the TensorFlow checkpoint.
135-
136-
- **base_model_prefix** (`str`) -- A string indicating the attribute associated to the base model in derived
137-
classes of the same architecture adding modules on top of the base model.
138-
- **is_parallelizable** (`bool`) -- A flag indicating whether this model supports model parallelization.
139-
- **main_input_name** (`str`) -- The name of the principal input to the model (often `input_ids` for NLP
140-
models, `pixel_values` for vision models and `input_values` for speech models).
124+
- **config_name** ([`str`]) -- A filename under which the model should be stored when calling
125+
[`~modeling_utils.ModelMixin.save_pretrained`].
141126
"""
142127
config_name = CONFIG_NAME
143128
_automatically_saved_args = ["_diffusers_version", "_class_name", "_name_or_path"]
@@ -150,11 +135,10 @@ def save_pretrained(
150135
save_directory: Union[str, os.PathLike],
151136
is_main_process: bool = True,
152137
save_function: Callable = torch.save,
153-
**kwargs,
154138
):
155139
"""
156140
Save a model and its configuration file to a directory, so that it can be re-loaded using the
157-
`[`~ModelMixin.from_pretrained`]` class method.
141+
`[`~modeling_utils.ModelMixin.from_pretrained`]` class method.
158142
159143
Arguments:
160144
save_directory (`str` or `os.PathLike`):
@@ -166,9 +150,6 @@ def save_pretrained(
166150
save_function (`Callable`):
167151
The function to use to save the state dictionary. Useful on distributed training like TPUs when one
168152
need to replace `torch.save` by another method.
169-
170-
kwargs:
171-
Additional key word arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
172153
"""
173154
if os.path.isfile(save_directory):
174155
logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
@@ -224,34 +205,12 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
224205
- A path to a *directory* containing model weights saved using [`~ModelMixin.save_pretrained`],
225206
e.g., `./my_model_directory/`.
226207
227-
config (`Union[ConfigMixin, str, os.PathLike]`, *optional*):
228-
Can be either:
229-
230-
- an instance of a class derived from [`ConfigMixin`],
231-
- a string or path valid as input to [`~ConfigMixin.from_pretrained`].
232-
233-
ConfigMixinuration for the model to use instead of an automatically loaded configuration.
234-
ConfigMixinuration can be automatically loaded when:
235-
236-
- The model is a model provided by the library (loaded with the *model id* string of a pretrained
237-
model).
238-
- The model was saved using [`~ModelMixin.save_pretrained`] and is reloaded by supplying the save
239-
directory.
240-
- The model is loaded by supplying a local directory as `pretrained_model_name_or_path` and a
241-
configuration JSON file named *config.json* is found in the directory.
242208
cache_dir (`Union[str, os.PathLike]`, *optional*):
243209
Path to a directory in which a downloaded pretrained model configuration should be cached if the
244210
standard cache should not be used.
245-
from_tf (`bool`, *optional*, defaults to `False`):
246-
Load the model weights from a TensorFlow checkpoint save file (see docstring of
247-
`pretrained_model_name_or_path` argument).
248-
from_flax (`bool`, *optional*, defaults to `False`):
249-
Load the model weights from a Flax checkpoint save file (see docstring of
250-
`pretrained_model_name_or_path` argument).
251-
ignore_mismatched_sizes (`bool`, *optional*, defaults to `False`):
252-
Whether or not to raise an error if some of the weights from the checkpoint do not have the same size
253-
as the weights of the model (if for instance, you are instantiating a model with 10 labels from a
254-
checkpoint with 3 labels).
211+
torch_dtype (`str` or `torch.dtype`, *optional*):
212+
Override the default `torch.dtype` and load the model under this dtype. If `"auto"` is passed the dtype
213+
will be automatically derived from the model's weights.
255214
force_download (`bool`, *optional*, defaults to `False`):
256215
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
257216
cached versions if they exist.
@@ -267,7 +226,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
267226
Whether or not to only look at local files (i.e., do not try to download the model).
268227
use_auth_token (`str` or *bool*, *optional*):
269228
The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
270-
when running `transformers-cli login` (stored in `~/.huggingface`).
229+
when running `diffusers-cli login` (stored in `~/.huggingface`).
271230
revision (`str`, *optional*, defaults to `"main"`):
272231
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
273232
git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
@@ -278,18 +237,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
278237
Please refer to the mirror site for more information.
279238
280239
kwargs (remaining dictionary of keyword arguments, *optional*):
281-
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
282-
`output_attentions=True`). Behaves differently depending on whether a `config` is provided or
283-
automatically loaded:
284-
285-
- If a configuration is provided with `config`, `**kwargs` will be directly passed to the
286-
underlying model's `__init__` method (we assume all relevant updates to the configuration have
287-
already been done)
288-
- If a configuration is not provided, `kwargs` will be first passed to the configuration class
289-
initialization function ([`~ConfigMixin.from_pretrained`]). Each key of `kwargs` that corresponds
290-
to a configuration attribute will be used to override said attribute with the supplied `kwargs`
291-
value. Remaining keys that do not correspond to any configuration attribute will be passed to the
292-
underlying model's `__init__` function.
240+
Can be used to update the [`ConfigMixin`] of the model (after it being loaded).
293241
294242
<Tip>
295243
@@ -299,8 +247,8 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
299247
300248
<Tip>
301249
302-
Activate the special ["offline-mode"](https://huggingface.co/transformers/installation.html#offline-mode) to
303-
use this method in a firewalled environment.
250+
Activate the special ["offline-mode"](https://huggingface.co/diffusers/installation.html#offline-mode) to use
251+
this method in a firewalled environment.
304252
305253
</Tip>
306254
@@ -404,7 +352,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
404352
f" in the cached files and it looks like {pretrained_model_name_or_path} is not the path to a"
405353
f" directory containing a file named {WEIGHTS_NAME} or"
406354
" \nCheckout your internet connection or see how to run the library in"
407-
" offline mode at 'https://huggingface.co/docs/transformers/installation#offline-mode'."
355+
" offline mode at 'https://huggingface.co/docs/diffusers/installation#offline-mode'."
408356
)
409357
except EnvironmentError:
410358
raise EnvironmentError(

src/diffusers/models/attention.py

Lines changed: 94 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import math
2+
from typing import Optional
23

34
import torch
45
import torch.nn.functional as F
@@ -10,16 +11,24 @@ class AttentionBlock(nn.Module):
1011
An attention block that allows spatial positions to attend to each other. Originally ported from here, but adapted
1112
to the N-d case.
1213
https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
13-
Uses three q, k, v linear layers to compute attention
14+
Uses three q, k, v linear layers to compute attention.
15+
16+
Parameters:
17+
channels (:obj:`int`): The number of channels in the input and output.
18+
num_head_channels (:obj:`int`, *optional*):
19+
The number of channels in each head. If None, then `num_heads` = 1.
20+
num_groups (:obj:`int`, *optional*, defaults to 32): The number of groups to use for group norm.
21+
rescale_output_factor (:obj:`float`, *optional*, defaults to 1.0): The factor to rescale the output by.
22+
eps (:obj:`float`, *optional*, defaults to 1e-5): The epsilon value to use for group norm.
1423
"""
1524

1625
def __init__(
1726
self,
18-
channels,
19-
num_head_channels=None,
20-
num_groups=32,
21-
rescale_output_factor=1.0,
22-
eps=1e-5,
27+
channels: int,
28+
num_head_channels: Optional[int] = None,
29+
num_groups: int = 32,
30+
rescale_output_factor: float = 1.0,
31+
eps: float = 1e-5,
2332
):
2433
super().__init__()
2534
self.channels = channels
@@ -86,10 +95,26 @@ def forward(self, hidden_states):
8695
class SpatialTransformer(nn.Module):
8796
"""
8897
Transformer block for image-like data. First, project the input (aka embedding) and reshape to b, t, d. Then apply
89-
standard transformer action. Finally, reshape to image
98+
standard transformer action. Finally, reshape to image.
99+
100+
Parameters:
101+
in_channels (:obj:`int`): The number of channels in the input and output.
102+
n_heads (:obj:`int`): The number of heads to use for multi-head attention.
103+
d_head (:obj:`int`): The number of channels in each head.
104+
depth (:obj:`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
105+
dropout (:obj:`float`, *optional*, defaults to 0.1): The dropout probability to use.
106+
context_dim (:obj:`int`, *optional*): The number of context dimensions to use.
90107
"""
91108

92-
def __init__(self, in_channels, n_heads, d_head, depth=1, dropout=0.0, context_dim=None):
109+
def __init__(
110+
self,
111+
in_channels: int,
112+
n_heads: int,
113+
d_head: int,
114+
depth: int = 1,
115+
dropout: float = 0.0,
116+
context_dim: Optional[int] = None,
117+
):
93118
super().__init__()
94119
self.n_heads = n_heads
95120
self.d_head = d_head
@@ -127,7 +152,29 @@ def forward(self, x, context=None):
127152

128153

129154
class BasicTransformerBlock(nn.Module):
130-
def __init__(self, dim, n_heads, d_head, dropout=0.0, context_dim=None, gated_ff=True, checkpoint=True):
155+
r"""
156+
A basic Transformer block.
157+
158+
Parameters:
159+
dim (:obj:`int`): The number of channels in the input and output.
160+
n_heads (:obj:`int`): The number of heads to use for multi-head attention.
161+
d_head (:obj:`int`): The number of channels in each head.
162+
dropout (:obj:`float`, *optional*, defaults to 0.0): The dropout probability to use.
163+
context_dim (:obj:`int`, *optional*): The size of the context vector for cross attention.
164+
gated_ff (:obj:`bool`, *optional*, defaults to :obj:`False`): Whether to use a gated feed-forward network.
165+
checkpoint (:obj:`bool`, *optional*, defaults to :obj:`False`): Whether to use checkpointing.
166+
"""
167+
168+
def __init__(
169+
self,
170+
dim: int,
171+
n_heads: int,
172+
d_head: int,
173+
dropout=0.0,
174+
context_dim: Optional[int] = None,
175+
gated_ff: bool = True,
176+
checkpoint: bool = True,
177+
):
131178
super().__init__()
132179
self.attn1 = CrossAttention(
133180
query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout
@@ -154,7 +201,21 @@ def forward(self, x, context=None):
154201

155202

156203
class CrossAttention(nn.Module):
157-
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0):
204+
r"""
205+
A cross attention layer.
206+
207+
Parameters:
208+
query_dim (:obj:`int`): The number of channels in the query.
209+
context_dim (:obj:`int`, *optional*):
210+
The number of channels in the context. If not given, defaults to `query_dim`.
211+
heads (:obj:`int`, *optional*, defaults to 8): The number of heads to use for multi-head attention.
212+
dim_head (:obj:`int`, *optional*, defaults to 64): The number of channels in each head.
213+
dropout (:obj:`float`, *optional*, defaults to 0.0): The dropout probability to use.
214+
"""
215+
216+
def __init__(
217+
self, query_dim: int, context_dim: Optional[int] = None, heads: int = 8, dim_head: int = 64, dropout: int = 0.0
218+
):
158219
super().__init__()
159220
inner_dim = dim_head * heads
160221
context_dim = context_dim if context_dim is not None else query_dim
@@ -228,7 +289,20 @@ def _attention(self, query, key, value, sequence_length, dim):
228289

229290

230291
class FeedForward(nn.Module):
231-
def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0):
292+
r"""
293+
A feed-forward layer.
294+
295+
Parameters:
296+
dim (:obj:`int`): The number of channels in the input.
297+
dim_out (:obj:`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`.
298+
mult (:obj:`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension.
299+
glu (:obj:`bool`, *optional*, defaults to :obj:`False`): Whether to use GLU activation.
300+
dropout (:obj:`float`, *optional*, defaults to 0.0): The dropout probability to use.
301+
"""
302+
303+
def __init__(
304+
self, dim: int, dim_out: Optional[int] = None, mult: int = 4, glu: bool = False, dropout: float = 0.0
305+
):
232306
super().__init__()
233307
inner_dim = int(dim * mult)
234308
dim_out = dim_out if dim_out is not None else dim
@@ -242,7 +316,15 @@ def forward(self, x):
242316

243317
# feedforward
244318
class GEGLU(nn.Module):
245-
def __init__(self, dim_in, dim_out):
319+
r"""
320+
A variant of the gated linear unit activation function from https://arxiv.org/abs/2002.05202.
321+
322+
Parameters:
323+
dim_in (:obj:`int`): The number of channels in the input.
324+
dim_out (:obj:`int`): The number of channels in the output.
325+
"""
326+
327+
def __init__(self, dim_in: int, dim_out: int):
246328
super().__init__()
247329
self.proj = nn.Linear(dim_in, dim_out * 2)
248330

src/diffusers/models/embeddings.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,12 @@
1919

2020

2121
def get_timestep_embedding(
22-
timesteps, embedding_dim, flip_sin_to_cos=False, downscale_freq_shift=1, scale=1, max_period=10000
22+
timesteps: torch.Tensor,
23+
embedding_dim: int,
24+
flip_sin_to_cos: bool = False,
25+
downscale_freq_shift: float = 1,
26+
scale: float = 1,
27+
max_period: int = 10000,
2328
):
2429
"""
2530
This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings.
@@ -55,7 +60,7 @@ def get_timestep_embedding(
5560

5661

5762
class TimestepEmbedding(nn.Module):
58-
def __init__(self, channel, time_embed_dim, act_fn="silu"):
63+
def __init__(self, channel: int, time_embed_dim: int, act_fn: str = "silu"):
5964
super().__init__()
6065

6166
self.linear_1 = nn.Linear(channel, time_embed_dim)
@@ -75,7 +80,7 @@ def forward(self, sample):
7580

7681

7782
class Timesteps(nn.Module):
78-
def __init__(self, num_channels, flip_sin_to_cos, downscale_freq_shift):
83+
def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float):
7984
super().__init__()
8085
self.num_channels = num_channels
8186
self.flip_sin_to_cos = flip_sin_to_cos
@@ -94,7 +99,7 @@ def forward(self, timesteps):
9499
class GaussianFourierProjection(nn.Module):
95100
"""Gaussian Fourier embeddings for noise levels."""
96101

97-
def __init__(self, embedding_size=256, scale=1.0):
102+
def __init__(self, embedding_size: int = 256, scale: float = 1.0):
98103
super().__init__()
99104
self.weight = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False)
100105

0 commit comments

Comments
 (0)