Skip to content

Commit 91db818

Browse files
Adding pred_original_sample to SchedulerOutput for some samplers (#614)
* Adding pred_original_sample to SchedulerOutput of DDPMScheduler, DDIMScheduler, LMSDiscreteScheduler, KarrasVeScheduler step methods so we can access the predicted denoised outputs * Gave DDPMScheduler, DDIMScheduler and LMSDiscreteScheduler their own output dataclasses so the default SchedulerOutput in scheduling_utils does not need pred_original_sample as an optional extra * Reordered library imports to follow standard * didnt get import order quite right apparently * Forgot to change name of LMSDiscreteSchedulerOutput * Aha, needed some extra libs for make style to fully work
1 parent f149d03 commit 91db818

File tree

4 files changed

+91
-23
lines changed

4 files changed

+91
-23
lines changed

src/diffusers/schedulers/scheduling_ddim.py

Lines changed: 26 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,33 @@
1717

1818
import math
1919
import warnings
20+
from dataclasses import dataclass
2021
from typing import Optional, Tuple, Union
2122

2223
import numpy as np
2324
import torch
2425

2526
from ..configuration_utils import ConfigMixin, register_to_config
26-
from .scheduling_utils import SchedulerMixin, SchedulerOutput
27+
from ..utils import BaseOutput
28+
from .scheduling_utils import SchedulerMixin
29+
30+
31+
@dataclass
32+
class DDIMSchedulerOutput(BaseOutput):
33+
"""
34+
Output class for the scheduler's step function output.
35+
36+
Args:
37+
prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
38+
Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the
39+
denoising loop.
40+
pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
41+
The predicted denoised sample (x_{0}) based on the model output from the current timestep.
42+
`pred_original_sample` can be used to preview progress or for guidance.
43+
"""
44+
45+
prev_sample: torch.FloatTensor
46+
pred_original_sample: Optional[torch.FloatTensor] = None
2747

2848

2949
def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
@@ -179,7 +199,7 @@ def step(
179199
use_clipped_model_output: bool = False,
180200
generator=None,
181201
return_dict: bool = True,
182-
) -> Union[SchedulerOutput, Tuple]:
202+
) -> Union[DDIMSchedulerOutput, Tuple]:
183203
"""
184204
Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
185205
process from the learned model outputs (most often the predicted noise).
@@ -192,11 +212,11 @@ def step(
192212
eta (`float`): weight of noise for added noise in diffusion step.
193213
use_clipped_model_output (`bool`): TODO
194214
generator: random number generator.
195-
return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
215+
return_dict (`bool`): option for returning tuple rather than DDIMSchedulerOutput class
196216
197217
Returns:
198-
[`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`:
199-
[`~schedulers.scheduling_utils.SchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When
218+
[`~schedulers.scheduling_utils.DDIMSchedulerOutput`] or `tuple`:
219+
[`~schedulers.scheduling_utils.DDIMSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When
200220
returning a tuple, the first element is the sample tensor.
201221
202222
"""
@@ -261,7 +281,7 @@ def step(
261281
if not return_dict:
262282
return (prev_sample,)
263283

264-
return SchedulerOutput(prev_sample=prev_sample)
284+
return DDIMSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample)
265285

