File tree Expand file tree Collapse file tree 1 file changed +7
-4
lines changed Expand file tree Collapse file tree 1 file changed +7
-4
lines changed Original file line number Diff line number Diff line change @@ -186,10 +186,13 @@ def step(
186186
187187 # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
188188 if self .config .prediction_type == "epsilon" :
189- pred_original_sample = sample - sigma_hat * model_output
189+ sigma_input = sigma_hat if self .state_in_first_order else sigma_next
190+ pred_original_sample = sample - sigma_input * model_output
190191 elif self .config .prediction_type == "v_prediction" :
191- # * c_out + input * c_skip
192- pred_original_sample = model_output * (- sigma / (sigma ** 2 + 1 ) ** 0.5 ) + (sample / (sigma ** 2 + 1 ))
192+ sigma_input = sigma_hat if self .state_in_first_order else sigma_next
193+ pred_original_sample = model_output * (- sigma_input / (sigma_input ** 2 + 1 ) ** 0.5 ) + (
194+ sample / (sigma_input ** 2 + 1 )
195+ )
193196 else :
194197 raise ValueError (
195198 f"prediction_type given as { self .config .prediction_type } must be one of `epsilon`, or `v_prediction`"
@@ -207,7 +210,7 @@ def step(
207210 self .sample = sample
208211 else :
209212 # 2. 2nd order / Heun's method
210- derivative = (sample - pred_original_sample ) / sigma_hat
213+ derivative = (sample - pred_original_sample ) / sigma_next
211214 derivative = (self .prev_derivative + derivative ) / 2
212215
213216 # 3. Retrieve 1st order derivative
You can’t perform that action at this time.
0 commit comments