From 8856677256c9ef6e5f7387eef47e0c6f6b60ee3e Mon Sep 17 00:00:00 2001 From: Sergey Borisov Date: Wed, 26 Jun 2024 18:42:15 +0300 Subject: [PATCH 1/9] Initial new backend implementation --- .../stable_diffusion/addons/__init__.py | 9 + .../backend/stable_diffusion/addons/base.py | 23 ++ .../stable_diffusion/diffusers_pipeline.py | 361 ++++++++++++++++++ .../diffusion/conditioning_data.py | 97 +++++ 4 files changed, 490 insertions(+) create mode 100644 invokeai/backend/stable_diffusion/addons/__init__.py create mode 100644 invokeai/backend/stable_diffusion/addons/base.py diff --git a/invokeai/backend/stable_diffusion/addons/__init__.py b/invokeai/backend/stable_diffusion/addons/__init__.py new file mode 100644 index 00000000000..0b62b1aaaf9 --- /dev/null +++ b/invokeai/backend/stable_diffusion/addons/__init__.py @@ -0,0 +1,9 @@ +""" +Initialization file for the invokeai.backend.stable_diffusion.addons package +""" + +from .base import AddonBase # noqa: F401 + +__all__ = [ + "AddonBase", +] \ No newline at end of file diff --git a/invokeai/backend/stable_diffusion/addons/base.py b/invokeai/backend/stable_diffusion/addons/base.py new file mode 100644 index 00000000000..d1996ea8392 --- /dev/null +++ b/invokeai/backend/stable_diffusion/addons/base.py @@ -0,0 +1,23 @@ +from __future__ import annotations + +import torch +from typing import Any, Dict +from abc import ABC, abstractmethod +from invokeai.backend.stable_diffusion.diffusion.conditioning_data import TextConditioningData + + +class AddonBase(ABC): + + @abstractmethod + def pre_unet_step( + self, + sample: torch.Tensor, + timestep: torch.Tensor, + step_index: int, + total_steps: int, + conditioning_data: TextConditioningData, + + unet_kwargs: Dict[str, Any], + conditioning_mode: str, + ): + pass diff --git a/invokeai/backend/stable_diffusion/diffusers_pipeline.py b/invokeai/backend/stable_diffusion/diffusers_pipeline.py index ee464f73e1f..8eb6129d9dc 100644 --- a/invokeai/backend/stable_diffusion/diffusers_pipeline.py +++ b/invokeai/backend/stable_diffusion/diffusers_pipeline.py @@ -18,8 +18,10 @@ from diffusers.utils.import_utils import is_xformers_available from pydantic import Field from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer +from tqdm.auto import tqdm from invokeai.app.services.config.config_default import get_config +from invokeai.backend.stable_diffusion.addons import AddonBase from invokeai.backend.stable_diffusion.diffusion.conditioning_data import IPAdapterData, TextConditioningData from invokeai.backend.stable_diffusion.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent from invokeai.backend.stable_diffusion.diffusion.unet_attention_patcher import UNetAttentionPatcher, UNetIPAdapterData @@ -315,6 +317,26 @@ def latents_from_embeddings( SD UNet model. is_gradient_mask: A flag indicating whether `mask` is a gradient mask or not. """ + if True: + new_back = StableDiffusionBackend(self.unet, self.scheduler) + return new_back.latents_from_embeddings( + latents=latents, + scheduler_step_kwargs=scheduler_step_kwargs, + conditioning_data=conditioning_data, + noise=noise, + timesteps=timesteps, + init_timestep=init_timestep, + callback=callback, + control_data=control_data, + ip_adapter_data=ip_adapter_data, + t2i_adapter_data=t2i_adapter_data, + mask=mask, + masked_latents=masked_latents, + is_gradient_mask=is_gradient_mask, + seed=seed, + ) + + if init_timestep.shape[0] == 0: return latents @@ -590,3 +612,342 @@ def _unet_forward( cross_attention_kwargs=cross_attention_kwargs, **kwargs, ).sample + + +class StableDiffusionBackend: + + def __init__(self, unet, scheduler): + self.unet = unet + self.scheduler = scheduler + self.sequential_guidance = False # TODO: + + def latents_from_embeddings( + self, + latents: torch.Tensor, + scheduler_step_kwargs: dict[str, Any], + conditioning_data: TextConditioningData, + noise: Optional[torch.Tensor], + seed: int, + timesteps: torch.Tensor, + init_timestep: torch.Tensor, + callback: Callable[[PipelineIntermediateState], None], + control_data: list[ControlNetData] | None = None, # + ip_adapter_data: Optional[list[IPAdapterData]] = None, # + t2i_adapter_data: Optional[list[T2IAdapterData]] = None, # + mask: Optional[torch.Tensor] = None, + masked_latents: Optional[torch.Tensor] = None, + is_gradient_mask: bool = False, + ) -> torch.Tensor: + + addons = [] + # TODO: convert controlnet/ip/t2i + + return self.latents_from_embeddings2( + latents=latents, + scheduler_step_kwargs=scheduler_step_kwargs, + conditioning_data=conditioning_data, + noise=noise, + seed=seed, + timesteps=timesteps, + init_timestep=init_timestep, + callback=callback, + mask=mask, + masked_latents=masked_latents, + is_gradient_mask=is_gradient_mask, + addons=addons, + ) + + + def latents_from_embeddings2( + self, + latents: torch.Tensor, + scheduler_step_kwargs: Dict[str, Any], + conditioning_data: TextConditioningData, + noise: Optional[torch.Tensor], + seed: int, + timesteps: torch.Tensor, + init_timestep: torch.Tensor, + callback: Callable[[PipelineIntermediateState], None], + mask: Optional[torch.Tensor] = None, + masked_latents: Optional[torch.Tensor] = None, + is_gradient_mask: bool = False, + addons: List[AddonBase] = [], + ) -> torch.Tensor: + + if init_timestep.shape[0] == 0: + return latents + + orig_latents = latents.clone() + batch_size = latents.shape[0] + + if noise is not None: + # latents = noise * self.scheduler.init_noise_sigma # it's like in t2l according to diffusers + latents = self.scheduler.add_noise(latents, noise, init_timestep.expand(batch_size)) + + # if no work to do, return latents + if timesteps.shape[0] == 0: + return latents + + + # TODO: inpaint + + # TODO: attention patcher + with nullcontext(): + callback( + PipelineIntermediateState( + step=-1, + order=self.scheduler.order, + total_steps=len(timesteps), + timestep=self.scheduler.config.num_train_timesteps, + latents=latents, + ) + ) + + for step_index, timestep in enumerate(tqdm(timesteps)): + step_output = self.step( + timestep, + latents, + conditioning_data, + step_index=step_index, + total_steps=len(timesteps), + scheduler_step_kwargs=scheduler_step_kwargs, + addons=addons, + ) + + latents = step_output.prev_sample + if hasattr(step_output, "denoised"): + predicted_original = step_output.denoised + elif hasattr(step_output, "pred_original_sample"): + predicted_original = step_output.pred_original_sample + else: + predicted_original = latents + + callback( + PipelineIntermediateState( + step=step_index, + order=self.scheduler.order, + total_steps=len(timesteps), + timestep=int(timestep), + latents=latents, + predicted_original=predicted_original, + ) + ) + + # restore unmasked part after the last step is completed + # in-process masking happens before each step + if mask is not None: + if is_gradient_mask: + latents = torch.where(mask > 0, latents, orig_latents) + else: + latents = torch.lerp( + orig_latents, latents.to(dtype=orig_latents.dtype), mask.to(dtype=orig_latents.dtype) + ) + + return latents + + + @torch.inference_mode() + def step( + self, + timestep: torch.Tensor, + latents: torch.Tensor, + conditioning_data: TextConditioningData, + step_index: int, + total_steps: int, + scheduler_step_kwargs: dict[str, Any], + addons: List[AddonBase], + ): + latent_model_input = self.scheduler.scale_model_input(latents, timestep) + + if self.sequential_guidance: + negative_noise_pred, positive_noise_pred = self._apply_standard_conditioning_sequentially( + sample=latent_model_input, + timestep=timestep, + step_index=step_index, + total_steps=total_steps, + conditioning_data=conditioning_data, + addons=addons, + ) + else: + negative_noise_pred, positive_noise_pred = self._apply_standard_conditioning( + sample=latent_model_input, + timestep=timestep, + step_index=step_index, + total_steps=total_steps, + conditioning_data=conditioning_data, + addons=addons, + ) + + + # noise mix algo + guidance_scale = conditioning_data.guidance_scale + if isinstance(guidance_scale, list): + guidance_scale = guidance_scale[step_index] + + # lol, is this just lerp? + # out = start + weight * (end - start) + noise_pred = torch.lerp(negative_noise_pred, positive_noise_pred, guidance_scale) + # noise_pred = negative_noise_pred + guidance_scale * (positive_noise_pred - negative_noise_pred) + + + # cfg rescale + guidance_rescale_multiplier = conditioning_data.guidance_rescale_multiplier + if guidance_rescale_multiplier > 0: + noise_pred = self._rescale_cfg( + noise_pred, + positive_noise_pred, + guidance_rescale_multiplier, + ) + + # compute the previous noisy sample x_t -> x_t-1 + return self.scheduler.step(noise_pred, timestep, latents, **scheduler_step_kwargs) + + @staticmethod + def _rescale_cfg(total_noise_pred, pos_noise_pred, multiplier=0.7): + """Implementation of Algorithm 2 from https://arxiv.org/pdf/2305.08891.pdf.""" + ro_pos = torch.std(pos_noise_pred, dim=(1, 2, 3), keepdim=True) + ro_cfg = torch.std(total_noise_pred, dim=(1, 2, 3), keepdim=True) + + x_rescaled = total_noise_pred * (ro_pos / ro_cfg) + x_final = multiplier * x_rescaled + (1.0 - multiplier) * total_noise_pred + return x_final + + def _apply_standard_conditioning( + self, + sample: torch.Tensor, + timestep: torch.Tensor, + conditioning_data: TextConditioningData, + step_index: int, + total_steps: int, + addons: List[AddonBase], + ) -> tuple[torch.Tensor, torch.Tensor]: + """Runs the conditioned and unconditioned UNet forward passes in a single batch for faster inference speed at + the cost of higher memory usage. + """ + unet_kwargs = dict() + + for addon in addons: + addon.pre_unet_step( + sample=sample, + timestep=timestep, + step_index=step_index, + total_steps=total_steps, + conditioning_data=conditioning_data, + + unet_kwargs=unet_kwargs, + conditioning_mode="both", + ) + + # TODO: encoder_attention_mask + conditioning_data.to_unet_kwargs(unet_kwargs, "both") + + sample_twice = torch.cat([sample] * 2) + both_results = self._unet_forward( + sample_twice, + timestep, + # conditioning_data + #both_conditionings, + #encoder_attention_mask=encoder_attention_mask, + #added_cond_kwargs=added_cond_kwargs, + + # extra_guidance + #down_block_additional_residuals=down_block_additional_residuals, # cn + #mid_block_additional_residual=mid_block_additional_residual, # cn + #down_intrablock_additional_residuals=down_intrablock_additional_residuals, # t2i + + #cross_attention_kwargs=cross_attention_kwargs, + + **unet_kwargs, + ) + negative_next_x, positive_next_x = both_results.chunk(2) + return negative_next_x, positive_next_x + + + def _apply_standard_conditioning_sequentially( + self, + sample: torch.Tensor, + timestep: torch.Tensor, + conditioning_data: TextConditioningData, + step_index: int, + total_steps: int, + addons: List[AddonBase], + ): + """Runs the conditioned and unconditioned UNet forward passes sequentially for lower memory usage at the cost of + slower execution speed. + """ + + ################### + # Negative pass + ################### + + negative_unet_kwargs = dict() + for addon in addons: + addon.pre_unet_step( + sample=sample, + timestep=timestep, + step_index=step_index, + total_steps=total_steps, + conditioning_data=conditioning_data, + + unet_kwargs=negative_unet_kwargs, + conditioning_mode="negative", + ) + + conditioning_data.to_unet_kwargs(negative_unet_kwargs, "negative") + + negative_next_x = self._unet_forward( + sample, + timestep, + **negative_unet_kwargs, + ) + + del negative_unet_kwargs + # TODO: gc.collect() ? + + ################### + # Positive pass + ################### + + positive_unet_kwargs = dict() + for addon in addons: + addon.pre_unet_step( + sample=sample, + timestep=timestep, + step_index=step_index, + total_steps=total_steps, + conditioning_data=conditioning_data, + + unet_kwargs=positive_unet_kwargs, + conditioning_mode="positive", + ) + + conditioning_data.to_unet_kwargs(positive_unet_kwargs, "positive") + + # Run conditioned UNet denoising (i.e. positive prompt). + positive_next_x = self._unet_forward( + sample, + timestep, + **positive_unet_kwargs, + ) + + del positive_unet_kwargs + # TODO: gc.collect() ? + + return negative_next_x, positive_next_x + + + def _unet_forward( + self, + sample: torch.Tensor, + timestep: Union[torch.Tensor, float, int], + encoder_hidden_states: torch.Tensor, + cross_attention_kwargs: Optional[dict[str, Any]] = None, + **kwargs, + ): + return self.unet( + sample, + timestep, + encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + **kwargs, + ).sample diff --git a/invokeai/backend/stable_diffusion/diffusion/conditioning_data.py b/invokeai/backend/stable_diffusion/diffusion/conditioning_data.py index 85950a01df5..1c61f9151cb 100644 --- a/invokeai/backend/stable_diffusion/diffusion/conditioning_data.py +++ b/invokeai/backend/stable_diffusion/diffusion/conditioning_data.py @@ -121,3 +121,100 @@ def __init__( def is_sdxl(self): assert isinstance(self.uncond_text, SDXLConditioningInfo) == isinstance(self.cond_text, SDXLConditioningInfo) return isinstance(self.cond_text, SDXLConditioningInfo) + + # TODO: prompt regions + def to_unet_kwargs(self, unet_kwargs, conditioning_mode): + if conditioning_mode == "both": + encoder_hidden_states, encoder_attention_mask = self._concat_conditionings_for_batch( + self.uncond_text.embeds, self.cond_text.embeds + ) + elif conditioning_mode == "positive": + encoder_hidden_states = self.cond_text.embeds + encoder_attention_mask = None + else: # elif conditioning_mode == "negative": + encoder_hidden_states = self.uncond_text.embeds + encoder_attention_mask = None + + unet_kwargs.update(dict( + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + )) + + if self.is_sdxl(): + if conditioning_mode == "negative": + added_cond_kwargs = dict( + text_embeds=self.cond_text.pooled_embeds, + time_ids=self.cond_text.add_time_ids, + ) + elif conditioning_mode == "positive": + added_cond_kwargs = dict( + text_embeds=self.uncond_text.pooled_embeds, + time_ids=self.uncond_text.add_time_ids, + ) + else: # elif conditioning_mode == "both": + added_cond_kwargs = dict( + text_embeds=torch.cat( + [ + # TODO: how to pad? just by zeros? or even truncate? + self.uncond_text.pooled_embeds, + self.cond_text.pooled_embeds, + ], + ), + time_ids=torch.cat( + [ + self.uncond_text.add_time_ids, + self.cond_text.add_time_ids, + ], + ), + ) + + unet_kwargs.update(dict( + added_cond_kwargs=added_cond_kwargs, + )) + + def _concat_conditionings_for_batch(self, unconditioning, conditioning): + def _pad_conditioning(cond, target_len, encoder_attention_mask): + conditioning_attention_mask = torch.ones( + (cond.shape[0], cond.shape[1]), device=cond.device, dtype=cond.dtype + ) + + if cond.shape[1] < max_len: + conditioning_attention_mask = torch.cat( + [ + conditioning_attention_mask, + torch.zeros((cond.shape[0], max_len - cond.shape[1]), device=cond.device, dtype=cond.dtype), + ], + dim=1, + ) + + cond = torch.cat( + [ + cond, + torch.zeros( + (cond.shape[0], max_len - cond.shape[1], cond.shape[2]), + device=cond.device, + dtype=cond.dtype, + ), + ], + dim=1, + ) + + if encoder_attention_mask is None: + encoder_attention_mask = conditioning_attention_mask + else: + encoder_attention_mask = torch.cat( + [ + encoder_attention_mask, + conditioning_attention_mask, + ] + ) + + return cond, encoder_attention_mask + + encoder_attention_mask = None + if unconditioning.shape[1] != conditioning.shape[1]: + max_len = max(unconditioning.shape[1], conditioning.shape[1]) + unconditioning, encoder_attention_mask = _pad_conditioning(unconditioning, max_len, encoder_attention_mask) + conditioning, encoder_attention_mask = _pad_conditioning(conditioning, max_len, encoder_attention_mask) + + return torch.cat([unconditioning, conditioning]), encoder_attention_mask From cac1eb551ff39af50e0183c3c012e379e3103c6a Mon Sep 17 00:00:00 2001 From: Sergey Borisov Date: Wed, 26 Jun 2024 19:51:25 +0300 Subject: [PATCH 2/9] Handle regional prompts --- .../stable_diffusion/diffusers_pipeline.py | 24 ++++++-- .../diffusion/conditioning_data.py | 38 +++++++++++- .../diffusion/regional_prompt_data.py | 9 ++- .../diffusion/unet_attention_patcher.py | 60 +++++++++++++++++++ 4 files changed, 121 insertions(+), 10 deletions(-) diff --git a/invokeai/backend/stable_diffusion/diffusers_pipeline.py b/invokeai/backend/stable_diffusion/diffusers_pipeline.py index 8eb6129d9dc..32ec8075cd2 100644 --- a/invokeai/backend/stable_diffusion/diffusers_pipeline.py +++ b/invokeai/backend/stable_diffusion/diffusers_pipeline.py @@ -24,7 +24,7 @@ from invokeai.backend.stable_diffusion.addons import AddonBase from invokeai.backend.stable_diffusion.diffusion.conditioning_data import IPAdapterData, TextConditioningData from invokeai.backend.stable_diffusion.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent -from invokeai.backend.stable_diffusion.diffusion.unet_attention_patcher import UNetAttentionPatcher, UNetIPAdapterData +from invokeai.backend.stable_diffusion.diffusion.unet_attention_patcher import UNetAttentionPatcher, UNetAttentionPatcher_new, UNetIPAdapterData from invokeai.backend.util.attention import auto_detect_slice_size from invokeai.backend.util.devices import TorchDevice from invokeai.backend.util.hotfixes import ControlNetModel @@ -691,8 +691,8 @@ def latents_from_embeddings2( # TODO: inpaint - # TODO: attention patcher - with nullcontext(): + # TODO: ip_adapters + with UNetAttentionPatcher_new(self.unet): callback( PipelineIntermediateState( step=-1, @@ -824,7 +824,11 @@ def _apply_standard_conditioning( """Runs the conditioned and unconditioned UNet forward passes in a single batch for faster inference speed at the cost of higher memory usage. """ - unet_kwargs = dict() + unet_kwargs = unet_kwargs = dict( + cross_attention_kwargs=dict( + percent_through=step_index / total_steps, + ) + ) for addon in addons: addon.pre_unet_step( @@ -880,7 +884,11 @@ def _apply_standard_conditioning_sequentially( # Negative pass ################### - negative_unet_kwargs = dict() + negative_unet_kwargs = unet_kwargs = dict( + cross_attention_kwargs=dict( + percent_through=step_index / total_steps, + ) + ) for addon in addons: addon.pre_unet_step( sample=sample, @@ -908,7 +916,11 @@ def _apply_standard_conditioning_sequentially( # Positive pass ################### - positive_unet_kwargs = dict() + positive_unet_kwargs = unet_kwargs = dict( + cross_attention_kwargs=dict( + percent_through=step_index / total_steps, + ) + ) for addon in addons: addon.pre_unet_step( sample=sample, diff --git a/invokeai/backend/stable_diffusion/diffusion/conditioning_data.py b/invokeai/backend/stable_diffusion/diffusion/conditioning_data.py index 1c61f9151cb..5cbfc7e8f9f 100644 --- a/invokeai/backend/stable_diffusion/diffusion/conditioning_data.py +++ b/invokeai/backend/stable_diffusion/diffusion/conditioning_data.py @@ -5,6 +5,7 @@ import torch from invokeai.backend.ip_adapter.ip_adapter import IPAdapter +from invokeai.backend.stable_diffusion.diffusion.regional_prompt_data import RegionalPromptData @dataclass @@ -122,7 +123,6 @@ def is_sdxl(self): assert isinstance(self.uncond_text, SDXLConditioningInfo) == isinstance(self.cond_text, SDXLConditioningInfo) return isinstance(self.cond_text, SDXLConditioningInfo) - # TODO: prompt regions def to_unet_kwargs(self, unet_kwargs, conditioning_mode): if conditioning_mode == "both": encoder_hidden_states, encoder_attention_mask = self._concat_conditionings_for_batch( @@ -172,6 +172,42 @@ def to_unet_kwargs(self, unet_kwargs, conditioning_mode): added_cond_kwargs=added_cond_kwargs, )) + if self.cond_regions is not None or self.uncond_regions is not None: + # TODO(ryand): We currently initialize RegionalPromptData for every denoising step. The text conditionings + # and masks are not changing from step-to-step, so this really only needs to be done once. While this seems + # painfully inefficient, the time spent is typically negligible compared to the forward inference pass of + # the UNet. The main reason that this hasn't been moved up to eliminate redundancy is that it is slightly + # awkward to handle both standard conditioning and sequential conditioning further up the stack. + + _tmp_regions = self.cond_regions if self.cond_regions is not None else self.uncond_regions + _, _, h, w = _tmp_regions.masks.shape + dtype = self.cond_text.embeds.dtype + device = self.cond_text.embeds.device + + regions = [] + for c, r in [ + (self.uncond_text, self.uncond_regions), + (self.cond_text, self.cond_regions), + ]: + if r is None: + # Create a dummy mask and range for text conditioning that doesn't have region masks. + r = TextConditioningRegions( + masks=torch.ones((1, 1, h, w), dtype=dtype), + ranges=[Range(start=0, end=c.embeds.shape[1])], + ) + regions.append(r) + + cross_attention_kwargs = unet_kwargs.get("cross_attention_kwargs", None) + if cross_attention_kwargs is None: + cross_attention_kwargs = dict() + unet_kwargs.update(dict(cross_attention_kwargs=cross_attention_kwargs)) + + cross_attention_kwargs.update(dict( + regional_prompt_data=RegionalPromptData( + regions=regions, device=device, dtype=dtype + ), + )) + def _concat_conditionings_for_batch(self, unconditioning, conditioning): def _pad_conditioning(cond, target_len, encoder_attention_mask): conditioning_attention_mask = torch.ones( diff --git a/invokeai/backend/stable_diffusion/diffusion/regional_prompt_data.py b/invokeai/backend/stable_diffusion/diffusion/regional_prompt_data.py index f09cc0a0d21..7489f9a1659 100644 --- a/invokeai/backend/stable_diffusion/diffusion/regional_prompt_data.py +++ b/invokeai/backend/stable_diffusion/diffusion/regional_prompt_data.py @@ -1,9 +1,12 @@ +from __future__ import annotations import torch import torch.nn.functional as F -from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ( - TextConditioningRegions, -) +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ( + TextConditioningRegions, + ) class RegionalPromptData: diff --git a/invokeai/backend/stable_diffusion/diffusion/unet_attention_patcher.py b/invokeai/backend/stable_diffusion/diffusion/unet_attention_patcher.py index ac00a8e06ea..809d152b405 100644 --- a/invokeai/backend/stable_diffusion/diffusion/unet_attention_patcher.py +++ b/invokeai/backend/stable_diffusion/diffusion/unet_attention_patcher.py @@ -66,3 +66,63 @@ def apply_ip_adapter_attention(self, unet: UNet2DConditionModel): yield None finally: unet.set_attn_processor(orig_attn_processors) + + +class UNetAttentionPatcher_new: + """A class for patching a UNet with CustomAttnProcessor2_0 attention layers.""" + + def __init__( + self, + unet: UNet2DConditionModel, + ): + self.unet = unet + self._orig_attn_processors = None + + def __enter__(self): + ip_adapters = [] + + attn_procs = self._prepare_attention_processors(self.unet, ip_adapters) + self._orig_attn_processors = self.unet.attn_processors + self.unet.set_attn_processor(attn_procs) + + return None + + def __exit__(self, exc_type, exc_value, exc_tb): + self.unet.set_attn_processor(self._orig_attn_processors) + self._orig_attn_processors = None + + + def _prepare_attention_processors(self, unet: UNet2DConditionModel, ip_adapters: list): + """Prepare a dict of attention processors that can be injected into a unet, and load the IP-Adapter attention + weights into them (if IP-Adapters are being applied). + Note that the `unet` param is only used to determine attention block dimensions and naming. + """ + # Construct a dict of attention processors based on the UNet's architecture. + + # TODO: add xformers/normal(?)/sliced + attn_processor_cls = CustomAttnProcessor2_0 + + attn_procs = dict() + for idx, name in enumerate(unet.attn_processors.keys()): + if name.endswith("attn1.processor") or len(ip_adapters) == 0: + # "attn1" processors do not use IP-Adapters. + attn_procs[name] = attn_processor_cls() + else: + # Collect the weights from each IP Adapter for the idx'th attention processor. + ip_adapter_attention_weights_collection: list[IPAdapterAttentionWeights] = [] + + for ip_adapter in ip_adapters: + ip_adapter_weights = ip_adapter.model.attn_weights.get_attention_processor_weights(idx) + skip = True + for block in ip_adapter.target_blocks: + if block in name: + skip = False + break + ip_adapter_attention_weights: IPAdapterAttentionWeights = IPAdapterAttentionWeights( + ip_adapter_weights=ip_adapter_weights, skip=skip + ) + ip_adapter_attention_weights_collection.append(ip_adapter_attention_weights) + + attn_procs[name] = attn_processor_cls(ip_adapter_attention_weights_collection) + + return attn_procs From 3ed219b6e1ddaf2f5ccf8bdd63b17927475f814d Mon Sep 17 00:00:00 2001 From: Sergey Borisov Date: Wed, 26 Jun 2024 20:02:23 +0300 Subject: [PATCH 3/9] Add inpaint support --- .../stable_diffusion/addons/__init__.py | 3 ++ .../stable_diffusion/addons/inpaint_model.py | 46 +++++++++++++++++++ .../stable_diffusion/diffusers_pipeline.py | 33 ++++++++++++- 3 files changed, 80 insertions(+), 2 deletions(-) create mode 100644 invokeai/backend/stable_diffusion/addons/inpaint_model.py diff --git a/invokeai/backend/stable_diffusion/addons/__init__.py b/invokeai/backend/stable_diffusion/addons/__init__.py index 0b62b1aaaf9..1cdaf1940ba 100644 --- a/invokeai/backend/stable_diffusion/addons/__init__.py +++ b/invokeai/backend/stable_diffusion/addons/__init__.py @@ -4,6 +4,9 @@ from .base import AddonBase # noqa: F401 +from .inpaint_model import InpaintModelAddon # noqa: F401 + __all__ = [ "AddonBase", + "InpaintModelAddon", ] \ No newline at end of file diff --git a/invokeai/backend/stable_diffusion/addons/inpaint_model.py b/invokeai/backend/stable_diffusion/addons/inpaint_model.py new file mode 100644 index 00000000000..455a9a54014 --- /dev/null +++ b/invokeai/backend/stable_diffusion/addons/inpaint_model.py @@ -0,0 +1,46 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any, Dict + +import torch +from pydantic import Field + +from invokeai.backend.stable_diffusion.diffusion.conditioning_data import TextConditioningData +from .base import AddonBase + + +@dataclass +class InpaintModelAddon(AddonBase): + mask: Optional[torch.Tensor] = None + masked_latents: Optional[torch.Tensor] = None + + def pre_unet_step( + self, + sample: torch.Tensor, + timestep: torch.Tensor, + step_index: int, + total_steps: int, + conditioning_data: TextConditioningData, + + unet_kwargs: Dict[str, Any], + conditioning_mode: str, + ): + batch_size = sample.shape[0] + if conditioning_mode == "both": + batch_size *= 2 + + if self.mask is None: + self.mask = torch.ones_like(sample[:1, :1]) + + if self.masked_latents is None: + self.masked_latents = torch.zeros_like(sample[:1]) + + b_mask = torch.cat([self.mask] * batch_size) + b_masked_latents = torch.cat([self.masked_latents] * batch_size) + + extra_channels = torch.cat([b_mask, b_masked_latents], dim=1).to(device=sample.device, dtype=sample.dtype) + + unet_kwargs.update(dict( + extra_channels=extra_channels, + )) diff --git a/invokeai/backend/stable_diffusion/diffusers_pipeline.py b/invokeai/backend/stable_diffusion/diffusers_pipeline.py index 32ec8075cd2..cd37f474fb0 100644 --- a/invokeai/backend/stable_diffusion/diffusers_pipeline.py +++ b/invokeai/backend/stable_diffusion/diffusers_pipeline.py @@ -21,7 +21,7 @@ from tqdm.auto import tqdm from invokeai.app.services.config.config_default import get_config -from invokeai.backend.stable_diffusion.addons import AddonBase +from invokeai.backend.stable_diffusion.addons import AddonBase, InpaintModelAddon from invokeai.backend.stable_diffusion.diffusion.conditioning_data import IPAdapterData, TextConditioningData from invokeai.backend.stable_diffusion.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent from invokeai.backend.stable_diffusion.diffusion.unet_attention_patcher import UNetAttentionPatcher, UNetAttentionPatcher_new, UNetIPAdapterData @@ -688,8 +688,27 @@ def latents_from_embeddings2( if timesteps.shape[0] == 0: return latents + inpaint_helper = None + + # if inpaint model used + if is_inpainting_model(self.unet): + if mask is not None and masked_latents is None: + raise Exception("Source image required for inpaint mask when inpaint model used!") + addons.append(InpaintModelAddon(mask=mask, masked_latents=masked_latents)) + + # is normal model used for inpaint + elif mask is not None: + # if no noise provided, noisify unmasked area based on seed + if noise is None: + noise = torch.randn( + orig_latents.shape, + dtype=torch.float32, + device="cpu", + generator=torch.Generator(device="cpu").manual_seed(seed), + ).to(device=orig_latents.device, dtype=orig_latents.dtype) + + inpaint_helper = AddsMaskGuidance(mask, orig_latents, self.scheduler, noise, is_gradient_mask) - # TODO: inpaint # TODO: ip_adapters with UNetAttentionPatcher_new(self.unet): @@ -704,6 +723,8 @@ def latents_from_embeddings2( ) for step_index, timestep in enumerate(tqdm(timesteps)): + if inpaint_helper is not None: + latents = inpaint_helper.apply_mask(latents, timestep) step_output = self.step( timestep, latents, @@ -722,6 +743,10 @@ def latents_from_embeddings2( else: predicted_original = latents + if inpaint_helper is not None: + # TODO: or timesteps[-1] + predicted_original = inpaint_helper.apply_mask(predicted_original, self.scheduler.timesteps[-1]) + callback( PipelineIntermediateState( step=step_index, @@ -954,8 +979,12 @@ def _unet_forward( timestep: Union[torch.Tensor, float, int], encoder_hidden_states: torch.Tensor, cross_attention_kwargs: Optional[dict[str, Any]] = None, + extra_channels: Optional[torch.Tensor] = None, **kwargs, ): + if extra_channels is not None: + sample = torch.cat([sample, extra_channels], dim=1) + return self.unet( sample, timestep, From c09b3deaa6555da7b798eb5b649a8c02546de0fe Mon Sep 17 00:00:00 2001 From: Sergey Borisov Date: Wed, 26 Jun 2024 20:23:38 +0300 Subject: [PATCH 4/9] Add ip adapter support --- .../stable_diffusion/addons/__init__.py | 2 + .../stable_diffusion/addons/ip_adapter.py | 78 +++++++++++++++++++ .../stable_diffusion/diffusers_pipeline.py | 19 ++++- .../diffusion/regional_ip_data.py | 14 +++- .../diffusion/unet_attention_patcher.py | 7 ++ 5 files changed, 116 insertions(+), 4 deletions(-) create mode 100644 invokeai/backend/stable_diffusion/addons/ip_adapter.py diff --git a/invokeai/backend/stable_diffusion/addons/__init__.py b/invokeai/backend/stable_diffusion/addons/__init__.py index 1cdaf1940ba..6246e770d5e 100644 --- a/invokeai/backend/stable_diffusion/addons/__init__.py +++ b/invokeai/backend/stable_diffusion/addons/__init__.py @@ -5,8 +5,10 @@ from .base import AddonBase # noqa: F401 from .inpaint_model import InpaintModelAddon # noqa: F401 +from .ip_adapter import IPAdapterAddon __all__ = [ "AddonBase", "InpaintModelAddon", + "IPAdapterAddon", ] \ No newline at end of file diff --git a/invokeai/backend/stable_diffusion/addons/ip_adapter.py b/invokeai/backend/stable_diffusion/addons/ip_adapter.py new file mode 100644 index 00000000000..e0fe5d46c2c --- /dev/null +++ b/invokeai/backend/stable_diffusion/addons/ip_adapter.py @@ -0,0 +1,78 @@ +from __future__ import annotations + +import math +from dataclasses import dataclass +from typing import Any, List, Dict, Union + +import torch +from pydantic import Field + +from invokeai.backend.stable_diffusion.diffusion.conditioning_data import TextConditioningData +from invokeai.backend.stable_diffusion.diffusion.regional_ip_data import RegionalIPData +from .base import AddonBase + + +@dataclass +class IPAdapterAddon(AddonBase): + model: IPAdapter + conditioning: IPAdapterConditioningInfo + mask: torch.Tensor + target_blocks: List[str] + weight: Union[float, List[float]] = 1.0 + begin_step_percent: float = 0.0 + end_step_percent: float = 1.0 + + def pre_unet_step( + self, + sample: torch.Tensor, + timestep: torch.Tensor, + step_index: int, + total_steps: int, + conditioning_data: TextConditioningData, + + unet_kwargs: Dict[str, Any], + conditioning_mode: str, + ): + # skip if model not active in current step + first_step = math.floor(self.begin_step_percent * total_steps) + last_step = math.ceil(self.end_step_percent * total_steps) + if step_index < first_step or step_index > last_step: + return + + weight = self.weight + if isinstance(weight, List): + weight = weight[step_index] + + if conditioning_mode == "both": + embeds = torch.stack([self.conditioning.uncond_image_prompt_embeds, self.conditioning.cond_image_prompt_embeds]) + elif conditioning_mode == "negative": + embeds = torch.stack([self.conditioning.uncond_image_prompt_embeds]) + else: # elif conditioning_mode == "positive": + embeds = torch.stack([self.conditioning.cond_image_prompt_embeds]) + + + cross_attention_kwargs = unet_kwargs.get("cross_attention_kwargs", None) + if cross_attention_kwargs is None: + cross_attention_kwargs = dict() + unet_kwargs.update(dict(cross_attention_kwargs=cross_attention_kwargs)) + + + regional_ip_data = cross_attention_kwargs.get("regional_ip_data", None) + if regional_ip_data is None: + regional_ip_data = RegionalIPData( + image_prompt_embeds=[], + scales=[], + masks=[], + dtype=sample.dtype, + device=sample.device, + ) + cross_attention_kwargs.update(dict( + regional_ip_data=regional_ip_data, + )) + + + regional_ip_data.add( + embeds=embeds, + scale=weight, + mask=self.mask, + ) diff --git a/invokeai/backend/stable_diffusion/diffusers_pipeline.py b/invokeai/backend/stable_diffusion/diffusers_pipeline.py index cd37f474fb0..0c05371c8b5 100644 --- a/invokeai/backend/stable_diffusion/diffusers_pipeline.py +++ b/invokeai/backend/stable_diffusion/diffusers_pipeline.py @@ -21,7 +21,7 @@ from tqdm.auto import tqdm from invokeai.app.services.config.config_default import get_config -from invokeai.backend.stable_diffusion.addons import AddonBase, InpaintModelAddon +from invokeai.backend.stable_diffusion.addons import AddonBase, InpaintModelAddon, IPAdapterAddon from invokeai.backend.stable_diffusion.diffusion.conditioning_data import IPAdapterData, TextConditioningData from invokeai.backend.stable_diffusion.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent from invokeai.backend.stable_diffusion.diffusion.unet_attention_patcher import UNetAttentionPatcher, UNetAttentionPatcher_new, UNetIPAdapterData @@ -642,6 +642,20 @@ def latents_from_embeddings( addons = [] # TODO: convert controlnet/ip/t2i + if ip_adapter_data is not None: + for ip_info in ip_adapter_data: + addons.append( + IPAdapterAddon( + model=ip_info.ip_adapter_model, + conditioning=ip_info.ip_adapter_conditioning, + mask=ip_info.mask, + target_blocks=ip_info.target_blocks, + weight=ip_info.weight, + begin_step_percent=ip_info.begin_step_percent, + end_step_percent=ip_info.end_step_percent, + ) + ) + return self.latents_from_embeddings2( latents=latents, scheduler_step_kwargs=scheduler_step_kwargs, @@ -710,8 +724,7 @@ def latents_from_embeddings2( inpaint_helper = AddsMaskGuidance(mask, orig_latents, self.scheduler, noise, is_gradient_mask) - # TODO: ip_adapters - with UNetAttentionPatcher_new(self.unet): + with UNetAttentionPatcher_new(self.unet, addons): callback( PipelineIntermediateState( step=-1, diff --git a/invokeai/backend/stable_diffusion/diffusion/regional_ip_data.py b/invokeai/backend/stable_diffusion/diffusion/regional_ip_data.py index 792c97114da..845de6a3f39 100644 --- a/invokeai/backend/stable_diffusion/diffusion/regional_ip_data.py +++ b/invokeai/backend/stable_diffusion/diffusion/regional_ip_data.py @@ -25,11 +25,16 @@ def __init__( # scales[i] contains the attention scale for the i'th IP-Adapter. self.scales = scales + self.masks = masks + self.device = device + self.dtype = dtype + self.max_downscale_factor = max_downscale_factor + # The IP-Adapter masks. # self._masks_by_seq_len[s] contains the spatial masks for the downsampling level with query sequence length of # s. It has shape (batch_size, num_ip_images, query_seq_len, 1). The masks have values of 1.0 for included # regions and 0.0 for excluded regions. - self._masks_by_seq_len = self._prepare_masks(masks, max_downscale_factor, device, dtype) + self._masks_by_seq_len = None # self._prepare_masks(masks, max_downscale_factor, device, dtype) def _prepare_masks( self, masks: list[torch.Tensor], max_downscale_factor: int, device: torch.device, dtype: torch.dtype @@ -69,4 +74,11 @@ def _prepare_masks( def get_masks(self, query_seq_len: int) -> torch.Tensor: """Get the mask for the given query sequence length.""" + if self._masks_by_seq_len is None: + self._masks_by_seq_len = self._prepare_masks(self.masks, self.max_downscale_factor, self.device, self.dtype) return self._masks_by_seq_len[query_seq_len] + + def add(self, embeds: torch.Tensor, scale: float, mask: torch.Tensor): + self.image_prompt_embeds.append(embeds) + self.scales.append(scale) + self.masks.append(mask) diff --git a/invokeai/backend/stable_diffusion/diffusion/unet_attention_patcher.py b/invokeai/backend/stable_diffusion/diffusion/unet_attention_patcher.py index 809d152b405..f3a5aa07373 100644 --- a/invokeai/backend/stable_diffusion/diffusion/unet_attention_patcher.py +++ b/invokeai/backend/stable_diffusion/diffusion/unet_attention_patcher.py @@ -4,6 +4,7 @@ from diffusers.models import UNet2DConditionModel from invokeai.backend.ip_adapter.ip_adapter import IPAdapter +from invokeai.backend.stable_diffusion.addons import AddonBase, IPAdapterAddon from invokeai.backend.stable_diffusion.diffusion.custom_atttention import ( CustomAttnProcessor2_0, IPAdapterAttentionWeights, @@ -74,13 +75,19 @@ class UNetAttentionPatcher_new: def __init__( self, unet: UNet2DConditionModel, + addons: List[AddonBase], ): self.unet = unet + self.addons = addons self._orig_attn_processors = None def __enter__(self): ip_adapters = [] + for addon in self.addons: + if isinstance(addon, IPAdapterAddon): + ip_adapters.append(addon) + attn_procs = self._prepare_attention_processors(self.unet, ip_adapters) self._orig_attn_processors = self.unet.attn_processors self.unet.set_attn_processor(attn_procs) From 87b375f41c56fb79d00f1926bbcbe961a7be80bf Mon Sep 17 00:00:00 2001 From: Sergey Borisov Date: Wed, 26 Jun 2024 20:37:40 +0300 Subject: [PATCH 5/9] Add controlnet support --- .../stable_diffusion/addons/__init__.py | 2 + .../stable_diffusion/addons/controlnet.py | 141 ++++++++++++++++++ .../stable_diffusion/diffusers_pipeline.py | 18 ++- .../diffusion/unet_attention_patcher.py | 22 ++- 4 files changed, 174 insertions(+), 9 deletions(-) create mode 100644 invokeai/backend/stable_diffusion/addons/controlnet.py diff --git a/invokeai/backend/stable_diffusion/addons/__init__.py b/invokeai/backend/stable_diffusion/addons/__init__.py index 6246e770d5e..13f3eecaf07 100644 --- a/invokeai/backend/stable_diffusion/addons/__init__.py +++ b/invokeai/backend/stable_diffusion/addons/__init__.py @@ -6,9 +6,11 @@ from .inpaint_model import InpaintModelAddon # noqa: F401 from .ip_adapter import IPAdapterAddon +from .controlnet import ControlNetAddon __all__ = [ "AddonBase", "InpaintModelAddon", "IPAdapterAddon", + "ControlNetAddon", ] \ No newline at end of file diff --git a/invokeai/backend/stable_diffusion/addons/controlnet.py b/invokeai/backend/stable_diffusion/addons/controlnet.py new file mode 100644 index 00000000000..fd28d1b187c --- /dev/null +++ b/invokeai/backend/stable_diffusion/addons/controlnet.py @@ -0,0 +1,141 @@ +from __future__ import annotations + +import math +from dataclasses import dataclass +from typing import Any, List, Dict, Union + +import torch +from pydantic import Field + +from invokeai.backend.stable_diffusion.diffusion.conditioning_data import TextConditioningData +from invokeai.backend.util.hotfixes import ControlNetModel +from .base import AddonBase + + +@dataclass +class ControlNetAddon(AddonBase): + model: ControlNetModel = Field(default=None) + image_tensor: torch.Tensor = Field(default=None) + weight: Union[float, List[float]] = Field(default=1.0) + begin_step_percent: float = Field(default=0.0) + end_step_percent: float = Field(default=1.0) + control_mode: str = Field(default="balanced") + resize_mode: str = Field(default="just_resize") + + def pre_unet_step( + self, + sample: torch.Tensor, + timestep: torch.Tensor, + step_index: int, + total_steps: int, + conditioning_data: TextConditioningData, + + unet_kwargs: Dict[str, Any], + conditioning_mode: str, + ): + # skip if model not active in current step + first_step = math.floor(self.begin_step_percent * total_steps) + last_step = math.ceil(self.end_step_percent * total_steps) + if step_index < first_step or step_index > last_step: + return + + # convert mode to internal flags + soft_injection = self.control_mode in ["more_prompt", "more_control"] + cfg_injection = self.control_mode in ["more_control", "unbalanced"] + + # skip, as negative not runned in cfg_injection mode + if cfg_injection and conditioning_mode == "negative": + return + + cn_unet_kwargs = dict( + cross_attention_kwargs=dict( + percent_through=step_index / total_steps, + ) + ) + + if conditioning_mode == "both": + if cfg_injection: + conditioning_data.to_unet_kwargs(cn_unet_kwargs, conditioning_mode="positive") + + down_samples, mid_sample = self._run( + sample=sample, + timestep=timestep, + step_index=step_index, + guess_mode=soft_injection, + unet_kwargs=cn_unet_kwargs, + ) + # add zeros as samples for negative conditioning + down_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_samples] + mid_sample = torch.cat([torch.zeros_like(mid_sample), mid_sample]) + + else: + conditioning_data.to_unet_kwargs(cn_unet_kwargs, conditioning_mode="both") + down_samples, mid_sample = self._run( + sample=torch.cat([sample] * 2), + timestep=timestep, + step_index=step_index, + guess_mode=soft_injection, + unet_kwargs=cn_unet_kwargs, + ) + + else: # elif in ["negative", "positive"]: + conditioning_data.to_unet_kwargs(cn_unet_kwargs, conditioning_mode=conditioning_mode) + + down_samples, mid_sample = self._run( + sample=sample, + timestep=timestep, + step_index=step_index, + guess_mode=soft_injection, + unet_kwargs=cn_unet_kwargs, + ) + + + down_block_additional_residuals = unet_kwargs.get("down_block_additional_residuals", None) + mid_block_additional_residual = unet_kwargs.get("mid_block_additional_residual", None) + + if down_block_additional_residuals is None and mid_block_additional_residual is None: + down_block_additional_residuals, mid_block_additional_residual = down_samples, mid_sample + else: + # add controlnet outputs together if have multiple controlnets + down_block_additional_residuals = [ + samples_prev + samples_curr + for samples_prev, samples_curr in zip(down_block_additional_residuals, down_samples, strict=True) + ] + mid_block_additional_residual += mid_sample + + unet_kwargs.update(dict( + down_block_additional_residuals=down_block_additional_residuals, + mid_block_additional_residual=mid_block_additional_residual, + )) + + + def _run( + self, + sample, + timestep, + step_index, + guess_mode, + unet_kwargs, + ): + # get static weight, or weight corresponding to current step + weight = self.weight + if isinstance(weight, list): + weight = weight[step_index] + + # controlnet(s) inference + down_samples, mid_sample = self.model( + sample=sample, + timestep=timestep, + controlnet_cond=self.image_tensor, + conditioning_scale=weight, # controlnet specific, NOT the guidance scale + guess_mode=guess_mode, # this is still called guess_mode in diffusers ControlNetModel + return_dict=False, + + + **unet_kwargs, + #added_cond_kwargs=added_cond_kwargs, + #encoder_hidden_states=encoder_hidden_states, + #encoder_attention_mask=encoder_attention_mask, + ) + + return down_samples, mid_sample diff --git a/invokeai/backend/stable_diffusion/diffusers_pipeline.py b/invokeai/backend/stable_diffusion/diffusers_pipeline.py index 0c05371c8b5..09f80ba2e4d 100644 --- a/invokeai/backend/stable_diffusion/diffusers_pipeline.py +++ b/invokeai/backend/stable_diffusion/diffusers_pipeline.py @@ -21,7 +21,7 @@ from tqdm.auto import tqdm from invokeai.app.services.config.config_default import get_config -from invokeai.backend.stable_diffusion.addons import AddonBase, InpaintModelAddon, IPAdapterAddon +from invokeai.backend.stable_diffusion.addons import AddonBase, InpaintModelAddon, IPAdapterAddon, ControlNetAddon from invokeai.backend.stable_diffusion.diffusion.conditioning_data import IPAdapterData, TextConditioningData from invokeai.backend.stable_diffusion.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent from invokeai.backend.stable_diffusion.diffusion.unet_attention_patcher import UNetAttentionPatcher, UNetAttentionPatcher_new, UNetIPAdapterData @@ -640,7 +640,21 @@ def latents_from_embeddings( ) -> torch.Tensor: addons = [] - # TODO: convert controlnet/ip/t2i + # TODO: convert t2i + + if control_data is not None: + for cn_info in control_data: + addons.append( + ControlNetAddon( + model=cn_info.model, + image_tensor=cn_info.image_tensor, + weight=cn_info.weight, + begin_step_percent=cn_info.begin_step_percent, + end_step_percent=cn_info.end_step_percent, + control_mode=cn_info.control_mode, + resize_mode=cn_info.resize_mode, #? + ) + ) if ip_adapter_data is not None: for ip_info in ip_adapter_data: diff --git a/invokeai/backend/stable_diffusion/diffusion/unet_attention_patcher.py b/invokeai/backend/stable_diffusion/diffusion/unet_attention_patcher.py index f3a5aa07373..8ed7b825729 100644 --- a/invokeai/backend/stable_diffusion/diffusion/unet_attention_patcher.py +++ b/invokeai/backend/stable_diffusion/diffusion/unet_attention_patcher.py @@ -4,7 +4,7 @@ from diffusers.models import UNet2DConditionModel from invokeai.backend.ip_adapter.ip_adapter import IPAdapter -from invokeai.backend.stable_diffusion.addons import AddonBase, IPAdapterAddon +from invokeai.backend.stable_diffusion.addons import AddonBase, IPAdapterAddon, ControlNetAddon from invokeai.backend.stable_diffusion.diffusion.custom_atttention import ( CustomAttnProcessor2_0, IPAdapterAttentionWeights, @@ -79,24 +79,32 @@ def __init__( ): self.unet = unet self.addons = addons - self._orig_attn_processors = None + self._orig_attn_processors = dict() def __enter__(self): ip_adapters = [] - for addon in self.addons: - if isinstance(addon, IPAdapterAddon): + # apply attention processor to controlnets for handling prompt regions + if isinstance(addon, ControlNetAddon): + attn_procs = self._prepare_attention_processors(addon.model, ip_adapters=[]) + self._orig_attn_processors[addon.model] = addon.model.attn_processors + addon.model.set_attn_processor(attn_procs) + + # collect ip adapters for main unet + elif isinstance(addon, IPAdapterAddon): ip_adapters.append(addon) + # apply attention processor with ip adapters to main unet attn_procs = self._prepare_attention_processors(self.unet, ip_adapters) - self._orig_attn_processors = self.unet.attn_processors + self._orig_attn_processors[self.unet] = self.unet.attn_processors self.unet.set_attn_processor(attn_procs) return None def __exit__(self, exc_type, exc_value, exc_tb): - self.unet.set_attn_processor(self._orig_attn_processors) - self._orig_attn_processors = None + for model, attn_processors in self._orig_attn_processors.items(): + model.set_attn_processor(attn_processors) + self._orig_attn_processors.clear() def _prepare_attention_processors(self, unet: UNet2DConditionModel, ip_adapters: list): From 1c3e2bb138b7733efe3b1663f00e1733a9551ecf Mon Sep 17 00:00:00 2001 From: Sergey Borisov Date: Wed, 26 Jun 2024 20:44:51 +0300 Subject: [PATCH 6/9] Add t2i adapter support --- .../stable_diffusion/addons/__init__.py | 6 ++- .../stable_diffusion/addons/t2i_adapter.py | 52 +++++++++++++++++++ .../stable_diffusion/diffusers_pipeline.py | 14 ++++- 3 files changed, 68 insertions(+), 4 deletions(-) create mode 100644 invokeai/backend/stable_diffusion/addons/t2i_adapter.py diff --git a/invokeai/backend/stable_diffusion/addons/__init__.py b/invokeai/backend/stable_diffusion/addons/__init__.py index 13f3eecaf07..32b491ae847 100644 --- a/invokeai/backend/stable_diffusion/addons/__init__.py +++ b/invokeai/backend/stable_diffusion/addons/__init__.py @@ -5,12 +5,14 @@ from .base import AddonBase # noqa: F401 from .inpaint_model import InpaintModelAddon # noqa: F401 -from .ip_adapter import IPAdapterAddon -from .controlnet import ControlNetAddon +from .ip_adapter import IPAdapterAddon # noqa: F401 +from .controlnet import ControlNetAddon # noqa: F401 +from .t2i_adapter import T2IAdapterAddon # noqa: F401 __all__ = [ "AddonBase", "InpaintModelAddon", "IPAdapterAddon", "ControlNetAddon", + "T2IAdapterAddon", ] \ No newline at end of file diff --git a/invokeai/backend/stable_diffusion/addons/t2i_adapter.py b/invokeai/backend/stable_diffusion/addons/t2i_adapter.py new file mode 100644 index 00000000000..ac431213341 --- /dev/null +++ b/invokeai/backend/stable_diffusion/addons/t2i_adapter.py @@ -0,0 +1,52 @@ +from __future__ import annotations + +import math +from dataclasses import dataclass +from typing import Any, List, Dict, Union + +import torch +from pydantic import Field + +from invokeai.backend.stable_diffusion.diffusion.conditioning_data import TextConditioningData +from .base import AddonBase + + +@dataclass +class T2IAdapterAddon(AddonBase): + adapter_state: List[torch.Tensor] = Field() # TODO: why here was dict before + weight: Union[float, List[float]] = Field(default=1.0) + begin_step_percent: float = Field(default=0.0) + end_step_percent: float = Field(default=1.0) + + def pre_unet_step( + self, + sample: torch.Tensor, + timestep: torch.Tensor, + step_index: int, + total_steps: int, + conditioning_data: TextConditioningData, + + unet_kwargs: Dict[str, Any], + conditioning_mode: str, + ): + # skip if model not active in current step + first_step = math.floor(self.begin_step_percent * total_steps) + last_step = math.ceil(self.end_step_percent * total_steps) + if step_index < first_step or step_index > last_step: + return + + weight = self.weight + if isinstance(weight, list): + weight = weight[step_index] + + # TODO: conditioning_mode? + down_intrablock_additional_residuals = unet_kwargs.get("down_intrablock_additional_residuals", None) + if down_intrablock_additional_residuals is None: + down_intrablock_additional_residuals = [v * weight for v in self.adapter_state] + else: + for i, value in enumerate(self.adapter_state): + down_intrablock_additional_residuals[i] += value * weight + + unet_kwargs.update(dict( + down_intrablock_additional_residuals=down_intrablock_additional_residuals, + )) diff --git a/invokeai/backend/stable_diffusion/diffusers_pipeline.py b/invokeai/backend/stable_diffusion/diffusers_pipeline.py index 09f80ba2e4d..df0a30cf544 100644 --- a/invokeai/backend/stable_diffusion/diffusers_pipeline.py +++ b/invokeai/backend/stable_diffusion/diffusers_pipeline.py @@ -21,7 +21,7 @@ from tqdm.auto import tqdm from invokeai.app.services.config.config_default import get_config -from invokeai.backend.stable_diffusion.addons import AddonBase, InpaintModelAddon, IPAdapterAddon, ControlNetAddon +from invokeai.backend.stable_diffusion.addons import AddonBase, InpaintModelAddon, IPAdapterAddon, ControlNetAddon, T2IAdapterAddon from invokeai.backend.stable_diffusion.diffusion.conditioning_data import IPAdapterData, TextConditioningData from invokeai.backend.stable_diffusion.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent from invokeai.backend.stable_diffusion.diffusion.unet_attention_patcher import UNetAttentionPatcher, UNetAttentionPatcher_new, UNetIPAdapterData @@ -640,7 +640,6 @@ def latents_from_embeddings( ) -> torch.Tensor: addons = [] - # TODO: convert t2i if control_data is not None: for cn_info in control_data: @@ -670,6 +669,17 @@ def latents_from_embeddings( ) ) + if t2i_adapter_data is not None: + for t2i_info in t2i_adapter_data: + addons.append( + T2IAdapterAddon( + adapter_state=t2i_info.adapter_state, + weight=t2i_info.weight, + begin_step_percent=t2i_info.begin_step_percent, + end_step_percent=t2i_info.end_step_percent, + ) + ) + return self.latents_from_embeddings2( latents=latents, scheduler_step_kwargs=scheduler_step_kwargs, From a92379f252ef910befaf7b56f1383f212215b87d Mon Sep 17 00:00:00 2001 From: Sergey Borisov Date: Wed, 26 Jun 2024 20:48:31 +0300 Subject: [PATCH 7/9] Fix copypaste --- invokeai/backend/stable_diffusion/diffusers_pipeline.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/invokeai/backend/stable_diffusion/diffusers_pipeline.py b/invokeai/backend/stable_diffusion/diffusers_pipeline.py index df0a30cf544..cbbdf957814 100644 --- a/invokeai/backend/stable_diffusion/diffusers_pipeline.py +++ b/invokeai/backend/stable_diffusion/diffusers_pipeline.py @@ -886,7 +886,7 @@ def _apply_standard_conditioning( """Runs the conditioned and unconditioned UNet forward passes in a single batch for faster inference speed at the cost of higher memory usage. """ - unet_kwargs = unet_kwargs = dict( + unet_kwargs = dict( cross_attention_kwargs=dict( percent_through=step_index / total_steps, ) @@ -946,7 +946,7 @@ def _apply_standard_conditioning_sequentially( # Negative pass ################### - negative_unet_kwargs = unet_kwargs = dict( + negative_unet_kwargs = dict( cross_attention_kwargs=dict( percent_through=step_index / total_steps, ) @@ -978,7 +978,7 @@ def _apply_standard_conditioning_sequentially( # Positive pass ################### - positive_unet_kwargs = unet_kwargs = dict( + positive_unet_kwargs = dict( cross_attention_kwargs=dict( percent_through=step_index / total_steps, ) From 1a550e4446b16a1eda0932088b67ef67c95fbe62 Mon Sep 17 00:00:00 2001 From: Sergey Borisov Date: Wed, 26 Jun 2024 21:18:35 +0300 Subject: [PATCH 8/9] A bit simplify preview event calling --- .../stable_diffusion/diffusers_pipeline.py | 22 ++++++++----------- 1 file changed, 9 insertions(+), 13 deletions(-) diff --git a/invokeai/backend/stable_diffusion/diffusers_pipeline.py b/invokeai/backend/stable_diffusion/diffusers_pipeline.py index cbbdf957814..2074cbd595c 100644 --- a/invokeai/backend/stable_diffusion/diffusers_pipeline.py +++ b/invokeai/backend/stable_diffusion/diffusers_pipeline.py @@ -748,17 +748,22 @@ def latents_from_embeddings2( inpaint_helper = AddsMaskGuidance(mask, orig_latents, self.scheduler, noise, is_gradient_mask) - with UNetAttentionPatcher_new(self.unet, addons): + def report_progress(step: int, latents: torch.Tensor): callback( PipelineIntermediateState( - step=-1, + step=step, order=self.scheduler.order, total_steps=len(timesteps), - timestep=self.scheduler.config.num_train_timesteps, + timestep=int(timesteps[step]), # TODO: is there any code which uses it? latents=latents, + predicted_original=latents, # TODO: is there any reason for additional field? ) ) + + with UNetAttentionPatcher_new(self.unet, addons): + report_progress(step=-1, latents=latents) + for step_index, timestep in enumerate(tqdm(timesteps)): if inpaint_helper is not None: latents = inpaint_helper.apply_mask(latents, timestep) @@ -784,16 +789,7 @@ def latents_from_embeddings2( # TODO: or timesteps[-1] predicted_original = inpaint_helper.apply_mask(predicted_original, self.scheduler.timesteps[-1]) - callback( - PipelineIntermediateState( - step=step_index, - order=self.scheduler.order, - total_steps=len(timesteps), - timestep=int(timestep), - latents=latents, - predicted_original=predicted_original, - ) - ) + report_progress(step=-1, latents=predicted_original) # restore unmasked part after the last step is completed # in-process masking happens before each step From 6027f2a2ecc53d1bc946f7a84254b06933d59a0e Mon Sep 17 00:00:00 2001 From: Sergey Borisov Date: Thu, 27 Jun 2024 02:31:57 +0300 Subject: [PATCH 9/9] Fix step count in progress event --- invokeai/backend/stable_diffusion/diffusers_pipeline.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/invokeai/backend/stable_diffusion/diffusers_pipeline.py b/invokeai/backend/stable_diffusion/diffusers_pipeline.py index 2074cbd595c..76ff5cda64a 100644 --- a/invokeai/backend/stable_diffusion/diffusers_pipeline.py +++ b/invokeai/backend/stable_diffusion/diffusers_pipeline.py @@ -789,7 +789,7 @@ def report_progress(step: int, latents: torch.Tensor): # TODO: or timesteps[-1] predicted_original = inpaint_helper.apply_mask(predicted_original, self.scheduler.timesteps[-1]) - report_progress(step=-1, latents=predicted_original) + report_progress(step=step_index, latents=predicted_original) # restore unmasked part after the last step is completed # in-process masking happens before each step