-
Notifications
You must be signed in to change notification settings - Fork 6.5k
add padding_mask_crop to all inpaint pipelines #6360
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 11 commits
7e3e66f
aa6e4ba
fdc7c5a
bb1596e
001c514
6d1a3fb
2fc751a
9d14a26
1ce9ecd
a0bedc8
98c5b37
3fa0a3e
3e00f00
93bb80e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -683,16 +683,19 @@ def check_inputs( | |
| self, | ||
| prompt, | ||
| image, | ||
| mask_image, | ||
| height, | ||
| width, | ||
| callback_steps, | ||
| output_type, | ||
| negative_prompt=None, | ||
| prompt_embeds=None, | ||
| negative_prompt_embeds=None, | ||
| controlnet_conditioning_scale=1.0, | ||
| control_guidance_start=0.0, | ||
| control_guidance_end=1.0, | ||
| callback_on_step_end_tensor_inputs=None, | ||
| padding_mask_crop=None, | ||
| ): | ||
| if height is not None and height % 8 != 0 or width is not None and width % 8 != 0: | ||
| raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") | ||
|
|
@@ -736,6 +739,24 @@ def check_inputs( | |
| f" {negative_prompt_embeds.shape}." | ||
| ) | ||
|
|
||
| if padding_mask_crop is not None: | ||
| if self.unet.config.in_channels != 4 and self.unet.config.in_channels != 9: | ||
| raise ValueError( | ||
| f"The UNet should have 4 or 9 input channels for inpainting mask crop, but has" | ||
| f" {self.unet.config.in_channels} input channels." | ||
| ) | ||
| if not isinstance(image, PIL.Image.Image): | ||
| raise ValueError( | ||
| f"The image should be a PIL image when inpainting mask crop, but is of type" f" {type(image)}." | ||
| ) | ||
| if not isinstance(mask_image, PIL.Image.Image): | ||
| raise ValueError( | ||
| f"The mask image should be a PIL image when inpainting mask crop, but is of type" | ||
| f" {type(mask_image)}." | ||
| ) | ||
| if output_type != "pil": | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 👍 |
||
| raise ValueError(f"The output type should be PIL when inpainting mask crop, but is" f" {output_type}.") | ||
|
|
||
| # `prompt` needs more sophisticated handling when there are multiple | ||
| # conditionings. | ||
| if isinstance(self.controlnet, MultiControlNetModel): | ||
|
|
@@ -862,7 +883,6 @@ def check_image(self, image, prompt, prompt_embeds): | |
| f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {image_batch_size}, prompt batch size: {prompt_batch_size}" | ||
| ) | ||
|
|
||
| # Copied from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.prepare_image | ||
| def prepare_control_image( | ||
| self, | ||
| image, | ||
|
|
@@ -872,10 +892,14 @@ def prepare_control_image( | |
| num_images_per_prompt, | ||
| device, | ||
| dtype, | ||
| crops_coords, | ||
| resize_mode, | ||
| do_classifier_free_guidance=False, | ||
| guess_mode=False, | ||
| ): | ||
| image = self.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32) | ||
| image = self.control_image_processor.preprocess( | ||
| image, height=height, width=width, crops_coords=crops_coords, resize_mode=resize_mode | ||
| ).to(dtype=torch.float32) | ||
| image_batch_size = image.shape[0] | ||
|
|
||
| if image_batch_size == 1: | ||
|
|
@@ -1074,6 +1098,7 @@ def __call__( | |
| control_image: PipelineImageInput = None, | ||
| height: Optional[int] = None, | ||
| width: Optional[int] = None, | ||
| padding_mask_crop: Optional[int] = None, | ||
| strength: float = 1.0, | ||
| num_inference_steps: int = 50, | ||
| guidance_scale: float = 7.5, | ||
|
|
@@ -1130,6 +1155,12 @@ def __call__( | |
| The height in pixels of the generated image. | ||
| width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): | ||
| The width in pixels of the generated image. | ||
| padding_mask_crop (`int`, *optional*, defaults to `None`): | ||
| The size of margin in the crop to be applied to the image and masking. If `None`, no crop is applied to image and mask_image. If | ||
| `padding_mask_crop` is not `None`, it will first find a rectangular region with the same aspect ration of the image and | ||
| contains all masked area, and then expand that area based on `padding_mask_crop`. The image and mask_image will then be cropped based on | ||
| the expanded area before resizing to the original image size for inpainting. This is useful when the masked area is small while the image is large | ||
| and contain information inreleant for inpainging, such as background. | ||
| strength (`float`, *optional*, defaults to 1.0): | ||
| Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a | ||
| starting point and more noise is added the higher the `strength`. The number of denoising steps depends | ||
|
|
@@ -1240,16 +1271,19 @@ def __call__( | |
| self.check_inputs( | ||
| prompt, | ||
| control_image, | ||
| mask_image, | ||
| height, | ||
| width, | ||
| callback_steps, | ||
| output_type, | ||
| negative_prompt, | ||
| prompt_embeds, | ||
| negative_prompt_embeds, | ||
| controlnet_conditioning_scale, | ||
| control_guidance_start, | ||
| control_guidance_end, | ||
| callback_on_step_end_tensor_inputs, | ||
| padding_mask_crop, | ||
| ) | ||
|
|
||
| self._guidance_scale = guidance_scale | ||
|
|
@@ -1264,6 +1298,17 @@ def __call__( | |
| else: | ||
| batch_size = prompt_embeds.shape[0] | ||
|
|
||
| if padding_mask_crop is not None: | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. um just saw your issue #6435 see my comment here #6435 (comment) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think it would work. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ok! |
||
| if width is None or height is None: | ||
| default_height, default_width = self.image_processor.get_default_height_width(image) | ||
| width = width or default_width | ||
| height = height or default_height | ||
rootonchair marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| crops_coords = self.mask_processor.get_crop_region(mask_image, width, height, pad=padding_mask_crop) | ||
| resize_mode = "fill" | ||
| else: | ||
| crops_coords = None | ||
| resize_mode = "default" | ||
|
|
||
| device = self._execution_device | ||
|
|
||
| if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float): | ||
|
|
@@ -1315,6 +1360,8 @@ def __call__( | |
| num_images_per_prompt=num_images_per_prompt, | ||
| device=device, | ||
| dtype=controlnet.dtype, | ||
| crops_coords=crops_coords, | ||
| resize_mode=resize_mode, | ||
| do_classifier_free_guidance=self.do_classifier_free_guidance, | ||
| guess_mode=guess_mode, | ||
| ) | ||
|
|
@@ -1330,6 +1377,8 @@ def __call__( | |
| num_images_per_prompt=num_images_per_prompt, | ||
| device=device, | ||
| dtype=controlnet.dtype, | ||
| crops_coords=crops_coords, | ||
| resize_mode=resize_mode, | ||
| do_classifier_free_guidance=self.do_classifier_free_guidance, | ||
| guess_mode=guess_mode, | ||
| ) | ||
|
|
@@ -1341,10 +1390,15 @@ def __call__( | |
| assert False | ||
|
|
||
| # 4.1 Preprocess mask and image - resizes image and mask w.r.t height and width | ||
| init_image = self.image_processor.preprocess(image, height=height, width=width) | ||
| original_image = image | ||
| init_image = self.image_processor.preprocess( | ||
| image, height=height, width=width, crops_coords=crops_coords, resize_mode=resize_mode | ||
| ) | ||
| init_image = init_image.to(dtype=torch.float32) | ||
|
|
||
| mask = self.mask_processor.preprocess(mask_image, height=height, width=width) | ||
| mask = self.mask_processor.preprocess( | ||
| mask_image, height=height, width=width, resize_mode=resize_mode, crops_coords=crops_coords | ||
| ) | ||
|
|
||
| masked_image = init_image * (mask < 0.5) | ||
| _, _, height, width = init_image.shape | ||
|
|
@@ -1534,6 +1588,9 @@ def __call__( | |
|
|
||
| image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) | ||
|
|
||
| if padding_mask_crop is not None: | ||
| image = [self.image_processor.apply_overlay(mask_image, original_image, i, crops_coords) for i in image] | ||
|
|
||
| # Offload all models | ||
| self.maybe_free_model_hooks() | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -642,6 +642,7 @@ def check_inputs( | |
| width, | ||
| strength, | ||
| callback_steps, | ||
| output_type, | ||
| negative_prompt=None, | ||
| prompt_embeds=None, | ||
| negative_prompt_embeds=None, | ||
|
|
@@ -693,9 +694,9 @@ def check_inputs( | |
| f" {negative_prompt_embeds.shape}." | ||
| ) | ||
| if padding_mask_crop is not None: | ||
| if self.unet.config.in_channels != 4: | ||
| if self.unet.config.in_channels != 4 and self.unet.config.in_channels != 9: | ||
| raise ValueError( | ||
| f"The UNet should have 4 input channels for inpainting mask crop, but has" | ||
| f"The UNet should have 4 or 9 input channels for inpainting mask crop, but has" | ||
| f" {self.unet.config.in_channels} input channels." | ||
| ) | ||
|
||
| if not isinstance(image, PIL.Image.Image): | ||
|
|
@@ -707,6 +708,8 @@ def check_inputs( | |
| f"The mask image should be a PIL image when inpainting mask crop, but is of type" | ||
| f" {type(mask_image)}." | ||
| ) | ||
| if output_type != "pil": | ||
| raise ValueError(f"The output type should be PIL when inpainting mask crop, but is" f" {output_type}.") | ||
|
|
||
| def prepare_latents( | ||
| self, | ||
|
|
@@ -1166,6 +1169,7 @@ def __call__( | |
| width, | ||
| strength, | ||
| callback_steps, | ||
| output_type, | ||
| negative_prompt, | ||
| prompt_embeds, | ||
| negative_prompt_embeds, | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.