Skip to content
Merged
Show file tree
Hide file tree
Changes from 47 commits
Commits
Show all changes
54 commits
Select commit Hold shift + click to select a range
80dd04f
added ldm3d pipeline and updated image processor to support depth
estelleafl Jun 4, 2023
dcb2518
added description
estelleafl Jun 4, 2023
62914ab
added paper reference
estelleafl Jun 4, 2023
ec51d63
added docs
estelleafl Jun 6, 2023
5c6de02
fixed bug
estelleafl Jun 7, 2023
ec05757
added test
estelleafl Jun 7, 2023
21ef8be
Merge branch 'main' into ldm3d_first_commit
estelleafl Jun 7, 2023
c541015
Merge branch 'main' into ldm3d_first_commit
patrickvonplaten Jun 7, 2023
57e8ce0
Update tests/pipelines/stable_diffusion/test_stable_diffusion_ldm3d.py
estelleafl Jun 8, 2023
426620f
Update tests/pipelines/stable_diffusion/test_stable_diffusion_ldm3d.py
estelleafl Jun 8, 2023
a272128
Update src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffu…
estelleafl Jun 8, 2023
81ee433
Update src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffu…
estelleafl Jun 8, 2023
24107ee
Update src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffu…
estelleafl Jun 8, 2023
8fef68a
Update src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffu…
estelleafl Jun 8, 2023
f6aa65f
Update src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffu…
estelleafl Jun 8, 2023
b2b8c92
Update src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffu…
estelleafl Jun 8, 2023
ec1d1cd
Update src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffu…
estelleafl Jun 8, 2023
0be8912
Update src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffu…
estelleafl Jun 8, 2023
c2f5ff7
Update src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffu…
estelleafl Jun 8, 2023
5e6c3ad
Update src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffu…
estelleafl Jun 8, 2023
2023a00
Update src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffu…
estelleafl Jun 8, 2023
7fd6bdb
added reference in indexmdx
estelleafl Jun 8, 2023
b4cedf4
reverted changes tto image processor'
estelleafl Jun 8, 2023
42989ff
added LDM3DOutput
Jun 8, 2023
9ffe6bc
Merge branch 'main' into ldm3d_first_commit
patrickvonplaten Jun 8, 2023
d357649
Fixes with make style
abhiwand Jun 9, 2023
868dd00
fix failing tests for make fix-copies
abhiwand Jun 9, 2023
e14fd33
aligned with our version
Jun 11, 2023
1759a91
Update pipeline_stable_diffusion_ldm3d.py
estelleafl Jun 12, 2023
8b905c5
Merge branch 'main' into ldm3d_first_commit
patrickvonplaten Jun 12, 2023
f61108d
Fix for failing check_code_quality test
abhiwand Jun 12, 2023
2c6db8e
Code review feedback
abhiwand Jun 12, 2023
5755331
Fix typo in ldm3d_diffusion.mdx
abhiwand Jun 12, 2023
c8a9574
updated the doc accordnlgy
Jun 13, 2023
70b26e1
copyrights
Jun 13, 2023
2054c08
fixed test failure
Jun 13, 2023
e7c5690
make style
Jun 13, 2023
ea7b5b9
added image processor of LDM3D in the documentation:
Jun 14, 2023
1b9248b
added ldm3d doc to toctree
Jun 14, 2023
ff2290f
run make style && make quality
sayakpaul Jun 14, 2023
256e248
run make fix-copies
sayakpaul Jun 14, 2023
aff8da2
Merge branch 'main' into ldm3d_first_commit
sayakpaul Jun 14, 2023
b12a2f1
Update docs/source/en/api/image_processor.mdx
estelleafl Jun 14, 2023
cc3c64f
Update docs/source/en/api/pipelines/stable_diffusion/ldm3d_diffusion.mdx
estelleafl Jun 14, 2023
29b2432
Update docs/source/en/api/pipelines/stable_diffusion/ldm3d_diffusion.mdx
estelleafl Jun 14, 2023
3a93f8a
updated the safety checker to accept tuple
Jun 14, 2023
aeee2ea
make style and make quality
Jun 14, 2023
562ea35
Update src/diffusers/pipelines/stable_diffusion/__init__.py
estelleafl Jun 15, 2023
5f4ed2b
Update src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffu…
estelleafl Jun 15, 2023
750248e
Update src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffu…
estelleafl Jun 15, 2023
39ea8a6
Update src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffu…
estelleafl Jun 15, 2023
647038b
LDM3D output
Jun 15, 2023
a21c4a5
Merge branch 'main' into ldm3d_first_commit
patrickvonplaten Jun 15, 2023
f66881d
up
patrickvonplaten Jun 15, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/source/en/_toctree.yml
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,8 @@
title: Stable-Diffusion-Latent-Upscaler
- local: api/pipelines/stable_diffusion/upscale
title: Super-Resolution
- local: api/pipelines/stable_diffusion/ldm3d_diffusion
title: LDM3D Text-to-(RGB, Depth)
title: Stable Diffusion
- local: api/pipelines/stable_unclip
title: Stable unCLIP
Expand Down
13 changes: 12 additions & 1 deletion docs/source/en/api/image_processor.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,17 @@ Image processor provides a unified API for Stable Diffusion pipelines to prepare
All pipelines with VAE image processor will accept image inputs in the format of PIL Image, PyTorch tensor, or Numpy array, and will able to return outputs in the format of PIL Image, Pytorch tensor, and Numpy array based on the `output_type` argument from the user. Additionally, the User can pass encoded image latents directly to the pipeline, or ask the pipeline to return latents as output with `output_type = 'pt'` argument. This allows you to take the generated latents from one pipeline and pass it to another pipeline as input, without ever having to leave the latent space. It also makes it much easier to use multiple pipelines together, by passing PyTorch tensors directly between different pipelines.


