Note
Go to the end to download the full example code.
Building your custom sampling algorithm.
This code shows how to build your custom sampling kernel. Here we build a preconditioned Unadjusted Langevin Algorithm (PreconULA) that takes advantage of the singular value decomposition of the forward operator to accelerate the sampling.
import deepinv as dinv
from deepinv.utils.plotting import plot
import torch
from deepinv.sampling import ULA
import numpy as np
from deepinv.utils.demo import load_url_image
Load image from the internet
This example uses an image of Lionel Messi from Wikipedia.
Define forward operator and noise model
We use a 5x5 box blur as the forward operator and Gaussian noise as the noise model.
sigma = 0.001 # noise level
physics = dinv.physics.BlurFFT(
img_size=(3, 32, 32),
filter=torch.ones((1, 1, 5, 5), device=device) / 25,
device=device,
noise_model=dinv.physics.GaussianNoise(sigma=sigma),
)
Generate the measurement
Apply the forward model to generate the noisy measurement.
Define the sampling iteration
In order to define a custom sampling kernel (possibly a Markov kernel which depends on the previous sample), we only need to define the iterator which takes the current sample and returns the next sample.
Here we define a preconditioned ULA iterator (for a Gaussian likelihood), which takes into account the singular value decomposition of the forward operator, \(A=USV^{\top}\), in order to accelerate the sampling.
We modify the standard ULA iteration (see deepinv.sampling.ULA
) defined as
by using a matrix-valued step size \(\eta = \eta_0 VRV^{\top}\) where \(R\) is a diagonal matrix with entries \(R_{i,i} = \frac{1}{S_{i,i}^2 + \epsilon}\). The parameter \(\epsilon\) is used to avoid numerical issues when \(S_{i,i}^2\) is close to zero. After some algebra, we obtain the following iteration
We exploit the methods of deepinv.physics.DecomposablePhysics
to compute the matrix-vector products
with \(V\) and \(V^{\top}\) efficiently. Note that computing the matrix-vector product with \(R\) and
\(S\) is trivial since they are diagonal matrices.
class PULAIterator(torch.nn.Module):
def __init__(self, step_size, sigma, alpha=1, epsilon=0.01):
super().__init__()
self.step_size = step_size
self.alpha = alpha
self.noise_std = np.sqrt(2 * step_size)
self.sigma = sigma
self.epsilon = epsilon
def forward(self, x, y, physics, likelihood, prior):
x_bar = physics.V_adjoint(x)
y_bar = physics.U_adjoint(y)
step_size = self.step_size / (self.epsilon + physics.mask.pow(2))
noise = torch.randn_like(x_bar)
sigma2_noise = 1 / likelihood.norm
lhood = -(physics.mask.pow(2) * x_bar - physics.mask * y_bar) / sigma2_noise
lprior = -physics.V_adjoint(prior.grad(x, self.sigma)) * self.alpha
return x + physics.V(
step_size * (lhood + lprior) + (2 * step_size).sqrt() * noise
)
Build Sampler class
Using our custom iterator, we can build a sampler class by inheriting from the base class
deepinv.sampling.MonteCarlo
.
The base class takes care of the sampling procedure
(calculating mean and variance, taking into account sample thinning and burnin iterations, etc),
providing a convenient interface to the user.
class PreconULA(dinv.sampling.MonteCarlo):
def __init__(
self,
prior,
data_fidelity,
sigma,
step_size,
max_iter=1e3,
thinning=1,
burnin_ratio=0.1,
clip=(-1, 2),
verbose=True,
):
# generate an iterator
iterator = PULAIterator(step_size=step_size, sigma=sigma)
# set the params of the base class
super().__init__(
iterator,
prior,
data_fidelity,
max_iter=max_iter,
thinning=thinning,
burnin_ratio=burnin_ratio,
clip=clip,
verbose=verbose,
)
Define the prior
The score a distribution can be approximated using a plug-and-play denoiser via the
deepinv.optim.ScorePrior
class.
This example uses a simple median filter as a plug-and-play denoiser. The hyperparameter \(\sigma_d\) controls the strength of the prior.
prior = dinv.optim.ScorePrior(denoiser=dinv.models.MedianFilter())
Create the preconditioned and standard ULA samplers
We create the preconditioned and standard ULA samplers using the same hyperparameters (step size, number of iterations, etc.).
step_size = 0.5 * (sigma**2)
iterations = int(1e2) if torch.cuda.is_available() else 10
g_param = 0.1
# load Gaussian Likelihood
likelihood = dinv.optim.data_fidelity.L2(sigma=sigma)
pula = PreconULA(
prior=prior,
data_fidelity=likelihood,
max_iter=iterations,
step_size=step_size,
thinning=1,
burnin_ratio=0.1,
verbose=True,
sigma=g_param,
)
ula = ULA(
prior=prior,
data_fidelity=likelihood,
max_iter=iterations,
step_size=step_size,
thinning=1,
burnin_ratio=0.1,
verbose=True,
sigma=g_param,
)
Run sampling algorithms and plot results
Each sampling algorithm returns the posterior mean and variance. We compare the posterior mean of each algorithm with a simple linear reconstruction.
The preconditioned step size of the new sampler provides a significant acceleration to standard ULA, which is evident in the PSNR of the posterior mean.
Note
The preconditioned ULA sampler requires a forward operator with an easy singular value decomposition
(e.g. which inherit from deepinv.physics.DecomposablePhysics
) and the noise to be Gaussian,
whereas ULA is more general.
ula_mean, ula_var = ula(y, physics)
pula_mean, pula_var = pula(y, physics)
# compute linear inverse
x_lin = physics.A_adjoint(y)
# compute PSNR
print(f"Linear reconstruction PSNR: {dinv.metric.PSNR()(x, x_lin).item():.2f} dB")
print(f"ULA posterior mean PSNR: {dinv.metric.PSNR()(x, ula_mean).item():.2f} dB")
print(
f"PreconULA posterior mean PSNR: {dinv.metric.PSNR()(x, pula_mean).item():.2f} dB"
)
# plot results
imgs = [x_lin, x, ula_mean, pula_mean]
plot(imgs, titles=["measurement", "ground truth", "ULA", "PreconULA"])
0%| | 0/10 [00:00<?, ?it/s]
100%|██████████| 10/10 [00:00<00:00, 497.36it/s]
Monte Carlo sampling finished! elapsed time=0.02 seconds
Iteration 9, current converge crit. = 3.91E-04, objective = 1.00E-03
0%| | 0/10 [00:00<?, ?it/s]
100%|██████████| 10/10 [00:00<00:00, 418.45it/s]
Monte Carlo sampling finished! elapsed time=0.02 seconds
Linear reconstruction PSNR: 15.66 dB
ULA posterior mean PSNR: 15.66 dB
PreconULA posterior mean PSNR: 23.28 dB
Total running time of the script: (0 minutes 0.274 seconds)