Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion examples/dreambooth/train_dreambooth_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -477,6 +477,7 @@ def collate_fn(examples):
noise_scheduler = FlaxDDPMScheduler(
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000
)
noise_scheduler_state = noise_scheduler.create_state()

# Initialize our training
train_rngs = jax.random.split(rng, jax.local_device_count())
Expand Down Expand Up @@ -513,7 +514,7 @@ def compute_loss(params):

# Add noise to the latents according to the noise magnitude at each timestep
# (this is the forward diffusion process)
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
noisy_latents = noise_scheduler.add_noise(noise_scheduler_state, latents, noise, timesteps)

# Get the text embedding for conditioning
if args.train_text_encoder:
Expand Down
3 changes: 2 additions & 1 deletion examples/text_to_image/train_text_to_image_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -417,6 +417,7 @@ def collate_fn(examples):
noise_scheduler = FlaxDDPMScheduler(
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000
)
noise_scheduler_state = noise_scheduler.create_state()

# Initialize our training
rng = jax.random.PRNGKey(args.seed)
Expand Down Expand Up @@ -449,7 +450,7 @@ def compute_loss(params):

# Add noise to the latents according to the noise magnitude at each timestep
# (this is the forward diffusion process)
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
noisy_latents = noise_scheduler.add_noise(noise_scheduler_state, latents, noise, timesteps)

# Get the text embedding for conditioning
encoder_hidden_states = text_encoder(
Expand Down
3 changes: 2 additions & 1 deletion examples/textual_inversion/textual_inversion_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -505,6 +505,7 @@ def update_fn(updates, state, params=None):
noise_scheduler = FlaxDDPMScheduler(
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000
)
noise_scheduler_state = noise_scheduler.create_state()

# Initialize our training
train_rngs = jax.random.split(rng, jax.local_device_count())
Expand All @@ -531,7 +532,7 @@ def compute_loss(params):
0,
noise_scheduler.config.num_train_timesteps,
)
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
noisy_latents = noise_scheduler.add_noise(noise_scheduler_state, latents, noise, timesteps)
encoder_hidden_states = state.apply_fn(
batch["input_ids"], params=params, dropout_rng=dropout_rng, train=True
)[0]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,8 @@ def loop_body(step, args):
)

# scale the initial noise by the standard deviation required by the scheduler
latents = latents * self.scheduler.init_noise_sigma
latents = latents * params["scheduler"].init_noise_sigma

if DEBUG:
# run with python for loop
for i in range(num_inference_steps):
Expand Down
157 changes: 68 additions & 89 deletions src/diffusers/schedulers/scheduling_ddim_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
# DISCLAIMER: This code is strongly influenced by https://github.com/pesser/pytorch_diffusion
# and https://github.com/hojonathanho/diffusion

import math
from dataclasses import dataclass
from typing import Optional, Tuple, Union

Expand All @@ -26,51 +25,37 @@
from ..utils import deprecate
from .scheduling_utils_flax import (
_FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS,
CommonSchedulerState,
FlaxSchedulerMixin,
FlaxSchedulerOutput,
broadcast_to_shape_from_left,
add_noise_common,
)


def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999) -> jnp.ndarray:
"""
Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
(1-beta) over time from t = [0,1].

Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
to that part of the diffusion process.


Args:
num_diffusion_timesteps (`int`): the number of betas to produce.
max_beta (`float`): the maximum beta to use; use values lower than 1 to
prevent singularities.

Returns:
betas (`jnp.ndarray`): the betas used by the scheduler to step the model outputs
"""

def alpha_bar(time_step):
return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2

betas = []
for i in range(num_diffusion_timesteps):
t1 = i / num_diffusion_timesteps
t2 = (i + 1) / num_diffusion_timesteps
betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
return jnp.array(betas, dtype=jnp.float32)


@flax.struct.dataclass
class DDIMSchedulerState:
common: CommonSchedulerState
final_alpha_cumprod: jnp.ndarray

# setable values
init_noise_sigma: jnp.ndarray
timesteps: jnp.ndarray
alphas_cumprod: jnp.ndarray
num_inference_steps: Optional[int] = None

