Uncertainty quantification with PnP-ULA.#

This code shows you how to use sampling algorithms to quantify uncertainty of a reconstruction from incomplete and noisy measurements.

ULA obtains samples by running the following iteration:

\[x_{k+1} = x_k + \alpha \eta \nabla \log p_{\sigma}(x_k) + \eta \nabla \log p(y|x_k) + \sqrt{2 \eta} z_k\]

where \(z_k \sim \mathcal{N}(0, I)\) is a Gaussian random variable, \(\eta\) is the step size and \(\alpha\) is a parameter controlling the regularization.

The PnP-ULA method is described in the paper Laumont et al.[1].

import deepinv as dinv
from deepinv.utils.plotting import plot
import torch
from deepinv.utils import load_example

Load image from the internet#

This example uses an image of Messi.

device = dinv.utils.get_device()

x = load_example("messi.jpg", img_size=32).to(device)
Selected GPU 0 with 4939.25 MiB free memory

Define forward operator and noise model#

This example uses inpainting as the forward operator and Gaussian noise as the noise model.

sigma = 0.1  # noise level
physics = dinv.physics.Inpainting(mask=0.5, img_size=x.shape[1:], device=device)
physics.noise_model = dinv.physics.GaussianNoise(sigma=sigma)

# Set the global random seed from pytorch to ensure reproducibility of the example.
torch.manual_seed(0)
<torch._C.Generator object at 0x7f5de9b54090>

Define the likelihood#

Since the noise model is Gaussian, the negative log-likelihood is the L2 loss.

\[-\log p(y|x) \propto \frac{1}{2\sigma^2} \|y-Ax\|^2\]
# load Gaussian Likelihood
likelihood = dinv.optim.data_fidelity.L2(sigma=sigma)

Define the prior#

The score a distribution can be approximated using Tweedie’s formula via the deepinv.optim.ScorePrior class.

\[\nabla \log p_{\sigma}(x) \approx \frac{1}{\sigma^2} \left(D(x,\sigma)-x\right)\]

This example uses a pretrained DnCNN model. From a Bayesian point of view, the score plays the role of the gradient of the negative log prior The hyperparameter sigma_denoiser (\(sigma\)) controls the strength of the prior.

In this example, we use a pretrained DnCNN model using the deepinv.loss.FNEJacobianSpectralNorm loss, which makes sure that the denoiser is firmly non-expansive (see Terris et al.[2]), and helps to stabilize the sampling algorithm.

sigma_denoiser = 2 / 255
prior = dinv.optim.ScorePrior(
    denoiser=dinv.models.DnCNN(pretrained="download_lipschitz")
).to(device)

Create the MCMC sampler#

Here we use the Unadjusted Langevin Algorithm (ULA) to sample from the posterior defined in deepinv.sampling.ULAIterator. The hyperparameter step_size controls the step size of the MCMC sampler, regularization controls the strength of the prior and iterations controls the number of iterations of the sampler.

regularization = 0.9
step_size = 0.01 * (sigma**2)
iterations = int(5e3) if torch.cuda.is_available() else 10
params = {
    "step_size": step_size,
    "alpha": regularization,
    "sigma": sigma_denoiser,
}
f = dinv.sampling.sampling_builder(
    "ULA",
    prior=prior,
    data_fidelity=likelihood,
    max_iter=iterations,
    params_algo=params,
    thinning=1,
    verbose=True,
)

Generate the measurement#

We apply the forward model to generate the noisy measurement.

y = physics(x)

Run sampling algorithm and plot results#

The sampling algorithm returns the posterior mean and variance. We compare the posterior mean with a simple linear reconstruction.

mean, var = f.sample(y, physics)

# compute linear inverse
x_lin = physics.A_adjoint(y)

# compute PSNR
print(f"Linear reconstruction PSNR: {dinv.metric.PSNR()(x, x_lin).item():.2f} dB")
print(f"Posterior mean PSNR: {dinv.metric.PSNR()(x, mean).item():.2f} dB")

