Note
New to DeepInverse? Get started with the basics with the 5 minute quickstart tutorial..
Uncertainty quantification with PnP-ULA.#
This code shows you how to use sampling algorithms to quantify uncertainty of a reconstruction from incomplete and noisy measurements.
ULA obtains samples by running the following iteration:
where \(z_k \sim \mathcal{N}(0, I)\) is a Gaussian random variable, \(\eta\) is the step size and \(\alpha\) is a parameter controlling the regularization.
The PnP-ULA method is described in the paper Laumont et al.[1].
import deepinv as dinv
from deepinv.utils.plotting import plot
import torch
from deepinv.utils import load_example
Load image from the internet#
This example uses an image of Messi.
device = dinv.utils.get_device()
x = load_example("messi.jpg", img_size=32).to(device)
Selected GPU 0 with 4605.25 MiB free memory
Define forward operator and noise model#
This example uses inpainting as the forward operator and Gaussian noise as the noise model.
sigma = 0.1 # noise level
physics = dinv.physics.Inpainting(mask=0.5, img_size=x.shape[1:], device=device)
physics.noise_model = dinv.physics.GaussianNoise(sigma=sigma)
# Set the global random seed from pytorch to ensure reproducibility of the example.
torch.manual_seed(0)
<torch._C.Generator object at 0x7fc5b53fca30>
Define the likelihood#
Since the noise model is Gaussian, the negative log-likelihood is the L2 loss.
# load Gaussian Likelihood
likelihood = dinv.optim.data_fidelity.L2(sigma=sigma)
Define the prior#
The score a distribution can be approximated using Tweedieβs formula via the
deepinv.optim.ScorePrior class.
This example uses a pretrained DnCNN model.
From a Bayesian point of view, the score plays the role of the gradient of the
negative log prior
The hyperparameter sigma_denoiser (\(sigma\)) controls the strength of the prior.
In this example, we use a pretrained DnCNN model using the deepinv.loss.FNEJacobianSpectralNorm loss,
which makes sure that the denoiser is firmly non-expansive (see Terris et al.[2]), and helps to
stabilize the sampling algorithm.
sigma_denoiser = 2 / 255
prior = dinv.optim.ScorePrior(
denoiser=dinv.models.DnCNN(pretrained="download_lipschitz")
).to(device)
Create the MCMC sampler#
Here we use the Unadjusted Langevin Algorithm (ULA) to sample from the posterior defined in
deepinv.sampling.ULAIterator.
The hyperparameter step_size controls the step size of the MCMC sampler,
regularization controls the strength of the prior and
iterations controls the number of iterations of the sampler.
regularization = 0.9
step_size = 0.01 * (sigma**2)
iterations = int(5e3) if torch.cuda.is_available() else 10
params = {
"step_size": step_size,
"alpha": regularization,
"sigma": sigma_denoiser,
}
f = dinv.sampling.sampling_builder(
"ULA",
prior=prior,
data_fidelity=likelihood,
max_iter=iterations,
params_algo=params,
thinning=1,
verbose=True,
)
Generate the measurement#
We apply the forward model to generate the noisy measurement.
Run sampling algorithm and plot results#
The sampling algorithm returns the posterior mean and variance. We compare the posterior mean with a simple linear reconstruction.
mean, var = f.sample(y, physics)
# compute linear inverse
x_lin = physics.A_adjoint(y)
# compute PSNR
print(f"Linear reconstruction PSNR: {dinv.metric.PSNR()(x, x_lin).item():.2f} dB")
print(f"Posterior mean PSNR: {dinv.metric.PSNR()(x, mean).item():.2f} dB")
# plot results
error = (mean - x).abs().sum(dim=1).unsqueeze(1) # per pixel average abs. error
std = var.sum(dim=1).unsqueeze(1).sqrt() # per pixel average standard dev.
imgs = [x_lin, x, mean, std / std.flatten().max(), error / error.flatten().max()]
plot(
imgs,
titles=["measurement", "ground truth", "post. mean", "post. std", "abs. error"],
)