@classmethod
def create(cls, num_train_timesteps: int, alphas_cumprod: jnp.ndarray):
return cls(timesteps=jnp.arange(0, num_train_timesteps)[::-1], alphas_cumprod=alphas_cumprod)
def create(
cls,
common: CommonSchedulerState,
final_alpha_cumprod: jnp.ndarray,
init_noise_sigma: jnp.ndarray,
timesteps: jnp.ndarray,
):
return cls(
common=common,
final_alpha_cumprod=final_alpha_cumprod,
init_noise_sigma=init_noise_sigma,
timesteps=timesteps,
)


@dataclass
Expand Down Expand Up @@ -112,12 +97,15 @@ class FlaxDDIMScheduler(FlaxSchedulerMixin, ConfigMixin):
prediction_type (`str`, default `epsilon`):
indicates whether the model predicts the noise (epsilon), or the samples. One of `epsilon`, `sample`.
`v-prediction` is not supported for this scheduler.

dtype (`jnp.dtype`, *optional*, defaults to `jnp.float32`):
the `dtype` used for params and computation.
"""

_compatibles = _FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy()
_deprecated_kwargs = ["predict_epsilon"]

dtype: jnp.dtype

@property
def has_state(self):
return True
Expand All @@ -129,43 +117,46 @@ def __init__(
beta_start: float = 0.0001,
beta_end: float = 0.02,
beta_schedule: str = "linear",
trained_betas: Optional[jnp.ndarray] = None,
set_alpha_to_one: bool = True,
steps_offset: int = 0,
prediction_type: str = "epsilon",
dtype: jnp.dtype = jnp.float32,
**kwargs,
):
message = (
"Please make sure to instantiate your scheduler with `prediction_type` instead. E.g. `scheduler ="
" FlaxDDIMScheduler.from_pretrained(<model_id>, prediction_type='epsilon')`."
f" {self.__class__.__name__}.from_pretrained(<model_id>, prediction_type='epsilon')`."
)
predict_epsilon = deprecate("predict_epsilon", "0.13.0", message, take_from=kwargs)
if predict_epsilon is not None:
self.register_to_config(prediction_type="epsilon" if predict_epsilon else "sample")

if beta_schedule == "linear":
self.betas = jnp.linspace(beta_start, beta_end, num_train_timesteps, dtype=jnp.float32)
elif beta_schedule == "scaled_linear":
# this schedule is very specific to the latent diffusion model.
self.betas = jnp.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=jnp.float32) ** 2
elif beta_schedule == "squaredcos_cap_v2":
# Glide cosine schedule
self.betas = betas_for_alpha_bar(num_train_timesteps)
else:
raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")

self.alphas = 1.0 - self.betas
self.dtype = dtype

# HACK for now - clean up later (PVP)
self._alphas_cumprod = jnp.cumprod(self.alphas, axis=0)
def create_state(self, common: Optional[CommonSchedulerState] = None) -> DDIMSchedulerState:
if common is None:
common = CommonSchedulerState.create(self)

# At every step in ddim, we are looking into the previous alphas_cumprod
# For the final step, there is no previous alphas_cumprod because we are already at 0
# `set_alpha_to_one` decides whether we set this parameter simply to one or
# whether we use the final alpha of the "non-previous" one.
self.final_alpha_cumprod = jnp.array(1.0) if set_alpha_to_one else float(self._alphas_cumprod[0])
final_alpha_cumprod = (
jnp.array(1.0, dtype=self.dtype) if self.config.set_alpha_to_one else common.alphas_cumprod[0]
)

# standard deviation of the initial noise distribution
self.init_noise_sigma = 1.0
init_noise_sigma = jnp.array(1.0, dtype=self.dtype)

timesteps = jnp.arange(0, self.config.num_train_timesteps).round()[::-1]

return DDIMSchedulerState.create(
common=common,
final_alpha_cumprod=final_alpha_cumprod,
init_noise_sigma=init_noise_sigma,
timesteps=timesteps,
)

def scale_model_input(
self, state: DDIMSchedulerState, sample: jnp.ndarray, timestep: Optional[int] = None
Expand All @@ -181,21 +172,6 @@ def scale_model_input(
"""
return sample

def create_state(self):
return DDIMSchedulerState.create(
num_train_timesteps=self.config.num_train_timesteps, alphas_cumprod=self._alphas_cumprod
)

