Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions invokeai/backend/stable_diffusion/addons/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
"""
Initialization file for the invokeai.backend.stable_diffusion.addons package
"""

from .base import AddonBase # noqa: F401

from .inpaint_model import InpaintModelAddon # noqa: F401
from .ip_adapter import IPAdapterAddon # noqa: F401
from .controlnet import ControlNetAddon # noqa: F401
from .t2i_adapter import T2IAdapterAddon # noqa: F401

__all__ = [
"AddonBase",
"InpaintModelAddon",
"IPAdapterAddon",
"ControlNetAddon",
"T2IAdapterAddon",
]
23 changes: 23 additions & 0 deletions invokeai/backend/stable_diffusion/addons/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from __future__ import annotations

import torch
from typing import Any, Dict
from abc import ABC, abstractmethod
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import TextConditioningData


class AddonBase(ABC):

@abstractmethod
def pre_unet_step(
self,
sample: torch.Tensor,
timestep: torch.Tensor,
step_index: int,
total_steps: int,
conditioning_data: TextConditioningData,

unet_kwargs: Dict[str, Any],
conditioning_mode: str,
):
pass
141 changes: 141 additions & 0 deletions invokeai/backend/stable_diffusion/addons/controlnet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
from __future__ import annotations

import math
from dataclasses import dataclass
from typing import Any, List, Dict, Union

import torch
from pydantic import Field

from invokeai.backend.stable_diffusion.diffusion.conditioning_data import TextConditioningData
from invokeai.backend.util.hotfixes import ControlNetModel
from .base import AddonBase


@dataclass
class ControlNetAddon(AddonBase):
model: ControlNetModel = Field(default=None)
image_tensor: torch.Tensor = Field(default=None)
weight: Union[float, List[float]] = Field(default=1.0)
begin_step_percent: float = Field(default=0.0)
end_step_percent: float = Field(default=1.0)
control_mode: str = Field(default="balanced")
resize_mode: str = Field(default="just_resize")

def pre_unet_step(
self,
sample: torch.Tensor,
timestep: torch.Tensor,
step_index: int,
total_steps: int,
conditioning_data: TextConditioningData,

unet_kwargs: Dict[str, Any],
conditioning_mode: str,
):
# skip if model not active in current step
first_step = math.floor(self.begin_step_percent * total_steps)
last_step = math.ceil(self.end_step_percent * total_steps)
if step_index < first_step or step_index > last_step:
return

# convert mode to internal flags
soft_injection = self.control_mode in ["more_prompt", "more_control"]
cfg_injection = self.control_mode in ["more_control", "unbalanced"]

# skip, as negative not runned in cfg_injection mode
if cfg_injection and conditioning_mode == "negative":
return

cn_unet_kwargs = dict(
cross_attention_kwargs=dict(
percent_through=step_index / total_steps,
)
)

if conditioning_mode == "both":
if cfg_injection:
conditioning_data.to_unet_kwargs(cn_unet_kwargs, conditioning_mode="positive")

down_samples, mid_sample = self._run(
sample=sample,
timestep=timestep,
step_index=step_index,
guess_mode=soft_injection,
unet_kwargs=cn_unet_kwargs,
)
# add zeros as samples for negative conditioning
down_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_samples]
mid_sample = torch.cat([torch.zeros_like(mid_sample), mid_sample])

else:
conditioning_data.to_unet_kwargs(cn_unet_kwargs, conditioning_mode="both")
down_samples, mid_sample = self._run(
sample=torch.cat([sample] * 2),
timestep=timestep,
step_index=step_index,
guess_mode=soft_injection,
unet_kwargs=cn_unet_kwargs,
)

else: # elif in ["negative", "positive"]:
conditioning_data.to_unet_kwargs(cn_unet_kwargs, conditioning_mode=conditioning_mode)

down_samples, mid_sample = self._run(
sample=sample,
timestep=timestep,
step_index=step_index,
guess_mode=soft_injection,
unet_kwargs=cn_unet_kwargs,
)


down_block_additional_residuals = unet_kwargs.get("down_block_additional_residuals", None)
mid_block_additional_residual = unet_kwargs.get("mid_block_additional_residual", None)

if down_block_additional_residuals is None and mid_block_additional_residual is None:
down_block_additional_residuals, mid_block_additional_residual = down_samples, mid_sample
else:
# add controlnet outputs together if have multiple controlnets
down_block_additional_residuals = [
samples_prev + samples_curr
for samples_prev, samples_curr in zip(down_block_additional_residuals, down_samples, strict=True)
]
mid_block_additional_residual += mid_sample

unet_kwargs.update(dict(
down_block_additional_residuals=down_block_additional_residuals,
mid_block_additional_residual=mid_block_additional_residual,
))


def _run(
self,
sample,
timestep,
step_index,
guess_mode,
unet_kwargs,
):
# get static weight, or weight corresponding to current step
weight = self.weight
if isinstance(weight, list):
weight = weight[step_index]

# controlnet(s) inference
down_samples, mid_sample = self.model(
sample=sample,
timestep=timestep,
controlnet_cond=self.image_tensor,
conditioning_scale=weight, # controlnet specific, NOT the guidance scale
guess_mode=guess_mode, # this is still called guess_mode in diffusers ControlNetModel
return_dict=False,


**unet_kwargs,
#added_cond_kwargs=added_cond_kwargs,
#encoder_hidden_states=encoder_hidden_states,
#encoder_attention_mask=encoder_attention_mask,
)

