Note
Go to the end to download the full example code.
Uncertainty quantification with PnP-ULA.#
This code shows you how to use sampling algorithms to quantify uncertainty of a reconstruction from incomplete and noisy measurements.
ULA obtains samples by running the following iteration:
where \(z_k \sim \mathcal{N}(0, I)\) is a Gaussian random variable, \(\eta\) is the step size and \(\alpha\) is a parameter controlling the regularization.
The PnP-ULA method is described in the paper “Bayesian imaging using Plug & Play priors: when Langevin meets Tweedie “.
import deepinv as dinv
from deepinv.utils.plotting import plot
import torch
from deepinv.utils.demo import load_url_image
Load image from the internet#
This example uses an image of Lionel Messi from Wikipedia.
Define forward operator and noise model#
This example uses inpainting as the forward operator and Gaussian noise as the noise model.
sigma = 0.1 # noise level
physics = dinv.physics.Inpainting(mask=0.5, tensor_size=x.shape[1:], device=device)
physics.noise_model = dinv.physics.GaussianNoise(sigma=sigma)
# Set the global random seed from pytorch to ensure reproducibility of the example.
torch.manual_seed(0)
<torch._C.Generator object at 0x7f90f64d9750>
Define the likelihood#
Since the noise model is Gaussian, the negative log-likelihood is the L2 loss.
# load Gaussian Likelihood
likelihood = dinv.optim.data_fidelity.L2(sigma=sigma)
Define the prior#
The score a distribution can be approximated using Tweedie’s formula via the
deepinv.optim.ScorePrior
class.
This example uses a pretrained DnCNN model.
From a Bayesian point of view, the score plays the role of the gradient of the
negative log prior
The hyperparameter sigma_denoiser
(\(sigma\)) controls the strength of the prior.
In this example, we use a pretrained DnCNN model using the deepinv.loss.FNEJacobianSpectralNorm
loss,
which makes sure that the denoiser is firmly non-expansive (see
“Building firmly nonexpansive convolutional neural networks”), and helps to
stabilize the sampling algorithm.
sigma_denoiser = 2 / 255
prior = dinv.optim.ScorePrior(
denoiser=dinv.models.DnCNN(pretrained="download_lipschitz")
).to(device)
Downloading: "https://huggingface.co/deepinv/dncnn/resolve/main/dncnn_sigma2_lipschitz_color.pth?download=true" to /home/runner/.cache/torch/hub/checkpoints/dncnn_sigma2_lipschitz_color.pth
0%| | 0.00/2.56M [00:00<?, ?B/s]
63%|██████▎ | 1.62M/2.56M [00:00<00:00, 15.5MB/s]
100%|██████████| 2.56M/2.56M [00:00<00:00, 14.4MB/s]
Create the MCMC sampler#
Here we use the Unadjusted Langevin Algorithm (ULA) to sample from the posterior defined in
deepinv.sampling.ULA
.
The hyperparameter step_size
controls the step size of the MCMC sampler,
regularization
controls the strength of the prior and
iterations
controls the number of iterations of the sampler.
regularization = 0.9
step_size = 0.01 * (sigma**2)
iterations = int(5e3) if torch.cuda.is_available() else 10
f = dinv.sampling.ULA(
prior=prior,
data_fidelity=likelihood,
max_iter=iterations,
alpha=regularization,
step_size=step_size,
verbose=True,
sigma=sigma_denoiser,
)
Generate the measurement#
We apply the forward model to generate the noisy measurement.
Run sampling algorithm and plot results#
The sampling algorithm returns the posterior mean and variance. We compare the posterior mean with a simple linear reconstruction.
mean, var = f(y, physics)
# compute linear inverse
x_lin = physics.A_adjoint(y)
# compute PSNR
print(f"Linear reconstruction PSNR: {dinv.metric.PSNR()(x, x_lin).item():.2f} dB")
print(f"Posterior mean PSNR: {dinv.metric.PSNR()(x, mean).item():.2f} dB")
# plot results
error = (mean - x).abs().sum(dim=1).unsqueeze(1) # per pixel average abs. error
std = var.sum(dim=1).unsqueeze(1).sqrt() # per pixel average standard dev.
imgs = [x_lin, x, mean, std / std.flatten().max(), error / error.flatten().max()]
plot(
imgs,
titles=["measurement", "ground truth", "post. mean", "post. std", "abs. error"],
)
0%| | 0/10 [00:00<?, ?it/s]
100%|██████████| 10/10 [00:00<00:00, 96.34it/s]
100%|██████████| 10/10 [00:00<00:00, 96.09it/s]
Monte Carlo sampling finished! elapsed time=0.10 seconds
Linear reconstruction PSNR: 8.79 dB
Posterior mean PSNR: 8.86 dB
Total running time of the script: (0 minutes 0.655 seconds)