def _get_variance(self, timestep, prev_timestep, alphas_cumprod):
alpha_prod_t = alphas_cumprod[timestep]
alpha_prod_t_prev = jnp.where(prev_timestep >= 0, alphas_cumprod[prev_timestep], self.final_alpha_cumprod)
beta_prod_t = 1 - alpha_prod_t
beta_prod_t_prev = 1 - alpha_prod_t_prev

variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev)

return variance

def set_timesteps(
self, state: DDIMSchedulerState, num_inference_steps: int, shape: Tuple = ()
) -> DDIMSchedulerState:
Expand All @@ -208,22 +184,35 @@ def set_timesteps(
num_inference_steps (`int`):
the number of diffusion steps used when generating samples with a pre-trained model.
"""
offset = self.config.steps_offset

step_ratio = self.config.num_train_timesteps // num_inference_steps
# creates integer timesteps by multiplying by ratio
# casting to int to avoid issues when num_inference_step is power of 3
timesteps = (jnp.arange(0, num_inference_steps) * step_ratio).round()[::-1]
timesteps = timesteps + offset
# rounding to avoid issues when num_inference_step is power of 3
timesteps = (jnp.arange(0, num_inference_steps) * step_ratio).round()[::-1] + self.config.steps_offset

return state.replace(
num_inference_steps=num_inference_steps,
timesteps=timesteps,
)

def _get_variance(self, state: DDIMSchedulerState, timestep, prev_timestep):
alpha_prod_t = state.common.alphas_cumprod[timestep]
alpha_prod_t_prev = jnp.where(
prev_timestep >= 0, state.common.alphas_cumprod[prev_timestep], state.final_alpha_cumprod
)
beta_prod_t = 1 - alpha_prod_t
beta_prod_t_prev = 1 - alpha_prod_t_prev

variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev)

return state.replace(num_inference_steps=num_inference_steps, timesteps=timesteps)
return variance

def step(
self,
state: DDIMSchedulerState,
model_output: jnp.ndarray,
timestep: int,
sample: jnp.ndarray,
eta: float = 0.0,
return_dict: bool = True,
) -> Union[FlaxDDIMSchedulerOutput, Tuple]:
"""
Expand Down Expand Up @@ -259,17 +248,15 @@ def step(
# - pred_sample_direction -> "direction pointing to x_t"
# - pred_prev_sample -> "x_t-1"

# TODO(Patrick) - eta is always 0.0 for now, allow to be set in step function
eta = 0.0

# 1. get previous step value (=t-1)
prev_timestep = timestep - self.config.num_train_timesteps // state.num_inference_steps

alphas_cumprod = state.alphas_cumprod
alphas_cumprod = state.common.alphas_cumprod
final_alpha_cumprod = state.final_alpha_cumprod

# 2. compute alphas, betas
alpha_prod_t = alphas_cumprod[timestep]
alpha_prod_t_prev = jnp.where(prev_timestep >= 0, alphas_cumprod[prev_timestep], self.final_alpha_cumprod)
alpha_prod_t_prev = jnp.where(prev_timestep >= 0, alphas_cumprod[prev_timestep], final_alpha_cumprod)

beta_prod_t = 1 - alpha_prod_t

Expand All @@ -291,7 +278,7 @@ def step(

# 4. compute variance: "sigma_t(η)" -> see formula (16)
# σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1)
variance = self._get_variance(timestep, prev_timestep, alphas_cumprod)
variance = self._get_variance(state, timestep, prev_timestep)
std_dev_t = eta * variance ** (0.5)

# 5. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
Expand All @@ -307,20 +294,12 @@ def step(

def add_noise(
self,
state: DDIMSchedulerState,
original_samples: jnp.ndarray,
noise: jnp.ndarray,
timesteps: jnp.ndarray,
) -> jnp.ndarray:
sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
sqrt_alpha_prod = sqrt_alpha_prod.flatten()
sqrt_alpha_prod = broadcast_to_shape_from_left(sqrt_alpha_prod, original_samples.shape)

sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.0
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
sqrt_one_minus_alpha_prod = broadcast_to_shape_from_left(sqrt_one_minus_alpha_prod, original_samples.shape)

noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
return noisy_samples
return add_noise_common(state.common, original_samples, noise, timesteps)

def __len__(self):
return self.config.num_train_timesteps
Loading