return down_samples, mid_sample
46 changes: 46 additions & 0 deletions invokeai/backend/stable_diffusion/addons/inpaint_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
from __future__ import annotations

from dataclasses import dataclass
from typing import Any, Dict

import torch
from pydantic import Field

from invokeai.backend.stable_diffusion.diffusion.conditioning_data import TextConditioningData
from .base import AddonBase


@dataclass
class InpaintModelAddon(AddonBase):
mask: Optional[torch.Tensor] = None
masked_latents: Optional[torch.Tensor] = None

def pre_unet_step(
self,
sample: torch.Tensor,
timestep: torch.Tensor,
step_index: int,
total_steps: int,
conditioning_data: TextConditioningData,

unet_kwargs: Dict[str, Any],
conditioning_mode: str,
):
batch_size = sample.shape[0]
if conditioning_mode == "both":
batch_size *= 2

if self.mask is None:
self.mask = torch.ones_like(sample[:1, :1])

if self.masked_latents is None:
self.masked_latents = torch.zeros_like(sample[:1])

b_mask = torch.cat([self.mask] * batch_size)
b_masked_latents = torch.cat([self.masked_latents] * batch_size)

extra_channels = torch.cat([b_mask, b_masked_latents], dim=1).to(device=sample.device, dtype=sample.dtype)

unet_kwargs.update(dict(
extra_channels=extra_channels,
))
78 changes: 78 additions & 0 deletions invokeai/backend/stable_diffusion/addons/ip_adapter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
from __future__ import annotations

import math
from dataclasses import dataclass
from typing import Any, List, Dict, Union

import torch
from pydantic import Field

from invokeai.backend.stable_diffusion.diffusion.conditioning_data import TextConditioningData
from invokeai.backend.stable_diffusion.diffusion.regional_ip_data import RegionalIPData
from .base import AddonBase


@dataclass
class IPAdapterAddon(AddonBase):
model: IPAdapter
conditioning: IPAdapterConditioningInfo
mask: torch.Tensor
target_blocks: List[str]
weight: Union[float, List[float]] = 1.0
begin_step_percent: float = 0.0
end_step_percent: float = 1.0

def pre_unet_step(
self,
sample: torch.Tensor,
timestep: torch.Tensor,
step_index: int,
total_steps: int,
conditioning_data: TextConditioningData,

unet_kwargs: Dict[str, Any],
conditioning_mode: str,
):
# skip if model not active in current step
first_step = math.floor(self.begin_step_percent * total_steps)
last_step = math.ceil(self.end_step_percent * total_steps)
if step_index < first_step or step_index > last_step:
return

weight = self.weight
if isinstance(weight, List):
weight = weight[step_index]

if conditioning_mode == "both":
embeds = torch.stack([self.conditioning.uncond_image_prompt_embeds, self.conditioning.cond_image_prompt_embeds])
elif conditioning_mode == "negative":
embeds = torch.stack([self.conditioning.uncond_image_prompt_embeds])
else: # elif conditioning_mode == "positive":
embeds = torch.stack([self.conditioning.cond_image_prompt_embeds])


cross_attention_kwargs = unet_kwargs.get("cross_attention_kwargs", None)
if cross_attention_kwargs is None:
cross_attention_kwargs = dict()
unet_kwargs.update(dict(cross_attention_kwargs=cross_attention_kwargs))


regional_ip_data = cross_attention_kwargs.get("regional_ip_data", None)
if regional_ip_data is None:
regional_ip_data = RegionalIPData(
image_prompt_embeds=[],
scales=[],
masks=[],
dtype=sample.dtype,
device=sample.device,
)
cross_attention_kwargs.update(dict(
regional_ip_data=regional_ip_data,
))


regional_ip_data.add(
embeds=embeds,
scale=weight,
mask=self.mask,
)
52 changes: 52 additions & 0 deletions invokeai/backend/stable_diffusion/addons/t2i_adapter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
from __future__ import annotations

import math
from dataclasses import dataclass
from typing import Any, List, Dict, Union

import torch
from pydantic import Field

from invokeai.backend.stable_diffusion.diffusion.conditioning_data import TextConditioningData
from .base import AddonBase


@dataclass
class T2IAdapterAddon(AddonBase):
adapter_state: List[torch.Tensor] = Field() # TODO: why here was dict before
weight: Union[float, List[float]] = Field(default=1.0)
begin_step_percent: float = Field(default=0.0)
end_step_percent: float = Field(default=1.0)

def pre_unet_step(
self,
sample: torch.Tensor,
timestep: torch.Tensor,
step_index: int,
total_steps: int,
conditioning_data: TextConditioningData,

unet_kwargs: Dict[str, Any],
conditioning_mode: str,
):
# skip if model not active in current step
first_step = math.floor(self.begin_step_percent * total_steps)
last_step = math.ceil(self.end_step_percent * total_steps)
if step_index < first_step or step_index > last_step:
return

weight = self.weight
if isinstance(weight, list):
weight = weight[step_index]

# TODO: conditioning_mode?
down_intrablock_additional_residuals = unet_kwargs.get("down_intrablock_additional_residuals", None)
if down_intrablock_additional_residuals is None:
down_intrablock_additional_residuals = [v * weight for v in self.adapter_state]
else:
for i, value in enumerate(self.adapter_state):
down_intrablock_additional_residuals[i] += value * weight

unet_kwargs.update(dict(
down_intrablock_additional_residuals=down_intrablock_additional_residuals,
))
Loading