Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 7 additions & 5 deletions examples/community/bit_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def ddpm_bit_scheduler_step(
model_output: torch.FloatTensor,
timestep: int,
sample: torch.FloatTensor,
predict_epsilon=True,
prediction_type="epsilon",
generator=None,
return_dict: bool = True,
) -> Union[DDPMSchedulerOutput, Tuple]:
Expand All @@ -150,8 +150,8 @@ def ddpm_bit_scheduler_step(
timestep (`int`): current discrete timestep in the diffusion chain.
sample (`torch.FloatTensor`):
current instance of sample being created by diffusion process.
predict_epsilon (`bool`):
optional flag to use when model predicts the samples directly instead of the noise, epsilon.
prediction_type (`str`, default `epsilon`):
indicates whether the model predicts the noise (epsilon), or the samples (`sample`).
generator: random number generator.
return_dict (`bool`): option for returning tuple rather than DDPMSchedulerOutput class
Returns:
Expand All @@ -174,10 +174,12 @@ def ddpm_bit_scheduler_step(

# 2. compute predicted original sample from predicted noise also called
# "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf
if predict_epsilon:
if prediction_type == "epsilon":
pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
else:
elif prediction_type == "sample":
pred_original_sample = model_output
else:
raise ValueError(f"Unsupported prediction_type {prediction_type}.")

# 3. Clip "predicted x_0"
scale = self.bit_scale
Expand Down
19 changes: 11 additions & 8 deletions examples/unconditional_image_generation/train_unconditional.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,9 +194,10 @@ def parse_args():
)

parser.add_argument(
"--predict_epsilon",
action="store_true",
default=True,
"--prediction_type",
type=str,
default="epsilon",
choices=["epsilon", "sample"],
help="Whether the model should predict the 'epsilon'/noise error or directly the reconstructed image 'x0'.",
)

Expand Down Expand Up @@ -256,13 +257,13 @@ def main(args):
"UpBlock2D",
),
)
accepts_predict_epsilon = "predict_epsilon" in set(inspect.signature(DDPMScheduler.__init__).parameters.keys())
accepts_prediction_type = "prediction_type" in set(inspect.signature(DDPMScheduler.__init__).parameters.keys())

if accepts_predict_epsilon:
if accepts_prediction_type:
noise_scheduler = DDPMScheduler(
num_train_timesteps=args.ddpm_num_steps,
beta_schedule=args.ddpm_beta_schedule,
predict_epsilon=args.predict_epsilon,
prediction_type=args.prediction_type,
)
else:
noise_scheduler = DDPMScheduler(num_train_timesteps=args.ddpm_num_steps, beta_schedule=args.ddpm_beta_schedule)
Expand Down Expand Up @@ -365,9 +366,9 @@ def transforms(examples):
# Predict the noise residual
model_output = model(noisy_images, timesteps).sample

if args.predict_epsilon:
if args.prediction_type == "epsilon":
loss = F.mse_loss(model_output, noise) # this could have different weights!
else:
elif args.prediction_type == "sample":
alpha_t = _extract_into_tensor(
noise_scheduler.alphas_cumprod, timesteps, (clean_images.shape[0], 1, 1, 1)
)
Expand All @@ -376,6 +377,8 @@ def transforms(examples):
model_output, clean_images, reduction="none"
) # use SNR weighting from distillation paper
loss = loss.mean()
else:
raise ValueError(f"Unsupported prediction type: {args.prediction_type}")

accelerator.backward(loss)

