Note
New to DeepInverse? Get started with the basics with the 5 minute quickstart tutorial.
Use iterative reconstruction algorithms#
Follow this example to reconstruct images using an iterative algorithm.
The library provides a flexible framework to define your own iterative reconstruction algorithm, which are generally written as the optimization of the following problem:
where \(\datafid{x}{y}\) is the data fidelity term, \(\reg{x}\) is the (explicit or implicit) regularization term, and \(\lambda\) is a regularization parameter. In this example, we demonstrate:
How to define your own iterative algorithm
How to package it as a
reconstructor modelHow to use define new optimization algorithm as a subclass of
BaseOptim
1. Defining your own iterative algorithm#
import deepinv as dinv
import torch
from deepinv.optim import PGD
device = dinv.utils.get_freer_gpu() if torch.cuda.is_available() else "cpu"
Define the physics of the problem#
Here we define a simple inpainting problem, where we want to reconstruct an image from partial measurements. We also load an image of a butterfly to use as ground truth.
x = dinv.utils.load_example("butterfly.png", device=device, img_size=(128, 128))
# Forward operator, here inpainting with a mask of 50% of the pixels
physics = dinv.physics.Inpainting(img_size=(3, 128, 128), mask=0.5, device=device)
# Generate measurements
y = physics(x)
dinv.utils.plot([x, y], titles=["Ground truth", "Measurements"])

Define the data fidelity term and prior#
The library provides a set of data fidelity terms and priors that can be used in the optimization problem. Here we use the \(\ell_2\) data fidelity term and the Total Variation (TV) prior.
These classes provide all the necessary methods for the optimization problem, such as the evaluation of the term, the gradient, and the proximal operator.
data_fidelity = dinv.optim.L2() # Data fidelity term
prior = dinv.optim.TVPrior() # Prior term
Define the iterative algorithm#
We will use the Proximal Gradient Descent (PGD) algorithm to solve the optimization problem defined above, which is defined as
where \(\operatorname{prox}_{\gamma \lambda \regname}\) is the proximal operator of the regularization term, \(\nabla \datafidname(x_k, y)\) is the gradient of the data fidelity term, \(\gamma\) is the stepsize. and \(\lambda\) is the regularization parameter.
We can choose the stepsize as \(\gamma < \frac{2}{\|A\|^2}\), where \(A\) is the forward operator, in order to ensure convergence of the algorithm.
lambd = 0.05 # Regularization parameter
# Compute the squared norm of the operator A
norm_A2 = physics.compute_sqnorm(y, tol=1e-4, verbose=False).item()
stepsize = 1.9 / norm_A2 # stepsize for the PGD algorithm
# PGD algorithm
max_iter = 20 # number of iterations
x_k = torch.zeros_like(x, device=device) # initial guess
# To store the cost at each iteration:
cost_history = torch.zeros(max_iter, device=device)
with torch.no_grad(): # disable autodifferentiation
for it in range(max_iter):
u = x_k - stepsize * data_fidelity.grad(x_k, y, physics) # Gradient step
x_k = prior.prox(u, gamma=lambd * stepsize) # Proximal step
cost = data_fidelity(x_k, y, physics) + lambd * prior(x_k) # Compute the cost
cost_history[it] = cost # Store the cost
Plot the cost history
import matplotlib.pyplot as plt
plt.figure(figsize=(8, 4))
plt.plot(cost_history.detach().cpu().numpy(), marker="o")
plt.title("Cost history")
plt.xlabel("Iteration")
plt.ylabel("Cost")
plt.grid()
plt.show()

Plot the results and metrics

Use a pretrained denoiser as prior#
We can improve the reconstruction by using a pretrained denoiser as prior, by replacing the proximal operator with a denoising step. The library provides a collection of classical and pretrained denoisers that can be used in iterative algorithms.
Note
Plug-and-play algorithms can be sensitive to the choice of initialization. Here we use the TV estimate as the initial guess.
x_k = x_k.clone()
denoiser = dinv.models.DRUNet(device=device) # Load a pretrained denoiser
with torch.no_grad(): # disable autodifferentiation
for it in range(max_iter):
u = x_k - stepsize * data_fidelity.grad(x_k, y, physics) # Gradient step
x_k = denoiser(u, sigma=0.05) # replace prox by denoising step
dinv.utils.plot(
{
f"Ground truth": x,
f"Measurements\n {metric(y, x).item():.2f} dB": y,
f"Recon w/ PnP prior\n {metric(x_k, x).item():.2f} dB": x_k,
}
)

2. Package your algorithm as a Reconstructor#
The iterative algorithm we defined above can be packaged as a Reconstructor.
This allows you to test it on different physics and datasets, and to use it in a more flexible way,
including unfolding it and learning some of its parameters.
class MyPGD(dinv.models.Reconstructor):
def __init__(self, data_fidelity, prior, stepsize, lambd, max_iter):
super().__init__()
self.data_fidelity = data_fidelity
self.prior = prior
self.stepsize = stepsize
self.lambd = lambd
self.max_iter = max_iter
def forward(self, y, physics, **kwargs):
"""Algorithm forward pass.
:param torch.Tensor y: measurements.
:param dinv.physics.Physics physics: measurement operator.
:return: torch.Tensor: reconstructed image.
"""
x_k = torch.zeros_like(y, device=y.device) # initial guess
# Disable autodifferentiation, remove this if you want to unfold
with torch.no_grad():
for _ in range(self.max_iter):
u = x_k - self.stepsize * self.data_fidelity.grad(
x_k, y, physics
) # Gradient step
x_k = self.prior.prox(
u, gamma=self.lambd * self.stepsize
) # Proximal step
return x_k
tv_algo = MyPGD(data_fidelity, prior, stepsize, lambd, max_iter)
# Standard reconstructor forward pass
x_hat = tv_algo(y, physics)
dinv.utils.plot(
{
f"Ground truth": x,
f"Measurements\n {metric(y, x).item():.2f} dB": y,
f"Recon w/ custom PGD\n {metric(x_hat, x).item():.2f} dB": x_hat,
}
)

3. Using a predefined optimization algorithm#
The library also comes with common optimization algorithms (PGD, ADMM, PGD, HQS, FISTA, Primal-Dual, etc.)
already implemented as a Reconstructors.
They can be instanciated in one line of code.
For example, the above PnP algorithm can be defined as follows:
See also
For more examples with other predefined algorithms, see the ADMM example and the DRS example.
prior = dinv.optim.PnP(denoiser=denoiser) # prior with prox via denoising step
model = PGD(
prior=prior,
data_fidelity=data_fidelity,
stepsize=stepsize,
sigma_denoiser=0.05,
max_iter=max_iter,
)
x_hat = model(y, physics, init=tv_algo(y, physics))
dinv.utils.plot(
{
f"Ground truth": x,
f"Measurements\n {metric(y, x).item():.2f} dB": y,
f"Reconstruction\n {metric(x_hat, x).item():.2f} dB": x_hat,
}
)

🎉 Well done, you now know how to define your own iterative reconstruction algorithm!
What’s next?#
Check out more about optimization algorithms in the optimization user guide.
Check out diffusion and MCMC iterative algorithms in the sampling user guide.
Check out more iterative algorithms examples.
Check out how to try the algorithm on a whole dataset by following the bring your own dataset tutorial.
Check out how to train your plug-and-play algorithm by unfolding its iterations in the vanilla unfolded tutorial.
Total running time of the script: (1 minutes 0.731 seconds)