-
Notifications
You must be signed in to change notification settings - Fork 6.5k
Add LDM Super Resolution pipeline #1116
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 25 commits
Commits
Show all changes
36 commits
Select commit
Hold shift + click to select a range
7664874
Add ldm super resolution pipeline
duongna21 2d3c98a
style
duongna21 0c44672
fix copies
duongna21 5519da2
style
duongna21 8af0ade
fix doc
duongna21 82623e0
Update src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffu…
duongna21 9ef5ba1
Update src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffu…
duongna21 9977636
Update src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffu…
duongna21 226bbc0
Update src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffu…
duongna21 d52704f
add doc
duongna21 b3e9cff
Merge branch 'add-sr-pipeline' of https://github.com/duongna21/diffus…
duongna21 16584f7
address comments
duongna21 003e185
address comments
duongna21 e360ce6
fix doc
duongna21 d189eea
minor
duongna21 b2d5e21
add tests
duongna21 4ca74e8
add tests
duongna21 69daedc
load text encoder from subfolder
duongna21 ac78735
fix test
duongna21 9c5134c
fix test
duongna21 7115557
style
duongna21 afc4462
style
duongna21 5708a2c
handle mps latents
duongna21 b4fbb2b
unfix typo
duongna21 f02b34b
unfix typo
duongna21 9606d01
Update tests/pipelines/latent_diffusion/test_latent_diffusion_superre…
duongna21 dc7de80
fix set_timesteps mps
duongna21 11e3d7b
Merge branch 'add-sr-pipeline' of https://github.com/duongna21/diffus…
duongna21 6f98543
fix set_timesteps mps
duongna21 3f6e1fa
Update src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffu…
duongna21 ef0c091
Update src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffu…
duongna21 47593f2
Update src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffu…
duongna21 1f808a1
Update src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffu…
duongna21 5308ff5
style
duongna21 6f122a7
test 64x64 instead of 256x256
duongna21 903bba0
Merge branch 'main' into add-sr-pipeline
patil-suraj File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
165 changes: 165 additions & 0 deletions
165
src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion_superresolution.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,165 @@ | ||
| import inspect | ||
| from typing import Optional, Tuple, Union | ||
|
|
||
| import numpy as np | ||
| import torch | ||
| import torch.utils.checkpoint | ||
|
|
||
| import PIL | ||
|
|
||
| from ...models import UNet2DModel, VQModel | ||
| from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput | ||
| from ...schedulers import ( | ||
| DDIMScheduler, | ||
| EulerAncestralDiscreteScheduler, | ||
| EulerDiscreteScheduler, | ||
| LMSDiscreteScheduler, | ||
| PNDMScheduler, | ||
| ) | ||
|
|
||
|
|
||
| def preprocess(image): | ||
| w, h = image.size | ||
| w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32 | ||
| image = image.resize((w, h), resample=PIL.Image.LANCZOS) | ||
| image = np.array(image).astype(np.float32) / 255.0 | ||
| image = image[None].transpose(0, 3, 1, 2) | ||
| image = torch.from_numpy(image) | ||
| return 2.0 * image - 1.0 | ||
|
|
||
|
|
||
| class LDMSuperResolutionPipeline(DiffusionPipeline): | ||
| r""" | ||
| This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the | ||
| library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) | ||
duongna21 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| Parameters: | ||
| vqvae ([`VQModel`]): | ||
| Vector-quantized (VQ) VAE Model to encode and decode images to and from latent representations. | ||
| unet ([`UNet2DModel`]): U-Net architecture to denoise the encoded image. | ||
| scheduler ([`SchedulerMixin`]): | ||
| A scheduler to be used in combination with `unet` to denoise the encoded image latens. Can be one of | ||
| [`DDIMScheduler`], [`LMSDiscreteScheduler`], [`EulerDiscreteScheduler`], | ||
| [`EulerAncestralDiscreteScheduler`], or [`PNDMScheduler`]. | ||
| """ | ||
|
|
||
| def __init__( | ||
| self, | ||
| vqvae: VQModel, | ||
| unet: UNet2DModel, | ||
| scheduler: Union[ | ||
| DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler, EulerDiscreteScheduler, EulerAncestralDiscreteScheduler | ||
duongna21 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| ], | ||
| ): | ||
| super().__init__() | ||
| self.register_modules(vqvae=vqvae, unet=unet, scheduler=scheduler) | ||
|
|
||
| @torch.no_grad() | ||
| def __call__( | ||
| self, | ||
| init_image: Union[torch.Tensor, PIL.Image.Image], | ||
| batch_size: Optional[int] = 1, | ||
| num_inference_steps: Optional[int] = 100, | ||
| eta: Optional[float] = 0.0, | ||
| generator: Optional[torch.Generator] = None, | ||
| output_type: Optional[str] = "pil", | ||
| return_dict: bool = True, | ||
| **kwargs, | ||
| ) -> Union[Tuple, ImagePipelineOutput]: | ||
| r""" | ||
| Args: | ||
| init_image (`torch.Tensor` or `PIL.Image.Image`): | ||
| `Image`, or tensor representing an image batch, that will be used as the starting point for the | ||
| process. | ||
| batch_size (`int`, *optional*, defaults to 1): | ||
| Number of images to generate. | ||
| num_inference_steps (`int`, *optional*, defaults to 100): | ||
| The number of denoising steps. More denoising steps usually lead to a higher quality image at the | ||
| expense of slower inference. | ||
| eta (`float`, *optional*, defaults to 0.0): | ||
| Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to | ||
| [`schedulers.DDIMScheduler`], will be ignored for others. | ||
| generator (`torch.Generator`, *optional*): | ||
| A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation | ||
| deterministic. | ||
| output_type (`str`, *optional*, defaults to `"pil"`): | ||
| The output format of the generate image. Choose between | ||
| [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. | ||
| return_dict (`bool`, *optional*): | ||
| Whether or not to return a [`~pipeline_utils.ImagePipelineOutput`] instead of a plain tuple. | ||
|
|
||
| Returns: | ||
| [`~pipeline_utils.ImagePipelineOutput`] or `tuple`: [`~pipelines.utils.ImagePipelineOutput`] if | ||
| `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the | ||
| generated images. | ||
| """ | ||
|
|
||
| if isinstance(init_image, PIL.Image.Image): | ||
| batch_size = 1 | ||
| elif isinstance(init_image, torch.Tensor): | ||
| batch_size = init_image.shape[0] | ||
| else: | ||
| raise ValueError( | ||
| f"`init_image` has to be of type `PIL.Image.Image` or `torch.Tensor` but is {type(init_image)}" | ||
| ) | ||
|
|
||
| if isinstance(init_image, PIL.Image.Image): | ||
| init_image = preprocess(init_image) | ||
|
|
||
| height, width = init_image.shape[-2:] | ||
|
|
||
| # in_channels should be 6: 3 for latents, 3 for low resolution image | ||
| latents_shape = (batch_size, self.unet.in_channels // 2, height, width) | ||
| latents_dtype = self.unet.dtype | ||
duongna21 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| if self.device.type == "mps": | ||
| # randn does not work reproducibly on mps | ||
| latents = torch.randn(latents_shape, generator=generator, device="cpu", dtype=latents_dtype).to( | ||
| self.device | ||
| ) | ||
duongna21 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| else: | ||
| latents = torch.randn(latents_shape, generator=generator, device=self.device, dtype=latents_dtype) | ||
|
|
||
| init_image = init_image.to(device=self.device, dtype=latents_dtype) | ||
|
|
||
| # set timesteps | ||
| self.scheduler.set_timesteps(num_inference_steps) | ||
duongna21 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| # Some schedulers like PNDM have timesteps as arrays | ||
| # It's more optimized to move all timesteps to correct device beforehand | ||
| timesteps_tensor = self.scheduler.timesteps.to(self.device) | ||
|
|
||
| # scale the initial noise by the standard deviation required by the scheduler | ||
| latents = latents * self.scheduler.init_noise_sigma | ||
|
|
||
| # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature. | ||
| # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. | ||
| # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 | ||
| # and should be between [0, 1] | ||
| accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) | ||
| extra_kwargs = {} | ||
| if accepts_eta: | ||
| extra_kwargs["eta"] = eta | ||
|
|
||
| for t in self.progress_bar(timesteps_tensor): | ||
| # concat latents and low resolution image in the channel dimension. | ||
| latents_input = torch.cat([latents, init_image], dim=1) | ||
duongna21 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| latents_input = self.scheduler.scale_model_input(latents_input, t) | ||
| # predict the noise residual | ||
duongna21 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| noise_pred = self.unet(latents_input, t).sample | ||
| # compute the previous noisy sample x_t -> x_t-1 | ||
| latents = self.scheduler.step(noise_pred, t, latents, **extra_kwargs).prev_sample | ||
|
|
||
| # decode the image latents with the VQVAE | ||
| image = self.vqvae.decode(latents).sample | ||
| image = torch.clamp(image, -1.0, 1.0) | ||
| image = image / 2 + 0.5 | ||
| image = image.cpu().permute(0, 2, 3, 1).numpy() | ||
|
|
||
| if output_type == "pil": | ||
| image = self.numpy_to_pil(image) | ||
|
|
||
| if not return_dict: | ||
| return (image,) | ||
|
|
||
| return ImagePipelineOutput(images=image) | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
115 changes: 115 additions & 0 deletions
115
tests/pipelines/latent_diffusion/test_latent_diffusion_superresolution.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,115 @@ | ||
| # coding=utf-8 | ||
| # Copyright 2022 HuggingFace Inc. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| import random | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nice tests! |
||
| import unittest | ||
|
|
||
| import numpy as np | ||
| import torch | ||
|
|
||
| from diffusers import DDIMScheduler, LDMSuperResolutionPipeline, UNet2DModel, VQModel | ||
| from diffusers.utils import floats_tensor, load_image, slow, torch_device | ||
| from diffusers.utils.testing_utils import require_torch | ||
|
|
||
| from ...test_pipelines_common import PipelineTesterMixin | ||
|
|
||
|
|
||
| torch.backends.cuda.matmul.allow_tf32 = False | ||
|
|
||
|
|
||
| class LDMSuperResolutionPipelineFastTests(PipelineTesterMixin, unittest.TestCase): | ||
| @property | ||
| def dummy_image(self): | ||
| batch_size = 1 | ||
| num_channels = 3 | ||
| sizes = (32, 32) | ||
|
|
||
| image = floats_tensor((batch_size, num_channels) + sizes, rng=random.Random(0)).to(torch_device) | ||
| return image | ||
|
|
||
| @property | ||
| def dummy_uncond_unet(self): | ||
| torch.manual_seed(0) | ||
| model = UNet2DModel( | ||
| block_out_channels=(32, 64), | ||
| layers_per_block=2, | ||
| sample_size=32, | ||
| in_channels=6, | ||
| out_channels=3, | ||
| down_block_types=("DownBlock2D", "AttnDownBlock2D"), | ||
| up_block_types=("AttnUpBlock2D", "UpBlock2D"), | ||
| ) | ||
| return model | ||
|
|
||
| @property | ||
| def dummy_vq_model(self): | ||
| torch.manual_seed(0) | ||
| model = VQModel( | ||
| block_out_channels=[32, 64], | ||
| in_channels=3, | ||
| out_channels=3, | ||
| down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"], | ||
| up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"], | ||
| latent_channels=3, | ||
| ) | ||
| return model | ||
|
|
||
| def test_inference_superresolution(self): | ||
| unet = self.dummy_uncond_unet | ||
| scheduler = DDIMScheduler() | ||
| vqvae = self.dummy_vq_model | ||
|
|
||
| ldm = LDMSuperResolutionPipeline(unet=unet, vqvae=vqvae, scheduler=scheduler) | ||
| ldm.to(torch_device) | ||
| ldm.set_progress_bar_config(disable=None) | ||
|
|
||
| init_image = self.dummy_image.to(torch_device) | ||
|
|
||
| # Warmup pass when using mps (see #372) | ||
| if torch_device == "mps": | ||
| generator = torch.manual_seed(0) | ||
| _ = ldm(init_image, generator=generator, num_inference_steps=1, output_type="numpy").images | ||
|
|
||
| generator = torch.manual_seed(0) | ||
| image = ldm(init_image, generator=generator, num_inference_steps=2, output_type="numpy").images | ||
|
|
||
| image_slice = image[0, -3:, -3:, -1] | ||
|
|
||
| assert image.shape == (1, 64, 64, 3) | ||
| expected_slice = np.array([0.8534, 0.8186, 0.6416, 0.6846, 0.4427, 0.5676, 0.4679, 0.6247, 0.5176]) | ||
| assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 | ||
duongna21 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
|
|
||
| @slow | ||
| @require_torch | ||
| class LDMSuperResolutionPipelineIntegrationTests(unittest.TestCase): | ||
| def test_inference_superresolution(self): | ||
| init_image = load_image( | ||
| "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" | ||
| "/vq_diffusion/teddy_bear_pool.png" | ||
| ) | ||
duongna21 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| ldm = LDMSuperResolutionPipeline.from_pretrained("duongna/ldm-super-resolution", device_map="auto") | ||
| ldm.to(torch_device) | ||
| ldm.set_progress_bar_config(disable=None) | ||
|
|
||
| generator = torch.Generator(device=torch_device).manual_seed(0) | ||
| image = ldm(init_image, generator=generator, num_inference_steps=20, output_type="numpy").images | ||
|
|
||
| image_slice = image[0, -3:, -3:, -1] | ||
|
|
||
| assert image.shape == (1, 1024, 1024, 3) | ||
| expected_slice = np.array([0.726, 0.7249, 0.7085, 0.774, 0.7419, 0.7188, 0.8359, 0.8031, 0.7158]) | ||
| assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 | ||
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is this necessary? An alternative would be to pad and then crop the upscaled image. Not sure if it's worth it, slightly worried that this might skew images a little bit.
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@pcuenca This is how other pipelines resize the image so it can successfully forward over UNet (agree that it might skew the image). Really sorry I can't fully understand your suggestion, could you kindly push a commit for it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here the preprocessing should be similar to how it's done in the original repo, since the model is trained on the preprocessed image. @duongna21 could post a link to the original inference code ?
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@patil-suraj Look at this and this. It works great with varying img size. But I can't spend time on this in the next few days.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks and no worries. We'll try to take a look at this, we can merge the PR without that also.