Skip to content

Commit 5a8b356

Browse files
authored
[DDIMScheduler] fix noise device in ddim step (#1189)
* fix noise device in ddim sched * fix typo * self.device -> device * remove duplicated if * use str device * don't use str for device
1 parent 20a05d6 commit 5a8b356

File tree

1 file changed

+9
-4
lines changed

1 file changed

+9
-4
lines changed

src/diffusers/schedulers/scheduling_ddim.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -288,17 +288,22 @@ def step(
288288

289289
if eta > 0:
290290
# randn_like does not support generator https://github.com/pytorch/pytorch/issues/27072
291-
device = model_output.device if torch.is_tensor(model_output) else "cpu"
291+
device = model_output.device if torch.is_tensor(model_output) else torch.device("cpu")
292292
if variance_noise is not None and generator is not None:
293293
raise ValueError(
294294
"Cannot pass both generator and variance_noise. Please make sure that either `generator` or"
295295
" `variance_noise` stays `None`."
296296
)
297297

298298
if variance_noise is None:
299-
variance_noise = torch.randn(model_output.shape, dtype=model_output.dtype, generator=generator).to(
300-
device
301-
)
299+
if device.type == "mps":
300+
# randn does not work reproducibly on mps
301+
variance_noise = torch.randn(model_output.shape, dtype=model_output.dtype, generator=generator)
302+
variance_noise = variance_noise.to(device)
303+
else:
304+
variance_noise = torch.randn(
305+
model_output.shape, generator=generator, device=device, dtype=model_output.dtype
306+
)
302307
variance = self._get_variance(timestep, prev_timestep) ** (0.5) * eta * variance_noise
303308

304309
prev_sample = prev_sample + variance

0 commit comments

Comments
 (0)