Note
New to DeepInverse? Get started with the basics with the 5 minute quickstart tutorial..
Uncertainty quantification with PnP-ULA.#
This code shows you how to use sampling algorithms to quantify uncertainty of a reconstruction from incomplete and noisy measurements.
ULA obtains samples by running the following iteration:
where \(z_k \sim \mathcal{N}(0, I)\) is a Gaussian random variable, \(\eta\) is the step size and \(\alpha\) is a parameter controlling the regularization.
The PnP-ULA method is described in the paper Laumont et al.[1].
import deepinv as dinv
from deepinv.utils.plotting import plot
import torch
from deepinv.utils import load_example
Load image from the internet#
This example uses an image of Messi.
device = dinv.utils.get_device()
x = load_example("messi.jpg", img_size=32).to(device)
Selected GPU 0 with 4759.25 MiB free memory
Define forward operator and noise model#
This example uses inpainting as the forward operator and Gaussian noise as the noise model.
sigma = 0.1 # noise level
physics = dinv.physics.Inpainting(mask=0.5, img_size=x.shape[1:], device=device)
physics.noise_model = dinv.physics.GaussianNoise(sigma=sigma)
# Set the global random seed from pytorch to ensure reproducibility of the example.
torch.manual_seed(0)
<torch._C.Generator object at 0x7f05ba810710>
Define the likelihood#
Since the noise model is Gaussian, the negative log-likelihood is the L2 loss.
# load Gaussian Likelihood
likelihood = dinv.optim.data_fidelity.L2(sigma=sigma)
Define the prior#
The score a distribution can be approximated using Tweedieβs formula via the
deepinv.optim.ScorePrior class.
This example uses a pretrained DnCNN model.
From a Bayesian point of view, the score plays the role of the gradient of the
negative log prior
The hyperparameter sigma_denoiser (\(sigma\)) controls the strength of the prior.
In this example, we use a pretrained DnCNN model using the deepinv.loss.FNEJacobianSpectralNorm loss,
which makes sure that the denoiser is firmly non-expansive (see Terris et al.[2]), and helps to
stabilize the sampling algorithm.
sigma_denoiser = 2 / 255
prior = dinv.optim.ScorePrior(
denoiser=dinv.models.DnCNN(pretrained="download_lipschitz")
).to(device)
Create the MCMC sampler#
Here we use the Unadjusted Langevin Algorithm (ULA) to sample from the posterior defined in
deepinv.sampling.ULAIterator.
The hyperparameter step_size controls the step size of the MCMC sampler,
regularization controls the strength of the prior and
iterations controls the number of iterations of the sampler.
regularization = 0.9
step_size = 0.01 * (sigma**2)
iterations = int(5e3) if torch.cuda.is_available() else 10
params = {
"step_size": step_size,
"alpha": regularization,
"sigma": sigma_denoiser,
}
f = dinv.sampling.sampling_builder(
"ULA",
prior=prior,
data_fidelity=likelihood,
max_iter=iterations,
params_algo=params,
thinning=1,
verbose=True,
)
Generate the measurement#
We apply the forward model to generate the noisy measurement.
Run sampling algorithm and plot results#
The sampling algorithm returns the posterior mean and variance. We compare the posterior mean with a simple linear reconstruction.
mean, var = f.sample(y, physics)
# compute linear inverse
x_lin = physics.A_adjoint(y)
# compute PSNR
print(f"Linear reconstruction PSNR: {dinv.metric.PSNR()(x, x_lin).item():.2f} dB")
print(f"Posterior mean PSNR: {dinv.metric.PSNR()(x, mean).item():.2f} dB")
# plot results
error = (mean - x).abs().sum(dim=1).unsqueeze(1) # per pixel average abs. error
std = var.sum(dim=1).unsqueeze(1).sqrt() # per pixel average standard dev.
imgs = [x_lin, x, mean, std / std.flatten().max(), error / error.flatten().max()]
plot(
imgs,
titles=["measurement", "ground truth", "post. mean", "post. std", "abs. error"],
)

