Note
New to DeepInverse? Get started with the basics with the 5 minute quickstart tutorial..
Uncertainty quantification with PnP-ULA.#
This code shows you how to use sampling algorithms to quantify uncertainty of a reconstruction from incomplete and noisy measurements.
ULA obtains samples by running the following iteration:
where \(z_k \sim \mathcal{N}(0, I)\) is a Gaussian random variable, \(\eta\) is the step size and \(\alpha\) is a parameter controlling the regularization.
The PnP-ULA method is described in the paper Laumont et al.[1].
import deepinv as dinv
from deepinv.utils.plotting import plot
import torch
from deepinv.utils import load_example
Load image from the internet#
This example uses an image of Messi.
device = dinv.utils.get_device()
x = load_example("messi.jpg", img_size=32).to(device)
Selected GPU 0 with 1946.125 MiB free memory
Define forward operator and noise model#
This example uses inpainting as the forward operator and Gaussian noise as the noise model.
sigma = 0.1 # noise level
physics = dinv.physics.Inpainting(mask=0.5, img_size=x.shape[1:], device=device)
physics.noise_model = dinv.physics.GaussianNoise(sigma=sigma)
# Set the global random seed from pytorch to ensure reproducibility of the example.
torch.manual_seed(0)
<torch._C.Generator object at 0x7f9c5e137c50>
Define the likelihood#
Since the noise model is Gaussian, the negative log-likelihood is the L2 loss.
# load Gaussian Likelihood
likelihood = dinv.optim.data_fidelity.L2(sigma=sigma)
Define the prior#
The score a distribution can be approximated using Tweedieβs formula via the
deepinv.optim.ScorePrior class.
This example uses a pretrained DnCNN model.
From a Bayesian point of view, the score plays the role of the gradient of the
negative log prior
The hyperparameter sigma_denoiser (\(sigma\)) controls the strength of the prior.
In this example, we use a pretrained DnCNN model using the deepinv.loss.FNEJacobianSpectralNorm loss,
which makes sure that the denoiser is firmly non-expansive (see Terris et al.[2]), and helps to
stabilize the sampling algorithm.
sigma_denoiser = 2 / 255
prior = dinv.optim.ScorePrior(
denoiser=dinv.models.DnCNN(pretrained="download_lipschitz")
).to(device)
Create the MCMC sampler#
Here we use the Unadjusted Langevin Algorithm (ULA) to sample from the posterior defined in
deepinv.sampling.ULAIterator.
The hyperparameter step_size controls the step size of the MCMC sampler,
regularization controls the strength of the prior and
iterations controls the number of iterations of the sampler.
regularization = 0.9
step_size = 0.01 * (sigma**2)
iterations = int(5e3) if torch.cuda.is_available() else 10
params = {
"step_size": step_size,
"alpha": regularization,
"sigma": sigma_denoiser,
}
f = dinv.sampling.sampling_builder(
"ULA",
prior=prior,
data_fidelity=likelihood,
max_iter=iterations,
params_algo=params,
thinning=1,
verbose=True,
)
Generate the measurement#
We apply the forward model to generate the noisy measurement.
Run sampling algorithm and plot results#
The sampling algorithm returns the posterior mean and variance. We compare the posterior mean with a simple linear reconstruction.
mean, var = f.sample(y, physics)
# compute linear inverse
x_lin = physics.A_adjoint(y)
# compute PSNR
print(f"Linear reconstruction PSNR: {dinv.metric.PSNR()(x, x_lin).item():.2f} dB")
print(f"Posterior mean PSNR: {dinv.metric.PSNR()(x, mean).item():.2f} dB")
# plot results
error = (mean - x).abs().sum(dim=1).unsqueeze(1) # per pixel average abs. error
std = var.sum(dim=1).unsqueeze(1).sqrt() # per pixel average standard dev.
imgs = [x_lin, x, mean, std / std.flatten().max(), error / error.flatten().max()]
plot(
imgs,
titles=["measurement", "ground truth", "post. mean", "post. std", "abs. error"],
)

