Skip to content

Commit b46b813

Browse files
patil-surajsliard
authored andcommitted
fix heun scheduler (huggingface#1512)
1 parent 134fe54 commit b46b813

File tree

1 file changed

+7
-4
lines changed

1 file changed

+7
-4
lines changed

src/diffusers/schedulers/scheduling_heun.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -186,10 +186,13 @@ def step(
186186

187187
# 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
188188
if self.config.prediction_type == "epsilon":
189-
pred_original_sample = sample - sigma_hat * model_output
189+
sigma_input = sigma_hat if self.state_in_first_order else sigma_next
190+
pred_original_sample = sample - sigma_input * model_output
190191
elif self.config.prediction_type == "v_prediction":
191-
# * c_out + input * c_skip
192-
pred_original_sample = model_output * (-sigma / (sigma**2 + 1) ** 0.5) + (sample / (sigma**2 + 1))
192+
sigma_input = sigma_hat if self.state_in_first_order else sigma_next
193+
pred_original_sample = model_output * (-sigma_input / (sigma_input**2 + 1) ** 0.5) + (
194+
sample / (sigma_input**2 + 1)
195+
)
193196
else:
194197
raise ValueError(
195198
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`"
@@ -207,7 +210,7 @@ def step(
207210
self.sample = sample
208211
else:
209212
# 2. 2nd order / Heun's method
210-
derivative = (sample - pred_original_sample) / sigma_hat
213+
derivative = (sample - pred_original_sample) / sigma_next
211214
derivative = (self.prev_derivative + derivative) / 2
212215

213216
# 3. Retrieve 1st order derivative

0 commit comments

Comments
 (0)