0%| | 0/5000 [00:00<?, ?it/s]
1%| | 30/5000 [00:00<00:16, 295.80it/s]
1%| | 61/5000 [00:00<00:16, 301.87it/s]
2%|β | 92/5000 [00:00<00:16, 303.75it/s]
2%|β | 123/5000 [00:00<00:15, 304.83it/s]
3%|β | 154/5000 [00:00<00:15, 305.29it/s]
4%|β | 185/5000 [00:00<00:15, 305.39it/s]
4%|β | 216/5000 [00:00<00:15, 305.44it/s]
5%|β | 247/5000 [00:00<00:15, 305.17it/s]
6%|β | 278/5000 [00:00<00:15, 305.35it/s]
6%|β | 309/5000 [00:01<00:15, 305.42it/s]
7%|β | 340/5000 [00:01<00:15, 305.29it/s]
7%|β | 371/5000 [00:01<00:15, 305.64it/s]
8%|β | 402/5000 [00:01<00:15, 305.65it/s]
9%|β | 433/5000 [00:01<00:14, 305.79it/s]
9%|β | 464/5000 [00:01<00:14, 306.09it/s]
10%|β | 495/5000 [00:01<00:14, 305.76it/s]
11%|β | 526/5000 [00:01<00:14, 305.88it/s]
11%|β | 557/5000 [00:01<00:14, 306.04it/s]
12%|ββ | 588/5000 [00:01<00:14, 306.20it/s]
12%|ββ | 619/5000 [00:02<00:14, 306.06it/s]
13%|ββ | 650/5000 [00:02<00:14, 306.18it/s]
14%|ββ | 681/5000 [00:02<00:14, 305.85it/s]
14%|ββ | 712/5000 [00:02<00:14, 305.80it/s]
15%|ββ | 743/5000 [00:02<00:13, 306.05it/s]
15%|ββ | 774/5000 [00:02<00:13, 306.01it/s]
16%|ββ | 805/5000 [00:02<00:13, 306.15it/s]
17%|ββ | 836/5000 [00:02<00:13, 306.39it/s]
17%|ββ | 867/5000 [00:02<00:13, 306.60it/s]
18%|ββ | 898/5000 [00:02<00:13, 306.69it/s]
19%|ββ | 929/5000 [00:03<00:13, 306.31it/s]
19%|ββ | 960/5000 [00:03<00:13, 306.06it/s]
20%|ββ | 991/5000 [00:03<00:13, 305.99it/s]
20%|ββ | 1022/5000 [00:03<00:13, 302.00it/s]
21%|ββ | 1053/5000 [00:03<00:13, 298.23it/s]
22%|βββ | 1083/5000 [00:03<00:13, 295.53it/s]
22%|βββ | 1113/5000 [00:03<00:13, 293.23it/s]
23%|βββ | 1143/5000 [00:03<00:13, 291.83it/s]
23%|βββ | 1173/5000 [00:03<00:13, 290.87it/s]
24%|βββ | 1203/5000 [00:03<00:13, 290.15it/s]
25%|βββ | 1233/5000 [00:04<00:12, 289.80it/s]
25%|βββ | 1262/5000 [00:04<00:12, 289.55it/s]
26%|βββ | 1291/5000 [00:04<00:12, 288.94it/s]
26%|βββ | 1320/5000 [00:04<00:12, 288.82it/s]
27%|βββ | 1349/5000 [00:04<00:12, 288.88it/s]
28%|βββ | 1378/5000 [00:04<00:12, 289.06it/s]
28%|βββ | 1407/5000 [00:04<00:12, 288.51it/s]
29%|βββ | 1436/5000 [00:04<00:12, 288.62it/s]
29%|βββ | 1465/5000 [00:04<00:12, 288.94it/s]
30%|βββ | 1494/5000 [00:04<00:12, 288.99it/s]
30%|βββ | 1523/5000 [00:05<00:12, 288.52it/s]
31%|βββ | 1552/5000 [00:05<00:11, 288.37it/s]
32%|ββββ | 1581/5000 [00:05<00:11, 288.05it/s]
32%|ββββ | 1610/5000 [00:05<00:11, 288.19it/s]
33%|ββββ | 1639/5000 [00:05<00:11, 288.28it/s]
33%|ββββ | 1668/5000 [00:05<00:11, 288.22it/s]
34%|ββββ | 1697/5000 [00:05<00:11, 287.83it/s]
35%|ββββ | 1726/5000 [00:05<00:11, 288.19it/s]
35%|ββββ | 1755/5000 [00:05<00:11, 288.47it/s]
36%|ββββ | 1784/5000 [00:05<00:11, 288.70it/s]
36%|ββββ | 1813/5000 [00:06<00:11, 286.45it/s]
37%|ββββ | 1842/5000 [00:06<00:11, 287.03it/s]
37%|ββββ | 1871/5000 [00:06<00:10, 287.31it/s]
38%|ββββ | 1900/5000 [00:06<00:10, 287.24it/s]
39%|ββββ | 1929/5000 [00:06<00:10, 287.32it/s]
39%|ββββ | 1958/5000 [00:06<00:10, 287.15it/s]
40%|ββββ | 1987/5000 [00:06<00:10, 287.44it/s]
40%|ββββ | 2016/5000 [00:06<00:10, 287.72it/s]
41%|ββββ | 2045/5000 [00:06<00:10, 287.85it/s]
41%|βββββ | 2074/5000 [00:06<00:10, 287.77it/s]
42%|βββββ | 2103/5000 [00:07<00:10, 288.24it/s]
43%|βββββ | 2132/5000 [00:07<00:09, 288.22it/s]
43%|βββββ | 2161/5000 [00:07<00:09, 287.91it/s]
44%|βββββ | 2190/5000 [00:07<00:09, 287.91it/s]
44%|βββββ | 2219/5000 [00:07<00:09, 287.86it/s]
45%|βββββ | 2248/5000 [00:07<00:09, 287.54it/s]
46%|βββββ | 2277/5000 [00:07<00:09, 287.74it/s]
46%|βββββ | 2306/5000 [00:07<00:09, 287.99it/s]
47%|βββββ | 2335/5000 [00:07<00:09, 288.15it/s]
47%|βββββ | 2364/5000 [00:08<00:09, 288.05it/s]
48%|βββββ | 2393/5000 [00:08<00:09, 288.23it/s]
48%|βββββ | 2422/5000 [00:08<00:08, 288.08it/s]
49%|βββββ | 2451/5000 [00:08<00:08, 287.81it/s]
50%|βββββ | 2480/5000 [00:08<00:08, 288.14it/s]
50%|βββββ | 2509/5000 [00:08<00:08, 288.23it/s]
51%|βββββ | 2538/5000 [00:08<00:08, 287.90it/s]
51%|ββββββ | 2567/5000 [00:08<00:08, 287.97it/s]
52%|ββββββ | 2596/5000 [00:08<00:08, 288.17it/s]
52%|ββββββ | 2625/5000 [00:08<00:08, 288.30it/s]
53%|ββββββ | 2654/5000 [00:09<00:08, 288.23it/s]
54%|ββββββ | 2683/5000 [00:09<00:08, 288.08it/s]
54%|ββββββ | 2712/5000 [00:09<00:07, 287.89it/s]
55%|ββββββ | 2741/5000 [00:09<00:07, 288.16it/s]
55%|ββββββ | 2770/5000 [00:09<00:07, 288.25it/s]
56%|ββββββ | 2799/5000 [00:09<00:07, 288.14it/s]
57%|ββββββ | 2828/5000 [00:09<00:07, 287.86it/s]
57%|ββββββ | 2857/5000 [00:09<00:07, 287.84it/s]
58%|ββββββ | 2886/5000 [00:09<00:07, 288.09it/s]
58%|ββββββ | 2915/5000 [00:09<00:07, 288.26it/s]
59%|ββββββ | 2944/5000 [00:10<00:07, 288.04it/s]
59%|ββββββ | 2973/5000 [00:10<00:07, 287.97it/s]
60%|ββββββ | 3002/5000 [00:10<00:06, 287.99it/s]
61%|ββββββ | 3031/5000 [00:10<00:06, 288.26it/s]
61%|ββββββ | 3060/5000 [00:10<00:06, 288.38it/s]
62%|βββββββ | 3089/5000 [00:10<00:06, 288.34it/s]
62%|βββββββ | 3118/5000 [00:10<00:06, 287.87it/s]
63%|βββββββ | 3147/5000 [00:10<00:06, 288.21it/s]
64%|βββββββ | 3176/5000 [00:10<00:06, 288.32it/s]
64%|βββββββ | 3205/5000 [00:10<00:06, 288.56it/s]
65%|βββββββ | 3234/5000 [00:11<00:06, 288.06it/s]
65%|βββββββ | 3263/5000 [00:11<00:06, 288.16it/s]
66%|βββββββ | 3292/5000 [00:11<00:05, 288.29it/s]
66%|βββββββ | 3321/5000 [00:11<00:05, 288.58it/s]
67%|βββββββ | 3350/5000 [00:11<00:05, 288.49it/s]
68%|βββββββ | 3379/5000 [00:11<00:05, 288.47it/s]
68%|βββββββ | 3408/5000 [00:11<00:05, 288.22it/s]
69%|βββββββ | 3437/5000 [00:11<00:05, 288.34it/s]
69%|βββββββ | 3466/5000 [00:11<00:05, 288.57it/s]
70%|βββββββ | 3495/5000 [00:11<00:05, 288.55it/s]
70%|βββββββ | 3524/5000 [00:12<00:05, 288.32it/s]
71%|βββββββ | 3553/5000 [00:12<00:05, 288.45it/s]
72%|ββββββββ | 3582/5000 [00:12<00:04, 287.16it/s]
72%|ββββββββ | 3611/5000 [00:12<00:04, 287.07it/s]
73%|ββββββββ | 3640/5000 [00:12<00:04, 287.51it/s]
73%|ββββββββ | 3669/5000 [00:12<00:04, 287.96it/s]
74%|ββββββββ | 3698/5000 [00:12<00:04, 287.96it/s]
75%|ββββββββ | 3727/5000 [00:12<00:04, 288.21it/s]
75%|ββββββββ | 3756/5000 [00:12<00:04, 288.39it/s]
76%|ββββββββ | 3785/5000 [00:12<00:04, 288.47it/s]
76%|ββββββββ | 3814/5000 [00:13<00:04, 288.03it/s]
77%|ββββββββ | 3843/5000 [00:13<00:04, 288.05it/s]
77%|ββββββββ | 3872/5000 [00:13<00:03, 287.89it/s]
78%|ββββββββ | 3901/5000 [00:13<00:03, 288.06it/s]
79%|ββββββββ | 3930/5000 [00:13<00:03, 288.16it/s]
79%|ββββββββ | 3959/5000 [00:13<00:03, 288.19it/s]
80%|ββββββββ | 3988/5000 [00:13<00:03, 287.51it/s]
80%|ββββββββ | 4017/5000 [00:13<00:03, 287.69it/s]
81%|ββββββββ | 4046/5000 [00:13<00:03, 287.91it/s]
82%|βββββββββ | 4075/5000 [00:13<00:03, 288.06it/s]
82%|βββββββββ | 4104/5000 [00:14<00:03, 288.17it/s]
83%|βββββββββ | 4133/5000 [00:14<00:03, 288.18it/s]
83%|βββββββββ | 4162/5000 [00:14<00:02, 288.04it/s]
84%|βββββββββ | 4191/5000 [00:14<00:02, 287.83it/s]
84%|βββββββββ | 4220/5000 [00:14<00:02, 287.74it/s]
85%|βββββββββ | 4249/5000 [00:14<00:02, 287.73it/s]
86%|βββββββββ | 4278/5000 [00:14<00:02, 287.51it/s]
86%|βββββββββ | 4307/5000 [00:14<00:02, 287.64it/s]
87%|βββββββββ | 4336/5000 [00:14<00:02, 287.88it/s]
87%|βββββββββ | 4365/5000 [00:14<00:02, 288.02it/s]
88%|βββββββββ | 4394/5000 [00:15<00:02, 287.68it/s]
88%|βββββββββ | 4423/5000 [00:15<00:02, 287.63it/s]
89%|βββββββββ | 4452/5000 [00:15<00:01, 287.68it/s]
90%|βββββββββ | 4481/5000 [00:15<00:01, 287.60it/s]
90%|βββββββββ | 4510/5000 [00:15<00:01, 287.98it/s]
91%|βββββββββ | 4539/5000 [00:15<00:01, 287.91it/s]
91%|ββββββββββ| 4568/5000 [00:15<00:01, 287.57it/s]
92%|ββββββββββ| 4597/5000 [00:15<00:01, 287.84it/s]
93%|ββββββββββ| 4626/5000 [00:15<00:01, 288.16it/s]
93%|ββββββββββ| 4655/5000 [00:15<00:01, 288.08it/s]
94%|ββββββββββ| 4684/5000 [00:16<00:01, 288.14it/s]
94%|ββββββββββ| 4713/5000 [00:16<00:00, 288.30it/s]
95%|ββββββββββ| 4742/5000 [00:16<00:00, 287.41it/s]
95%|ββββββββββ| 4771/5000 [00:16<00:00, 287.32it/s]
96%|ββββββββββ| 4800/5000 [00:16<00:00, 287.61it/s]
97%|ββββββββββ| 4829/5000 [00:16<00:00, 287.65it/s]
97%|ββββββββββ| 4858/5000 [00:16<00:00, 287.36it/s]
98%|ββββββββββ| 4887/5000 [00:16<00:00, 287.57it/s]
98%|ββββββββββ| 4916/5000 [00:16<00:00, 287.67it/s]
99%|ββββββββββ| 4945/5000 [00:16<00:00, 287.87it/s]
99%|ββββββββββ| 4974/5000 [00:17<00:00, 288.14it/s]
100%|ββββββββββ| 5000/5000 [00:17<00:00, 291.40it/s]
Iteration 4999, current converge crit. = 1.43E-05, objective = 1.00E-03
Iteration 4999, current converge crit. = 3.42E-04, objective = 1.00E-03
Linear reconstruction PSNR: 8.55 dB
Posterior mean PSNR: 22.31 dB
- References:
Total running time of the script: (0 minutes 17.893 seconds)