Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
144 changes: 143 additions & 1 deletion examples/community/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2480,4 +2480,146 @@ images = pipe(
).images
images[0].save("controlnet_and_adapter_inpaint.png")

```
```

## Diffusion Posterior Sampling Pipeline
* Reference paper
```
@article{chung2022diffusion,
title={Diffusion posterior sampling for general noisy inverse problems},
author={Chung, Hyungjin and Kim, Jeongsol and Mccann, Michael T and Klasky, Marc L and Ye, Jong Chul},
journal={arXiv preprint arXiv:2209.14687},
year={2022}
}
```
* This pipeline allows zero-shot conditional sampling from the posterior distribution $p(x|y)$, given observation on $y$, unconditional generative model $p(x)$ and differentiable operator $y=f(x)$.
* For example, $f(.)$ can be downsample operator, then $y$ is a downsampled image, and the pipeline becomes a super-resolution pipeline.
* To use this pipeline, you need to know your operator $f(.)$ and corrupted image $y$, and pass them during the call. For example, as in the main function of dps_pipeline.py, you need to first define the Gaussian blurring operator $f(.)$. The operator should be a callable nn.Module, with all the parameter gradient disabled:
```python
import torch.nn.functional as F
import scipy
from torch import nn

# define the Gaussian blurring operator first
class GaussialBlurOperator(nn.Module):
def __init__(self, kernel_size, intensity):
super().__init__()

class Blurkernel(nn.Module):
def __init__(self, blur_type='gaussian', kernel_size=31, std=3.0):
super().__init__()
self.blur_type = blur_type
self.kernel_size = kernel_size
self.std = std
self.seq = nn.Sequential(
nn.ReflectionPad2d(self.kernel_size//2),
nn.Conv2d(3, 3, self.kernel_size, stride=1, padding=0, bias=False, groups=3)
)
self.weights_init()

def forward(self, x):
return self.seq(x)

def weights_init(self):
if self.blur_type == "gaussian":
n = np.zeros((self.kernel_size, self.kernel_size))
n[self.kernel_size // 2,self.kernel_size // 2] = 1
k = scipy.ndimage.gaussian_filter(n, sigma=self.std)
k = torch.from_numpy(k)
self.k = k
for name, f in self.named_parameters():
f.data.copy_(k)
elif self.blur_type == "motion":
k = Kernel(size=(self.kernel_size, self.kernel_size), intensity=self.std).kernelMatrix
k = torch.from_numpy(k)
self.k = k
for name, f in self.named_parameters():
f.data.copy_(k)

def update_weights(self, k):
if not torch.is_tensor(k):
k = torch.from_numpy(k)
for name, f in self.named_parameters():
f.data.copy_(k)

def get_kernel(self):
return self.k

self.kernel_size = kernel_size
self.conv = Blurkernel(blur_type='gaussian',
kernel_size=kernel_size,
std=intensity)
self.kernel = self.conv.get_kernel()
self.conv.update_weights(self.kernel.type(torch.float32))

for param in self.parameters():
param.requires_grad=False

def forward(self, data, **kwargs):
return self.conv(data)

def transpose(self, data, **kwargs):
return data

def get_kernel(self):
return self.kernel.view(1, 1, self.kernel_size, self.kernel_size)
```
* Next, you should obtain the corrupted image $y$ by the operator. In this example, we generate $y$ from the source image $x$. However in practice, having the operator $f(.)$ and corrupted image $y$ is enough:
```python
# set up source image
src = Image.open('sample.png')
# read image into [1,3,H,W]
src = torch.from_numpy(np.array(src, dtype=np.float32)).permute(2,0,1)[None]
# normalize image to [-1,1]
src = (src / 127.5) - 1.0
src = src.to("cuda")

# set up operator and measurement
operator = GaussialBlurOperator(kernel_size=61, intensity=3.0).to("cuda")
measurement = operator(src)

# save the source and corrupted images
save_image((src+1.0)/2.0, "dps_src.png")
save_image((measurement+1.0)/2.0, "dps_mea.png")
```
* We provide an example pair of saved source and corrupted images, using the Gaussian blur operator above
* Source image:
* ![sample](https://github.com/tongdaxu/Images/assets/22267548/4d2a1216-08d1-4aeb-9ce3-7a2d87561d65)
* Gaussian blurred image:
* ![ddpm_generated_image](https://github.com/tongdaxu/Images/assets/22267548/65076258-344b-4ed8-b704-a04edaade8ae)
* You can download those image to run the example on your own.
* Next, we need to define a loss function used for diffusion posterior sample. For most of the cases, the RMSE is fine:
```python
def RMSELoss(yhat, y):
return torch.sqrt(torch.sum((yhat-y)**2))
```
* And next, as any other diffusion models, we need the score estimator and scheduler. As we are working with $256x256$ face images, we use ddmp-celebahq-256:
```python
# set up scheduler
scheduler = DDPMScheduler.from_pretrained("google/ddpm-celebahq-256")
scheduler.set_timesteps(1000)

# set up model
model = UNet2DModel.from_pretrained("google/ddpm-celebahq-256").to("cuda")
```
* And finally, run the pipeline:
```python
# finally, the pipeline
dpspipe = DPSPipeline(model, scheduler)
image = dpspipe(
measurement = measurement,
operator = operator,
loss_fn = RMSELoss,
zeta = 1.0,
).images[0]
image.save("dps_generated_image.png")
```
* The zeta is a hyperparameter that is in range of $[0,1]$. It need to be tuned for best effect. By setting zeta=1, you should be able to have the reconstructed result:
* Reconstructed image:
* ![sample](https://github.com/tongdaxu/Images/assets/22267548/0ceb5575-d42e-4f0b-99c0-50e69c982209)
* The reconstruction is perceptually similar to the source image, but different in details.
* In dps_pipeline.py, we also provide a super-resolution example, which should produce:
* Downsampled image:
* ![dps_mea](https://github.com/tongdaxu/Images/assets/22267548/ff6a33d6-26f0-42aa-88ce-f8a76ba45a13)
* Reconstructed image:
* ![dps_generated_image](https://github.com/tongdaxu/Images/assets/22267548/b74f084d-93f4-4845-83d8-44c0fa758a5f)
Loading