# plot results
error = (mean - x).abs().sum(dim=1).unsqueeze(1)  # per pixel average abs. error
std = var.sum(dim=1).unsqueeze(1).sqrt()  # per pixel average standard dev.
imgs = [x_lin, x, mean, std / std.flatten().max(), error / error.flatten().max()]
plot(
    imgs,
    titles=["measurement", "ground truth", "post. mean", "post. std", "abs. error"],
)
measurement, ground truth, post. mean, post. std, abs. error
  0%|          | 0/5000 [00:00<?, ?it/s]
  2%|▏         | 76/5000 [00:00<00:06, 759.15it/s]
  3%|β–Ž         | 165/5000 [00:00<00:05, 833.42it/s]
  5%|▍         | 249/5000 [00:00<00:05, 811.67it/s]
  7%|β–‹         | 338/5000 [00:00<00:05, 841.43it/s]
  9%|β–Š         | 427/5000 [00:00<00:05, 856.89it/s]
 10%|β–ˆ         | 515/5000 [00:00<00:05, 864.34it/s]
 12%|β–ˆβ–        | 604/5000 [00:00<00:05, 870.48it/s]
 14%|β–ˆβ–        | 693/5000 [00:00<00:04, 874.90it/s]
 16%|β–ˆβ–Œ        | 781/5000 [00:00<00:04, 875.78it/s]
 17%|β–ˆβ–‹        | 869/5000 [00:01<00:04, 866.92it/s]
 19%|β–ˆβ–‰        | 956/5000 [00:01<00:04, 853.78it/s]
 21%|β–ˆβ–ˆ        | 1042/5000 [00:01<00:04, 844.27it/s]
 23%|β–ˆβ–ˆβ–Ž       | 1127/5000 [00:01<00:04, 832.53it/s]
 24%|β–ˆβ–ˆβ–       | 1211/5000 [00:01<00:04, 823.87it/s]
 26%|β–ˆβ–ˆβ–Œ       | 1294/5000 [00:01<00:04, 816.68it/s]
 28%|β–ˆβ–ˆβ–Š       | 1376/5000 [00:01<00:04, 812.28it/s]
 29%|β–ˆβ–ˆβ–‰       | 1458/5000 [00:01<00:04, 810.03it/s]
 31%|β–ˆβ–ˆβ–ˆ       | 1540/5000 [00:01<00:04, 808.98it/s]
 32%|β–ˆβ–ˆβ–ˆβ–      | 1621/5000 [00:01<00:04, 798.91it/s]
 34%|β–ˆβ–ˆβ–ˆβ–      | 1701/5000 [00:02<00:04, 789.22it/s]
 36%|β–ˆβ–ˆβ–ˆβ–Œ      | 1780/5000 [00:02<00:04, 774.61it/s]
 37%|β–ˆβ–ˆβ–ˆβ–‹      | 1858/5000 [00:02<00:04, 715.06it/s]
 39%|β–ˆβ–ˆβ–ˆβ–Š      | 1935/5000 [00:02<00:04, 728.41it/s]
 40%|β–ˆβ–ˆβ–ˆβ–ˆ      | 2011/5000 [00:02<00:04, 736.56it/s]
 42%|β–ˆβ–ˆβ–ˆβ–ˆβ–     | 2089/5000 [00:02<00:03, 748.25it/s]
 43%|β–ˆβ–ˆβ–ˆβ–ˆβ–Ž     | 2169/5000 [00:02<00:03, 762.68it/s]
 45%|β–ˆβ–ˆβ–ˆβ–ˆβ–Œ     | 2250/5000 [00:02<00:03, 775.26it/s]
 47%|β–ˆβ–ˆβ–ˆβ–ˆβ–‹     | 2331/5000 [00:02<00:03, 783.41it/s]
 48%|β–ˆβ–ˆβ–ˆβ–ˆβ–Š     | 2412/5000 [00:02<00:03, 789.93it/s]
 50%|β–ˆβ–ˆβ–ˆβ–ˆβ–‰     | 2492/5000 [00:03<00:03, 791.43it/s]
 51%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–    | 2573/5000 [00:03<00:03, 794.76it/s]
 53%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Ž    | 2654/5000 [00:03<00:02, 796.79it/s]
 55%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–    | 2735/5000 [00:03<00:02, 799.70it/s]
 56%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–‹    | 2816/5000 [00:03<00:02, 800.20it/s]
 58%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Š    | 2897/5000 [00:03<00:02, 801.33it/s]
 60%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–‰    | 2978/5000 [00:03<00:02, 802.16it/s]
 61%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ    | 3059/5000 [00:03<00:02, 803.21it/s]
 63%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Ž   | 3140/5000 [00:03<00:02, 804.21it/s]
 64%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–   | 3221/5000 [00:03<00:02, 803.37it/s]
 66%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Œ   | 3302/5000 [00:04<00:02, 803.55it/s]
 68%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Š   | 3383/5000 [00:04<00:02, 803.94it/s]
 69%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–‰   | 3464/5000 [00:04<00:01, 803.19it/s]
 71%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ   | 3545/5000 [00:04<00:01, 803.74it/s]
 73%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Ž  | 3626/5000 [00:04<00:01, 803.96it/s]
 74%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–  | 3707/5000 [00:04<00:01, 796.54it/s]
 76%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Œ  | 3787/5000 [00:04<00:01, 787.39it/s]
 77%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–‹  | 3866/5000 [00:04<00:01, 781.85it/s]
 79%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–‰  | 3945/5000 [00:04<00:01, 777.26it/s]
 80%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ  | 4023/5000 [00:05<00:01, 772.31it/s]
 82%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ– | 4101/5000 [00:05<00:01, 771.67it/s]
 84%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Ž | 4181/5000 [00:05<00:01, 777.67it/s]
 85%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Œ | 4262/5000 [00:05<00:00, 784.56it/s]
 87%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–‹ | 4343/5000 [00:05<00:00, 790.69it/s]
 88%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Š | 4424/5000 [00:05<00:00, 794.56it/s]
 90%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ | 4505/5000 [00:05<00:00, 796.85it/s]
 92%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–| 4586/5000 [00:05<00:00, 798.79it/s]
 93%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Ž| 4667/5000 [00:05<00:00, 801.42it/s]
 95%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–| 4748/5000 [00:05<00:00, 801.60it/s]
 97%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–‹| 4829/5000 [00:06<00:00, 803.67it/s]
 98%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Š| 4910/5000 [00:06<00:00, 803.41it/s]
100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–‰| 4991/5000 [00:06<00:00, 803.60it/s]
100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 5000/5000 [00:06<00:00, 801.15it/s]
Iteration 4999, current converge crit. = 1.43E-05, objective = 1.00E-03
Iteration 4999, current converge crit. = 3.42E-04, objective = 1.00E-03
Linear reconstruction PSNR: 8.55 dB
Posterior mean PSNR: 22.31 dB
References:

Total running time of the script: (0 minutes 6.678 seconds)

Gallery generated by Sphinx-Gallery