Skip to content
Merged
Show file tree
Hide file tree
Changes from 25 commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
7664874
Add ldm super resolution pipeline
duongna21 Nov 3, 2022
2d3c98a
style
duongna21 Nov 3, 2022
0c44672
fix copies
duongna21 Nov 3, 2022
5519da2
style
duongna21 Nov 3, 2022
8af0ade
fix doc
duongna21 Nov 3, 2022
82623e0
Update src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffu…
duongna21 Nov 5, 2022
9ef5ba1
Update src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffu…
duongna21 Nov 5, 2022
9977636
Update src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffu…
duongna21 Nov 5, 2022
226bbc0
Update src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffu…
duongna21 Nov 5, 2022
d52704f
add doc
duongna21 Nov 5, 2022
b3e9cff
Merge branch 'add-sr-pipeline' of https://github.com/duongna21/diffus…
duongna21 Nov 5, 2022
16584f7
address comments
duongna21 Nov 5, 2022
003e185
address comments
duongna21 Nov 5, 2022
e360ce6
fix doc
duongna21 Nov 5, 2022
d189eea
minor
duongna21 Nov 5, 2022
b2d5e21
add tests
duongna21 Nov 5, 2022
4ca74e8
add tests
duongna21 Nov 5, 2022
69daedc
load text encoder from subfolder
duongna21 Nov 5, 2022
ac78735
fix test
duongna21 Nov 5, 2022
9c5134c
fix test
duongna21 Nov 6, 2022
7115557
style
duongna21 Nov 6, 2022
afc4462
style
duongna21 Nov 6, 2022
5708a2c
handle mps latents
duongna21 Nov 6, 2022
b4fbb2b
unfix typo
duongna21 Nov 7, 2022
f02b34b
unfix typo
duongna21 Nov 7, 2022
9606d01
Update tests/pipelines/latent_diffusion/test_latent_diffusion_superre…
duongna21 Nov 8, 2022
dc7de80
fix set_timesteps mps
duongna21 Nov 8, 2022
11e3d7b
Merge branch 'add-sr-pipeline' of https://github.com/duongna21/diffus…
duongna21 Nov 8, 2022
6f98543
fix set_timesteps mps
duongna21 Nov 8, 2022
3f6e1fa
Update src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffu…
duongna21 Nov 8, 2022
ef0c091
Update src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffu…
duongna21 Nov 8, 2022
47593f2
Update src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffu…
duongna21 Nov 8, 2022
1f808a1
Update src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffu…
duongna21 Nov 8, 2022
5308ff5
style
duongna21 Nov 8, 2022
6f122a7
test 64x64 instead of 256x256
duongna21 Nov 8, 2022
903bba0
Merge branch 'main' into add-sr-pipeline
patil-suraj Nov 9, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/source/api/pipelines/latent_diffusion.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,15 @@ The original codebase can be found [here](https://github.com/CompVis/latent-diff
| Pipeline | Tasks | Colab
|---|---|:---:|
| [pipeline_latent_diffusion.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py) | *Text-to-Image Generation* | - |
| [pipeline_latent_diffusion_superresolution.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion_superresolution.py) | *Super Resolution* | - |

## Examples:


## LDMTextToImagePipeline
[[autodoc]] pipelines.latent_diffusion.pipeline_latent_diffusion.LDMTextToImagePipeline
- __call__

## LDMSuperResolutionPipeline
[[autodoc]] pipelines.latent_diffusion.pipeline_latent_diffusion_superresolution.LDMSuperResolutionPipeline
- __call__
1 change: 1 addition & 0 deletions src/diffusers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
DDPMPipeline,
KarrasVePipeline,
LDMPipeline,
LDMSuperResolutionPipeline,
PNDMPipeline,
ScoreSdeVePipeline,
)
Expand Down
1 change: 1 addition & 0 deletions src/diffusers/pipelines/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from .dance_diffusion import DanceDiffusionPipeline
from .ddim import DDIMPipeline
from .ddpm import DDPMPipeline
from .latent_diffusion import LDMSuperResolutionPipeline
from .latent_diffusion_uncond import LDMPipeline
from .pndm import PNDMPipeline
from .score_sde_ve import ScoreSdeVePipeline
Expand Down
1 change: 1 addition & 0 deletions src/diffusers/pipelines/latent_diffusion/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# flake8: noqa
from ...utils import is_transformers_available
from .pipeline_latent_diffusion_superresolution import LDMSuperResolutionPipeline


if is_transformers_available():
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
import inspect
from typing import Optional, Tuple, Union

import numpy as np
import torch
import torch.utils.checkpoint