0%| | 0/5000 [00:00<?, ?it/s]
0%| | 25/5000 [00:00<00:20, 244.90it/s]
1%| | 56/5000 [00:00<00:17, 281.22it/s]
2%|β | 86/5000 [00:00<00:17, 288.78it/s]
2%|β | 117/5000 [00:00<00:16, 296.23it/s]
3%|β | 148/5000 [00:00<00:16, 300.54it/s]
4%|β | 179/5000 [00:00<00:15, 302.73it/s]
4%|β | 210/5000 [00:00<00:15, 304.34it/s]
5%|β | 241/5000 [00:00<00:15, 305.26it/s]
5%|β | 272/5000 [00:00<00:15, 305.81it/s]
6%|β | 303/5000 [00:01<00:15, 304.82it/s]
7%|β | 334/5000 [00:01<00:15, 305.30it/s]
7%|β | 365/5000 [00:01<00:15, 305.39it/s]
8%|β | 396/5000 [00:01<00:15, 305.23it/s]
9%|β | 427/5000 [00:01<00:14, 304.96it/s]
9%|β | 458/5000 [00:01<00:14, 305.70it/s]
10%|β | 489/5000 [00:01<00:14, 306.14it/s]
10%|β | 520/5000 [00:01<00:14, 306.27it/s]
11%|β | 551/5000 [00:01<00:14, 306.43it/s]
12%|ββ | 582/5000 [00:01<00:14, 306.46it/s]
12%|ββ | 613/5000 [00:02<00:14, 306.90it/s]
13%|ββ | 644/5000 [00:02<00:14, 307.27it/s]
14%|ββ | 675/5000 [00:02<00:14, 307.54it/s]
14%|ββ | 706/5000 [00:02<00:13, 306.73it/s]
15%|ββ | 737/5000 [00:02<00:13, 306.90it/s]
15%|ββ | 768/5000 [00:02<00:13, 306.85it/s]
16%|ββ | 799/5000 [00:02<00:13, 306.76it/s]
17%|ββ | 830/5000 [00:02<00:13, 306.54it/s]
17%|ββ | 861/5000 [00:02<00:13, 304.53it/s]
18%|ββ | 892/5000 [00:02<00:13, 305.25it/s]
18%|ββ | 923/5000 [00:03<00:13, 305.84it/s]
19%|ββ | 954/5000 [00:03<00:13, 306.21it/s]
20%|ββ | 985/5000 [00:03<00:13, 306.30it/s]
20%|ββ | 1016/5000 [00:03<00:13, 303.06it/s]
21%|ββ | 1047/5000 [00:03<00:13, 299.44it/s]
22%|βββ | 1077/5000 [00:03<00:13, 297.11it/s]
22%|βββ | 1107/5000 [00:03<00:13, 295.46it/s]
23%|βββ | 1137/5000 [00:03<00:13, 293.87it/s]
23%|βββ | 1167/5000 [00:03<00:13, 293.02it/s]
24%|βββ | 1197/5000 [00:03<00:13, 292.53it/s]
25%|βββ | 1227/5000 [00:04<00:12, 292.29it/s]
25%|βββ | 1257/5000 [00:04<00:12, 291.49it/s]
26%|βββ | 1287/5000 [00:04<00:12, 291.44it/s]
26%|βββ | 1317/5000 [00:04<00:12, 291.24it/s]
27%|βββ | 1347/5000 [00:04<00:12, 291.46it/s]
28%|βββ | 1377/5000 [00:04<00:12, 291.63it/s]
28%|βββ | 1407/5000 [00:04<00:12, 291.34it/s]
29%|βββ | 1437/5000 [00:04<00:12, 291.43it/s]
29%|βββ | 1467/5000 [00:04<00:12, 291.65it/s]
30%|βββ | 1497/5000 [00:04<00:12, 291.84it/s]
31%|βββ | 1527/5000 [00:05<00:11, 291.79it/s]
31%|βββ | 1557/5000 [00:05<00:11, 292.03it/s]
32%|ββββ | 1587/5000 [00:05<00:11, 291.51it/s]
32%|ββββ | 1617/5000 [00:05<00:11, 291.48it/s]
33%|ββββ | 1647/5000 [00:05<00:11, 291.76it/s]
34%|ββββ | 1677/5000 [00:05<00:11, 291.92it/s]
34%|ββββ | 1707/5000 [00:05<00:11, 292.14it/s]
35%|ββββ | 1737/5000 [00:05<00:11, 291.95it/s]
35%|ββββ | 1767/5000 [00:05<00:11, 287.84it/s]
36%|ββββ | 1796/5000 [00:06<00:11, 283.68it/s]
36%|ββββ | 1825/5000 [00:06<00:11, 280.77it/s]
37%|ββββ | 1854/5000 [00:06<00:11, 278.75it/s]
38%|ββββ | 1882/5000 [00:06<00:11, 277.19it/s]
38%|ββββ | 1910/5000 [00:06<00:11, 275.89it/s]
39%|ββββ | 1938/5000 [00:06<00:11, 276.44it/s]
39%|ββββ | 1968/5000 [00:06<00:10, 280.76it/s]
40%|ββββ | 1998/5000 [00:06<00:10, 284.19it/s]
41%|ββββ | 2028/5000 [00:06<00:10, 286.33it/s]
41%|ββββ | 2058/5000 [00:06<00:10, 287.99it/s]
42%|βββββ | 2088/5000 [00:07<00:10, 288.73it/s]
42%|βββββ | 2118/5000 [00:07<00:09, 289.67it/s]
43%|βββββ | 2148/5000 [00:07<00:09, 290.44it/s]
44%|βββββ | 2178/5000 [00:07<00:09, 290.92it/s]
44%|βββββ | 2208/5000 [00:07<00:09, 291.00it/s]
45%|βββββ | 2238/5000 [00:07<00:09, 291.16it/s]
45%|βββββ | 2268/5000 [00:07<00:09, 291.33it/s]
46%|βββββ | 2298/5000 [00:07<00:09, 291.40it/s]
47%|βββββ | 2328/5000 [00:07<00:09, 291.46it/s]
47%|βββββ | 2358/5000 [00:07<00:09, 291.37it/s]
48%|βββββ | 2388/5000 [00:08<00:08, 291.15it/s]
48%|βββββ | 2418/5000 [00:08<00:08, 291.32it/s]
49%|βββββ | 2448/5000 [00:08<00:08, 291.45it/s]
50%|βββββ | 2478/5000 [00:08<00:08, 291.52it/s]
50%|βββββ | 2508/5000 [00:08<00:08, 291.80it/s]
51%|βββββ | 2538/5000 [00:08<00:08, 291.92it/s]
51%|ββββββ | 2568/5000 [00:08<00:08, 292.18it/s]
52%|ββββββ | 2598/5000 [00:08<00:08, 291.98it/s]
53%|ββββββ | 2628/5000 [00:08<00:08, 291.74it/s]
53%|ββββββ | 2658/5000 [00:09<00:08, 291.52it/s]
54%|ββββββ | 2688/5000 [00:09<00:07, 291.40it/s]
54%|ββββββ | 2718/5000 [00:09<00:07, 291.56it/s]
55%|ββββββ | 2748/5000 [00:09<00:07, 291.15it/s]
56%|ββββββ | 2778/5000 [00:09<00:07, 290.45it/s]
56%|ββββββ | 2808/5000 [00:09<00:07, 290.02it/s]
57%|ββββββ | 2838/5000 [00:09<00:07, 290.41it/s]
57%|ββββββ | 2868/5000 [00:09<00:07, 290.80it/s]
58%|ββββββ | 2898/5000 [00:09<00:07, 290.96it/s]
59%|ββββββ | 2928/5000 [00:09<00:07, 291.02it/s]
59%|ββββββ | 2958/5000 [00:10<00:07, 290.75it/s]
60%|ββββββ | 2988/5000 [00:10<00:06, 291.05it/s]
60%|ββββββ | 3018/5000 [00:10<00:06, 291.17it/s]
61%|ββββββ | 3048/5000 [00:10<00:06, 290.93it/s]
62%|βββββββ | 3078/5000 [00:10<00:06, 291.13it/s]
62%|βββββββ | 3108/5000 [00:10<00:06, 291.11it/s]
63%|βββββββ | 3138/5000 [00:10<00:06, 291.28it/s]
63%|βββββββ | 3168/5000 [00:10<00:06, 291.42it/s]
64%|βββββββ | 3198/5000 [00:10<00:06, 291.69it/s]
65%|βββββββ | 3228/5000 [00:10<00:06, 291.57it/s]
65%|βββββββ | 3258/5000 [00:11<00:05, 291.67it/s]
66%|βββββββ | 3288/5000 [00:11<00:05, 291.87it/s]
66%|βββββββ | 3318/5000 [00:11<00:05, 291.72it/s]
67%|βββββββ | 3348/5000 [00:11<00:05, 291.70it/s]
68%|βββββββ | 3378/5000 [00:11<00:05, 291.10it/s]
68%|βββββββ | 3408/5000 [00:11<00:05, 290.93it/s]
69%|βββββββ | 3438/5000 [00:11<00:05, 291.11it/s]
69%|βββββββ | 3468/5000 [00:11<00:05, 291.44it/s]
70%|βββββββ | 3498/5000 [00:11<00:05, 291.61it/s]
71%|βββββββ | 3528/5000 [00:12<00:05, 291.56it/s]
71%|βββββββ | 3558/5000 [00:12<00:04, 291.61it/s]
72%|ββββββββ | 3588/5000 [00:12<00:04, 291.84it/s]
72%|ββββββββ | 3618/5000 [00:12<00:04, 291.79it/s]
73%|ββββββββ | 3648/5000 [00:12<00:04, 291.88it/s]
74%|ββββββββ | 3678/5000 [00:12<00:04, 291.87it/s]
74%|ββββββββ | 3708/5000 [00:12<00:04, 292.02it/s]
75%|ββββββββ | 3738/5000 [00:12<00:04, 292.12it/s]
75%|ββββββββ | 3768/5000 [00:12<00:04, 292.01it/s]
76%|ββββββββ | 3798/5000 [00:12<00:04, 291.82it/s]
77%|ββββββββ | 3828/5000 [00:13<00:04, 290.21it/s]
77%|ββββββββ | 3858/5000 [00:13<00:03, 290.85it/s]
78%|ββββββββ | 3888/5000 [00:13<00:03, 291.52it/s]
78%|ββββββββ | 3918/5000 [00:13<00:03, 291.26it/s]
79%|ββββββββ | 3948/5000 [00:13<00:03, 291.56it/s]
80%|ββββββββ | 3978/5000 [00:13<00:03, 287.54it/s]
80%|ββββββββ | 4008/5000 [00:13<00:03, 288.90it/s]
81%|ββββββββ | 4038/5000 [00:13<00:03, 289.72it/s]
81%|βββββββββ | 4068/5000 [00:13<00:03, 290.47it/s]
82%|βββββββββ | 4098/5000 [00:13<00:03, 290.91it/s]
83%|βββββββββ | 4128/5000 [00:14<00:02, 291.10it/s]
83%|βββββββββ | 4158/5000 [00:14<00:02, 291.07it/s]
84%|βββββββββ | 4188/5000 [00:14<00:02, 290.98it/s]
84%|βββββββββ | 4218/5000 [00:14<00:02, 291.19it/s]
85%|βββββββββ | 4248/5000 [00:14<00:02, 291.17it/s]
86%|βββββββββ | 4278/5000 [00:14<00:02, 291.17it/s]
86%|βββββββββ | 4308/5000 [00:14<00:02, 291.33it/s]
87%|βββββββββ | 4338/5000 [00:14<00:02, 291.64it/s]
87%|βββββββββ | 4368/5000 [00:14<00:02, 291.89it/s]
88%|βββββββββ | 4398/5000 [00:14<00:02, 292.16it/s]
89%|βββββββββ | 4428/5000 [00:15<00:01, 292.23it/s]
89%|βββββββββ | 4458/5000 [00:15<00:01, 292.41it/s]
90%|βββββββββ | 4488/5000 [00:15<00:01, 292.12it/s]
90%|βββββββββ | 4518/5000 [00:15<00:01, 292.01it/s]
91%|βββββββββ | 4548/5000 [00:15<00:01, 291.94it/s]
92%|ββββββββββ| 4578/5000 [00:15<00:01, 292.02it/s]
92%|ββββββββββ| 4608/5000 [00:15<00:01, 292.32it/s]
93%|ββββββββββ| 4638/5000 [00:15<00:01, 292.30it/s]
93%|ββββββββββ| 4668/5000 [00:15<00:01, 292.49it/s]
94%|ββββββββββ| 4698/5000 [00:16<00:01, 292.50it/s]
95%|ββββββββββ| 4728/5000 [00:16<00:00, 292.46it/s]
95%|ββββββββββ| 4758/5000 [00:16<00:00, 292.36it/s]
96%|ββββββββββ| 4788/5000 [00:16<00:00, 292.07it/s]
96%|ββββββββββ| 4818/5000 [00:16<00:00, 292.27it/s]
97%|ββββββββββ| 4848/5000 [00:16<00:00, 292.33it/s]
98%|ββββββββββ| 4878/5000 [00:16<00:00, 292.23it/s]
98%|ββββββββββ| 4908/5000 [00:16<00:00, 292.14it/s]
99%|ββββββββββ| 4938/5000 [00:16<00:00, 292.39it/s]
99%|ββββββββββ| 4968/5000 [00:16<00:00, 292.46it/s]
100%|ββββββββββ| 4998/5000 [00:17<00:00, 292.35it/s]
100%|ββββββββββ| 5000/5000 [00:17<00:00, 293.23it/s]
Iteration 4999, current converge crit. = 1.43E-05, objective = 1.00E-03
Iteration 4999, current converge crit. = 3.42E-04, objective = 1.00E-03
Linear reconstruction PSNR: 8.55 dB
Posterior mean PSNR: 22.31 dB
- References:
Total running time of the script: (0 minutes 17.791 seconds)