Skip to content

Commit 1a69c6f

Browse files
authored
Fix MPS scheduler indexing when using mps (#450)
* Fix LMS scheduler indexing in `add_noise` #358. * Fix DDIM and DDPM indexing with mps device. * Verify format is PyTorch before using `.to()`
1 parent 7c4b38b commit 1a69c6f

File tree

4 files changed

+9
-3
lines changed

4 files changed

+9
-3
lines changed

src/diffusers/schedulers/scheduling_ddim.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -250,6 +250,8 @@ def add_noise(
250250
noise: Union[torch.FloatTensor, np.ndarray],
251251
timesteps: Union[torch.IntTensor, np.ndarray],
252252
) -> Union[torch.FloatTensor, np.ndarray]:
253+
if self.tensor_format == "pt":
254+
timesteps = timesteps.to(self.alphas_cumprod.device)
253255
sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
254256
sqrt_alpha_prod = self.match_shape(sqrt_alpha_prod, original_samples)
255257
sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5

src/diffusers/schedulers/scheduling_ddpm.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -251,6 +251,8 @@ def add_noise(
251251
noise: Union[torch.FloatTensor, np.ndarray],
252252
timesteps: Union[torch.IntTensor, np.ndarray],
253253
) -> Union[torch.FloatTensor, np.ndarray]:
254+
if self.tensor_format == "pt":
255+
timesteps = timesteps.to(self.alphas_cumprod.device)
254256
sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
255257
sqrt_alpha_prod = self.match_shape(sqrt_alpha_prod, original_samples)
256258
sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5

src/diffusers/schedulers/scheduling_lms_discrete.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ def set_timesteps(self, num_inference_steps: int):
120120
frac = np.mod(self.timesteps, 1.0)
121121
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
122122
sigmas = (1 - frac) * sigmas[low_idx] + frac * sigmas[high_idx]
123-
self.sigmas = np.concatenate([sigmas, [0.0]])
123+
self.sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32)
124124

125125
self.derivatives = []
126126

@@ -183,6 +183,8 @@ def add_noise(
183183
noise: Union[torch.FloatTensor, np.ndarray],
184184
timesteps: Union[torch.IntTensor, np.ndarray],
185185
) -> Union[torch.FloatTensor, np.ndarray]:
186+
if self.tensor_format == "pt":
187+
timesteps = timesteps.to(self.sigmas.device)
186188
sigmas = self.match_shape(self.sigmas[timesteps], noise)
187189
noisy_samples = original_samples + noise * sigmas
188190

src/diffusers/schedulers/scheduling_pndm.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -367,8 +367,8 @@ def add_noise(
367367
noise: Union[torch.FloatTensor, np.ndarray],
368368
timesteps: Union[torch.IntTensor, np.ndarray],
369369
) -> torch.Tensor:
370-
# mps requires indices to be in the same device, so we use cpu as is the default with cuda
371-
timesteps = timesteps.to(self.alphas_cumprod.device)
370+
if self.tensor_format == "pt":
371+
timesteps = timesteps.to(self.alphas_cumprod.device)
372372
sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
373373
sqrt_alpha_prod = self.match_shape(sqrt_alpha_prod, original_samples)
374374
sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5

0 commit comments

Comments
 (0)