import PIL

from ...models import UNet2DModel, VQModel
from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput
from ...schedulers import (
DDIMScheduler,
EulerAncestralDiscreteScheduler,
EulerDiscreteScheduler,
LMSDiscreteScheduler,
PNDMScheduler,
)


def preprocess(image):
w, h = image.size
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this necessary? An alternative would be to pad and then crop the upscaled image. Not sure if it's worth it, slightly worried that this might skew images a little bit.

Copy link
Contributor Author

@duongna21 duongna21 Nov 8, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@pcuenca This is how other pipelines resize the image so it can successfully forward over UNet (agree that it might skew the image). Really sorry I can't fully understand your suggestion, could you kindly push a commit for it?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here the preprocessing should be similar to how it's done in the original repo, since the model is trained on the preprocessed image. @duongna21 could post a link to the original inference code ?

Copy link
Contributor Author

@duongna21 duongna21 Nov 8, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@patil-suraj Look at this and this. It works great with varying img size. But I can't spend time on this in the next few days.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks and no worries. We'll try to take a look at this, we can merge the PR without that also.

image = image.resize((w, h), resample=PIL.Image.LANCZOS)
image = np.array(image).astype(np.float32) / 255.0
image = image[None].transpose(0, 3, 1, 2)
image = torch.from_numpy(image)
return 2.0 * image - 1.0


class LDMSuperResolutionPipeline(DiffusionPipeline):
r"""
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)

Parameters:
vqvae ([`VQModel`]):
Vector-quantized (VQ) VAE Model to encode and decode images to and from latent representations.
unet ([`UNet2DModel`]): U-Net architecture to denoise the encoded image.
scheduler ([`SchedulerMixin`]):
A scheduler to be used in combination with `unet` to denoise the encoded image latens. Can be one of
[`DDIMScheduler`], [`LMSDiscreteScheduler`], [`EulerDiscreteScheduler`],
[`EulerAncestralDiscreteScheduler`], or [`PNDMScheduler`].
"""

def __init__(
self,
vqvae: VQModel,
unet: UNet2DModel,
scheduler: Union[
DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler, EulerDiscreteScheduler, EulerAncestralDiscreteScheduler
],
):
super().__init__()
self.register_modules(vqvae=vqvae, unet=unet, scheduler=scheduler)

@torch.no_grad()
def __call__(
self,
init_image: Union[torch.Tensor, PIL.Image.Image],
batch_size: Optional[int] = 1,
num_inference_steps: Optional[int] = 100,
eta: Optional[float] = 0.0,
generator: Optional[torch.Generator] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
**kwargs,
) -> Union[Tuple, ImagePipelineOutput]:
r"""
Args:
init_image (`torch.Tensor` or `PIL.Image.Image`):
`Image`, or tensor representing an image batch, that will be used as the starting point for the
process.
batch_size (`int`, *optional*, defaults to 1):
Number of images to generate.
num_inference_steps (`int`, *optional*, defaults to 100):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
eta (`float`, *optional*, defaults to 0.0):
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
[`schedulers.DDIMScheduler`], will be ignored for others.
generator (`torch.Generator`, *optional*):
A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
deterministic.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
return_dict (`bool`, *optional*):
Whether or not to return a [`~pipeline_utils.ImagePipelineOutput`] instead of a plain tuple.

Returns:
[`~pipeline_utils.ImagePipelineOutput`] or `tuple`: [`~pipelines.utils.ImagePipelineOutput`] if
`return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the
generated images.
"""

if isinstance(init_image, PIL.Image.Image):
batch_size = 1
elif isinstance(init_image, torch.Tensor):
batch_size = init_image.shape[0]
else:
raise ValueError(
f"`init_image` has to be of type `PIL.Image.Image` or `torch.Tensor` but is {type(init_image)}"
)

if isinstance(init_image, PIL.Image.Image):
init_image = preprocess(init_image)

height, width = init_image.shape[-2:]

