Skip to content

Commit b34be03

Browse files
authored
Karras VE, DDIM and DDPM flax schedulers (#508)
* beta never changes removed from state * fix typos in docs * removed unused var * initial ddim flax scheduler * import * added dummy objects * fix style * fix typo * docs * fix typo in comment * set return type * added flax ddom * fix style * remake * pass PRNG key as argument and split before use * fix doc string * use config * added flax Karras VE scheduler * make style * fix dummy * fix ndarray type annotation * replace returns a new state * added lms_discrete scheduler * use self.config * add_noise needs state * use config * use config * docstring * added flax score sde ve * fix imports * fix typos
1 parent 83a7bb2 commit b34be03

21 files changed

+1351
-66
lines changed

examples/textual_inversion/textual_inversion.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -504,7 +504,9 @@ def main():
504504
noise = torch.randn(latents.shape).to(latents.device)
505505
bsz = latents.shape[0]
506506
# Sample a random timestep for each image
507-
timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, (bsz,), device=latents.device).long()
507+
timesteps = torch.randint(
508+
0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device
509+
).long()
508510

509511
# Add noise to the latents according to the noise magnitude at each timestep
510512
# (this is the forward diffusion process)

examples/unconditional_image_generation/train_unconditional.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ def transforms(examples):
130130
bsz = clean_images.shape[0]
131131
# Sample a random timestep for each image
132132
timesteps = torch.randint(
133-
0, noise_scheduler.num_train_timesteps, (bsz,), device=clean_images.device
133+
0, noise_scheduler.config.num_train_timesteps, (bsz,), device=clean_images.device
134134
).long()
135135

136136
# Add noise to the clean images according to the noise magnitude at each timestep

src/diffusers/__init__.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,13 @@
6464

6565
if is_flax_available():
6666
from .modeling_flax_utils import FlaxModelMixin
67-
from .schedulers import FlaxPNDMScheduler
67+
from .schedulers import (
68+
FlaxDDIMScheduler,
69+
FlaxDDPMScheduler,
70+
FlaxKarrasVeScheduler,
71+
FlaxLMSDiscreteScheduler,
72+
FlaxPNDMScheduler,
73+
FlaxScoreSdeVeScheduler,
74+
)
6875
else:
6976
from .utils.dummy_flax_objects import * # noqa F403

