-
Notifications
You must be signed in to change notification settings - Fork 6.5k
adding callback_on_step_end for StableDiffusionLDM3DPipeline
#7149
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 4 commits
Commits
Show all changes
5 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -59,6 +59,64 @@ | |
| """ | ||
|
|
||
|
|
||
| def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): | ||
| """ | ||
| Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and | ||
| Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 | ||
| """ | ||
| std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) | ||
| std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) | ||
| # rescale the results from guidance (fixes overexposure) | ||
| noise_pred_rescaled = noise_cfg * (std_text / std_cfg) | ||
| # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images | ||
| noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg | ||
| return noise_cfg | ||
|
|
||
|
|
||
| def retrieve_timesteps( | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. add a
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sure, I will |
||
| scheduler, | ||
| num_inference_steps: Optional[int] = None, | ||
| device: Optional[Union[str, torch.device]] = None, | ||
| timesteps: Optional[List[int]] = None, | ||
| **kwargs, | ||
| ): | ||
| """ | ||
| Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles | ||
| custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. | ||
|
|
||
| Args: | ||
| scheduler (`SchedulerMixin`): | ||
| The scheduler to get timesteps from. | ||
| num_inference_steps (`int`): | ||
| The number of diffusion steps used when generating samples with a pre-trained model. If used, | ||
| `timesteps` must be `None`. | ||
| device (`str` or `torch.device`, *optional*): | ||
| The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. | ||
| timesteps (`List[int]`, *optional*): | ||
| Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default | ||
| timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps` | ||
| must be `None`. | ||
|
|
||
| Returns: | ||
| `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the | ||
| second element is the number of inference steps. | ||
| """ | ||
| if timesteps is not None: | ||
| accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) | ||
| if not accepts_timesteps: | ||
| raise ValueError( | ||
| f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" | ||
| f" timestep schedules. Please check whether you are using the correct scheduler." | ||
| ) | ||
| scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) | ||
| timesteps = scheduler.timesteps | ||
| num_inference_steps = len(timesteps) | ||
| else: | ||
| scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) | ||
| timesteps = scheduler.timesteps | ||
| return timesteps, num_inference_steps | ||
|
|
||
|
|
||
| @dataclass | ||
| class LDM3DPipelineOutput(BaseOutput): | ||
| """ | ||
|
|
@@ -125,6 +183,7 @@ class StableDiffusionLDM3DPipeline( | |
| model_cpu_offload_seq = "text_encoder->unet->vae" | ||
| _optional_components = ["safety_checker", "feature_extractor", "image_encoder"] | ||
| _exclude_from_cpu_offload = ["safety_checker"] | ||
| _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] | ||
|
|
||
| def __init__( | ||
| self, | ||
|
|
@@ -575,6 +634,66 @@ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype | |
| latents = latents * self.scheduler.init_noise_sigma | ||
| return latents | ||
|
|
||
| # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding | ||
| def get_guidance_scale_embedding(self, w, embedding_dim=512, dtype=torch.float32): | ||
| """ | ||
| See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298 | ||
|
|
||
| Args: | ||
| timesteps (`torch.Tensor`): | ||
| generate embedding vectors at these timesteps | ||
| embedding_dim (`int`, *optional*, defaults to 512): | ||
| dimension of the embeddings to generate | ||
| dtype: | ||
| data type of the generated embeddings | ||
|
|
||
| Returns: | ||
| `torch.FloatTensor`: Embedding vectors with shape `(len(timesteps), embedding_dim)` | ||
| """ | ||
| assert len(w.shape) == 1 | ||
| w = w * 1000.0 | ||
|
|
||
| half_dim = embedding_dim // 2 | ||
| emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1) | ||
| emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb) | ||
| emb = w.to(dtype)[:, None] * emb[None, :] | ||
| emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) | ||
| if embedding_dim % 2 == 1: # zero pad | ||
| emb = torch.nn.functional.pad(emb, (0, 1)) | ||
| assert emb.shape == (w.shape[0], embedding_dim) | ||
| return emb | ||
|
|
||
| @property | ||
| def guidance_scale(self): | ||
| return self._guidance_scale | ||
|
|
||
| @property | ||
| def guidance_rescale(self): | ||
| return self._guidance_rescale | ||
|
|
||
| @property | ||
| def clip_skip(self): | ||
| return self._clip_skip | ||
|
|
||
| # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) | ||
| # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` | ||
| # corresponds to doing no classifier free guidance. | ||
| @property | ||
| def do_classifier_free_guidance(self): | ||
| return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None | ||
|
|
||
| @property | ||
| def cross_attention_kwargs(self): | ||
| return self._cross_attention_kwargs | ||
|
|
||
| @property | ||
| def num_timesteps(self): | ||
| return self._num_timesteps | ||
|
|
||
| @property | ||
| def interrupt(self): | ||
| return self._interrupt | ||
|
|
||
| @torch.no_grad() | ||
| @replace_example_docstring(EXAMPLE_DOC_STRING) | ||
| def __call__( | ||
|
|
@@ -583,6 +702,7 @@ def __call__( | |
| height: Optional[int] = None, | ||
| width: Optional[int] = None, | ||
| num_inference_steps: int = 49, | ||
| timesteps: List[int] = None, | ||
| guidance_scale: float = 5.0, | ||
| negative_prompt: Optional[Union[str, List[str]]] = None, | ||
| num_images_per_prompt: Optional[int] = 1, | ||
|
|
@@ -595,10 +715,12 @@ def __call__( | |
| ip_adapter_image_embeds: Optional[List[torch.FloatTensor]] = None, | ||
| output_type: Optional[str] = "pil", | ||
| return_dict: bool = True, | ||
| callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, | ||
| callback_steps: int = 1, | ||
| cross_attention_kwargs: Optional[Dict[str, Any]] = None, | ||
| guidance_rescale: float = 0.0, | ||
| clip_skip: Optional[int] = None, | ||
| callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, | ||
| callback_on_step_end_tensor_inputs: List[str] = ["latents"], | ||
| **kwargs, | ||
| ): | ||
| r""" | ||
| The call function to the pipeline for generation. | ||
|
|
@@ -649,18 +771,21 @@ def __call__( | |
| return_dict (`bool`, *optional*, defaults to `True`): | ||
| Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a | ||
| plain tuple. | ||
| callback (`Callable`, *optional*): | ||
| A function that calls every `callback_steps` steps during inference. The function is called with the | ||
| following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. | ||
| callback_steps (`int`, *optional*, defaults to 1): | ||
| The frequency at which the `callback` function is called. If not specified, the callback is called at | ||
| every step. | ||
| cross_attention_kwargs (`dict`, *optional*): | ||
| A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in | ||
| [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). | ||
| clip_skip (`int`, *optional*): | ||
| Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that | ||
| the output of the pre-final layer will be used for computing the prompt embeddings. | ||
| callback_on_step_end (`Callable`, *optional*): | ||
| A function that calls at the end of each denoising steps during the inference. The function is called | ||
| with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, | ||
| callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by | ||
| `callback_on_step_end_tensor_inputs`. | ||
| callback_on_step_end_tensor_inputs (`List`, *optional*): | ||
| The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list | ||
| will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the | ||
| `._callback_tensor_inputs` attribute of your pipeline class. | ||
| Examples: | ||
|
|
||
| Returns: | ||
|
|
@@ -670,6 +795,22 @@ def __call__( | |
| second element is a list of `bool`s indicating whether the corresponding generated image contains | ||
| "not-safe-for-work" (nsfw) content. | ||
| """ | ||
| callback = kwargs.pop("callback", None) | ||
| callback_steps = kwargs.pop("callback_steps", None) | ||
|
|
||
| if callback is not None: | ||
| deprecate( | ||
| "callback", | ||
| "1.0.0", | ||
| "Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`", | ||
| ) | ||
| if callback_steps is not None: | ||
| deprecate( | ||
| "callback_steps", | ||
| "1.0.0", | ||
| "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`", | ||
| ) | ||
|
|
||
| # 0. Default height and width to unet | ||
| height = height or self.unet.config.sample_size * self.vae_scale_factor | ||
| width = width or self.unet.config.sample_size * self.vae_scale_factor | ||
|
|
@@ -685,8 +826,15 @@ def __call__( | |
| negative_prompt_embeds, | ||
| ip_adapter_image, | ||
| ip_adapter_image_embeds, | ||
| callback_on_step_end_tensor_inputs, | ||
| ) | ||
|
|
||
| self._guidance_scale = guidance_scale | ||
| self._guidance_rescale = guidance_rescale | ||
| self._clip_skip = clip_skip | ||
| self._cross_attention_kwargs = cross_attention_kwargs | ||
| self._interrupt = False | ||
|
|
||
| # 2. Define call parameters | ||
| if prompt is not None and isinstance(prompt, str): | ||
| batch_size = 1 | ||
|
|
@@ -696,26 +844,22 @@ def __call__( | |
| batch_size = prompt_embeds.shape[0] | ||
|
|
||
| device = self._execution_device | ||
| # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) | ||
| # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` | ||
| # corresponds to doing no classifier free guidance. | ||
| do_classifier_free_guidance = guidance_scale > 1.0 | ||
|
|
||
| if ip_adapter_image is not None or ip_adapter_image_embeds is not None: | ||
| image_embeds = self.prepare_ip_adapter_image_embeds( | ||
| ip_adapter_image, | ||
| ip_adapter_image_embeds, | ||
| device, | ||
| batch_size * num_images_per_prompt, | ||
| do_classifier_free_guidance, | ||
| self.do_classifier_free_guidance, | ||
| ) | ||
|
|
||
| # 3. Encode input prompt | ||
| prompt_embeds, negative_prompt_embeds = self.encode_prompt( | ||
| prompt, | ||
| device, | ||
| num_images_per_prompt, | ||
| do_classifier_free_guidance, | ||
| self.do_classifier_free_guidance, | ||
| negative_prompt, | ||
| prompt_embeds=prompt_embeds, | ||
| negative_prompt_embeds=negative_prompt_embeds, | ||
|
|
@@ -724,12 +868,11 @@ def __call__( | |
| # For classifier free guidance, we need to do two forward passes. | ||
| # Here we concatenate the unconditional and text embeddings into a single batch | ||
| # to avoid doing two forward passes | ||
| if do_classifier_free_guidance: | ||
| if self.do_classifier_free_guidance: | ||
| prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) | ||
|
|
||
| # 4. Prepare timesteps | ||
| self.scheduler.set_timesteps(num_inference_steps, device=device) | ||
| timesteps = self.scheduler.timesteps | ||
| timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps) | ||
|
|
||
| # 5. Prepare latent variables | ||
| num_channels_latents = self.unet.config.in_channels | ||
|
|
@@ -750,32 +893,59 @@ def __call__( | |
| # 6.1 Add image embeds for IP-Adapter | ||
| added_cond_kwargs = {"image_embeds": image_embeds} if ip_adapter_image is not None else None | ||
|
|
||
| # 6.2 Optionally get Guidance Scale Embedding | ||
| timestep_cond = None | ||
| if self.unet.config.time_cond_proj_dim is not None: | ||
| guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt) | ||
| timestep_cond = self.get_guidance_scale_embedding( | ||
| guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim | ||
| ).to(device=device, dtype=latents.dtype) | ||
|
|
||
| # 7. Denoising loop | ||
| num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order | ||
| self._num_timesteps = len(timesteps) | ||
| with self.progress_bar(total=num_inference_steps) as progress_bar: | ||
| for i, t in enumerate(timesteps): | ||
| if self.interrupt: | ||
| continue | ||
|
|
||
| # expand the latents if we are doing classifier free guidance | ||
| latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents | ||
| latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents | ||
| latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) | ||
|
|
||
| # predict the noise residual | ||
| noise_pred = self.unet( | ||
| latent_model_input, | ||
| t, | ||
| encoder_hidden_states=prompt_embeds, | ||
| timestep_cond=timestep_cond, | ||
| cross_attention_kwargs=cross_attention_kwargs, | ||
| added_cond_kwargs=added_cond_kwargs, | ||
| return_dict=False, | ||
| )[0] | ||
|
|
||
| # perform guidance | ||
| if do_classifier_free_guidance: | ||
| if self.do_classifier_free_guidance: | ||
| noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) | ||
| noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) | ||
|
|
||
| if self.do_classifier_free_guidance and self.guidance_rescale > 0.0: | ||
| # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf | ||
| noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale) | ||
|
|
||
| # compute the previous noisy sample x_t -> x_t-1 | ||
| latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] | ||
|
|
||
| if callback_on_step_end is not None: | ||
| callback_kwargs = {} | ||
| for k in callback_on_step_end_tensor_inputs: | ||
| callback_kwargs[k] = locals()[k] | ||
| callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) | ||
|
|
||
| latents = callback_outputs.pop("latents", latents) | ||
| prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) | ||
| negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) | ||
|
|
||
| # call the callback, if provided | ||
| if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): | ||
| progress_bar.update() | ||
|
|
||
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
add #copied from?