0%| | 0/5000 [00:00<?, ?it/s]
1%| | 55/5000 [00:00<00:09, 543.93it/s]
2%|β | 116/5000 [00:00<00:08, 582.53it/s]
4%|β | 177/5000 [00:00<00:08, 594.48it/s]
5%|β | 238/5000 [00:00<00:07, 599.97it/s]
6%|β | 299/5000 [00:00<00:07, 603.52it/s]
7%|β | 360/5000 [00:00<00:07, 604.55it/s]
8%|β | 421/5000 [00:00<00:07, 605.92it/s]
10%|β | 483/5000 [00:00<00:07, 607.53it/s]
11%|β | 545/5000 [00:00<00:07, 609.70it/s]
12%|ββ | 607/5000 [00:01<00:07, 610.35it/s]
13%|ββ | 669/5000 [00:01<00:07, 610.37it/s]
15%|ββ | 731/5000 [00:01<00:06, 610.41it/s]
16%|ββ | 793/5000 [00:01<00:06, 610.18it/s]
17%|ββ | 855/5000 [00:01<00:06, 610.42it/s]
18%|ββ | 917/5000 [00:01<00:06, 610.86it/s]
20%|ββ | 979/5000 [00:01<00:06, 610.52it/s]
21%|ββ | 1041/5000 [00:01<00:06, 603.50it/s]
22%|βββ | 1102/5000 [00:01<00:06, 596.98it/s]
23%|βββ | 1162/5000 [00:01<00:06, 592.17it/s]
24%|βββ | 1222/5000 [00:02<00:06, 588.81it/s]
26%|βββ | 1281/5000 [00:02<00:06, 586.74it/s]
27%|βββ | 1340/5000 [00:02<00:06, 584.77it/s]
28%|βββ | 1399/5000 [00:02<00:06, 579.18it/s]
29%|βββ | 1458/5000 [00:02<00:06, 579.94it/s]
30%|βββ | 1517/5000 [00:02<00:06, 580.18it/s]
32%|ββββ | 1576/5000 [00:02<00:05, 579.98it/s]
33%|ββββ | 1650/5000 [00:02<00:05, 626.14it/s]
35%|ββββ | 1746/5000 [00:02<00:04, 723.49it/s]
37%|ββββ | 1844/5000 [00:02<00:03, 799.47it/s]
39%|ββββ | 1939/5000 [00:03<00:03, 843.63it/s]
41%|ββββ | 2038/5000 [00:03<00:03, 884.87it/s]
43%|βββββ | 2137/5000 [00:03<00:03, 913.62it/s]
45%|βββββ | 2233/5000 [00:03<00:02, 926.81it/s]
47%|βββββ | 2330/5000 [00:03<00:02, 938.66it/s]
49%|βββββ | 2428/5000 [00:03<00:02, 950.64it/s]
51%|βββββ | 2526/5000 [00:03<00:02, 959.33it/s]
52%|ββββββ | 2622/5000 [00:03<00:02, 959.40it/s]
54%|ββββββ | 2720/5000 [00:03<00:02, 964.42it/s]
56%|ββββββ | 2818/5000 [00:03<00:02, 968.57it/s]
58%|ββββββ | 2915/5000 [00:04<00:02, 961.51it/s]
60%|ββββββ | 3013/5000 [00:04<00:02, 966.85it/s]
62%|βββββββ | 3112/5000 [00:04<00:01, 970.93it/s]
64%|βββββββ | 3210/5000 [00:04<00:01, 973.30it/s]
66%|βββββββ | 3308/5000 [00:04<00:01, 968.24it/s]
68%|βββββββ | 3406/5000 [00:04<00:01, 971.73it/s]
70%|βββββββ | 3505/5000 [00:04<00:01, 974.81it/s]
72%|ββββββββ | 3604/5000 [00:04<00:01, 976.68it/s]
74%|ββββββββ | 3703/5000 [00:04<00:01, 978.02it/s]
76%|ββββββββ | 3802/5000 [00:04<00:01, 979.13it/s]
78%|ββββββββ | 3900/5000 [00:05<00:01, 979.22it/s]
80%|ββββββββ | 3999/5000 [00:05<00:01, 979.96it/s]
82%|βββββββββ | 4097/5000 [00:05<00:00, 971.76it/s]
84%|βββββββββ | 4195/5000 [00:05<00:00, 967.66it/s]
86%|βββββββββ | 4294/5000 [00:05<00:00, 972.10it/s]
88%|βββββββββ | 4393/5000 [00:05<00:00, 975.24it/s]
90%|βββββββββ | 4492/5000 [00:05<00:00, 977.18it/s]
92%|ββββββββββ| 4590/5000 [00:05<00:00, 973.30it/s]
94%|ββββββββββ| 4689/5000 [00:05<00:00, 975.94it/s]
96%|ββββββββββ| 4788/5000 [00:05<00:00, 977.50it/s]
98%|ββββββββββ| 4887/5000 [00:06<00:00, 978.80it/s]
100%|ββββββββββ| 4986/5000 [00:06<00:00, 979.59it/s]
100%|ββββββββββ| 5000/5000 [00:06<00:00, 808.13it/s]
Iteration 4999, current converge crit. = 1.43E-05, objective = 1.00E-03
Iteration 4999, current converge crit. = 3.42E-04, objective = 1.00E-03
Linear reconstruction PSNR: 8.55 dB
Posterior mean PSNR: 22.31 dB
- References:
Total running time of the script: (0 minutes 6.591 seconds)