Skip to content

Commit b8271d9

Browse files
committed
Revert "Add DPMSolverScheduler trait"
This reverts commit 83b28b3.
1 parent 83b28b3 commit b8271d9

File tree

3 files changed

+39
-86
lines changed

3 files changed

+39
-86
lines changed

src/schedulers/dpmsolver.rs

Lines changed: 0 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
use tch::Tensor;
2-
31
use crate::schedulers::BetaSchedule;
42
use crate::schedulers::PredictionType;
53

@@ -67,46 +65,3 @@ impl Default for DPMSolverSchedulerConfig {
6765
}
6866
}
6967
}
70-
71-
pub trait DPMSolverScheduler {
72-
fn new(inference_steps: usize, config: DPMSolverSchedulerConfig) -> Self;
73-
fn convert_model_output(
74-
&self,
75-
model_output: &Tensor,
76-
timestep: usize,
77-
sample: &Tensor,
78-
) -> Tensor;
79-
80-
fn first_order_update(
81-
&self,
82-
model_output: Tensor,
83-
timestep: usize,
84-
prev_timestep: usize,
85-
sample: &Tensor,
86-
) -> Tensor;
87-
88-
fn second_order_update(
89-
&self,
90-
model_output_list: &Vec<Tensor>,
91-
timestep_list: [usize; 2],
92-
prev_timestep: usize,
93-
sample: &Tensor,
94-
) -> Tensor;
95-
96-
fn third_order_update(
97-
&self,
98-
model_output_list: &Vec<Tensor>,
99-
timestep_list: [usize; 3],
100-
prev_timestep: usize,
101-
sample: &Tensor,
102-
) -> Tensor;
103-
104-
fn step(&mut self, model_output: &Tensor, timestep: usize, sample: &Tensor) -> Tensor;
105-
106-
fn timesteps(&self) -> &[usize];
107-
fn scale_model_input(&self, sample: Tensor, timestep: usize) -> Tensor;
108-
109-
110-
fn add_noise(&self, original_samples: &Tensor, noise: Tensor, timestep: usize) -> Tensor;
111-
fn init_noise_sigma(&self) -> f64;
112-
}

src/schedulers/dpmsolver_multistep.rs

Lines changed: 23 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,4 @@
1-
use super::{
2-
betas_for_alpha_bar,
3-
dpmsolver::{
4-
DPMSolverAlgorithmType, DPMSolverScheduler, DPMSolverSchedulerConfig, DPMSolverType,
5-
},
6-
BetaSchedule, PredictionType,
7-
};
1+
use super::{betas_for_alpha_bar, BetaSchedule, PredictionType, dpmsolver::{DPMSolverSchedulerConfig, DPMSolverAlgorithmType, DPMSolverType}};
82
use tch::{kind, Kind, Tensor};
93

