You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
title={Diffusion posterior sampling for general noisy inverse problems},
2490
+
author={Chung, Hyungjin and Kim, Jeongsol and Mccann, Michael T and Klasky, Marc L and Ye, Jong Chul},
2491
+
journal={arXiv preprint arXiv:2209.14687},
2492
+
year={2022}
2493
+
}
2494
+
```
2495
+
* This pipeline allows zero-shot conditional sampling from the posterior distribution $p(x|y)$, given observation on $y$, unconditional generative model $p(x)$ and differentiable operator $y=f(x)$.
2496
+
* For example, $f(.)$ can be downsample operator, then $y$ is a downsampled image, and the pipeline becomes a super-resolution pipeline.
2497
+
* To use this pipeline, you need to know your operator $f(.)$ and corrupted image $y$, and pass them during the call. For example, as in the main function of dps_pipeline.py, you need to first define the Gaussian blurring operator $f(.)$. The operator should be a callable nn.Module, with all the parameter gradient disabled:
* Next, you should obtain the corrupted image $y$ by the operator. In this example, we generate $y$ from the source image $x$. However in practice, having the operator $f(.)$ and corrupted image $y$ is enough:
* You can download those image to run the example on your own.
2591
+
* Next, we need to define a loss function used for diffusion posterior sample. For most of the cases, the RMSE is fine:
2592
+
```python
2593
+
def RMSELoss(yhat, y):
2594
+
return torch.sqrt(torch.sum((yhat-y)**2))
2595
+
```
2596
+
* And next, as any other diffusion models, we need the score estimator and scheduler. As we are working with $256x256$ face images, we use ddmp-celebahq-256:
model = UNet2DModel.from_pretrained("google/ddpm-celebahq-256").to("cuda")
2604
+
```
2605
+
* And finally, run the pipeline:
2606
+
```python
2607
+
# finally, the pipeline
2608
+
dpspipe = DPSPipeline(model, scheduler)
2609
+
image = dpspipe(
2610
+
measurement = measurement,
2611
+
operator = operator,
2612
+
loss_fn = RMSELoss,
2613
+
zeta = 1.0,
2614
+
).images[0]
2615
+
image.save("dps_generated_image.png")
2616
+
```
2617
+
* The zeta is a hyperparameter that is in range of $[0,1]$. It need to be tuned for best effect. By setting zeta=1, you should be able to have the reconstructed result:
0 commit comments