diff --git a/.github/FUNDING.yml b/.github/FUNDING.yml
index 81ba6f6ef..e7a475049 100644
--- a/.github/FUNDING.yml
+++ b/.github/FUNDING.yml
@@ -10,4 +10,4 @@ liberapay: # Replace with a single Liberapay username
issuehunt: # Replace with a single IssueHunt username
otechie: # Replace with a single Otechie username
lfx_crowdfunding: # Replace with a single LFX Crowdfunding project-name e.g., cloud-foundry
-custom: ['https://paypal.me/basuj']
+custom: [ 'https://paypal.me/basuj' ]
diff --git a/README.md b/README.md
index ea33dcfdc..e5b6faba5 100644
--- a/README.md
+++ b/README.md
@@ -1,35 +1,53 @@
-# Update: v0.8
+# Update: v0.9 neon optimization edition
[Added support for inpainting](#inpainting)
Optimized Stable Diffusion
-
-
-
+
+
+
+
+
+
-This repo is a modified version of the Stable Diffusion repo, optimized to use less VRAM than the original by sacrificing inference speed.
+This repo is a modified version of the Stable Diffusion repo, optimized to use less VRAM than the original by
+sacrificing inference speed.
-To achieve this, the stable diffusion model is fragmented into four parts which are sent to the GPU only when needed. After the calculation is done, they are moved back to the CPU. This allows us to run a bigger model while requiring less VRAM.
+To achieve this, the stable diffusion model is fragmented into four parts which are sent to the GPU only when needed.
+After the calculation is done, they are moved back to the CPU. This allows us to run a bigger model while requiring less
+VRAM.
Installation
-All the modified files are in the [optimizedSD](optimizedSD) folder, so if you have already cloned the original repository you can just download and copy this folder into the original instead of cloning the entire repo. You can also clone this repo and follow the same installation steps as the original (mainly creating the conda environment and placing the weights at the specified location).
+All the modified files are in the [optimizedSD](optimizedSD) folder, so if you have already cloned the original
+repository you can just download and copy this folder into the original instead of cloning the entire repo. You can also
+clone this repo and follow the same installation steps as the original (mainly creating the conda environment and
+placing the weights at the specified location).
+So run:
+`conda env create -f environment.yaml`
+`conda activate ldm`
Alternatively, if you prefer to use Docker, you can do the following:
-1. Install [Docker](https://docs.docker.com/engine/install/), [Docker Compose plugin](https://docs.docker.com/compose/install/), and [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html#docker)
+
+1. Install [Docker](https://docs.docker.com/engine/install/)
+ , [Docker Compose plugin](https://docs.docker.com/compose/install/),
+ and [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html#docker)
2. Clone this repo to, e.g., `~/stable-diffusion`
-3. Put your downloaded `model.ckpt` file into `~/sd-data` (it's a relative path, you can change it in `docker-compose.yml`)
+3. Put your downloaded `model.ckpt` file into `~/sd-data` (it's a relative path, you can change it
+ in `docker-compose.yml`)
4. `cd` into `~/stable-diffusion` and execute `docker compose up --build`
-This will launch gradio on port 7860 with txt2img. You can also use `docker compose run` to execute other Python scripts.
+This will launch gradio on port 7860 with txt2img. You can also use `docker compose run` to execute other Python
+scripts.
Usage
## img2img
-- `img2img` can generate _512x512 images from a prior image and prompt on a 4GB VRAM GPU in under 20 seconds per image_ on an RTX 2060.
+- `img2img` can generate _512x512 images from a prior image and prompt on a 4GB VRAM GPU in under 20 seconds per image_
+ on an RTX 2060.
- The maximum size that can fit on 6GB GPU (RTX 2060) is around 576x768.
@@ -47,45 +65,57 @@ This will launch gradio on port 7860 with txt2img. You can also use `docker comp
## inpainting
-- `inpaint_gradio.py` can fill masked parts of an image based on a given prompt. It can inpaint 512x512 images while using under 4GB of VRAM.
+- `inpaint_gradio.py` can fill masked parts of an image based on a given prompt. It can inpaint 512x512 images while
+ using under 4GB of VRAM.
-- To launch the gradio interface for inpainting, run `python optimizedSD/inpaint_gradio.py`. The mask for the image can be drawn on the selected image using the brush tool.
+- To launch the gradio interface for inpainting, run `python optimizedSD/inpaint_gradio.py`. The mask for the image can
+ be drawn on the selected image using the brush tool.
-- The results are not yet perfect but can be improved by using a combination of prompt weighting, prompt engineering and testing out multiple values of the `--strength` argument.
+- The results are not yet perfect but can be improved by using a combination of prompt weighting, prompt engineering and
+ testing out multiple values of the `--strength` argument.
- _Suggestions to improve the inpainting algorithm are most welcome_.
Using the Gradio GUI
-- You can also use the built-in gradio interface for `img2img`, `txt2img` & `inpainting` instead of the command line interface. Activate the conda environment and install the latest version of gradio using `pip install gradio`,
+- You can also use the built-in gradio interface for `img2img`, `txt2img` & `inpainting` instead of the command line
+ interface. Activate the conda environment and install the latest version of gradio using `pip install gradio`,
-- Run img2img using `python optimizedSD/img2img_gradio.py`, txt2img using `python optimizedSD/txt2img_gradio.py` and inpainting using `python optimizedSD/inpaint_gradio.py`.
+- Run img2img using `python optimizedSD/img2img_gradio.py`, txt2img using `python optimizedSD/txt2img_gradio.py` and
+ inpainting using `python optimizedSD/inpaint_gradio.py`.
-- img2img_gradio.py has a feature to crop input images. Look for the pen symbol in the image box after selecting the image.
+- img2img_gradio.py has a feature to crop input images. Look for the pen symbol in the image box after selecting the
+ image.
Arguments
## `--seed`
-**Seed for image generation**, can be used to reproduce previously generated images. Defaults to a random seed if unspecified.
+**Seed for image generation**, can be used to reproduce previously generated images. Defaults to a random seed if
+unspecified.
-- The code will give the seed number along with each generated image. To generate the same image again, just specify the seed using `--seed` argument. Images are saved with its seed number as its name by default.
+- The code will give the seed number along with each generated image. To generate the same image again, just specify the
+ seed using `--seed` argument. Images are saved with its seed number as its name by default.
-- For example if the seed number for an image is `1234` and it's the 55th image in the folder, the image name will be named `seed_1234_00055.png`.
+- For example if the seed number for an image is `1234` and it's the 55th image in the folder, the image name will be
+ named `seed_1234_00055.png`.
## `--n_samples`
**Batch size/amount of images to generate at once.**
-- To get the lowest inference time per image, use the maximum batch size `--n_samples` that can fit on the GPU. Inference time per image will reduce on increasing the batch size, but the required VRAM will increase.
+- To get the lowest inference time per image, use the maximum batch size `--n_samples` that can fit on the GPU.
+ Inference time per image will reduce on increasing the batch size, but the required VRAM will increase.
-- If you get a CUDA out of memory error, try reducing the batch size `--n_samples`. If it doesn't work, the other option is to reduce the image width `--W` or height `--H` or both.
+- If you get a CUDA out of memory error, try reducing the batch size `--n_samples`. If it doesn't work, the other option
+ is to reduce the image width `--W` or height `--H` or both.
## `--n_iter`
**Run _x_ amount of times**
-- Equivalent to running the script n_iter number of times. Only difference is that the model is loaded only once per n_iter iterations. Unlike `n_samples`, reducing it doesn't have an effect on VRAM required or inference time.
+- Equivalent to running the script n_iter number of times. Only difference is that the model is loaded only once per
+ n_iter iterations. Unlike `n_samples`, reducing it doesn't have an effect on VRAM required or inference time.
## `--H` & `--W`
@@ -97,19 +127,23 @@ This will launch gradio on port 7860 with txt2img. You can also use `docker comp
**Increases inference speed at the cost of extra VRAM usage.**
-- Using this argument increases the inference speed by using around 1GB of extra GPU VRAM. It is especially effective when generating a small batch of images (~ 1 to 4) images. It takes under 25 seconds for txt2img and 15 seconds for img2img (on an RTX 2060, excluding the time to load the model). Use it on larger batch sizes if GPU VRAM available.
+- Using this argument increases the inference speed by using around 1GB of extra GPU VRAM. It is especially effective
+ when generating a small batch of images (~ 1 to 4) images. It takes under 25 seconds for txt2img and 15 seconds for
+ img2img (on an RTX 2060, excluding the time to load the model). Use it on larger batch sizes if GPU VRAM available.
## `--precision autocast` or `--precision full`
**Whether to use `full` or `mixed` precision**
-- Mixed Precision is enabled by default. If you don't have a GPU with tensor cores (any GTX 10 series card), you may not be able use mixed precision. Use the `--precision full` argument to disable it.
+- Mixed Precision is enabled by default. If you don't have a GPU with tensor cores (any GTX 10 series card), you may not
+ be able use mixed precision. Use the `--precision full` argument to disable it.
## `--format png` or `--format jpg`
**Output image format**
-- The default output format is `png`. While `png` is lossless, it takes up a lot of space (unless large portions of the image happen to be a single colour). Use lossy `jpg` to get smaller image file sizes.
+- The default output format is `png`. While `png` is lossless, it takes up a lot of space (unless large portions of the
+ image happen to be a single colour). Use lossy `jpg` to get smaller image file sizes.
## `--unet_bs`
@@ -124,13 +158,15 @@ This will launch gradio on port 7860 with txt2img. You can also use `docker comp
- Prompts can also be weighted to put relative emphasis on certain words.
eg. `--prompt tabby cat:0.25 white duck:0.75 hybrid`.
-- The number followed by the colon represents the weight given to the words before the colon. The weights can be both fractions or integers.
+- The number followed by the colon represents the weight given to the words before the colon. The weights can be both
+ fractions or integers.
## Changelog
- v0.8: Added gradio interface for inpainting.
- v0.7: Added support for logging, jpg file format
-- v0.6: Added support for using weighted prompts. (based on @lstein's [repo](https://github.com/lstein/stable-diffusion))
+- v0.6: Added support for using weighted prompts. (based on
+ @lstein's [repo](https://github.com/lstein/stable-diffusion))
- v0.5: Added support for using gradio interface.
- v0.4: Added support for specifying image seed.
- v0.3: Added support for using mixed precision.
diff --git a/Stable_Diffusion_v1_Model_Card.md b/Stable_Diffusion_v1_Model_Card.md
index 2cbf99bd2..c510940d1 100644
--- a/Stable_Diffusion_v1_Model_Card.md
+++ b/Stable_Diffusion_v1_Model_Card.md
@@ -1,13 +1,20 @@
# Stable Diffusion v1 Model Card
-This model card focuses on the model associated with the Stable Diffusion model, available [here](https://github.com/CompVis/stable-diffusion).
+
+This model card focuses on the model associated with the Stable Diffusion model,
+available [here](https://github.com/CompVis/stable-diffusion).
## Model Details
+
- **Developed by:** Robin Rombach, Patrick Esser
- **Model type:** Diffusion-based text-to-image generation model
- **Language(s):** English
- **License:** [Proprietary](LICENSE)
-- **Model Description:** This is a model that can be used to generate and modify images based on text prompts. It is a [Latent Diffusion Model](https://arxiv.org/abs/2112.10752) that uses a fixed, pretrained text encoder ([CLIP ViT-L/14](https://arxiv.org/abs/2103.00020)) as suggested in the [Imagen paper](https://arxiv.org/abs/2205.11487).
-- **Resources for more information:** [GitHub Repository](https://github.com/CompVis/stable-diffusion), [Paper](https://arxiv.org/abs/2112.10752).
+- **Model Description:** This is a model that can be used to generate and modify images based on text prompts. It is
+ a [Latent Diffusion Model](https://arxiv.org/abs/2112.10752) that uses a fixed, pretrained text
+ encoder ([CLIP ViT-L/14](https://arxiv.org/abs/2103.00020)) as suggested in
+ the [Imagen paper](https://arxiv.org/abs/2205.11487).
+- **Resources for more information:** [GitHub Repository](https://github.com/CompVis/stable-diffusion)
+ , [Paper](https://arxiv.org/abs/2112.10752).
- **Cite as:**
@InProceedings{Rombach_2022_CVPR,
@@ -21,9 +28,9 @@ This model card focuses on the model associated with the Stable Diffusion model,
# Uses
-## Direct Use
-The model is intended for research purposes only. Possible research areas and
-tasks include
+## Direct Use
+
+The model is intended for research purposes only. Possible research areas and tasks include
- Safe deployment of models which have the potential to generate harmful content.
- Probing and understanding the limitations and biases of generative models.
@@ -33,17 +40,27 @@ tasks include
Excluded uses are described below.
- ### Misuse, Malicious Use, and Out-of-Scope Use
-_Note: This section is taken from the [DALLE-MINI model card](https://huggingface.co/dalle-mini/dalle-mini), but applies in the same way to Stable Diffusion v1_.
+### Misuse, Malicious Use, and Out-of-Scope Use
+_Note: This section is taken from the [DALLE-MINI model card](https://huggingface.co/dalle-mini/dalle-mini), but applies
+in the same way to Stable Diffusion v1_.
+
+The model should not be used to intentionally create or disseminate images that create hostile or alienating
+environments for people. This includes generating images that people would foreseeably find disturbing, distressing, or
+offensive; or content that propagates historical or current stereotypes.
-The model should not be used to intentionally create or disseminate images that create hostile or alienating environments for people. This includes generating images that people would foreseeably find disturbing, distressing, or offensive; or content that propagates historical or current stereotypes.
#### Out-of-Scope Use
-The model was not trained to be factual or true representations of people or events, and therefore using the model to generate such content is out-of-scope for the abilities of this model.
+
+The model was not trained to be factual or true representations of people or events, and therefore using the model to
+generate such content is out-of-scope for the abilities of this model.
+
#### Misuse and Malicious Use
-Using the model to generate content that is cruel to individuals is a misuse of this model. This includes, but is not limited to:
-- Generating demeaning, dehumanizing, or otherwise harmful representations of people or their environments, cultures, religions, etc.
+Using the model to generate content that is cruel to individuals is a misuse of this model. This includes, but is not
+limited to:
+
+- Generating demeaning, dehumanizing, or otherwise harmful representations of people or their environments, cultures,
+ religions, etc.
- Intentionally promoting or propagating discriminatory content or harmful stereotypes.
- Impersonating individuals without their consent.
- Sexual content without consent of the people who might see it.
@@ -58,23 +75,23 @@ Using the model to generate content that is cruel to individuals is a misuse of
- The model does not achieve perfect photorealism
- The model cannot render legible text
-- The model does not perform well on more difficult tasks which involve compositionality, such as rendering an image corresponding to “A red cube on top of a blue sphere”
+- The model does not perform well on more difficult tasks which involve compositionality, such as rendering an image
+ corresponding to “A red cube on top of a blue sphere”
- Faces and people in general may not be generated properly.
- The model was trained mainly with English captions and will not work as well in other languages.
- The autoencoding part of the model is lossy
- The model was trained on a large-scale dataset
- [LAION-5B](https://laion.ai/blog/laion-5b/) which contains adult material
- and is not fit for product use without additional safety mechanisms and
- considerations.
+ [LAION-5B](https://laion.ai/blog/laion-5b/) which contains adult material and is not fit for product use without
+ additional safety mechanisms and considerations.
### Bias
-While the capabilities of image generation models are impressive, they can also reinforce or exacerbate social biases.
-Stable Diffusion v1 was trained on subsets of [LAION-2B(en)](https://laion.ai/blog/laion-5b/),
-which consists of images that are primarily limited to English descriptions.
-Texts and images from communities and cultures that use other languages are likely to be insufficiently accounted for.
-This affects the overall output of the model, as white and western cultures are often set as the default. Further, the
-ability of the model to generate content with non-English prompts is significantly worse than with English-language prompts.
+While the capabilities of image generation models are impressive, they can also reinforce or exacerbate social biases.
+Stable Diffusion v1 was trained on subsets of [LAION-2B(en)](https://laion.ai/blog/laion-5b/), which consists of images
+that are primarily limited to English descriptions. Texts and images from communities and cultures that use other
+languages are likely to be insufficiently accounted for. This affects the overall output of the model, as white and
+western cultures are often set as the default. Further, the ability of the model to generate content with non-English
+prompts is significantly worse than with English-language prompts.
## Training
@@ -84,22 +101,32 @@ The model developers used the following dataset for training the model:
- LAION-2B (en) and subsets thereof (see next section)
**Training Procedure**
-Stable Diffusion v1 is a latent diffusion model which combines an autoencoder with a diffusion model that is trained in the latent space of the autoencoder. During training,
+Stable Diffusion v1 is a latent diffusion model which combines an autoencoder with a diffusion model that is trained in
+the latent space of the autoencoder. During training,
-- Images are encoded through an encoder, which turns images into latent representations. The autoencoder uses a relative downsampling factor of 8 and maps images of shape H x W x 3 to latents of shape H/f x W/f x 4
+- Images are encoded through an encoder, which turns images into latent representations. The autoencoder uses a relative
+ downsampling factor of 8 and maps images of shape H x W x 3 to latents of shape H/f x W/f x 4
- Text prompts are encoded through a ViT-L/14 text-encoder.
-- The non-pooled output of the text encoder is fed into the UNet backbone of the latent diffusion model via cross-attention.
-- The loss is a reconstruction objective between the noise that was added to the latent and the prediction made by the UNet.
+- The non-pooled output of the text encoder is fed into the UNet backbone of the latent diffusion model via
+ cross-attention.
+- The loss is a reconstruction objective between the noise that was added to the latent and the prediction made by the
+ UNet.
-We currently provide three checkpoints, `sd-v1-1.ckpt`, `sd-v1-2.ckpt` and `sd-v1-3.ckpt`,
-which were trained as follows,
+We currently provide three checkpoints, `sd-v1-1.ckpt`, `sd-v1-2.ckpt` and `sd-v1-3.ckpt`, which were trained as
+follows,
- `sd-v1-1.ckpt`: 237k steps at resolution `256x256` on [laion2B-en](https://huggingface.co/datasets/laion/laion2B-en).
- 194k steps at resolution `512x512` on [laion-high-resolution](https://huggingface.co/datasets/laion/laion-high-resolution) (170M examples from LAION-5B with resolution `>= 1024x1024`).
-- `sd-v1-2.ckpt`: Resumed from `sd-v1-1.ckpt`.
- 515k steps at resolution `512x512` on "laion-improved-aesthetics" (a subset of laion2B-en,
-filtered to images with an original size `>= 512x512`, estimated aesthetics score `> 5.0`, and an estimated watermark probability `< 0.5`. The watermark estimate is from the LAION-5B metadata, the aesthetics score is estimated using an [improved aesthetics estimator](https://github.com/christophschuhmann/improved-aesthetic-predictor)).
-- `sd-v1-3.ckpt`: Resumed from `sd-v1-2.ckpt`. 195k steps at resolution `512x512` on "laion-improved-aesthetics" and 10\% dropping of the text-conditioning to improve [classifier-free guidance sampling](https://arxiv.org/abs/2207.12598).
+ 194k steps at resolution `512x512`
+ on [laion-high-resolution](https://huggingface.co/datasets/laion/laion-high-resolution) (170M examples from LAION-5B
+ with resolution `>= 1024x1024`).
+- `sd-v1-2.ckpt`: Resumed from `sd-v1-1.ckpt`. 515k steps at resolution `512x512` on "laion-improved-aesthetics" (a
+ subset of laion2B-en, filtered to images with an original size `>= 512x512`, estimated aesthetics score `> 5.0`, and
+ an estimated watermark probability `< 0.5`. The watermark estimate is from the LAION-5B metadata, the aesthetics score
+ is estimated using
+ an [improved aesthetics estimator](https://github.com/christophschuhmann/improved-aesthetic-predictor)).
+- `sd-v1-3.ckpt`: Resumed from `sd-v1-2.ckpt`. 195k steps at resolution `512x512` on "laion-improved-aesthetics" and
+ 10\% dropping of the text-conditioning to
+ improve [classifier-free guidance sampling](https://arxiv.org/abs/2207.12598).
- **Hardware:** 32 x 8 x A100 GPUs
@@ -108,25 +135,32 @@ filtered to images with an original size `>= 512x512`, estimated aesthetics scor
- **Batch:** 32 x 8 x 2 x 4 = 2048
- **Learning rate:** warmup to 0.0001 for 10,000 steps and then kept constant
-## Evaluation Results
-Evaluations with different classifier-free guidance scales (1.5, 2.0, 3.0, 4.0,
-5.0, 6.0, 7.0, 8.0) and 50 PLMS sampling
+## Evaluation Results
+
+Evaluations with different classifier-free guidance scales (1.5, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0) and 50 PLMS sampling
steps show the relative improvements of the checkpoints:
-
+
+
+Evaluated using 50 PLMS steps and 10000 random prompts from the COCO2017 validation set, evaluated at 512x512
+resolution. Not optimized for FID scores.
-Evaluated using 50 PLMS steps and 10000 random prompts from the COCO2017 validation set, evaluated at 512x512 resolution. Not optimized for FID scores.
## Environmental Impact
**Stable Diffusion v1** **Estimated Emissions**
-Based on that information, we estimate the following CO2 emissions using the [Machine Learning Impact calculator](https://mlco2.github.io/impact#compute) presented in [Lacoste et al. (2019)](https://arxiv.org/abs/1910.09700). The hardware, runtime, cloud provider, and compute region were utilized to estimate the carbon impact.
+Based on that information, we estimate the following CO2 emissions using
+the [Machine Learning Impact calculator](https://mlco2.github.io/impact#compute) presented
+in [Lacoste et al. (2019)](https://arxiv.org/abs/1910.09700). The hardware, runtime, cloud provider, and compute region
+were utilized to estimate the carbon impact.
- **Hardware Type:** A100 PCIe 40GB
- **Hours used:** 150000
- **Cloud Provider:** AWS
- **Compute Region:** US-east
- **Carbon Emitted (Power consumption x Time x Carbon produced based on location of power grid):** 11250 kg CO2 eq.
+
## Citation
+
@InProceedings{Rombach_2022_CVPR,
author = {Rombach, Robin and Blattmann, Andreas and Lorenz, Dominik and Esser, Patrick and Ommer, Bj\"orn},
title = {High-Resolution Image Synthesis With Latent Diffusion Models},
@@ -136,5 +170,6 @@ Based on that information, we estimate the following CO2 emissions using the [Ma
pages = {10684-10695}
}
-*This model card was written by: Robin Rombach and Patrick Esser and is based on the [DALL-E Mini model card](https://huggingface.co/dalle-mini/dalle-mini).*
+*This model card was written by: Robin Rombach and Patrick Esser and is based on
+the [DALL-E Mini model card](https://huggingface.co/dalle-mini/dalle-mini).*
diff --git a/configs/autoencoder/autoencoder_kl_16x16x16.yaml b/configs/autoencoder/autoencoder_kl_16x16x16.yaml
index 5f1d10ec7..b218c5eed 100644
--- a/configs/autoencoder/autoencoder_kl_16x16x16.yaml
+++ b/configs/autoencoder/autoencoder_kl_16x16x16.yaml
@@ -18,9 +18,9 @@ model:
in_channels: 3
out_ch: 3
ch: 128
- ch_mult: [ 1,1,2,2,4] # num_down = len(ch_mult)-1
+ ch_mult: [ 1,1,2,2,4 ] # num_down = len(ch_mult)-1
num_res_blocks: 2
- attn_resolutions: [16]
+ attn_resolutions: [ 16 ]
dropout: 0.0
diff --git a/configs/autoencoder/autoencoder_kl_8x8x64.yaml b/configs/autoencoder/autoencoder_kl_8x8x64.yaml
index 5ccd09d38..4d1bd20ce 100644
--- a/configs/autoencoder/autoencoder_kl_8x8x64.yaml
+++ b/configs/autoencoder/autoencoder_kl_8x8x64.yaml
@@ -18,9 +18,9 @@ model:
in_channels: 3
out_ch: 3
ch: 128
- ch_mult: [ 1,1,2,2,4,4] # num_down = len(ch_mult)-1
+ ch_mult: [ 1,1,2,2,4,4 ] # num_down = len(ch_mult)-1
num_res_blocks: 2
- attn_resolutions: [16,8]
+ attn_resolutions: [ 16,8 ]
dropout: 0.0
data:
diff --git a/configs/latent-diffusion/celebahq-ldm-vq-4.yaml b/configs/latent-diffusion/celebahq-ldm-vq-4.yaml
index 89b3df4fe..9bc6b5c23 100644
--- a/configs/latent-diffusion/celebahq-ldm-vq-4.yaml
+++ b/configs/latent-diffusion/celebahq-ldm-vq-4.yaml
@@ -20,19 +20,19 @@ model:
out_channels: 3
model_channels: 224
attention_resolutions:
- # note: this isn\t actually the resolution but
- # the downsampling factor, i.e. this corresnponds to
- # attention on spatial resolution 8,16,32, as the
- # spatial reolution of the latents is 64 for f4
- - 8
- - 4
- - 2
+ # note: this isn\t actually the resolution but
+ # the downsampling factor, i.e. this corresnponds to
+ # attention on spatial resolution 8,16,32, as the
+ # spatial reolution of the latents is 64 for f4
+ - 8
+ - 4
+ - 2
num_res_blocks: 2
channel_mult:
- - 1
- - 2
- - 3
- - 4
+ - 1
+ - 2
+ - 3
+ - 4
num_head_channels: 32
first_stage_config:
target: ldm.models.autoencoder.VQModelInterface
@@ -48,11 +48,11 @@ model:
out_ch: 3
ch: 128
ch_mult:
- - 1
- - 2
- - 4
+ - 1
+ - 2
+ - 4
num_res_blocks: 2
- attn_resolutions: []
+ attn_resolutions: [ ]
dropout: 0.0
lossconfig:
target: torch.nn.Identity
diff --git a/configs/latent-diffusion/cin-ldm-vq-f8.yaml b/configs/latent-diffusion/cin-ldm-vq-f8.yaml
index b8cd9e2ef..487183747 100644
--- a/configs/latent-diffusion/cin-ldm-vq-f8.yaml
+++ b/configs/latent-diffusion/cin-ldm-vq-f8.yaml
@@ -22,18 +22,18 @@ model:
out_channels: 4
model_channels: 256
attention_resolutions:
- #note: this isn\t actually the resolution but
- # the downsampling factor, i.e. this corresnponds to
- # attention on spatial resolution 8,16,32, as the
- # spatial reolution of the latents is 32 for f8
- - 4
- - 2
- - 1
+ #note: this isn\t actually the resolution but
+ # the downsampling factor, i.e. this corresnponds to
+ # attention on spatial resolution 8,16,32, as the
+ # spatial reolution of the latents is 32 for f8
+ - 4
+ - 2
+ - 1
num_res_blocks: 2
channel_mult:
- - 1
- - 2
- - 4
+ - 1
+ - 2
+ - 4
num_head_channels: 32
use_spatial_transformer: true
transformer_depth: 1
@@ -52,13 +52,13 @@ model:
out_ch: 3
ch: 128
ch_mult:
- - 1
- - 2
- - 2
- - 4
+ - 1
+ - 2
+ - 2
+ - 4
num_res_blocks: 2
attn_resolutions:
- - 32
+ - 32
dropout: 0.0
lossconfig:
target: torch.nn.Identity
diff --git a/configs/latent-diffusion/cin256-v2.yaml b/configs/latent-diffusion/cin256-v2.yaml
index b7c1aa240..274806b10 100644
--- a/configs/latent-diffusion/cin256-v2.yaml
+++ b/configs/latent-diffusion/cin256-v2.yaml
@@ -15,7 +15,7 @@ model:
conditioning_key: crossattn
monitor: val/loss
use_ema: False
-
+
unet_config:
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
params:
@@ -24,20 +24,20 @@ model:
out_channels: 3
model_channels: 192
attention_resolutions:
- - 8
- - 4
- - 2
+ - 8
+ - 4
+ - 2
num_res_blocks: 2
channel_mult:
- - 1
- - 2
- - 3
- - 5
+ - 1
+ - 2
+ - 3
+ - 5
num_heads: 1
use_spatial_transformer: true
transformer_depth: 1
context_dim: 512
-
+
first_stage_config:
target: ldm.models.autoencoder.VQModelInterface
params:
@@ -51,15 +51,15 @@ model:
out_ch: 3
ch: 128
ch_mult:
- - 1
- - 2
- - 4
+ - 1
+ - 2
+ - 4
num_res_blocks: 2
- attn_resolutions: []
+ attn_resolutions: [ ]
dropout: 0.0
lossconfig:
target: torch.nn.Identity
-
+
cond_stage_config:
target: ldm.modules.encoders.modules.ClassEmbedder
params:
diff --git a/configs/latent-diffusion/ffhq-ldm-vq-4.yaml b/configs/latent-diffusion/ffhq-ldm-vq-4.yaml
index 1899e30f7..e0610cfb0 100644
--- a/configs/latent-diffusion/ffhq-ldm-vq-4.yaml
+++ b/configs/latent-diffusion/ffhq-ldm-vq-4.yaml
@@ -19,19 +19,19 @@ model:
out_channels: 3
model_channels: 224
attention_resolutions:
- # note: this isn\t actually the resolution but
- # the downsampling factor, i.e. this corresnponds to
- # attention on spatial resolution 8,16,32, as the
- # spatial reolution of the latents is 64 for f4
- - 8
- - 4
- - 2
+ # note: this isn\t actually the resolution but
+ # the downsampling factor, i.e. this corresnponds to
+ # attention on spatial resolution 8,16,32, as the
+ # spatial reolution of the latents is 64 for f4
+ - 8
+ - 4
+ - 2
num_res_blocks: 2
channel_mult:
- - 1
- - 2
- - 3
- - 4
+ - 1
+ - 2
+ - 3
+ - 4
num_head_channels: 32
first_stage_config:
target: ldm.models.autoencoder.VQModelInterface
@@ -47,11 +47,11 @@ model:
out_ch: 3
ch: 128
ch_mult:
- - 1
- - 2
- - 4
+ - 1
+ - 2
+ - 4
num_res_blocks: 2
- attn_resolutions: []
+ attn_resolutions: [ ]
dropout: 0.0
lossconfig:
target: torch.nn.Identity
diff --git a/configs/latent-diffusion/lsun_bedrooms-ldm-vq-4.yaml b/configs/latent-diffusion/lsun_bedrooms-ldm-vq-4.yaml
index c4ca66c16..6237814ba 100644
--- a/configs/latent-diffusion/lsun_bedrooms-ldm-vq-4.yaml
+++ b/configs/latent-diffusion/lsun_bedrooms-ldm-vq-4.yaml
@@ -19,19 +19,19 @@ model:
out_channels: 3
model_channels: 224
attention_resolutions:
- # note: this isn\t actually the resolution but
- # the downsampling factor, i.e. this corresnponds to
- # attention on spatial resolution 8,16,32, as the
- # spatial reolution of the latents is 64 for f4
- - 8
- - 4
- - 2
+ # note: this isn\t actually the resolution but
+ # the downsampling factor, i.e. this corresnponds to
+ # attention on spatial resolution 8,16,32, as the
+ # spatial reolution of the latents is 64 for f4
+ - 8
+ - 4
+ - 2
num_res_blocks: 2
channel_mult:
- - 1
- - 2
- - 3
- - 4
+ - 1
+ - 2
+ - 3
+ - 4
num_head_channels: 32
first_stage_config:
target: ldm.models.autoencoder.VQModelInterface
@@ -47,11 +47,11 @@ model:
out_ch: 3
ch: 128
ch_mult:
- - 1
- - 2
- - 4
+ - 1
+ - 2
+ - 4
num_res_blocks: 2
- attn_resolutions: []
+ attn_resolutions: [ ]
dropout: 0.0
lossconfig:
target: torch.nn.Identity
diff --git a/configs/latent-diffusion/lsun_churches-ldm-kl-8.yaml b/configs/latent-diffusion/lsun_churches-ldm-kl-8.yaml
index 18dc8c2d9..ef69fba98 100644
--- a/configs/latent-diffusion/lsun_churches-ldm-kl-8.yaml
+++ b/configs/latent-diffusion/lsun_churches-ldm-kl-8.yaml
@@ -20,11 +20,11 @@ model:
scheduler_config: # 10000 warmup steps
target: ldm.lr_scheduler.LambdaLinearScheduler
params:
- warm_up_steps: [10000]
- cycle_lengths: [10000000000000]
- f_start: [1.e-6]
- f_max: [1.]
- f_min: [ 1.]
+ warm_up_steps: [ 10000 ]
+ cycle_lengths: [ 10000000000000 ]
+ f_start: [ 1.e-6 ]
+ f_max: [ 1. ]
+ f_min: [ 1. ]
unet_config:
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
diff --git a/configs/latent-diffusion/txt2img-1p4B-eval.yaml b/configs/latent-diffusion/txt2img-1p4B-eval.yaml
index 8e331cbfd..4f985d6a6 100644
--- a/configs/latent-diffusion/txt2img-1p4B-eval.yaml
+++ b/configs/latent-diffusion/txt2img-1p4B-eval.yaml
@@ -25,15 +25,15 @@ model:
out_channels: 4
model_channels: 320
attention_resolutions:
- - 4
- - 2
- - 1
+ - 4
+ - 2
+ - 1
num_res_blocks: 2
channel_mult:
- - 1
- - 2
- - 4
- - 4
+ - 1
+ - 2
+ - 4
+ - 4
num_heads: 8
use_spatial_transformer: true
transformer_depth: 1
@@ -54,12 +54,12 @@ model:
out_ch: 3
ch: 128
ch_mult:
- - 1
- - 2
- - 4
- - 4
+ - 1
+ - 2
+ - 4
+ - 4
num_res_blocks: 2
- attn_resolutions: []
+ attn_resolutions: [ ]
dropout: 0.0
lossconfig:
target: torch.nn.Identity
diff --git a/configs/retrieval-augmented-diffusion/768x768.yaml b/configs/retrieval-augmented-diffusion/768x768.yaml
index b51b1d837..5aca7ec0b 100644
--- a/configs/retrieval-augmented-diffusion/768x768.yaml
+++ b/configs/retrieval-augmented-diffusion/768x768.yaml
@@ -24,15 +24,15 @@ model:
out_channels: 16
model_channels: 448
attention_resolutions:
- - 4
- - 2
- - 1
+ - 4
+ - 2
+ - 1
num_res_blocks: 2
channel_mult:
- - 1
- - 2
- - 3
- - 4
+ - 1
+ - 2
+ - 3
+ - 4
use_scale_shift_norm: false
resblock_updown: false
num_head_channels: 32
@@ -53,14 +53,14 @@ model:
out_ch: 3
ch: 128
ch_mult:
- - 1
- - 1
- - 2
- - 2
- - 4
+ - 1
+ - 1
+ - 2
+ - 2
+ - 4
num_res_blocks: 2
attn_resolutions:
- - 16
+ - 16
dropout: 0.0
lossconfig:
target: torch.nn.Identity
diff --git a/configs/stable-diffusion/v1-inference.yaml b/configs/stable-diffusion/v1-inference.yaml
index d4effe569..a7434069f 100644
--- a/configs/stable-diffusion/v1-inference.yaml
+++ b/configs/stable-diffusion/v1-inference.yaml
@@ -56,12 +56,12 @@ model:
out_ch: 3
ch: 128
ch_mult:
- - 1
- - 2
- - 4
- - 4
+ - 1
+ - 2
+ - 4
+ - 4
num_res_blocks: 2
- attn_resolutions: []
+ attn_resolutions: [ ]
dropout: 0.0
lossconfig:
target: torch.nn.Identity
diff --git a/environment.yaml b/environment.yaml
index 7f25da800..a83aa546e 100644
--- a/environment.yaml
+++ b/environment.yaml
@@ -10,20 +10,20 @@ dependencies:
- torchvision=0.12.0
- numpy=1.19.2
- pip:
- - albumentations==0.4.3
- - opencv-python==4.1.2.30
- - pudb==2019.2
- - imageio==2.9.0
- - imageio-ffmpeg==0.4.2
- - pytorch-lightning==1.4.2
- - omegaconf==2.1.1
- - test-tube>=0.7.5
- - streamlit>=0.73.1
- - einops==0.3.0
- - torch-fidelity==0.3.0
- - transformers==4.19.2
- - torchmetrics==0.6.0
- - kornia==0.6
- - -e git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers
- - -e git+https://github.com/openai/CLIP.git@main#egg=clip
- - -e .
+ - albumentations==0.4.3
+ - opencv-python==4.1.2.30
+ - pudb==2019.2
+ - imageio==2.9.0
+ - imageio-ffmpeg==0.4.2
+ - pytorch-lightning==1.4.2
+ - omegaconf==2.1.1
+ - test-tube>=0.7.5
+ - streamlit>=0.73.1
+ - einops==0.3.0
+ - torch-fidelity==0.3.0
+ - transformers==4.19.2
+ - torchmetrics==0.6.0
+ - kornia==0.6
+ - -e git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers
+ - -e git+https://github.com/openai/CLIP.git@main#egg=clip
+ - -e .
diff --git a/ldm/data/base.py b/ldm/data/base.py
index b196c2f7a..e1beae49f 100644
--- a/ldm/data/base.py
+++ b/ldm/data/base.py
@@ -6,6 +6,7 @@ class Txt2ImgIterableBaseDataset(IterableDataset):
'''
Define an interface to make the IterableDatasets for text2img data chainable
'''
+
def __init__(self, num_records=0, valid_ids=None, size=256):
super().__init__()
self.num_records = num_records
@@ -20,4 +21,4 @@ def __len__(self):
@abstractmethod
def __iter__(self):
- pass
\ No newline at end of file
+ pass
diff --git a/ldm/data/imagenet.py b/ldm/data/imagenet.py
index 1c473f9c6..8fd6dc48d 100644
--- a/ldm/data/imagenet.py
+++ b/ldm/data/imagenet.py
@@ -1,18 +1,22 @@
-import os, yaml, pickle, shutil, tarfile, glob
-import cv2
-import albumentations
import PIL
+import albumentations
+import cv2
+import glob
import numpy as np
+import os
+import pickle
+import shutil
+import taming.data.utils as tdu
+import tarfile
import torchvision.transforms.functional as TF
-from omegaconf import OmegaConf
-from functools import partial
+import yaml
from PIL import Image
-from tqdm import tqdm
-from torch.utils.data import Dataset, Subset
-
-import taming.data.utils as tdu
-from taming.data.imagenet import str_to_indices, give_synsets_from_indices, download, retrieve
+from functools import partial
+from omegaconf import OmegaConf
from taming.data.imagenet import ImagePaths
+from taming.data.imagenet import str_to_indices, give_synsets_from_indices, download, retrieve
+from torch.utils.data import Dataset, Subset
+from tqdm import tqdm
from ldm.modules.image_degradation import degradation_fn_bsr, degradation_fn_bsr_light
@@ -20,13 +24,13 @@
def synset2idx(path_to_yaml="data/index_synset.yaml"):
with open(path_to_yaml) as f:
di2s = yaml.load(f)
- return dict((v,k) for k,v in di2s.items())
+ return dict((v, k) for k, v in di2s.items())
class ImageNetBase(Dataset):
def __init__(self, config=None):
self.config = config or OmegaConf.create()
- if not type(self.config)==dict:
+ if not type(self.config) == dict:
self.config = OmegaConf.to_container(self.config)
self.keep_orig_class_label = self.config.get("keep_orig_class_label", False)
self.process_images = True # if False we skip loading & processing images and self.data contains filepaths
@@ -68,7 +72,7 @@ def _prepare_synset_to_human(self):
URL = "https://heibox.uni-heidelberg.de/f/9f28e956cd304264bb82/?dl=1"
self.human_dict = os.path.join(self.root, "synset_human.txt")
if (not os.path.exists(self.human_dict) or
- not os.path.getsize(self.human_dict)==SIZE):
+ not os.path.getsize(self.human_dict) == SIZE):
download(URL, self.human_dict)
def _prepare_idx_to_synset(self):
@@ -166,7 +170,7 @@ def _prepare(self):
datadir = self.datadir
if not os.path.exists(datadir):
path = os.path.join(self.root, self.FILES[0])
- if not os.path.exists(path) or not os.path.getsize(path)==self.SIZES[0]:
+ if not os.path.exists(path) or not os.path.getsize(path) == self.SIZES[0]:
import academictorrents as at
atpath = at.get(self.AT_HASH, datastore=self.root)
assert atpath == path
@@ -187,7 +191,7 @@ def _prepare(self):
filelist = glob.glob(os.path.join(datadir, "**", "*.JPEG"))
filelist = [os.path.relpath(p, start=datadir) for p in filelist]
filelist = sorted(filelist)
- filelist = "\n".join(filelist)+"\n"
+ filelist = "\n".join(filelist) + "\n"
with open(self.txt_filelist, "w") as f:
f.write(filelist)
@@ -231,7 +235,7 @@ def _prepare(self):
datadir = self.datadir
if not os.path.exists(datadir):
path = os.path.join(self.root, self.FILES[0])
- if not os.path.exists(path) or not os.path.getsize(path)==self.SIZES[0]:
+ if not os.path.exists(path) or not os.path.getsize(path) == self.SIZES[0]:
import academictorrents as at
atpath = at.get(self.AT_HASH, datastore=self.root)
assert atpath == path
@@ -242,7 +246,7 @@ def _prepare(self):
tar.extractall(path=datadir)
vspath = os.path.join(self.root, self.FILES[1])
- if not os.path.exists(vspath) or not os.path.getsize(vspath)==self.SIZES[1]:
+ if not os.path.exists(vspath) or not os.path.getsize(vspath) == self.SIZES[1]:
download(self.VS_URL, vspath)
with open(vspath, "r") as f:
@@ -261,14 +265,13 @@ def _prepare(self):
filelist = glob.glob(os.path.join(datadir, "**", "*.JPEG"))
filelist = [os.path.relpath(p, start=datadir) for p in filelist]
filelist = sorted(filelist)
- filelist = "\n".join(filelist)+"\n"
+ filelist = "\n".join(filelist) + "\n"
with open(self.txt_filelist, "w") as f:
f.write(filelist)
tdu.mark_prepared(self.root)
-
class ImageNetSR(Dataset):
def __init__(self, size=None,
degradation=None, downscale_f=4, min_crop_f=0.5, max_crop_f=1.,
@@ -296,12 +299,12 @@ def __init__(self, size=None,
self.LR_size = int(size / downscale_f)
self.min_crop_f = min_crop_f
self.max_crop_f = max_crop_f
- assert(max_crop_f <= 1.)
+ assert (max_crop_f <= 1.)
self.center_crop = not random_crop
self.image_rescaler = albumentations.SmallestMaxSize(max_size=size, interpolation=cv2.INTER_AREA)
- self.pil_interpolation = False # gets reset later if incase interp_op is from pillow
+ self.pil_interpolation = False # gets reset later if incase interp_op is from pillow
if degradation == "bsrgan":
self.degradation_process = partial(degradation_fn_bsr, sf=downscale_f)
@@ -311,17 +314,17 @@ def __init__(self, size=None,
else:
interpolation_fn = {
- "cv_nearest": cv2.INTER_NEAREST,
- "cv_bilinear": cv2.INTER_LINEAR,
- "cv_bicubic": cv2.INTER_CUBIC,
- "cv_area": cv2.INTER_AREA,
- "cv_lanczos": cv2.INTER_LANCZOS4,
- "pil_nearest": PIL.Image.NEAREST,
- "pil_bilinear": PIL.Image.BILINEAR,
- "pil_bicubic": PIL.Image.BICUBIC,
- "pil_box": PIL.Image.BOX,
- "pil_hamming": PIL.Image.HAMMING,
- "pil_lanczos": PIL.Image.LANCZOS,
+ "cv_nearest": cv2.INTER_NEAREST,
+ "cv_bilinear": cv2.INTER_LINEAR,
+ "cv_bicubic": cv2.INTER_CUBIC,
+ "cv_area": cv2.INTER_AREA,
+ "cv_lanczos": cv2.INTER_LANCZOS4,
+ "pil_nearest": PIL.Image.NEAREST,
+ "pil_bilinear": PIL.Image.BILINEAR,
+ "pil_bicubic": PIL.Image.BICUBIC,
+ "pil_box": PIL.Image.BOX,
+ "pil_hamming": PIL.Image.HAMMING,
+ "pil_lanczos": PIL.Image.LANCZOS,
}[degradation]
self.pil_interpolation = degradation.startswith("pil_")
@@ -366,8 +369,8 @@ def __getitem__(self, i):
else:
LR_image = self.degradation_process(image=image)["image"]
- example["image"] = (image/127.5 - 1.0).astype(np.float32)
- example["LR_image"] = (LR_image/127.5 - 1.0).astype(np.float32)
+ example["image"] = (image / 127.5 - 1.0).astype(np.float32)
+ example["LR_image"] = (LR_image / 127.5 - 1.0).astype(np.float32)
return example
@@ -379,7 +382,7 @@ def __init__(self, **kwargs):
def get_base(self):
with open("data/imagenet_train_hr_indices.p", "rb") as f:
indices = pickle.load(f)
- dset = ImageNetTrain(process_images=False,)
+ dset = ImageNetTrain(process_images=False, )
return Subset(dset, indices)
@@ -390,5 +393,5 @@ def __init__(self, **kwargs):
def get_base(self):
with open("data/imagenet_val_hr_indices.p", "rb") as f:
indices = pickle.load(f)
- dset = ImageNetValidation(process_images=False,)
+ dset = ImageNetValidation(process_images=False, )
return Subset(dset, indices)
diff --git a/ldm/data/lsun.py b/ldm/data/lsun.py
index 6256e4571..23f1607ee 100644
--- a/ldm/data/lsun.py
+++ b/ldm/data/lsun.py
@@ -1,6 +1,6 @@
-import os
-import numpy as np
import PIL
+import numpy as np
+import os
from PIL import Image
from torch.utils.data import Dataset
from torchvision import transforms
diff --git a/ldm/lr_scheduler.py b/ldm/lr_scheduler.py
index be39da9ca..d53a8bee7 100644
--- a/ldm/lr_scheduler.py
+++ b/ldm/lr_scheduler.py
@@ -5,6 +5,7 @@ class LambdaWarmUpCosineScheduler:
"""
note: use with a base_lr of 1.0
"""
+
def __init__(self, warm_up_steps, lr_min, lr_max, lr_start, max_decay_steps, verbosity_interval=0):
self.lr_warm_up_steps = warm_up_steps
self.lr_start = lr_start
@@ -30,7 +31,7 @@ def schedule(self, n, **kwargs):
return lr
def __call__(self, n, **kwargs):
- return self.schedule(n,**kwargs)
+ return self.schedule(n, **kwargs)
class LambdaWarmUpCosineScheduler2:
@@ -38,6 +39,7 @@ class LambdaWarmUpCosineScheduler2:
supports repeated iterations, configurable via lists
note: use with a base_lr of 1.0.
"""
+
def __init__(self, warm_up_steps, f_min, f_max, f_start, cycle_lengths, verbosity_interval=0):
assert len(warm_up_steps) == len(f_min) == len(f_max) == len(f_start) == len(cycle_lengths)
self.lr_warm_up_steps = warm_up_steps
@@ -92,7 +94,7 @@ def schedule(self, n, **kwargs):
self.last_f = f
return f
else:
- f = self.f_min[cycle] + (self.f_max[cycle] - self.f_min[cycle]) * (self.cycle_lengths[cycle] - n) / (self.cycle_lengths[cycle])
+ f = self.f_min[cycle] + (self.f_max[cycle] - self.f_min[cycle]) * (self.cycle_lengths[cycle] - n) / (
+ self.cycle_lengths[cycle])
self.last_f = f
return f
-
diff --git a/ldm/models/autoencoder.py b/ldm/models/autoencoder.py
index 6a9c4f454..90cb48010 100644
--- a/ldm/models/autoencoder.py
+++ b/ldm/models/autoencoder.py
@@ -1,13 +1,11 @@
-import torch
import pytorch_lightning as pl
+import torch
import torch.nn.functional as F
from contextlib import contextmanager
-
from taming.modules.vqvae.quantize import VectorQuantizer2 as VectorQuantizer
from ldm.modules.diffusionmodules.model import Encoder, Decoder
from ldm.modules.distributions.distributions import DiagonalGaussianDistribution
-
from ldm.util import instantiate_from_config
@@ -26,7 +24,7 @@ def __init__(self,
scheduler_config=None,
lr_g_factor=1.0,
remap=None,
- sane_index_shape=False, # tell vector quantizer to return indices as bhw
+ sane_index_shape=False, # tell vector quantizer to return indices as bhw
use_ema=False
):
super().__init__()
@@ -42,7 +40,7 @@ def __init__(self,
self.quant_conv = torch.nn.Conv2d(ddconfig["z_channels"], embed_dim, 1)
self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
if colorize_nlabels is not None:
- assert type(colorize_nlabels)==int
+ assert type(colorize_nlabels) == int
self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
if monitor is not None:
self.monitor = monitor
@@ -115,7 +113,7 @@ def decode_code(self, code_b):
return dec
def forward(self, input, return_pred_indices=False):
- quant, diff, (_,_,ind) = self.encode(input)
+ quant, diff, (_, _, ind) = self.encode(input)
dec = self.decode(quant)
if return_pred_indices:
return dec, diff, ind
@@ -133,7 +131,7 @@ def get_input(self, batch, k):
# do the first few batches with max size to avoid later oom
new_resize = upper_size
else:
- new_resize = np.random.choice(np.arange(lower_size, upper_size+16, 16))
+ new_resize = np.random.choice(np.arange(lower_size, upper_size + 16, 16))
if new_resize != x.shape[2]:
x = F.interpolate(x, size=new_resize, mode="bicubic")
x = x.detach()
@@ -157,7 +155,7 @@ def training_step(self, batch, batch_idx, optimizer_idx):
if optimizer_idx == 1:
# discriminator
discloss, log_dict_disc = self.loss(qloss, x, xrec, optimizer_idx, self.global_step,
- last_layer=self.get_last_layer(), split="train")
+ last_layer=self.get_last_layer(), split="train")
self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True)
return discloss
@@ -173,21 +171,21 @@ def _validation_step(self, batch, batch_idx, suffix=""):
aeloss, log_dict_ae = self.loss(qloss, x, xrec, 0,
self.global_step,
last_layer=self.get_last_layer(),
- split="val"+suffix,
+ split="val" + suffix,
predicted_indices=ind
)
discloss, log_dict_disc = self.loss(qloss, x, xrec, 1,
self.global_step,
last_layer=self.get_last_layer(),
- split="val"+suffix,
+ split="val" + suffix,
predicted_indices=ind
)
rec_loss = log_dict_ae[f"val{suffix}/rec_loss"]
self.log(f"val{suffix}/rec_loss", rec_loss,
- prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True)
+ prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True)
self.log(f"val{suffix}/aeloss", aeloss,
- prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True)
+ prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True)
if version.parse(pl.__version__) >= version.parse('1.4.0'):
del log_dict_ae[f"val{suffix}/rec_loss"]
self.log_dict(log_dict_ae)
@@ -196,13 +194,13 @@ def _validation_step(self, batch, batch_idx, suffix=""):
def configure_optimizers(self):
lr_d = self.learning_rate
- lr_g = self.lr_g_factor*self.learning_rate
+ lr_g = self.lr_g_factor * self.learning_rate
print("lr_d", lr_d)
print("lr_g", lr_g)
- opt_ae = torch.optim.Adam(list(self.encoder.parameters())+
- list(self.decoder.parameters())+
- list(self.quantize.parameters())+
- list(self.quant_conv.parameters())+
+ opt_ae = torch.optim.Adam(list(self.encoder.parameters()) +
+ list(self.decoder.parameters()) +
+ list(self.quantize.parameters()) +
+ list(self.quant_conv.parameters()) +
list(self.post_quant_conv.parameters()),
lr=lr_g, betas=(0.5, 0.9))
opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
@@ -257,7 +255,7 @@ def to_rgb(self, x):
if not hasattr(self, "colorize"):
self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
x = F.conv2d(x, weight=self.colorize)
- x = 2.*(x-x.min())/(x.max()-x.min()) - 1.
+ x = 2. * (x - x.min()) / (x.max() - x.min()) - 1.
return x
@@ -299,11 +297,11 @@ def __init__(self,
self.decoder = Decoder(**ddconfig)
self.loss = instantiate_from_config(lossconfig)
assert ddconfig["double_z"]
- self.quant_conv = torch.nn.Conv2d(2*ddconfig["z_channels"], 2*embed_dim, 1)
+ self.quant_conv = torch.nn.Conv2d(2 * ddconfig["z_channels"], 2 * embed_dim, 1)
self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
self.embed_dim = embed_dim
if colorize_nlabels is not None:
- assert type(colorize_nlabels)==int
+ assert type(colorize_nlabels) == int
self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
if monitor is not None:
self.monitor = monitor
@@ -385,9 +383,9 @@ def validation_step(self, batch, batch_idx):
def configure_optimizers(self):
lr = self.learning_rate
- opt_ae = torch.optim.Adam(list(self.encoder.parameters())+
- list(self.decoder.parameters())+
- list(self.quant_conv.parameters())+
+ opt_ae = torch.optim.Adam(list(self.encoder.parameters()) +
+ list(self.decoder.parameters()) +
+ list(self.quant_conv.parameters()) +
list(self.post_quant_conv.parameters()),
lr=lr, betas=(0.5, 0.9))
opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
@@ -419,7 +417,7 @@ def to_rgb(self, x):
if not hasattr(self, "colorize"):
self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
x = F.conv2d(x, weight=self.colorize)
- x = 2.*(x-x.min())/(x.max()-x.min()) - 1.
+ x = 2. * (x - x.min()) / (x.max() - x.min()) - 1.
return x
diff --git a/ldm/models/diffusion/classifier.py b/ldm/models/diffusion/classifier.py
index 67e98b9d8..f27384f38 100644
--- a/ldm/models/diffusion/classifier.py
+++ b/ldm/models/diffusion/classifier.py
@@ -1,14 +1,14 @@
import os
-import torch
import pytorch_lightning as pl
-from omegaconf import OmegaConf
-from torch.nn import functional as F
-from torch.optim import AdamW
-from torch.optim.lr_scheduler import LambdaLR
+import torch
from copy import deepcopy
from einops import rearrange
from glob import glob
from natsort import natsorted
+from omegaconf import OmegaConf
+from torch.nn import functional as F
+from torch.optim import AdamW
+from torch.optim.lr_scheduler import LambdaLR
from ldm.modules.diffusionmodules.openaimodel import EncoderUNetModel, UNetModel
from ldm.util import log_txt_as_img, default, ismap, instantiate_from_config
diff --git a/ldm/models/diffusion/ddim.py b/ldm/models/diffusion/ddim.py
index fb31215db..c62cc4c34 100644
--- a/ldm/models/diffusion/ddim.py
+++ b/ldm/models/diffusion/ddim.py
@@ -1,9 +1,9 @@
"""SAMPLING ONLY."""
-import torch
import numpy as np
-from tqdm import tqdm
+import torch
from functools import partial
+from tqdm import tqdm
from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like, \
extract_into_tensor
@@ -24,7 +24,7 @@ def register_buffer(self, name, attr):
def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps,
- num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose)
+ num_ddpm_timesteps=self.ddpm_num_timesteps, verbose=verbose)
alphas_cumprod = self.model.alphas_cumprod
assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
@@ -43,14 +43,14 @@ def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0.,
# ddim sampling parameters
ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(),
ddim_timesteps=self.ddim_timesteps,
- eta=ddim_eta,verbose=verbose)
+ eta=ddim_eta, verbose=verbose)
self.register_buffer('ddim_sigmas', ddim_sigmas)
self.register_buffer('ddim_alphas', ddim_alphas)
self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas))
sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
(1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (
- 1 - self.alphas_cumprod / self.alphas_cumprod_prev))
+ 1 - self.alphas_cumprod / self.alphas_cumprod_prev))
self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps)
@torch.no_grad()
@@ -116,7 +116,7 @@ def ddim_sampling(self, cond, shape,
callback=None, timesteps=None, quantize_denoised=False,
mask=None, x0=None, img_callback=None, log_every_t=100,
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
- unconditional_guidance_scale=1., unconditional_conditioning=None,):
+ unconditional_guidance_scale=1., unconditional_conditioning=None, ):
device = self.model.betas.device
b = shape[0]
if x_T is None:
@@ -131,7 +131,7 @@ def ddim_sampling(self, cond, shape,
timesteps = self.ddim_timesteps[:subset_end]
intermediates = {'x_inter': [img], 'pred_x0': [img]}
- time_range = reversed(range(0,timesteps)) if ddim_use_original_steps else np.flip(timesteps)
+ time_range = reversed(range(0, timesteps)) if ddim_use_original_steps else np.flip(timesteps)
total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
print(f"Running DDIM Sampling with {total_steps} timesteps")
@@ -189,14 +189,14 @@ def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=F
a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
- sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)
+ sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index], device=device)
# current prediction for x_0
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
if quantize_denoised:
pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
# direction pointing to x_t
- dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
+ dir_xt = (1. - a_prev - sigma_t ** 2).sqrt() * e_t
noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
if noise_dropout > 0.:
noise = torch.nn.functional.dropout(noise, p=noise_dropout)
@@ -238,4 +238,4 @@ def decode(self, x_latent, cond, t_start, unconditional_guidance_scale=1.0, unco
x_dec, _ = self.p_sample_ddim(x_dec, cond, ts, index=index, use_original_steps=use_original_steps,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning)
- return x_dec
\ No newline at end of file
+ return x_dec
diff --git a/ldm/models/diffusion/ddpm.py b/ldm/models/diffusion/ddpm.py
index bbedd04cf..8ef1d7836 100644
--- a/ldm/models/diffusion/ddpm.py
+++ b/ldm/models/diffusion/ddpm.py
@@ -6,25 +6,24 @@
-- merci
"""
-import torch
-import torch.nn as nn
import numpy as np
import pytorch_lightning as pl
-from torch.optim.lr_scheduler import LambdaLR
-from einops import rearrange, repeat
+import torch
+import torch.nn as nn
from contextlib import contextmanager
+from einops import rearrange, repeat
from functools import partial
-from tqdm import tqdm
-from torchvision.utils import make_grid
from pytorch_lightning.utilities.distributed import rank_zero_only
+from torch.optim.lr_scheduler import LambdaLR
+from torchvision.utils import make_grid
+from tqdm import tqdm
-from ldm.util import log_txt_as_img, exists, default, ismap, isimage, mean_flat, count_params, instantiate_from_config
-from ldm.modules.ema import LitEma
-from ldm.modules.distributions.distributions import normal_kl, DiagonalGaussianDistribution
from ldm.models.autoencoder import VQModelInterface, IdentityFirstStage, AutoencoderKL
-from ldm.modules.diffusionmodules.util import make_beta_schedule, extract_into_tensor, noise_like
from ldm.models.diffusion.ddim import DDIMSampler
-
+from ldm.modules.diffusionmodules.util import make_beta_schedule, extract_into_tensor, noise_like
+from ldm.modules.distributions.distributions import normal_kl, DiagonalGaussianDistribution
+from ldm.modules.ema import LitEma
+from ldm.util import log_txt_as_img, exists, default, ismap, isimage, mean_flat, count_params, instantiate_from_config
__conditioning_keys__ = {'concat': 'c_concat',
'crossattn': 'c_crossattn',
@@ -113,7 +112,6 @@ def __init__(self,
if self.learn_logvar:
self.logvar = nn.Parameter(self.logvar, requires_grad=True)
-
def register_schedule(self, given_betas=None, beta_schedule="linear", timesteps=1000,
linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
if exists(given_betas):
@@ -146,7 +144,7 @@ def register_schedule(self, given_betas=None, beta_schedule="linear", timesteps=
# calculations for posterior q(x_{t-1} | x_t, x_0)
posterior_variance = (1 - self.v_posterior) * betas * (1. - alphas_cumprod_prev) / (
- 1. - alphas_cumprod) + self.v_posterior * betas
+ 1. - alphas_cumprod) + self.v_posterior * betas
# above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
self.register_buffer('posterior_variance', to_torch(posterior_variance))
# below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
@@ -158,7 +156,7 @@ def register_schedule(self, given_betas=None, beta_schedule="linear", timesteps=
if self.parameterization == "eps":
lvlb_weights = self.betas ** 2 / (
- 2 * self.posterior_variance * to_torch(alphas) * (1 - self.alphas_cumprod))
+ 2 * self.posterior_variance * to_torch(alphas) * (1 - self.alphas_cumprod))
elif self.parameterization == "x0":
lvlb_weights = 0.5 * np.sqrt(torch.Tensor(alphas_cumprod)) / (2. * 1 - torch.Tensor(alphas_cumprod))
else:
@@ -423,6 +421,7 @@ def configure_optimizers(self):
class LatentDiffusion(DDPM):
"""main class"""
+
def __init__(self,
first_stage_config,
cond_stage_config,
@@ -461,7 +460,7 @@ def __init__(self,
self.instantiate_cond_stage(cond_stage_config)
self.cond_stage_forward = cond_stage_forward
self.clip_denoised = False
- self.bbox_tokenizer = None
+ self.bbox_tokenizer = None
self.restarted_from_ckpt = False
if ckpt_path is not None:
@@ -531,7 +530,7 @@ def _get_denoise_row_from_list(self, samples, desc='', force_no_decoder_quantiza
denoise_row = []
for zd in tqdm(samples, desc=desc):
denoise_row.append(self.decode_first_stage(zd.to(self.device),
- force_not_quantize=force_no_decoder_quantization))
+ force_not_quantize=force_no_decoder_quantization))
n_imgs_per_row = len(denoise_row)
denoise_row = torch.stack(denoise_row) # n_log_step, n_row, C, H, W
denoise_grid = rearrange(denoise_row, 'n b c h w -> b n c h w')
@@ -793,7 +792,7 @@ def differentiable_decode_first_stage(self, z, predict_cids=False, force_not_qua
z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L )
# 2. apply model loop over last dim
- if isinstance(self.first_stage_model, VQModelInterface):
+ if isinstance(self.first_stage_model, VQModelInterface):
output_list = [self.first_stage_model.decode(z[:, :, :, :, i],
force_not_quantize=predict_cids or force_not_quantize)
for i in range(z.shape[-1])]
@@ -901,7 +900,7 @@ def apply_model(self, x_noisy, t, cond, return_ids=False):
if hasattr(self, "split_input_params"):
assert len(cond) == 1 # todo can only deal with one conditioning atm
- assert not return_ids
+ assert not return_ids
ks = self.split_input_params["ks"] # eg. (128, 128)
stride = self.split_input_params["stride"] # eg. (64, 64)
@@ -1216,7 +1215,7 @@ def p_sample_loop(self, cond, shape, return_intermediates=False,
@torch.no_grad()
def sample(self, cond, batch_size=16, return_intermediates=False, x_T=None,
verbose=True, timesteps=None, quantize_denoised=False,
- mask=None, x0=None, shape=None,**kwargs):
+ mask=None, x0=None, shape=None, **kwargs):
if shape is None:
shape = (batch_size, self.channels, self.image_size, self.image_size)
if cond is not None:
@@ -1232,21 +1231,20 @@ def sample(self, cond, batch_size=16, return_intermediates=False, x_T=None,
mask=mask, x0=x0)
@torch.no_grad()
- def sample_log(self,cond,batch_size,ddim, ddim_steps,**kwargs):
+ def sample_log(self, cond, batch_size, ddim, ddim_steps, **kwargs):
if ddim:
ddim_sampler = DDIMSampler(self)
shape = (self.channels, self.image_size, self.image_size)
- samples, intermediates =ddim_sampler.sample(ddim_steps,batch_size,
- shape,cond,verbose=False,**kwargs)
+ samples, intermediates = ddim_sampler.sample(ddim_steps, batch_size,
+ shape, cond, verbose=False, **kwargs)
else:
samples, intermediates = self.sample(cond=cond, batch_size=batch_size,
- return_intermediates=True,**kwargs)
+ return_intermediates=True, **kwargs)
return samples, intermediates
-
@torch.no_grad()
def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=200, ddim_eta=1., return_keys=None,
quantize_denoised=True, inpaint=True, plot_denoise_rows=False, plot_progressive_rows=True,
@@ -1300,8 +1298,8 @@ def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=200, ddim_eta=
if sample:
# get denoise row
with self.ema_scope("Plotting"):
- samples, z_denoise_row = self.sample_log(cond=c,batch_size=N,ddim=use_ddim,
- ddim_steps=ddim_steps,eta=ddim_eta)
+ samples, z_denoise_row = self.sample_log(cond=c, batch_size=N, ddim=use_ddim,
+ ddim_steps=ddim_steps, eta=ddim_eta)
# samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True)
x_samples = self.decode_first_stage(samples)
log["samples"] = x_samples
@@ -1313,8 +1311,8 @@ def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=200, ddim_eta=
self.first_stage_model, IdentityFirstStage):
# also display when quantizing x0 while sampling
with self.ema_scope("Plotting Quantized Denoised"):
- samples, z_denoise_row = self.sample_log(cond=c,batch_size=N,ddim=use_ddim,
- ddim_steps=ddim_steps,eta=ddim_eta,
+ samples, z_denoise_row = self.sample_log(cond=c, batch_size=N, ddim=use_ddim,
+ ddim_steps=ddim_steps, eta=ddim_eta,
quantize_denoised=True)
# samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True,
# quantize_denoised=True)
@@ -1329,17 +1327,16 @@ def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=200, ddim_eta=
mask[:, h // 4:3 * h // 4, w // 4:3 * w // 4] = 0.
mask = mask[:, None, ...]
with self.ema_scope("Plotting Inpaint"):
-
- samples, _ = self.sample_log(cond=c,batch_size=N,ddim=use_ddim, eta=ddim_eta,
- ddim_steps=ddim_steps, x0=z[:N], mask=mask)
+ samples, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim, eta=ddim_eta,
+ ddim_steps=ddim_steps, x0=z[:N], mask=mask)
x_samples = self.decode_first_stage(samples.to(self.device))
log["samples_inpainting"] = x_samples
log["mask"] = mask
# outpaint
with self.ema_scope("Plotting Outpaint"):
- samples, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim,eta=ddim_eta,
- ddim_steps=ddim_steps, x0=z[:N], mask=mask)
+ samples, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim, eta=ddim_eta,
+ ddim_steps=ddim_steps, x0=z[:N], mask=mask)
x_samples = self.decode_first_stage(samples.to(self.device))
log["samples_outpainting"] = x_samples
diff --git a/ldm/models/diffusion/plms.py b/ldm/models/diffusion/plms.py
index 78eeb1003..cbeb486ab 100644
--- a/ldm/models/diffusion/plms.py
+++ b/ldm/models/diffusion/plms.py
@@ -1,9 +1,9 @@
"""SAMPLING ONLY."""
-import torch
import numpy as np
-from tqdm import tqdm
+import torch
from functools import partial
+from tqdm import tqdm
from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like
@@ -25,7 +25,7 @@ def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0.,
if ddim_eta != 0:
raise ValueError('ddim_eta must be 0 for PLMS')
self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps,
- num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose)
+ num_ddpm_timesteps=self.ddpm_num_timesteps, verbose=verbose)
alphas_cumprod = self.model.alphas_cumprod
assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
@@ -44,14 +44,14 @@ def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0.,
# ddim sampling parameters
ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(),
ddim_timesteps=self.ddim_timesteps,
- eta=ddim_eta,verbose=verbose)
+ eta=ddim_eta, verbose=verbose)
self.register_buffer('ddim_sigmas', ddim_sigmas)
self.register_buffer('ddim_alphas', ddim_alphas)
self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas))
sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
(1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (
- 1 - self.alphas_cumprod / self.alphas_cumprod_prev))
+ 1 - self.alphas_cumprod / self.alphas_cumprod_prev))
self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps)
@torch.no_grad()
@@ -117,7 +117,7 @@ def plms_sampling(self, cond, shape,
callback=None, timesteps=None, quantize_denoised=False,
mask=None, x0=None, img_callback=None, log_every_t=100,
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
- unconditional_guidance_scale=1., unconditional_conditioning=None,):
+ unconditional_guidance_scale=1., unconditional_conditioning=None, ):
device = self.model.betas.device
b = shape[0]
if x_T is None:
@@ -132,7 +132,7 @@ def plms_sampling(self, cond, shape,
timesteps = self.ddim_timesteps[:subset_end]
intermediates = {'x_inter': [img], 'pred_x0': [img]}
- time_range = list(reversed(range(0,timesteps))) if ddim_use_original_steps else np.flip(timesteps)
+ time_range = list(reversed(range(0, timesteps))) if ddim_use_original_steps else np.flip(timesteps)
total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
print(f"Running PLMS Sampling with {total_steps} timesteps")
@@ -201,14 +201,14 @@ def get_x_prev_and_pred_x0(e_t, index):
a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
- sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)
+ sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index], device=device)
# current prediction for x_0
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
if quantize_denoised:
pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
# direction pointing to x_t
- dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
+ dir_xt = (1. - a_prev - sigma_t ** 2).sqrt() * e_t
noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
if noise_dropout > 0.:
noise = torch.nn.functional.dropout(noise, p=noise_dropout)
diff --git a/ldm/modules/attention.py b/ldm/modules/attention.py
index f4eff39cc..47dd9d41b 100644
--- a/ldm/modules/attention.py
+++ b/ldm/modules/attention.py
@@ -1,9 +1,9 @@
-from inspect import isfunction
import math
import torch
import torch.nn.functional as F
-from torch import nn, einsum
from einops import rearrange, repeat
+from inspect import isfunction
+from torch import nn, einsum
from ldm.modules.diffusionmodules.util import checkpoint
@@ -13,7 +13,7 @@ def exists(val):
def uniq(arr):
- return{el: True for el in arr}.keys()
+ return {el: True for el in arr}.keys()
def default(val, d):
@@ -82,14 +82,14 @@ def __init__(self, dim, heads=4, dim_head=32):
super().__init__()
self.heads = heads
hidden_dim = dim_head * heads
- self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False)
+ self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
self.to_out = nn.Conv2d(hidden_dim, dim, 1)
def forward(self, x):
b, c, h, w = x.shape
qkv = self.to_qkv(x)
- q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', heads = self.heads, qkv=3)
- k = k.softmax(dim=-1)
+ q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', heads=self.heads, qkv=3)
+ k = k.softmax(dim=-1)
context = torch.einsum('bhdn,bhen->bhde', k, v)
out = torch.einsum('bhde,bhdn->bhen', context, q)
out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', heads=self.heads, h=h, w=w)
@@ -131,12 +131,12 @@ def forward(self, x):
v = self.v(h_)
# compute attention
- b,c,h,w = q.shape
+ b, c, h, w = q.shape
q = rearrange(q, 'b c h w -> b (h w) c')
k = rearrange(k, 'b c h w -> b c (h w)')
w_ = torch.einsum('bij,bjk->bik', q, k)
- w_ = w_ * (int(c)**(-0.5))
+ w_ = w_ * (int(c) ** (-0.5))
w_ = torch.nn.functional.softmax(w_, dim=2)
# attend to values
@@ -146,7 +146,7 @@ def forward(self, x):
h_ = rearrange(h_, 'b c (h w) -> b c h w', h=h)
h_ = self.proj_out(h_)
- return x+h_
+ return x + h_
class CrossAttention(nn.Module):
@@ -174,29 +174,34 @@ def forward(self, x, context=None, mask=None):
context = default(context, x)
k = self.to_k(context)
v = self.to_v(context)
+ del context, x
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
- sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
+ sim = einsum('b i d, b j d -> b i j', q, k) * self.scale # (8, 4096, 40)
+ del q, k
if exists(mask):
mask = rearrange(mask, 'b ... -> b (...)')
max_neg_value = -torch.finfo(sim.dtype).max
mask = repeat(mask, 'b j -> (b h) () j', h=h)
sim.masked_fill_(~mask, max_neg_value)
+ del mask
- # attention, what we cannot get enough of
- attn = sim.softmax(dim=-1)
+ # attention, what we cannot get enough of, by halves
+ sim[4:] = sim[4:].softmax(dim=-1)
+ sim[:4] = sim[:4].softmax(dim=-1)
- out = einsum('b i j, b j d -> b i d', attn, v)
- out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
- return self.to_out(out)
+ sim = einsum('b i j, b j d -> b i d', sim, v)
+ sim = rearrange(sim, '(b h) n d -> b n (h d)', h=h)
+ return self.to_out(sim)
class BasicTransformerBlock(nn.Module):
def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True):
super().__init__()
- self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout) # is a self-attention
+ self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head,
+ dropout=dropout) # is a self-attention
self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
self.attn2 = CrossAttention(query_dim=dim, context_dim=context_dim,
heads=n_heads, dim_head=d_head, dropout=dropout) # is self-attn if context is none
@@ -223,6 +228,7 @@ class SpatialTransformer(nn.Module):
Then apply standard transformer action.
Finally, reshape to image
"""
+
def __init__(self, in_channels, n_heads, d_head,
depth=1, dropout=0., context_dim=None):
super().__init__()
@@ -238,7 +244,7 @@ def __init__(self, in_channels, n_heads, d_head,
self.transformer_blocks = nn.ModuleList(
[BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim)
- for d in range(depth)]
+ for d in range(depth)]
)
self.proj_out = zero_module(nn.Conv2d(inner_dim,
@@ -258,4 +264,4 @@ def forward(self, x, context=None):
x = block(x, context=context)
x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w)
x = self.proj_out(x)
- return x + x_in
\ No newline at end of file
+ return x + x_in
diff --git a/ldm/modules/diffusionmodules/model.py b/ldm/modules/diffusionmodules/model.py
index 533e589a2..cd1a04db3 100644
--- a/ldm/modules/diffusionmodules/model.py
+++ b/ldm/modules/diffusionmodules/model.py
@@ -1,12 +1,12 @@
# pytorch_diffusion + derived encoder decoder
import math
+import numpy as np
import torch
import torch.nn as nn
-import numpy as np
from einops import rearrange
-from ldm.util import instantiate_from_config
from ldm.modules.attention import LinearAttention
+from ldm.util import instantiate_from_config
def get_timestep_embedding(timesteps, embedding_dim):
@@ -26,13 +26,13 @@ def get_timestep_embedding(timesteps, embedding_dim):
emb = timesteps.float()[:, None] * emb[None, :]
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
if embedding_dim % 2 == 1: # zero pad
- emb = torch.nn.functional.pad(emb, (0,1,0,0))
+ emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
return emb
def nonlinearity(x):
# swish
- return x*torch.sigmoid(x)
+ return x * torch.sigmoid(x)
def Normalize(in_channels, num_groups=32):
@@ -71,7 +71,7 @@ def __init__(self, in_channels, with_conv):
def forward(self, x):
if self.with_conv:
- pad = (0,1,0,1)
+ pad = (0, 1, 0, 1)
x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
x = self.conv(x)
else:
@@ -125,7 +125,7 @@ def forward(self, x, temb):
h = self.conv1(h)
if temb is not None:
- h = h + self.temb_proj(nonlinearity(temb))[:,:,None,None]
+ h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None]
h = self.norm2(h)
h = nonlinearity(h)
@@ -138,11 +138,12 @@ def forward(self, x, temb):
else:
x = self.nin_shortcut(x)
- return x+h
+ return x + h
class LinAttnBlock(LinearAttention):
"""to match AttnBlock usage"""
+
def __init__(self, in_channels):
super().__init__(dim=in_channels, heads=1, dim_head=in_channels)
@@ -174,7 +175,6 @@ def __init__(self, in_channels):
stride=1,
padding=0)
-
def forward(self, x):
h_ = x
h_ = self.norm(h_)
@@ -183,23 +183,23 @@ def forward(self, x):
v = self.v(h_)
# compute attention
- b,c,h,w = q.shape
- q = q.reshape(b,c,h*w)
- q = q.permute(0,2,1) # b,hw,c
- k = k.reshape(b,c,h*w) # b,c,hw
- w_ = torch.bmm(q,k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
- w_ = w_ * (int(c)**(-0.5))
+ b, c, h, w = q.shape
+ q = q.reshape(b, c, h * w)
+ q = q.permute(0, 2, 1) # b,hw,c
+ k = k.reshape(b, c, h * w) # b,c,hw
+ w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
+ w_ = w_ * (int(c) ** (-0.5))
w_ = torch.nn.functional.softmax(w_, dim=2)
# attend to values
- v = v.reshape(b,c,h*w)
- w_ = w_.permute(0,2,1) # b,hw,hw (first hw of k, second of q)
- h_ = torch.bmm(v,w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
- h_ = h_.reshape(b,c,h,w)
+ v = v.reshape(b, c, h * w)
+ w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q)
+ h_ = torch.bmm(v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
+ h_ = h_.reshape(b, c, h, w)
h_ = self.proj_out(h_)
- return x+h_
+ return x + h_
def make_attn(in_channels, attn_type="vanilla"):
@@ -214,13 +214,13 @@ def make_attn(in_channels, attn_type="vanilla"):
class Model(nn.Module):
- def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
+ def __init__(self, *, ch, out_ch, ch_mult=(1, 2, 4, 8), num_res_blocks,
attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
resolution, use_timestep=True, use_linear_attn=False, attn_type="vanilla"):
super().__init__()
if use_linear_attn: attn_type = "linear"
self.ch = ch
- self.temb_ch = self.ch*4
+ self.temb_ch = self.ch * 4
self.num_resolutions = len(ch_mult)
self.num_res_blocks = num_res_blocks
self.resolution = resolution
@@ -245,13 +245,13 @@ def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
padding=1)
curr_res = resolution
- in_ch_mult = (1,)+tuple(ch_mult)
+ in_ch_mult = (1,) + tuple(ch_mult)
self.down = nn.ModuleList()
for i_level in range(self.num_resolutions):
block = nn.ModuleList()
attn = nn.ModuleList()
- block_in = ch*in_ch_mult[i_level]
- block_out = ch*ch_mult[i_level]
+ block_in = ch * in_ch_mult[i_level]
+ block_out = ch * ch_mult[i_level]
for i_block in range(self.num_res_blocks):
block.append(ResnetBlock(in_channels=block_in,
out_channels=block_out,
@@ -263,7 +263,7 @@ def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
down = nn.Module()
down.block = block
down.attn = attn
- if i_level != self.num_resolutions-1:
+ if i_level != self.num_resolutions - 1:
down.downsample = Downsample(block_in, resamp_with_conv)
curr_res = curr_res // 2
self.down.append(down)
@@ -285,12 +285,12 @@ def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
for i_level in reversed(range(self.num_resolutions)):
block = nn.ModuleList()
attn = nn.ModuleList()
- block_out = ch*ch_mult[i_level]
- skip_in = ch*ch_mult[i_level]
- for i_block in range(self.num_res_blocks+1):
+ block_out = ch * ch_mult[i_level]
+ skip_in = ch * ch_mult[i_level]
+ for i_block in range(self.num_res_blocks + 1):
if i_block == self.num_res_blocks:
- skip_in = ch*in_ch_mult[i_level]
- block.append(ResnetBlock(in_channels=block_in+skip_in,
+ skip_in = ch * in_ch_mult[i_level]
+ block.append(ResnetBlock(in_channels=block_in + skip_in,
out_channels=block_out,
temb_channels=self.temb_ch,
dropout=dropout))
@@ -303,7 +303,7 @@ def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
if i_level != 0:
up.upsample = Upsample(block_in, resamp_with_conv)
curr_res = curr_res * 2
- self.up.insert(0, up) # prepend to get consistent order
+ self.up.insert(0, up) # prepend to get consistent order
# end
self.norm_out = Normalize(block_in)
@@ -314,7 +314,7 @@ def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
padding=1)
def forward(self, x, t=None, context=None):
- #assert x.shape[2] == x.shape[3] == self.resolution
+ # assert x.shape[2] == x.shape[3] == self.resolution
if context is not None:
# assume aligned context, cat along channel axis
x = torch.cat((x, context), dim=1)
@@ -336,7 +336,7 @@ def forward(self, x, t=None, context=None):
if len(self.down[i_level].attn) > 0:
h = self.down[i_level].attn[i_block](h)
hs.append(h)
- if i_level != self.num_resolutions-1:
+ if i_level != self.num_resolutions - 1:
hs.append(self.down[i_level].downsample(hs[-1]))
# middle
@@ -347,7 +347,7 @@ def forward(self, x, t=None, context=None):
# upsampling
for i_level in reversed(range(self.num_resolutions)):
- for i_block in range(self.num_res_blocks+1):
+ for i_block in range(self.num_res_blocks + 1):
h = self.up[i_level].block[i_block](
torch.cat([h, hs.pop()], dim=1), temb)
if len(self.up[i_level].attn) > 0:
@@ -366,7 +366,7 @@ def get_last_layer(self):
class Encoder(nn.Module):
- def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
+ def __init__(self, *, ch, out_ch, ch_mult=(1, 2, 4, 8), num_res_blocks,
attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
resolution, z_channels, double_z=True, use_linear_attn=False, attn_type="vanilla",
**ignore_kwargs):
@@ -387,14 +387,14 @@ def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
padding=1)
curr_res = resolution
- in_ch_mult = (1,)+tuple(ch_mult)
+ in_ch_mult = (1,) + tuple(ch_mult)
self.in_ch_mult = in_ch_mult
self.down = nn.ModuleList()
for i_level in range(self.num_resolutions):
block = nn.ModuleList()
attn = nn.ModuleList()
- block_in = ch*in_ch_mult[i_level]
- block_out = ch*ch_mult[i_level]
+ block_in = ch * in_ch_mult[i_level]
+ block_out = ch * ch_mult[i_level]
for i_block in range(self.num_res_blocks):
block.append(ResnetBlock(in_channels=block_in,
out_channels=block_out,
@@ -406,7 +406,7 @@ def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
down = nn.Module()
down.block = block
down.attn = attn
- if i_level != self.num_resolutions-1:
+ if i_level != self.num_resolutions - 1:
down.downsample = Downsample(block_in, resamp_with_conv)
curr_res = curr_res // 2
self.down.append(down)
@@ -426,7 +426,7 @@ def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
# end
self.norm_out = Normalize(block_in)
self.conv_out = torch.nn.Conv2d(block_in,
- 2*z_channels if double_z else z_channels,
+ 2 * z_channels if double_z else z_channels,
kernel_size=3,
stride=1,
padding=1)
@@ -443,7 +443,7 @@ def forward(self, x):
if len(self.down[i_level].attn) > 0:
h = self.down[i_level].attn[i_block](h)
hs.append(h)
- if i_level != self.num_resolutions-1:
+ if i_level != self.num_resolutions - 1:
hs.append(self.down[i_level].downsample(hs[-1]))
# middle
@@ -460,7 +460,7 @@ def forward(self, x):
class Decoder(nn.Module):
- def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
+ def __init__(self, *, ch, out_ch, ch_mult=(1, 2, 4, 8), num_res_blocks,
attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
resolution, z_channels, give_pre_end=False, tanh_out=False, use_linear_attn=False,
attn_type="vanilla", **ignorekwargs):
@@ -476,10 +476,10 @@ def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
self.tanh_out = tanh_out
# compute in_ch_mult, block_in and curr_res at lowest res
- in_ch_mult = (1,)+tuple(ch_mult)
- block_in = ch*ch_mult[self.num_resolutions-1]
- curr_res = resolution // 2**(self.num_resolutions-1)
- self.z_shape = (1,z_channels,curr_res,curr_res)
+ in_ch_mult = (1,) + tuple(ch_mult)
+ block_in = ch * ch_mult[self.num_resolutions - 1]
+ curr_res = resolution // 2 ** (self.num_resolutions - 1)
+ self.z_shape = (1, z_channels, curr_res, curr_res)
print("Working with z of shape {} = {} dimensions.".format(
self.z_shape, np.prod(self.z_shape)))
@@ -507,8 +507,8 @@ def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
for i_level in reversed(range(self.num_resolutions)):
block = nn.ModuleList()
attn = nn.ModuleList()
- block_out = ch*ch_mult[i_level]
- for i_block in range(self.num_res_blocks+1):
+ block_out = ch * ch_mult[i_level]
+ for i_block in range(self.num_res_blocks + 1):
block.append(ResnetBlock(in_channels=block_in,
out_channels=block_out,
temb_channels=self.temb_ch,
@@ -522,7 +522,7 @@ def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
if i_level != 0:
up.upsample = Upsample(block_in, resamp_with_conv)
curr_res = curr_res * 2
- self.up.insert(0, up) # prepend to get consistent order
+ self.up.insert(0, up) # prepend to get consistent order
# end
self.norm_out = Normalize(block_in)
@@ -533,7 +533,7 @@ def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
padding=1)
def forward(self, z):
- #assert z.shape[1:] == self.z_shape[1:]
+ # assert z.shape[1:] == self.z_shape[1:]
self.last_z_shape = z.shape
# timestep embedding
@@ -549,7 +549,7 @@ def forward(self, z):
# upsampling
for i_level in reversed(range(self.num_resolutions)):
- for i_block in range(self.num_res_blocks+1):
+ for i_block in range(self.num_res_blocks + 1):
h = self.up[i_level].block[i_block](h, temb)
if len(self.up[i_level].attn) > 0:
h = self.up[i_level].attn[i_block](h)
@@ -572,17 +572,17 @@ class SimpleDecoder(nn.Module):
def __init__(self, in_channels, out_channels, *args, **kwargs):
super().__init__()
self.model = nn.ModuleList([nn.Conv2d(in_channels, in_channels, 1),
- ResnetBlock(in_channels=in_channels,
- out_channels=2 * in_channels,
- temb_channels=0, dropout=0.0),
- ResnetBlock(in_channels=2 * in_channels,
+ ResnetBlock(in_channels=in_channels,
+ out_channels=2 * in_channels,
+ temb_channels=0, dropout=0.0),
+ ResnetBlock(in_channels=2 * in_channels,
out_channels=4 * in_channels,
temb_channels=0, dropout=0.0),
- ResnetBlock(in_channels=4 * in_channels,
+ ResnetBlock(in_channels=4 * in_channels,
out_channels=2 * in_channels,
temb_channels=0, dropout=0.0),
- nn.Conv2d(2*in_channels, in_channels, 1),
- Upsample(in_channels, with_conv=True)])
+ nn.Conv2d(2 * in_channels, in_channels, 1),
+ Upsample(in_channels, with_conv=True)])
# end
self.norm_out = Normalize(in_channels)
self.conv_out = torch.nn.Conv2d(in_channels,
@@ -593,7 +593,7 @@ def __init__(self, in_channels, out_channels, *args, **kwargs):
def forward(self, x):
for i, layer in enumerate(self.model):
- if i in [1,2,3]:
+ if i in [1, 2, 3]:
x = layer(x, None)
else:
x = layer(x)
@@ -606,7 +606,7 @@ def forward(self, x):
class UpsampleDecoder(nn.Module):
def __init__(self, in_channels, out_channels, ch, num_res_blocks, resolution,
- ch_mult=(2,2), dropout=0.0):
+ ch_mult=(2, 2), dropout=0.0):
super().__init__()
# upsampling
self.temb_ch = 0
@@ -621,9 +621,9 @@ def __init__(self, in_channels, out_channels, ch, num_res_blocks, resolution,
block_out = ch * ch_mult[i_level]
for i_block in range(self.num_res_blocks + 1):
res_block.append(ResnetBlock(in_channels=block_in,
- out_channels=block_out,
- temb_channels=self.temb_ch,
- dropout=dropout))
+ out_channels=block_out,
+ temb_channels=self.temb_ch,
+ dropout=dropout))
block_in = block_out
self.res_blocks.append(nn.ModuleList(res_block))
if i_level != self.num_resolutions - 1:
@@ -681,7 +681,8 @@ def forward(self, x):
x = self.conv_in(x)
for block in self.res_block1:
x = block(x, None)
- x = torch.nn.functional.interpolate(x, size=(int(round(x.shape[2]*self.factor)), int(round(x.shape[3]*self.factor))))
+ x = torch.nn.functional.interpolate(x, size=(
+ int(round(x.shape[2] * self.factor)), int(round(x.shape[3] * self.factor))))
x = self.attn(x)
for block in self.res_block2:
x = block(x, None)
@@ -692,7 +693,7 @@ def forward(self, x):
class MergedRescaleEncoder(nn.Module):
def __init__(self, in_channels, ch, resolution, out_ch, num_res_blocks,
attn_resolutions, dropout=0.0, resamp_with_conv=True,
- ch_mult=(1,2,4,8), rescale_factor=1.0, rescale_module_depth=1):
+ ch_mult=(1, 2, 4, 8), rescale_factor=1.0, rescale_module_depth=1):
super().__init__()
intermediate_chn = ch * ch_mult[-1]
self.encoder = Encoder(in_channels=in_channels, num_res_blocks=num_res_blocks, ch=ch, ch_mult=ch_mult,
@@ -709,10 +710,10 @@ def forward(self, x):
class MergedRescaleDecoder(nn.Module):
- def __init__(self, z_channels, out_ch, resolution, num_res_blocks, attn_resolutions, ch, ch_mult=(1,2,4,8),
+ def __init__(self, z_channels, out_ch, resolution, num_res_blocks, attn_resolutions, ch, ch_mult=(1, 2, 4, 8),
dropout=0.0, resamp_with_conv=True, rescale_factor=1.0, rescale_module_depth=1):
super().__init__()
- tmp_chn = z_channels*ch_mult[-1]
+ tmp_chn = z_channels * ch_mult[-1]
self.decoder = Decoder(out_ch=out_ch, z_channels=tmp_chn, attn_resolutions=attn_resolutions, dropout=dropout,
resamp_with_conv=resamp_with_conv, in_channels=None, num_res_blocks=num_res_blocks,
ch_mult=ch_mult, resolution=resolution, ch=ch)
@@ -729,10 +730,11 @@ class Upsampler(nn.Module):
def __init__(self, in_size, out_size, in_channels, out_channels, ch_mult=2):
super().__init__()
assert out_size >= in_size
- num_blocks = int(np.log2(out_size//in_size))+1
- factor_up = 1.+ (out_size % in_size)
- print(f"Building {self.__class__.__name__} with in_size: {in_size} --> out_size {out_size} and factor {factor_up}")
- self.rescaler = LatentRescaler(factor=factor_up, in_channels=in_channels, mid_channels=2*in_channels,
+ num_blocks = int(np.log2(out_size // in_size)) + 1
+ factor_up = 1. + (out_size % in_size)
+ print(
+ f"Building {self.__class__.__name__} with in_size: {in_size} --> out_size {out_size} and factor {factor_up}")
+ self.rescaler = LatentRescaler(factor=factor_up, in_channels=in_channels, mid_channels=2 * in_channels,
out_channels=in_channels)
self.decoder = Decoder(out_ch=out_channels, resolution=out_size, z_channels=in_channels, num_res_blocks=2,
attn_resolutions=[], in_channels=None, ch=in_channels,
@@ -761,16 +763,17 @@ def __init__(self, in_channels=None, learned=False, mode="bilinear"):
padding=1)
def forward(self, x, scale_factor=1.0):
- if scale_factor==1.0:
+ if scale_factor == 1.0:
return x
else:
x = torch.nn.functional.interpolate(x, mode=self.mode, align_corners=False, scale_factor=scale_factor)
return x
+
class FirstStagePostProcessor(nn.Module):
- def __init__(self, ch_mult:list, in_channels,
- pretrained_model:nn.Module=None,
+ def __init__(self, ch_mult: list, in_channels,
+ pretrained_model: nn.Module = None,
reshape=False,
n_channels=None,
dropout=0.,
@@ -788,22 +791,21 @@ def __init__(self, ch_mult:list, in_channels,
if n_channels is None:
n_channels = self.pretrained_model.encoder.ch
- self.proj_norm = Normalize(in_channels,num_groups=in_channels//2)
- self.proj = nn.Conv2d(in_channels,n_channels,kernel_size=3,
- stride=1,padding=1)
+ self.proj_norm = Normalize(in_channels, num_groups=in_channels // 2)
+ self.proj = nn.Conv2d(in_channels, n_channels, kernel_size=3,
+ stride=1, padding=1)
blocks = []
downs = []
ch_in = n_channels
for m in ch_mult:
- blocks.append(ResnetBlock(in_channels=ch_in,out_channels=m*n_channels,dropout=dropout))
+ blocks.append(ResnetBlock(in_channels=ch_in, out_channels=m * n_channels, dropout=dropout))
ch_in = m * n_channels
downs.append(Downsample(ch_in, with_conv=False))
self.model = nn.ModuleList(blocks)
self.downsampler = nn.ModuleList(downs)
-
def instantiate_pretrained(self, config):
model = instantiate_from_config(config)
self.pretrained_model = model.eval()
@@ -811,25 +813,23 @@ def instantiate_pretrained(self, config):
for param in self.pretrained_model.parameters():
param.requires_grad = False
-
@torch.no_grad()
- def encode_with_pretrained(self,x):
+ def encode_with_pretrained(self, x):
c = self.pretrained_model.encode(x)
if isinstance(c, DiagonalGaussianDistribution):
c = c.mode()
- return c
+ return c
- def forward(self,x):
+ def forward(self, x):
z_fs = self.encode_with_pretrained(x)
z = self.proj_norm(z_fs)
z = self.proj(z)
z = nonlinearity(z)
- for submodel, downmodel in zip(self.model,self.downsampler):
- z = submodel(z,temb=None)
+ for submodel, downmodel in zip(self.model, self.downsampler):
+ z = submodel(z, temb=None)
z = downmodel(z)
if self.do_reshape:
- z = rearrange(z,'b c h w -> b (h w) c')
+ z = rearrange(z, 'b c h w -> b (h w) c')
return z
-
diff --git a/ldm/modules/diffusionmodules/openaimodel.py b/ldm/modules/diffusionmodules/openaimodel.py
index fcf95d1ea..493d98814 100644
--- a/ldm/modules/diffusionmodules/openaimodel.py
+++ b/ldm/modules/diffusionmodules/openaimodel.py
@@ -1,13 +1,13 @@
-from abc import abstractmethod
-from functools import partial
import math
-from typing import Iterable
-
import numpy as np
import torch as th
import torch.nn as nn
import torch.nn.functional as F
+from abc import abstractmethod
+from functools import partial
+from typing import Iterable
+from ldm.modules.attention import SpatialTransformer
from ldm.modules.diffusionmodules.util import (
checkpoint,
conv_nd,
@@ -17,13 +17,13 @@
normalization,
timestep_embedding,
)
-from ldm.modules.attention import SpatialTransformer
# dummy replace
def convert_module_to_f16(x):
pass
+
def convert_module_to_f32(x):
pass
@@ -35,11 +35,11 @@ class AttentionPool2d(nn.Module):
"""
def __init__(
- self,
- spacial_dim: int,
- embed_dim: int,
- num_heads_channels: int,
- output_dim: int = None,
+ self,
+ spacial_dim: int,
+ embed_dim: int,
+ num_heads_channels: int,
+ output_dim: int = None,
):
super().__init__()
self.positional_embedding = nn.Parameter(th.randn(embed_dim, spacial_dim ** 2 + 1) / embed_dim ** 0.5)
@@ -118,16 +118,18 @@ def forward(self, x):
x = self.conv(x)
return x
+
class TransposedUpsample(nn.Module):
'Learned 2x upsampling without padding'
+
def __init__(self, channels, out_channels=None, ks=5):
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
- self.up = nn.ConvTranspose2d(self.channels,self.out_channels,kernel_size=ks,stride=2)
+ self.up = nn.ConvTranspose2d(self.channels, self.out_channels, kernel_size=ks, stride=2)
- def forward(self,x):
+ def forward(self, x):
return self.up(x)
@@ -140,7 +142,7 @@ class Downsample(nn.Module):
downsampling occurs in the inner-two dimensions.
"""
- def __init__(self, channels, use_conv, dims=2, out_channels=None,padding=1):
+ def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
@@ -177,17 +179,17 @@ class ResBlock(TimestepBlock):
"""
def __init__(
- self,
- channels,
- emb_channels,
- dropout,
- out_channels=None,
- use_conv=False,
- use_scale_shift_norm=False,
- dims=2,
- use_checkpoint=False,
- up=False,
- down=False,
+ self,
+ channels,
+ emb_channels,
+ dropout,
+ out_channels=None,
+ use_conv=False,
+ use_scale_shift_norm=False,
+ dims=2,
+ use_checkpoint=False,
+ up=False,
+ down=False,
):
super().__init__()
self.channels = channels
@@ -251,7 +253,6 @@ def forward(self, x, emb):
self._forward, (x, emb), self.parameters(), self.use_checkpoint
)
-
def _forward(self, x, emb):
if self.updown:
in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
@@ -283,12 +284,12 @@ class AttentionBlock(nn.Module):
"""
def __init__(
- self,
- channels,
- num_heads=1,
- num_head_channels=-1,
- use_checkpoint=False,
- use_new_attention_order=False,
+ self,
+ channels,
+ num_heads=1,
+ num_head_channels=-1,
+ use_checkpoint=False,
+ use_new_attention_order=False,
):
super().__init__()
self.channels = channels
@@ -296,7 +297,7 @@ def __init__(
self.num_heads = num_heads
else:
assert (
- channels % num_head_channels == 0
+ channels % num_head_channels == 0
), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
self.num_heads = channels // num_head_channels
self.use_checkpoint = use_checkpoint
@@ -312,8 +313,9 @@ def __init__(
self.proj_out = zero_module(conv_nd(1, channels, channels, 1))
def forward(self, x):
- return checkpoint(self._forward, (x,), self.parameters(), True) # TODO: check checkpoint usage, is True # TODO: fix the .half call!!!
- #return pt_checkpoint(self._forward, x) # pytorch
+ return checkpoint(self._forward, (x,), self.parameters(),
+ True) # TODO: check checkpoint usage, is True # TODO: fix the .half call!!!
+ # return pt_checkpoint(self._forward, x) # pytorch
def _forward(self, x):
b, c, *spatial = x.shape
@@ -441,31 +443,31 @@ class UNetModel(nn.Module):
"""
def __init__(
- self,
- image_size,
- in_channels,
- model_channels,
- out_channels,
- num_res_blocks,
- attention_resolutions,
- dropout=0,
- channel_mult=(1, 2, 4, 8),
- conv_resample=True,
- dims=2,
- num_classes=None,
- use_checkpoint=False,
- use_fp16=False,
- num_heads=-1,
- num_head_channels=-1,
- num_heads_upsample=-1,
- use_scale_shift_norm=False,
- resblock_updown=False,
- use_new_attention_order=False,
- use_spatial_transformer=False, # custom transformer support
- transformer_depth=1, # custom transformer support
- context_dim=None, # custom transformer support
- n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model
- legacy=True,
+ self,
+ image_size,
+ in_channels,
+ model_channels,
+ out_channels,
+ num_res_blocks,
+ attention_resolutions,
+ dropout=0,
+ channel_mult=(1, 2, 4, 8),
+ conv_resample=True,
+ dims=2,
+ num_classes=None,
+ use_checkpoint=False,
+ use_fp16=False,
+ num_heads=-1,
+ num_head_channels=-1,
+ num_heads_upsample=-1,
+ use_scale_shift_norm=False,
+ resblock_updown=False,
+ use_new_attention_order=False,
+ use_spatial_transformer=False, # custom transformer support
+ transformer_depth=1, # custom transformer support
+ context_dim=None, # custom transformer support
+ n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model
+ legacy=True,
):
super().__init__()
if use_spatial_transformer:
@@ -545,7 +547,7 @@ def __init__(
num_heads = ch // num_head_channels
dim_head = num_head_channels
if legacy:
- #num_heads = 1
+ # num_heads = 1
dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
layers.append(
AttentionBlock(
@@ -592,7 +594,7 @@ def __init__(
num_heads = ch // num_head_channels
dim_head = num_head_channels
if legacy:
- #num_heads = 1
+ # num_heads = 1
dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
self.middle_block = TimestepEmbedSequential(
ResBlock(
@@ -610,8 +612,8 @@ def __init__(
num_head_channels=dim_head,
use_new_attention_order=use_new_attention_order,
) if not use_spatial_transformer else SpatialTransformer(
- ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim
- ),
+ ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim
+ ),
ResBlock(
ch,
time_embed_dim,
@@ -646,7 +648,7 @@ def __init__(
num_heads = ch // num_head_channels
dim_head = num_head_channels
if legacy:
- #num_heads = 1
+ # num_heads = 1
dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
layers.append(
AttentionBlock(
@@ -686,10 +688,10 @@ def __init__(
)
if self.predict_codebook_ids:
self.id_predictor = nn.Sequential(
- normalization(ch),
- conv_nd(dims, model_channels, n_embed, 1),
- #nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits
- )
+ normalization(ch),
+ conv_nd(dims, model_channels, n_embed, 1),
+ # nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits
+ )
def convert_to_fp16(self):
"""
@@ -707,7 +709,7 @@ def convert_to_fp32(self):
self.middle_block.apply(convert_module_to_f32)
self.output_blocks.apply(convert_module_to_f32)
- def forward(self, x, timesteps=None, context=None, y=None,**kwargs):
+ def forward(self, x, timesteps=None, context=None, y=None, **kwargs):
"""
Apply the model to an input batch.
:param x: an [N x C x ...] Tensor of inputs.
@@ -717,7 +719,7 @@ def forward(self, x, timesteps=None, context=None, y=None,**kwargs):
:return: an [N x C x ...] Tensor of outputs.
"""
assert (y is not None) == (
- self.num_classes is not None
+ self.num_classes is not None
), "must specify y if and only if the model is class-conditional"
hs = []
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
@@ -749,28 +751,28 @@ class EncoderUNetModel(nn.Module):
"""
def __init__(
- self,
- image_size,
- in_channels,
- model_channels,
- out_channels,
- num_res_blocks,
- attention_resolutions,
- dropout=0,
- channel_mult=(1, 2, 4, 8),
- conv_resample=True,
- dims=2,
- use_checkpoint=False,
- use_fp16=False,
- num_heads=1,
- num_head_channels=-1,
- num_heads_upsample=-1,
- use_scale_shift_norm=False,
- resblock_updown=False,
- use_new_attention_order=False,
- pool="adaptive",
- *args,
- **kwargs
+ self,
+ image_size,
+ in_channels,
+ model_channels,
+ out_channels,
+ num_res_blocks,
+ attention_resolutions,
+ dropout=0,
+ channel_mult=(1, 2, 4, 8),
+ conv_resample=True,
+ dims=2,
+ use_checkpoint=False,
+ use_fp16=False,
+ num_heads=1,
+ num_head_channels=-1,
+ num_heads_upsample=-1,
+ use_scale_shift_norm=False,
+ resblock_updown=False,
+ use_new_attention_order=False,
+ pool="adaptive",
+ *args,
+ **kwargs
):
super().__init__()
@@ -958,4 +960,3 @@ def forward(self, x, timesteps):
else:
h = h.type(x.dtype)
return self.out(h)
-
diff --git a/ldm/modules/diffusionmodules/util.py b/ldm/modules/diffusionmodules/util.py
index a952e6c40..acd44b851 100644
--- a/ldm/modules/diffusionmodules/util.py
+++ b/ldm/modules/diffusionmodules/util.py
@@ -8,11 +8,11 @@
# thanks!
-import os
import math
+import numpy as np
+import os
import torch
import torch.nn as nn
-import numpy as np
from einops import repeat
from ldm.util import instantiate_from_config
@@ -215,6 +215,7 @@ class GroupNorm32(nn.GroupNorm):
def forward(self, x):
return super().forward(x.float()).type(x.dtype)
+
def conv_nd(dims, *args, **kwargs):
"""
Create a 1D, 2D, or 3D convolution module.
@@ -264,4 +265,4 @@ def forward(self, c_concat, c_crossattn):
def noise_like(shape, device, repeat=False):
repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1)))
noise = lambda: torch.randn(shape, device=device)
- return repeat_noise() if repeat else noise()
\ No newline at end of file
+ return repeat_noise() if repeat else noise()
diff --git a/ldm/modules/distributions/distributions.py b/ldm/modules/distributions/distributions.py
index f2b8ef901..776277234 100644
--- a/ldm/modules/distributions/distributions.py
+++ b/ldm/modules/distributions/distributions.py
@@ -1,5 +1,5 @@
-import torch
import numpy as np
+import torch
class AbstractDistribution:
@@ -50,7 +50,7 @@ def kl(self, other=None):
+ self.var / other.var - 1.0 - self.logvar + other.logvar,
dim=[1, 2, 3])
- def nll(self, sample, dims=[1,2,3]):
+ def nll(self, sample, dims=[1, 2, 3]):
if self.deterministic:
return torch.Tensor([0.])
logtwopi = np.log(2.0 * np.pi)
@@ -84,9 +84,9 @@ def normal_kl(mean1, logvar1, mean2, logvar2):
]
return 0.5 * (
- -1.0
- + logvar2
- - logvar1
- + torch.exp(logvar1 - logvar2)
- + ((mean1 - mean2) ** 2) * torch.exp(-logvar2)
+ -1.0
+ + logvar2
+ - logvar1
+ + torch.exp(logvar1 - logvar2)
+ + ((mean1 - mean2) ** 2) * torch.exp(-logvar2)
)
diff --git a/ldm/modules/ema.py b/ldm/modules/ema.py
index c8c75af43..ea6d1d359 100644
--- a/ldm/modules/ema.py
+++ b/ldm/modules/ema.py
@@ -10,24 +10,24 @@ def __init__(self, model, decay=0.9999, use_num_upates=True):
self.m_name2s_name = {}
self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32))
- self.register_buffer('num_updates', torch.tensor(0,dtype=torch.int) if use_num_upates
- else torch.tensor(-1,dtype=torch.int))
+ self.register_buffer('num_updates', torch.tensor(0, dtype=torch.int) if use_num_upates
+ else torch.tensor(-1, dtype=torch.int))
for name, p in model.named_parameters():
if p.requires_grad:
- #remove as '.'-character is not allowed in buffers
- s_name = name.replace('.','')
- self.m_name2s_name.update({name:s_name})
- self.register_buffer(s_name,p.clone().detach().data)
+ # remove as '.'-character is not allowed in buffers
+ s_name = name.replace('.', '')
+ self.m_name2s_name.update({name: s_name})
+ self.register_buffer(s_name, p.clone().detach().data)
self.collected_params = []
- def forward(self,model):
+ def forward(self, model):
decay = self.decay
if self.num_updates >= 0:
self.num_updates += 1
- decay = min(self.decay,(1 + self.num_updates) / (10 + self.num_updates))
+ decay = min(self.decay, (1 + self.num_updates) / (10 + self.num_updates))
one_minus_decay = 1.0 - decay
diff --git a/ldm/modules/encoders/modules.py b/ldm/modules/encoders/modules.py
index ededbe43e..027ce9904 100644
--- a/ldm/modules/encoders/modules.py
+++ b/ldm/modules/encoders/modules.py
@@ -1,12 +1,13 @@
+import clip
+import kornia
import torch
import torch.nn as nn
-from functools import partial
-import clip
from einops import rearrange, repeat
+from functools import partial
from transformers import CLIPTokenizer, CLIPTextModel
-import kornia
-from ldm.modules.x_transformer import Encoder, TransformerWrapper # TODO: can we directly rely on lucidrains code and simply add this as a reuirement? --> test
+from ldm.modules.x_transformer import Encoder, \
+ TransformerWrapper # TODO: can we directly rely on lucidrains code and simply add this as a reuirement? --> test
class AbstractEncoder(nn.Module):
@@ -17,7 +18,6 @@ def encode(self, *args, **kwargs):
raise NotImplementedError
-
class ClassEmbedder(nn.Module):
def __init__(self, embed_dim, n_classes=1000, key='class'):
super().__init__()
@@ -35,6 +35,7 @@ def forward(self, batch, key=None):
class TransformerEmbedder(AbstractEncoder):
"""Some transformer encoder layers"""
+
def __init__(self, n_embed, n_layer, vocab_size, max_seq_len=77, device="cuda"):
super().__init__()
self.device = device
@@ -52,6 +53,7 @@ def encode(self, x):
class BERTTokenizer(AbstractEncoder):
""" Uses a pretrained BERT tokenizer by huggingface. Vocab size: 30522 (?)"""
+
def __init__(self, device="cuda", vq_interface=True, max_length=77):
super().__init__()
from transformers import BertTokenizerFast # TODO: add to reuquirements
@@ -79,8 +81,9 @@ def decode(self, text):
class BERTEmbedder(AbstractEncoder):
"""Uses the BERT tokenizr model and add some transformer encoder layers"""
+
def __init__(self, n_embed, n_layer, vocab_size=30522, max_seq_len=77,
- device="cuda",use_tokenizer=True, embedding_dropout=0.0):
+ device="cuda", use_tokenizer=True, embedding_dropout=0.0):
super().__init__()
self.use_tknz_fn = use_tokenizer
if self.use_tknz_fn:
@@ -92,7 +95,7 @@ def __init__(self, n_embed, n_layer, vocab_size=30522, max_seq_len=77,
def forward(self, text):
if self.use_tknz_fn:
- tokens = self.tknz_fn(text)#.to(self.device)
+ tokens = self.tknz_fn(text) # .to(self.device)
else:
tokens = text
z = self.transformer(tokens, return_embeddings=True)
@@ -114,19 +117,18 @@ def __init__(self,
super().__init__()
self.n_stages = n_stages
assert self.n_stages >= 0
- assert method in ['nearest','linear','bilinear','trilinear','bicubic','area']
+ assert method in ['nearest', 'linear', 'bilinear', 'trilinear', 'bicubic', 'area']
self.multiplier = multiplier
self.interpolator = partial(torch.nn.functional.interpolate, mode=method)
self.remap_output = out_channels is not None
if self.remap_output:
print(f'Spatial Rescaler mapping from {in_channels} to {out_channels} channels after resizing.')
- self.channel_mapper = nn.Conv2d(in_channels,out_channels,1,bias=bias)
+ self.channel_mapper = nn.Conv2d(in_channels, out_channels, 1, bias=bias)
- def forward(self,x):
+ def forward(self, x):
for stage in range(self.n_stages):
x = self.interpolator(x, scale_factor=self.multiplier)
-
if self.remap_output:
x = self.channel_mapper(x)
return x
@@ -134,8 +136,10 @@ def forward(self,x):
def encode(self, x):
return self(x)
+
class FrozenCLIPEmbedder(AbstractEncoder):
"""Uses the CLIP transformer encoder for text (from Hugging Face)"""
+
def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77):
super().__init__()
self.tokenizer = CLIPTokenizer.from_pretrained(version)
@@ -166,6 +170,7 @@ class FrozenCLIPTextEmbedder(nn.Module):
"""
Uses the CLIP transformer encoder for text.
"""
+
def __init__(self, version='ViT-L/14', device="cuda", max_length=77, n_repeat=1, normalize=True):
super().__init__()
self.model, _ = clip.load(version, jit=False, device="cpu")
@@ -188,7 +193,7 @@ def forward(self, text):
def encode(self, text):
z = self(text)
- if z.ndim==2:
+ if z.ndim == 2:
z = z[:, None, :]
z = repeat(z, 'b 1 d -> b k d', k=self.n_repeat)
return z
@@ -198,13 +203,14 @@ class FrozenClipImageEmbedder(nn.Module):
"""
Uses the CLIP image encoder.
"""
+
def __init__(
self,
model,
jit=False,
device='cuda' if torch.cuda.is_available() else 'cpu',
antialias=False,
- ):
+ ):
super().__init__()
self.model, _ = clip.load(name=model, device=device, jit=jit)
@@ -216,7 +222,7 @@ def __init__(
def preprocess(self, x):
# normalize to [0,1]
x = kornia.geometry.resize(x, (224, 224),
- interpolation='bicubic',align_corners=True,
+ interpolation='bicubic', align_corners=True,
antialias=self.antialias)
x = (x + 1.) / 2.
# renormalize according to clip
@@ -230,5 +236,6 @@ def forward(self, x):
if __name__ == "__main__":
from ldm.util import count_params
+
model = FrozenCLIPEmbedder()
- count_params(model, verbose=True)
\ No newline at end of file
+ count_params(model, verbose=True)
diff --git a/ldm/modules/image_degradation/bsrgan.py b/ldm/modules/image_degradation/bsrgan.py
index 32ef56169..3b0eae8eb 100644
--- a/ldm/modules/image_degradation/bsrgan.py
+++ b/ldm/modules/image_degradation/bsrgan.py
@@ -10,18 +10,17 @@
# --------------------------------------------
"""
-import numpy as np
+import albumentations
import cv2
-import torch
-
-from functools import partial
+import numpy as np
import random
-from scipy import ndimage
import scipy
import scipy.stats as ss
+import torch
+from functools import partial
+from scipy import ndimage
from scipy.interpolate import interp2d
from scipy.linalg import orth
-import albumentations
import ldm.modules.image_degradation.utils_image as util
@@ -609,7 +608,7 @@ def degradation_bsrgan_variant(image, sf=4, isp_model=None):
# add final JPEG compression noise
image = add_JPEG_noise(image)
image = util.single2uint(image)
- example = {"image":image}
+ example = {"image": image}
return example
@@ -702,29 +701,28 @@ def degradation_bsrgan_plus(img, sf=4, shuffle_prob=0.5, use_sharp=True, lq_patc
if __name__ == '__main__':
- print("hey")
- img = util.imread_uint('utils/test.png', 3)
- print(img)
- img = util.uint2single(img)
- print(img)
- img = img[:448, :448]
- h = img.shape[0] // 4
- print("resizing to", h)
- sf = 4
- deg_fn = partial(degradation_bsrgan_variant, sf=sf)
- for i in range(20):
- print(i)
- img_lq = deg_fn(img)
- print(img_lq)
- img_lq_bicubic = albumentations.SmallestMaxSize(max_size=h, interpolation=cv2.INTER_CUBIC)(image=img)["image"]
- print(img_lq.shape)
- print("bicubic", img_lq_bicubic.shape)
- print(img_hq.shape)
- lq_nearest = cv2.resize(util.single2uint(img_lq), (int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])),
- interpolation=0)
- lq_bicubic_nearest = cv2.resize(util.single2uint(img_lq_bicubic), (int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])),
- interpolation=0)
- img_concat = np.concatenate([lq_bicubic_nearest, lq_nearest, util.single2uint(img_hq)], axis=1)
- util.imsave(img_concat, str(i) + '.png')
-
-
+ print("hey")
+ img = util.imread_uint('utils/test.png', 3)
+ print(img)
+ img = util.uint2single(img)
+ print(img)
+ img = img[:448, :448]
+ h = img.shape[0] // 4
+ print("resizing to", h)
+ sf = 4
+ deg_fn = partial(degradation_bsrgan_variant, sf=sf)
+ for i in range(20):
+ print(i)
+ img_lq = deg_fn(img)
+ print(img_lq)
+ img_lq_bicubic = albumentations.SmallestMaxSize(max_size=h, interpolation=cv2.INTER_CUBIC)(image=img)["image"]
+ print(img_lq.shape)
+ print("bicubic", img_lq_bicubic.shape)
+ print(img_hq.shape)
+ lq_nearest = cv2.resize(util.single2uint(img_lq), (int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])),
+ interpolation=0)
+ lq_bicubic_nearest = cv2.resize(util.single2uint(img_lq_bicubic),
+ (int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])),
+ interpolation=0)
+ img_concat = np.concatenate([lq_bicubic_nearest, lq_nearest, util.single2uint(img_hq)], axis=1)
+ util.imsave(img_concat, str(i) + '.png')
diff --git a/ldm/modules/image_degradation/bsrgan_light.py b/ldm/modules/image_degradation/bsrgan_light.py
index 9e1f82399..712191c54 100644
--- a/ldm/modules/image_degradation/bsrgan_light.py
+++ b/ldm/modules/image_degradation/bsrgan_light.py
@@ -1,16 +1,15 @@
# -*- coding: utf-8 -*-
-import numpy as np
+import albumentations
import cv2
-import torch
-
-from functools import partial
+import numpy as np
import random
-from scipy import ndimage
import scipy
import scipy.stats as ss
+import torch
+from functools import partial
+from scipy import ndimage
from scipy.interpolate import interp2d
from scipy.linalg import orth
-import albumentations
import ldm.modules.image_degradation.utils_image as util
@@ -326,8 +325,8 @@ def add_blur(img, sf=4):
wd2 = 4.0 + sf
wd = 2.0 + 0.2 * sf
- wd2 = wd2/4
- wd = wd/4
+ wd2 = wd2 / 4
+ wd = wd / 4
if random.random() < 0.5:
l1 = wd2 * random.random()
@@ -621,8 +620,6 @@ def degradation_bsrgan_variant(image, sf=4, isp_model=None):
return example
-
-
if __name__ == '__main__':
print("hey")
img = util.imread_uint('utils/test.png', 3)
@@ -637,7 +634,8 @@ def degradation_bsrgan_variant(image, sf=4, isp_model=None):
img_lq = deg_fn(img)["image"]
img_hq, img_lq = util.uint2single(img_hq), util.uint2single(img_lq)
print(img_lq)
- img_lq_bicubic = albumentations.SmallestMaxSize(max_size=h, interpolation=cv2.INTER_CUBIC)(image=img_hq)["image"]
+ img_lq_bicubic = albumentations.SmallestMaxSize(max_size=h, interpolation=cv2.INTER_CUBIC)(image=img_hq)[
+ "image"]
print(img_lq.shape)
print("bicubic", img_lq_bicubic.shape)
print(img_hq.shape)
diff --git a/ldm/modules/image_degradation/utils_image.py b/ldm/modules/image_degradation/utils_image.py
index 0175f155a..6c282e77b 100644
--- a/ldm/modules/image_degradation/utils_image.py
+++ b/ldm/modules/image_degradation/utils_image.py
@@ -1,16 +1,16 @@
-import os
+import cv2
import math
-import random
import numpy as np
+import os
+import random
import torch
-import cv2
-from torchvision.utils import make_grid
from datetime import datetime
-#import matplotlib.pyplot as plt # TODO: check with Dominik, also bsrgan.py vs bsrgan_light.py
+from torchvision.utils import make_grid
+# import matplotlib.pyplot as plt # TODO: check with Dominik, also bsrgan.py vs bsrgan_light.py
-os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"
+os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
'''
# --------------------------------------------
@@ -22,7 +22,6 @@
# --------------------------------------------
'''
-
IMG_EXTENSIONS = ['.jpg', '.JPG', '.jpeg', '.JPEG', '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', '.tif']
@@ -49,11 +48,11 @@ def surf(Z, cmap='rainbow', figsize=None):
ax3 = plt.axes(projection='3d')
w, h = Z.shape[:2]
- xx = np.arange(0,w,1)
- yy = np.arange(0,h,1)
+ xx = np.arange(0, w, 1)
+ yy = np.arange(0, h, 1)
X, Y = np.meshgrid(xx, yy)
- ax3.plot_surface(X,Y,Z,cmap=cmap)
- #ax3.contour(X,Y,Z, zdim='z',offset=-2,cmap=cmap)
+ ax3.plot_surface(X, Y, Z, cmap=cmap)
+ # ax3.contour(X,Y,Z, zdim='z',offset=-2,cmap=cmap)
plt.show()
@@ -94,15 +93,15 @@ def patches_from_image(img, p_size=512, p_overlap=64, p_max=800):
w, h = img.shape[:2]
patches = []
if w > p_max and h > p_max:
- w1 = list(np.arange(0, w-p_size, p_size-p_overlap, dtype=np.int))
- h1 = list(np.arange(0, h-p_size, p_size-p_overlap, dtype=np.int))
- w1.append(w-p_size)
- h1.append(h-p_size)
-# print(w1)
-# print(h1)
+ w1 = list(np.arange(0, w - p_size, p_size - p_overlap, dtype=np.int))
+ h1 = list(np.arange(0, h - p_size, p_size - p_overlap, dtype=np.int))
+ w1.append(w - p_size)
+ h1.append(h - p_size)
+ # print(w1)
+ # print(h1)
for i in w1:
for j in h1:
- patches.append(img[i:i+p_size, j:j+p_size,:])
+ patches.append(img[i:i + p_size, j:j + p_size, :])
else:
patches.append(img)
@@ -118,7 +117,7 @@ def imssave(imgs, img_path):
for i, img in enumerate(imgs):
if img.ndim == 3:
img = img[:, :, [2, 1, 0]]
- new_path = os.path.join(os.path.dirname(img_path), img_name+str('_s{:04d}'.format(i))+'.png')
+ new_path = os.path.join(os.path.dirname(img_path), img_name + str('_s{:04d}'.format(i)) + '.png')
cv2.imwrite(new_path, img)
@@ -139,9 +138,10 @@ def split_imageset(original_dataroot, taget_dataroot, n_channels=3, p_size=800,
# img_name, ext = os.path.splitext(os.path.basename(img_path))
img = imread_uint(img_path, n_channels=n_channels)
patches = patches_from_image(img, p_size, p_overlap, p_max)
- imssave(patches, os.path.join(taget_dataroot,os.path.basename(img_path)))
- #if original_dataroot == taget_dataroot:
- #del img_path
+ imssave(patches, os.path.join(taget_dataroot, os.path.basename(img_path)))
+ # if original_dataroot == taget_dataroot:
+ # del img_path
+
'''
# --------------------------------------------
@@ -206,6 +206,7 @@ def imsave(img, img_path):
img = img[:, :, [2, 1, 0]]
cv2.imwrite(img_path, img)
+
def imwrite(img, img_path):
img = np.squeeze(img)
if img.ndim == 3:
@@ -213,7 +214,6 @@ def imwrite(img, img_path):
cv2.imwrite(img_path, img)
-
# --------------------------------------------
# get single image of size HxWxn_channles (BGR)
# --------------------------------------------
@@ -247,23 +247,19 @@ def read_img(path):
def uint2single(img):
-
- return np.float32(img/255.)
+ return np.float32(img / 255.)
def single2uint(img):
-
- return np.uint8((img.clip(0, 1)*255.).round())
+ return np.uint8((img.clip(0, 1) * 255.).round())
def uint162single(img):
-
- return np.float32(img/65535.)
+ return np.float32(img / 65535.)
def single2uint16(img):
-
- return np.uint16((img.clip(0, 1)*65535.).round())
+ return np.uint16((img.clip(0, 1) * 65535.).round())
# --------------------------------------------
@@ -290,7 +286,7 @@ def tensor2uint(img):
img = img.data.squeeze().float().clamp_(0, 1).cpu().numpy()
if img.ndim == 3:
img = np.transpose(img, (1, 2, 0))
- return np.uint8((img*255.0).round())
+ return np.uint8((img * 255.0).round())
# --------------------------------------------
@@ -316,6 +312,7 @@ def tensor2single(img):
return img
+
# convert torch tensor to single
def tensor2single3(img):
img = img.data.squeeze().float().cpu().numpy()
@@ -511,7 +508,7 @@ def shave(img_in, border=0):
# img_in: Numpy, HWC or HW
img = np.copy(img_in)
h, w = img.shape[:2]
- img = img[border:h-border, border:w-border]
+ img = img[border:h - border, border:w - border]
return img
@@ -620,17 +617,17 @@ def channel_convert(in_c, tar_type, img_list):
# --------------------------------------------
def calculate_psnr(img1, img2, border=0):
# img1 and img2 have range [0, 255]
- #img1 = img1.squeeze()
- #img2 = img2.squeeze()
+ # img1 = img1.squeeze()
+ # img2 = img2.squeeze()
if not img1.shape == img2.shape:
raise ValueError('Input images must have the same dimensions.')
h, w = img1.shape[:2]
- img1 = img1[border:h-border, border:w-border]
- img2 = img2[border:h-border, border:w-border]
+ img1 = img1[border:h - border, border:w - border]
+ img2 = img2[border:h - border, border:w - border]
img1 = img1.astype(np.float64)
img2 = img2.astype(np.float64)
- mse = np.mean((img1 - img2)**2)
+ mse = np.mean((img1 - img2) ** 2)
if mse == 0:
return float('inf')
return 20 * math.log10(255.0 / math.sqrt(mse))
@@ -644,13 +641,13 @@ def calculate_ssim(img1, img2, border=0):
the same outputs as MATLAB's
img1, img2: [0, 255]
'''
- #img1 = img1.squeeze()
- #img2 = img2.squeeze()
+ # img1 = img1.squeeze()
+ # img2 = img2.squeeze()
if not img1.shape == img2.shape:
raise ValueError('Input images must have the same dimensions.')
h, w = img1.shape[:2]
- img1 = img1[border:h-border, border:w-border]
- img2 = img2[border:h-border, border:w-border]
+ img1 = img1[border:h - border, border:w - border]
+ img2 = img2[border:h - border, border:w - border]
if img1.ndim == 2:
return ssim(img1, img2)
@@ -658,7 +655,7 @@ def calculate_ssim(img1, img2, border=0):
if img1.shape[2] == 3:
ssims = []
for i in range(3):
- ssims.append(ssim(img1[:,:,i], img2[:,:,i]))
+ ssims.append(ssim(img1[:, :, i], img2[:, :, i]))
return np.array(ssims).mean()
elif img1.shape[2] == 1:
return ssim(np.squeeze(img1), np.squeeze(img2))
@@ -667,8 +664,8 @@ def calculate_ssim(img1, img2, border=0):
def ssim(img1, img2):
- C1 = (0.01 * 255)**2
- C2 = (0.03 * 255)**2
+ C1 = (0.01 * 255) ** 2
+ C2 = (0.03 * 255) ** 2
img1 = img1.astype(np.float64)
img2 = img2.astype(np.float64)
@@ -677,11 +674,11 @@ def ssim(img1, img2):
mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5] # valid
mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5]
- mu1_sq = mu1**2
- mu2_sq = mu2**2
+ mu1_sq = mu1 ** 2
+ mu2_sq = mu2 ** 2
mu1_mu2 = mu1 * mu2
- sigma1_sq = cv2.filter2D(img1**2, -1, window)[5:-5, 5:-5] - mu1_sq
- sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq
+ sigma1_sq = cv2.filter2D(img1 ** 2, -1, window)[5:-5, 5:-5] - mu1_sq
+ sigma2_sq = cv2.filter2D(img2 ** 2, -1, window)[5:-5, 5:-5] - mu2_sq
sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2
ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) *
@@ -699,10 +696,10 @@ def ssim(img1, img2):
# matlab 'imresize' function, now only support 'bicubic'
def cubic(x):
absx = torch.abs(x)
- absx2 = absx**2
- absx3 = absx**3
- return (1.5*absx3 - 2.5*absx2 + 1) * ((absx <= 1).type_as(absx)) + \
- (-0.5*absx3 + 2.5*absx2 - 4*absx + 2) * (((absx > 1)*(absx <= 2)).type_as(absx))
+ absx2 = absx ** 2
+ absx3 = absx ** 3
+ return (1.5 * absx3 - 2.5 * absx2 + 1) * ((absx <= 1).type_as(absx)) + \
+ (-0.5 * absx3 + 2.5 * absx2 - 4 * absx + 2) * (((absx > 1) * (absx <= 2)).type_as(absx))
def calculate_weights_indices(in_length, out_length, scale, kernel, kernel_width, antialiasing):
@@ -913,4 +910,4 @@ def imresize_np(img, scale, antialiasing=True):
print('---')
# img = imread_uint('test.bmp', 3)
# img = uint2single(img)
-# img_bicubic = imresize_np(img, 1/4)
\ No newline at end of file
+# img_bicubic = imresize_np(img, 1/4)
diff --git a/ldm/modules/losses/__init__.py b/ldm/modules/losses/__init__.py
index 876d7c5bd..d86294210 100644
--- a/ldm/modules/losses/__init__.py
+++ b/ldm/modules/losses/__init__.py
@@ -1 +1 @@
-from ldm.modules.losses.contperceptual import LPIPSWithDiscriminator
\ No newline at end of file
+from ldm.modules.losses.contperceptual import LPIPSWithDiscriminator
diff --git a/ldm/modules/losses/contperceptual.py b/ldm/modules/losses/contperceptual.py
index 672c1e32a..4f3463c06 100644
--- a/ldm/modules/losses/contperceptual.py
+++ b/ldm/modules/losses/contperceptual.py
@@ -53,7 +53,7 @@ def forward(self, inputs, reconstructions, posteriors, optimizer_idx,
nll_loss = rec_loss / torch.exp(self.logvar) + self.logvar
weighted_nll_loss = nll_loss
if weights is not None:
- weighted_nll_loss = weights*nll_loss
+ weighted_nll_loss = weights * nll_loss
weighted_nll_loss = torch.sum(weighted_nll_loss) / weighted_nll_loss.shape[0]
nll_loss = torch.sum(nll_loss) / nll_loss.shape[0]
kl_loss = posteriors.kl()
@@ -82,8 +82,10 @@ def forward(self, inputs, reconstructions, posteriors, optimizer_idx,
disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
loss = weighted_nll_loss + self.kl_weight * kl_loss + d_weight * disc_factor * g_loss
- log = {"{}/total_loss".format(split): loss.clone().detach().mean(), "{}/logvar".format(split): self.logvar.detach(),
- "{}/kl_loss".format(split): kl_loss.detach().mean(), "{}/nll_loss".format(split): nll_loss.detach().mean(),
+ log = {"{}/total_loss".format(split): loss.clone().detach().mean(),
+ "{}/logvar".format(split): self.logvar.detach(),
+ "{}/kl_loss".format(split): kl_loss.detach().mean(),
+ "{}/nll_loss".format(split): nll_loss.detach().mean(),
"{}/rec_loss".format(split): rec_loss.detach().mean(),
"{}/d_weight".format(split): d_weight.detach(),
"{}/disc_factor".format(split): torch.tensor(disc_factor),
@@ -108,4 +110,3 @@ def forward(self, inputs, reconstructions, posteriors, optimizer_idx,
"{}/logits_fake".format(split): logits_fake.detach().mean()
}
return d_loss, log
-
diff --git a/ldm/modules/losses/vqperceptual.py b/ldm/modules/losses/vqperceptual.py
index f69981769..a30b87fcc 100644
--- a/ldm/modules/losses/vqperceptual.py
+++ b/ldm/modules/losses/vqperceptual.py
@@ -1,22 +1,22 @@
import torch
-from torch import nn
import torch.nn.functional as F
from einops import repeat
-
from taming.modules.discriminator.model import NLayerDiscriminator, weights_init
from taming.modules.losses.lpips import LPIPS
from taming.modules.losses.vqperceptual import hinge_d_loss, vanilla_d_loss
+from torch import nn
def hinge_d_loss_with_exemplar_weights(logits_real, logits_fake, weights):
assert weights.shape[0] == logits_real.shape[0] == logits_fake.shape[0]
- loss_real = torch.mean(F.relu(1. - logits_real), dim=[1,2,3])
- loss_fake = torch.mean(F.relu(1. + logits_fake), dim=[1,2,3])
+ loss_real = torch.mean(F.relu(1. - logits_real), dim=[1, 2, 3])
+ loss_fake = torch.mean(F.relu(1. + logits_fake), dim=[1, 2, 3])
loss_real = (weights * loss_real).sum() / weights.sum()
loss_fake = (weights * loss_fake).sum() / weights.sum()
d_loss = 0.5 * (loss_real + loss_fake)
return d_loss
+
def adopt_weight(weight, global_step, threshold=0, value=0.):
if global_step < threshold:
weight = value
@@ -32,12 +32,13 @@ def measure_perplexity(predicted_indices, n_embed):
cluster_use = torch.sum(avg_probs > 0)
return perplexity, cluster_use
+
def l1(x, y):
- return torch.abs(x-y)
+ return torch.abs(x - y)
def l2(x, y):
- return torch.pow((x-y), 2)
+ return torch.pow((x - y), 2)
class VQLPIPSWithDiscriminator(nn.Module):
@@ -99,7 +100,7 @@ def forward(self, codebook_loss, inputs, reconstructions, optimizer_idx,
global_step, last_layer=None, cond=None, split="train", predicted_indices=None):
if not exists(codebook_loss):
codebook_loss = torch.tensor([0.]).to(inputs.device)
- #rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous())
+ # rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous())
rec_loss = self.pixel_loss(inputs.contiguous(), reconstructions.contiguous())
if self.perceptual_weight > 0:
p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous())
@@ -108,7 +109,7 @@ def forward(self, codebook_loss, inputs, reconstructions, optimizer_idx,
p_loss = torch.tensor([0.0])
nll_loss = rec_loss
- #nll_loss = torch.sum(nll_loss) / nll_loss.shape[0]
+ # nll_loss = torch.sum(nll_loss) / nll_loss.shape[0]
nll_loss = torch.mean(nll_loss)
# now the GAN part
diff --git a/ldm/modules/x_transformer.py b/ldm/modules/x_transformer.py
index 5fc15bf9c..76c74c304 100644
--- a/ldm/modules/x_transformer.py
+++ b/ldm/modules/x_transformer.py
@@ -1,11 +1,11 @@
"""shout-out to https://github.com/lucidrains/x-transformers/tree/main/x_transformers"""
import torch
-from torch import nn, einsum
import torch.nn.functional as F
-from functools import partial
-from inspect import isfunction
from collections import namedtuple
from einops import rearrange, repeat, reduce
+from functools import partial
+from inspect import isfunction
+from torch import nn, einsum
# constants
@@ -64,18 +64,21 @@ def default(val, d):
def always(val):
def inner(*args, **kwargs):
return val
+
return inner
def not_equals(val):
def inner(x):
return x != val
+
return inner
def equals(val):
def inner(x):
return x == val
+
return inner
@@ -252,7 +255,7 @@ def __init__(
self.sparse_topk = sparse_topk
# entmax
- #self.attn_fn = entmax15 if use_entmax15 else F.softmax
+ # self.attn_fn = entmax15 if use_entmax15 else F.softmax
self.attn_fn = F.softmax
# add memory key / values
@@ -544,7 +547,6 @@ def __init__(self, **kwargs):
super().__init__(causal=False, **kwargs)
-
class TransformerWrapper(nn.Module):
def __init__(
self,
@@ -571,7 +573,7 @@ def __init__(
self.token_emb = nn.Embedding(num_tokens, emb_dim)
self.pos_emb = AbsolutePositionalEmbedding(emb_dim, max_seq_len) if (
- use_pos_emb and not attn_layers.has_pos_emb) else always(0)
+ use_pos_emb and not attn_layers.has_pos_emb) else always(0)
self.emb_dropout = nn.Dropout(emb_dropout)
self.project_emb = nn.Linear(emb_dim, dim) if emb_dim != dim else nn.Identity()
@@ -638,4 +640,3 @@ def forward(
return out, attn_maps
return out
-
diff --git a/ldm/util.py b/ldm/util.py
index 8ba38853e..bb4b853c9 100644
--- a/ldm/util.py
+++ b/ldm/util.py
@@ -1,16 +1,12 @@
import importlib
-
-import torch
-import numpy as np
-from collections import abc
-from einops import rearrange
-from functools import partial
-
import multiprocessing as mp
-from threading import Thread
+from collections import abc
+from inspect import isfunction
from queue import Queue
+from threading import Thread
-from inspect import isfunction
+import numpy as np
+import torch
from PIL import Image, ImageDraw, ImageFont
diff --git a/main.py b/main.py
index e8e18c18f..f1ef268a1 100644
--- a/main.py
+++ b/main.py
@@ -1,21 +1,25 @@
-import argparse, os, sys, datetime, glob, importlib, csv
+import argparse
+import csv
+import datetime
+import glob
+import importlib
import numpy as np
+import os
+import pytorch_lightning as pl
+import sys
import time
import torch
import torchvision
-import pytorch_lightning as pl
-
-from packaging import version
-from omegaconf import OmegaConf
-from torch.utils.data import random_split, DataLoader, Dataset, Subset
-from functools import partial
from PIL import Image
-
+from functools import partial
+from omegaconf import OmegaConf
+from packaging import version
from pytorch_lightning import seed_everything
-from pytorch_lightning.trainer import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint, Callback, LearningRateMonitor
-from pytorch_lightning.utilities.distributed import rank_zero_only
+from pytorch_lightning.trainer import Trainer
from pytorch_lightning.utilities import rank_zero_info
+from pytorch_lightning.utilities.distributed import rank_zero_only
+from torch.utils.data import random_split, DataLoader, Dataset, Subset
from ldm.data.base import Txt2ImgIterableBaseDataset
from ldm.util import instantiate_from_config
@@ -583,7 +587,7 @@ def on_train_epoch_end(self, trainer, pl_module, outputs):
if "modelcheckpoint" in lightning_config:
modelckpt_cfg = lightning_config.modelcheckpoint
else:
- modelckpt_cfg = OmegaConf.create()
+ modelckpt_cfg = OmegaConf.create()
modelckpt_cfg = OmegaConf.merge(default_modelckpt_cfg, modelckpt_cfg)
print(f"Merged modelckpt-cfg: \n{modelckpt_cfg}")
if version.parse(pl.__version__) < version.parse('1.4.0'):
diff --git a/models/first_stage_models/kl-f16/config.yaml b/models/first_stage_models/kl-f16/config.yaml
index 661921cf7..5513addcb 100644
--- a/models/first_stage_models/kl-f16/config.yaml
+++ b/models/first_stage_models/kl-f16/config.yaml
@@ -18,14 +18,14 @@ model:
out_ch: 3
ch: 128
ch_mult:
- - 1
- - 1
- - 2
- - 2
- - 4
+ - 1
+ - 1
+ - 2
+ - 2
+ - 4
num_res_blocks: 2
attn_resolutions:
- - 16
+ - 16
dropout: 0.0
data:
target: main.DataModuleFromConfig
diff --git a/models/first_stage_models/kl-f32/config.yaml b/models/first_stage_models/kl-f32/config.yaml
index 7b642b136..140045168 100644
--- a/models/first_stage_models/kl-f32/config.yaml
+++ b/models/first_stage_models/kl-f32/config.yaml
@@ -18,16 +18,16 @@ model:
out_ch: 3
ch: 128
ch_mult:
- - 1
- - 1
- - 2
- - 2
- - 4
- - 4
+ - 1
+ - 1
+ - 2
+ - 2
+ - 4
+ - 4
num_res_blocks: 2
attn_resolutions:
- - 16
- - 8
+ - 16
+ - 8
dropout: 0.0
data:
target: main.DataModuleFromConfig
diff --git a/models/first_stage_models/kl-f4/config.yaml b/models/first_stage_models/kl-f4/config.yaml
index 85cfb3e94..091bfefc0 100644
--- a/models/first_stage_models/kl-f4/config.yaml
+++ b/models/first_stage_models/kl-f4/config.yaml
@@ -18,11 +18,11 @@ model:
out_ch: 3
ch: 128
ch_mult:
- - 1
- - 2
- - 4
+ - 1
+ - 2
+ - 4
num_res_blocks: 2
- attn_resolutions: []
+ attn_resolutions: [ ]
dropout: 0.0
data:
target: main.DataModuleFromConfig
diff --git a/models/first_stage_models/kl-f8/config.yaml b/models/first_stage_models/kl-f8/config.yaml
index 921aa4253..b45420abb 100644
--- a/models/first_stage_models/kl-f8/config.yaml
+++ b/models/first_stage_models/kl-f8/config.yaml
@@ -18,12 +18,12 @@ model:
out_ch: 3
ch: 128
ch_mult:
- - 1
- - 2
- - 4
- - 4
+ - 1
+ - 2
+ - 4
+ - 4
num_res_blocks: 2
- attn_resolutions: []
+ attn_resolutions: [ ]
dropout: 0.0
data:
target: main.DataModuleFromConfig
diff --git a/models/first_stage_models/vq-f16/config.yaml b/models/first_stage_models/vq-f16/config.yaml
index 91c745490..617996101 100644
--- a/models/first_stage_models/vq-f16/config.yaml
+++ b/models/first_stage_models/vq-f16/config.yaml
@@ -12,14 +12,14 @@ model:
out_ch: 3
ch: 128
ch_mult:
- - 1
- - 1
- - 2
- - 2
- - 4
+ - 1
+ - 1
+ - 2
+ - 2
+ - 4
num_res_blocks: 2
attn_resolutions:
- - 16
+ - 16
dropout: 0.0
lossconfig:
target: taming.modules.losses.vqperceptual.VQLPIPSWithDiscriminator
diff --git a/models/first_stage_models/vq-f4-noattn/config.yaml b/models/first_stage_models/vq-f4-noattn/config.yaml
index f8e499fa2..dcc27351c 100644
--- a/models/first_stage_models/vq-f4-noattn/config.yaml
+++ b/models/first_stage_models/vq-f4-noattn/config.yaml
@@ -15,11 +15,11 @@ model:
out_ch: 3
ch: 128
ch_mult:
- - 1
- - 2
- - 4
+ - 1
+ - 2
+ - 4
num_res_blocks: 2
- attn_resolutions: []
+ attn_resolutions: [ ]
dropout: 0.0
lossconfig:
target: taming.modules.losses.vqperceptual.VQLPIPSWithDiscriminator
diff --git a/models/first_stage_models/vq-f4/config.yaml b/models/first_stage_models/vq-f4/config.yaml
index 7d8cef325..5e982bf19 100644
--- a/models/first_stage_models/vq-f4/config.yaml
+++ b/models/first_stage_models/vq-f4/config.yaml
@@ -14,11 +14,11 @@ model:
out_ch: 3
ch: 128
ch_mult:
- - 1
- - 2
- - 4
+ - 1
+ - 2
+ - 4
num_res_blocks: 2
- attn_resolutions: []
+ attn_resolutions: [ ]
dropout: 0.0
lossconfig:
target: taming.modules.losses.vqperceptual.VQLPIPSWithDiscriminator
diff --git a/models/first_stage_models/vq-f8-n256/config.yaml b/models/first_stage_models/vq-f8-n256/config.yaml
index 8519e13d6..eeacdcf0c 100644
--- a/models/first_stage_models/vq-f8-n256/config.yaml
+++ b/models/first_stage_models/vq-f8-n256/config.yaml
@@ -13,13 +13,13 @@ model:
out_ch: 3
ch: 128
ch_mult:
- - 1
- - 2
- - 2
- - 4
+ - 1
+ - 2
+ - 2
+ - 4
num_res_blocks: 2
attn_resolutions:
- - 32
+ - 32
dropout: 0.0
lossconfig:
target: taming.modules.losses.vqperceptual.VQLPIPSWithDiscriminator
diff --git a/models/first_stage_models/vq-f8/config.yaml b/models/first_stage_models/vq-f8/config.yaml
index efd6801ca..b87f6ee2b 100644
--- a/models/first_stage_models/vq-f8/config.yaml
+++ b/models/first_stage_models/vq-f8/config.yaml
@@ -13,13 +13,13 @@ model:
out_ch: 3
ch: 128
ch_mult:
- - 1
- - 2
- - 2
- - 4
+ - 1
+ - 2
+ - 2
+ - 4
num_res_blocks: 2
attn_resolutions:
- - 32
+ - 32
dropout: 0.0
lossconfig:
target: taming.modules.losses.vqperceptual.VQLPIPSWithDiscriminator
diff --git a/models/ldm/bsr_sr/config.yaml b/models/ldm/bsr_sr/config.yaml
index 861692a8d..837d2764c 100644
--- a/models/ldm/bsr_sr/config.yaml
+++ b/models/ldm/bsr_sr/config.yaml
@@ -21,14 +21,14 @@ model:
out_channels: 3
model_channels: 160
attention_resolutions:
- - 16
- - 8
+ - 16
+ - 8
num_res_blocks: 2
channel_mult:
- - 1
- - 2
- - 2
- - 4
+ - 1
+ - 2
+ - 2
+ - 4
num_head_channels: 32
first_stage_config:
target: ldm.models.autoencoder.VQModelInterface
@@ -44,11 +44,11 @@ model:
out_ch: 3
ch: 128
ch_mult:
- - 1
- - 2
- - 4
+ - 1
+ - 2
+ - 4
num_res_blocks: 2
- attn_resolutions: []
+ attn_resolutions: [ ]
dropout: 0.0
lossconfig:
target: torch.nn.Identity
diff --git a/models/ldm/celeba256/config.yaml b/models/ldm/celeba256/config.yaml
index a12f4e9d3..b91149f65 100644
--- a/models/ldm/celeba256/config.yaml
+++ b/models/ldm/celeba256/config.yaml
@@ -22,15 +22,15 @@ model:
out_channels: 3
model_channels: 224
attention_resolutions:
- - 8
- - 4
- - 2
+ - 8
+ - 4
+ - 2
num_res_blocks: 2
channel_mult:
- - 1
- - 2
- - 3
- - 4
+ - 1
+ - 2
+ - 3
+ - 4
num_head_channels: 32
first_stage_config:
target: ldm.models.autoencoder.VQModelInterface
@@ -45,11 +45,11 @@ model:
out_ch: 3
ch: 128
ch_mult:
- - 1
- - 2
- - 4
+ - 1
+ - 2
+ - 4
num_res_blocks: 2
- attn_resolutions: []
+ attn_resolutions: [ ]
dropout: 0.0
lossconfig:
target: torch.nn.Identity
diff --git a/models/ldm/cin256/config.yaml b/models/ldm/cin256/config.yaml
index 9bc1b4566..8aa44fbab 100644
--- a/models/ldm/cin256/config.yaml
+++ b/models/ldm/cin256/config.yaml
@@ -22,14 +22,14 @@ model:
out_channels: 4
model_channels: 256
attention_resolutions:
- - 4
- - 2
- - 1
+ - 4
+ - 2
+ - 1
num_res_blocks: 2
channel_mult:
- - 1
- - 2
- - 4
+ - 1
+ - 2
+ - 4
num_head_channels: 32
use_spatial_transformer: true
transformer_depth: 1
@@ -47,13 +47,13 @@ model:
out_ch: 3
ch: 128
ch_mult:
- - 1
- - 2
- - 2
- - 4
+ - 1
+ - 2
+ - 2
+ - 4
num_res_blocks: 2
attn_resolutions:
- - 32
+ - 32
dropout: 0.0
lossconfig:
target: torch.nn.Identity
diff --git a/models/ldm/ffhq256/config.yaml b/models/ldm/ffhq256/config.yaml
index 0ddfd1b93..b52027c81 100644
--- a/models/ldm/ffhq256/config.yaml
+++ b/models/ldm/ffhq256/config.yaml
@@ -22,15 +22,15 @@ model:
out_channels: 3
model_channels: 224
attention_resolutions:
- - 8
- - 4
- - 2
+ - 8
+ - 4
+ - 2
num_res_blocks: 2
channel_mult:
- - 1
- - 2
- - 3
- - 4
+ - 1
+ - 2
+ - 3
+ - 4
num_head_channels: 32
first_stage_config:
target: ldm.models.autoencoder.VQModelInterface
@@ -45,11 +45,11 @@ model:
out_ch: 3
ch: 128
ch_mult:
- - 1
- - 2
- - 4
+ - 1
+ - 2
+ - 4
num_res_blocks: 2
- attn_resolutions: []
+ attn_resolutions: [ ]
dropout: 0.0
lossconfig:
target: torch.nn.Identity
diff --git a/models/ldm/inpainting_big/config.yaml b/models/ldm/inpainting_big/config.yaml
index da5fd5ea5..99ff65a83 100644
--- a/models/ldm/inpainting_big/config.yaml
+++ b/models/ldm/inpainting_big/config.yaml
@@ -30,15 +30,15 @@ model:
out_channels: 3
model_channels: 256
attention_resolutions:
- - 8
- - 4
- - 2
+ - 8
+ - 4
+ - 2
num_res_blocks: 2
channel_mult:
- - 1
- - 2
- - 3
- - 4
+ - 1
+ - 2
+ - 3
+ - 4
num_heads: 8
resblock_updown: true
first_stage_config:
@@ -56,11 +56,11 @@ model:
out_ch: 3
ch: 128
ch_mult:
- - 1
- - 2
- - 4
+ - 1
+ - 2
+ - 4
num_res_blocks: 2
- attn_resolutions: []
+ attn_resolutions: [ ]
dropout: 0.0
lossconfig:
target: ldm.modules.losses.contperceptual.DummyLoss
diff --git a/models/ldm/layout2img-openimages256/config.yaml b/models/ldm/layout2img-openimages256/config.yaml
index 9e1dc15fe..a5618a938 100644
--- a/models/ldm/layout2img-openimages256/config.yaml
+++ b/models/ldm/layout2img-openimages256/config.yaml
@@ -21,15 +21,15 @@ model:
out_channels: 3
model_channels: 128
attention_resolutions:
- - 8
- - 4
- - 2
+ - 8
+ - 4
+ - 2
num_res_blocks: 2
channel_mult:
- - 1
- - 2
- - 3
- - 4
+ - 1
+ - 2
+ - 3
+ - 4
num_head_channels: 32
use_spatial_transformer: true
transformer_depth: 3
@@ -48,11 +48,11 @@ model:
out_ch: 3
ch: 128
ch_mult:
- - 1
- - 2
- - 4
+ - 1
+ - 2
+ - 4
num_res_blocks: 2
- attn_resolutions: []
+ attn_resolutions: [ ]
dropout: 0.0
lossconfig:
target: torch.nn.Identity
diff --git a/models/ldm/lsun_beds256/config.yaml b/models/ldm/lsun_beds256/config.yaml
index 1a50c766a..47b9e38ef 100644
--- a/models/ldm/lsun_beds256/config.yaml
+++ b/models/ldm/lsun_beds256/config.yaml
@@ -22,15 +22,15 @@ model:
out_channels: 3
model_channels: 224
attention_resolutions:
- - 8
- - 4
- - 2
+ - 8
+ - 4
+ - 2
num_res_blocks: 2
channel_mult:
- - 1
- - 2
- - 3
- - 4
+ - 1
+ - 2
+ - 3
+ - 4
num_head_channels: 32
first_stage_config:
target: ldm.models.autoencoder.VQModelInterface
@@ -45,11 +45,11 @@ model:
out_ch: 3
ch: 128
ch_mult:
- - 1
- - 2
- - 4
+ - 1
+ - 2
+ - 4
num_res_blocks: 2
- attn_resolutions: []
+ attn_resolutions: [ ]
dropout: 0.0
lossconfig:
target: torch.nn.Identity
diff --git a/models/ldm/lsun_churches256/config.yaml b/models/ldm/lsun_churches256/config.yaml
index 424d0914c..9569453be 100644
--- a/models/ldm/lsun_churches256/config.yaml
+++ b/models/ldm/lsun_churches256/config.yaml
@@ -20,15 +20,15 @@ model:
target: ldm.lr_scheduler.LambdaLinearScheduler
params:
warm_up_steps:
- - 10000
+ - 10000
cycle_lengths:
- - 10000000000000
+ - 10000000000000
f_start:
- - 1.0e-06
+ - 1.0e-06
f_max:
- - 1.0
+ - 1.0
f_min:
- - 1.0
+ - 1.0
unet_config:
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
params:
@@ -37,17 +37,17 @@ model:
out_channels: 4
model_channels: 192
attention_resolutions:
- - 1
- - 2
- - 4
- - 8
+ - 1
+ - 2
+ - 4
+ - 8
num_res_blocks: 2
channel_mult:
- - 1
- - 2
- - 2
- - 4
- - 4
+ - 1
+ - 2
+ - 2
+ - 4
+ - 4
num_heads: 8
use_scale_shift_norm: true
resblock_updown: true
@@ -64,12 +64,12 @@ model:
out_ch: 3
ch: 128
ch_mult:
- - 1
- - 2
- - 4
- - 4
+ - 1
+ - 2
+ - 4
+ - 4
num_res_blocks: 2
- attn_resolutions: []
+ attn_resolutions: [ ]
dropout: 0.0
lossconfig:
target: torch.nn.Identity
diff --git a/models/ldm/semantic_synthesis256/config.yaml b/models/ldm/semantic_synthesis256/config.yaml
index 1a721cfff..16505bbb5 100644
--- a/models/ldm/semantic_synthesis256/config.yaml
+++ b/models/ldm/semantic_synthesis256/config.yaml
@@ -21,14 +21,14 @@ model:
out_channels: 3
model_channels: 128
attention_resolutions:
- - 32
- - 16
- - 8
+ - 32
+ - 16
+ - 8
num_res_blocks: 2
channel_mult:
- - 1
- - 4
- - 8
+ - 1
+ - 4
+ - 8
num_heads: 8
first_stage_config:
target: ldm.models.autoencoder.VQModelInterface
@@ -43,11 +43,11 @@ model:
out_ch: 3
ch: 128
ch_mult:
- - 1
- - 2
- - 4
+ - 1
+ - 2
+ - 4
num_res_blocks: 2
- attn_resolutions: []
+ attn_resolutions: [ ]
dropout: 0.0
lossconfig:
target: torch.nn.Identity
diff --git a/models/ldm/semantic_synthesis512/config.yaml b/models/ldm/semantic_synthesis512/config.yaml
index 8faded2ee..6059dcdf0 100644
--- a/models/ldm/semantic_synthesis512/config.yaml
+++ b/models/ldm/semantic_synthesis512/config.yaml
@@ -21,14 +21,14 @@ model:
out_channels: 3
model_channels: 128
attention_resolutions:
- - 32
- - 16
- - 8
+ - 32
+ - 16
+ - 8
num_res_blocks: 2
channel_mult:
- - 1
- - 4
- - 8
+ - 1
+ - 4
+ - 8
num_heads: 8
first_stage_config:
target: ldm.models.autoencoder.VQModelInterface
@@ -44,11 +44,11 @@ model:
out_ch: 3
ch: 128
ch_mult:
- - 1
- - 2
- - 4
+ - 1
+ - 2
+ - 4
num_res_blocks: 2
- attn_resolutions: []
+ attn_resolutions: [ ]
dropout: 0.0
lossconfig:
target: torch.nn.Identity
diff --git a/models/ldm/text2img256/config.yaml b/models/ldm/text2img256/config.yaml
index 3f54a0151..98fffef93 100644
--- a/models/ldm/text2img256/config.yaml
+++ b/models/ldm/text2img256/config.yaml
@@ -22,15 +22,15 @@ model:
out_channels: 3
model_channels: 192
attention_resolutions:
- - 8
- - 4
- - 2
+ - 8
+ - 4
+ - 2
num_res_blocks: 2
channel_mult:
- - 1
- - 2
- - 3
- - 5
+ - 1
+ - 2
+ - 3
+ - 5
num_head_channels: 32
use_spatial_transformer: true
transformer_depth: 1
@@ -48,11 +48,11 @@ model:
out_ch: 3
ch: 128
ch_mult:
- - 1
- - 2
- - 4
+ - 1
+ - 2
+ - 4
num_res_blocks: 2
- attn_resolutions: []
+ attn_resolutions: [ ]
dropout: 0.0
lossconfig:
target: torch.nn.Identity
diff --git a/notebook_helpers.py b/notebook_helpers.py
index 5d0ebd7e1..d67bf60ca 100644
--- a/notebook_helpers.py
+++ b/notebook_helpers.py
@@ -1,23 +1,24 @@
-from torchvision.datasets.utils import download_url
-from ldm.util import instantiate_from_config
-import torch
+import ipywidgets as widgets
import os
-# todo ?
-from google.colab import files
+import time
+import torch
+import torch
+import torchvision
from IPython.display import Image as ipyimg
-import ipywidgets as widgets
from PIL import Image
-from numpy import asarray
from einops import rearrange, repeat
-import torch, torchvision
+# todo ?
+from google.colab import files
+from numpy import asarray
+from omegaconf import OmegaConf
+from torchvision.datasets.utils import download_url
+
from ldm.models.diffusion.ddim import DDIMSampler
+from ldm.util import instantiate_from_config
from ldm.util import ismap
-import time
-from omegaconf import OmegaConf
def download_models(mode):
-
if mode == "superresolution":
# this is the small bsr light model
url_conf = 'https://heibox.uni-heidelberg.de/f/31a76b13ea27482981b4/?dl=1'
@@ -29,8 +30,8 @@ def download_models(mode):
download_url(url_conf, path_conf)
download_url(url_ckpt, path_ckpt)
- path_conf = path_conf + '/?dl=1' # fix it
- path_ckpt = path_ckpt + '/?dl=1' # fix it
+ path_conf = path_conf + '/?dl=1' # fix it
+ path_ckpt = path_ckpt + '/?dl=1' # fix it
return path_conf, path_ckpt
else:
@@ -62,7 +63,7 @@ def get_custom_cond(mode):
if mode == "superresolution":
uploaded_img = files.upload()
filename = next(iter(uploaded_img))
- name, filetype = filename.split(".") # todo assumes just one dot in name !
+ name, filetype = filename.split(".") # todo assumes just one dot in name !
os.rename(f"{filename}", f"{dest}/{mode}/custom_{name}.{filetype}")
elif mode == "text_conditional":
@@ -129,7 +130,6 @@ def visualize_cond_img(path):
def run(model, selected_path, task, custom_steps, resize_enabled=False, classifier_ckpt=None, global_step=None):
-
example = get_cond(task, selected_path)
save_intermediate_vid = False
@@ -173,13 +173,14 @@ def run(model, selected_path, task, custom_steps, resize_enabled=False, classifi
logs = make_convolutional_sample(example, model,
mode=mode, custom_steps=custom_steps,
- eta=eta, swap_mode=False , masked=masked,
+ eta=eta, swap_mode=False, masked=masked,
invert_mask=invert_mask, quantize_x0=False,
custom_schedule=None, decode_interval=10,
resize_enabled=resize_enabled, custom_shape=custom_shape,
temperature=temperature, noise_dropout=0.,
- corrector=guider, corrector_kwargs=ckwargs, x_T=x_T, save_intermediate_vid=save_intermediate_vid,
- make_progrow=make_progrow,ddim_use_x0_pred=ddim_use_x0_pred
+ corrector=guider, corrector_kwargs=ckwargs, x_T=x_T,
+ save_intermediate_vid=save_intermediate_vid,
+ make_progrow=make_progrow, ddim_use_x0_pred=ddim_use_x0_pred
)
return logs
@@ -190,7 +191,6 @@ def convsample_ddim(model, cond, steps, shape, eta=1.0, callback=None, normals_s
temperature=1., noise_dropout=0., score_corrector=None,
corrector_kwargs=None, x_T=None, log_every_t=None
):
-
ddim = DDIMSampler(model)
bs = shape[0] # dont know where this comes from but wayne
shape = shape[1:] # cut batch dim
@@ -208,7 +208,8 @@ def convsample_ddim(model, cond, steps, shape, eta=1.0, callback=None, normals_s
def make_convolutional_sample(batch, model, mode="vanilla", custom_steps=None, eta=1.0, swap_mode=False, masked=False,
invert_mask=True, quantize_x0=False, custom_schedule=None, decode_interval=1000,
resize_enabled=False, custom_shape=None, temperature=1., noise_dropout=0., corrector=None,
- corrector_kwargs=None, x_T=None, save_intermediate_vid=False, make_progrow=True,ddim_use_x0_pred=False):
+ corrector_kwargs=None, x_T=None, save_intermediate_vid=False, make_progrow=True,
+ ddim_use_x0_pred=False):
log = dict()
z, c, x, xrec, xc = model.get_input(batch, model.first_stage_key,
@@ -237,7 +238,7 @@ def make_convolutional_sample(batch, model, mode="vanilla", custom_steps=None, e
log["original_conditioning"] = xc if xc is not None else torch.zeros_like(x)
if model.cond_stage_model:
log[model.cond_stage_key] = xc if xc is not None else torch.zeros_like(x)
- if model.cond_stage_key =='class_label':
+ if model.cond_stage_key == 'class_label':
log[model.cond_stage_key] = xc[model.cond_stage_key]
with model.ema_scope("Plotting"):
@@ -267,4 +268,4 @@ def make_convolutional_sample(batch, model, mode="vanilla", custom_steps=None, e
log["sample"] = x_sample
log["time"] = t1 - t0
- return log
\ No newline at end of file
+ return log
diff --git a/optimizedSD/ddpm.py b/optimizedSD/ddpm.py
index dcf790186..04bbd0b72 100644
--- a/optimizedSD/ddpm.py
+++ b/optimizedSD/ddpm.py
@@ -6,22 +6,24 @@
-- merci
"""
-import time, math
-from tqdm.auto import trange, tqdm
+import math
+from functools import partial
+
+import numpy as np
+import pytorch_lightning as pl
import torch
from einops import rearrange
+from pytorch_lightning.utilities.distributed import rank_zero_only
from tqdm import tqdm
-from ldm.modules.distributions.distributions import DiagonalGaussianDistribution
+from tqdm.auto import trange, tqdm
+
from ldm.models.autoencoder import VQModelInterface
-import torch.nn as nn
-import numpy as np
-import pytorch_lightning as pl
-from functools import partial
-from pytorch_lightning.utilities.distributed import rank_zero_only
-from ldm.util import exists, default, instantiate_from_config
-from ldm.modules.diffusionmodules.util import make_beta_schedule
+from ldm.modules.diffusionmodules.util import make_beta_schedule, extract_into_tensor
from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like
-from ldm.modules.diffusionmodules.util import make_beta_schedule, extract_into_tensor, noise_like
+from ldm.modules.distributions.distributions import DiagonalGaussianDistribution
+from ldm.util import exists, default, instantiate_from_config
+
+
# from samplers import CompVisDenoiser
def disabled_train(self):
@@ -36,7 +38,7 @@ def __init__(self,
timesteps=1000,
beta_schedule="linear",
ckpt_path=None,
- ignore_keys=[],
+ ignore_keys=None,
load_only_unet=False,
monitor="val/loss",
use_ema=True,
@@ -58,6 +60,8 @@ def __init__(self,
use_positional_encodings=False,
):
super().__init__()
+ if ignore_keys is None:
+ ignore_keys = []
assert parameterization in ["eps", "x0"], 'currently only supporting "eps" and "x0"'
self.parameterization = parameterization
print(f"{self.__class__.__name__}: Running in {self.parameterization}-prediction mode")
@@ -83,7 +87,6 @@ def __init__(self,
self.register_schedule(given_betas=given_betas, beta_schedule=beta_schedule, timesteps=timesteps,
linear_start=linear_start, linear_end=linear_end, cosine_s=cosine_s)
-
def register_schedule(self, given_betas=None, beta_schedule="linear", timesteps=1000,
linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
if exists(given_betas):
@@ -108,6 +111,7 @@ def register_schedule(self, given_betas=None, beta_schedule="linear", timesteps=
class FirstStage(DDPM):
"""main class"""
+
def __init__(self,
first_stage_config,
num_timesteps_cond=None,
@@ -140,14 +144,13 @@ def __init__(self,
self.instantiate_first_stage(first_stage_config)
self.cond_stage_forward = cond_stage_forward
self.clip_denoised = False
- self.bbox_tokenizer = None
+ self.bbox_tokenizer = None
self.restarted_from_ckpt = False
if ckpt_path is not None:
self.init_from_ckpt(ckpt_path, ignore_keys)
self.restarted_from_ckpt = True
-
def instantiate_first_stage(self, config):
model = instantiate_from_config(config)
self.first_stage_model = model.eval()
@@ -164,7 +167,6 @@ def get_first_stage_encoding(self, encoder_posterior):
raise NotImplementedError(f"encoder_posterior of type '{type(encoder_posterior)}' not yet implemented")
return self.scale_factor * z
-
@torch.no_grad()
def decode_first_stage(self, z, predict_cids=False, force_not_quantize=False):
if predict_cids:
@@ -187,7 +189,6 @@ def decode_first_stage(self, z, predict_cids=False, force_not_quantize=False):
else:
return self.first_stage_model.decode(z)
-
@torch.no_grad()
def encode_first_stage(self, x):
if hasattr(self, "split_input_params"):
@@ -231,6 +232,7 @@ def encode_first_stage(self, x):
class CondStage(DDPM):
"""main class"""
+
def __init__(self,
cond_stage_config,
num_timesteps_cond=None,
@@ -262,7 +264,7 @@ def __init__(self,
self.instantiate_cond_stage(cond_stage_config)
self.cond_stage_forward = cond_stage_forward
self.clip_denoised = False
- self.bbox_tokenizer = None
+ self.bbox_tokenizer = None
self.restarted_from_ckpt = False
if ckpt_path is not None:
@@ -303,6 +305,7 @@ def get_learned_conditioning(self, c):
c = getattr(self.cond_stage_model, self.cond_stage_forward)(c)
return c
+
class DiffusionWrapper(pl.LightningModule):
def __init__(self, diff_model_config):
super().__init__()
@@ -312,17 +315,19 @@ def forward(self, x, t, cc):
out = self.diffusion_model(x, t, context=cc)
return out
+
class DiffusionWrapperOut(pl.LightningModule):
def __init__(self, diff_model_config):
super().__init__()
self.diffusion_model = instantiate_from_config(diff_model_config)
- def forward(self, h,emb,tp,hs, cc):
- return self.diffusion_model(h,emb,tp,hs, context=cc)
+ def forward(self, h, emb, tp, hs, cc):
+ return self.diffusion_model(h, emb, tp, hs, context=cc)
class UNet(DDPM):
"""main class"""
+
def __init__(self,
unetConfigEncode,
unetConfigDecode,
@@ -333,7 +338,7 @@ def __init__(self,
cond_stage_forward=None,
conditioning_key=None,
scale_factor=1.0,
- unet_bs = 1,
+ unet_bs=1,
scale_by_std=False,
*args, **kwargs):
self.num_timesteps_cond = default(num_timesteps_cond, 1)
@@ -358,7 +363,7 @@ def __init__(self,
self.register_buffer('scale_factor', torch.tensor(scale_factor))
self.cond_stage_forward = cond_stage_forward
self.clip_denoised = False
- self.bbox_tokenizer = None
+ self.bbox_tokenizer = None
self.model1 = DiffusionWrapper(self.unetConfigEncode)
self.model2 = DiffusionWrapperOut(self.unetConfigDecode)
self.model1.eval()
@@ -392,41 +397,38 @@ def on_train_batch_start(self, batch, batch_idx):
print(f"setting self.scale_factor to {self.scale_factor}")
print("### USING STD-RESCALING ###")
-
def apply_model(self, x_noisy, t, cond, return_ids=False):
-
- if(not self.turbo):
+
+ if not self.turbo:
self.model1.to(self.cdevice)
step = self.unet_bs
- h,emb,hs = self.model1(x_noisy[0:step], t[:step], cond[:step])
+ h, emb, hs = self.model1(x_noisy[0:step], t[:step], cond[:step])
bs = cond.shape[0]
-
+
# assert bs%2 == 0
lenhs = len(hs)
- for i in range(step,bs,step):
- h_temp,emb_temp,hs_temp = self.model1(x_noisy[i:i+step], t[i:i+step], cond[i:i+step])
- h = torch.cat((h,h_temp))
- emb = torch.cat((emb,emb_temp))
+ for i in range(step, bs, step):
+ h_temp, emb_temp, hs_temp = self.model1(x_noisy[i:i + step], t[i:i + step], cond[i:i + step])
+ h = torch.cat((h, h_temp))
+ emb = torch.cat((emb, emb_temp))
for j in range(lenhs):
hs[j] = torch.cat((hs[j], hs_temp[j]))
-
- if(not self.turbo):
+ if not self.turbo:
self.model1.to("cpu")
self.model2.to(self.cdevice)
-
- hs_temp = [hs[j][:step] for j in range(lenhs)]
- x_recon = self.model2(h[:step],emb[:step],x_noisy.dtype,hs_temp,cond[:step])
- for i in range(step,bs,step):
+ hs_temp = [hs[j][:step] for j in range(lenhs)]
+ x_recon = self.model2(h[:step], emb[:step], x_noisy.dtype, hs_temp, cond[:step])
- hs_temp = [hs[j][i:i+step] for j in range(lenhs)]
- x_recon1 = self.model2(h[i:i+step],emb[i:i+step],x_noisy.dtype,hs_temp,cond[i:i+step])
+ for i in range(step, bs, step):
+ hs_temp = [hs[j][i:i + step] for j in range(lenhs)]
+ x_recon1 = self.model2(h[i:i + step], emb[i:i + step], x_noisy.dtype, hs_temp, cond[i:i + step])
x_recon = torch.cat((x_recon, x_recon1))
- if(not self.turbo):
+ if not self.turbo:
self.model2.to("cpu")
if isinstance(x_recon, tuple) and not return_ids:
@@ -435,47 +437,43 @@ def apply_model(self, x_noisy, t, cond, return_ids=False):
return x_recon
def register_buffer1(self, name, attr):
- if type(attr) == torch.Tensor:
- if attr.device != torch.device(self.cdevice):
- attr = attr.to(torch.device(self.cdevice))
- setattr(self, name, attr)
+ if type(attr) == torch.Tensor:
+ if attr.device != torch.device(self.cdevice):
+ attr = attr.to(torch.device(self.cdevice))
+ setattr(self, name, attr)
def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
-
self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps,
- num_ddpm_timesteps=self.num_timesteps,verbose=verbose)
+ num_ddpm_timesteps=self.num_timesteps, verbose=verbose)
-
assert self.alphas_cumprod.shape[0] == self.num_timesteps, 'alphas have to be defined for each timestep'
-
to_torch = lambda x: x.to(self.cdevice)
self.register_buffer1('betas', to_torch(self.betas))
self.register_buffer1('alphas_cumprod', to_torch(self.alphas_cumprod))
# ddim sampling parameters
ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=self.alphas_cumprod.cpu(),
ddim_timesteps=self.ddim_timesteps,
- eta=ddim_eta,verbose=verbose)
+ eta=ddim_eta, verbose=verbose)
self.register_buffer1('ddim_sigmas', ddim_sigmas)
self.register_buffer1('ddim_alphas', ddim_alphas)
self.register_buffer1('ddim_alphas_prev', ddim_alphas_prev)
self.register_buffer1('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas))
-
@torch.no_grad()
def sample(self,
S,
conditioning,
x0=None,
- shape = None,
- seed=1234,
+ shape=None,
+ seed=1234,
callback=None,
img_callback=None,
quantize_x0=False,
eta=0.,
mask=None,
- sampler = "plms",
+ sampler="plms",
temperature=1.,
noise_dropout=0.,
score_corrector=None,
@@ -486,9 +484,8 @@ def sample(self,
unconditional_guidance_scale=1.,
unconditional_conditioning=None,
):
-
- if(self.turbo):
+ if self.turbo:
self.model1.to(self.cdevice)
self.model2.to(self.cdevice)
@@ -496,39 +493,40 @@ def sample(self,
batch_size, b1, b2, b3 = shape
img_shape = (1, b1, b2, b3)
tens = []
- print("seeds used = ", [seed+s for s in range(batch_size)])
+ print("seeds used = ", [seed + s for s in range(batch_size)])
for _ in range(batch_size):
torch.manual_seed(seed)
tens.append(torch.randn(img_shape, device=self.cdevice))
- seed+=1
+ seed += 1
noise = torch.cat(tens)
del tens
self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=False)
x_latent = noise if x0 is None else x0
# sampling
-
+
if sampler == "plms":
print(f'Data shape for PLMS sampling is {shape}')
samples = self.plms_sampling(conditioning, batch_size, x_latent,
- callback=callback,
- img_callback=img_callback,
- quantize_denoised=quantize_x0,
- mask=mask, x0=x0,
- ddim_use_original_steps=False,
- noise_dropout=noise_dropout,
- temperature=temperature,
- score_corrector=score_corrector,
- corrector_kwargs=corrector_kwargs,
- log_every_t=log_every_t,
- unconditional_guidance_scale=unconditional_guidance_scale,
- unconditional_conditioning=unconditional_conditioning,
- )
+ callback=callback,
+ img_callback=img_callback,
+ quantize_denoised=quantize_x0,
+ mask=mask, x0=x0,
+ ddim_use_original_steps=False,
+ noise_dropout=noise_dropout,
+ temperature=temperature,
+ score_corrector=score_corrector,
+ corrector_kwargs=corrector_kwargs,
+ log_every_t=log_every_t,
+ unconditional_guidance_scale=unconditional_guidance_scale,
+ unconditional_conditioning=unconditional_conditioning,
+ )
elif sampler == "ddim":
- samples = self.ddim_sampling(x_latent, conditioning, S, unconditional_guidance_scale=unconditional_guidance_scale,
+ samples = self.ddim_sampling(x_latent, conditioning, S,
+ unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning,
- mask = mask,init_latent=x_T,use_original_steps=False)
+ mask=mask, init_latent=x_T, use_original_steps=False)
# elif sampler == "euler":
# cvd = CompVisDenoiser(self.alphas_cumprod)
@@ -536,20 +534,20 @@ def sample(self,
# samples = self.heun_sampling(noise, sig, conditioning, unconditional_conditioning=unconditional_conditioning,
# unconditional_guidance_scale=unconditional_guidance_scale)
- if(self.turbo):
+ if self.turbo:
self.model1.to("cpu")
self.model2.to("cpu")
return samples
@torch.no_grad()
- def plms_sampling(self, cond,b, img,
+ def plms_sampling(self, cond, b, img,
ddim_use_original_steps=False,
callback=None, quantize_denoised=False,
mask=None, x0=None, img_callback=None, log_every_t=100,
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
- unconditional_guidance_scale=1., unconditional_conditioning=None,):
-
+ unconditional_guidance_scale=1., unconditional_conditioning=None, ):
+
device = self.betas.device
timesteps = self.ddim_timesteps
time_range = np.flip(timesteps)
@@ -607,7 +605,7 @@ def get_model_output(x, t):
return e_t
- alphas = self.ddim_alphas
+ alphas = self.ddim_alphas
alphas_prev = self.ddim_alphas_prev
sqrt_one_minus_alphas = self.ddim_sqrt_one_minus_alphas
sigmas = self.ddim_sigmas
@@ -617,14 +615,14 @@ def get_x_prev_and_pred_x0(e_t, index):
a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
- sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)
+ sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index], device=device)
# current prediction for x_0
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
if quantize_denoised:
pred_x0, _, *_ = self.first_stage_model.quantize(pred_x0)
# direction pointing to x_t
- dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
+ dir_xt = (1. - a_prev - sigma_t ** 2).sqrt() * e_t
noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
if noise_dropout > 0.:
noise = torch.nn.functional.dropout(noise, p=noise_dropout)
@@ -651,9 +649,8 @@ def get_x_prev_and_pred_x0(e_t, index):
return x_prev, pred_x0, e_t
-
@torch.no_grad()
- def stochastic_encode(self, x0, t, seed, ddim_eta,ddim_steps,use_original_steps=False, noise=None):
+ def stochastic_encode(self, x0, t, seed, ddim_eta, ddim_steps, use_original_steps=False, noise=None):
# fast, but does not allow for exact reconstruction
# t serves as an index to gather the correct alphas
self.make_schedule(ddim_num_steps=ddim_steps, ddim_eta=ddim_eta, verbose=False)
@@ -663,11 +660,11 @@ def stochastic_encode(self, x0, t, seed, ddim_eta,ddim_steps,use_original_steps=
b0, b1, b2, b3 = x0.shape
img_shape = (1, b1, b2, b3)
tens = []
- print("seeds used = ", [seed+s for s in range(b0)])
+ print("seeds used = ", [seed + s for s in range(b0)])
for _ in range(b0):
torch.manual_seed(seed)
tens.append(torch.randn(img_shape, device=x0.device))
- seed+=1
+ seed += 1
noise = torch.cat(tens)
del tens
return (extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0 +
@@ -684,10 +681,9 @@ def add_noise(self, x0, t):
return (extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0 +
extract_into_tensor(self.ddim_sqrt_one_minus_alphas, t, x0.shape) * noise)
-
@torch.no_grad()
def ddim_sampling(self, x_latent, cond, t_start, unconditional_guidance_scale=1.0, unconditional_conditioning=None,
- mask = None,init_latent=None,use_original_steps=False):
+ mask=None, init_latent=None, use_original_steps=False):
timesteps = self.ddim_timesteps
timesteps = timesteps[:t_start]
@@ -700,23 +696,22 @@ def ddim_sampling(self, x_latent, cond, t_start, unconditional_guidance_scale=1.
x0 = init_latent
for i, step in enumerate(iterator):
index = total_steps - i - 1
- ts = torch.full((x_latent.shape[0],), step, device=x_latent.device, dtype=torch.long)
+ ts = torch.full((x_latent.shape[0],), step, device=x_latent.device, dtype=torch.long)
if mask is not None:
# x0_noisy = self.add_noise(mask, torch.tensor([index] * x0.shape[0]).to(self.cdevice))
x0_noisy = x0
- x_dec = x0_noisy* mask + (1. - mask) * x_dec
+ x_dec = x0_noisy * mask + (1. - mask) * x_dec
x_dec = self.p_sample_ddim(x_dec, cond, ts, index=index, use_original_steps=use_original_steps,
- unconditional_guidance_scale=unconditional_guidance_scale,
- unconditional_conditioning=unconditional_conditioning)
-
+ unconditional_guidance_scale=unconditional_guidance_scale,
+ unconditional_conditioning=unconditional_conditioning)
+
if mask is not None:
return x0 * mask + (1. - mask) * x_dec
return x_dec
-
@torch.no_grad()
def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
@@ -744,21 +739,20 @@ def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=F
a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
- sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)
+ sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index], device=device)
# current prediction for x_0
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
if quantize_denoised:
pred_x0, _, *_ = self.first_stage_model.quantize(pred_x0)
# direction pointing to x_t
- dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
+ dir_xt = (1. - a_prev - sigma_t ** 2).sqrt() * e_t
noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
if noise_dropout > 0.:
noise = torch.nn.functional.dropout(noise, p=noise_dropout)
x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
return x_prev
-
def append_zero(self, x):
return torch.cat([x, x.new_zeros([1])])
@@ -770,31 +764,29 @@ def get_sigmas_karras(self, n, sigma_min, sigma_max, rho=7., device='cpu'):
sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
return self.append_zero(sigmas).to(device)
-
- def get_sigmas_exponential(self,n, sigma_min, sigma_max, device='cpu'):
+ def get_sigmas_exponential(self, n, sigma_min, sigma_max, device='cpu'):
"""Constructs an exponential noise schedule."""
sigmas = torch.linspace(math.log(sigma_max), math.log(sigma_min), n, device=device).exp()
return self.append_zero(sigmas)
-
- def get_sigmas_vp(self,n, beta_d=19.9, beta_min=0.1, eps_s=1e-3, device='cpu'):
+ def get_sigmas_vp(self, n, beta_d=19.9, beta_min=0.1, eps_s=1e-3, device='cpu'):
"""Constructs a continuous VP noise schedule."""
t = torch.linspace(1, eps_s, n, device=device)
sigmas = torch.sqrt(torch.exp(beta_d * t ** 2 / 2 + beta_min * t) - 1)
return self.append_zero(sigmas)
- def to_d(self,x, sigma, denoised):
+ def to_d(self, x, sigma, denoised):
"""Converts a denoiser output to a Karras ODE derivative."""
return (x - denoised) / self.append_dims(sigma, x.ndim)
- def append_dims(self,x, target_dims):
+ def append_dims(self, x, target_dims):
"""Appends dimensions to the end of a tensor until it has target_dims dimensions."""
dims_to_append = target_dims - x.ndim
if dims_to_append < 0:
raise ValueError(f'input has {x.ndim} dims but target_dims is {target_dims}, which is less')
return x[(...,) + (None,) * dims_to_append]
- def get_ancestral_step(self,sigma_from, sigma_to):
+ def get_ancestral_step(self, sigma_from, sigma_to):
"""Calculates the noise level (sigma_down) to step down to and the amount
of noise to add (sigma_up) when doing an ancestral sampling step."""
sigma_up = (sigma_to ** 2 * (sigma_from ** 2 - sigma_to ** 2) / sigma_from ** 2) ** 0.5
@@ -802,7 +794,8 @@ def get_ancestral_step(self,sigma_from, sigma_to):
return sigma_down, sigma_up
@torch.no_grad()
- def euler_sampling(self, x, sigmas, cond,extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
+ def euler_sampling(self, x, sigmas, cond, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0.,
+ s_tmax=float('inf'), s_noise=1.):
"""Implements Algorithm 2 (Euler steps) from Karras et al. (2022)."""
extra_args = {} if extra_args is None else extra_args
s_in = x.new_ones([x.shape[0]]).half()
@@ -822,13 +815,15 @@ def euler_sampling(self, x, sigmas, cond,extra_args=None, callback=None, disable
return x
@torch.no_grad()
- def heun_sampling(self, x,sigmas, cond,unconditional_conditioning = None,unconditional_guidance_scale = 1, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
+ def heun_sampling(self, x, sigmas, cond, unconditional_conditioning=None, unconditional_guidance_scale=1,
+ extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'),
+ s_noise=1.):
"""Implements Algorithm 2 (Heun steps) from Karras et al. (2022)."""
extra_args = {} if extra_args is None else extra_args
print(sigmas)
# sigmas = self.get_sigmas_karras(steps, 0.01, 80, rho=7., device=x.device)
print(x[0])
- x = x*sigmas[0]
+ x = x * sigmas[0]
print("alu", x[0])
s_in = x.new_ones([x.shape[0]])
for i in trange(len(sigmas) - 1, disable=disable):
@@ -857,7 +852,7 @@ def heun_sampling(self, x,sigmas, cond,unconditional_conditioning = None,uncondi
else:
# Heun's method
x_2 = x + d * dt
- denoised_2 = self.apply_model(x_2, sigmas[i + 1] * s_in,cond)
+ denoised_2 = self.apply_model(x_2, sigmas[i + 1] * s_in, cond)
d_2 = self.to_d(x_2, sigmas[i + 1], denoised_2)
d_prime = (d + d_2) / 2
x = x + d_prime * dt
diff --git a/optimizedSD/diffusers_txt2img.py b/optimizedSD/diffusers_txt2img.py
index 80fbb9723..fce7dba66 100644
--- a/optimizedSD/diffusers_txt2img.py
+++ b/optimizedSD/diffusers_txt2img.py
@@ -3,10 +3,11 @@
pipe = LDMTextToImagePipeline.from_pretrained("CompVis/stable-diffusion-v1-3-diffusers", use_auth_token=True)
-prompt = "19th Century wooden engraving of Elon musk"
+prompt = "19th Century wooden engraving of Elon musk"
seed = torch.manual_seed(1024)
-images = pipe([prompt], batch_size=1, num_inference_steps=50, guidance_scale=7, generator=seed,torch_device="cpu" )["sample"]
+images = pipe([prompt], batch_size=1, num_inference_steps=50, guidance_scale=7, generator=seed, torch_device="cpu")[
+ "sample"]
# save images
for idx, image in enumerate(images):
diff --git a/optimizedSD/img2img_gradio.py b/optimizedSD/img2img_gradio.py
index 65d844d3b..b1773ac49 100644
--- a/optimizedSD/img2img_gradio.py
+++ b/optimizedSD/img2img_gradio.py
@@ -1,29 +1,28 @@
+import argparse
+import os
+import re
+import time
+from contextlib import nullcontext
+from itertools import islice
+from random import randint
+import mimetypes
import gradio as gr
import numpy as np
import torch
-from torchvision.utils import make_grid
-import os, re
from PIL import Image
-import torch
-import numpy as np
-from random import randint
+from einops import rearrange, repeat
from omegaconf import OmegaConf
-from PIL import Image
-from tqdm import tqdm, trange
-from itertools import islice
-from einops import rearrange
-from torchvision.utils import make_grid
-import time
from pytorch_lightning import seed_everything
from torch import autocast
-from einops import rearrange, repeat
-from contextlib import nullcontext
-from ldm.util import instantiate_from_config
+from torchvision.utils import make_grid
+from tqdm import tqdm, trange
from transformers import logging
-import pandas as pd
+
+from ldm.util import instantiate_from_config
from optimUtils import split_weighted_subprompts, logger
+
logging.set_verbosity_error()
-import mimetypes
+
mimetypes.init()
mimetypes.add_type("application/javascript", ".js")
@@ -43,7 +42,6 @@ def load_model_from_config(ckpt, verbose=False):
def load_img(image, h0, w0):
-
image = image.convert("RGB")
w, h = image.size
print(f"loaded input image of size ({w}, {h})")
@@ -59,61 +57,26 @@ def load_img(image, h0, w0):
image = torch.from_numpy(image)
return 2.0 * image - 1.0
-config = "optimizedSD/v1-inference.yaml"
-ckpt = "models/ldm/stable-diffusion-v1/model.ckpt"
-sd = load_model_from_config(f"{ckpt}")
-li, lo = [], []
-for key, v_ in sd.items():
- sp = key.split(".")
- if (sp[0]) == "model":
- if "input_blocks" in sp:
- li.append(key)
- elif "middle_block" in sp:
- li.append(key)
- elif "time_embed" in sp:
- li.append(key)
- else:
- lo.append(key)
-for key in li:
- sd["model1." + key[6:]] = sd.pop(key)
-for key in lo:
- sd["model2." + key[6:]] = sd.pop(key)
-
-config = OmegaConf.load(f"{config}")
-
-model = instantiate_from_config(config.modelUNet)
-_, _ = model.load_state_dict(sd, strict=False)
-model.eval()
-
-modelCS = instantiate_from_config(config.modelCondStage)
-_, _ = modelCS.load_state_dict(sd, strict=False)
-modelCS.eval()
-
-modelFS = instantiate_from_config(config.modelFirstStage)
-_, _ = modelFS.load_state_dict(sd, strict=False)
-modelFS.eval()
-del sd
def generate(
- image,
- prompt,
- strength,
- ddim_steps,
- n_iter,
- batch_size,
- Height,
- Width,
- scale,
- ddim_eta,
- unet_bs,
- device,
- seed,
- outdir,
- img_format,
- turbo,
- full_precision,
+ image,
+ prompt,
+ strength,
+ ddim_steps,
+ n_iter,
+ batch_size,
+ Height,
+ Width,
+ scale,
+ ddim_eta,
+ unet_bs,
+ device,
+ seed,
+ outdir,
+ img_format,
+ turbo,
+ full_precision,
):
-
if seed == "":
seed = randint(0, 1000000)
seed = int(seed)
@@ -121,7 +84,7 @@ def generate(
# Logging
sampler = "ddim"
- logger(locals(), log_csv = "logs/img2img_gradio_logs.csv")
+ logger(locals(), log_csv="logs/img2img_gradio_logs.csv")
init_image = load_img(image, Height, Width).to(device)
model.unet_bs = unet_bs
@@ -205,18 +168,17 @@ def generate(
)
# decode it
samples_ddim = model.sample(
- t_enc,
- c,
- z_enc,
- unconditional_guidance_scale=scale,
- unconditional_conditioning=uc,
- sampler = sampler
+ t_enc,
+ c,
+ z_enc,
+ unconditional_guidance_scale=scale,
+ unconditional_conditioning=uc,
+ sampler=sampler
)
modelFS.to(device)
print("saving images")
for i in range(batch_size):
-
x_samples_ddim = modelFS.decode_first_stage(samples_ddim[i].unsqueeze(0))
x_sample = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
all_samples.append(x_sample.to("cpu"))
@@ -247,37 +209,77 @@ def generate(
grid = 255.0 * rearrange(grid, "c h w -> h w c").cpu().numpy()
txt = (
- "Samples finished in "
- + str(round(time_taken, 3))
- + " minutes and exported to \n"
- + sample_path
- + "\nSeeds used = "
- + seeds[:-1]
+ "Samples finished in "
+ + str(round(time_taken, 3))
+ + " minutes and exported to \n"
+ + sample_path
+ + "\nSeeds used = "
+ + seeds[:-1]
)
return Image.fromarray(grid.astype(np.uint8)), txt
-demo = gr.Interface(
- fn=generate,
- inputs=[
- gr.Image(tool="editor", type="pil"),
- "text",
- gr.Slider(0, 1, value=0.75),
- gr.Slider(1, 1000, value=50),
- gr.Slider(1, 100, step=1),
- gr.Slider(1, 100, step=1),
- gr.Slider(64, 4096, value=512, step=64),
- gr.Slider(64, 4096, value=512, step=64),
- gr.Slider(0, 50, value=7.5, step=0.1),
- gr.Slider(0, 1, step=0.01),
- gr.Slider(1, 2, value=1, step=1),
- gr.Text(value="cuda"),
- "text",
- gr.Text(value="outputs/img2img-samples"),
- gr.Radio(["png", "jpg"], value='png'),
- "checkbox",
- "checkbox",
- ],
- outputs=["image", "text"],
-)
-demo.launch()
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser(description='img2img using gradio')
+ parser.add_argument('--config_path', default="optimizedSD/v1-inference.yaml", type=str, help='config path')
+ parser.add_argument('--ckpt_path', default="models/ldm/stable-diffusion-v1/model.ckpt", type=str, help='ckpt path')
+ args = parser.parse_args()
+ config = args.config_path
+ ckpt = args.ckpt_path
+ sd = load_model_from_config(f"{ckpt}")
+ li, lo = [], []
+ for key, v_ in sd.items():
+ sp = key.split(".")
+ if (sp[0]) == "model":
+ if "input_blocks" in sp:
+ li.append(key)
+ elif "middle_block" in sp:
+ li.append(key)
+ elif "time_embed" in sp:
+ li.append(key)
+ else:
+ lo.append(key)
+ for key in li:
+ sd["model1." + key[6:]] = sd.pop(key)
+ for key in lo:
+ sd["model2." + key[6:]] = sd.pop(key)
+
+ config = OmegaConf.load(f"{config}")
+
+ model = instantiate_from_config(config.modelUNet)
+ _, _ = model.load_state_dict(sd, strict=False)
+ model.eval()
+
+ modelCS = instantiate_from_config(config.modelCondStage)
+ _, _ = modelCS.load_state_dict(sd, strict=False)
+ modelCS.eval()
+
+ modelFS = instantiate_from_config(config.modelFirstStage)
+ _, _ = modelFS.load_state_dict(sd, strict=False)
+ modelFS.eval()
+ del sd
+
+ demo = gr.Interface(
+ fn=generate,
+ inputs=[
+ gr.Image(tool="editor", type="pil"),
+ "text",
+ gr.Slider(0, 1, value=0.75),
+ gr.Slider(1, 1000, value=50),
+ gr.Slider(1, 100, step=1),
+ gr.Slider(1, 100, step=1),
+ gr.Slider(64, 4096, value=512, step=64),
+ gr.Slider(64, 4096, value=512, step=64),
+ gr.Slider(0, 50, value=7.5, step=0.1),
+ gr.Slider(0, 1, step=0.01),
+ gr.Slider(1, 2, value=1, step=1),
+ gr.Text(value="cuda"),
+ "text",
+ gr.Text(value="outputs/img2img-samples"),
+ gr.Radio(["png", "jpg"], value='png'),
+ "checkbox",
+ "checkbox",
+ ],
+ outputs=["image", "text"],
+ )
+ demo.launch(share=True)
diff --git a/optimizedSD/inpaint_gradio.py b/optimizedSD/inpaint_gradio.py
index 4b5653bc2..d42930244 100644
--- a/optimizedSD/inpaint_gradio.py
+++ b/optimizedSD/inpaint_gradio.py
@@ -1,29 +1,29 @@
+import argparse
+import os
+import re
+import time
+from contextlib import nullcontext
+from itertools import islice
+from random import randint
+
import gradio as gr
import numpy as np
import torch
-from torchvision.utils import make_grid
-import os, re
from PIL import Image
-import torch
-import numpy as np
-from random import randint
+from einops import rearrange, repeat
from omegaconf import OmegaConf
-from PIL import Image
-from tqdm import tqdm, trange
-from itertools import islice
-from einops import rearrange
-from torchvision.utils import make_grid
-import time
from pytorch_lightning import seed_everything
from torch import autocast
-from einops import rearrange, repeat
-from contextlib import nullcontext
-from ldm.util import instantiate_from_config
+from torchvision.utils import make_grid
+from tqdm import tqdm, trange
from transformers import logging
-import pandas as pd
+
+from ldm.util import instantiate_from_config
from optimUtils import split_weighted_subprompts, logger
+
logging.set_verbosity_error()
import mimetypes
+
mimetypes.init()
mimetypes.add_type("application/javascript", ".js")
@@ -43,7 +43,6 @@ def load_model_from_config(ckpt, verbose=False):
def load_img(image, h0, w0):
-
image = image.convert("RGB")
w, h = image.size
print(f"loaded input image of size ({w}, {h})")
@@ -60,85 +59,49 @@ def load_img(image, h0, w0):
return 2.0 * image - 1.0
-def load_mask(mask, h0, w0, invert=False):
-
+def load_mask(mask, h0, w0, newH, newW, invert=False):
image = mask.convert("RGB")
w, h = image.size
- print(f"loaded input image of size ({w}, {h})")
- if(h0 is not None and w0 is not None):
+ print(f"loaded input mask of size ({w}, {h})")
+ if h0 is not None and w0 is not None:
h, w = h0, w0
-
+
w, h = map(lambda x: x - x % 64, (w, h)) # resize to integer multiple of 32
- print(f"New image size ({w}, {h})")
- image = image.resize((64, 64), resample = Image.LANCZOS)
+ print(f"New mask size ({w}, {h})")
+ image = image.resize((newW, newH), resample=Image.LANCZOS)
+ # image = image.resize((64, 64), resample=Image.LANCZOS)
image = np.array(image)
if invert:
print("inverted")
- where_0, where_1 = np.where(image == 0),np.where(image == 255)
+ where_0, where_1 = np.where(image == 0), np.where(image == 255)
image[where_0], image[where_1] = 255, 0
- image = image.astype(np.float32)/255.0
+ image = image.astype(np.float32) / 255.0
image = image[None].transpose(0, 3, 1, 2)
image = torch.from_numpy(image)
return image
-config = "optimizedSD/v1-inference.yaml"
-ckpt = "models/ldm/stable-diffusion-v1/model.ckpt"
-sd = load_model_from_config(f"{ckpt}")
-li, lo = [], []
-for key, v_ in sd.items():
- sp = key.split(".")
- if (sp[0]) == "model":
- if "input_blocks" in sp:
- li.append(key)
- elif "middle_block" in sp:
- li.append(key)
- elif "time_embed" in sp:
- li.append(key)
- else:
- lo.append(key)
-for key in li:
- sd["model1." + key[6:]] = sd.pop(key)
-for key in lo:
- sd["model2." + key[6:]] = sd.pop(key)
-
-config = OmegaConf.load(f"{config}")
-
-model = instantiate_from_config(config.modelUNet)
-_, _ = model.load_state_dict(sd, strict=False)
-model.eval()
-
-modelCS = instantiate_from_config(config.modelCondStage)
-_, _ = modelCS.load_state_dict(sd, strict=False)
-modelCS.eval()
-
-modelFS = instantiate_from_config(config.modelFirstStage)
-_, _ = modelFS.load_state_dict(sd, strict=False)
-modelFS.eval()
-del sd
-
def generate(
- image,
- prompt,
- strength,
- ddim_steps,
- n_iter,
- batch_size,
- Height,
- Width,
- scale,
- ddim_eta,
- unet_bs,
- device,
- seed,
- outdir,
- img_format,
- turbo,
- full_precision,
+ image,
+ prompt,
+ strength,
+ ddim_steps,
+ n_iter,
+ batch_size,
+ Height,
+ Width,
+ scale,
+ ddim_eta,
+ unet_bs,
+ device,
+ seed,
+ outdir,
+ img_format,
+ turbo,
+ full_precision,
):
-
if seed == "":
seed = randint(0, 1000000)
seed = int(seed)
@@ -146,10 +109,9 @@ def generate(
sampler = "ddim"
# Logging
- logger(locals(), log_csv = "logs/inpaint_gradio_logs.csv")
+ logger(locals(), log_csv="logs/inpaint_gradio_logs.csv")
init_image = load_img(image['image'], Height, Width).to(device)
- mask = load_mask(image['mask'], Height, Width, True).to(device)
model.unet_bs = unet_bs
model.turbo = turbo
@@ -161,10 +123,7 @@ def generate(
modelCS.half()
modelFS.half()
init_image = init_image.half()
- mask.half()
-
- mask = mask[0][0].unsqueeze(0).repeat(4,1,1).unsqueeze(0)
- mask = repeat(mask, '1 ... -> b ...', b=batch_size)
+ # mask.half()
tic = time.time()
os.makedirs(outdir, exist_ok=True)
@@ -182,6 +141,10 @@ def generate(
init_latent = modelFS.get_first_stage_encoding(modelFS.encode_first_stage(init_image)) # move to latent space
init_latent = repeat(init_latent, "1 ... -> b ...", b=batch_size)
+ mask = load_mask(image['mask'], Height, Width, init_latent.shape[2], init_latent.shape[3], True).to(device)
+ mask = mask[0][0].unsqueeze(0).repeat(4, 1, 1).unsqueeze(0)
+ mask = repeat(mask, '1 ... -> b ...', b=batch_size)
+
if device != "cpu":
mem = torch.cuda.memory_allocated() / 1e6
modelFS.to("cpu")
@@ -237,7 +200,7 @@ def generate(
z_enc = model.stochastic_encode(
init_latent, torch.tensor([t_enc] * batch_size).to(device),
seed, ddim_eta, ddim_steps)
-
+
# decode it
samples_ddim = model.sample(
t_enc,
@@ -245,15 +208,14 @@ def generate(
z_enc,
unconditional_guidance_scale=scale,
unconditional_conditioning=uc,
- mask = mask,
- x_T = init_latent,
- sampler = sampler,
+ mask=mask,
+ x_T=init_latent,
+ sampler=sampler,
)
modelFS.to(device)
print("saving images")
for i in range(batch_size):
-
x_samples_ddim = modelFS.decode_first_stage(samples_ddim[i].unsqueeze(0))
x_sample = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
all_samples.append(x_sample.to("cpu"))
@@ -284,37 +246,77 @@ def generate(
grid = 255.0 * rearrange(grid, "c h w -> h w c").cpu().numpy()
txt = (
- "Samples finished in "
- + str(round(time_taken, 3))
- + " minutes and exported to \n"
- + sample_path
- + "\nSeeds used = "
- + seeds[:-1]
+ "Samples finished in "
+ + str(round(time_taken, 3))
+ + " minutes and exported to \n"
+ + sample_path
+ + "\nSeeds used = "
+ + seeds[:-1]
+ )
+ return Image.fromarray(grid.astype(np.uint8)), image['mask'], txt
+
+
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser(description='txt2img using gradio')
+ parser.add_argument('--config_path', default="optimizedSD/v1-inference.yaml", type=str, help='config path')
+ parser.add_argument('--ckpt_path', default="models/ldm/stable-diffusion-v1/model.ckpt", type=str, help='ckpt path')
+ args = parser.parse_args()
+ config = args.config_path
+ ckpt = args.ckpt_path
+ sd = load_model_from_config(f"{ckpt}")
+ li, lo = [], []
+ for key, v_ in sd.items():
+ sp = key.split(".")
+ if (sp[0]) == "model":
+ if "input_blocks" in sp:
+ li.append(key)
+ elif "middle_block" in sp:
+ li.append(key)
+ elif "time_embed" in sp:
+ li.append(key)
+ else:
+ lo.append(key)
+ for key in li:
+ sd["model1." + key[6:]] = sd.pop(key)
+ for key in lo:
+ sd["model2." + key[6:]] = sd.pop(key)
+
+ config = OmegaConf.load(f"{config}")
+
+ model = instantiate_from_config(config.modelUNet)
+ _, _ = model.load_state_dict(sd, strict=False)
+ model.eval()
+
+ modelCS = instantiate_from_config(config.modelCondStage)
+ _, _ = modelCS.load_state_dict(sd, strict=False)
+ modelCS.eval()
+
+ modelFS = instantiate_from_config(config.modelFirstStage)
+ _, _ = modelFS.load_state_dict(sd, strict=False)
+ modelFS.eval()
+ del sd
+
+ demo = gr.Interface(
+ fn=generate,
+ inputs=[
+ gr.Image(tool="sketch", type="pil"),
+ "text",
+ gr.Slider(0, 0.99, value=0.99, step=0.01),
+ gr.Slider(1, 1000, value=50),
+ gr.Slider(1, 100, step=1),
+ gr.Slider(1, 100, step=1),
+ gr.Slider(64, 4096, value=512, step=64),
+ gr.Slider(64, 4096, value=512, step=64),
+ gr.Slider(0, 50, value=7.5, step=0.1),
+ gr.Slider(0, 1, step=0.01),
+ gr.Slider(1, 2, value=1, step=1),
+ gr.Text(value="cuda"),
+ "text",
+ gr.Text(value="outputs/inpaint-samples"),
+ gr.Radio(["png", "jpg"], value='png'),
+ "checkbox",
+ "checkbox",
+ ],
+ outputs=["image", "image", "text"],
)
- return Image.fromarray(grid.astype(np.uint8)), image['mask'],txt
-
-
-demo = gr.Interface(
- fn=generate,
- inputs=[
- gr.Image(tool="sketch", type="pil"),
- "text",
- gr.Slider(0, 0.99, value=0.99, step = 0.01),
- gr.Slider(1, 1000, value=50),
- gr.Slider(1, 100, step=1),
- gr.Slider(1, 100, step=1),
- gr.Slider(64, 4096, value=512, step=64),
- gr.Slider(64, 4096, value=512, step=64),
- gr.Slider(0, 50, value=7.5, step=0.1),
- gr.Slider(0, 1, step=0.01),
- gr.Slider(1, 2, value=1, step=1),
- gr.Text(value="cuda"),
- "text",
- gr.Text(value="outputs/inpaint-samples"),
- gr.Radio(["png", "jpg"], value='png'),
- "checkbox",
- "checkbox",
- ],
- outputs=["image", "image", "text"],
-)
-demo.launch()
\ No newline at end of file
+ demo.launch()
diff --git a/optimizedSD/openaimodelSplit.py b/optimizedSD/openaimodelSplit.py
index 2136ada44..2ecec6c62 100644
--- a/optimizedSD/openaimodelSplit.py
+++ b/optimizedSD/openaimodelSplit.py
@@ -1,9 +1,11 @@
-from abc import abstractmethod
import math
import numpy as np
import torch as th
import torch.nn as nn
import torch.nn.functional as F
+from abc import abstractmethod
+
+from ldm.modules.attention import SpatialTransformer
from ldm.modules.diffusionmodules.util import (
checkpoint,
conv_nd,
@@ -13,7 +15,6 @@
normalization,
timestep_embedding,
)
-from ldm.modules.attention import SpatialTransformer
class AttentionPool2d(nn.Module):
@@ -22,11 +23,11 @@ class AttentionPool2d(nn.Module):
"""
def __init__(
- self,
- spacial_dim: int,
- embed_dim: int,
- num_heads_channels: int,
- output_dim: int = None,
+ self,
+ spacial_dim: int,
+ embed_dim: int,
+ num_heads_channels: int,
+ output_dim: int = None,
):
super().__init__()
self.positional_embedding = nn.Parameter(th.randn(embed_dim, spacial_dim ** 2 + 1) / embed_dim ** 0.5)
@@ -105,16 +106,18 @@ def forward(self, x):
x = self.conv(x)
return x
+
class TransposedUpsample(nn.Module):
- 'Learned 2x upsampling without padding'
+ """Learned 2x upsampling without padding"""
+
def __init__(self, channels, out_channels=None, ks=5):
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
- self.up = nn.ConvTranspose2d(self.channels,self.out_channels,kernel_size=ks,stride=2)
+ self.up = nn.ConvTranspose2d(self.channels, self.out_channels, kernel_size=ks, stride=2)
- def forward(self,x):
+ def forward(self, x):
return self.up(x)
@@ -127,7 +130,7 @@ class Downsample(nn.Module):
downsampling occurs in the inner-two dimensions.
"""
- def __init__(self, channels, use_conv, dims=2, out_channels=None,padding=1):
+ def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
@@ -164,17 +167,17 @@ class ResBlock(TimestepBlock):
"""
def __init__(
- self,
- channels,
- emb_channels,
- dropout,
- out_channels=None,
- use_conv=False,
- use_scale_shift_norm=False,
- dims=2,
- use_checkpoint=False,
- up=False,
- down=False,
+ self,
+ channels,
+ emb_channels,
+ dropout,
+ out_channels=None,
+ use_conv=False,
+ use_scale_shift_norm=False,
+ dims=2,
+ use_checkpoint=False,
+ up=False,
+ down=False,
):
super().__init__()
self.channels = channels
@@ -238,7 +241,6 @@ def forward(self, x, emb):
self._forward, (x, emb), self.parameters(), self.use_checkpoint
)
-
def _forward(self, x, emb):
if self.updown:
in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
@@ -270,12 +272,12 @@ class AttentionBlock(nn.Module):
"""
def __init__(
- self,
- channels,
- num_heads=1,
- num_head_channels=-1,
- use_checkpoint=False,
- use_new_attention_order=False,
+ self,
+ channels,
+ num_heads=1,
+ num_head_channels=-1,
+ use_checkpoint=False,
+ use_new_attention_order=False,
):
super().__init__()
self.channels = channels
@@ -283,7 +285,7 @@ def __init__(
self.num_heads = num_heads
else:
assert (
- channels % num_head_channels == 0
+ channels % num_head_channels == 0
), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
self.num_heads = channels // num_head_channels
self.use_checkpoint = use_checkpoint
@@ -299,8 +301,9 @@ def __init__(
self.proj_out = zero_module(conv_nd(1, channels, channels, 1))
def forward(self, x):
- return checkpoint(self._forward, (x,), self.parameters(), True) # TODO: check checkpoint usage, is True # TODO: fix the .half call!!!
- #return pt_checkpoint(self._forward, x) # pytorch
+ return checkpoint(self._forward, (x,), self.parameters(),
+ True) # TODO: check checkpoint usage, is True # TODO: fix the .half call!!!
+ # return pt_checkpoint(self._forward, x) # pytorch
def _forward(self, x):
b, c, *spatial = x.shape
@@ -399,33 +402,32 @@ def count_flops(model, _x, y):
class UNetModelEncode(nn.Module):
-
def __init__(
- self,
- image_size,
- in_channels,
- model_channels,
- out_channels,
- num_res_blocks,
- attention_resolutions,
- dropout=0,
- channel_mult=(1, 2, 4, 8),
- conv_resample=True,
- dims=2,
- num_classes=None,
- use_checkpoint=False,
- use_fp16=False,
- num_heads=-1,
- num_head_channels=-1,
- num_heads_upsample=-1,
- use_scale_shift_norm=False,
- resblock_updown=False,
- use_new_attention_order=False,
- use_spatial_transformer=False, # custom transformer support
- transformer_depth=1, # custom transformer support
- context_dim=None, # custom transformer support
- n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model
- legacy=True,
+ self,
+ image_size,
+ in_channels,
+ model_channels,
+ out_channels,
+ num_res_blocks,
+ attention_resolutions,
+ dropout=0,
+ channel_mult=(1, 2, 4, 8),
+ conv_resample=True,
+ dims=2,
+ num_classes=None,
+ use_checkpoint=False,
+ use_fp16=False,
+ num_heads=-1,
+ num_head_channels=-1,
+ num_heads_upsample=-1,
+ use_scale_shift_norm=False,
+ resblock_updown=False,
+ use_new_attention_order=False,
+ use_spatial_transformer=False, # custom transformer support
+ transformer_depth=1, # custom transformer support
+ context_dim=None, # custom transformer support
+ n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model
+ legacy=True,
):
super().__init__()
if use_spatial_transformer:
@@ -505,7 +507,7 @@ def __init__(
num_heads = ch // num_head_channels
dim_head = num_head_channels
if legacy:
- #num_heads = 1
+ # num_heads = 1
dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
layers.append(
AttentionBlock(
@@ -552,7 +554,7 @@ def __init__(
num_heads = ch // num_head_channels
dim_head = num_head_channels
if legacy:
- #num_heads = 1
+ # num_heads = 1
dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
self.middle_block = TimestepEmbedSequential(
ResBlock(
@@ -570,8 +572,8 @@ def __init__(
num_head_channels=dim_head,
use_new_attention_order=use_new_attention_order,
) if not use_spatial_transformer else SpatialTransformer(
- ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim
- ),
+ ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim
+ ),
ResBlock(
ch,
time_embed_dim,
@@ -593,7 +595,7 @@ def forward(self, x, timesteps=None, context=None, y=None):
:return: an [N x C x ...] Tensor of outputs.
"""
assert (y is not None) == (
- self.num_classes is not None
+ self.num_classes is not None
), "must specify y if and only if the model is class-conditional"
hs = []
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
@@ -608,39 +610,38 @@ def forward(self, x, timesteps=None, context=None, y=None):
h = module(h, emb, context)
hs.append(h)
h = self.middle_block(h, emb, context)
-
+
return h, emb, hs
class UNetModelDecode(nn.Module):
-
def __init__(
- self,
- image_size,
- in_channels,
- model_channels,
- out_channels,
- num_res_blocks,
- attention_resolutions,
- dropout=0,
- channel_mult=(1, 2, 4, 8),
- conv_resample=True,
- dims=2,
- num_classes=None,
- use_checkpoint=False,
- use_fp16=False,
- num_heads=-1,
- num_head_channels=-1,
- num_heads_upsample=-1,
- use_scale_shift_norm=False,
- resblock_updown=False,
- use_new_attention_order=False,
- use_spatial_transformer=False, # custom transformer support
- transformer_depth=1, # custom transformer support
- context_dim=None, # custom transformer support
- n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model
- legacy=True,
+ self,
+ image_size,
+ in_channels,
+ model_channels,
+ out_channels,
+ num_res_blocks,
+ attention_resolutions,
+ dropout=0,
+ channel_mult=(1, 2, 4, 8),
+ conv_resample=True,
+ dims=2,
+ num_classes=None,
+ use_checkpoint=False,
+ use_fp16=False,
+ num_heads=-1,
+ num_head_channels=-1,
+ num_heads_upsample=-1,
+ use_scale_shift_norm=False,
+ resblock_updown=False,
+ use_new_attention_order=False,
+ use_spatial_transformer=False, # custom transformer support
+ transformer_depth=1, # custom transformer support
+ context_dim=None, # custom transformer support
+ n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model
+ legacy=True,
):
super().__init__()
if use_spatial_transformer:
@@ -695,7 +696,7 @@ def __init__(
num_heads = ch // num_head_channels
dim_head = num_head_channels
if legacy:
- #num_heads = 1
+ # num_heads = 1
dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
self._feature_size += ch
@@ -714,7 +715,7 @@ def __init__(
num_heads = ch // num_head_channels
dim_head = num_head_channels
if legacy:
- #num_heads = 1
+ # num_heads = 1
dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
self._feature_size += ch
@@ -742,7 +743,7 @@ def __init__(
num_heads = ch // num_head_channels
dim_head = num_head_channels
if legacy:
- #num_heads = 1
+ # num_heads = 1
dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
layers.append(
AttentionBlock(
@@ -782,12 +783,12 @@ def __init__(
)
if self.predict_codebook_ids:
self.id_predictor = nn.Sequential(
- normalization(ch),
- conv_nd(dims, model_channels, n_embed, 1),
- #nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits
- )
+ normalization(ch),
+ conv_nd(dims, model_channels, n_embed, 1),
+ # nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits
+ )
- def forward(self, h,emb,tp,hs, context=None, y=None):
+ def forward(self, h, emb, tp, hs, context=None, y=None):
"""
Apply the model to an input batch.
:param x: an [N x C x ...] Tensor of inputs.
@@ -796,7 +797,7 @@ def forward(self, h,emb,tp,hs, context=None, y=None):
:param y: an [N] Tensor of labels, if class-conditional.
:return: an [N x C x ...] Tensor of outputs.
"""
-
+
for module in self.output_blocks:
h = th.cat([h, hs.pop()], dim=1)
h = module(h, emb, context)
@@ -804,4 +805,4 @@ def forward(self, h,emb,tp,hs, context=None, y=None):
if self.predict_codebook_ids:
return self.id_predictor(h)
else:
- return self.out(h)
\ No newline at end of file
+ return self.out(h)
diff --git a/optimizedSD/optimUtils.py b/optimizedSD/optimUtils.py
index 18b996792..cefc7155f 100644
--- a/optimizedSD/optimUtils.py
+++ b/optimizedSD/optimUtils.py
@@ -1,4 +1,5 @@
import os
+
import pandas as pd
@@ -14,60 +15,61 @@ def split_weighted_subprompts(text):
weights = []
while remaining > 0:
if ":" in text:
- idx = text.index(":") # first occurrence from start
+ idx = text.index(":") # first occurrence from start
# grab up to index as sub-prompt
prompt = text[:idx]
remaining -= idx
# remove from main text
- text = text[idx+1:]
+ text = text[idx + 1:]
# find value for weight
if " " in text:
- idx = text.index(" ") # first occurence
- else: # no space, read to end
+ idx = text.index(" ") # first occurence
+ else: # no space, read to end
idx = len(text)
if idx != 0:
try:
weight = float(text[:idx])
- except: # couldn't treat as float
+ except: # couldn't treat as float
print(f"Warning: '{text[:idx]}' is not a value, are you missing a space?")
weight = 1.0
- else: # no value found
+ else: # no value found
weight = 1.0
# remove from main text
remaining -= idx
- text = text[idx+1:]
+ text = text[idx + 1:]
# append the sub-prompt and its weight
prompts.append(prompt)
weights.append(weight)
- else: # no : found
- if len(text) > 0: # there is still text though
+ else: # no : found
+ if len(text) > 0: # there is still text though
# take remainder as weight 1
prompts.append(text)
weights.append(1.0)
remaining = 0
return prompts, weights
+
def logger(params, log_csv):
os.makedirs('logs', exist_ok=True)
cols = [arg for arg, _ in params.items()]
if not os.path.exists(log_csv):
- df = pd.DataFrame(columns=cols)
+ df = pd.DataFrame(columns=cols)
df.to_csv(log_csv, index=False)
df = pd.read_csv(log_csv)
for arg in cols:
if arg not in df.columns:
df[arg] = ""
- df.to_csv(log_csv, index = False)
+ df.to_csv(log_csv, index=False)
li = {}
cols = [col for col in df.columns]
- data = {arg:value for arg, value in params.items()}
+ data = {arg: value for arg, value in params.items()}
for col in cols:
if col in data:
li[col] = data[col]
else:
li[col] = ''
- df = pd.DataFrame(li,index = [0])
- df.to_csv(log_csv,index=False, mode='a', header=False)
\ No newline at end of file
+ df = pd.DataFrame(li, index=[0])
+ df.to_csv(log_csv, index=False, mode='a', header=False)
diff --git a/optimizedSD/optimized_img2img.py b/optimizedSD/optimized_img2img.py
index 9ca304c2e..52b3a59ee 100644
--- a/optimizedSD/optimized_img2img.py
+++ b/optimizedSD/optimized_img2img.py
@@ -1,22 +1,26 @@
-import argparse, os, re
-import torch
+import argparse
import numpy as np
-from random import randint
-from omegaconf import OmegaConf
+import os
+import pandas as pd
+import re
+import time
+import torch
from PIL import Image
-from tqdm import tqdm, trange
-from itertools import islice
+from contextlib import contextmanager, nullcontext
from einops import rearrange
-from torchvision.utils import make_grid
-import time
+from einops import rearrange, repeat
+from itertools import islice
+from omegaconf import OmegaConf
from pytorch_lightning import seed_everything
+from random import randint
from torch import autocast
-from contextlib import contextmanager, nullcontext
-from einops import rearrange, repeat
+from torchvision.utils import make_grid
+from tqdm import tqdm, trange
+from transformers import logging
+
from ldm.util import instantiate_from_config
from optimUtils import split_weighted_subprompts, logger
-from transformers import logging
-import pandas as pd
+
logging.set_verbosity_error()
@@ -35,7 +39,6 @@ def load_model_from_config(ckpt, verbose=False):
def load_img(path, h0, w0):
-
image = Image.open(path).convert("RGB")
w, h = image.size
@@ -186,7 +189,7 @@ def load_img(path, h0, w0):
seed_everything(opt.seed)
# Logging
-logger(vars(opt), log_csv = "logs/img2img_logs.csv")
+logger(vars(opt), log_csv="logs/img2img_logs.csv")
sd = load_model_from_config(f"{ckpt}")
li, lo = [], []
@@ -258,12 +261,10 @@ def load_img(path, h0, w0):
while torch.cuda.memory_allocated() / 1e6 >= mem:
time.sleep(1)
-
assert 0.0 <= opt.strength <= 1.0, "can only work with strength in [0.0, 1.0]"
t_enc = int(opt.strength * opt.ddim_steps)
print(f"target t_enc is {t_enc} steps")
-
if opt.precision == "autocast" and opt.device != "cpu":
precision_scope = autocast
else:
@@ -271,7 +272,6 @@ def load_img(path, h0, w0):
seeds = ""
with torch.no_grad():
-
all_samples = list()
for n in trange(opt.n_iter, desc="Sampling"):
for prompts in tqdm(data, desc="data"):
@@ -322,13 +322,12 @@ def load_img(path, h0, w0):
z_enc,
unconditional_guidance_scale=opt.scale,
unconditional_conditioning=uc,
- sampler = opt.sampler
+ sampler=opt.sampler
)
modelFS.to(opt.device)
print("saving images")
for i in range(batch_size):
-
x_samples_ddim = modelFS.decode_first_stage(samples_ddim[i].unsqueeze(0))
x_sample = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
x_sample = 255.0 * rearrange(x_sample[0].cpu().numpy(), "c h w -> h w c")
@@ -354,9 +353,9 @@ def load_img(path, h0, w0):
print(
(
- "Samples finished in {0:.2f} minutes and exported to "
- + sample_path
- + "\n Seeds used = "
- + seeds[:-1]
+ "Samples finished in {0:.2f} minutes and exported to "
+ + sample_path
+ + "\n Seeds used = "
+ + seeds[:-1]
).format(time_taken)
)
diff --git a/optimizedSD/optimized_txt2img.py b/optimizedSD/optimized_txt2img.py
index 6ead8618f..ef264ff67 100644
--- a/optimizedSD/optimized_txt2img.py
+++ b/optimizedSD/optimized_txt2img.py
@@ -1,20 +1,24 @@
-import argparse, os, re
-import torch
-import numpy as np
+import argparse
+import os
+import re
+import time
+from contextlib import nullcontext
+from itertools import islice
from random import randint
-from omegaconf import OmegaConf
+
+import numpy as np
+import torch
from PIL import Image
-from tqdm import tqdm, trange
-from itertools import islice
from einops import rearrange
-from torchvision.utils import make_grid
-import time
+from omegaconf import OmegaConf
from pytorch_lightning import seed_everything
from torch import autocast
-from contextlib import contextmanager, nullcontext
+from tqdm import tqdm, trange
+from transformers import logging
+
from ldm.util import instantiate_from_config
from optimUtils import split_weighted_subprompts, logger
-from transformers import logging
+
# from samplers import CompVisDenoiser
logging.set_verbosity_error()
@@ -147,7 +151,7 @@ def load_model_from_config(ckpt, verbose=False):
help="Reduces inference time on the expense of 1GB VRAM",
)
parser.add_argument(
- "--precision",
+ "--precision",
type=str,
help="evaluate at this precision",
choices=["full", "autocast"],
@@ -179,7 +183,7 @@ def load_model_from_config(ckpt, verbose=False):
seed_everything(opt.seed)
# Logging
-logger(vars(opt), log_csv = "logs/txt2img_logs.csv")
+logger(vars(opt), log_csv="logs/txt2img_logs.csv")
sd = load_model_from_config(f"{ckpt}")
li, lo = [], []
@@ -226,7 +230,6 @@ def load_model_from_config(ckpt, verbose=False):
if opt.fixed_code:
start_code = torch.randn([opt.n_samples, opt.C, opt.H // opt.f, opt.W // opt.f], device=opt.device)
-
batch_size = opt.n_samples
n_rows = opt.n_rows if opt.n_rows > 0 else batch_size
if not opt.from_file:
@@ -241,7 +244,6 @@ def load_model_from_config(ckpt, verbose=False):
data = batch_size * list(data)
data = list(chunk(sorted(data), batch_size))
-
if opt.precision == "autocast" and opt.device != "cpu":
precision_scope = autocast
else:
@@ -249,7 +251,6 @@ def load_model_from_config(ckpt, verbose=False):
seeds = ""
with torch.no_grad():
-
all_samples = list()
for n in trange(opt.n_iter, desc="Sampling"):
for prompts in tqdm(data, desc="data"):
@@ -297,7 +298,7 @@ def load_model_from_config(ckpt, verbose=False):
unconditional_conditioning=uc,
eta=opt.ddim_eta,
x_T=start_code,
- sampler = opt.sampler,
+ sampler=opt.sampler,
)
modelFS.to(opt.device)
@@ -305,7 +306,6 @@ def load_model_from_config(ckpt, verbose=False):
print(samples_ddim.shape)
print("saving images")
for i in range(batch_size):
-
x_samples_ddim = modelFS.decode_first_stage(samples_ddim[i].unsqueeze(0))
x_sample = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
x_sample = 255.0 * rearrange(x_sample[0].cpu().numpy(), "c h w -> h w c")
@@ -330,9 +330,9 @@ def load_model_from_config(ckpt, verbose=False):
print(
(
- "Samples finished in {0:.2f} minutes and exported to "
- + sample_path
- + "\n Seeds used = "
- + seeds[:-1]
+ "Samples finished in {0:.2f} minutes and exported to "
+ + sample_path
+ + "\n Seeds used = "
+ + seeds[:-1]
).format(time_taken)
)
diff --git a/optimizedSD/txt2img_gradio.py b/optimizedSD/txt2img_gradio.py
index d9775735a..4c8365d51 100644
--- a/optimizedSD/txt2img_gradio.py
+++ b/optimizedSD/txt2img_gradio.py
@@ -1,29 +1,28 @@
+import argparse
+import os
+import re
+import time
+from contextlib import nullcontext
+from itertools import islice
+from random import randint
+
import gradio as gr
import numpy as np
import torch
-from torchvision.utils import make_grid
-from einops import rearrange
-import os, re
-from PIL import Image
-import torch
-import pandas as pd
-import numpy as np
-from random import randint
-from omegaconf import OmegaConf
from PIL import Image
-from tqdm import tqdm, trange
-from itertools import islice
from einops import rearrange
-from torchvision.utils import make_grid
-import time
+from omegaconf import OmegaConf
from pytorch_lightning import seed_everything
from torch import autocast
-from contextlib import nullcontext
+from torchvision.utils import make_grid
+from tqdm import tqdm, trange
+from transformers import logging
+import mimetypes
from ldm.util import instantiate_from_config
from optimUtils import split_weighted_subprompts, logger
-from transformers import logging
+
logging.set_verbosity_error()
-import mimetypes
+
mimetypes.init()
mimetypes.add_type("application/javascript", ".js")
@@ -41,61 +40,24 @@ def load_model_from_config(ckpt, verbose=False):
sd = pl_sd["state_dict"]
return sd
-config = "optimizedSD/v1-inference.yaml"
-ckpt = "models/ldm/stable-diffusion-v1/model.ckpt"
-sd = load_model_from_config(f"{ckpt}")
-li, lo = [], []
-for key, v_ in sd.items():
- sp = key.split(".")
- if (sp[0]) == "model":
- if "input_blocks" in sp:
- li.append(key)
- elif "middle_block" in sp:
- li.append(key)
- elif "time_embed" in sp:
- li.append(key)
- else:
- lo.append(key)
-for key in li:
- sd["model1." + key[6:]] = sd.pop(key)
-for key in lo:
- sd["model2." + key[6:]] = sd.pop(key)
-
-config = OmegaConf.load(f"{config}")
-
-model = instantiate_from_config(config.modelUNet)
-_, _ = model.load_state_dict(sd, strict=False)
-model.eval()
-
-modelCS = instantiate_from_config(config.modelCondStage)
-_, _ = modelCS.load_state_dict(sd, strict=False)
-modelCS.eval()
-
-modelFS = instantiate_from_config(config.modelFirstStage)
-_, _ = modelFS.load_state_dict(sd, strict=False)
-modelFS.eval()
-del sd
-
-
def generate(
- prompt,
- ddim_steps,
- n_iter,
- batch_size,
- Height,
- Width,
- scale,
- ddim_eta,
- unet_bs,
- device,
- seed,
- outdir,
- img_format,
- turbo,
- full_precision,
- sampler,
+ prompt,
+ ddim_steps,
+ n_iter,
+ batch_size,
+ Height,
+ Width,
+ scale,
+ ddim_eta,
+ unet_bs,
+ device,
+ seed,
+ outdir,
+ img_format,
+ turbo,
+ full_precision,
+ sampler,
):
-
C = 4
f = 8
start_code = None
@@ -122,7 +84,7 @@ def generate(
sample_path = os.path.join(outpath, "_".join(re.split(":| ", prompt)))[:150]
os.makedirs(sample_path, exist_ok=True)
base_count = len(os.listdir(sample_path))
-
+
# n_rows = opt.n_rows if opt.n_rows > 0 else batch_size
assert prompt is not None
data = [batch_size * [prompt]]
@@ -178,13 +140,12 @@ def generate(
unconditional_conditioning=uc,
eta=ddim_eta,
x_T=start_code,
- sampler = sampler,
+ sampler=sampler,
)
modelFS.to(device)
print("saving images")
for i in range(batch_size):
-
x_samples_ddim = modelFS.decode_first_stage(samples_ddim[i].unsqueeze(0))
x_sample = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
all_samples.append(x_sample.to("cpu"))
@@ -215,36 +176,76 @@ def generate(
grid = 255.0 * rearrange(grid, "c h w -> h w c").cpu().numpy()
txt = (
- "Samples finished in "
- + str(round(time_taken, 3))
- + " minutes and exported to "
- + sample_path
- + "\nSeeds used = "
- + seeds[:-1]
+ "Samples finished in "
+ + str(round(time_taken, 3))
+ + " minutes and exported to "
+ + sample_path
+ + "\nSeeds used = "
+ + seeds[:-1]
)
return Image.fromarray(grid.astype(np.uint8)), txt
-demo = gr.Interface(
- fn=generate,
- inputs=[
- "text",
- gr.Slider(1, 1000, value=50),
- gr.Slider(1, 100, step=1),
- gr.Slider(1, 100, step=1),
- gr.Slider(64, 4096, value=512, step=64),
- gr.Slider(64, 4096, value=512, step=64),
- gr.Slider(0, 50, value=7.5, step=0.1),
- gr.Slider(0, 1, step=0.01),
- gr.Slider(1, 2, value=1, step=1),
- gr.Text(value="cuda"),
- "text",
- gr.Text(value="outputs/txt2img-samples"),
- gr.Radio(["png", "jpg"], value='png'),
- "checkbox",
- "checkbox",
- gr.Radio(["ddim", "plms"], value="plms"),
- ],
- outputs=["image", "text"],
-)
-demo.launch()
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser(description='txt2img using gradio')
+ parser.add_argument('--config_path', default="optimizedSD/v1-inference.yaml", type=str, help='config path')
+ parser.add_argument('--ckpt_path', default="models/ldm/stable-diffusion-v1/model.ckpt", type=str, help='ckpt path')
+ args = parser.parse_args()
+ config = args.config_path
+ ckpt = args.ckpt_path
+ sd = load_model_from_config(f"{ckpt}")
+ li, lo = [], []
+ for key, v_ in sd.items():
+ sp = key.split(".")
+ if (sp[0]) == "model":
+ if "input_blocks" in sp:
+ li.append(key)
+ elif "middle_block" in sp:
+ li.append(key)
+ elif "time_embed" in sp:
+ li.append(key)
+ else:
+ lo.append(key)
+ for key in li:
+ sd["model1." + key[6:]] = sd.pop(key)
+ for key in lo:
+ sd["model2." + key[6:]] = sd.pop(key)
+
+ config = OmegaConf.load(f"{config}")
+
+ model = instantiate_from_config(config.modelUNet)
+ _, _ = model.load_state_dict(sd, strict=False)
+ model.eval()
+
+ modelCS = instantiate_from_config(config.modelCondStage)
+ _, _ = modelCS.load_state_dict(sd, strict=False)
+ modelCS.eval()
+
+ modelFS = instantiate_from_config(config.modelFirstStage)
+ _, _ = modelFS.load_state_dict(sd, strict=False)
+ modelFS.eval()
+ del sd
+
+ demo = gr.Interface(
+ fn=generate,
+ inputs=[
+ "text",
+ gr.Slider(1, 1000, value=50),
+ gr.Slider(1, 100, step=1),
+ gr.Slider(1, 100, step=1),
+ gr.Slider(64, 4096, value=512, step=64),
+ gr.Slider(64, 4096, value=512, step=64),
+ gr.Slider(0, 50, value=7.5, step=0.1),
+ gr.Slider(0, 1, step=0.01),
+ gr.Slider(1, 2, value=1, step=1),
+ gr.Text(value="cuda"),
+ "text",
+ gr.Text(value="outputs/txt2img-samples"),
+ gr.Radio(["png", "jpg"], value='png'),
+ "checkbox",
+ "checkbox",
+ gr.Radio(["ddim", "plms"], value="plms"),
+ ],
+ outputs=["image", "text"],
+ )
+ demo.launch(share=True)
diff --git a/optimizedSD/v1-inference.yaml b/optimizedSD/v1-inference.yaml
index 2e535fcb4..989ac3745 100644
--- a/optimizedSD/v1-inference.yaml
+++ b/optimizedSD/v1-inference.yaml
@@ -24,9 +24,9 @@ modelUNet:
in_channels: 4
out_channels: 4
model_channels: 320
- attention_resolutions: [4, 2, 1]
+ attention_resolutions: [ 4, 2, 1 ]
num_res_blocks: 2
- channel_mult: [1, 2, 4, 4]
+ channel_mult: [ 1, 2, 4, 4 ]
num_heads: 8
use_spatial_transformer: True
transformer_depth: 1
@@ -41,9 +41,9 @@ modelUNet:
in_channels: 4
out_channels: 4
model_channels: 320
- attention_resolutions: [4, 2, 1]
+ attention_resolutions: [ 4, 2, 1 ]
num_res_blocks: 2
- channel_mult: [1, 2, 4, 4]
+ channel_mult: [ 1, 2, 4, 4 ]
num_heads: 8
use_spatial_transformer: True
transformer_depth: 1
@@ -86,7 +86,7 @@ modelFirstStage:
- 4
- 4
num_res_blocks: 2
- attn_resolutions: []
+ attn_resolutions: [ ]
dropout: 0.0
lossconfig:
target: torch.nn.Identity
diff --git a/optimized_colab.ipynb b/optimized_colab.ipynb
new file mode 100644
index 000000000..35f21c699
--- /dev/null
+++ b/optimized_colab.ipynb
@@ -0,0 +1,584 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "source": [
+ "This colab requires you to have model.ckpt on your google drive"
+ ],
+ "metadata": {
+ "collapsed": false,
+ "id": "f20_rzsJc1zU"
+ }
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "collapsed": true,
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 1000
+ },
+ "id": "JtgQHPFgc1zZ",
+ "outputId": "0919675b-1b24-402c-9da7-6baed18d3c84"
+ },
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "Cloning into 'stable-diffusion'...\n",
+ "remote: Enumerating objects: 616, done.\u001b[K\n",
+ "remote: Counting objects: 100% (144/144), done.\u001b[K\n",
+ "remote: Compressing objects: 100% (93/93), done.\u001b[K\n",
+ "remote: Total 616 (delta 51), reused 81 (delta 28), pack-reused 472\u001b[K\n",
+ "Receiving objects: 100% (616/616), 42.90 MiB | 41.76 MiB/s, done.\n",
+ "Resolving deltas: 100% (261/261), done.\n",
+ "Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n",
+ "Collecting gradio\n",
+ " Downloading gradio-3.2-py3-none-any.whl (6.1 MB)\n",
+ "\u001b[K |████████████████████████████████| 6.1 MB 25.2 MB/s \n",
+ "\u001b[?25hRequirement already satisfied: albumentations in /usr/local/lib/python3.7/dist-packages (1.2.1)\n",
+ "Collecting diffusers\n",
+ " Downloading diffusers-0.2.4-py3-none-any.whl (112 kB)\n",
+ "\u001b[K |████████████████████████████████| 112 kB 90.4 MB/s \n",
+ "\u001b[?25hRequirement already satisfied: opencv-python in /usr/local/lib/python3.7/dist-packages (4.6.0.66)\n",
+ "Collecting pudb\n",
+ " Downloading pudb-2022.1.2.tar.gz (219 kB)\n",
+ "\u001b[K |████████████████████████████████| 219 kB 93.3 MB/s \n",
+ "\u001b[?25hCollecting invisible-watermark\n",
+ " Downloading invisible_watermark-0.1.5-py3-none-any.whl (1.6 MB)\n",
+ "\u001b[K |████████████████████████████████| 1.6 MB 63.7 MB/s \n",
+ "\u001b[?25hRequirement already satisfied: imageio in /usr/local/lib/python3.7/dist-packages (2.9.0)\n",
+ "Collecting imageio-ffmpeg\n",
+ " Downloading imageio_ffmpeg-0.4.7-py3-none-manylinux2010_x86_64.whl (26.9 MB)\n",
+ "\u001b[K |████████████████████████████████| 26.9 MB 76.7 MB/s \n",
+ "\u001b[?25hCollecting pytorch-lightning\n",
+ " Downloading pytorch_lightning-1.7.4-py3-none-any.whl (706 kB)\n",
+ "\u001b[K |████████████████████████████████| 706 kB 62.4 MB/s \n",
+ "\u001b[?25hCollecting omegaconf\n",
+ " Downloading omegaconf-2.2.3-py3-none-any.whl (79 kB)\n",
+ "\u001b[K |████████████████████████████████| 79 kB 10.1 MB/s \n",
+ "\u001b[?25hCollecting test-tube\n",
+ " Downloading test_tube-0.7.5.tar.gz (21 kB)\n",
+ "Collecting streamlit\n",
+ " Downloading streamlit-1.12.2-py2.py3-none-any.whl (9.1 MB)\n",
+ "\u001b[K |████████████████████████████████| 9.1 MB 57.7 MB/s \n",
+ "\u001b[?25hCollecting einops\n",
+ " Downloading einops-0.4.1-py3-none-any.whl (28 kB)\n",
+ "Collecting torch-fidelity\n",
+ " Downloading torch_fidelity-0.3.0-py3-none-any.whl (37 kB)\n",
+ "Collecting transformers\n",
+ " Downloading transformers-4.21.2-py3-none-any.whl (4.7 MB)\n",
+ "\u001b[K |████████████████████████████████| 4.7 MB 77.3 MB/s \n",
+ "\u001b[?25hCollecting torchmetrics\n",
+ " Downloading torchmetrics-0.9.3-py3-none-any.whl (419 kB)\n",
+ "\u001b[K |████████████████████████████████| 419 kB 69.9 MB/s \n",
+ "\u001b[?25hCollecting kornia\n",
+ " Downloading kornia-0.6.7-py2.py3-none-any.whl (565 kB)\n",
+ "\u001b[K |████████████████████████████████| 565 kB 80.7 MB/s \n",
+ "\u001b[?25hRequirement already satisfied: fsspec in /usr/local/lib/python3.7/dist-packages (from gradio) (2022.7.1)\n",
+ "Collecting h11<0.13,>=0.11\n",
+ " Downloading h11-0.12.0-py3-none-any.whl (54 kB)\n",
+ "\u001b[K |████████████████████████████████| 54 kB 3.8 MB/s \n",
+ "\u001b[?25hCollecting pycryptodome\n",
+ " Downloading pycryptodome-3.15.0-cp35-abi3-manylinux2010_x86_64.whl (2.3 MB)\n",
+ "\u001b[K |████████████████████████████████| 2.3 MB 57.9 MB/s \n",
+ "\u001b[?25hCollecting ffmpy\n",
+ " Downloading ffmpy-0.3.0.tar.gz (4.8 kB)\n",
+ "Requirement already satisfied: requests in /usr/local/lib/python3.7/dist-packages (from gradio) (2.23.0)\n",
+ "Collecting pydub\n",
+ " Downloading pydub-0.25.1-py2.py3-none-any.whl (32 kB)\n",
+ "Requirement already satisfied: aiohttp in /usr/local/lib/python3.7/dist-packages (from gradio) (3.8.1)\n",
+ "Requirement already satisfied: matplotlib in /usr/local/lib/python3.7/dist-packages (from gradio) (3.2.2)\n",
+ "Requirement already satisfied: Jinja2 in /usr/local/lib/python3.7/dist-packages (from gradio) (2.11.3)\n",
+ "Collecting python-multipart\n",
+ " Downloading python-multipart-0.0.5.tar.gz (32 kB)\n",
+ "Requirement already satisfied: numpy in /usr/local/lib/python3.7/dist-packages (from gradio) (1.21.6)\n",
+ "Collecting analytics-python\n",
+ " Downloading analytics_python-1.4.0-py2.py3-none-any.whl (15 kB)\n",
+ "Collecting markdown-it-py[linkify,plugins]\n",
+ " Downloading markdown_it_py-2.1.0-py3-none-any.whl (84 kB)\n",
+ "\u001b[K |████████████████████████████████| 84 kB 3.3 MB/s \n",
+ "\u001b[?25hCollecting websockets\n",
+ " Downloading websockets-10.3-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl (112 kB)\n",
+ "\u001b[K |████████████████████████████████| 112 kB 73.4 MB/s \n",
+ "\u001b[?25hCollecting paramiko\n",
+ " Downloading paramiko-2.11.0-py2.py3-none-any.whl (212 kB)\n",
+ "\u001b[K |████████████████████████████████| 212 kB 77.0 MB/s \n",
+ "\u001b[?25hCollecting uvicorn\n",
+ " Downloading uvicorn-0.18.3-py3-none-any.whl (57 kB)\n",
+ "\u001b[K |████████████████████████████████| 57 kB 6.1 MB/s \n",
+ "\u001b[?25hRequirement already satisfied: pandas in /usr/local/lib/python3.7/dist-packages (from gradio) (1.3.5)\n",
+ "Collecting orjson\n",
+ " Downloading orjson-3.8.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (270 kB)\n",
+ "\u001b[K |████████████████████████████████| 270 kB 80.6 MB/s \n",
+ "\u001b[?25hRequirement already satisfied: pydantic in /usr/local/lib/python3.7/dist-packages (from gradio) (1.9.2)\n",
+ "Collecting fastapi\n",
+ " Downloading fastapi-0.81.0-py3-none-any.whl (54 kB)\n",
+ "\u001b[K |████████████████████████████████| 54 kB 3.2 MB/s \n",
+ "\u001b[?25hCollecting httpx\n",
+ " Downloading httpx-0.23.0-py3-none-any.whl (84 kB)\n",
+ "\u001b[K |████████████████████████████████| 84 kB 4.8 MB/s \n",
+ "\u001b[?25hRequirement already satisfied: pillow in /usr/local/lib/python3.7/dist-packages (from gradio) (7.1.2)\n",
+ "Requirement already satisfied: scipy in /usr/local/lib/python3.7/dist-packages (from albumentations) (1.7.3)\n",
+ "Requirement already satisfied: scikit-image>=0.16.1 in /usr/local/lib/python3.7/dist-packages (from albumentations) (0.18.3)\n",
+ "Requirement already satisfied: opencv-python-headless>=4.1.1 in /usr/local/lib/python3.7/dist-packages (from albumentations) (4.6.0.66)\n",
+ "Requirement already satisfied: qudida>=0.0.4 in /usr/local/lib/python3.7/dist-packages (from albumentations) (0.0.4)\n",
+ "Requirement already satisfied: PyYAML in /usr/local/lib/python3.7/dist-packages (from albumentations) (6.0)\n",
+ "Requirement already satisfied: typing-extensions in /usr/local/lib/python3.7/dist-packages (from qudida>=0.0.4->albumentations) (4.1.1)\n",
+ "Requirement already satisfied: scikit-learn>=0.19.1 in /usr/local/lib/python3.7/dist-packages (from qudida>=0.0.4->albumentations) (1.0.2)\n",
+ "Requirement already satisfied: networkx>=2.0 in /usr/local/lib/python3.7/dist-packages (from scikit-image>=0.16.1->albumentations) (2.6.3)\n",
+ "Requirement already satisfied: PyWavelets>=1.1.1 in /usr/local/lib/python3.7/dist-packages (from scikit-image>=0.16.1->albumentations) (1.3.0)\n",
+ "Requirement already satisfied: tifffile>=2019.7.26 in /usr/local/lib/python3.7/dist-packages (from scikit-image>=0.16.1->albumentations) (2021.11.2)\n",
+ "Requirement already satisfied: python-dateutil>=2.1 in /usr/local/lib/python3.7/dist-packages (from matplotlib->gradio) (2.8.2)\n",
+ "Requirement already satisfied: kiwisolver>=1.0.1 in /usr/local/lib/python3.7/dist-packages (from matplotlib->gradio) (1.4.4)\n",
+ "Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.7/dist-packages (from matplotlib->gradio) (0.11.0)\n",
+ "Requirement already satisfied: pyparsing!=2.0.4,!=2.1.2,!=2.1.6,>=2.0.1 in /usr/local/lib/python3.7/dist-packages (from matplotlib->gradio) (3.0.9)\n",
+ "Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.7/dist-packages (from python-dateutil>=2.1->matplotlib->gradio) (1.15.0)\n",
+ "Requirement already satisfied: threadpoolctl>=2.0.0 in /usr/local/lib/python3.7/dist-packages (from scikit-learn>=0.19.1->qudida>=0.0.4->albumentations) (3.1.0)\n",
+ "Requirement already satisfied: joblib>=0.11 in /usr/local/lib/python3.7/dist-packages (from scikit-learn>=0.19.1->qudida>=0.0.4->albumentations) (1.1.0)\n",
+ "Requirement already satisfied: importlib-metadata in /usr/local/lib/python3.7/dist-packages (from diffusers) (4.12.0)\n",
+ "Collecting huggingface-hub<1.0,>=0.8.1\n",
+ " Downloading huggingface_hub-0.9.1-py3-none-any.whl (120 kB)\n",
+ "\u001b[K |████████████████████████████████| 120 kB 81.1 MB/s \n",
+ "\u001b[?25hRequirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.7/dist-packages (from diffusers) (2022.6.2)\n",
+ "Requirement already satisfied: filelock in /usr/local/lib/python3.7/dist-packages (from diffusers) (3.8.0)\n",
+ "Requirement already satisfied: torch>=1.4 in /usr/local/lib/python3.7/dist-packages (from diffusers) (1.12.1+cu113)\n",
+ "Requirement already satisfied: tqdm in /usr/local/lib/python3.7/dist-packages (from huggingface-hub<1.0,>=0.8.1->diffusers) (4.64.0)\n",
+ "Requirement already satisfied: packaging>=20.9 in /usr/local/lib/python3.7/dist-packages (from huggingface-hub<1.0,>=0.8.1->diffusers) (21.3)\n",
+ "Collecting urwid>=1.1.1\n",
+ " Downloading urwid-2.1.2.tar.gz (634 kB)\n",
+ "\u001b[K |████████████████████████████████| 634 kB 77.4 MB/s \n",
+ "\u001b[?25hCollecting pygments>=2.7.4\n",
+ " Downloading Pygments-2.13.0-py3-none-any.whl (1.1 MB)\n",
+ "\u001b[K |████████████████████████████████| 1.1 MB 52.6 MB/s \n",
+ "\u001b[?25hCollecting jedi<1,>=0.18\n",
+ " Downloading jedi-0.18.1-py2.py3-none-any.whl (1.6 MB)\n",
+ "\u001b[K |████████████████████████████████| 1.6 MB 61.7 MB/s \n",
+ "\u001b[?25hCollecting urwid_readline\n",
+ " Downloading urwid_readline-0.13.tar.gz (7.9 kB)\n",
+ " Installing build dependencies ... \u001b[?25l\u001b[?25hdone\n",
+ " Getting requirements to build wheel ... \u001b[?25l\u001b[?25hdone\n",
+ " Preparing wheel metadata ... \u001b[?25l\u001b[?25hdone\n",
+ "Requirement already satisfied: parso<0.9.0,>=0.8.0 in /usr/local/lib/python3.7/dist-packages (from jedi<1,>=0.18->pudb) (0.8.3)\n",
+ "Collecting onnxruntime\n",
+ " Downloading onnxruntime-1.12.1-cp37-cp37m-manylinux_2_27_x86_64.whl (4.9 MB)\n",
+ "\u001b[K |████████████████████████████████| 4.9 MB 68.3 MB/s \n",
+ "\u001b[?25hCollecting onnx\n",
+ " Downloading onnx-1.12.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (13.1 MB)\n",
+ "\u001b[K |████████████████████████████████| 13.1 MB 63.4 MB/s \n",
+ "\u001b[?25hCollecting tensorboard>=2.9.1\n",
+ " Downloading tensorboard-2.10.0-py3-none-any.whl (5.9 MB)\n",
+ "\u001b[K |████████████████████████████████| 5.9 MB 70.8 MB/s \n",
+ "\u001b[?25hCollecting pyDeprecate>=0.3.1\n",
+ " Downloading pyDeprecate-0.3.2-py3-none-any.whl (10 kB)\n",
+ "Requirement already satisfied: google-auth<3,>=1.6.3 in /usr/local/lib/python3.7/dist-packages (from tensorboard>=2.9.1->pytorch-lightning) (1.35.0)\n",
+ "Requirement already satisfied: markdown>=2.6.8 in /usr/local/lib/python3.7/dist-packages (from tensorboard>=2.9.1->pytorch-lightning) (3.4.1)\n",
+ "Requirement already satisfied: tensorboard-plugin-wit>=1.6.0 in /usr/local/lib/python3.7/dist-packages (from tensorboard>=2.9.1->pytorch-lightning) (1.8.1)\n",
+ "Requirement already satisfied: tensorboard-data-server<0.7.0,>=0.6.0 in /usr/local/lib/python3.7/dist-packages (from tensorboard>=2.9.1->pytorch-lightning) (0.6.1)\n",
+ "Requirement already satisfied: protobuf<3.20,>=3.9.2 in /usr/local/lib/python3.7/dist-packages (from tensorboard>=2.9.1->pytorch-lightning) (3.17.3)\n",
+ "Requirement already satisfied: werkzeug>=1.0.1 in /usr/local/lib/python3.7/dist-packages (from tensorboard>=2.9.1->pytorch-lightning) (1.0.1)\n",
+ "Requirement already satisfied: setuptools>=41.0.0 in /usr/local/lib/python3.7/dist-packages (from tensorboard>=2.9.1->pytorch-lightning) (57.4.0)\n",
+ "Requirement already satisfied: absl-py>=0.4 in /usr/local/lib/python3.7/dist-packages (from tensorboard>=2.9.1->pytorch-lightning) (1.2.0)\n",
+ "Requirement already satisfied: grpcio>=1.24.3 in /usr/local/lib/python3.7/dist-packages (from tensorboard>=2.9.1->pytorch-lightning) (1.47.0)\n",
+ "Requirement already satisfied: wheel>=0.26 in /usr/local/lib/python3.7/dist-packages (from tensorboard>=2.9.1->pytorch-lightning) (0.37.1)\n",
+ "Requirement already satisfied: google-auth-oauthlib<0.5,>=0.4.1 in /usr/local/lib/python3.7/dist-packages (from tensorboard>=2.9.1->pytorch-lightning) (0.4.6)\n",
+ "Requirement already satisfied: cachetools<5.0,>=2.0.0 in /usr/local/lib/python3.7/dist-packages (from google-auth<3,>=1.6.3->tensorboard>=2.9.1->pytorch-lightning) (4.2.4)\n",
+ "Requirement already satisfied: rsa<5,>=3.1.4 in /usr/local/lib/python3.7/dist-packages (from google-auth<3,>=1.6.3->tensorboard>=2.9.1->pytorch-lightning) (4.9)\n",
+ "Requirement already satisfied: pyasn1-modules>=0.2.1 in /usr/local/lib/python3.7/dist-packages (from google-auth<3,>=1.6.3->tensorboard>=2.9.1->pytorch-lightning) (0.2.8)\n",
+ "Requirement already satisfied: requests-oauthlib>=0.7.0 in /usr/local/lib/python3.7/dist-packages (from google-auth-oauthlib<0.5,>=0.4.1->tensorboard>=2.9.1->pytorch-lightning) (1.3.1)\n",
+ "Requirement already satisfied: zipp>=0.5 in /usr/local/lib/python3.7/dist-packages (from importlib-metadata->diffusers) (3.8.1)\n",
+ "Requirement already satisfied: pyasn1<0.5.0,>=0.4.6 in /usr/local/lib/python3.7/dist-packages (from pyasn1-modules>=0.2.1->google-auth<3,>=1.6.3->tensorboard>=2.9.1->pytorch-lightning) (0.4.8)\n",
+ "Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.7/dist-packages (from requests->gradio) (3.0.4)\n",
+ "Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.7/dist-packages (from requests->gradio) (1.24.3)\n",
+ "Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.7/dist-packages (from requests->gradio) (2.10)\n",
+ "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.7/dist-packages (from requests->gradio) (2022.6.15)\n",
+ "Requirement already satisfied: oauthlib>=3.0.0 in /usr/local/lib/python3.7/dist-packages (from requests-oauthlib>=0.7.0->google-auth-oauthlib<0.5,>=0.4.1->tensorboard>=2.9.1->pytorch-lightning) (3.2.0)\n",
+ "Collecting antlr4-python3-runtime==4.9.*\n",
+ " Downloading antlr4-python3-runtime-4.9.3.tar.gz (117 kB)\n",
+ "\u001b[K |████████████████████████████████| 117 kB 86.1 MB/s \n",
+ "\u001b[?25hRequirement already satisfied: future in /usr/local/lib/python3.7/dist-packages (from test-tube) (0.16.0)\n",
+ "Requirement already satisfied: pytz>=2017.3 in /usr/local/lib/python3.7/dist-packages (from pandas->gradio) (2022.2.1)\n",
+ "Requirement already satisfied: tornado>=5.0 in /usr/local/lib/python3.7/dist-packages (from streamlit) (5.1.1)\n",
+ "Collecting blinker>=1.0.0\n",
+ " Downloading blinker-1.5-py2.py3-none-any.whl (12 kB)\n",
+ "Collecting gitpython!=3.1.19\n",
+ " Downloading GitPython-3.1.27-py3-none-any.whl (181 kB)\n",
+ "\u001b[K |████████████████████████████████| 181 kB 76.6 MB/s \n",
+ "\u001b[?25hCollecting pympler>=0.9\n",
+ " Downloading Pympler-1.0.1-py3-none-any.whl (164 kB)\n",
+ "\u001b[K |████████████████████████████████| 164 kB 67.1 MB/s \n",
+ "\u001b[?25hRequirement already satisfied: altair>=3.2.0 in /usr/local/lib/python3.7/dist-packages (from streamlit) (4.2.0)\n",
+ "Requirement already satisfied: toml in /usr/local/lib/python3.7/dist-packages (from streamlit) (0.10.2)\n",
+ "Requirement already satisfied: pyarrow>=4.0 in /usr/local/lib/python3.7/dist-packages (from streamlit) (6.0.1)\n",
+ "Collecting pydeck>=0.1.dev5\n",
+ " Downloading pydeck-0.8.0b1-py2.py3-none-any.whl (4.7 MB)\n",
+ "\u001b[K |████████████████████████████████| 4.7 MB 70.5 MB/s \n",
+ "\u001b[?25hCollecting validators>=0.2\n",
+ " Downloading validators-0.20.0.tar.gz (30 kB)\n",
+ "Collecting semver\n",
+ " Downloading semver-2.13.0-py2.py3-none-any.whl (12 kB)\n",
+ "Requirement already satisfied: tzlocal>=1.1 in /usr/local/lib/python3.7/dist-packages (from streamlit) (1.5.1)\n",
+ "Collecting watchdog\n",
+ " Downloading watchdog-2.1.9-py3-none-manylinux2014_x86_64.whl (78 kB)\n",
+ "\u001b[K |████████████████████████████████| 78 kB 8.7 MB/s \n",
+ "\u001b[?25hCollecting rich>=10.11.0\n",
+ " Downloading rich-12.5.1-py3-none-any.whl (235 kB)\n",
+ "\u001b[K |████████████████████████████████| 235 kB 75.3 MB/s \n",
+ "\u001b[?25hRequirement already satisfied: click>=7.0 in /usr/local/lib/python3.7/dist-packages (from streamlit) (7.1.2)\n",
+ "Requirement already satisfied: entrypoints in /usr/local/lib/python3.7/dist-packages (from altair>=3.2.0->streamlit) (0.4)\n",
+ "Requirement already satisfied: jsonschema>=3.0 in /usr/local/lib/python3.7/dist-packages (from altair>=3.2.0->streamlit) (4.3.3)\n",
+ "Requirement already satisfied: toolz in /usr/local/lib/python3.7/dist-packages (from altair>=3.2.0->streamlit) (0.12.0)\n",
+ "Collecting gitdb<5,>=4.0.1\n",
+ " Downloading gitdb-4.0.9-py3-none-any.whl (63 kB)\n",
+ "\u001b[K |████████████████████████████████| 63 kB 2.0 MB/s \n",
+ "\u001b[?25hCollecting smmap<6,>=3.0.1\n",
+ " Downloading smmap-5.0.0-py3-none-any.whl (24 kB)\n",
+ "Requirement already satisfied: pyrsistent!=0.17.0,!=0.17.1,!=0.17.2,>=0.14.0 in /usr/local/lib/python3.7/dist-packages (from jsonschema>=3.0->altair>=3.2.0->streamlit) (0.18.1)\n",
+ "Requirement already satisfied: importlib-resources>=1.4.0 in /usr/local/lib/python3.7/dist-packages (from jsonschema>=3.0->altair>=3.2.0->streamlit) (5.9.0)\n",
+ "Requirement already satisfied: attrs>=17.4.0 in /usr/local/lib/python3.7/dist-packages (from jsonschema>=3.0->altair>=3.2.0->streamlit) (22.1.0)\n",
+ "Requirement already satisfied: MarkupSafe>=0.23 in /usr/local/lib/python3.7/dist-packages (from Jinja2->gradio) (2.0.1)\n",
+ "Collecting commonmark<0.10.0,>=0.9.0\n",
+ " Downloading commonmark-0.9.1-py2.py3-none-any.whl (51 kB)\n",
+ "\u001b[K |████████████████████████████████| 51 kB 8.2 MB/s \n",
+ "\u001b[?25hRequirement already satisfied: decorator>=3.4.0 in /usr/local/lib/python3.7/dist-packages (from validators>=0.2->streamlit) (4.4.2)\n",
+ "Requirement already satisfied: torchvision in /usr/local/lib/python3.7/dist-packages (from torch-fidelity) (0.13.1+cu113)\n",
+ "Collecting tokenizers!=0.11.3,<0.13,>=0.11.1\n",
+ " Downloading tokenizers-0.12.1-cp37-cp37m-manylinux_2_12_x86_64.manylinux2010_x86_64.whl (6.6 MB)\n",
+ "\u001b[K |████████████████████████████████| 6.6 MB 86.8 MB/s \n",
+ "\u001b[?25hRequirement already satisfied: yarl<2.0,>=1.0 in /usr/local/lib/python3.7/dist-packages (from aiohttp->gradio) (1.8.1)\n",
+ "Requirement already satisfied: async-timeout<5.0,>=4.0.0a3 in /usr/local/lib/python3.7/dist-packages (from aiohttp->gradio) (4.0.2)\n",
+ "Requirement already satisfied: charset-normalizer<3.0,>=2.0 in /usr/local/lib/python3.7/dist-packages (from aiohttp->gradio) (2.1.1)\n",
+ "Requirement already satisfied: aiosignal>=1.1.2 in /usr/local/lib/python3.7/dist-packages (from aiohttp->gradio) (1.2.0)\n",
+ "Requirement already satisfied: asynctest==0.13.0 in /usr/local/lib/python3.7/dist-packages (from aiohttp->gradio) (0.13.0)\n",
+ "Requirement already satisfied: frozenlist>=1.1.1 in /usr/local/lib/python3.7/dist-packages (from aiohttp->gradio) (1.3.1)\n",
+ "Requirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.7/dist-packages (from aiohttp->gradio) (6.0.2)\n",
+ "Collecting monotonic>=1.5\n",
+ " Downloading monotonic-1.6-py2.py3-none-any.whl (8.2 kB)\n",
+ "Collecting backoff==1.10.0\n",
+ " Downloading backoff-1.10.0-py2.py3-none-any.whl (31 kB)\n",
+ "Collecting starlette==0.19.1\n",
+ " Downloading starlette-0.19.1-py3-none-any.whl (63 kB)\n",
+ "\u001b[K |████████████████████████████████| 63 kB 2.3 MB/s \n",
+ "\u001b[?25hCollecting anyio<5,>=3.4.0\n",
+ " Downloading anyio-3.6.1-py3-none-any.whl (80 kB)\n",
+ "\u001b[K |████████████████████████████████| 80 kB 12.3 MB/s \n",
+ "\u001b[?25hCollecting sniffio>=1.1\n",
+ " Downloading sniffio-1.3.0-py3-none-any.whl (10 kB)\n",
+ "Collecting rfc3986[idna2008]<2,>=1.3\n",
+ " Downloading rfc3986-1.5.0-py2.py3-none-any.whl (31 kB)\n",
+ "Collecting httpcore<0.16.0,>=0.15.0\n",
+ " Downloading httpcore-0.15.0-py3-none-any.whl (68 kB)\n",
+ "\u001b[K |████████████████████████████████| 68 kB 7.8 MB/s \n",
+ "\u001b[?25hCollecting mdurl~=0.1\n",
+ " Downloading mdurl-0.1.2-py3-none-any.whl (10.0 kB)\n",
+ "Collecting mdit-py-plugins\n",
+ " Downloading mdit_py_plugins-0.3.0-py3-none-any.whl (43 kB)\n",
+ "\u001b[K |████████████████████████████████| 43 kB 2.2 MB/s \n",
+ "\u001b[?25hCollecting linkify-it-py~=1.0\n",
+ " Downloading linkify_it_py-1.0.3-py3-none-any.whl (19 kB)\n",
+ "Collecting uc-micro-py\n",
+ " Downloading uc_micro_py-1.0.1-py3-none-any.whl (6.2 kB)\n",
+ "Collecting coloredlogs\n",
+ " Downloading coloredlogs-15.0.1-py2.py3-none-any.whl (46 kB)\n",
+ "\u001b[K |████████████████████████████████| 46 kB 4.7 MB/s \n",
+ "\u001b[?25hRequirement already satisfied: flatbuffers in /usr/local/lib/python3.7/dist-packages (from onnxruntime->invisible-watermark) (2.0.7)\n",
+ "Requirement already satisfied: sympy in /usr/local/lib/python3.7/dist-packages (from onnxruntime->invisible-watermark) (1.7.1)\n",
+ "Collecting humanfriendly>=9.1\n",
+ " Downloading humanfriendly-10.0-py2.py3-none-any.whl (86 kB)\n",
+ "\u001b[K |████████████████████████████████| 86 kB 6.9 MB/s \n",
+ "\u001b[?25hCollecting bcrypt>=3.1.3\n",
+ " Downloading bcrypt-4.0.0-cp36-abi3-manylinux_2_24_x86_64.whl (594 kB)\n",
+ "\u001b[K |████████████████████████████████| 594 kB 86.9 MB/s \n",
+ "\u001b[?25hCollecting pynacl>=1.0.1\n",
+ " Downloading PyNaCl-1.5.0-cp36-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_24_x86_64.whl (856 kB)\n",
+ "\u001b[K |████████████████████████████████| 856 kB 70.3 MB/s \n",
+ "\u001b[?25hCollecting cryptography>=2.5\n",
+ " Downloading cryptography-37.0.4-cp36-abi3-manylinux_2_24_x86_64.whl (4.1 MB)\n",
+ "\u001b[K |████████████████████████████████| 4.1 MB 68.4 MB/s \n",
+ "\u001b[?25hRequirement already satisfied: cffi>=1.12 in /usr/local/lib/python3.7/dist-packages (from cryptography>=2.5->paramiko->gradio) (1.15.1)\n",
+ "Requirement already satisfied: pycparser in /usr/local/lib/python3.7/dist-packages (from cffi>=1.12->cryptography>=2.5->paramiko->gradio) (2.21)\n",
+ "Requirement already satisfied: mpmath>=0.19 in /usr/local/lib/python3.7/dist-packages (from sympy->onnxruntime->invisible-watermark) (1.2.1)\n",
+ "Building wheels for collected packages: pudb, urwid, antlr4-python3-runtime, test-tube, validators, ffmpy, python-multipart, urwid-readline\n",
+ " Building wheel for pudb (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
+ " Created wheel for pudb: filename=pudb-2022.1.2-py3-none-any.whl size=69098 sha256=f977e6058f9892ce0acb9db00038f7217c289ff944bd607049d8944e62760c73\n",
+ " Stored in directory: /root/.cache/pip/wheels/5e/92/c1/e3e89bc3921f89b17c3595c1cfc1ad485d580be00e6f2cf4b3\n",
+ " Building wheel for urwid (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
+ " Created wheel for urwid: filename=urwid-2.1.2-cp37-cp37m-linux_x86_64.whl size=258319 sha256=f7045cfb1c3bf4015c815b2ace097c0437ac5ee669477e6b33dfdb0a22b7a531\n",
+ " Stored in directory: /root/.cache/pip/wheels/79/77/cf/cae9cf1cc3f1f777f9db531424bbd9e15aa38e4ca28dbe499e\n",
+ " Building wheel for antlr4-python3-runtime (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
+ " Created wheel for antlr4-python3-runtime: filename=antlr4_python3_runtime-4.9.3-py3-none-any.whl size=144575 sha256=9779b7a4f96af8556f28c57dadc5c7d039f60e07fe57c4a1498a75bfc7edff52\n",
+ " Stored in directory: /root/.cache/pip/wheels/8b/8d/53/2af8772d9aec614e3fc65e53d4a993ad73c61daa8bbd85a873\n",
+ " Building wheel for test-tube (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
+ " Created wheel for test-tube: filename=test_tube-0.7.5-py3-none-any.whl size=25356 sha256=df61a82154841c406b3855026f5a717df022fbc45f6a4a4cd63a9f9633f0fc59\n",
+ " Stored in directory: /root/.cache/pip/wheels/1c/50/0d/15b3236957cc18a5c39ec4d4d4d21624f4d4a876756ec17064\n",
+ " Building wheel for validators (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
+ " Created wheel for validators: filename=validators-0.20.0-py3-none-any.whl size=19582 sha256=dc40f53fc2c056ef9dcf3e576338c78860738c5fad2ce01228a85b6bf2102c92\n",
+ " Stored in directory: /root/.cache/pip/wheels/5f/55/ab/36a76989f7f88d9ca7b1f68da6d94252bb6a8d6ad4f18e04e9\n",
+ " Building wheel for ffmpy (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
+ " Created wheel for ffmpy: filename=ffmpy-0.3.0-py3-none-any.whl size=4712 sha256=253ca687885aee92d242b65fb41f40190a65f5a3fabf586acd234ce392997a05\n",
+ " Stored in directory: /root/.cache/pip/wheels/13/e4/6c/e8059816e86796a597c6e6b0d4c880630f51a1fcfa0befd5e6\n",
+ " Building wheel for python-multipart (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
+ " Created wheel for python-multipart: filename=python_multipart-0.0.5-py3-none-any.whl size=31678 sha256=9ba14a5d4c59c7f069d45af15d397bd8cd352a8d22d98aeaf340b59f08efd6b2\n",
+ " Stored in directory: /root/.cache/pip/wheels/2c/41/7c/bfd1c180534ffdcc0972f78c5758f89881602175d48a8bcd2c\n",
+ " Building wheel for urwid-readline (PEP 517) ... \u001b[?25l\u001b[?25hdone\n",
+ " Created wheel for urwid-readline: filename=urwid_readline-0.13-py3-none-any.whl size=7576 sha256=4367729597eb6a6dfc10ca0dafcaa61d1d68f9589881973f4d01c649b1a4e646\n",
+ " Stored in directory: /root/.cache/pip/wheels/67/54/48/d63c65f1b0f25d47d10c1c062bc73d81b81ee2c25f76efe9c1\n",
+ "Successfully built pudb urwid antlr4-python3-runtime test-tube validators ffmpy python-multipart urwid-readline\n",
+ "Installing collected packages: sniffio, mdurl, uc-micro-py, smmap, rfc3986, markdown-it-py, humanfriendly, h11, anyio, urwid, starlette, pynacl, pygments, monotonic, mdit-py-plugins, linkify-it-py, httpcore, gitdb, cryptography, commonmark, coloredlogs, bcrypt, backoff, websockets, watchdog, validators, uvicorn, urwid-readline, torchmetrics, tokenizers, tensorboard, semver, rich, python-multipart, pympler, pydub, pyDeprecate, pydeck, pycryptodome, paramiko, orjson, onnxruntime, onnx, jedi, huggingface-hub, httpx, gitpython, ffmpy, fastapi, blinker, antlr4-python3-runtime, analytics-python, transformers, torch-fidelity, test-tube, streamlit, pytorch-lightning, pudb, omegaconf, kornia, invisible-watermark, imageio-ffmpeg, gradio, einops, diffusers\n",
+ " Attempting uninstall: pygments\n",
+ " Found existing installation: Pygments 2.6.1\n",
+ " Uninstalling Pygments-2.6.1:\n",
+ " Successfully uninstalled Pygments-2.6.1\n",
+ " Attempting uninstall: tensorboard\n",
+ " Found existing installation: tensorboard 2.8.0\n",
+ " Uninstalling tensorboard-2.8.0:\n",
+ " Successfully uninstalled tensorboard-2.8.0\n",
+ "\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n",
+ "tensorflow 2.8.2+zzzcolab20220719082949 requires tensorboard<2.9,>=2.8, but you have tensorboard 2.10.0 which is incompatible.\u001b[0m\n",
+ "Successfully installed analytics-python-1.4.0 antlr4-python3-runtime-4.9.3 anyio-3.6.1 backoff-1.10.0 bcrypt-4.0.0 blinker-1.5 coloredlogs-15.0.1 commonmark-0.9.1 cryptography-37.0.4 diffusers-0.2.4 einops-0.4.1 fastapi-0.81.0 ffmpy-0.3.0 gitdb-4.0.9 gitpython-3.1.27 gradio-3.2 h11-0.12.0 httpcore-0.15.0 httpx-0.23.0 huggingface-hub-0.9.1 humanfriendly-10.0 imageio-ffmpeg-0.4.7 invisible-watermark-0.1.5 jedi-0.18.1 kornia-0.6.7 linkify-it-py-1.0.3 markdown-it-py-2.1.0 mdit-py-plugins-0.3.0 mdurl-0.1.2 monotonic-1.6 omegaconf-2.2.3 onnx-1.12.0 onnxruntime-1.12.1 orjson-3.8.0 paramiko-2.11.0 pudb-2022.1.2 pyDeprecate-0.3.2 pycryptodome-3.15.0 pydeck-0.8.0b1 pydub-0.25.1 pygments-2.13.0 pympler-1.0.1 pynacl-1.5.0 python-multipart-0.0.5 pytorch-lightning-1.7.4 rfc3986-1.5.0 rich-12.5.1 semver-2.13.0 smmap-5.0.0 sniffio-1.3.0 starlette-0.19.1 streamlit-1.12.2 tensorboard-2.10.0 test-tube-0.7.5 tokenizers-0.12.1 torch-fidelity-0.3.0 torchmetrics-0.9.3 transformers-4.21.2 uc-micro-py-1.0.1 urwid-2.1.2 urwid-readline-0.13 uvicorn-0.18.3 validators-0.20.0 watchdog-2.1.9 websockets-10.3\n"
+ ]
+ },
+ {
+ "output_type": "display_data",
+ "data": {
+ "application/vnd.colab-display-data+json": {
+ "pip_warning": {
+ "packages": [
+ "pydevd_plugins",
+ "pygments"
+ ]
+ }
+ }
+ },
+ "metadata": {}
+ },
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n",
+ "Obtaining taming-transformers from git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers\n",
+ " Cloning https://github.com/CompVis/taming-transformers.git (to revision master) to ./src/taming-transformers\n",
+ " Running command git clone -q https://github.com/CompVis/taming-transformers.git /content/src/taming-transformers\n",
+ "Requirement already satisfied: torch in /usr/local/lib/python3.7/dist-packages (from taming-transformers) (1.12.1+cu113)\n",
+ "Requirement already satisfied: numpy in /usr/local/lib/python3.7/dist-packages (from taming-transformers) (1.21.6)\n",
+ "Requirement already satisfied: tqdm in /usr/local/lib/python3.7/dist-packages (from taming-transformers) (4.64.0)\n",
+ "Requirement already satisfied: typing-extensions in /usr/local/lib/python3.7/dist-packages (from torch->taming-transformers) (4.1.1)\n",
+ "Installing collected packages: taming-transformers\n",
+ " Running setup.py develop for taming-transformers\n",
+ "Successfully installed taming-transformers-0.0.1\n",
+ "Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n",
+ "Obtaining clip from git+https://github.com/openai/CLIP.git@main#egg=clip\n",
+ " Cloning https://github.com/openai/CLIP.git (to revision main) to ./src/clip\n",
+ " Running command git clone -q https://github.com/openai/CLIP.git /content/src/clip\n",
+ "Collecting ftfy\n",
+ " Downloading ftfy-6.1.1-py3-none-any.whl (53 kB)\n",
+ "\u001b[K |████████████████████████████████| 53 kB 1.9 MB/s \n",
+ "\u001b[?25hRequirement already satisfied: regex in /usr/local/lib/python3.7/dist-packages (from clip) (2022.6.2)\n",
+ "Requirement already satisfied: tqdm in /usr/local/lib/python3.7/dist-packages (from clip) (4.64.0)\n",
+ "Requirement already satisfied: torch in /usr/local/lib/python3.7/dist-packages (from clip) (1.12.1+cu113)\n",
+ "Requirement already satisfied: torchvision in /usr/local/lib/python3.7/dist-packages (from clip) (0.13.1+cu113)\n",
+ "Requirement already satisfied: wcwidth>=0.2.5 in /usr/local/lib/python3.7/dist-packages (from ftfy->clip) (0.2.5)\n",
+ "Requirement already satisfied: typing-extensions in /usr/local/lib/python3.7/dist-packages (from torch->clip) (4.1.1)\n",
+ "Requirement already satisfied: requests in /usr/local/lib/python3.7/dist-packages (from torchvision->clip) (2.23.0)\n",
+ "Requirement already satisfied: pillow!=8.3.*,>=5.3.0 in /usr/local/lib/python3.7/dist-packages (from torchvision->clip) (7.1.2)\n",
+ "Requirement already satisfied: numpy in /usr/local/lib/python3.7/dist-packages (from torchvision->clip) (1.21.6)\n",
+ "Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.7/dist-packages (from requests->torchvision->clip) (2.10)\n",
+ "Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.7/dist-packages (from requests->torchvision->clip) (1.24.3)\n",
+ "Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.7/dist-packages (from requests->torchvision->clip) (3.0.4)\n",
+ "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.7/dist-packages (from requests->torchvision->clip) (2022.6.15)\n",
+ "Installing collected packages: ftfy, clip\n",
+ " Running setup.py develop for clip\n",
+ "Successfully installed clip-1.0 ftfy-6.1.1\n"
+ ]
+ }
+ ],
+ "source": [
+ "!git clone https://github.com/neonsecret/stable-diffusion.git\n",
+ "!pip install gradio albumentations diffusers opencv-python pudb invisible-watermark imageio imageio-ffmpeg pytorch-lightning omegaconf test-tube streamlit einops torch-fidelity transformers torchmetrics kornia\n",
+ "!pip install -e git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers\n",
+ "!pip install -e git+https://github.com/openai/CLIP.git@main#egg=clip"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "/content/stable-diffusion\n",
+ "Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n",
+ "Obtaining file:///content/stable-diffusion\n",
+ "Requirement already satisfied: torch in /usr/local/lib/python3.7/dist-packages (from latent-diffusion==0.0.1) (1.12.1+cu113)\n",
+ "Requirement already satisfied: numpy in /usr/local/lib/python3.7/dist-packages (from latent-diffusion==0.0.1) (1.21.6)\n",
+ "Requirement already satisfied: tqdm in /usr/local/lib/python3.7/dist-packages (from latent-diffusion==0.0.1) (4.64.0)\n",
+ "Requirement already satisfied: typing-extensions in /usr/local/lib/python3.7/dist-packages (from torch->latent-diffusion==0.0.1) (4.1.1)\n",
+ "Installing collected packages: latent-diffusion\n",
+ " Running setup.py develop for latent-diffusion\n",
+ "Successfully installed latent-diffusion-0.0.1\n"
+ ]
+ }
+ ],
+ "source": [
+ "%cd stable-diffusion\n",
+ "!pip install -e ."
+ ],
+ "metadata": {
+ "pycharm": {
+ "name": "#%%\n"
+ },
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "ksvqjvync1zb",
+ "outputId": "505ac7d3-cba4-4f54-f81f-4520af57a733"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "from google.colab import drive\n",
+ "drive.mount('/content/drive')"
+ ],
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "RJeVH0KtdVBh",
+ "outputId": "22fa9eba-6d35-464b-8fda-55ae8962d89e"
+ },
+ "execution_count": null,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "Mounted at /content/drive\n"
+ ]
+ }
+ ]
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "#@title If you don't have sd-v1-4.ckpt, download it here\n",
+ "save_path = '/content/drive/MyDrive/sd-v1-4.ckpt' #@param {type:\"string\"}\n",
+ "username = \"username\" #@param {type:\"string\"}\n",
+ "huggingface_token = \"token\" #@param {type:\"string\"}\n",
+ "!wget https://$username:$huggingface_token@huggingface.co/CompVis/stable-diffusion-v-1-4-original/resolve/main/sd-v1-4.ckpt -O $save_path"
+ ],
+ "metadata": {
+ "id": "oUYCuhBrfO3G"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "Loading model from /content/drive/MyDrive/sd-v1-4.ckpt\n",
+ "Global Step: 470000\n",
+ "UNet: Running in eps-prediction mode\n",
+ "CondStage: Running in eps-prediction mode\n",
+ "Downloading vocab.json: 100% 939k/939k [00:00<00:00, 51.9MB/s]\n",
+ "Downloading merges.txt: 100% 512k/512k [00:00<00:00, 20.0MB/s]\n",
+ "Downloading special_tokens_map.json: 100% 389/389 [00:00<00:00, 456kB/s]\n",
+ "Downloading tokenizer_config.json: 100% 905/905 [00:00<00:00, 885kB/s]\n",
+ "Downloading config.json: 100% 4.31k/4.31k [00:00<00:00, 4.29MB/s]\n",
+ "Downloading pytorch_model.bin: 100% 1.59G/1.59G [00:25<00:00, 67.9MB/s]\n",
+ "FirstStage: Running in eps-prediction mode\n",
+ "making attention of type 'vanilla' with 512 in_channels\n",
+ "Working with z of shape (1, 4, 32, 32) = 4096 dimensions.\n",
+ "making attention of type 'vanilla' with 512 in_channels\n",
+ "Running on local URL: http://127.0.0.1:7860/\n",
+ "\n",
+ "To create a public link, set `share=True` in `launch()`.\n",
+ "Keyboard interruption in main thread... closing server.\n",
+ "^C\n"
+ ]
+ }
+ ],
+ "source": [
+ "#@title Optimized img2img\n",
+ "ckpt_path = '/content/drive/MyDrive/sd-v1-4.ckpt' #@param {type:\"string\"}\n",
+ "!python optimizedSD/img2img_gradio.py --ckpt_path $ckpt_path"
+ ],
+ "metadata": {
+ "pycharm": {
+ "name": "#%%\n"
+ },
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "0FgHkm_Ic1zb",
+ "outputId": "67f905ee-b7fc-4382-99b4-88ab2a8c3be7"
+ }
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "outputs": [],
+ "source": [
+ "#@title Optimized txt2img\n",
+ "ckpt_path = '/content/drive/MyDrive/sd-v1-4.ckpt' #@param {type:\"string\"}\n",
+ "!python optimizedSD/txt2img_gradio.py --ckpt_path $ckpt_path"
+ ],
+ "metadata": {
+ "pycharm": {
+ "name": "#%%\n"
+ },
+ "id": "gnYhjq2Gc1zc"
+ }
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "outputs": [],
+ "source": [
+ "#@title Optimized masked img2img (inpainting)\n",
+ "ckpt_path = '/content/drive/MyDrive/sd-v1-4.ckpt' #@param {type:\"string\"}\n",
+ "!python optimizedSD/inpaint_gradio.py --ckpt_path $ckpt_path"
+ ],
+ "metadata": {
+ "pycharm": {
+ "name": "#%%\n"
+ },
+ "id": "Yba4bSgCc1ze"
+ }
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 2
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython2",
+ "version": "2.7.6"
+ },
+ "colab": {
+ "provenance": [],
+ "collapsed_sections": []
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 0
+}
diff --git a/scripts/img2img.py b/scripts/img2img.py
index 421e2151d..2e5d102d2 100644
--- a/scripts/img2img.py
+++ b/scripts/img2img.py
@@ -1,23 +1,26 @@
"""make variations of input image"""
-import argparse, os, sys, glob
import PIL
-import torch
+import argparse
+import glob
import numpy as np
-from omegaconf import OmegaConf
+import os
+import sys
+import time
+import torch
from PIL import Image
-from tqdm import tqdm, trange
-from itertools import islice
-from einops import rearrange, repeat
-from torchvision.utils import make_grid
-from torch import autocast
from contextlib import nullcontext
-import time
+from einops import rearrange, repeat
+from itertools import islice
+from omegaconf import OmegaConf
from pytorch_lightning import seed_everything
+from torch import autocast
+from torchvision.utils import make_grid
+from tqdm import tqdm, trange
-from ldm.util import instantiate_from_config
from ldm.models.diffusion.ddim import DDIMSampler
from ldm.models.diffusion.plms import PLMSSampler
+from ldm.util import instantiate_from_config
def chunk(it, size):
@@ -54,7 +57,7 @@ def load_img(path):
image = np.array(image).astype(np.float32) / 255.0
image = image[None].transpose(0, 3, 1, 2)
image = torch.from_numpy(image)
- return 2.*image - 1.
+ return 2. * image - 1.
def main():
@@ -256,10 +259,10 @@ def main():
c = model.get_learned_conditioning(prompts)
# encode (scaled latent)
- z_enc = sampler.stochastic_encode(init_latent, torch.tensor([t_enc]*batch_size).to(device))
+ z_enc = sampler.stochastic_encode(init_latent, torch.tensor([t_enc] * batch_size).to(device))
# decode it
samples = sampler.decode(z_enc, c, t_enc, unconditional_guidance_scale=opt.scale,
- unconditional_conditioning=uc,)
+ unconditional_conditioning=uc, )
x_samples = model.decode_first_stage(samples)
x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0)
diff --git a/scripts/inpaint.py b/scripts/inpaint.py
index d6e6387a9..3579afe16 100644
--- a/scripts/inpaint.py
+++ b/scripts/inpaint.py
@@ -1,32 +1,36 @@
-import argparse, os, sys, glob
-from omegaconf import OmegaConf
-from PIL import Image
-from tqdm import tqdm
+import argparse
+import glob
import numpy as np
+import os
+import sys
import torch
-from main import instantiate_from_config
+from PIL import Image
+from omegaconf import OmegaConf
+from tqdm import tqdm
+
from ldm.models.diffusion.ddim import DDIMSampler
+from main import instantiate_from_config
def make_batch(image, mask, device):
image = np.array(Image.open(image).convert("RGB"))
- image = image.astype(np.float32)/255.0
- image = image[None].transpose(0,3,1,2)
+ image = image.astype(np.float32) / 255.0
+ image = image[None].transpose(0, 3, 1, 2)
image = torch.from_numpy(image)
mask = np.array(Image.open(mask).convert("L"))
- mask = mask.astype(np.float32)/255.0
- mask = mask[None,None]
+ mask = mask.astype(np.float32) / 255.0
+ mask = mask[None, None]
mask[mask < 0.5] = 0
mask[mask >= 0.5] = 1
mask = torch.from_numpy(mask)
- masked_image = (1-mask)*image
+ masked_image = (1 - mask) * image
batch = {"image": image, "mask": mask, "masked_image": masked_image}
for k in batch:
batch[k] = batch[k].to(device=device)
- batch[k] = batch[k]*2.0-1.0
+ batch[k] = batch[k] * 2.0 - 1.0
return batch
@@ -78,7 +82,7 @@ def make_batch(image, mask, device):
size=c.shape[-2:])
c = torch.cat((c, cc), dim=1)
- shape = (c.shape[1]-1,)+c.shape[2:]
+ shape = (c.shape[1] - 1,) + c.shape[2:]
samples_ddim, _ = sampler.sample(S=opt.steps,
conditioning=c,
batch_size=c.shape[0],
@@ -86,13 +90,13 @@ def make_batch(image, mask, device):
verbose=False)
x_samples_ddim = model.decode_first_stage(samples_ddim)
- image = torch.clamp((batch["image"]+1.0)/2.0,
+ image = torch.clamp((batch["image"] + 1.0) / 2.0,
min=0.0, max=1.0)
- mask = torch.clamp((batch["mask"]+1.0)/2.0,
+ mask = torch.clamp((batch["mask"] + 1.0) / 2.0,
min=0.0, max=1.0)
- predicted_image = torch.clamp((x_samples_ddim+1.0)/2.0,
+ predicted_image = torch.clamp((x_samples_ddim + 1.0) / 2.0,
min=0.0, max=1.0)
- inpainted = (1-mask)*image+mask*predicted_image
- inpainted = inpainted.cpu().numpy().transpose(0,2,3,1)[0]*255
+ inpainted = (1 - mask) * image + mask * predicted_image
+ inpainted = inpainted.cpu().numpy().transpose(0, 2, 3, 1)[0] * 255
Image.fromarray(inpainted.astype(np.uint8)).save(outpath)
diff --git a/scripts/knn2img.py b/scripts/knn2img.py
index e6eaaecab..df5793c78 100644
--- a/scripts/knn2img.py
+++ b/scripts/knn2img.py
@@ -1,22 +1,25 @@
-import argparse, os, sys, glob
+import argparse
import clip
+import glob
+import numpy as np
+import os
+import scann
+import sys
+import time
import torch
import torch.nn as nn
-import numpy as np
-from omegaconf import OmegaConf
from PIL import Image
-from tqdm import tqdm, trange
-from itertools import islice
from einops import rearrange, repeat
-from torchvision.utils import make_grid
-import scann
-import time
+from itertools import islice
from multiprocessing import cpu_count
+from omegaconf import OmegaConf
+from torchvision.utils import make_grid
+from tqdm import tqdm, trange
-from ldm.util import instantiate_from_config, parallel_data_prefetch
from ldm.models.diffusion.ddim import DDIMSampler
from ldm.models.diffusion.plms import PLMSSampler
from ldm.modules.encoders.modules import FrozenClipImageEmbedder, FrozenCLIPTextEmbedder
+from ldm.util import instantiate_from_config, parallel_data_prefetch
DATABASES = [
"openimages",
@@ -134,7 +137,7 @@ def load_searcher(self):
def search(self, x, k):
if self.searcher is None and self.database['embedding'].shape[0] < 2e4:
- self.train_searcher(k) # quickly fit searcher on the fly for small databases
+ self.train_searcher(k) # quickly fit searcher on the fly for small databases
assert self.searcher is not None, 'Cannot search with uninitialized searcher'
if isinstance(x, torch.Tensor):
x = x.detach().cpu().numpy()
diff --git a/scripts/sample_diffusion.py b/scripts/sample_diffusion.py
index 876fe3c36..877d0e7f5 100644
--- a/scripts/sample_diffusion.py
+++ b/scripts/sample_diffusion.py
@@ -1,17 +1,22 @@
-import argparse, os, sys, glob, datetime, yaml
-import torch
-import time
+import argparse
+import datetime
+import glob
import numpy as np
-from tqdm import trange
-
-from omegaconf import OmegaConf
+import os
+import sys
+import time
+import torch
+import yaml
from PIL import Image
+from omegaconf import OmegaConf
+from tqdm import trange
from ldm.models.diffusion.ddim import DDIMSampler
from ldm.util import instantiate_from_config
rescale = lambda x: (x + 1.) / 2.
+
def custom_to_pil(x):
x = x.detach().cpu()
x = torch.clamp(x, -1., 1.)
@@ -54,8 +59,6 @@ def logs2pil(logs, keys=["sample"]):
def convsample(model, shape, return_intermediates=True,
verbose=True,
make_prog_row=False):
-
-
if not make_prog_row:
return model.p_sample_loop(None, shape,
return_intermediates=return_intermediates, verbose=verbose)
@@ -71,14 +74,12 @@ def convsample_ddim(model, steps, shape, eta=1.0
ddim = DDIMSampler(model)
bs = shape[0]
shape = shape[1:]
- samples, intermediates = ddim.sample(steps, batch_size=bs, shape=shape, eta=eta, verbose=False,)
+ samples, intermediates = ddim.sample(steps, batch_size=bs, shape=shape, eta=eta, verbose=False, )
return samples, intermediates
@torch.no_grad()
-def make_convolutional_sample(model, batch_size, vanilla=False, custom_steps=None, eta=1.0,):
-
-
+def make_convolutional_sample(model, batch_size, vanilla=False, custom_steps=None, eta=1.0, ):
log = dict()
shape = [batch_size,
@@ -92,7 +93,7 @@ def make_convolutional_sample(model, batch_size, vanilla=False, custom_steps=Non
sample, progrow = convsample(model, shape,
make_prog_row=True)
else:
- sample, intermediates = convsample_ddim(model, steps=custom_steps, shape=shape,
+ sample, intermediates = convsample_ddim(model, steps=custom_steps, shape=shape,
eta=eta)
t1 = time.time()
@@ -105,15 +106,15 @@ def make_convolutional_sample(model, batch_size, vanilla=False, custom_steps=Non
print(f'Throughput for this batch: {log["throughput"]}')
return log
+
def run(model, logdir, batch_size=50, vanilla=False, custom_steps=None, eta=None, n_samples=50000, nplog=None):
if vanilla:
print(f'Using Vanilla DDPM sampling with {model.num_timesteps} sampling steps.')
else:
print(f'Using DDIM sampling with {custom_steps} sampling steps and eta={eta}')
-
tstart = time.time()
- n_saved = len(glob.glob(os.path.join(logdir,'*.png')))-1
+ n_saved = len(glob.glob(os.path.join(logdir, '*.png'))) - 1
# path = logdir
if model.cond_stage_model is None:
all_images = []
@@ -135,7 +136,7 @@ def run(model, logdir, batch_size=50, vanilla=False, custom_steps=None, eta=None
np.savez(nppath, all_img)
else:
- raise NotImplementedError('Currently only sampling for unconditional models supported.')
+ raise NotImplementedError('Currently only sampling for unconditional models supported.')
print(f"sampling of {n_saved} images finished in {(time.time() - tstart) / 60.:.2f} minutes.")
@@ -219,7 +220,7 @@ def get_parser():
def load_model_from_config(config, sd):
model = instantiate_from_config(config)
- model.load_state_dict(sd,strict=False)
+ model.load_state_dict(sd, strict=False)
model.cuda()
model.eval()
return model
@@ -305,9 +306,8 @@ def load_model(config, ckpt, gpu, eval_mode):
yaml.dump(sampling_conf, f, default_flow_style=False)
print(sampling_conf)
-
run(model, imglogdir, eta=opt.eta,
- vanilla=opt.vanilla_sample, n_samples=opt.n_samples, custom_steps=opt.custom_steps,
+ vanilla=opt.vanilla_sample, n_samples=opt.n_samples, custom_steps=opt.custom_steps,
batch_size=opt.batch_size, nplog=numpylogdir)
print("done.")
diff --git a/scripts/train_searcher.py b/scripts/train_searcher.py
index 1e7904889..83d219c6c 100644
--- a/scripts/train_searcher.py
+++ b/scripts/train_searcher.py
@@ -1,8 +1,9 @@
-import os, sys
-import numpy as np
-import scann
import argparse
import glob
+import numpy as np
+import os
+import scann
+import sys
from multiprocessing import cpu_count
from tqdm import tqdm
@@ -25,9 +26,8 @@ def search_ah(searcher, dims_per_block, aiq_threshold, reorder_k):
return searcher.score_ah(dims_per_block, anisotropic_quantization_threshold=aiq_threshold).reorder(
reorder_k).build()
-def load_datapool(dpath):
-
+def load_datapool(dpath):
def load_single_file(saved_embeddings):
compressed = np.load(saved_embeddings)
database = {key: compressed[key] for key in compressed.files}
@@ -51,7 +51,8 @@ def load_multi_files(data_archive):
prefetched_data = parallel_data_prefetch(load_multi_files, data,
n_proc=min(len(data), cpu_count()), target_data_type='dict')
- data_pool = {key: np.concatenate([od[key] for od in prefetched_data], axis=1)[0] for key in prefetched_data[0].keys()}
+ data_pool = {key: np.concatenate([od[key] for od in prefetched_data], axis=1)[0] for key in
+ prefetched_data[0].keys()}
else:
raise ValueError(f'No npz-files in specified path "{dpath}" is this directory existing?')
@@ -67,8 +68,7 @@ def train_searcher(opt,
aiq_thld=0.2,
dims_per_block=2,
num_leaves=None,
- num_leaves_to_search=None,):
-
+ num_leaves_to_search=None, ):
data_pool = load_datapool(opt.database)
k = opt.knn
@@ -77,7 +77,8 @@ def train_searcher(opt,
# normalize
# embeddings =
- searcher = scann.scann_ops_pybind.builder(data_pool['embedding'] / np.linalg.norm(data_pool['embedding'], axis=1)[:, np.newaxis], k, metric)
+ searcher = scann.scann_ops_pybind.builder(
+ data_pool['embedding'] / np.linalg.norm(data_pool['embedding'], axis=1)[:, np.newaxis], k, metric)
pool_size = data_pool['embedding'].shape[0]
print(*(['#'] * 100))
@@ -115,7 +116,7 @@ def train_searcher(opt,
print(f'num_leaves_to_search: {num_leaves_to_search}')
# self.searcher = self.search_ah(searcher, dims_per_block, aiq_thld, reorder_k)
searcher = search_partioned_ah(searcher, dims_per_block, aiq_thld, reorder_k,
- partioning_trainsize, num_leaves, num_leaves_to_search)
+ partioning_trainsize, num_leaves, num_leaves_to_search)
print('Finish training searcher')
searcher_savedir = opt.target_path
@@ -123,6 +124,7 @@ def train_searcher(opt,
searcher.serialize(searcher_savedir)
print(f'Saved trained searcher under "{searcher_savedir}"')
+
if __name__ == '__main__':
sys.path.append(os.getcwd())
parser = argparse.ArgumentParser()
@@ -142,6 +144,6 @@ def train_searcher(opt,
type=int,
help='number of nearest neighbors, for which the searcher shall be optimized')
- opt, _ = parser.parse_known_args()
+ opt, _ = parser.parse_known_args()
- train_searcher(opt,)
\ No newline at end of file
+ train_searcher(opt, )
diff --git a/scripts/txt2img.py b/scripts/txt2img.py
index da77e1a03..9a41f258f 100644
--- a/scripts/txt2img.py
+++ b/scripts/txt2img.py
@@ -1,20 +1,23 @@
-import argparse, os, sys, glob
-import torch
+import argparse
+import glob
import numpy as np
-from omegaconf import OmegaConf
+import os
+import sys
+import time
+import torch
from PIL import Image
-from tqdm import tqdm, trange
-from itertools import islice
+from contextlib import contextmanager, nullcontext
from einops import rearrange
-from torchvision.utils import make_grid
-import time
+from itertools import islice
+from omegaconf import OmegaConf
from pytorch_lightning import seed_everything
from torch import autocast
-from contextlib import contextmanager, nullcontext
+from torchvision.utils import make_grid
+from tqdm import tqdm, trange
-from ldm.util import instantiate_from_config
from ldm.models.diffusion.ddim import DDIMSampler
from ldm.models.diffusion.plms import PLMSSampler
+from ldm.util import instantiate_from_config
def chunk(it, size):
@@ -220,7 +223,7 @@ def main():
if opt.fixed_code:
start_code = torch.randn([opt.n_samples, opt.C, opt.H // opt.f, opt.W // opt.f], device=device)
- precision_scope = autocast if opt.precision=="autocast" else nullcontext
+ precision_scope = autocast if opt.precision == "autocast" else nullcontext
with torch.no_grad():
with precision_scope("cuda"):
with model.ema_scope():
diff --git a/setup.py b/setup.py
index a24d54167..cb1533632 100644
--- a/setup.py
+++ b/setup.py
@@ -10,4 +10,4 @@
'numpy',
'tqdm',
],
-)
\ No newline at end of file
+)