Skip to content

Commit 958d9ec

Browse files
estelleaflpatrickvonplatenAflaloabhiwandAflalo
authored
Ldm3d first PR (#3668)
* added ldm3d pipeline and updated image processor to support depth * added description * added paper reference * added docs * fixed bug * added test * Update tests/pipelines/stable_diffusion/test_stable_diffusion_ldm3d.py Co-authored-by: Patrick von Platen <[email protected]> * Update tests/pipelines/stable_diffusion/test_stable_diffusion_ldm3d.py Co-authored-by: Patrick von Platen <[email protected]> * Update src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_ldm3d.py Co-authored-by: Patrick von Platen <[email protected]> * Update src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_ldm3d.py Co-authored-by: Patrick von Platen <[email protected]> * Update src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_ldm3d.py Co-authored-by: Patrick von Platen <[email protected]> * Update src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_ldm3d.py Co-authored-by: Patrick von Platen <[email protected]> * Update src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_ldm3d.py Co-authored-by: Patrick von Platen <[email protected]> * Update src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_ldm3d.py Co-authored-by: Patrick von Platen <[email protected]> * Update src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_ldm3d.py Co-authored-by: Patrick von Platen <[email protected]> * Update src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_ldm3d.py Co-authored-by: Patrick von Platen <[email protected]> * Update src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_ldm3d.py Co-authored-by: Patrick von Platen <[email protected]> * Update src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_ldm3d.py Co-authored-by: Patrick von Platen <[email protected]> * Update src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_ldm3d.py Co-authored-by: Patrick von Platen <[email protected]> * added reference in indexmdx * reverted changes tto image processor' * added LDM3DOutput * Fixes with make style * fix failing tests for make fix-copies * aligned with our version * Update pipeline_stable_diffusion_ldm3d.py updated the guidance scale * Fix for failing check_code_quality test * Code review feedback * Fix typo in ldm3d_diffusion.mdx * updated the doc accordnlgy * copyrights * fixed test failure * make style * added image processor of LDM3D in the documentation: * added ldm3d doc to toctree * run make style && make quality * run make fix-copies * Update docs/source/en/api/image_processor.mdx Co-authored-by: Sayak Paul <[email protected]> * Update docs/source/en/api/pipelines/stable_diffusion/ldm3d_diffusion.mdx Co-authored-by: Sayak Paul <[email protected]> * Update docs/source/en/api/pipelines/stable_diffusion/ldm3d_diffusion.mdx Co-authored-by: Sayak Paul <[email protected]> * updated the safety checker to accept tuple * make style and make quality * Update src/diffusers/pipelines/stable_diffusion/__init__.py Co-authored-by: Patrick von Platen <[email protected]> * Update src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_ldm3d.py Co-authored-by: Patrick von Platen <[email protected]> * Update src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_ldm3d.py Co-authored-by: Patrick von Platen <[email protected]> * Update src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_ldm3d.py Co-authored-by: Patrick von Platen <[email protected]> * LDM3D output * up --------- Co-authored-by: Patrick von Platen <[email protected]> Co-authored-by: Aflalo <[email protected]> Co-authored-by: Anahita Bhiwandiwalla <[email protected]> Co-authored-by: Aflalo <[email protected]> Co-authored-by: Aflalo <[email protected]> Co-authored-by: Sayak Paul <[email protected]> Co-authored-by: Aflalo <[email protected]> Co-authored-by: Aflalo <[email protected]>
1 parent 77f9137 commit 958d9ec

File tree

11 files changed

+1201
-1
lines changed

11 files changed

+1201
-1
lines changed

docs/source/en/_toctree.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -221,6 +221,8 @@
221221
title: Stable-Diffusion-Latent-Upscaler
222222
- local: api/pipelines/stable_diffusion/upscale
223223
title: Super-Resolution
224+
- local: api/pipelines/stable_diffusion/ldm3d_diffusion
225+
title: LDM3D Text-to-(RGB, Depth)
224226
title: Stable Diffusion
225227
- local: api/pipelines/stable_unclip
226228
title: Stable unCLIP

docs/source/en/api/image_processor.mdx

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,17 @@ Image processor provides a unified API for Stable Diffusion pipelines to prepare
1717
All pipelines with VAE image processor will accept image inputs in the format of PIL Image, PyTorch tensor, or Numpy array, and will able to return outputs in the format of PIL Image, Pytorch tensor, and Numpy array based on the `output_type` argument from the user. Additionally, the User can pass encoded image latents directly to the pipeline, or ask the pipeline to return latents as output with `output_type = 'pt'` argument. This allows you to take the generated latents from one pipeline and pass it to another pipeline as input, without ever having to leave the latent space. It also makes it much easier to use multiple pipelines together, by passing PyTorch tensors directly between different pipelines.
1818

1919

20+
# Image Processor for VAE adapted to LDM3D
21+
22+
LDM3D Image processor does the same as the Image processor for VAE but accepts both RGB and depth inputs and will return RGB and depth outputs.
23+
24+
25+
2026
## VaeImageProcessor
2127

22-
[[autodoc]] image_processor.VaeImageProcessor
28+
[[autodoc]] image_processor.VaeImageProcessor
29+
30+
31+
## VaeImageProcessorLDM3D
32+
33+
[[autodoc]] image_processor.VaeImageProcessorLDM3D
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
<!--Copyright 2023 The Intel Labs Team Authors and HuggingFace Team. All rights reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
4+
the License. You may obtain a copy of the License at
5+
6+
http://www.apache.org/licenses/LICENSE-2.0
7+
8+
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
9+
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
10+
specific language governing permissions and limitations under the License.
11+
-->
12+
13+
# LDM3D
14+
15+
LDM3D was proposed in [LDM3D: Latent Diffusion Model for 3D](https://arxiv.org/abs/2305.10853) by Gabriela Ben Melech Stan, Diana Wofk, Scottie Fox, Alex Redden, Will Saxton, Jean Yu, Estelle Aflalo, Shao-Yen Tseng, Fabio Nonato, Matthias Muller, Vasudev Lal
16+
The abstract of the paper is the following:
17+
18+
*This research paper proposes a Latent Diffusion Model for 3D (LDM3D) that generates both image and depth map data from a given text prompt, allowing users to generate RGBD images from text prompts. The LDM3D model is fine-tuned on a dataset of tuples containing an RGB image, depth map and caption, and validated through extensive experiments. We also develop an application called DepthFusion, which uses the generated RGB images and depth maps to create immersive and interactive 360-degree-view experiences using TouchDesigner. This technology has the potential to transform a wide range of industries, from entertainment and gaming to architecture and design. Overall, this paper presents a significant contribution to the field of generative AI and computer vision, and showcases the potential of LDM3D and DepthFusion to revolutionize content creation and digital experiences. A short video summarizing the approach can be found at [this url](https://t.ly/tdi2).*
19+
20+
21+
*Overview*:
22+
23+
| Pipeline | Tasks | Colab | Demo
24+
|---|---|:---:|:---:|
25+
| [pipeline_stable_diffusion_ldm3d.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_ldm3d.py) | *Text-to-Image Generation* | - | -
26+
27+
## Tips
28+
29+
- LDM3D generates both an image and a depth map from a given text prompt, compared to the existing txt-to-img diffusion models such as [Stable Diffusion](./stable_diffusion/overview) that generates only an image.
30+
- With almost the same number of parameters, LDM3D achieves to create a latent space that can compress both the RGB images and the depth maps.
31+
32+
33+
Running LDM3D is straighforward with the [`StableDiffusionLDM3DPipeline`]:
34+
35+
```python
36+
>>> from diffusers import StableDiffusionLDM3DPipeline
37+
38+
>>> pipe_ldm3d = StableDiffusionLDM3DPipeline.from_pretrained("Intel/ldm3d")
39+
prompt ="A picture of some lemons on a table"
40+
output = pipe_ldm3d(prompt)
41+
rgb_image, depth_image = output.rgb, output.depth
42+
rgb_image[0].save("lemons_ldm3d_rgb.jpg")
43+
depth_image[0].save("lemons_ldm3d_depth.png")
44+
```
45+
46+
47+
## StableDiffusionPipelineOutput
48+
[[autodoc]] pipelines.stable_diffusion.StableDiffusionPipelineOutput
49+
- all
50+
- __call__
51+
52+
## StableDiffusionLDM3DPipeline
53+
[[autodoc]] StableDiffusionLDM3DPipeline
54+
- all
55+
- __call__

docs/source/en/index.mdx

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,3 +94,4 @@ The library has three main components:
9494
| [versatile_diffusion](./api/pipelines/versatile_diffusion) | [Versatile Diffusion: Text, Images and Variations All in One Diffusion Model](https://arxiv.org/abs/2211.08332) | Image Variations Generation |
9595
| [versatile_diffusion](./api/pipelines/versatile_diffusion) | [Versatile Diffusion: Text, Images and Variations All in One Diffusion Model](https://arxiv.org/abs/2211.08332) | Dual Image and Text Guided Generation |
9696
| [vq_diffusion](./api/pipelines/vq_diffusion) | [Vector Quantized Diffusion Model for Text-to-Image Synthesis](https://arxiv.org/abs/2111.14822) | Text-to-Image Generation |
97+
| [stable_diffusion_ldm3d](./api/pipelines/stable_diffusion/ldm3d_diffusion) | [LDM3D: Latent Diffusion Model for 3D](https://arxiv.org/abs/2305.10853) | Text to Image and Depth Generation |

src/diffusers/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,7 @@
149149
StableDiffusionInpaintPipelineLegacy,
150150
StableDiffusionInstructPix2PixPipeline,
151151
StableDiffusionLatentUpscalePipeline,
152+
StableDiffusionLDM3DPipeline,
152153
StableDiffusionModelEditingPipeline,
153154
StableDiffusionPanoramaPipeline,
154155
StableDiffusionPipeline,

src/diffusers/image_processor.py

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -251,3 +251,109 @@ def postprocess(
251251

252252
if output_type == "pil":
253253
return self.numpy_to_pil(image)
254+
255+
256+
class VaeImageProcessorLDM3D(VaeImageProcessor):
257+
"""
258+
Image Processor for VAE LDM3D.
259+
260+
Args:
261+
do_resize (`bool`, *optional*, defaults to `True`):
262+
Whether to downscale the image's (height, width) dimensions to multiples of `vae_scale_factor`.
263+
vae_scale_factor (`int`, *optional*, defaults to `8`):
264+
VAE scale factor. If `do_resize` is True, the image will be automatically resized to multiples of this
265+
factor.
266+
resample (`str`, *optional*, defaults to `lanczos`):
267+
Resampling filter to use when resizing the image.
268+
do_normalize (`bool`, *optional*, defaults to `True`):
269+
Whether to normalize the image to [-1,1]
270+
"""
271+
272+
config_name = CONFIG_NAME
273+
274+
@register_to_config
275+
def __init__(
276+
self,
277+
do_resize: bool = True,
278+
vae_scale_factor: int = 8,
279+
resample: str = "lanczos",
280+
do_normalize: bool = True,
281+
):
282+
super().__init__()
283+
284+
@staticmethod
285+
def numpy_to_pil(images):
286+
"""
287+
Convert a numpy image or a batch of images to a PIL image.
288+
"""
289+
if images.ndim == 3:
290+
images = images[None, ...]
291+
images = (images * 255).round().astype("uint8")
292+
if images.shape[-1] == 1:
293+
# special case for grayscale (single channel) images
294+
pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images]
295+
else:
296+
pil_images = [Image.fromarray(image[:, :, :3]) for image in images]
297+
298+
return pil_images
299+
300+
@staticmethod
301+
def rgblike_to_depthmap(image):
302+
"""
303+
Args:
304+
image: RGB-like depth image
305+
306+
Returns: depth map
307+
308+
"""
309+
return image[:, :, 1] * 2**8 + image[:, :, 2]
310+
311+
def numpy_to_depth(self, images):
312+
"""
313+
Convert a numpy depth image or a batch of images to a PIL image.
314+
"""
315+
if images.ndim == 3:
316+
images = images[None, ...]
317+
images = (images * 255).round().astype("uint8")
318+
if images.shape[-1] == 1:
319+
# special case for grayscale (single channel) images
320+
raise Exception("Not supported")
321+
else:
322+
pil_images = [Image.fromarray(self.rgblike_to_depthmap(image[:, :, 3:]), mode="I;16") for image in images]
323+
324+
return pil_images
325+
326+
def postprocess(
327+
self,
328+
image: torch.FloatTensor,
329+
output_type: str = "pil",
330+
do_denormalize: Optional[List[bool]] = None,
331+
):
332+
if not isinstance(image, torch.Tensor):
333+
raise ValueError(
334+
f"Input for postprocessing is in incorrect format: {type(image)}. We only support pytorch tensor"
335+
)
336+
if output_type not in ["latent", "pt", "np", "pil"]:
337+
deprecation_message = (
338+
f"the output_type {output_type} is outdated and has been set to `np`. Please make sure to set it to one of these instead: "
339+
"`pil`, `np`, `pt`, `latent`"
340+
)
341+
deprecate("Unsupported output_type", "1.0.0", deprecation_message, standard_warn=False)
342+
output_type = "np"
343+
344+
if do_denormalize is None:
345+
do_denormalize = [self.config.do_normalize] * image.shape[0]
346+
347+
image = torch.stack(
348+
[self.denormalize(image[i]) if do_denormalize[i] else image[i] for i in range(image.shape[0])]
349+
)
350+
351+
image = self.pt_to_numpy(image)
352+
353+
if output_type == "np":
354+
return image[:, :, :, :3], np.stack([self.rgblike_to_depthmap(im[:, :, 3:]) for im in image], axis=0)
355+
356+
if output_type == "pil":
357+
return self.numpy_to_pil(image), self.numpy_to_depth(image)
358+
else:
359+
raise Exception(f"This type {output_type} is not supported")

src/diffusers/pipelines/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@
7777
StableDiffusionInpaintPipelineLegacy,
7878
StableDiffusionInstructPix2PixPipeline,
7979
StableDiffusionLatentUpscalePipeline,
80+
StableDiffusionLDM3DPipeline,
8081
StableDiffusionModelEditingPipeline,
8182
StableDiffusionPanoramaPipeline,
8283
StableDiffusionPipeline,

src/diffusers/pipelines/stable_diffusion/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ class StableDiffusionPipelineOutput(BaseOutput):
5050
from .pipeline_stable_diffusion_inpaint_legacy import StableDiffusionInpaintPipelineLegacy
5151
from .pipeline_stable_diffusion_instruct_pix2pix import StableDiffusionInstructPix2PixPipeline
5252
from .pipeline_stable_diffusion_latent_upscale import StableDiffusionLatentUpscalePipeline
53+
from .pipeline_stable_diffusion_ldm3d import StableDiffusionLDM3DPipeline
5354
from .pipeline_stable_diffusion_model_editing import StableDiffusionModelEditingPipeline
5455
from .pipeline_stable_diffusion_panorama import StableDiffusionPanoramaPipeline
5556
from .pipeline_stable_diffusion_sag import StableDiffusionSAGPipeline

0 commit comments

Comments
 (0)