# Image Processor for VAE adapted to LDM3D

LDM3D Image processor does the same as the Image processor for VAE but accepts both RGB and depth inputs and will return RGB and depth outputs.



## VaeImageProcessor

[[autodoc]] image_processor.VaeImageProcessor
[[autodoc]] image_processor.VaeImageProcessor


## VaeImageProcessorLDM3D

[[autodoc]] image_processor.VaeImageProcessorLDM3D
55 changes: 55 additions & 0 deletions docs/source/en/api/pipelines/stable_diffusion/ldm3d_diffusion.mdx
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
<!--Copyright 2023 The Intel Labs Team Authors and HuggingFace Team. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->

# LDM3D

LDM3D was proposed in [LDM3D: Latent Diffusion Model for 3D](https://arxiv.org/abs/2305.10853) by Gabriela Ben Melech Stan, Diana Wofk, Scottie Fox, Alex Redden, Will Saxton, Jean Yu, Estelle Aflalo, Shao-Yen Tseng, Fabio Nonato, Matthias Muller, Vasudev Lal
The abstract of the paper is the following:

*This research paper proposes a Latent Diffusion Model for 3D (LDM3D) that generates both image and depth map data from a given text prompt, allowing users to generate RGBD images from text prompts. The LDM3D model is fine-tuned on a dataset of tuples containing an RGB image, depth map and caption, and validated through extensive experiments. We also develop an application called DepthFusion, which uses the generated RGB images and depth maps to create immersive and interactive 360-degree-view experiences using TouchDesigner. This technology has the potential to transform a wide range of industries, from entertainment and gaming to architecture and design. Overall, this paper presents a significant contribution to the field of generative AI and computer vision, and showcases the potential of LDM3D and DepthFusion to revolutionize content creation and digital experiences. A short video summarizing the approach can be found at [this url](https://t.ly/tdi2).*


*Overview*:

| Pipeline | Tasks | Colab | Demo
|---|---|:---:|:---:|
| [pipeline_stable_diffusion_ldm3d.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_ldm3d.py) | *Text-to-Image Generation* | - | -

## Tips

- LDM3D generates both an image and a depth map from a given text prompt, compared to the existing txt-to-img diffusion models such as [Stable Diffusion](./stable_diffusion/overview) that generates only an image.
- With almost the same number of parameters, LDM3D achieves to create a latent space that can compress both the RGB images and the depth maps.


Running LDM3D is straighforward with the [`StableDiffusionLDM3DPipeline`]:

```python
>>> from diffusers import StableDiffusionLDM3DPipeline

>>> pipe_ldm3d = StableDiffusionLDM3DPipeline.from_pretrained("Intel/ldm3d")
prompt ="A picture of some lemons on a table"
output = pipe_ldm3d(prompt)
rgb_image, depth_image = output.rgb, output.depth
rgb_image[0].save("lemons_ldm3d_rgb.jpg")
depth_image[0].save("lemons_ldm3d_depth.png")
```


## StableDiffusionPipelineOutput
[[autodoc]] pipelines.stable_diffusion.StableDiffusionPipelineOutput
- all
- __call__

## StableDiffusionLDM3DPipeline
[[autodoc]] StableDiffusionLDM3DPipeline
- all
- __call__
1 change: 1 addition & 0 deletions docs/source/en/index.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -94,3 +94,4 @@ The library has three main components:
| [versatile_diffusion](./api/pipelines/versatile_diffusion) | [Versatile Diffusion: Text, Images and Variations All in One Diffusion Model](https://arxiv.org/abs/2211.08332) | Image Variations Generation |
| [versatile_diffusion](./api/pipelines/versatile_diffusion) | [Versatile Diffusion: Text, Images and Variations All in One Diffusion Model](https://arxiv.org/abs/2211.08332) | Dual Image and Text Guided Generation |
| [vq_diffusion](./api/pipelines/vq_diffusion) | [Vector Quantized Diffusion Model for Text-to-Image Synthesis](https://arxiv.org/abs/2111.14822) | Text-to-Image Generation |
| [stable_diffusion_ldm3d](./api/pipelines/stable_diffusion/ldm3d_diffusion) | [LDM3D: Latent Diffusion Model for 3D](https://arxiv.org/abs/2305.10853) | Text to Image and Depth Generation |
1 change: 1 addition & 0 deletions src/diffusers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,7 @@
StableDiffusionInpaintPipelineLegacy,
StableDiffusionInstructPix2PixPipeline,
StableDiffusionLatentUpscalePipeline,
StableDiffusionLDM3DPipeline,
StableDiffusionModelEditingPipeline,
StableDiffusionPanoramaPipeline,
StableDiffusionPipeline,
Expand Down
106 changes: 106 additions & 0 deletions src/diffusers/image_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,3 +251,109 @@ def postprocess(

if output_type == "pil":
return self.numpy_to_pil(image)


class VaeImageProcessorLDM3D(VaeImageProcessor):
"""
Image Processor for VAE LDM3D.

Args:
do_resize (`bool`, *optional*, defaults to `True`):
Whether to downscale the image's (height, width) dimensions to multiples of `vae_scale_factor`.
vae_scale_factor (`int`, *optional*, defaults to `8`):
VAE scale factor. If `do_resize` is True, the image will be automatically resized to multiples of this
factor.
resample (`str`, *optional*, defaults to `lanczos`):
Resampling filter to use when resizing the image.
do_normalize (`bool`, *optional*, defaults to `True`):
Whether to normalize the image to [-1,1]
"""

config_name = CONFIG_NAME

@register_to_config
def __init__(
self,
do_resize: bool = True,
vae_scale_factor: int = 8,
resample: str = "lanczos",
do_normalize: bool = True,
):
super().__init__()

@staticmethod
def numpy_to_pil(images):
"""
Convert a numpy image or a batch of images to a PIL image.
"""
if images.ndim == 3:
images = images[None, ...]
images = (images * 255).round().astype("uint8")
if images.shape[-1] == 1:
# special case for grayscale (single channel) images
pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images]
else:
pil_images = [Image.fromarray(image[:, :, :3]) for image in images]

return pil_images

@staticmethod
def rgblike_to_depthmap(image):
"""
Args:
image: RGB-like depth image

Returns: depth map

"""
return image[:, :, 1] * 2**8 + image[:, :, 2]

def numpy_to_depth(self, images):
"""
Convert a numpy depth image or a batch of images to a PIL image.
"""
if images.ndim == 3:
images = images[None, ...]
images = (images * 255).round().astype("uint8")
if images.shape[-1] == 1:
# special case for grayscale (single channel) images
raise Exception("Not supported")
else:
pil_images = [Image.fromarray(self.rgblike_to_depthmap(image[:, :, 3:]), mode="I;16") for image in images]

return pil_images

def postprocess(
self,
image: torch.FloatTensor,
output_type: str = "pil",
do_denormalize: Optional[List[bool]] = None,
):
if not isinstance(image, torch.Tensor):
raise ValueError(
f"Input for postprocessing is in incorrect format: {type(image)}. We only support pytorch tensor"
)
if output_type not in ["latent", "pt", "np", "pil"]:
deprecation_message = (
f"the output_type {output_type} is outdated and has been set to `np`. Please make sure to set it to one of these instead: "
"`pil`, `np`, `pt`, `latent`"
)
deprecate("Unsupported output_type", "1.0.0", deprecation_message, standard_warn=False)
output_type = "np"

if do_denormalize is None:
do_denormalize = [self.config.do_normalize] * image.shape[0]

image = torch.stack(
[self.denormalize(image[i]) if do_denormalize[i] else image[i] for i in range(image.shape[0])]
)

image = self.pt_to_numpy(image)

if output_type == "np":
return image[:, :, :, :3], np.stack([self.rgblike_to_depthmap(im[:, :, 3:]) for im in image], axis=0)

if output_type == "pil":
return self.numpy_to_pil(image), self.numpy_to_depth(image)
else:
raise Exception(f"This type {output_type} is not supported")
1 change: 1 addition & 0 deletions src/diffusers/pipelines/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@
StableDiffusionInpaintPipelineLegacy,
StableDiffusionInstructPix2PixPipeline,
StableDiffusionLatentUpscalePipeline,
StableDiffusionLDM3DPipeline,
StableDiffusionModelEditingPipeline,
StableDiffusionPanoramaPipeline,
StableDiffusionPipeline,
Expand Down
20 changes: 20 additions & 0 deletions src/diffusers/pipelines/stable_diffusion/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,25 @@ class StableDiffusionPipelineOutput(BaseOutput):
nsfw_content_detected: Optional[List[bool]]


@dataclass
class LDM3DPipelineOutput(BaseOutput):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@patrickvonplaten should we add this directly to the pipeline script?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let me know what to do here :)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's go with directly adding the class to the pipeline script :)

"""
Output class for Stable Diffusion pipelines.

Args:
images (`List[PIL.Image.Image]` or `np.ndarray`)
List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width,
num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline.
nsfw_content_detected (`List[bool]`)
List of flags denoting whether the corresponding generated image likely represents "not-safe-for-work"
(nsfw) content, or `None` if safety checking could not be performed.
"""

rgb: Union[List[PIL.Image.Image], np.ndarray]
depth: Union[List[PIL.Image.Image], np.ndarray]
nsfw_content_detected: Optional[List[bool]]


try:
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
Expand All @@ -50,6 +69,7 @@ class StableDiffusionPipelineOutput(BaseOutput):
from .pipeline_stable_diffusion_inpaint_legacy import StableDiffusionInpaintPipelineLegacy
from .pipeline_stable_diffusion_instruct_pix2pix import StableDiffusionInstructPix2PixPipeline
from .pipeline_stable_diffusion_latent_upscale import StableDiffusionLatentUpscalePipeline
from .pipeline_stable_diffusion_ldm3d import StableDiffusionLDM3DPipeline
from .pipeline_stable_diffusion_model_editing import StableDiffusionModelEditingPipeline
from .pipeline_stable_diffusion_panorama import StableDiffusionPanoramaPipeline
from .pipeline_stable_diffusion_sag import StableDiffusionSAGPipeline
Expand Down
Loading