From 7d90f1c95d58da40d81c6dc64711e222a6b36f3a Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Tue, 8 Nov 2022 12:30:16 +0100 Subject: [PATCH 1/6] fix noise device in ddim sched --- src/diffusers/schedulers/scheduling_ddim.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_ddim.py b/src/diffusers/schedulers/scheduling_ddim.py index 8d4407c16c30..0fd9f4820f6e 100644 --- a/src/diffusers/schedulers/scheduling_ddim.py +++ b/src/diffusers/schedulers/scheduling_ddim.py @@ -296,9 +296,15 @@ def step( ) if variance_noise is None: - variance_noise = torch.randn(model_output.shape, dtype=model_output.dtype, generator=generator).to( - device - ) + if variance_noise is None: + if self.device.type == "mps": + # randn does not work reproducibly on mps + variance_noise = torch.randn(model_output.shape, dtype=model_output.dtype, generator=generator) + variance_noise = variance_noise.to(self.device) + else: + variance_noise = torch.randn( + model_output.shape, generator=generator, device=device, dtype=model_output.dtype + ) variance = self._get_variance(timestep, prev_timestep) ** (0.5) * eta * variance_noise prev_sample = prev_sample + variance From 19449145ab7123ec2bfcefcd799cc01a889067b3 Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Tue, 8 Nov 2022 12:38:03 +0100 Subject: [PATCH 2/6] fix typo --- src/diffusers/schedulers/scheduling_ddim.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/schedulers/scheduling_ddim.py b/src/diffusers/schedulers/scheduling_ddim.py index 0fd9f4820f6e..d8021f372d51 100644 --- a/src/diffusers/schedulers/scheduling_ddim.py +++ b/src/diffusers/schedulers/scheduling_ddim.py @@ -300,7 +300,7 @@ def step( if self.device.type == "mps": # randn does not work reproducibly on mps variance_noise = torch.randn(model_output.shape, dtype=model_output.dtype, generator=generator) - variance_noise = variance_noise.to(self.device) + variance_noise = variance_noise.to(device) else: variance_noise = torch.randn( model_output.shape, generator=generator, device=device, dtype=model_output.dtype From 4bc7d6d70da2def70fb812323c5fcdac70e1588b Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Tue, 8 Nov 2022 12:43:25 +0100 Subject: [PATCH 3/6] self.device -> device --- src/diffusers/schedulers/scheduling_ddim.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/schedulers/scheduling_ddim.py b/src/diffusers/schedulers/scheduling_ddim.py index d8021f372d51..52fa3e27b1de 100644 --- a/src/diffusers/schedulers/scheduling_ddim.py +++ b/src/diffusers/schedulers/scheduling_ddim.py @@ -297,7 +297,7 @@ def step( if variance_noise is None: if variance_noise is None: - if self.device.type == "mps": + if device.type == "mps": # randn does not work reproducibly on mps variance_noise = torch.randn(model_output.shape, dtype=model_output.dtype, generator=generator) variance_noise = variance_noise.to(device) From 8644fd01b53f70985d27e8ecdcbe30c2a146f8f1 Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Tue, 8 Nov 2022 12:46:15 +0100 Subject: [PATCH 4/6] remove duplicated if --- src/diffusers/schedulers/scheduling_ddim.py | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_ddim.py b/src/diffusers/schedulers/scheduling_ddim.py index 52fa3e27b1de..98b629e6c52f 100644 --- a/src/diffusers/schedulers/scheduling_ddim.py +++ b/src/diffusers/schedulers/scheduling_ddim.py @@ -296,15 +296,14 @@ def step( ) if variance_noise is None: - if variance_noise is None: - if device.type == "mps": - # randn does not work reproducibly on mps - variance_noise = torch.randn(model_output.shape, dtype=model_output.dtype, generator=generator) - variance_noise = variance_noise.to(device) - else: - variance_noise = torch.randn( - model_output.shape, generator=generator, device=device, dtype=model_output.dtype - ) + if device.type == "mps": + # randn does not work reproducibly on mps + variance_noise = torch.randn(model_output.shape, dtype=model_output.dtype, generator=generator) + variance_noise = variance_noise.to(device) + else: + variance_noise = torch.randn( + model_output.shape, generator=generator, device=device, dtype=model_output.dtype + ) variance = self._get_variance(timestep, prev_timestep) ** (0.5) * eta * variance_noise prev_sample = prev_sample + variance From 93b2ceee1996cd7ad9e99810b37b1df4d559938e Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Tue, 8 Nov 2022 12:46:36 +0100 Subject: [PATCH 5/6] use str device --- src/diffusers/schedulers/scheduling_ddim.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/schedulers/scheduling_ddim.py b/src/diffusers/schedulers/scheduling_ddim.py index 98b629e6c52f..a181c5d14f43 100644 --- a/src/diffusers/schedulers/scheduling_ddim.py +++ b/src/diffusers/schedulers/scheduling_ddim.py @@ -296,7 +296,7 @@ def step( ) if variance_noise is None: - if device.type == "mps": + if str(device) == "mps": # randn does not work reproducibly on mps variance_noise = torch.randn(model_output.shape, dtype=model_output.dtype, generator=generator) variance_noise = variance_noise.to(device) From e7d9fc6df566645bf75b047a65f6c40cf9ddf340 Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Tue, 8 Nov 2022 12:58:54 +0100 Subject: [PATCH 6/6] don't use str for device --- src/diffusers/schedulers/scheduling_ddim.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_ddim.py b/src/diffusers/schedulers/scheduling_ddim.py index a181c5d14f43..1acb81764d32 100644 --- a/src/diffusers/schedulers/scheduling_ddim.py +++ b/src/diffusers/schedulers/scheduling_ddim.py @@ -288,7 +288,7 @@ def step( if eta > 0: # randn_like does not support generator https://github.com/pytorch/pytorch/issues/27072 - device = model_output.device if torch.is_tensor(model_output) else "cpu" + device = model_output.device if torch.is_tensor(model_output) else torch.device("cpu") if variance_noise is not None and generator is not None: raise ValueError( "Cannot pass both generator and variance_noise. Please make sure that either `generator` or" @@ -296,7 +296,7 @@ def step( ) if variance_noise is None: - if str(device) == "mps": + if device.type == "mps": # randn does not work reproducibly on mps variance_noise = torch.randn(model_output.shape, dtype=model_output.dtype, generator=generator) variance_noise = variance_noise.to(device)