src/diffusers/modeling_flax_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -386,7 +386,7 @@ def from_pretrained(
386386
raise ValueError from e
387387
except (UnicodeDecodeError, ValueError):
388388
raise EnvironmentError(f"Unable to convert {model_file} to Flax deserializable object. ")
389-
# make sure all arrays are stored as jnp.arrays
389+
# make sure all arrays are stored as jnp.ndarray
390390
# NOTE: This is to prevent a bug this will be fixed in Flax >= v0.3.4:
391391
# https://github.com/google/flax/issues/1261
392392
state = jax.tree_util.tree_map(lambda x: jax.device_put(x, jax.devices("cpu")[0]), state)

src/diffusers/pipelines/score_sde_ve/pipeline_score_sde_ve.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ def __call__(
8080
sigma_t = self.scheduler.sigmas[i] * torch.ones(shape[0], device=self.device)
8181

8282
# correction step
83-
for _ in range(self.scheduler.correct_steps):
83+
for _ in range(self.scheduler.config.correct_steps):
8484
model_output = self.unet(sample, sigma_t).sample
8585
sample = self.scheduler.step_correct(model_output, sample, generator=generator).prev_sample
8686

src/diffusers/schedulers/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,12 @@
2828
from ..utils.dummy_pt_objects import * # noqa F403
2929

3030
if is_flax_available():
31+
from .scheduling_ddim_flax import FlaxDDIMScheduler
32+
from .scheduling_ddpm_flax import FlaxDDPMScheduler
33+
from .scheduling_karras_ve_flax import FlaxKarrasVeScheduler
34+
from .scheduling_lms_discrete_flax import FlaxLMSDiscreteScheduler
3135
from .scheduling_pndm_flax import FlaxPNDMScheduler
36+
from .scheduling_sde_ve_flax import FlaxScoreSdeVeScheduler
3237
else:
3338
from ..utils.dummy_flax_objects import * # noqa F403
3439

src/diffusers/schedulers/scheduling_ddim.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ def __init__(
113113

114114
# At every step in ddim, we are looking into the previous alphas_cumprod
115115
# For the final step, there is no previous alphas_cumprod because we are already at 0
116-
# `set_alpha_to_one` decides whether we set this paratemer simply to one or
116+
# `set_alpha_to_one` decides whether we set this parameter simply to one or
117117
# whether we use the final alpha of the "non-previous" one.
118118
self.final_alpha_cumprod = np.array(1.0) if set_alpha_to_one else self.alphas_cumprod[0]
119119

@@ -195,7 +195,7 @@ def step(
195195
# - pred_original_sample -> f_theta(x_t, t) or x_0
196196
# - std_dev_t -> sigma_t
197197
# - eta -> η
198-
# - pred_sample_direction -> "direction pointingc to x_t"
198+
# - pred_sample_direction -> "direction pointing to x_t"
199199
# - pred_prev_sample -> "x_t-1"
200200

201201
# 1. get previous step value (=t-1)
Lines changed: 274 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,274 @@
1+
# Copyright 2022 Stanford University Team and The HuggingFace Team. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
# DISCLAIMER: This code is strongly influenced by https://github.com/pesser/pytorch_diffusion
16+
# and https://github.com/hojonathanho/diffusion
17+
18+
import math
19+
from dataclasses import dataclass
20+
from typing import Optional, Tuple, Union
21+
22+
import flax
23+
import jax.numpy as jnp
24+
from jax import random
25+
26+
from ..configuration_utils import ConfigMixin, register_to_config
27+
from .scheduling_utils import SchedulerMixin, SchedulerOutput
28+
29+
30+
def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999) -> jnp.ndarray:
31+
"""
32+
Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
33+
(1-beta) over time from t = [0,1].
34+
35+
Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
36+
to that part of the diffusion process.
37+
38+
39+
Args:
40+
num_diffusion_timesteps (`int`): the number of betas to produce.
41+
max_beta (`float`): the maximum beta to use; use values lower than 1 to
42+
prevent singularities.
43+
44+
Returns:
45+
betas (`jnp.ndarray`): the betas used by the scheduler to step the model outputs
46+
"""
47+
48+
def alpha_bar(time_step):
49+
return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2
50+
51+
betas = []
52+
for i in range(num_diffusion_timesteps):
53+
t1 = i / num_diffusion_timesteps
54+
t2 = (i + 1) / num_diffusion_timesteps
55+
betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
56+
return jnp.array(betas, dtype=jnp.float32)
57+
58+
59+
@flax.struct.dataclass
60+
class DDIMSchedulerState:
61+
# setable values
62+
timesteps: jnp.ndarray
63+
num_inference_steps: Optional[int] = None
64+
65+
@classmethod
66+
def create(cls, num_train_timesteps: int):
67+
return cls(timesteps=jnp.arange(0, num_train_timesteps)[::-1])
68+
69+
70+
@dataclass
71+
class FlaxSchedulerOutput(SchedulerOutput):
72+
state: DDIMSchedulerState
73+
74+
75+
class FlaxDDIMScheduler(SchedulerMixin, ConfigMixin):
76+
"""
77+
Denoising diffusion implicit models is a scheduler that extends the denoising procedure introduced in denoising
78+
diffusion probabilistic models (DDPMs) with non-Markovian guidance.
79+
80+
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
81+
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
82+
[`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and
83+
[`~ConfigMixin.from_config`] functions.
84+
85+
For more details, see the original paper: https://arxiv.org/abs/2010.02502
86+
87+
Args:
88+
num_train_timesteps (`int`): number of diffusion steps used to train the model.
89+
beta_start (`float`): the starting `beta` value of inference.
90+
beta_end (`float`): the final `beta` value.
91+
beta_schedule (`str`):
92+
the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
93+
`linear`, `scaled_linear`, or `squaredcos_cap_v2`.
94+
trained_betas (`jnp.ndarray`, optional):
95+
option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc.
96+
clip_sample (`bool`, default `True`):
97+
option to clip predicted sample between -1 and 1 for numerical stability.
98+
set_alpha_to_one (`bool`, default `True`):
99+
if alpha for final step is 1 or the final alpha of the "non-previous" one.
100+
"""
101+
102+
@register_to_config
103+
def __init__(
104+
self,
105+
num_train_timesteps: int = 1000,
106+
beta_start: float = 0.0001,
107+
beta_end: float = 0.02,
108+
beta_schedule: str = "linear",
109+
trained_betas: Optional[jnp.ndarray] = None,
110+
clip_sample: bool = True,
111+
set_alpha_to_one: bool = True,
112+
):
113+
if trained_betas is not None:
114+
self.betas = jnp.asarray(trained_betas)
115+
if beta_schedule == "linear":
116+
self.betas = jnp.linspace(beta_start, beta_end, num_train_timesteps, dtype=jnp.float32)
117+
elif beta_schedule == "scaled_linear":
118+
# this schedule is very specific to the latent diffusion model.
119+
self.betas = jnp.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=jnp.float32) ** 2
120+
elif beta_schedule == "squaredcos_cap_v2":
121+
# Glide cosine schedule
122+
self.betas = betas_for_alpha_bar(num_train_timesteps)
123+
else:
124+
raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
125+
126+
self.alphas = 1.0 - self.betas
127+
self.alphas_cumprod = jnp.cumprod(self.alphas, axis=0)
128+
129+
# At every step in ddim, we are looking into the previous alphas_cumprod
130+
# For the final step, there is no previous alphas_cumprod because we are already at 0
131+
# `set_alpha_to_one` decides whether we set this parameter simply to one or
132+
# whether we use the final alpha of the "non-previous" one.
133+
self.final_alpha_cumprod = jnp.array(1.0) if set_alpha_to_one else self.alphas_cumprod[0]
134+
135+
self.state = DDIMSchedulerState.create(num_train_timesteps=num_train_timesteps)
136+
137+
def _get_variance(self, timestep, prev_timestep):
138+
alpha_prod_t = self.alphas_cumprod[timestep]
139+
alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
140+
beta_prod_t = 1 - alpha_prod_t
141+
beta_prod_t_prev = 1 - alpha_prod_t_prev
142+
143+
variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev)
144+
145+
return variance
146+
147+
def set_timesteps(
148+
self, state: DDIMSchedulerState, num_inference_steps: int, offset: int = 0
149+
) -> DDIMSchedulerState:
150+
"""
151+
Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.
152+
153+
Args:
154+
state (`DDIMSchedulerState`):
155+
the `FlaxDDIMScheduler` state data class instance.
156+
num_inference_steps (`int`):
157+
the number of diffusion steps used when generating samples with a pre-trained model.
158+
offset (`int`):
159+
optional value to shift timestep values up by. A value of 1 is used in stable diffusion for inference.
160+
"""
161+
step_ratio = self.config.num_train_timesteps // num_inference_steps
162+
# creates integer timesteps by multiplying by ratio
163+
# casting to int to avoid issues when num_inference_step is power of 3
164+
timesteps = (jnp.arange(0, num_inference_steps) * step_ratio).round()[::-1]
165+
timesteps = timesteps + offset
166+
167+
return state.replace(num_inference_steps=num_inference_steps, timesteps=timesteps)
168+
169+
def step(
170+
self,
171+
state: DDIMSchedulerState,
172+
model_output: jnp.ndarray,
173+
timestep: int,
174+
sample: jnp.ndarray,
175+
key: random.KeyArray,
176+
eta: float = 0.0,
177+
use_clipped_model_output: bool = False,
178+
return_dict: bool = True,
179+
) -> Union[FlaxSchedulerOutput, Tuple]:
180+
"""
181+
Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
182+
process from the learned model outputs (most often the predicted noise).
183+
184+
Args:
185+
state (`DDIMSchedulerState`): the `FlaxDDIMScheduler` state data class instance.
186+
model_output (`jnp.ndarray`): direct output from learned diffusion model.
187+
timestep (`int`): current discrete timestep in the diffusion chain.
188+
sample (`jnp.ndarray`):
189+
current instance of sample being created by diffusion process.
190+
key (`random.KeyArray`): a PRNG key.
191+
eta (`float`): weight of noise for added noise in diffusion step.
192+
use_clipped_model_output (`bool`): TODO
193+
return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
194+
195+
Returns:
196+
[`FlaxSchedulerOutput`] or `tuple`: [`FlaxSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`.
197+
When returning a tuple, the first element is the sample tensor.
198+
199+
"""
200+
if state.num_inference_steps is None:
201+
raise ValueError(
202+
"Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
203+
)
204+
205+
# See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf
206+
# Ideally, read DDIM paper in-detail understanding
207+
208+
# Notation (<variable name> -> <name in paper>
209+
# - pred_noise_t -> e_theta(x_t, t)
210+
# - pred_original_sample -> f_theta(x_t, t) or x_0
211+
# - std_dev_t -> sigma_t
212+
# - eta -> η
213+
# - pred_sample_direction -> "direction pointing to x_t"
214+
# - pred_prev_sample -> "x_t-1"
215+
216+
# 1. get previous step value (=t-1)
217+
prev_timestep = timestep - self.config.num_train_timesteps // state.num_inference_steps
218+
219+
# 2. compute alphas, betas
220+
alpha_prod_t = self.alphas_cumprod[timestep]
221+
alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
222+
beta_prod_t = 1 - alpha_prod_t
223+
224+
# 3. compute predicted original sample from predicted noise also called
225+
# "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
226+
pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
227+
228+
# 4. Clip "predicted x_0"
229+
if self.config.clip_sample:
230+
pred_original_sample = jnp.clip(pred_original_sample, -1, 1)
231+
232+
# 5. compute variance: "sigma_t(η)" -> see formula (16)
233+
# σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1)
234+
variance = self._get_variance(timestep, prev_timestep)
235+
std_dev_t = eta * variance ** (0.5)
236+
237+
if use_clipped_model_output:
238+
# the model_output is always re-derived from the clipped x_0 in Glide
239+
model_output = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5)
240+
241+
# 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
242+
pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * model_output
243+
244+
# 7. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
245+
prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction
246+
247+
if eta > 0:
248+
key = random.split(key, num=1)
249+
noise = random.normal(key=key, shape=model_output.shape)
250+
variance = self._get_variance(timestep, prev_timestep) ** (0.5) * eta * noise
251+
252+
prev_sample = prev_sample + variance
253+
254+
if not return_dict:
255+
return (prev_sample, state)
256+
257+
return FlaxSchedulerOutput(prev_sample=prev_sample, state=state)
258+
259+
def add_noise(
260+
self,
261+
original_samples: jnp.ndarray,
262+
noise: jnp.ndarray,
263+
timesteps: jnp.ndarray,
264+
) -> jnp.ndarray:
265+
sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
266+
sqrt_alpha_prod = self.match_shape(sqrt_alpha_prod, original_samples)
267+
sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5
268+
sqrt_one_minus_alpha_prod = self.match_shape(sqrt_one_minus_alpha_prod, original_samples)
269+
270+
noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
271+
return noisy_samples
272+
273+
def __len__(self):
274+
return self.config.num_train_timesteps

src/diffusers/schedulers/scheduling_ddpm.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,7 @@ def _get_variance(self, t, predicted_variance=None, variance_type=None):
148148
if variance_type is None:
149149
variance_type = self.config.variance_type
150150

151-
# hacks - were probs added for training stability
151+
# hacks - were probably added for training stability
152152
if variance_type == "fixed_small":
153153
variance = self.clip(variance, min_value=1e-20)
154154
# for rl-diffuser https://arxiv.org/abs/2205.09991
@@ -187,7 +187,6 @@ def step(
187187
timestep (`int`): current discrete timestep in the diffusion chain.
188188
sample (`torch.FloatTensor` or `np.ndarray`):
189189
current instance of sample being created by diffusion process.
190-
eta (`float`): weight of noise for added noise in diffusion step.
191190
predict_epsilon (`bool`):
192191
optional flag to use when model predicts the samples directly instead of the noise, epsilon.
193192
generator: random number generator.

0 commit comments

Comments
 (0)