104
pub struct DPMSolverMultistepScheduler {
@@ -21,8 +15,8 @@ pub struct DPMSolverMultistepScheduler {
2115
pub config: DPMSolverSchedulerConfig,
2216
}
2317

24-
impl DPMSolverScheduler for DPMSolverMultistepScheduler {
25-
fn new(inference_steps: usize, config: DPMSolverSchedulerConfig) -> Self {
18+
impl DPMSolverMultistepScheduler {
19+
pub fn new(inference_steps: usize, config: DPMSolverSchedulerConfig) -> Self {
2620
let betas = match config.beta_schedule {
2721
BetaSchedule::ScaledLinear => Tensor::linspace(
2822
config.beta_start.sqrt(),
@@ -123,7 +117,7 @@ impl DPMSolverScheduler for DPMSolverMultistepScheduler {
123117

124118
/// One step for the first-order DPM-Solver (equivalent to DDIM).
125119
/// See https://arxiv.org/abs/2206.00927 for the detailed derivation.
126-
fn first_order_update(
120+
fn dpm_solver_first_order_update(
127121
&self,
128122
model_output: Tensor,
129123
timestep: usize,
@@ -145,7 +139,7 @@ impl DPMSolverScheduler for DPMSolverMultistepScheduler {
145139
}
146140

147141
/// One step for the second-order multistep DPM-Solver.
148-
fn second_order_update(
142+
fn multistep_dpm_solver_second_order_update(
149143
&self,
150144
model_output_list: &Vec<Tensor>,
151145
timestep_list: [usize; 2],
@@ -198,7 +192,7 @@ impl DPMSolverScheduler for DPMSolverMultistepScheduler {
198192
}
199193

200194
/// One step for the third-order multistep DPM-Solver
201-
fn third_order_update(
195+
fn multistep_dpm_solver_third_order_update(
202196
&self,
203197
model_output_list: &Vec<Tensor>,
204198
timestep_list: [usize; 3],
@@ -243,11 +237,11 @@ impl DPMSolverScheduler for DPMSolverMultistepScheduler {
243237
}
244238
}
245239

246-
fn timesteps(&self) -> &[usize] {
240+
pub fn timesteps(&self) -> &[usize] {
247241
self.timesteps.as_slice()
248242
}
249243

250-
fn step(&mut self, model_output: &Tensor, timestep: usize, sample: &Tensor) -> Tensor {
244+
pub fn step(&mut self, model_output: &Tensor, timestep: usize, sample: &Tensor) -> Tensor {
251245
// https://github.com/huggingface/diffusers/blob/e4fe9413121b78c4c1f109b50f0f3cc1c320a1a2/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py#L457
252246
let step_index = self.timesteps.iter().position(|&t| t == timestep).unwrap();
253247

@@ -272,14 +266,24 @@ impl DPMSolverScheduler for DPMSolverMultistepScheduler {
272266
|| self.lower_order_nums < 1
273267
|| lower_order_final
274268
{
275-
self.first_order_update(model_output, timestep, prev_timestep, sample)
269+
self.dpm_solver_first_order_update(model_output, timestep, prev_timestep, sample)
276270
} else if self.config.solver_order == 2 || self.lower_order_nums < 2 || lower_order_second {
277271
let timestep_list = [self.timesteps[step_index - 1], timestep];
278-
self.second_order_update(&self.model_outputs, timestep_list, prev_timestep, sample)
272+
self.multistep_dpm_solver_second_order_update(
273+
&self.model_outputs,
274+
timestep_list,
275+
prev_timestep,
276+
sample,
277+
)
279278
} else {
280279
let timestep_list =
281280
[self.timesteps[step_index - 2], self.timesteps[step_index - 1], timestep];
282-
self.third_order_update(&self.model_outputs, timestep_list, prev_timestep, sample)
281+
self.multistep_dpm_solver_third_order_update(
282+
&self.model_outputs,
283+
timestep_list,
284+
prev_timestep,
285+
sample,
286+
)
283287
};
284288

285289
if self.lower_order_nums < self.config.solver_order {
@@ -289,16 +293,12 @@ impl DPMSolverScheduler for DPMSolverMultistepScheduler {
289293
prev_sample
290294
}
291295

292-
fn add_noise(&self, original_samples: &Tensor, noise: Tensor, timestep: usize) -> Tensor {
296+
pub fn add_noise(&self, original_samples: &Tensor, noise: Tensor, timestep: usize) -> Tensor {
293297
self.alphas_cumprod[timestep].sqrt() * original_samples.to_owned()
294298
+ (1.0 - self.alphas_cumprod[timestep]).sqrt() * noise
295299
}
296300

297-
fn init_noise_sigma(&self) -> f64 {
301+
pub fn init_noise_sigma(&self) -> f64 {
298302
self.init_noise_sigma
299303
}
300-
301-
fn scale_model_input(&self, _sample: Tensor, _timestep: usize) -> Tensor {
302-
todo!()
303-
}
304304
}

src/schedulers/dpmsolver_singlestep.rs

Lines changed: 16 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,7 @@ use std::iter::repeat;
22

33
use super::{
44
betas_for_alpha_bar,
5-
dpmsolver::{
6-
DPMSolverAlgorithmType, DPMSolverScheduler, DPMSolverSchedulerConfig, DPMSolverType,
7-
},
5+
dpmsolver::{DPMSolverAlgorithmType, DPMSolverSchedulerConfig, DPMSolverType},
86
BetaSchedule, PredictionType,
97
};
108
use tch::{kind, Kind, Tensor};
@@ -25,8 +23,8 @@ pub struct DPMSolverSinglestepScheduler {
2523
pub config: DPMSolverSchedulerConfig,
2624
}
2725

28-
impl DPMSolverScheduler for DPMSolverSinglestepScheduler {
29-
fn new(inference_steps: usize, config: DPMSolverSchedulerConfig) -> Self {
26+
impl DPMSolverSinglestepScheduler {
27+
pub fn new(inference_steps: usize, config: DPMSolverSchedulerConfig) -> Self {
3028
let betas = match config.beta_schedule {
3129
BetaSchedule::ScaledLinear => Tensor::linspace(
3230
config.beta_start.sqrt(),
@@ -143,9 +141,9 @@ impl DPMSolverScheduler for DPMSolverSinglestepScheduler {
143141
/// * `timestep` - current discrete timestep in the diffusion chain
144142
/// * `prev_timestep` - previous discrete timestep in the diffusion chain
145143
/// * `sample` - current instance of sample being created by diffusion process
146-
fn first_order_update(
144+
fn dpm_solver_first_order_update(
147145
&self,
148-
model_output: Tensor,
146+
model_output: &Tensor,
149147
timestep: usize,
150148
prev_timestep: usize,
151149
sample: &Tensor,
@@ -173,7 +171,7 @@ impl DPMSolverScheduler for DPMSolverSinglestepScheduler {
173171
/// * `timestep_list` - current and latter discrete timestep in the diffusion chain
174172
/// * `prev_timestep` - previous discrete timestep in the diffusion chain
175173
/// * `sample` - current instance of sample being created by diffusion process
176-
fn second_order_update(
174+
fn singlestep_dpm_solver_second_order_update(
177175
&self,
178176
model_output_list: &Vec<Tensor>,
179177
timestep_list: [usize; 2],
@@ -234,7 +232,7 @@ impl DPMSolverScheduler for DPMSolverSinglestepScheduler {
234232
/// * `timestep_list` - current and latter discrete timestep in the diffusion chain
235233
/// * `prev_timestep` - previous discrete timestep in the diffusion chain
236234
/// * `sample` - current instance of sample being created by diffusion process
237-
fn third_order_update(
235+
fn singlestep_dpm_solver_third_order_update(
238236
&self,
239237
model_output_list: &Vec<Tensor>,
240238
timestep_list: [usize; 3],
@@ -292,13 +290,13 @@ impl DPMSolverScheduler for DPMSolverSinglestepScheduler {
292290
}
293291
}
294292

295-
fn timesteps(&self) -> &[usize] {
293+
pub fn timesteps(&self) -> &[usize] {
296294
self.timesteps.as_slice()
297295
}
298296

299297
/// Ensures interchangeability with schedulers that need to scale the denoising model input
300298
/// depending on the current timestep.
301-
fn scale_model_input(&self, sample: Tensor, _timestep: usize) -> Tensor {
299+
pub fn scale_model_input(&self, sample: Tensor, _timestep: usize) -> Tensor {
302300
sample
303301
}
304302

@@ -309,7 +307,7 @@ impl DPMSolverScheduler for DPMSolverSinglestepScheduler {
309307
/// * `model_output` - direct output from learned diffusion model
310308
/// * `timestep` - current discrete timestep in the diffusion chain
311309
/// * `sample` - current instance of sample being created by diffusion process
312-
fn step(&mut self, model_output: &Tensor, timestep: usize, sample: &Tensor) -> Tensor {
310+
pub fn step(&mut self, model_output: &Tensor, timestep: usize, sample: &Tensor) -> Tensor {
313311
// https://github.com/huggingface/diffusers/blob/e4fe9413121b78c4c1f109b50f0f3cc1c320a1a2/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py#L535
314312
let step_index: usize = self.timesteps.iter().position(|&t| t == timestep).unwrap();
315313

@@ -331,19 +329,19 @@ impl DPMSolverScheduler for DPMSolverSinglestepScheduler {
331329
};
332330

333331
match order {
334-
1 => self.first_order_update(
335-
model_output,
332+
1 => self.dpm_solver_first_order_update(
333+
&self.model_outputs[self.model_outputs.len() - 1],
336334
timestep,
337335
prev_timestep,
338336
&self.sample.as_ref().unwrap(),
339337
),
340-
2 => self.second_order_update(
338+
2 => self.singlestep_dpm_solver_second_order_update(
341339
&self.model_outputs,
342340
[self.timesteps[step_index - 1], self.timesteps[step_index]],
343341
prev_timestep,
344342
&self.sample.as_ref().unwrap(),
345343
),
346-
3 => self.third_order_update(
344+
3 => self.singlestep_dpm_solver_third_order_update(
347345
&self.model_outputs,
348346
[
349347
self.timesteps[step_index - 2],
@@ -359,12 +357,12 @@ impl DPMSolverScheduler for DPMSolverSinglestepScheduler {
359357
}
360358
}
361359

362-
fn add_noise(&self, original_samples: &Tensor, noise: Tensor, timestep: usize) -> Tensor {
360+
pub fn add_noise(&self, original_samples: &Tensor, noise: Tensor, timestep: usize) -> Tensor {
363361
self.alphas_cumprod[timestep].sqrt() * original_samples.to_owned()
364362
+ (1.0 - self.alphas_cumprod[timestep]).sqrt() * noise
365363
}
366364

367-
fn init_noise_sigma(&self) -> f64 {
365+
pub fn init_noise_sigma(&self) -> f64 {
368366
self.init_noise_sigma
369367
}
370368
}

0 commit comments

Comments
 (0)