Skip to content
2 changes: 1 addition & 1 deletion src/diffusers/schedulers/scheduling_ddim.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,7 @@ def step(
if not return_dict:
return (prev_sample,)

return SchedulerOutput(prev_sample=prev_sample)
return SchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample)

def add_noise(
self,
Expand Down
2 changes: 1 addition & 1 deletion src/diffusers/schedulers/scheduling_ddpm.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,7 @@ def step(
if not return_dict:
return (pred_prev_sample,)

return SchedulerOutput(prev_sample=pred_prev_sample)
return SchedulerOutput(prev_sample=pred_prev_sample, pred_original_sample=pred_original_sample)

def add_noise(
self,
Expand Down
8 changes: 6 additions & 2 deletions src/diffusers/schedulers/scheduling_karras_ve.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,14 @@ class KarrasVeOutput(BaseOutput):
denoising loop.
derivative (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
Derivative of predicted original image sample (x_0).
pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
The predicted denoised sample (x_{0}) based on the model output from the current timestep.
`pred_original_sample` can be used to preview progress or for guidance.
"""

prev_sample: torch.FloatTensor
derivative: torch.FloatTensor
pred_original_sample: Optional[torch.FloatTensor] = None


class KarrasVeScheduler(SchedulerMixin, ConfigMixin):
Expand Down Expand Up @@ -170,7 +174,7 @@ def step(
if not return_dict:
return (sample_prev, derivative)

return KarrasVeOutput(prev_sample=sample_prev, derivative=derivative)
return KarrasVeOutput(prev_sample=sample_prev, derivative=derivative, pred_original_sample=pred_original_sample)

def step_correct(
self,
Expand Down Expand Up @@ -205,7 +209,7 @@ def step_correct(
if not return_dict:
return (sample_prev, derivative)

return KarrasVeOutput(prev_sample=sample_prev, derivative=derivative)
return KarrasVeOutput(prev_sample=sample_prev, derivative=derivative, pred_original_sample=pred_original_sample)

def add_noise(self, original_samples, noise, timesteps):
raise NotImplementedError()
2 changes: 1 addition & 1 deletion src/diffusers/schedulers/scheduling_lms_discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ def step(
if not return_dict:
return (prev_sample,)

return SchedulerOutput(prev_sample=prev_sample)
return SchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample)

def add_noise(
self,
Expand Down
6 changes: 5 additions & 1 deletion src/diffusers/schedulers/scheduling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from typing import Union
from typing import Union, Optional

import numpy as np
import torch
Expand All @@ -32,9 +32,13 @@ class SchedulerOutput(BaseOutput):
prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the
denoising loop.
pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
The predicted denoised sample (x_{0}) based on the model output from the current timestep.
`pred_original_sample` can be used to preview progress or for guidance.
"""

prev_sample: torch.FloatTensor
pred_original_sample: Optional[torch.FloatTensor] = None


class SchedulerMixin:
Expand Down