diff --git a/README.md b/README.md index 54ed22b..452df33 100644 --- a/README.md +++ b/README.md @@ -11,6 +11,12 @@ Improved decoding for stable diffusion vaes. ``` $ pip install git+https://github.com/openai/consistencydecoder.git ``` +or +``` +$ git clone https://github.com/openai/consistencydecoder +$ cd consistencydecoder +$ python setup.py install +``` ## Usage @@ -26,21 +32,21 @@ pipe = StableDiffusionPipeline.from_pretrained( pipe.vae.cuda() decoder_consistency = ConsistencyDecoder(device="cuda:0") # Model size: 2.49 GB -image = load_image("assets/gt1.png", size=(256, 256), center_crop=True) +image = load_image("assets/gt1.png", size=(256, 256), center_crop=True) # alternatively, one can provide image url latent = pipe.vae.encode(image.half().cuda()).latent_dist.mean -# decode with gan -sample_gan = pipe.vae.decode(latent).sample.detach() -save_image(sample_gan, "gan.png") +# decode with VAE +sample_vae = pipe.vae.decode(latent).sample.detach() +save_image(sample_vae, "vae.png") # decode with vae sample_consistency = decoder_consistency(latent) -save_image(sample_consistency, "con.png") +save_image(sample_consistency, "consistency.png") ``` ## Examples - Original Image | GAN Decoder | Consistency Decoder | + Original Image | VAE Decoder | Consistency Decoder | :---:|:---:|:---:| -![Original Image](assets/gt1.png) | ![GAN Image](assets/gan1.png) | ![VAE Image](assets/con1.png) | -![Original Image](assets/gt2.png) | ![GAN Image](assets/gan2.png) | ![VAE Image](assets/con2.png) | -![Original Image](assets/gt3.png) | ![GAN Image](assets/gan3.png) | ![VAE Image](assets/con3.png) | +![Original Image](assets/gt1.png) | ![VAE Image](assets/vae1.png) | ![Consistency Image](assets/consistency1.png) | +![Original Image](assets/gt2.png) | ![VAE Image](assets/vae2.png) | ![Consistency Image](assets/consistency2.png) | +![Original Image](assets/gt3.png) | ![VAE Image](assets/vae3.png) | ![Consistency Image](assets/consistency3.png) | diff --git a/assets/con1.png b/assets/consistency1.png similarity index 100% rename from assets/con1.png rename to assets/consistency1.png diff --git a/assets/con2.png b/assets/consistency2.png similarity index 100% rename from assets/con2.png rename to assets/consistency2.png diff --git a/assets/con3.png b/assets/consistency3.png similarity index 100% rename from assets/con3.png rename to assets/consistency3.png diff --git a/assets/gan1.png b/assets/vae1.png similarity index 100% rename from assets/gan1.png rename to assets/vae1.png diff --git a/assets/gan2.png b/assets/vae2.png similarity index 100% rename from assets/gan2.png rename to assets/vae2.png diff --git a/assets/gan3.png b/assets/vae3.png similarity index 100% rename from assets/gan3.png rename to assets/vae3.png diff --git a/consistencydecoder/__init__.py b/consistencydecoder/__init__.py index f153dc1..b9e79d4 100644 --- a/consistencydecoder/__init__.py +++ b/consistencydecoder/__init__.py @@ -2,6 +2,7 @@ import math import os import urllib +import requests import warnings import torch @@ -190,7 +191,15 @@ def load_image(uri, size=None, center_crop=False): import numpy as np from PIL import Image - image = Image.open(uri) + if os.path.isfile(uri): + # load image from local + image = Image.open(uri) + else: + # load image by url + image = Image.open(requests.get(uri, stream=True).raw) + # handle case of grayscale and RGBA images + image = image.convert("RGB") + if center_crop: image = image.crop( ( diff --git a/requirements.txt b/requirements.txt index e32f875..e0b49bd 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,2 +1,3 @@ tqdm -torch \ No newline at end of file +torch +diffusers \ No newline at end of file