Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 10 additions & 10 deletions src/diffusers/pipelines/flux/pipeline_flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,13 +195,13 @@ def __init__(
scheduler=scheduler,
)
self.vae_scale_factor = (
2 ** (len(self.vae.config.block_out_channels)) if hasattr(self, "vae") and self.vae is not None else 16
2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.tokenizer_max_length = (
self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77
)
self.default_sample_size = 64
self.default_sample_size = 128

def _get_t5_prompt_embeds(
self,
Expand Down Expand Up @@ -425,9 +425,9 @@ def check_inputs(

@staticmethod
def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
latent_image_ids = torch.zeros(height // 2, width // 2, 3)
latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None]
latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :]
latent_image_ids = torch.zeros(height, width, 3)
latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None]
latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :]

latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape

Expand All @@ -452,10 +452,10 @@ def _unpack_latents(latents, height, width, vae_scale_factor):
height = height // vae_scale_factor
width = width // vae_scale_factor

latents = latents.view(batch_size, height, width, channels // 4, 2, 2)
latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2)
latents = latents.permute(0, 3, 1, 4, 2, 5)

latents = latents.reshape(batch_size, channels // (2 * 2), height * 2, width * 2)
latents = latents.reshape(batch_size, channels // (2 * 2), height, width)

return latents

Expand Down Expand Up @@ -499,8 +499,8 @@ def prepare_latents(
generator,
latents=None,
):
height = 2 * (int(height) // self.vae_scale_factor)
width = 2 * (int(width) // self.vae_scale_factor)
height = int(height) // self.vae_scale_factor
width = int(width) // self.vae_scale_factor

shape = (batch_size, num_channels_latents, height, width)

Expand All @@ -517,7 +517,7 @@ def prepare_latents(
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)

latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype)
latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)

return latents, latent_image_ids

Expand Down
22 changes: 11 additions & 11 deletions src/diffusers/pipelines/flux/pipeline_flux_controlnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,13 +216,13 @@ def __init__(
controlnet=controlnet,
)
self.vae_scale_factor = (
2 ** (len(self.vae.config.block_out_channels)) if hasattr(self, "vae") and self.vae is not None else 16
2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.tokenizer_max_length = (
self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77
)
self.default_sample_size = 64
self.default_sample_size = 128

def _get_t5_prompt_embeds(
self,
Expand Down Expand Up @@ -450,9 +450,9 @@ def check_inputs(
@staticmethod
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._prepare_latent_image_ids
def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
latent_image_ids = torch.zeros(height // 2, width // 2, 3)
latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None]
latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :]
latent_image_ids = torch.zeros(height, width, 3)
latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None]
latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :]

latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape

Expand All @@ -479,10 +479,10 @@ def _unpack_latents(latents, height, width, vae_scale_factor):
height = height // vae_scale_factor
width = width // vae_scale_factor

latents = latents.view(batch_size, height, width, channels // 4, 2, 2)
latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2)
latents = latents.permute(0, 3, 1, 4, 2, 5)

latents = latents.reshape(batch_size, channels // (2 * 2), height * 2, width * 2)
latents = latents.reshape(batch_size, channels // (2 * 2), height, width)

return latents

Expand All @@ -498,13 +498,13 @@ def prepare_latents(
generator,
latents=None,
):
height = 2 * (int(height) // self.vae_scale_factor)
width = 2 * (int(width) // self.vae_scale_factor)
height = int(height) // self.vae_scale_factor
width = int(width) // self.vae_scale_factor

shape = (batch_size, num_channels_latents, height, width)

if latents is not None:
latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype)
latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
return latents.to(device=device, dtype=dtype), latent_image_ids

if isinstance(generator, list) and len(generator) != batch_size:
Expand All @@ -516,7 +516,7 @@ def prepare_latents(
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)

latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype)
latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)

return latents, latent_image_ids

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -228,13 +228,13 @@ def __init__(
controlnet=controlnet,
)
self.vae_scale_factor = (
2 ** (len(self.vae.config.block_out_channels)) if hasattr(self, "vae") and self.vae is not None else 16
2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.tokenizer_max_length = (
self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77
)
self.default_sample_size = 64
self.default_sample_size = 128

# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_t5_prompt_embeds
def _get_t5_prompt_embeds(
Expand Down Expand Up @@ -493,9 +493,9 @@ def check_inputs(
@staticmethod
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._prepare_latent_image_ids
def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
latent_image_ids = torch.zeros(height // 2, width // 2, 3)
latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None]
latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :]
latent_image_ids = torch.zeros(height, width, 3)
latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None]
latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :]

latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape

Expand All @@ -522,10 +522,10 @@ def _unpack_latents(latents, height, width, vae_scale_factor):
height = height // vae_scale_factor
width = width // vae_scale_factor

latents = latents.view(batch_size, height, width, channels // 4, 2, 2)
latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2)
latents = latents.permute(0, 3, 1, 4, 2, 5)

latents = latents.reshape(batch_size, channels // (2 * 2), height * 2, width * 2)
latents = latents.reshape(batch_size, channels // (2 * 2), height, width)

return latents

Expand All @@ -549,11 +549,11 @@ def prepare_latents(
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)

height = 2 * (int(height) // self.vae_scale_factor)
width = 2 * (int(width) // self.vae_scale_factor)
height = int(height) // self.vae_scale_factor
width = int(width) // self.vae_scale_factor

shape = (batch_size, num_channels_latents, height, width)
latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype)
latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)

if latents is not None:
return latents.to(device=device, dtype=dtype), latent_image_ids
Expand Down Expand Up @@ -852,7 +852,7 @@ def __call__(
control_mode = control_mode.reshape([-1, 1])

sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
image_seq_len = (int(height) // self.vae_scale_factor) * (int(width) // self.vae_scale_factor)
image_seq_len = (int(height) // self.vae_scale_factor // 2) * (int(width) // self.vae_scale_factor // 2)
mu = calculate_shift(
image_seq_len,
self.scheduler.config.base_image_seq_len,
Expand Down
26 changes: 13 additions & 13 deletions src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ def __init__(
)

self.vae_scale_factor = (
2 ** (len(self.vae.config.block_out_channels)) if hasattr(self, "vae") and self.vae is not None else 16
2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.mask_processor = VaeImageProcessor(
Expand All @@ -244,7 +244,7 @@ def __init__(
self.tokenizer_max_length = (
self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77
)
self.default_sample_size = 64
self.default_sample_size = 128

# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_t5_prompt_embeds
def _get_t5_prompt_embeds(
Expand Down Expand Up @@ -520,9 +520,9 @@ def check_inputs(
@staticmethod
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._prepare_latent_image_ids
def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
latent_image_ids = torch.zeros(height // 2, width // 2, 3)
latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None]
latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :]
latent_image_ids = torch.zeros(height, width, 3)
latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None]
latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :]

latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape

Expand All @@ -549,10 +549,10 @@ def _unpack_latents(latents, height, width, vae_scale_factor):
height = height // vae_scale_factor
width = width // vae_scale_factor

latents = latents.view(batch_size, height, width, channels // 4, 2, 2)
latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2)
latents = latents.permute(0, 3, 1, 4, 2, 5)

latents = latents.reshape(batch_size, channels // (2 * 2), height * 2, width * 2)
latents = latents.reshape(batch_size, channels // (2 * 2), height, width)

return latents

Expand All @@ -576,11 +576,11 @@ def prepare_latents(
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)

height = 2 * (int(height) // self.vae_scale_factor)
width = 2 * (int(width) // self.vae_scale_factor)
height = int(height) // self.vae_scale_factor
width = int(width) // self.vae_scale_factor

shape = (batch_size, num_channels_latents, height, width)
latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype)
latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)

image = image.to(device=device, dtype=dtype)
image_latents = self._encode_vae_image(image=image, generator=generator)
Expand Down Expand Up @@ -622,8 +622,8 @@ def prepare_mask_latents(
device,
generator,
):
height = 2 * (int(height) // self.vae_scale_factor)
width = 2 * (int(width) // self.vae_scale_factor)
height = int(height) // self.vae_scale_factor
width = int(width) // self.vae_scale_factor
# resize the mask to latents shape as we concatenate the mask to the latents
# we do that before converting to dtype to avoid breaking in case we're using cpu_offload
# and half precision
Expand Down Expand Up @@ -996,7 +996,7 @@ def __call__(
# 6. Prepare timesteps

sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
image_seq_len = (int(global_height) // self.vae_scale_factor) * (int(global_width) // self.vae_scale_factor)
image_seq_len = (int(global_height) // self.vae_scale_factor // 2) * (int(global_width) // self.vae_scale_factor // 2)
mu = calculate_shift(
image_seq_len,
self.scheduler.config.base_image_seq_len,
Expand Down
22 changes: 11 additions & 11 deletions src/diffusers/pipelines/flux/pipeline_flux_img2img.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,13 +212,13 @@ def __init__(
scheduler=scheduler,
)
self.vae_scale_factor = (
2 ** (len(self.vae.config.block_out_channels)) if hasattr(self, "vae") and self.vae is not None else 16
2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.tokenizer_max_length = (
self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77
)
self.default_sample_size = 64
self.default_sample_size = 128

# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_t5_prompt_embeds
def _get_t5_prompt_embeds(
Expand Down Expand Up @@ -477,9 +477,9 @@ def check_inputs(
@staticmethod
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._prepare_latent_image_ids
def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
latent_image_ids = torch.zeros(height // 2, width // 2, 3)
latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None]
latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :]
latent_image_ids = torch.zeros(height, width, 3)
latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None]
latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :]

latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape

Expand All @@ -506,10 +506,10 @@ def _unpack_latents(latents, height, width, vae_scale_factor):
height = height // vae_scale_factor
width = width // vae_scale_factor

latents = latents.view(batch_size, height, width, channels // 4, 2, 2)
latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2)
latents = latents.permute(0, 3, 1, 4, 2, 5)

latents = latents.reshape(batch_size, channels // (2 * 2), height * 2, width * 2)
latents = latents.reshape(batch_size, channels // (2 * 2), height, width)

return latents

Expand All @@ -532,11 +532,11 @@ def prepare_latents(
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)

height = 2 * (int(height) // self.vae_scale_factor)
width = 2 * (int(width) // self.vae_scale_factor)
height = int(height) // self.vae_scale_factor
width = int(width) // self.vae_scale_factor

shape = (batch_size, num_channels_latents, height, width)
latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype)
latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)

if latents is not None:
return latents.to(device=device, dtype=dtype), latent_image_ids
Expand Down Expand Up @@ -736,7 +736,7 @@ def __call__(

# 4.Prepare timesteps
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
image_seq_len = (int(height) // self.vae_scale_factor) * (int(width) // self.vae_scale_factor)
image_seq_len = (int(height) // self.vae_scale_factor // 2) * (int(width) // self.vae_scale_factor // 2)
mu = calculate_shift(
image_seq_len,
self.scheduler.config.base_image_seq_len,
Expand Down
Loading
Loading