266286
def add_noise(
267287
self,

src/diffusers/schedulers/scheduling_ddpm.py

Lines changed: 26 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,33 @@
1515
# DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim
1616

1717
import math
18+
from dataclasses import dataclass
1819
from typing import Optional, Tuple, Union
1920

2021
import numpy as np
2122
import torch
2223

2324
from ..configuration_utils import ConfigMixin, register_to_config
24-
from .scheduling_utils import SchedulerMixin, SchedulerOutput
25+
from ..utils import BaseOutput
26+
from .scheduling_utils import SchedulerMixin
27+
28+
29+
@dataclass
30+
class DDPMSchedulerOutput(BaseOutput):
31+
"""
32+
Output class for the scheduler's step function output.
33+
34+
Args:
35+
prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
36+
Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the
37+
denoising loop.
38+
pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
39+
The predicted denoised sample (x_{0}) based on the model output from the current timestep.
40+
`pred_original_sample` can be used to preview progress or for guidance.
41+
"""
42+
43+
prev_sample: torch.FloatTensor
44+
pred_original_sample: Optional[torch.FloatTensor] = None
2545

2646

2747
def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
@@ -177,7 +197,7 @@ def step(
177197
predict_epsilon=True,
178198
generator=None,
179199
return_dict: bool = True,
180-
) -> Union[SchedulerOutput, Tuple]:
200+
) -> Union[DDPMSchedulerOutput, Tuple]:
181201
"""
182202
Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
183203
process from the learned model outputs (most often the predicted noise).
@@ -190,11 +210,11 @@ def step(
190210
predict_epsilon (`bool`):
191211
optional flag to use when model predicts the samples directly instead of the noise, epsilon.
192212
generator: random number generator.
193-
return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
213+
return_dict (`bool`): option for returning tuple rather than DDPMSchedulerOutput class
194214
195215
Returns:
196-
[`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`:
197-
[`~schedulers.scheduling_utils.SchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When
216+
[`~schedulers.scheduling_utils.DDPMSchedulerOutput`] or `tuple`:
217+
[`~schedulers.scheduling_utils.DDPMSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When
198218
returning a tuple, the first element is the sample tensor.
199219
200220
"""
@@ -242,7 +262,7 @@ def step(
242262
if not return_dict:
243263
return (pred_prev_sample,)
244264

245-
return SchedulerOutput(prev_sample=pred_prev_sample)
265+
return DDPMSchedulerOutput(prev_sample=pred_prev_sample, pred_original_sample=pred_original_sample)
246266

247267
def add_noise(
248268
self,

src/diffusers/schedulers/scheduling_karras_ve.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -35,10 +35,14 @@ class KarrasVeOutput(BaseOutput):
3535
denoising loop.
3636
derivative (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
3737
Derivative of predicted original image sample (x_0).
38+
pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
39+
The predicted denoised sample (x_{0}) based on the model output from the current timestep.
40+
`pred_original_sample` can be used to preview progress or for guidance.
3841
"""
3942

4043
prev_sample: torch.FloatTensor
4144
derivative: torch.FloatTensor
45+
pred_original_sample: Optional[torch.FloatTensor] = None
4246

4347

4448
class KarrasVeScheduler(SchedulerMixin, ConfigMixin):
@@ -153,7 +157,7 @@ def step(
153157
sigma_hat (`float`): TODO
154158
sigma_prev (`float`): TODO
155159
sample_hat (`torch.FloatTensor` or `np.ndarray`): TODO
156-
return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
160+
return_dict (`bool`): option for returning tuple rather than KarrasVeOutput class
157161
158162
KarrasVeOutput: updated sample in the diffusion chain and derivative (TODO double check).
159163
Returns:
@@ -170,7 +174,9 @@ def step(
170174
if not return_dict:
171175
return (sample_prev, derivative)
172176

173-
return KarrasVeOutput(prev_sample=sample_prev, derivative=derivative)
177+
return KarrasVeOutput(
178+
prev_sample=sample_prev, derivative=derivative, pred_original_sample=pred_original_sample
179+
)
174180

175181
def step_correct(
176182
self,
@@ -192,7 +198,7 @@ def step_correct(
192198
sample_hat (`torch.FloatTensor` or `np.ndarray`): TODO
193199
sample_prev (`torch.FloatTensor` or `np.ndarray`): TODO
194200
derivative (`torch.FloatTensor` or `np.ndarray`): TODO
195-
return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
201+
return_dict (`bool`): option for returning tuple rather than KarrasVeOutput class
196202
197203
Returns:
198204
prev_sample (TODO): updated sample in the diffusion chain. derivative (TODO): TODO
@@ -205,7 +211,9 @@ def step_correct(
205211
if not return_dict:
206212
return (sample_prev, derivative)
207213

208-
return KarrasVeOutput(prev_sample=sample_prev, derivative=derivative)
214+
return KarrasVeOutput(
215+
prev_sample=sample_prev, derivative=derivative, pred_original_sample=pred_original_sample
216+
)
209217

210218
def add_noise(self, original_samples, noise, timesteps):
211219
raise NotImplementedError()

src/diffusers/schedulers/scheduling_lms_discrete.py

Lines changed: 27 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
from dataclasses import dataclass
1516
from typing import Optional, Tuple, Union
1617

1718
import numpy as np
@@ -20,7 +21,26 @@
2021
from scipy import integrate
2122

2223
from ..configuration_utils import ConfigMixin, register_to_config
23-
from .scheduling_utils import SchedulerMixin, SchedulerOutput
24+
from ..utils import BaseOutput
25+
from .scheduling_utils import SchedulerMixin
26+
27+
28+
@dataclass
29+
class LMSDiscreteSchedulerOutput(BaseOutput):
30+
"""
31+
Output class for the scheduler's step function output.
32+
33+
Args:
34+
prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
35+
Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the
36+
denoising loop.
37+
pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
38+
The predicted denoised sample (x_{0}) based on the model output from the current timestep.
39+
`pred_original_sample` can be used to preview progress or for guidance.
40+
"""
41+
42+
prev_sample: torch.FloatTensor
43+
pred_original_sample: Optional[torch.FloatTensor] = None
2444

2545

2646
class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
@@ -133,7 +153,7 @@ def step(
133153
sample: Union[torch.FloatTensor, np.ndarray],
134154
order: int = 4,
135155
return_dict: bool = True,
136-
) -> Union[SchedulerOutput, Tuple]:
156+
) -> Union[LMSDiscreteSchedulerOutput, Tuple]:
137157
"""
138158
Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
139159
process from the learned model outputs (most often the predicted noise).
@@ -144,12 +164,12 @@ def step(
144164
sample (`torch.FloatTensor` or `np.ndarray`):
145165
current instance of sample being created by diffusion process.
146166
order: coefficient for multi-step inference.
147-
return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
167+
return_dict (`bool`): option for returning tuple rather than LMSDiscreteSchedulerOutput class
148168
149169
Returns:
150-
[`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`:
151-
[`~schedulers.scheduling_utils.SchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When
152-
returning a tuple, the first element is the sample tensor.
170+
[`~schedulers.scheduling_utils.LMSDiscreteSchedulerOutput`] or `tuple`:
171+
[`~schedulers.scheduling_utils.LMSDiscreteSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`.
172+
When returning a tuple, the first element is the sample tensor.
153173
154174
"""
155175
sigma = self.sigmas[timestep]
@@ -175,7 +195,7 @@ def step(
175195
if not return_dict:
176196
return (prev_sample,)
177197

178-
return SchedulerOutput(prev_sample=prev_sample)
198+
return LMSDiscreteSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample)
179199

180200
def add_noise(
181201
self,

0 commit comments

Comments
 (0)