Skip to content

Commit 14a0d21

Browse files
authored
[Community Pipeline] Diffusion Posterior Sampling for General Noisy Inverse Problems (#5939)
* [community pipeline] dps impl * add type checking * pass ruff check * ruff formatter
1 parent ebf581e commit 14a0d21

File tree

2 files changed

+609
-1
lines changed

2 files changed

+609
-1
lines changed

examples/community/README.md

Lines changed: 143 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2480,4 +2480,146 @@ images = pipe(
24802480
).images
24812481
images[0].save("controlnet_and_adapter_inpaint.png")
24822482

2483-
```
2483+
```
2484+
2485+
## Diffusion Posterior Sampling Pipeline
2486+
* Reference paper
2487+
```
2488+
@article{chung2022diffusion,
2489+
title={Diffusion posterior sampling for general noisy inverse problems},
2490+
author={Chung, Hyungjin and Kim, Jeongsol and Mccann, Michael T and Klasky, Marc L and Ye, Jong Chul},
2491+
journal={arXiv preprint arXiv:2209.14687},
2492+
year={2022}
2493+
}
2494+
```
2495+
* This pipeline allows zero-shot conditional sampling from the posterior distribution $p(x|y)$, given observation on $y$, unconditional generative model $p(x)$ and differentiable operator $y=f(x)$.
2496+
* For example, $f(.)$ can be downsample operator, then $y$ is a downsampled image, and the pipeline becomes a super-resolution pipeline.
2497+
* To use this pipeline, you need to know your operator $f(.)$ and corrupted image $y$, and pass them during the call. For example, as in the main function of dps_pipeline.py, you need to first define the Gaussian blurring operator $f(.)$. The operator should be a callable nn.Module, with all the parameter gradient disabled:
2498+
```python
2499+
import torch.nn.functional as F
2500+
import scipy
2501+
from torch import nn
2502+
2503+
# define the Gaussian blurring operator first
2504+
class GaussialBlurOperator(nn.Module):
2505+
def __init__(self, kernel_size, intensity):
2506+
super().__init__()
2507+
2508+
class Blurkernel(nn.Module):
2509+
def __init__(self, blur_type='gaussian', kernel_size=31, std=3.0):
2510+
super().__init__()
2511+
self.blur_type = blur_type
2512+
self.kernel_size = kernel_size
2513+
self.std = std
2514+
self.seq = nn.Sequential(
2515+
nn.ReflectionPad2d(self.kernel_size//2),
2516+
nn.Conv2d(3, 3, self.kernel_size, stride=1, padding=0, bias=False, groups=3)
2517+
)
2518+
self.weights_init()
2519+
2520+
def forward(self, x):
2521+
return self.seq(x)
2522+
2523+
def weights_init(self):
2524+
if self.blur_type == "gaussian":
2525+
n = np.zeros((self.kernel_size, self.kernel_size))
2526+
n[self.kernel_size // 2,self.kernel_size // 2] = 1
2527+
k = scipy.ndimage.gaussian_filter(n, sigma=self.std)
2528+
k = torch.from_numpy(k)
2529+
self.k = k
2530+
for name, f in self.named_parameters():
2531+
f.data.copy_(k)
2532+
elif self.blur_type == "motion":
2533+
k = Kernel(size=(self.kernel_size, self.kernel_size), intensity=self.std).kernelMatrix
2534+
k = torch.from_numpy(k)
2535+
self.k = k
2536+
for name, f in self.named_parameters():
2537+
f.data.copy_(k)
2538+
2539+
def update_weights(self, k):
2540+
if not torch.is_tensor(k):
2541+
k = torch.from_numpy(k)
2542+
for name, f in self.named_parameters():
2543+
f.data.copy_(k)
2544+
2545+
def get_kernel(self):
2546+
return self.k
2547+
2548+
self.kernel_size = kernel_size
2549+
self.conv = Blurkernel(blur_type='gaussian',
2550+
kernel_size=kernel_size,
2551+
std=intensity)
2552+
self.kernel = self.conv.get_kernel()
2553+
self.conv.update_weights(self.kernel.type(torch.float32))
2554+
2555+
for param in self.parameters():
2556+
param.requires_grad=False
2557+
2558+
def forward(self, data, **kwargs):
2559+
return self.conv(data)
2560+
2561+
def transpose(self, data, **kwargs):
2562+
return data
2563+
2564+
def get_kernel(self):
2565+
return self.kernel.view(1, 1, self.kernel_size, self.kernel_size)
2566+
```
2567+
* Next, you should obtain the corrupted image $y$ by the operator. In this example, we generate $y$ from the source image $x$. However in practice, having the operator $f(.)$ and corrupted image $y$ is enough:
2568+
```python
2569+
# set up source image
2570+
src = Image.open('sample.png')
2571+
# read image into [1,3,H,W]
2572+
src = torch.from_numpy(np.array(src, dtype=np.float32)).permute(2,0,1)[None]
2573+
# normalize image to [-1,1]
2574+
src = (src / 127.5) - 1.0
2575+
src = src.to("cuda")
2576+
2577+
# set up operator and measurement
2578+
operator = GaussialBlurOperator(kernel_size=61, intensity=3.0).to("cuda")
2579+
measurement = operator(src)
2580+
2581+
# save the source and corrupted images
2582+
save_image((src+1.0)/2.0, "dps_src.png")
2583+
save_image((measurement+1.0)/2.0, "dps_mea.png")
2584+
```
2585+
* We provide an example pair of saved source and corrupted images, using the Gaussian blur operator above
2586+
* Source image:
2587+
* ![sample](https://github.com/tongdaxu/Images/assets/22267548/4d2a1216-08d1-4aeb-9ce3-7a2d87561d65)
2588+
* Gaussian blurred image:
2589+
* ![ddpm_generated_image](https://github.com/tongdaxu/Images/assets/22267548/65076258-344b-4ed8-b704-a04edaade8ae)
2590+
* You can download those image to run the example on your own.
2591+
* Next, we need to define a loss function used for diffusion posterior sample. For most of the cases, the RMSE is fine:
2592+
```python
2593+
def RMSELoss(yhat, y):
2594+
return torch.sqrt(torch.sum((yhat-y)**2))
2595+
```
2596+
* And next, as any other diffusion models, we need the score estimator and scheduler. As we are working with $256x256$ face images, we use ddmp-celebahq-256:
2597+
```python
2598+
# set up scheduler
2599+
scheduler = DDPMScheduler.from_pretrained("google/ddpm-celebahq-256")
2600+
scheduler.set_timesteps(1000)
2601+
2602+
# set up model
2603+
model = UNet2DModel.from_pretrained("google/ddpm-celebahq-256").to("cuda")
2604+
```
2605+
* And finally, run the pipeline:
2606+
```python
2607+
# finally, the pipeline
2608+
dpspipe = DPSPipeline(model, scheduler)
2609+
image = dpspipe(
2610+
measurement = measurement,
2611+
operator = operator,
2612+
loss_fn = RMSELoss,
2613+
zeta = 1.0,
2614+
).images[0]
2615+
image.save("dps_generated_image.png")
2616+
```
2617+
* The zeta is a hyperparameter that is in range of $[0,1]$. It need to be tuned for best effect. By setting zeta=1, you should be able to have the reconstructed result:
2618+
* Reconstructed image:
2619+
* ![sample](https://github.com/tongdaxu/Images/assets/22267548/0ceb5575-d42e-4f0b-99c0-50e69c982209)
2620+
* The reconstruction is perceptually similar to the source image, but different in details.
2621+
* In dps_pipeline.py, we also provide a super-resolution example, which should produce:
2622+
* Downsampled image:
2623+
* ![dps_mea](https://github.com/tongdaxu/Images/assets/22267548/ff6a33d6-26f0-42aa-88ce-f8a76ba45a13)
2624+
* Reconstructed image:
2625+
* ![dps_generated_image](https://github.com/tongdaxu/Images/assets/22267548/b74f084d-93f4-4845-83d8-44c0fa758a5f)

0 commit comments

Comments
 (0)