diff --git a/src/diffusers/schedulers/scheduling_lcm.py b/src/diffusers/schedulers/scheduling_lcm.py index 209125f156d1..8263e487fa44 100644 --- a/src/diffusers/schedulers/scheduling_lcm.py +++ b/src/diffusers/schedulers/scheduling_lcm.py @@ -373,10 +373,11 @@ def set_timesteps( ) # LCM Timesteps Setting - # Currently, only linear spacing is supported. - c = self.config.num_train_timesteps // original_steps - # LCM Training Steps Schedule - lcm_origin_timesteps = np.asarray(list(range(1, int(original_steps * strength) + 1))) * c - 1 + # The skipping step parameter k from the paper. + k = self.config.num_train_timesteps // original_steps + # LCM Training/Distillation Steps Schedule + # Currently, only a linearly-spaced schedule is supported (same as in the LCM distillation scripts). + lcm_origin_timesteps = np.asarray(list(range(1, int(original_steps * strength) + 1))) * k - 1 skipping_step = len(lcm_origin_timesteps) // num_inference_steps if skipping_step < 1: @@ -385,9 +386,13 @@ def set_timesteps( ) # LCM Inference Steps Schedule - timesteps = lcm_origin_timesteps[::-skipping_step][:num_inference_steps] + lcm_origin_timesteps = lcm_origin_timesteps[::-1].copy() + # Select (approximately) evenly spaced indices from lcm_origin_timesteps. + inference_indices = np.linspace(0, len(lcm_origin_timesteps), num=num_inference_steps, endpoint=False) + inference_indices = np.floor(inference_indices).astype(np.int64) + timesteps = lcm_origin_timesteps[inference_indices] - self.timesteps = torch.from_numpy(timesteps.copy()).to(device=device, dtype=torch.long) + self.timesteps = torch.from_numpy(timesteps).to(device=device, dtype=torch.long) self._step_index = None diff --git a/tests/schedulers/test_scheduler_lcm.py b/tests/schedulers/test_scheduler_lcm.py index f7d511ff0573..014cdca90479 100644 --- a/tests/schedulers/test_scheduler_lcm.py +++ b/tests/schedulers/test_scheduler_lcm.py @@ -84,7 +84,7 @@ def test_time_indices(self): def test_inference_steps(self): # Hardcoded for now - for t, num_inference_steps in zip([99, 39, 19], [10, 25, 50]): + for t, num_inference_steps in zip([99, 39, 39, 19], [10, 25, 26, 50]): self.check_over_forward(time_step=t, num_inference_steps=num_inference_steps) # Override test_add_noise_device because the hardcoded num_inference_steps of 100 doesn't work