|
| 1 | +# Copyright 2022 Stanford University Team and The HuggingFace Team. All rights reserved. |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | + |
| 15 | +# DISCLAIMER: This code is strongly influenced by https://github.com/pesser/pytorch_diffusion |
| 16 | +# and https://github.com/hojonathanho/diffusion |
| 17 | + |
| 18 | +import math |
| 19 | +from dataclasses import dataclass |
| 20 | +from typing import Optional, Tuple, Union |
| 21 | + |
| 22 | +import flax |
| 23 | +import jax.numpy as jnp |
| 24 | +from jax import random |
| 25 | + |
| 26 | +from ..configuration_utils import ConfigMixin, register_to_config |
| 27 | +from .scheduling_utils import SchedulerMixin, SchedulerOutput |
| 28 | + |
| 29 | + |
| 30 | +def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999) -> jnp.ndarray: |
| 31 | + """ |
| 32 | + Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of |
| 33 | + (1-beta) over time from t = [0,1]. |
| 34 | +
|
| 35 | + Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up |
| 36 | + to that part of the diffusion process. |
| 37 | +
|
| 38 | +
|
| 39 | + Args: |
| 40 | + num_diffusion_timesteps (`int`): the number of betas to produce. |
| 41 | + max_beta (`float`): the maximum beta to use; use values lower than 1 to |
| 42 | + prevent singularities. |
| 43 | +
|
| 44 | + Returns: |
| 45 | + betas (`jnp.ndarray`): the betas used by the scheduler to step the model outputs |
| 46 | + """ |
| 47 | + |
| 48 | + def alpha_bar(time_step): |
| 49 | + return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2 |
| 50 | + |
| 51 | + betas = [] |
| 52 | + for i in range(num_diffusion_timesteps): |
| 53 | + t1 = i / num_diffusion_timesteps |
| 54 | + t2 = (i + 1) / num_diffusion_timesteps |
| 55 | + betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) |
| 56 | + return jnp.array(betas, dtype=jnp.float32) |
| 57 | + |
| 58 | + |
| 59 | +@flax.struct.dataclass |
| 60 | +class DDIMSchedulerState: |
| 61 | + # setable values |
| 62 | + timesteps: jnp.ndarray |
| 63 | + num_inference_steps: Optional[int] = None |
| 64 | + |
| 65 | + @classmethod |
| 66 | + def create(cls, num_train_timesteps: int): |
| 67 | + return cls(timesteps=jnp.arange(0, num_train_timesteps)[::-1]) |
| 68 | + |
| 69 | + |
| 70 | +@dataclass |
| 71 | +class FlaxSchedulerOutput(SchedulerOutput): |
| 72 | + state: DDIMSchedulerState |
| 73 | + |
| 74 | + |
| 75 | +class FlaxDDIMScheduler(SchedulerMixin, ConfigMixin): |
| 76 | + """ |
| 77 | + Denoising diffusion implicit models is a scheduler that extends the denoising procedure introduced in denoising |
| 78 | + diffusion probabilistic models (DDPMs) with non-Markovian guidance. |
| 79 | +
|
| 80 | + [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` |
| 81 | + function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. |
| 82 | + [`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and |
| 83 | + [`~ConfigMixin.from_config`] functions. |
| 84 | +
|
| 85 | + For more details, see the original paper: https://arxiv.org/abs/2010.02502 |
| 86 | +
|
| 87 | + Args: |
| 88 | + num_train_timesteps (`int`): number of diffusion steps used to train the model. |
| 89 | + beta_start (`float`): the starting `beta` value of inference. |
| 90 | + beta_end (`float`): the final `beta` value. |
| 91 | + beta_schedule (`str`): |
| 92 | + the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from |
| 93 | + `linear`, `scaled_linear`, or `squaredcos_cap_v2`. |
| 94 | + trained_betas (`jnp.ndarray`, optional): |
| 95 | + option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc. |
| 96 | + clip_sample (`bool`, default `True`): |
| 97 | + option to clip predicted sample between -1 and 1 for numerical stability. |
| 98 | + set_alpha_to_one (`bool`, default `True`): |
| 99 | + if alpha for final step is 1 or the final alpha of the "non-previous" one. |
| 100 | + """ |
| 101 | + |
| 102 | + @register_to_config |
| 103 | + def __init__( |
| 104 | + self, |
| 105 | + num_train_timesteps: int = 1000, |
| 106 | + beta_start: float = 0.0001, |
| 107 | + beta_end: float = 0.02, |
| 108 | + beta_schedule: str = "linear", |
| 109 | + trained_betas: Optional[jnp.ndarray] = None, |
| 110 | + clip_sample: bool = True, |
| 111 | + set_alpha_to_one: bool = True, |
| 112 | + ): |
| 113 | + if trained_betas is not None: |
| 114 | + self.betas = jnp.asarray(trained_betas) |
| 115 | + if beta_schedule == "linear": |
| 116 | + self.betas = jnp.linspace(beta_start, beta_end, num_train_timesteps, dtype=jnp.float32) |
| 117 | + elif beta_schedule == "scaled_linear": |
| 118 | + # this schedule is very specific to the latent diffusion model. |
| 119 | + self.betas = jnp.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=jnp.float32) ** 2 |
| 120 | + elif beta_schedule == "squaredcos_cap_v2": |
| 121 | + # Glide cosine schedule |
| 122 | + self.betas = betas_for_alpha_bar(num_train_timesteps) |
| 123 | + else: |
| 124 | + raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") |
| 125 | + |
| 126 | + self.alphas = 1.0 - self.betas |
| 127 | + self.alphas_cumprod = jnp.cumprod(self.alphas, axis=0) |
| 128 | + |
| 129 | + # At every step in ddim, we are looking into the previous alphas_cumprod |
| 130 | + # For the final step, there is no previous alphas_cumprod because we are already at 0 |
| 131 | + # `set_alpha_to_one` decides whether we set this parameter simply to one or |
| 132 | + # whether we use the final alpha of the "non-previous" one. |
| 133 | + self.final_alpha_cumprod = jnp.array(1.0) if set_alpha_to_one else self.alphas_cumprod[0] |
| 134 | + |
| 135 | + self.state = DDIMSchedulerState.create(num_train_timesteps=num_train_timesteps) |
| 136 | + |
| 137 | + def _get_variance(self, timestep, prev_timestep): |
| 138 | + alpha_prod_t = self.alphas_cumprod[timestep] |
| 139 | + alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod |
| 140 | + beta_prod_t = 1 - alpha_prod_t |
| 141 | + beta_prod_t_prev = 1 - alpha_prod_t_prev |
| 142 | + |
| 143 | + variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev) |
| 144 | + |
| 145 | + return variance |
| 146 | + |
| 147 | + def set_timesteps( |
| 148 | + self, state: DDIMSchedulerState, num_inference_steps: int, offset: int = 0 |
| 149 | + ) -> DDIMSchedulerState: |
| 150 | + """ |
| 151 | + Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference. |
| 152 | +
|
| 153 | + Args: |
| 154 | + state (`DDIMSchedulerState`): |
| 155 | + the `FlaxDDIMScheduler` state data class instance. |
| 156 | + num_inference_steps (`int`): |
| 157 | + the number of diffusion steps used when generating samples with a pre-trained model. |
| 158 | + offset (`int`): |
| 159 | + optional value to shift timestep values up by. A value of 1 is used in stable diffusion for inference. |
| 160 | + """ |
| 161 | + step_ratio = self.config.num_train_timesteps // num_inference_steps |
| 162 | + # creates integer timesteps by multiplying by ratio |
| 163 | + # casting to int to avoid issues when num_inference_step is power of 3 |
| 164 | + timesteps = (jnp.arange(0, num_inference_steps) * step_ratio).round()[::-1] |
| 165 | + timesteps = timesteps + offset |
| 166 | + |
| 167 | + return state.replace(num_inference_steps=num_inference_steps, timesteps=timesteps) |
| 168 | + |
| 169 | + def step( |
| 170 | + self, |
| 171 | + state: DDIMSchedulerState, |
| 172 | + model_output: jnp.ndarray, |
| 173 | + timestep: int, |
| 174 | + sample: jnp.ndarray, |
| 175 | + key: random.KeyArray, |
| 176 | + eta: float = 0.0, |
| 177 | + use_clipped_model_output: bool = False, |
| 178 | + return_dict: bool = True, |
| 179 | + ) -> Union[FlaxSchedulerOutput, Tuple]: |
| 180 | + """ |
| 181 | + Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion |
| 182 | + process from the learned model outputs (most often the predicted noise). |
| 183 | +
|
| 184 | + Args: |
| 185 | + state (`DDIMSchedulerState`): the `FlaxDDIMScheduler` state data class instance. |
| 186 | + model_output (`jnp.ndarray`): direct output from learned diffusion model. |
| 187 | + timestep (`int`): current discrete timestep in the diffusion chain. |
| 188 | + sample (`jnp.ndarray`): |
| 189 | + current instance of sample being created by diffusion process. |
| 190 | + key (`random.KeyArray`): a PRNG key. |
| 191 | + eta (`float`): weight of noise for added noise in diffusion step. |
| 192 | + use_clipped_model_output (`bool`): TODO |
| 193 | + return_dict (`bool`): option for returning tuple rather than SchedulerOutput class |
| 194 | +
|
| 195 | + Returns: |
| 196 | + [`FlaxSchedulerOutput`] or `tuple`: [`FlaxSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. |
| 197 | + When returning a tuple, the first element is the sample tensor. |
| 198 | +
|
| 199 | + """ |
| 200 | + if state.num_inference_steps is None: |
| 201 | + raise ValueError( |
| 202 | + "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" |
| 203 | + ) |
| 204 | + |
| 205 | + # See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf |
| 206 | + # Ideally, read DDIM paper in-detail understanding |
| 207 | + |
| 208 | + # Notation (<variable name> -> <name in paper> |
| 209 | + # - pred_noise_t -> e_theta(x_t, t) |
| 210 | + # - pred_original_sample -> f_theta(x_t, t) or x_0 |
| 211 | + # - std_dev_t -> sigma_t |
| 212 | + # - eta -> η |
| 213 | + # - pred_sample_direction -> "direction pointing to x_t" |
| 214 | + # - pred_prev_sample -> "x_t-1" |
| 215 | + |
| 216 | + # 1. get previous step value (=t-1) |
| 217 | + prev_timestep = timestep - self.config.num_train_timesteps // state.num_inference_steps |
| 218 | + |
| 219 | + # 2. compute alphas, betas |
| 220 | + alpha_prod_t = self.alphas_cumprod[timestep] |
| 221 | + alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod |
| 222 | + beta_prod_t = 1 - alpha_prod_t |
| 223 | + |
| 224 | + # 3. compute predicted original sample from predicted noise also called |
| 225 | + # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf |
| 226 | + pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) |
| 227 | + |
| 228 | + # 4. Clip "predicted x_0" |
| 229 | + if self.config.clip_sample: |
| 230 | + pred_original_sample = jnp.clip(pred_original_sample, -1, 1) |
| 231 | + |
| 232 | + # 5. compute variance: "sigma_t(η)" -> see formula (16) |
| 233 | + # σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1) |
| 234 | + variance = self._get_variance(timestep, prev_timestep) |
| 235 | + std_dev_t = eta * variance ** (0.5) |
| 236 | + |
| 237 | + if use_clipped_model_output: |
| 238 | + # the model_output is always re-derived from the clipped x_0 in Glide |
| 239 | + model_output = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5) |
| 240 | + |
| 241 | + # 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf |
| 242 | + pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * model_output |
| 243 | + |
| 244 | + # 7. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf |
| 245 | + prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction |
| 246 | + |
| 247 | + if eta > 0: |
| 248 | + key = random.split(key, num=1) |
| 249 | + noise = random.normal(key=key, shape=model_output.shape) |
| 250 | + variance = self._get_variance(timestep, prev_timestep) ** (0.5) * eta * noise |
| 251 | + |
| 252 | + prev_sample = prev_sample + variance |
| 253 | + |
| 254 | + if not return_dict: |
| 255 | + return (prev_sample, state) |
| 256 | + |
| 257 | + return FlaxSchedulerOutput(prev_sample=prev_sample, state=state) |
| 258 | + |
| 259 | + def add_noise( |
| 260 | + self, |
| 261 | + original_samples: jnp.ndarray, |
| 262 | + noise: jnp.ndarray, |
| 263 | + timesteps: jnp.ndarray, |
| 264 | + ) -> jnp.ndarray: |
| 265 | + sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5 |
| 266 | + sqrt_alpha_prod = self.match_shape(sqrt_alpha_prod, original_samples) |
| 267 | + sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5 |
| 268 | + sqrt_one_minus_alpha_prod = self.match_shape(sqrt_one_minus_alpha_prod, original_samples) |
| 269 | + |
| 270 | + noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise |
| 271 | + return noisy_samples |
| 272 | + |
| 273 | + def __len__(self): |
| 274 | + return self.config.num_train_timesteps |
0 commit comments