@@ -2,9 +2,7 @@ use std::iter::repeat;
22
33use super :: {
44 betas_for_alpha_bar,
5- dpmsolver:: {
6- DPMSolverAlgorithmType , DPMSolverScheduler , DPMSolverSchedulerConfig , DPMSolverType ,
7- } ,
5+ dpmsolver:: { DPMSolverAlgorithmType , DPMSolverSchedulerConfig , DPMSolverType } ,
86 BetaSchedule , PredictionType ,
97} ;
108use tch:: { kind, Kind , Tensor } ;
@@ -25,8 +23,8 @@ pub struct DPMSolverSinglestepScheduler {
2523 pub config : DPMSolverSchedulerConfig ,
2624}
2725
28- impl DPMSolverScheduler for DPMSolverSinglestepScheduler {
29- fn new ( inference_steps : usize , config : DPMSolverSchedulerConfig ) -> Self {
26+ impl DPMSolverSinglestepScheduler {
27+ pub fn new ( inference_steps : usize , config : DPMSolverSchedulerConfig ) -> Self {
3028 let betas = match config. beta_schedule {
3129 BetaSchedule :: ScaledLinear => Tensor :: linspace (
3230 config. beta_start . sqrt ( ) ,
@@ -143,9 +141,9 @@ impl DPMSolverScheduler for DPMSolverSinglestepScheduler {
143141 /// * `timestep` - current discrete timestep in the diffusion chain
144142 /// * `prev_timestep` - previous discrete timestep in the diffusion chain
145143 /// * `sample` - current instance of sample being created by diffusion process
146- fn first_order_update (
144+ fn dpm_solver_first_order_update (
147145 & self ,
148- model_output : Tensor ,
146+ model_output : & Tensor ,
149147 timestep : usize ,
150148 prev_timestep : usize ,
151149 sample : & Tensor ,
@@ -173,7 +171,7 @@ impl DPMSolverScheduler for DPMSolverSinglestepScheduler {
173171 /// * `timestep_list` - current and latter discrete timestep in the diffusion chain
174172 /// * `prev_timestep` - previous discrete timestep in the diffusion chain
175173 /// * `sample` - current instance of sample being created by diffusion process
176- fn second_order_update (
174+ fn singlestep_dpm_solver_second_order_update (
177175 & self ,
178176 model_output_list : & Vec < Tensor > ,
179177 timestep_list : [ usize ; 2 ] ,
@@ -234,7 +232,7 @@ impl DPMSolverScheduler for DPMSolverSinglestepScheduler {
234232 /// * `timestep_list` - current and latter discrete timestep in the diffusion chain
235233 /// * `prev_timestep` - previous discrete timestep in the diffusion chain
236234 /// * `sample` - current instance of sample being created by diffusion process
237- fn third_order_update (
235+ fn singlestep_dpm_solver_third_order_update (
238236 & self ,
239237 model_output_list : & Vec < Tensor > ,
240238 timestep_list : [ usize ; 3 ] ,
@@ -292,13 +290,13 @@ impl DPMSolverScheduler for DPMSolverSinglestepScheduler {
292290 }
293291 }
294292
295- fn timesteps ( & self ) -> & [ usize ] {
293+ pub fn timesteps ( & self ) -> & [ usize ] {
296294 self . timesteps . as_slice ( )
297295 }
298296
299297 /// Ensures interchangeability with schedulers that need to scale the denoising model input
300298 /// depending on the current timestep.
301- fn scale_model_input ( & self , sample : Tensor , _timestep : usize ) -> Tensor {
299+ pub fn scale_model_input ( & self , sample : Tensor , _timestep : usize ) -> Tensor {
302300 sample
303301 }
304302
@@ -309,7 +307,7 @@ impl DPMSolverScheduler for DPMSolverSinglestepScheduler {
309307 /// * `model_output` - direct output from learned diffusion model
310308 /// * `timestep` - current discrete timestep in the diffusion chain
311309 /// * `sample` - current instance of sample being created by diffusion process
312- fn step ( & mut self , model_output : & Tensor , timestep : usize , sample : & Tensor ) -> Tensor {
310+ pub fn step ( & mut self , model_output : & Tensor , timestep : usize , sample : & Tensor ) -> Tensor {
313311 // https://github.com/huggingface/diffusers/blob/e4fe9413121b78c4c1f109b50f0f3cc1c320a1a2/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py#L535
314312 let step_index: usize = self . timesteps . iter ( ) . position ( |& t| t == timestep) . unwrap ( ) ;
315313
@@ -331,19 +329,19 @@ impl DPMSolverScheduler for DPMSolverSinglestepScheduler {
331329 } ;
332330
333331 match order {
334- 1 => self . first_order_update (
335- model_output ,
332+ 1 => self . dpm_solver_first_order_update (
333+ & self . model_outputs [ self . model_outputs . len ( ) - 1 ] ,
336334 timestep,
337335 prev_timestep,
338336 & self . sample . as_ref ( ) . unwrap ( ) ,
339337 ) ,
340- 2 => self . second_order_update (
338+ 2 => self . singlestep_dpm_solver_second_order_update (
341339 & self . model_outputs ,
342340 [ self . timesteps [ step_index - 1 ] , self . timesteps [ step_index] ] ,
343341 prev_timestep,
344342 & self . sample . as_ref ( ) . unwrap ( ) ,
345343 ) ,
346- 3 => self . third_order_update (
344+ 3 => self . singlestep_dpm_solver_third_order_update (
347345 & self . model_outputs ,
348346 [
349347 self . timesteps [ step_index - 2 ] ,
@@ -359,12 +357,12 @@ impl DPMSolverScheduler for DPMSolverSinglestepScheduler {
359357 }
360358 }
361359
362- fn add_noise ( & self , original_samples : & Tensor , noise : Tensor , timestep : usize ) -> Tensor {
360+ pub fn add_noise ( & self , original_samples : & Tensor , noise : Tensor , timestep : usize ) -> Tensor {
363361 self . alphas_cumprod [ timestep] . sqrt ( ) * original_samples. to_owned ( )
364362 + ( 1.0 - self . alphas_cumprod [ timestep] ) . sqrt ( ) * noise
365363 }
366364
367- fn init_noise_sigma ( & self ) -> f64 {
365+ pub fn init_noise_sigma ( & self ) -> f64 {
368366 self . init_noise_sigma
369367 }
370368}
0 commit comments