diff --git a/src/diffusers/schedulers/scheduling_pndm.py b/src/diffusers/schedulers/scheduling_pndm.py index 2ef59c438889..94185c5b9a2e 100644 --- a/src/diffusers/schedulers/scheduling_pndm.py +++ b/src/diffusers/schedulers/scheduling_pndm.py @@ -56,6 +56,7 @@ def __init__( beta_end=0.02, beta_schedule="linear", tensor_format="pt", + skip_prk_steps=False, ): if beta_schedule == "linear": @@ -88,6 +89,7 @@ def __init__( # setable values self.num_inference_steps = None self._timesteps = np.arange(0, num_train_timesteps)[::-1].copy() + self._offset = 0 self.prk_timesteps = None self.plms_timesteps = None self.timesteps = None @@ -95,17 +97,27 @@ def __init__( self.tensor_format = tensor_format self.set_format(tensor_format=tensor_format) - def set_timesteps(self, num_inference_steps): + def set_timesteps(self, num_inference_steps, offset=0): self.num_inference_steps = num_inference_steps self._timesteps = list( range(0, self.config.num_train_timesteps, self.config.num_train_timesteps // num_inference_steps) ) + self._offset = offset + self._timesteps = [t + self._offset for t in self._timesteps] + + if self.config.skip_prk_steps: + # for some models like stable diffusion the prk steps can/should be skipped to + # produce better results. When using PNDM with `self.config.skip_prk_steps` the implementation + # is based on crowsonkb's PLMS sampler implementation: https://github.com/CompVis/latent-diffusion/pull/51 + self.prk_timesteps = [] + self.plms_timesteps = list(reversed(self._timesteps[:-1] + self._timesteps[-2:-1] + self._timesteps[-1:])) + else: + prk_timesteps = np.array(self._timesteps[-self.pndm_order :]).repeat(2) + np.tile( + np.array([0, self.config.num_train_timesteps // num_inference_steps // 2]), self.pndm_order + ) + self.prk_timesteps = list(reversed(prk_timesteps[:-1].repeat(2)[1:-1])) + self.plms_timesteps = list(reversed(self._timesteps[:-3])) - prk_timesteps = np.array(self._timesteps[-self.pndm_order :]).repeat(2) + np.tile( - np.array([0, self.config.num_train_timesteps // num_inference_steps // 2]), self.pndm_order - ) - self.prk_timesteps = list(reversed(prk_timesteps[:-1].repeat(2)[1:-1])) - self.plms_timesteps = list(reversed(self._timesteps[:-3])) self.timesteps = self.prk_timesteps + self.plms_timesteps self.counter = 0 @@ -117,7 +129,7 @@ def step( timestep: int, sample: Union[torch.FloatTensor, np.ndarray], ): - if self.counter < len(self.prk_timesteps): + if self.counter < len(self.prk_timesteps) and not self.config.skip_prk_steps: return self.step_prk(model_output=model_output, timestep=timestep, sample=sample) else: return self.step_plms(model_output=model_output, timestep=timestep, sample=sample) @@ -166,7 +178,7 @@ def step_plms( Step function propagating the sample with the linear multi-step method. This has one forward pass with multiple times to approximate the solution. """ - if len(self.ets) < 3: + if not self.config.skip_prk_steps and len(self.ets) < 3: raise ValueError( f"{self.__class__} can only be run AFTER scheduler has been run " "in 'prk' mode for at least 12 iterations " @@ -175,9 +187,26 @@ def step_plms( ) prev_timestep = max(timestep - self.config.num_train_timesteps // self.num_inference_steps, 0) - self.ets.append(model_output) - model_output = (1 / 24) * (55 * self.ets[-1] - 59 * self.ets[-2] + 37 * self.ets[-3] - 9 * self.ets[-4]) + if self.counter != 1: + self.ets.append(model_output) + else: + prev_timestep = timestep + timestep = timestep + self.config.num_train_timesteps // self.num_inference_steps + + if len(self.ets) == 1 and self.counter == 0: + model_output = model_output + self.cur_sample = sample + elif len(self.ets) == 1 and self.counter == 1: + model_output = (model_output + self.ets[-1]) / 2 + sample = self.cur_sample + self.cur_sample = None + elif len(self.ets) == 2: + model_output = (3 * self.ets[-1] - self.ets[-2]) / 2 + elif len(self.ets) == 3: + model_output = (23 * self.ets[-1] - 16 * self.ets[-2] + 5 * self.ets[-3]) / 12 + else: + model_output = (1 / 24) * (55 * self.ets[-1] - 59 * self.ets[-2] + 37 * self.ets[-3] - 9 * self.ets[-4]) prev_sample = self._get_prev_sample(sample, timestep, prev_timestep, model_output) self.counter += 1 @@ -197,8 +226,8 @@ def _get_prev_sample(self, sample, timestep, timestep_prev, model_output): # sample -> x_t # model_output -> e_θ(x_t, t) # prev_sample -> x_(t−δ) - alpha_prod_t = self.alphas_cumprod[timestep + 1] - alpha_prod_t_prev = self.alphas_cumprod[timestep_prev + 1] + alpha_prod_t = self.alphas_cumprod[timestep + 1 - self._offset] + alpha_prod_t_prev = self.alphas_cumprod[timestep_prev + 1 - self._offset] beta_prod_t = 1 - alpha_prod_t beta_prod_t_prev = 1 - alpha_prod_t_prev diff --git a/tests/test_modeling_utils.py b/tests/test_modeling_utils.py index 894a4294d664..d38a0ff9cb4a 100755 --- a/tests/test_modeling_utils.py +++ b/tests/test_modeling_utils.py @@ -843,6 +843,7 @@ def test_ldm_text2img_fast(self): @slow @unittest.skipIf(torch_device == "cpu", "Stable diffusion is suppused to run on GPU") def test_stable_diffusion(self): + # make sure here that pndm scheduler skips prk sd_pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-1-diffusers") prompt = "A painting of a squirrel eating a burger" @@ -857,7 +858,7 @@ def test_stable_diffusion(self): image_slice = image[0, -3:, -3:, -1] assert image.shape == (1, 512, 512, 3) - expected_slice = np.array([0.898, 0.9194, 0.91, 0.8955, 0.915, 0.919, 0.9233, 0.9307, 0.8887]) + expected_slice = np.array([0.8887, 0.915, 0.91, 0.894, 0.909, 0.912, 0.919, 0.925, 0.883]) assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 @slow