Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 15 additions & 9 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,12 @@ Improved decoding for stable diffusion vaes.
```
$ pip install git+https://github.com/openai/consistencydecoder.git
```
or
```
$ git clone https://github.com/openai/consistencydecoder
$ cd consistencydecoder
$ python setup.py install
```

## Usage

Expand All @@ -26,21 +32,21 @@ pipe = StableDiffusionPipeline.from_pretrained(
pipe.vae.cuda()
decoder_consistency = ConsistencyDecoder(device="cuda:0") # Model size: 2.49 GB

image = load_image("assets/gt1.png", size=(256, 256), center_crop=True)
image = load_image("assets/gt1.png", size=(256, 256), center_crop=True) # alternatively, one can provide image url
latent = pipe.vae.encode(image.half().cuda()).latent_dist.mean

# decode with gan
sample_gan = pipe.vae.decode(latent).sample.detach()
save_image(sample_gan, "gan.png")
# decode with VAE
sample_vae = pipe.vae.decode(latent).sample.detach()
save_image(sample_vae, "vae.png")

# decode with vae
sample_consistency = decoder_consistency(latent)
save_image(sample_consistency, "con.png")
save_image(sample_consistency, "consistency.png")
```

## Examples
Original Image | GAN Decoder | Consistency Decoder |
Original Image | VAE Decoder | Consistency Decoder |
:---:|:---:|:---:|
![Original Image](assets/gt1.png) | ![GAN Image](assets/gan1.png) | ![VAE Image](assets/con1.png) |
![Original Image](assets/gt2.png) | ![GAN Image](assets/gan2.png) | ![VAE Image](assets/con2.png) |
![Original Image](assets/gt3.png) | ![GAN Image](assets/gan3.png) | ![VAE Image](assets/con3.png) |
![Original Image](assets/gt1.png) | ![VAE Image](assets/vae1.png) | ![Consistency Image](assets/consistency1.png) |
![Original Image](assets/gt2.png) | ![VAE Image](assets/vae2.png) | ![Consistency Image](assets/consistency2.png) |
![Original Image](assets/gt3.png) | ![VAE Image](assets/vae3.png) | ![Consistency Image](assets/consistency3.png) |
File renamed without changes
File renamed without changes
File renamed without changes
File renamed without changes
File renamed without changes
File renamed without changes
11 changes: 10 additions & 1 deletion consistencydecoder/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import math
import os
import urllib
import requests
import warnings

import torch
Expand Down Expand Up @@ -190,7 +191,15 @@ def load_image(uri, size=None, center_crop=False):
import numpy as np
from PIL import Image

image = Image.open(uri)
if os.path.isfile(uri):
# load image from local
image = Image.open(uri)
else:
# load image by url
image = Image.open(requests.get(uri, stream=True).raw)
# handle case of grayscale and RGBA images
image = image.convert("RGB")

if center_crop:
image = image.crop(
(
Expand Down
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
tqdm
torch
torch
diffusers