From 3945f212a44fa74cf8e8ba97bdbdd1ec0576eda7 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Tue, 13 Sep 2022 09:15:07 +0200 Subject: [PATCH 01/18] initial flax pndm --- .../schedulers/scheduling_pndm_flax.py | 418 ++++++++++++++++++ 1 file changed, 418 insertions(+) create mode 100644 src/diffusers/schedulers/scheduling_pndm_flax.py diff --git a/src/diffusers/schedulers/scheduling_pndm_flax.py b/src/diffusers/schedulers/scheduling_pndm_flax.py new file mode 100644 index 000000000000..a3a982090b4c --- /dev/null +++ b/src/diffusers/schedulers/scheduling_pndm_flax.py @@ -0,0 +1,418 @@ +# Copyright 2022 Zhejiang University Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim + +import math +from typing import Optional, Tuple, Union, List + +import jax +import jax.numpy as jnp +import flax + +from ..configuration_utils import ConfigMixin, register_to_config +from .scheduling_utils import SchedulerMixin, SchedulerOutput + + +def betas_for_alpha_bar(num_diffusion_timesteps: int, max_beta=0.999): + """ + Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of + (1-beta) over time from t = [0,1]. + + Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up + to that part of the diffusion process. + + + Args: + num_diffusion_timesteps (`int`): the number of betas to produce. + max_beta (`float`): the maximum beta to use; use values lower than 1 to + prevent singularities. + + Returns: + betas (`np.ndarray`): the betas used by the scheduler to step the model outputs + """ + + def alpha_bar(time_step): + return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2 + + betas = [] + for i in range(num_diffusion_timesteps): + t1 = i / num_diffusion_timesteps + t2 = (i + 1) / num_diffusion_timesteps + betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) + return jnp.array(betas, dtype=jnp.float32) + + +@flax.struct.dataclass +class PNDMSchedulerState: + betas: jnp.array + + # setable values + _timesteps: jnp.array + num_inference_steps: Optional[int] = None + _offset: int = 0 + prk_timesteps: Optional[jnp.array] = None + plms_timesteps: Optional[jnp.array] = None + timesteps: Optional[jnp.array] = None + + # running values + cur_model_output: Optional[jnp.ndarray] = None + counter: int = 0 + cur_sample: Optional[jnp.ndarray] = None + ets: List = [] + + @property + def alphas(self) -> jnp.array: + return 1.0 - self.betas + + @property + def alphas_cumprod(self) -> jnp.array: + return jnp.cumprod(self.alphas, axis=0) + + @classmethod + def create(cls, beta: jnp.array, num_train_timesteps: int): + state = cls( + betas=beta, + _timesteps=jnp.arange(0, num_train_timesteps)[::-1].copy(), + ) + + +class FlaxPNDMScheduler(SchedulerMixin, ConfigMixin): + """ + Pseudo numerical methods for diffusion models (PNDM) proposes using more advanced ODE integration techniques, + namely Runge-Kutta method and a linear multi-step method. + + [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` + function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. + [`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and + [`~ConfigMixin.from_config`] functions. + + For more details, see the original paper: https://arxiv.org/abs/2202.09778 + + Args: + num_train_timesteps (`int`): number of diffusion steps used to train the model. + beta_start (`float`): the starting `beta` value of inference. + beta_end (`float`): the final `beta` value. + beta_schedule (`str`): + the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from + `linear`, `scaled_linear`, or `squaredcos_cap_v2`. + trained_betas (`np.ndarray`, optional): TODO + tensor_format (`str`): whether the scheduler expects pytorch or numpy arrays + skip_prk_steps (`bool`): + allows the scheduler to skip the Runge-Kutta steps that are defined in the original paper as being required + before plms steps; defaults to `False`. + """ + + @register_to_config + def __init__( + self, + num_train_timesteps: int = 1000, + beta_start: float = 0.0001, + beta_end: float = 0.02, + beta_schedule: str = "linear", + trained_betas: Optional[jnp.array] = None, + tensor_format: str = "np", + skip_prk_steps: bool = False, + ): + if trained_betas is not None: + betas = jnp.asarray(trained_betas) + if beta_schedule == "linear": + betas = jnp.linspace(beta_start, beta_end, num_train_timesteps, dtype=jnp.float32) + elif beta_schedule == "scaled_linear": + # this schedule is very specific to the latent diffusion model. + betas = jnp.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=jnp.float32) ** 2 + elif beta_schedule == "squaredcos_cap_v2": + # Glide cosine schedule + betas = betas_for_alpha_bar(num_train_timesteps) + else: + raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") + + # For now we only support F-PNDM, i.e. the runge-kutta method + # For more information on the algorithm please take a look at the paper: https://arxiv.org/pdf/2202.09778.pdf + # mainly at formula (9), (12), (13) and the Algorithm 2. + self.pndm_order = 4 + + self.state = PNDMSchedulerState.create(betas=betas, num_train_timesteps=num_train_timesteps) + + self.tensor_format = tensor_format + self.set_format(tensor_format=tensor_format) + + def set_timesteps( + self, state: PNDMSchedulerState, num_inference_steps: int, offset: int = 0 + ) -> PNDMSchedulerState: + """ + Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference. + + Args: + state (`PNDMSchedulerState`): + the PNDMScheduler state data class instance. + num_inference_steps (`int`): + the number of diffusion steps used when generating samples with a pre-trained model. + offset (`int`): TODO + """ + self.num_inference_steps = num_inference_steps + + step_ratio = self.config.num_train_timesteps // self.num_inference_steps + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + _timesteps = (jnp.arange(0, num_inference_steps) * step_ratio).astype(int) + _timesteps = _timesteps + offset + + state = state.replace(num_inference_steps=num_inference_steps, _offset=offset, _timesteps=_timesteps) + + if self.config.skip_prk_steps: + # for some models like stable diffusion the prk steps can/should be skipped to + # produce better results. When using PNDM with `self.config.skip_prk_steps` the implementation + # is based on crowsonkb's PLMS sampler implementation: https://github.com/CompVis/latent-diffusion/pull/51 + state = state.replace( + prk_timesteps=jnp.array([]), + plms_timesteps=jnp.concatenate( + [state._timesteps[:-1], state._timesteps[-2:-1], state._timesteps[-1:]] + )[::-1].copy(), + ) + else: + prk_timesteps = jnp.array(state._timesteps[-self.pndm_order :]).repeat(2) + jnp.tile( + jnp.array([0, self.config.num_train_timesteps // num_inference_steps // 2]), self.pndm_order + ) + + state = state.replace( + prk_timesteps=(prk_timesteps[:-1].repeat(2)[1:-1])[::-1].copy(), + plms_timesteps=state._timesteps[:-3][ + ::-1 + ].copy(), # we copy to avoid having negative strides which are not supported by torch.from_numpy + ) + + return state.replace( + timesteps=jnp.concatenate([state.prk_timesteps, state.plms_timesteps]).astype(jnp.int64), ets=[], counter=0 + ) + + def step( + self, + state: PNDMSchedulerState, + model_output: jnp.ndarray, + timestep: int, + sample: jnp.ndarray, + return_dict: bool = True, + ) -> Union[SchedulerOutput, Tuple]: + """ + Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion + process from the learned model outputs (most often the predicted noise). + + This function calls `step_prk()` or `step_plms()` depending on the internal variable `counter`. + + Args: + state (`PNDMSchedulerState`): the PNDMScheduler state data class instance. + model_output (`jnp.ndarray`): direct output from learned diffusion model. + timestep (`int`): current discrete timestep in the diffusion chain. + sample (`jnp.ndarray`): + current instance of sample being created by diffusion process. + return_dict (`bool`): option for returning tuple rather than SchedulerOutput class + + Returns: + [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`: + [`~schedulers.scheduling_utils.SchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When + returning a tuple, the first element is the sample tensor. + + """ + if state.counter < len(state.prk_timesteps) and not self.config.skip_prk_steps: + return self.step_prk( + state=state, model_output=model_output, timestep=timestep, sample=sample, return_dict=return_dict + ) + else: + return self.step_plms( + state=state, model_output=model_output, timestep=timestep, sample=sample, return_dict=return_dict + ) + + def step_prk( + self, + state: PNDMSchedulerState, + model_output: jnp.ndarray, + timestep: int, + sample: jnp.ndarray, + return_dict: bool = True, + ) -> Union[SchedulerOutput, Tuple]: + """ + Step function propagating the sample with the Runge-Kutta method. RK takes 4 forward passes to approximate the + solution to the differential equation. + + Args: + state (`PNDMSchedulerState`): the PNDMScheduler state data class instance. + model_output (`jnp.ndarray`): direct output from learned diffusion model. + timestep (`int`): current discrete timestep in the diffusion chain. + sample (`jnp.ndarray`): + current instance of sample being created by diffusion process. + return_dict (`bool`): option for returning tuple rather than SchedulerOutput class + + Returns: + [`~scheduling_utils.SchedulerOutput`] or `tuple`: [`~scheduling_utils.SchedulerOutput`] if `return_dict` is + True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. + + """ + if self.num_inference_steps is None: + raise ValueError( + "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" + ) + + diff_to_prev = 0 if state.counter % 2 else self.config.num_train_timesteps // state.num_inference_steps // 2 + prev_timestep = max(timestep - diff_to_prev, state.prk_timesteps[-1]) + timestep = state.prk_timesteps[state.counter // 4 * 4] + + if state.counter % 4 == 0: + state.replace( + cur_model_output=state.cur_model_output + 1 / 6 * model_output, + ets=state.ets.append(model_output), + cur_sample=sample, + ) + elif (self.counter - 1) % 4 == 0: + state.replace(cur_model_output=state.cur_model_output + 1 / 3 * model_output) + elif (self.counter - 2) % 4 == 0: + state.replace(cur_model_output=state.cur_model_output + 1 / 3 * model_output) + elif (self.counter - 3) % 4 == 0: + model_output = state.cur_model_output + 1 / 6 * model_output + state.replace(cur_model_output=0) + + # cur_sample should not be `None` + cur_sample = state.cur_sample if state.cur_sample is not None else sample + + prev_sample = self._get_prev_sample(cur_sample, timestep, prev_timestep, model_output, state=state) + state.replace(counter=state.counter + 1) + + if not return_dict: + return (prev_sample,) + + return SchedulerOutput(prev_sample=prev_sample) + + def step_plms( + self, + state: PNDMSchedulerState, + model_output: jnp.ndarray, + timestep: int, + sample: jnp.ndarray, + return_dict: bool = True, + ) -> Union[SchedulerOutput, Tuple]: + """ + Step function propagating the sample with the linear multi-step method. This has one forward pass with multiple + times to approximate the solution. + + Args: + state (`PNDMSchedulerState`): the PNDMScheduler state data class instance. + model_output (`jnp.ndarray`): direct output from learned diffusion model. + timestep (`int`): current discrete timestep in the diffusion chain. + sample (`jnp.ndarray`): + current instance of sample being created by diffusion process. + return_dict (`bool`): option for returning tuple rather than SchedulerOutput class + + Returns: + [`~scheduling_utils.SchedulerOutput`] or `tuple`: [`~scheduling_utils.SchedulerOutput`] if `return_dict` is + True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. + + """ + if state.num_inference_steps is None: + raise ValueError( + "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" + ) + + if not self.config.skip_prk_steps and len(state.ets) < 3: + raise ValueError( + f"{self.__class__} can only be run AFTER scheduler has been run " + "in 'prk' mode for at least 12 iterations " + "See: https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/pipeline_pndm.py " + "for more information." + ) + + prev_timestep = max(timestep - self.config.num_train_timesteps // state.num_inference_steps, 0) + + if state.counter != 1: + state.replace(ets=state.ets.append(model_output)) + else: + prev_timestep = timestep + timestep = timestep + self.config.num_train_timesteps // state.num_inference_steps + + if len(state.ets) == 1 and state.counter == 0: + model_output = model_output + state.replace(cur_sampe=sample) + elif len(state.ets) == 1 and state.counter == 1: + model_output = (model_output + state.ets[-1]) / 2 + sample = state.cur_sample + state.replace(cur_sample=None) + elif len(state.ets) == 2: + model_output = (3 * state.ets[-1] - state.ets[-2]) / 2 + elif len(state.ets) == 3: + model_output = (23 * state.ets[-1] - 16 * state.ets[-2] + 5 * state.ets[-3]) / 12 + else: + model_output = (1 / 24) * ( + 55 * state.ets[-1] - 59 * state.ets[-2] + 37 * state.ets[-3] - 9 * state.ets[-4] + ) + + prev_sample = self._get_prev_sample(sample, timestep, prev_timestep, model_output, state=state) + state.replace(counter=state.counter + 1) + + if not return_dict: + return (prev_sample,) + + return SchedulerOutput(prev_sample=prev_sample) + + def _get_prev_sample(self, sample, timestep, timestep_prev, model_output, state): + # See formula (9) of PNDM paper https://arxiv.org/pdf/2202.09778.pdf + # this function computes x_(t−δ) using the formula of (9) + # Note that x_t needs to be added to both sides of the equation + + # Notation ( -> + # alpha_prod_t -> α_t + # alpha_prod_t_prev -> α_(t−δ) + # beta_prod_t -> (1 - α_t) + # beta_prod_t_prev -> (1 - α_(t−δ)) + # sample -> x_t + # model_output -> e_θ(x_t, t) + # prev_sample -> x_(t−δ) + alpha_prod_t = state.alphas_cumprod[timestep + 1 - state._offset] + alpha_prod_t_prev = state.alphas_cumprod[timestep_prev + 1 - state._offset] + beta_prod_t = 1 - alpha_prod_t + beta_prod_t_prev = 1 - alpha_prod_t_prev + + # corresponds to (α_(t−δ) - α_t) divided by + # denominator of x_t in formula (9) and plus 1 + # Note: (α_(t−δ) - α_t) / (sqrt(α_t) * (sqrt(α_(t−δ)) + sqr(α_t))) = + # sqrt(α_(t−δ)) / sqrt(α_t)) + sample_coeff = (alpha_prod_t_prev / alpha_prod_t) ** (0.5) + + # corresponds to denominator of e_θ(x_t, t) in formula (9) + model_output_denom_coeff = alpha_prod_t * beta_prod_t_prev ** (0.5) + ( + alpha_prod_t * beta_prod_t * alpha_prod_t_prev + ) ** (0.5) + + # full formula (9) + prev_sample = ( + sample_coeff * sample - (alpha_prod_t_prev - alpha_prod_t) * model_output / model_output_denom_coeff + ) + + return prev_sample + + def add_noise( + self, + state: PNDMSchedulerState, + original_samples: jnp.ndarray, + noise: jnp.ndarray, + timesteps: jnp.ndarray, + ) -> jnp.ndarray: + sqrt_alpha_prod = state.alphas_cumprod[timesteps] ** 0.5 + sqrt_alpha_prod = self.match_shape(sqrt_alpha_prod, original_samples) + sqrt_one_minus_alpha_prod = (1 - state.alphas_cumprod[timesteps]) ** 0.5 + sqrt_one_minus_alpha_prod = self.match_shape(sqrt_one_minus_alpha_prod, original_samples) + + noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise + return noisy_samples + + def __len__(self): + return self.config.num_train_timesteps From 48f671720078531eb38d3611aaea32685ac13390 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Tue, 13 Sep 2022 09:18:28 +0200 Subject: [PATCH 02/18] fix typo --- src/diffusers/schedulers/scheduling_pndm_flax.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/schedulers/scheduling_pndm_flax.py b/src/diffusers/schedulers/scheduling_pndm_flax.py index a3a982090b4c..55f3fa3e872d 100644 --- a/src/diffusers/schedulers/scheduling_pndm_flax.py +++ b/src/diffusers/schedulers/scheduling_pndm_flax.py @@ -166,7 +166,7 @@ def set_timesteps( step_ratio = self.config.num_train_timesteps // self.num_inference_steps # creates integer timesteps by multiplying by ratio # casting to int to avoid issues when num_inference_step is power of 3 - _timesteps = (jnp.arange(0, num_inference_steps) * step_ratio).astype(int) + _timesteps = (jnp.arange(0, num_inference_steps) * step_ratio).astype(int)[::-1].copy() _timesteps = _timesteps + offset state = state.replace(num_inference_steps=num_inference_steps, _offset=offset, _timesteps=_timesteps) From 8b93a164dace11632611a17df7e106c4fb52ee9a Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Tue, 13 Sep 2022 09:22:00 +0200 Subject: [PATCH 03/18] use state --- src/diffusers/schedulers/scheduling_pndm_flax.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_pndm_flax.py b/src/diffusers/schedulers/scheduling_pndm_flax.py index 55f3fa3e872d..2a08a88e715c 100644 --- a/src/diffusers/schedulers/scheduling_pndm_flax.py +++ b/src/diffusers/schedulers/scheduling_pndm_flax.py @@ -161,9 +161,7 @@ def set_timesteps( the number of diffusion steps used when generating samples with a pre-trained model. offset (`int`): TODO """ - self.num_inference_steps = num_inference_steps - - step_ratio = self.config.num_train_timesteps // self.num_inference_steps + step_ratio = self.config.num_train_timesteps // num_inference_steps # creates integer timesteps by multiplying by ratio # casting to int to avoid issues when num_inference_step is power of 3 _timesteps = (jnp.arange(0, num_inference_steps) * step_ratio).astype(int)[::-1].copy() @@ -259,7 +257,7 @@ def step_prk( True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. """ - if self.num_inference_steps is None: + if state.num_inference_steps is None: raise ValueError( "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" ) From c8a3d2dba9e60107485d6800121a8fd33929e8cb Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Tue, 13 Sep 2022 09:27:25 +0200 Subject: [PATCH 04/18] return state --- src/diffusers/schedulers/scheduling_pndm_flax.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_pndm_flax.py b/src/diffusers/schedulers/scheduling_pndm_flax.py index 2a08a88e715c..a032ad8f1c82 100644 --- a/src/diffusers/schedulers/scheduling_pndm_flax.py +++ b/src/diffusers/schedulers/scheduling_pndm_flax.py @@ -164,7 +164,7 @@ def set_timesteps( step_ratio = self.config.num_train_timesteps // num_inference_steps # creates integer timesteps by multiplying by ratio # casting to int to avoid issues when num_inference_step is power of 3 - _timesteps = (jnp.arange(0, num_inference_steps) * step_ratio).astype(int)[::-1].copy() + _timesteps = (jnp.arange(0, num_inference_steps) * step_ratio).astype(jnp.int64)[::-1].copy() _timesteps = _timesteps + offset state = state.replace(num_inference_steps=num_inference_steps, _offset=offset, _timesteps=_timesteps) @@ -287,7 +287,7 @@ def step_prk( state.replace(counter=state.counter + 1) if not return_dict: - return (prev_sample,) + return (prev_sample, state) return SchedulerOutput(prev_sample=prev_sample) @@ -357,7 +357,7 @@ def step_plms( state.replace(counter=state.counter + 1) if not return_dict: - return (prev_sample,) + return (prev_sample, state) return SchedulerOutput(prev_sample=prev_sample) From 7d045a55fc0dcef9e7404246b9b0b82fd5a3a4d5 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Tue, 13 Sep 2022 11:40:48 +0200 Subject: [PATCH 05/18] add FlaxSchedulerOutput --- .../schedulers/scheduling_pndm_flax.py | 27 +++++++++++-------- 1 file changed, 16 insertions(+), 11 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_pndm_flax.py b/src/diffusers/schedulers/scheduling_pndm_flax.py index a032ad8f1c82..ea1e20c1ae17 100644 --- a/src/diffusers/schedulers/scheduling_pndm_flax.py +++ b/src/diffusers/schedulers/scheduling_pndm_flax.py @@ -13,7 +13,7 @@ # limitations under the License. # DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim - +from dataclasses import dataclass import math from typing import Optional, Tuple, Union, List @@ -40,7 +40,7 @@ def betas_for_alpha_bar(num_diffusion_timesteps: int, max_beta=0.999): prevent singularities. Returns: - betas (`np.ndarray`): the betas used by the scheduler to step the model outputs + betas (`jnp.array`): the betas used by the scheduler to step the model outputs """ def alpha_bar(time_step): @@ -88,6 +88,11 @@ def create(cls, beta: jnp.array, num_train_timesteps: int): ) +@dataclass +class FlaxSchedulerOutput(SchedulerOutput): + state: PNDMSchedulerState + + class FlaxPNDMScheduler(SchedulerMixin, ConfigMixin): """ Pseudo numerical methods for diffusion models (PNDM) proposes using more advanced ODE integration techniques, @@ -202,7 +207,7 @@ def step( timestep: int, sample: jnp.ndarray, return_dict: bool = True, - ) -> Union[SchedulerOutput, Tuple]: + ) -> Union[FlaxSchedulerOutput, Tuple]: """ Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion process from the learned model outputs (most often the predicted noise). @@ -218,8 +223,8 @@ def step( return_dict (`bool`): option for returning tuple rather than SchedulerOutput class Returns: - [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`: - [`~schedulers.scheduling_utils.SchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When + [`FlaxSchedulerOutput`] or `tuple`: + [`FlaxSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. """ @@ -239,7 +244,7 @@ def step_prk( timestep: int, sample: jnp.ndarray, return_dict: bool = True, - ) -> Union[SchedulerOutput, Tuple]: + ) -> Union[FlaxSchedulerOutput, Tuple]: """ Step function propagating the sample with the Runge-Kutta method. RK takes 4 forward passes to approximate the solution to the differential equation. @@ -253,7 +258,7 @@ def step_prk( return_dict (`bool`): option for returning tuple rather than SchedulerOutput class Returns: - [`~scheduling_utils.SchedulerOutput`] or `tuple`: [`~scheduling_utils.SchedulerOutput`] if `return_dict` is + [`FlaxSchedulerOutput`] or `tuple`: [`FlaxSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. """ @@ -289,7 +294,7 @@ def step_prk( if not return_dict: return (prev_sample, state) - return SchedulerOutput(prev_sample=prev_sample) + return FlaxSchedulerOutput(prev_sample=prev_sample, state=state) def step_plms( self, @@ -298,7 +303,7 @@ def step_plms( timestep: int, sample: jnp.ndarray, return_dict: bool = True, - ) -> Union[SchedulerOutput, Tuple]: + ) -> Union[FlaxSchedulerOutput, Tuple]: """ Step function propagating the sample with the linear multi-step method. This has one forward pass with multiple times to approximate the solution. @@ -312,7 +317,7 @@ def step_plms( return_dict (`bool`): option for returning tuple rather than SchedulerOutput class Returns: - [`~scheduling_utils.SchedulerOutput`] or `tuple`: [`~scheduling_utils.SchedulerOutput`] if `return_dict` is + [`FlaxSchedulerOutput`] or `tuple`: [`FlaxSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. """ @@ -359,7 +364,7 @@ def step_plms( if not return_dict: return (prev_sample, state) - return SchedulerOutput(prev_sample=prev_sample) + return FlaxSchedulerOutput(prev_sample=prev_sample, state=state) def _get_prev_sample(self, sample, timestep, timestep_prev, model_output, state): # See formula (9) of PNDM paper https://arxiv.org/pdf/2202.09778.pdf From 7d2fffc5fc64f5e9c5a9c3fc7cdbd920a393c254 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Tue, 13 Sep 2022 11:49:15 +0200 Subject: [PATCH 06/18] fix style --- .../schedulers/scheduling_pndm_flax.py | 20 +++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_pndm_flax.py b/src/diffusers/schedulers/scheduling_pndm_flax.py index ea1e20c1ae17..2a4d40312103 100644 --- a/src/diffusers/schedulers/scheduling_pndm_flax.py +++ b/src/diffusers/schedulers/scheduling_pndm_flax.py @@ -12,14 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. +import math + # DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim from dataclasses import dataclass -import math -from typing import Optional, Tuple, Union, List +from typing import List, Optional, Tuple, Union +import flax import jax import jax.numpy as jnp -import flax from ..configuration_utils import ConfigMixin, register_to_config from .scheduling_utils import SchedulerMixin, SchedulerOutput @@ -223,9 +224,8 @@ def step( return_dict (`bool`): option for returning tuple rather than SchedulerOutput class Returns: - [`FlaxSchedulerOutput`] or `tuple`: - [`FlaxSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When - returning a tuple, the first element is the sample tensor. + [`FlaxSchedulerOutput`] or `tuple`: [`FlaxSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. + When returning a tuple, the first element is the sample tensor. """ if state.counter < len(state.prk_timesteps) and not self.config.skip_prk_steps: @@ -258,8 +258,8 @@ def step_prk( return_dict (`bool`): option for returning tuple rather than SchedulerOutput class Returns: - [`FlaxSchedulerOutput`] or `tuple`: [`FlaxSchedulerOutput`] if `return_dict` is - True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. + [`FlaxSchedulerOutput`] or `tuple`: [`FlaxSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. + When returning a tuple, the first element is the sample tensor. """ if state.num_inference_steps is None: @@ -317,8 +317,8 @@ def step_plms( return_dict (`bool`): option for returning tuple rather than SchedulerOutput class Returns: - [`FlaxSchedulerOutput`] or `tuple`: [`FlaxSchedulerOutput`] if `return_dict` is - True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. + [`FlaxSchedulerOutput`] or `tuple`: [`FlaxSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. + When returning a tuple, the first element is the sample tensor. """ if state.num_inference_steps is None: From 3ec603674cd5adf0c18a4cdc2a93e14b303ec940 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Tue, 13 Sep 2022 14:29:02 +0200 Subject: [PATCH 07/18] add flax imports --- src/diffusers/__init__.py | 6 ++++++ src/diffusers/schedulers/__init__.py | 2 +- src/diffusers/utils/dummy_flax_objects.py | 11 +++++++++++ 3 files changed, 18 insertions(+), 1 deletion(-) create mode 100644 src/diffusers/utils/dummy_flax_objects.py diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 219f2d8bf9d1..34cc16591d40 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -1,4 +1,5 @@ from .utils import ( + is_flax_available, is_inflect_available, is_onnx_available, is_scipy_available, @@ -60,3 +61,8 @@ from .pipelines import StableDiffusionOnnxPipeline else: from .utils.dummy_torch_and_transformers_and_onnx_objects import * # noqa F403 + +if is_flax_available(): + from .schedulers import FlaxPNDMScheduler +else: + from .utils.dummy_flax_objects import * # noqa F403 diff --git a/src/diffusers/schedulers/__init__.py b/src/diffusers/schedulers/__init__.py index 20c25f35183f..3fe8a011b735 100644 --- a/src/diffusers/schedulers/__init__.py +++ b/src/diffusers/schedulers/__init__.py @@ -20,7 +20,7 @@ from .scheduling_sde_ve import ScoreSdeVeScheduler from .scheduling_sde_vp import ScoreSdeVpScheduler from .scheduling_utils import SchedulerMixin - +from .scheduling_pndm_flax import FlaxPNDMScheduler if is_scipy_available(): from .scheduling_lms_discrete import LMSDiscreteScheduler diff --git a/src/diffusers/utils/dummy_flax_objects.py b/src/diffusers/utils/dummy_flax_objects.py new file mode 100644 index 000000000000..b5f4362bcb6e --- /dev/null +++ b/src/diffusers/utils/dummy_flax_objects.py @@ -0,0 +1,11 @@ +# This file is autogenerated by the command `make fix-copies`, do not edit. +# flake8: noqa + +from ..utils import DummyObject, requires_backends + + +class FlaxPNDMScheduler(metaclass=DummyObject): + _backends = ["flax"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["flax"]) From 769c88a8f8f43071601102c91daa5a660381ad7b Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Tue, 13 Sep 2022 14:31:43 +0200 Subject: [PATCH 08/18] make style --- src/diffusers/schedulers/__init__.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/diffusers/schedulers/__init__.py b/src/diffusers/schedulers/__init__.py index 3fe8a011b735..56d31abd99c8 100644 --- a/src/diffusers/schedulers/__init__.py +++ b/src/diffusers/schedulers/__init__.py @@ -17,10 +17,11 @@ from .scheduling_ddpm import DDPMScheduler from .scheduling_karras_ve import KarrasVeScheduler from .scheduling_pndm import PNDMScheduler +from .scheduling_pndm_flax import FlaxPNDMScheduler from .scheduling_sde_ve import ScoreSdeVeScheduler from .scheduling_sde_vp import ScoreSdeVpScheduler from .scheduling_utils import SchedulerMixin -from .scheduling_pndm_flax import FlaxPNDMScheduler + if is_scipy_available(): from .scheduling_lms_discrete import LMSDiscreteScheduler From f83c3e34a66bcf522ec10b28c3e8337c6034ab19 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Tue, 13 Sep 2022 14:42:12 +0200 Subject: [PATCH 09/18] fix typos --- src/diffusers/schedulers/scheduling_pndm_flax.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_pndm_flax.py b/src/diffusers/schedulers/scheduling_pndm_flax.py index 2a4d40312103..6b1971a9253b 100644 --- a/src/diffusers/schedulers/scheduling_pndm_flax.py +++ b/src/diffusers/schedulers/scheduling_pndm_flax.py @@ -71,7 +71,7 @@ class PNDMSchedulerState: cur_model_output: Optional[jnp.ndarray] = None counter: int = 0 cur_sample: Optional[jnp.ndarray] = None - ets: List = [] + ets: jnp.array = jnp.array([]) @property def alphas(self) -> jnp.array: @@ -82,9 +82,9 @@ def alphas_cumprod(self) -> jnp.array: return jnp.cumprod(self.alphas, axis=0) @classmethod - def create(cls, beta: jnp.array, num_train_timesteps: int): + def create(cls, betas: jnp.array, num_train_timesteps: int): state = cls( - betas=beta, + betas=betas, _timesteps=jnp.arange(0, num_train_timesteps)[::-1].copy(), ) From 28e6377fd2a267ddb70aaab9913f33f4076d3e69 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Tue, 13 Sep 2022 16:08:20 +0200 Subject: [PATCH 10/18] return created state --- src/diffusers/schedulers/scheduling_pndm_flax.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/diffusers/schedulers/scheduling_pndm_flax.py b/src/diffusers/schedulers/scheduling_pndm_flax.py index 6b1971a9253b..3180906fffd3 100644 --- a/src/diffusers/schedulers/scheduling_pndm_flax.py +++ b/src/diffusers/schedulers/scheduling_pndm_flax.py @@ -87,6 +87,7 @@ def create(cls, betas: jnp.array, num_train_timesteps: int): betas=betas, _timesteps=jnp.arange(0, num_train_timesteps)[::-1].copy(), ) + return state @dataclass From 5aab5412592b8690f720256c0d1be6d51cecfd5d Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Tue, 13 Sep 2022 16:11:52 +0200 Subject: [PATCH 11/18] make style --- src/diffusers/schedulers/scheduling_pndm_flax.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_pndm_flax.py b/src/diffusers/schedulers/scheduling_pndm_flax.py index 3180906fffd3..574e1b2d179b 100644 --- a/src/diffusers/schedulers/scheduling_pndm_flax.py +++ b/src/diffusers/schedulers/scheduling_pndm_flax.py @@ -16,10 +16,9 @@ # DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim from dataclasses import dataclass -from typing import List, Optional, Tuple, Union +from typing import Optional, Tuple, Union import flax -import jax import jax.numpy as jnp from ..configuration_utils import ConfigMixin, register_to_config @@ -83,11 +82,10 @@ def alphas_cumprod(self) -> jnp.array: @classmethod def create(cls, betas: jnp.array, num_train_timesteps: int): - state = cls( + return cls( betas=betas, _timesteps=jnp.arange(0, num_train_timesteps)[::-1].copy(), ) - return state @dataclass From 5b67ffcdc41e65e8975f783909402e0cda0cffce Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Tue, 13 Sep 2022 16:59:26 +0200 Subject: [PATCH 12/18] add torch/flax imports --- src/diffusers/schedulers/__init__.py | 29 ++++++++++++++++++---------- 1 file changed, 19 insertions(+), 10 deletions(-) diff --git a/src/diffusers/schedulers/__init__.py b/src/diffusers/schedulers/__init__.py index 56d31abd99c8..c53101a9bb4a 100644 --- a/src/diffusers/schedulers/__init__.py +++ b/src/diffusers/schedulers/__init__.py @@ -12,18 +12,27 @@ # See the License for the specific language governing permissions and # limitations under the License. -from ..utils import is_scipy_available -from .scheduling_ddim import DDIMScheduler -from .scheduling_ddpm import DDPMScheduler -from .scheduling_karras_ve import KarrasVeScheduler -from .scheduling_pndm import PNDMScheduler -from .scheduling_pndm_flax import FlaxPNDMScheduler -from .scheduling_sde_ve import ScoreSdeVeScheduler -from .scheduling_sde_vp import ScoreSdeVpScheduler -from .scheduling_utils import SchedulerMixin +from ..utils import is_flax_available, is_scipy_available, is_torch_available + + +if is_torch_available(): + from .scheduling_ddim import DDIMScheduler + from .scheduling_ddpm import DDPMScheduler + from .scheduling_karras_ve import KarrasVeScheduler + from .scheduling_pndm import PNDMScheduler + from .scheduling_sde_ve import ScoreSdeVeScheduler + from .scheduling_sde_vp import ScoreSdeVpScheduler + from .scheduling_utils import SchedulerMixin +else: + from ..utils.dummy_pt_objects import * # noqa F403 + +if is_flax_available(): + from .scheduling_pndm_flax import FlaxPNDMScheduler +else: + from ..utils.dummy_flax_objects import * # noqa F403 if is_scipy_available(): from .scheduling_lms_discrete import LMSDiscreteScheduler else: - from ..utils.dummy_scipy_objects import * # noqa F403 + from ..utils.dummy_torch_and_scipy_objects import * # noqa F403 From 71e29e37aa6edc7f7c0a22177b173c821fbd85d4 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Tue, 13 Sep 2022 17:05:47 +0200 Subject: [PATCH 13/18] docs --- src/diffusers/schedulers/scheduling_pndm_flax.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_pndm_flax.py b/src/diffusers/schedulers/scheduling_pndm_flax.py index 574e1b2d179b..647aa47639cb 100644 --- a/src/diffusers/schedulers/scheduling_pndm_flax.py +++ b/src/diffusers/schedulers/scheduling_pndm_flax.py @@ -112,7 +112,8 @@ class FlaxPNDMScheduler(SchedulerMixin, ConfigMixin): beta_schedule (`str`): the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from `linear`, `scaled_linear`, or `squaredcos_cap_v2`. - trained_betas (`np.ndarray`, optional): TODO + trained_betas (`np.ndarray`, optional): + option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc. tensor_format (`str`): whether the scheduler expects pytorch or numpy arrays skip_prk_steps (`bool`): allows the scheduler to skip the Runge-Kutta steps that are defined in the original paper as being required @@ -164,7 +165,8 @@ def set_timesteps( the PNDMScheduler state data class instance. num_inference_steps (`int`): the number of diffusion steps used when generating samples with a pre-trained model. - offset (`int`): TODO + offset (`int`): + optional value to shift timestep values up by. A value of 1 is used in stable diffusion for inference. """ step_ratio = self.config.num_train_timesteps // num_inference_steps # creates integer timesteps by multiplying by ratio From 4330661bfe261c5b6fa79de902354031a5274855 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Tue, 13 Sep 2022 17:08:48 +0200 Subject: [PATCH 14/18] fixed typo --- src/diffusers/schedulers/scheduling_pndm_flax.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/schedulers/scheduling_pndm_flax.py b/src/diffusers/schedulers/scheduling_pndm_flax.py index 647aa47639cb..c4f032d6d93b 100644 --- a/src/diffusers/schedulers/scheduling_pndm_flax.py +++ b/src/diffusers/schedulers/scheduling_pndm_flax.py @@ -345,7 +345,7 @@ def step_plms( if len(state.ets) == 1 and state.counter == 0: model_output = model_output - state.replace(cur_sampe=sample) + state.replace(cur_sample=sample) elif len(state.ets) == 1 and state.counter == 1: model_output = (model_output + state.ets[-1]) / 2 sample = state.cur_sample From 209c7f4b4d585856bfe17984bce910cf2c5d07ab Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Tue, 13 Sep 2022 17:37:05 +0200 Subject: [PATCH 15/18] remove tensor_format --- src/diffusers/schedulers/scheduling_pndm_flax.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_pndm_flax.py b/src/diffusers/schedulers/scheduling_pndm_flax.py index c4f032d6d93b..1d4689ec1563 100644 --- a/src/diffusers/schedulers/scheduling_pndm_flax.py +++ b/src/diffusers/schedulers/scheduling_pndm_flax.py @@ -114,7 +114,6 @@ class FlaxPNDMScheduler(SchedulerMixin, ConfigMixin): `linear`, `scaled_linear`, or `squaredcos_cap_v2`. trained_betas (`np.ndarray`, optional): option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc. - tensor_format (`str`): whether the scheduler expects pytorch or numpy arrays skip_prk_steps (`bool`): allows the scheduler to skip the Runge-Kutta steps that are defined in the original paper as being required before plms steps; defaults to `False`. @@ -128,7 +127,6 @@ def __init__( beta_end: float = 0.02, beta_schedule: str = "linear", trained_betas: Optional[jnp.array] = None, - tensor_format: str = "np", skip_prk_steps: bool = False, ): if trained_betas is not None: @@ -151,9 +149,6 @@ def __init__( self.state = PNDMSchedulerState.create(betas=betas, num_train_timesteps=num_train_timesteps) - self.tensor_format = tensor_format - self.set_format(tensor_format=tensor_format) - def set_timesteps( self, state: PNDMSchedulerState, num_inference_steps: int, offset: int = 0 ) -> PNDMSchedulerState: From 359ec4e2d7503a66bd7c222b1ae639d0df5e19f4 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Tue, 13 Sep 2022 18:02:05 +0200 Subject: [PATCH 16/18] round instead of cast --- src/diffusers/schedulers/scheduling_pndm_flax.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/schedulers/scheduling_pndm_flax.py b/src/diffusers/schedulers/scheduling_pndm_flax.py index 1d4689ec1563..8736a446efb9 100644 --- a/src/diffusers/schedulers/scheduling_pndm_flax.py +++ b/src/diffusers/schedulers/scheduling_pndm_flax.py @@ -166,7 +166,7 @@ def set_timesteps( step_ratio = self.config.num_train_timesteps // num_inference_steps # creates integer timesteps by multiplying by ratio # casting to int to avoid issues when num_inference_step is power of 3 - _timesteps = (jnp.arange(0, num_inference_steps) * step_ratio).astype(jnp.int64)[::-1].copy() + _timesteps = (jnp.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy() _timesteps = _timesteps + offset state = state.replace(num_inference_steps=num_inference_steps, _offset=offset, _timesteps=_timesteps) From 5b32c1a2d266be14850ca1e7d686a0e6df20f356 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Tue, 13 Sep 2022 19:04:05 +0200 Subject: [PATCH 17/18] ets is jnp array --- src/diffusers/schedulers/scheduling_pndm_flax.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/diffusers/schedulers/scheduling_pndm_flax.py b/src/diffusers/schedulers/scheduling_pndm_flax.py index 8736a446efb9..e4a06e0121d8 100644 --- a/src/diffusers/schedulers/scheduling_pndm_flax.py +++ b/src/diffusers/schedulers/scheduling_pndm_flax.py @@ -194,7 +194,9 @@ def set_timesteps( ) return state.replace( - timesteps=jnp.concatenate([state.prk_timesteps, state.plms_timesteps]).astype(jnp.int64), ets=[], counter=0 + timesteps=jnp.concatenate([state.prk_timesteps, state.plms_timesteps]).astype(jnp.int64), + ets=jnp.array([]), + counter=0, ) def step( From 295a7d7130f347cca68e99d5b88509c0a607850a Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Tue, 13 Sep 2022 19:08:11 +0200 Subject: [PATCH 18/18] remove copy --- src/diffusers/schedulers/scheduling_pndm_flax.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_pndm_flax.py b/src/diffusers/schedulers/scheduling_pndm_flax.py index e4a06e0121d8..53b0126c51c0 100644 --- a/src/diffusers/schedulers/scheduling_pndm_flax.py +++ b/src/diffusers/schedulers/scheduling_pndm_flax.py @@ -84,7 +84,7 @@ def alphas_cumprod(self) -> jnp.array: def create(cls, betas: jnp.array, num_train_timesteps: int): return cls( betas=betas, - _timesteps=jnp.arange(0, num_train_timesteps)[::-1].copy(), + _timesteps=jnp.arange(0, num_train_timesteps)[::-1], ) @@ -166,7 +166,7 @@ def set_timesteps( step_ratio = self.config.num_train_timesteps // num_inference_steps # creates integer timesteps by multiplying by ratio # casting to int to avoid issues when num_inference_step is power of 3 - _timesteps = (jnp.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy() + _timesteps = (jnp.arange(0, num_inference_steps) * step_ratio).round()[::-1] _timesteps = _timesteps + offset state = state.replace(num_inference_steps=num_inference_steps, _offset=offset, _timesteps=_timesteps) @@ -179,7 +179,7 @@ def set_timesteps( prk_timesteps=jnp.array([]), plms_timesteps=jnp.concatenate( [state._timesteps[:-1], state._timesteps[-2:-1], state._timesteps[-1:]] - )[::-1].copy(), + )[::-1], ) else: prk_timesteps = jnp.array(state._timesteps[-self.pndm_order :]).repeat(2) + jnp.tile( @@ -187,10 +187,8 @@ def set_timesteps( ) state = state.replace( - prk_timesteps=(prk_timesteps[:-1].repeat(2)[1:-1])[::-1].copy(), - plms_timesteps=state._timesteps[:-3][ - ::-1 - ].copy(), # we copy to avoid having negative strides which are not supported by torch.from_numpy + prk_timesteps=(prk_timesteps[:-1].repeat(2)[1:-1])[::-1], + plms_timesteps=state._timesteps[:-3][::-1], ) return state.replace(