# in_channels should be 6: 3 for latents, 3 for low resolution image
latents_shape = (batch_size, self.unet.in_channels // 2, height, width)
latents_dtype = self.unet.dtype

if self.device.type == "mps":
# randn does not work reproducibly on mps
latents = torch.randn(latents_shape, generator=generator, device="cpu", dtype=latents_dtype).to(
self.device
)
else:
latents = torch.randn(latents_shape, generator=generator, device=self.device, dtype=latents_dtype)

init_image = init_image.to(device=self.device, dtype=latents_dtype)

# set timesteps
self.scheduler.set_timesteps(num_inference_steps)

# Some schedulers like PNDM have timesteps as arrays
# It's more optimized to move all timesteps to correct device beforehand
timesteps_tensor = self.scheduler.timesteps.to(self.device)

# scale the initial noise by the standard deviation required by the scheduler
latents = latents * self.scheduler.init_noise_sigma

# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature.
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
# and should be between [0, 1]
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
extra_kwargs = {}
if accepts_eta:
extra_kwargs["eta"] = eta

for t in self.progress_bar(timesteps_tensor):
# concat latents and low resolution image in the channel dimension.
latents_input = torch.cat([latents, init_image], dim=1)
latents_input = self.scheduler.scale_model_input(latents_input, t)
# predict the noise residual
noise_pred = self.unet(latents_input, t).sample
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_kwargs).prev_sample

# decode the image latents with the VQVAE
image = self.vqvae.decode(latents).sample
image = torch.clamp(image, -1.0, 1.0)
image = image / 2 + 0.5
image = image.cpu().permute(0, 2, 3, 1).numpy()

if output_type == "pil":
image = self.numpy_to_pil(image)

if not return_dict:
return (image,)

return ImagePipelineOutput(images=image)
15 changes: 15 additions & 0 deletions src/diffusers/utils/dummy_pt_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,21 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])


class LDMSuperResolutionPipeline(metaclass=DummyObject):
_backends = ["torch"]

def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])

@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch"])

@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])


class PNDMPipeline(metaclass=DummyObject):
_backends = ["torch"]

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
# coding=utf-8
# Copyright 2022 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import random
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice tests!

import unittest

import numpy as np
import torch

from diffusers import DDIMScheduler, LDMSuperResolutionPipeline, UNet2DModel, VQModel
from diffusers.utils import floats_tensor, load_image, slow, torch_device
from diffusers.utils.testing_utils import require_torch

from ...test_pipelines_common import PipelineTesterMixin


torch.backends.cuda.matmul.allow_tf32 = False


class LDMSuperResolutionPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
@property
def dummy_image(self):
batch_size = 1
num_channels = 3
sizes = (32, 32)

image = floats_tensor((batch_size, num_channels) + sizes, rng=random.Random(0)).to(torch_device)
return image

@property
def dummy_uncond_unet(self):
torch.manual_seed(0)
model = UNet2DModel(
block_out_channels=(32, 64),
layers_per_block=2,
sample_size=32,
in_channels=6,
out_channels=3,
down_block_types=("DownBlock2D", "AttnDownBlock2D"),
up_block_types=("AttnUpBlock2D", "UpBlock2D"),
)
return model

@property
def dummy_vq_model(self):
torch.manual_seed(0)
model = VQModel(
block_out_channels=[32, 64],
in_channels=3,
out_channels=3,
down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
latent_channels=3,
)
return model

def test_inference_superresolution(self):
unet = self.dummy_uncond_unet
scheduler = DDIMScheduler()
vqvae = self.dummy_vq_model

ldm = LDMSuperResolutionPipeline(unet=unet, vqvae=vqvae, scheduler=scheduler)
ldm.to(torch_device)
ldm.set_progress_bar_config(disable=None)

init_image = self.dummy_image.to(torch_device)

# Warmup pass when using mps (see #372)
if torch_device == "mps":
generator = torch.manual_seed(0)
_ = ldm(init_image, generator=generator, num_inference_steps=1, output_type="numpy").images

generator = torch.manual_seed(0)
image = ldm(init_image, generator=generator, num_inference_steps=2, output_type="numpy").images

image_slice = image[0, -3:, -3:, -1]

assert image.shape == (1, 64, 64, 3)
expected_slice = np.array([0.8534, 0.8186, 0.6416, 0.6846, 0.4427, 0.5676, 0.4679, 0.6247, 0.5176])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2


@slow
@require_torch
class LDMSuperResolutionPipelineIntegrationTests(unittest.TestCase):
def test_inference_superresolution(self):
init_image = load_image(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
"/vq_diffusion/teddy_bear_pool.png"
)

ldm = LDMSuperResolutionPipeline.from_pretrained("duongna/ldm-super-resolution", device_map="auto")
ldm.to(torch_device)
ldm.set_progress_bar_config(disable=None)

generator = torch.Generator(device=torch_device).manual_seed(0)
image = ldm(init_image, generator=generator, num_inference_steps=20, output_type="numpy").images

image_slice = image[0, -3:, -3:, -1]

assert image.shape == (1, 1024, 1024, 3)
expected_slice = np.array([0.726, 0.7249, 0.7085, 0.774, 0.7419, 0.7188, 0.8359, 0.8031, 0.7158])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2