Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/diffusers/pipelines/ddpm/pipeline_ddpm.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def __call__(self, batch_size=1, generator=None, output_type="pil", **kwargs):
model_output = self.unet(image, t)["sample"]

# 2. compute previous image: x_t -> t_t-1
image = self.scheduler.step(model_output, t, image)["prev_sample"]
image = self.scheduler.step(model_output, t, image, generator=generator)["prev_sample"]
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

DDPM scheduler is also stochastic


image = (image / 2 + 0.5).clamp(0, 1)
image = image.cpu().permute(0, 2, 3, 1).numpy()
Expand Down
6 changes: 3 additions & 3 deletions src/diffusers/pipelines/score_sde_ve/pipeline_score_sde_ve.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def __call__(self, batch_size=1, num_inference_steps=2000, generator=None, outpu

model = self.unet

sample = torch.randn(*shape) * self.scheduler.config.sigma_max
sample = torch.randn(*shape, generator=generator) * self.scheduler.config.sigma_max
sample = sample.to(self.device)

self.scheduler.set_timesteps(num_inference_steps)
Expand All @@ -42,11 +42,11 @@ def __call__(self, batch_size=1, num_inference_steps=2000, generator=None, outpu
# correction step
for _ in range(self.scheduler.correct_steps):
model_output = self.unet(sample, sigma_t)["sample"]
sample = self.scheduler.step_correct(model_output, sample)["prev_sample"]
sample = self.scheduler.step_correct(model_output, sample, generator=generator)["prev_sample"]
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's pass the generator here


# prediction step
model_output = model(sample, sigma_t)["sample"]
output = self.scheduler.step_pred(model_output, t, sample)
output = self.scheduler.step_pred(model_output, t, sample, generator=generator)

sample, sample_mean = output["prev_sample"], output["prev_sample_mean"]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ class KarrasVePipeline(DiffusionPipeline):
differential equations." https://arxiv.org/abs/2011.13456
"""

# add type hints for linting
unet: UNet2DModel
scheduler: KarrasVeScheduler

Expand Down
26 changes: 16 additions & 10 deletions src/diffusers/schedulers/scheduling_sde_ve.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,10 @@

# DISCLAIMER: This file is strongly influenced by https://github.com/yang-song/score_sde_pytorch

import warnings

# TODO(Patrick, Anton, Suraj) - make scheduler framework indepedent and clean-up a bit
from typing import Union
from typing import Optional, Union

import numpy as np
import torch
Expand Down Expand Up @@ -98,6 +100,9 @@ def get_adjacent_sigma(self, timesteps, t):
raise ValueError(f"`self.tensor_format`: {self.tensor_format} is not valid.")

def set_seed(self, seed):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should not work with set_seed(...) here as it goes against the generator design. IMO passing generators around is the correct thing to do here because:

  • We cannot really pass seeds around and globally set manual_seed(...) every time:
    • It's not a great design to retrieve the current seed from a generator and then pass this -> better to pass generator directly
    • Flax passes PNRG keys around that can be split -> we cannot split PyTorch seeds in the same way and we cannot pass a seed into a torch.randn(...) function -> so let's go for the generator here

warnings.warn(
"The method `set_seed` is deprecated and will be removed in version `0.4.0`. Please consider passing a generator instead.", DeprecationWarning
)
tensor_format = getattr(self, "tensor_format", "pt")
if tensor_format == "np":
np.random.seed(seed)
Expand All @@ -111,14 +116,14 @@ def step_pred(
model_output: Union[torch.FloatTensor, np.ndarray],
timestep: int,
sample: Union[torch.FloatTensor, np.ndarray],
seed=None,
generator: Optional[torch.Generator] = None,
**kwargs,
):
"""
Predict the sample at the previous timestep by reversing the SDE.
"""
if seed is not None:
self.set_seed(seed)
# TODO(Patrick) non-PyTorch
if "seed" in kwargs and kwargs["seed"] is not None:
self.set_seed(kwargs["seed"])

if self.timesteps is None:
raise ValueError(
Expand All @@ -140,7 +145,7 @@ def step_pred(
drift = drift - diffusion[:, None, None, None] ** 2 * model_output

# equation 6: sample noise for the diffusion term of
noise = self.randn_like(sample)
noise = self.randn_like(sample, generator=generator)
prev_sample_mean = sample - drift # subtract because `dt` is a small negative timestep
# TODO is the variable diffusion the correct scaling term for the noise?
prev_sample = prev_sample_mean + diffusion[:, None, None, None] * noise # add impact of diffusion field g
Expand All @@ -151,14 +156,15 @@ def step_correct(
self,
model_output: Union[torch.FloatTensor, np.ndarray],
sample: Union[torch.FloatTensor, np.ndarray],
seed=None,
generator: Optional[torch.Generator] = None,
**kwargs,
):
"""
Correct the predicted sample based on the output model_output of the network. This is often run repeatedly
after making the prediction for the previous timestep.
"""
if seed is not None:
self.set_seed(seed)
if "seed" in kwargs and kwargs["seed"] is not None:
self.set_seed(kwargs["seed"])

if self.timesteps is None:
raise ValueError(
Expand All @@ -167,7 +173,7 @@ def step_correct(

# For small batch sizes, the paper "suggest replacing norm(z) with sqrt(d), where d is the dim. of z"
# sample noise for correction
noise = self.randn_like(sample)
noise = self.randn_like(sample, generator=generator)

# compute step size from the model_output, the noise, and the snr
grad_norm = self.norm(model_output)
Expand Down