Skip to content

Commit 6a0a312

Browse files
authored
Fix bug in half precision for DPMSolverMultistepScheduler (#1349)
* cast to float for quantile method * add fp16 test for DPMSolverMultistepScheduler fix * formatting update
1 parent c28d3c8 commit 6a0a312

File tree

2 files changed

+20
-0
lines changed

2 files changed

+20
-0
lines changed

src/diffusers/schedulers/scheduling_dpmsolver_multistep.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -247,6 +247,9 @@ def convert_model_output(
247247

248248
if self.config.thresholding:
249249
# Dynamic thresholding in https://arxiv.org/abs/2205.11487
250+
orig_dtype = x0_pred.dtype
251+
if orig_dtype not in [torch.float, torch.double]:
252+
x0_pred = x0_pred.float()
250253
dynamic_max_val = torch.quantile(
251254
torch.abs(x0_pred).reshape((x0_pred.shape[0], -1)), self.config.dynamic_thresholding_ratio, dim=1
252255
)
@@ -255,6 +258,7 @@ def convert_model_output(
255258
self.config.sample_max_value * torch.ones_like(dynamic_max_val).to(dynamic_max_val.device),
256259
)[(...,) + (None,) * (x0_pred.ndim - 1)]
257260
x0_pred = torch.clamp(x0_pred, -dynamic_max_val, dynamic_max_val) / dynamic_max_val
261+
x0_pred = x0_pred.type(orig_dtype)
258262
return x0_pred
259263
# DPM-Solver needs to solve an integral of the noise prediction model.
260264
elif self.config.algorithm_type == "dpmsolver":

tests/test_scheduler.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -991,6 +991,22 @@ def test_full_loop_no_noise(self):
991991

992992
assert abs(result_mean.item() - 0.3301) < 1e-3
993993

994+
def test_fp16_support(self):
995+
scheduler_class = self.scheduler_classes[0]
996+
scheduler_config = self.get_scheduler_config(thresholding=True, dynamic_thresholding_ratio=0)
997+
scheduler = scheduler_class(**scheduler_config)
998+
999+
num_inference_steps = 10
1000+
model = self.dummy_model()
1001+
sample = self.dummy_sample_deter.half()
1002+
scheduler.set_timesteps(num_inference_steps)
1003+
1004+
for i, t in enumerate(scheduler.timesteps):
1005+
residual = model(sample, t)
1006+
sample = scheduler.step(residual, t, sample).prev_sample
1007+
1008+
assert sample.dtype == torch.float16
1009+
9941010

9951011
class PNDMSchedulerTest(SchedulerCommonTest):
9961012
scheduler_classes = (PNDMScheduler,)

0 commit comments

Comments
 (0)