Expand Down
5 changes: 5 additions & 0 deletions src/diffusers/configuration_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,11 @@ def from_config(cls, config: Union[FrozenDict, Dict[str, Any]] = None, return_un
if "dtype" in unused_kwargs:
init_dict["dtype"] = unused_kwargs.pop("dtype")

if "predict_epsilon" in unused_kwargs and "prediction_type" not in init_dict:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's see if we can find a more general solution here

deprecate("remove this", "0.10.0", "remove")
predict_epsilon = unused_kwargs.pop("predict_epsilon")
init_dict["prediction_type"] = "epsilon" if predict_epsilon else "sample"

# Return model and optionally state and/or unused_kwargs
model = cls(**init_dict)

Expand Down
1 change: 1 addition & 0 deletions src/diffusers/experimental/rl/value_guided_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ def run_diffusion(self, x, conditions, n_guide_steps, scale):
x = x + scale * grad
x = self.reset_x0(x, conditions, self.action_dim)
prev_x = self.unet(x.permute(0, 2, 1), timesteps).sample.permute(0, 2, 1)
# TODO: set prediction_type when instantiating the model
x = self.scheduler.step(prev_x, i, x, predict_epsilon=False)["prev_sample"]

# apply conditions to the trajectory
Expand Down
10 changes: 4 additions & 6 deletions src/diffusers/pipelines/ddpm/pipeline_ddpm.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,14 +70,14 @@ def __call__(
generated images.
"""
message = (
"Please make sure to instantiate your scheduler with `predict_epsilon` instead. E.g. `scheduler ="
" DDPMScheduler.from_pretrained(<model_id>, predict_epsilon=True)`."
"Please make sure to instantiate your scheduler with `prediction_type` instead. E.g. `scheduler ="
" DDPMScheduler.from_pretrained(<model_id>, prediction_type='epsilon')`."
)
predict_epsilon = deprecate("predict_epsilon", "0.10.0", message, take_from=kwargs)

if predict_epsilon is not None:
new_config = dict(self.scheduler.config)
new_config["predict_epsilon"] = predict_epsilon
new_config["prediction_type"] = "epsilon" if predict_epsilon else "sample"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice!

self.scheduler._internal_dict = FrozenDict(new_config)

if generator is not None and generator.device.type != self.device.type and self.device.type != "mps":
Expand Down Expand Up @@ -114,9 +114,7 @@ def __call__(
model_output = self.unet(image, t).sample

# 2. compute previous image: x_t -> x_t-1
image = self.scheduler.step(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good catch we don't need to pass it anymore!

model_output, t, image, generator=generator, predict_epsilon=predict_epsilon
).prev_sample
image = self.scheduler.step(model_output, t, image, generator=generator).prev_sample

image = (image / 2 + 0.5).clamp(0, 1)
image = image.cpu().permute(0, 2, 3, 1).numpy()
Expand Down
24 changes: 17 additions & 7 deletions src/diffusers/schedulers/scheduling_ddim.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
import torch

from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS, BaseOutput
from ..utils import _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS, BaseOutput, deprecate
from .scheduling_utils import SchedulerMixin


Expand Down Expand Up @@ -106,6 +106,9 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
an offset added to the inference steps. You can use a combination of `offset=1` and
`set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in
stable diffusion.
prediction_type (`str`, default `epsilon`):
indicates whether the model predicts the noise (epsilon), or the samples. One of `epsilon`, `sample`.
`v-prediction` is not supported for this scheduler.

"""

Expand All @@ -123,7 +126,16 @@ def __init__(
set_alpha_to_one: bool = True,
steps_offset: int = 0,
prediction_type: str = "epsilon",
**kwargs,
):
message = (
"Please make sure to instantiate your scheduler with `prediction_type` instead. E.g. `scheduler ="
" DDIMScheduler.from_pretrained(<model_id>, prediction_type='epsilon')`."
)
predict_epsilon = deprecate("predict_epsilon", "0.10.0", message, take_from=kwargs)
if predict_epsilon is not None:
self.register_to_config(prediction_type="epsilon" if predict_epsilon else "sample")

if trained_betas is not None:
self.betas = torch.from_numpy(trained_betas)
elif beta_schedule == "linear":
Expand All @@ -139,8 +151,6 @@ def __init__(
else:
raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")

self.prediction_type = prediction_type

self.alphas = 1.0 - self.betas
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)

Expand Down Expand Up @@ -261,17 +271,17 @@ def step(

# 3. compute predicted original sample from predicted noise also called
# "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
if self.prediction_type == "epsilon":
if self.config.prediction_type == "epsilon":
pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
elif self.prediction_type == "sample":
elif self.config.prediction_type == "sample":
pred_original_sample = model_output
elif self.prediction_type == "v_prediction":
elif self.config.prediction_type == "v_prediction":
pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output
# predict V
model_output = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample
else:
raise ValueError(
f"prediction_type given as {self.prediction_type} must be one of `epsilon`, `sample`, or"
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or"
" `v_prediction`"
)

Expand Down
29 changes: 28 additions & 1 deletion src/diffusers/schedulers/scheduling_ddim_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import jax.numpy as jnp

from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import deprecate
from .scheduling_utils_flax import (
_FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS,
FlaxSchedulerMixin,
Expand Down Expand Up @@ -108,6 +109,10 @@ class FlaxDDIMScheduler(FlaxSchedulerMixin, ConfigMixin):
an offset added to the inference steps. You can use a combination of `offset=1` and
`set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in
stable diffusion.
prediction_type (`str`, default `epsilon`):
indicates whether the model predicts the noise (epsilon), or the samples. One of `epsilon`, `sample`.
`v-prediction` is not supported for this scheduler.

"""

_compatibles = _FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy()
Expand All @@ -125,7 +130,17 @@ def __init__(
beta_schedule: str = "linear",
set_alpha_to_one: bool = True,
steps_offset: int = 0,
prediction_type: str = "epsilon",
**kwargs,
):
message = (
"Please make sure to instantiate your scheduler with `prediction_type` instead. E.g. `scheduler ="
" FlaxDDIMScheduler.from_pretrained(<model_id>, prediction_type='epsilon')`."
)
predict_epsilon = deprecate("predict_epsilon", "0.10.0", message, take_from=kwargs)
if predict_epsilon is not None:
self.register_to_config(prediction_type="epsilon" if predict_epsilon else "sample")

if beta_schedule == "linear":
self.betas = jnp.linspace(beta_start, beta_end, num_train_timesteps, dtype=jnp.float32)
elif beta_schedule == "scaled_linear":
Expand Down Expand Up @@ -259,7 +274,19 @@ def step(

# 3. compute predicted original sample from predicted noise also called
# "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
if self.config.prediction_type == "epsilon":
pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
elif self.config.prediction_type == "sample":
pred_original_sample = model_output
elif self.config.prediction_type == "v_prediction":
pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output
# predict V
model_output = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample
else:
raise ValueError(
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or"
" `v_prediction`"
)

# 4. compute variance: "sigma_t(η)" -> see formula (16)
# σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1)
Expand Down
34 changes: 24 additions & 10 deletions src/diffusers/schedulers/scheduling_ddpm.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,9 +99,9 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
`fixed_small_log`, `fixed_large`, `fixed_large_log`, `learned` or `learned_range`.
clip_sample (`bool`, default `True`):
option to clip predicted sample between -1 and 1 for numerical stability.
predict_epsilon (`bool`):
optional flag to use when the model predicts the noise (epsilon), or the samples instead of the noise.

prediction_type (`str`, default `epsilon`):
indicates whether the model predicts the noise (epsilon), or the samples. One of `epsilon`, `sample`.
`v-prediction` is not supported for this scheduler.
"""

_compatibles = _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy()
Expand All @@ -116,8 +116,17 @@ def __init__(
trained_betas: Optional[np.ndarray] = None,
variance_type: str = "fixed_small",
clip_sample: bool = True,
predict_epsilon: bool = True,
prediction_type: str = "epsilon",
**kwargs,
):
message = (
"Please make sure to instantiate your scheduler with `prediction_type` instead. E.g. `scheduler ="
" DDPMScheduler.from_pretrained(<model_id>, prediction_type='epsilon')`."
)
predict_epsilon = deprecate("predict_epsilon", "0.10.0", message, take_from=kwargs)
if predict_epsilon is not None:
self.register_to_config(prediction_type="epsilon" if predict_epsilon else "sample")

if trained_betas is not None:
self.betas = torch.from_numpy(trained_betas)
elif beta_schedule == "linear":
Expand Down Expand Up @@ -241,13 +250,13 @@ def step(

"""
message = (
"Please make sure to instantiate your scheduler with `predict_epsilon` instead. E.g. `scheduler ="
" DDPMScheduler.from_pretrained(<model_id>, predict_epsilon=True)`."
"Please make sure to instantiate your scheduler with `prediction_type` instead. E.g. `scheduler ="
" DDPMScheduler.from_pretrained(<model_id>, prediction_type='epsilon')`."
)
predict_epsilon = deprecate("predict_epsilon", "0.10.0", message, take_from=kwargs)
if predict_epsilon is not None and predict_epsilon != self.config.predict_epsilon:
if predict_epsilon is not None:
new_config = dict(self.config)
new_config["predict_epsilon"] = predict_epsilon
new_config["prediction_type"] = "epsilon" if predict_epsilon else "sample"
self._internal_dict = FrozenDict(new_config)

t = timestep
Expand All @@ -265,10 +274,15 @@ def step(

# 2. compute predicted original sample from predicted noise also called
# "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf
if self.config.predict_epsilon:
if self.config.prediction_type == "epsilon":
pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
else:
elif self.config.prediction_type == "sample":
pred_original_sample = model_output
else:
raise ValueError(
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample` "
" for the DDPMScheduler."
)

# 3. Clip "predicted x_0"
if self.config.clip_sample:
Expand Down
Loading