Note
New to DeepInverse? Get started with the basics with the 5 minute quickstart tutorial..
DEAL denoising and reconstruction#
This example shows how to use the Deep Equilibrium Attention Least Squares (DEAL) model in DeepInverse for both denoising and a simple reconstruction settings.
The reconstruction is obtained by solving
where \(A\) is the forward operator, \(y\) are the measurements, and \(g_{\theta}\) is the learned adaptive regularizer.
In DEAL, the regularizer is induced by a masked linear operator
which depends on an auxiliary image \(u\), resulting in
The fixed-point iterations used by the solver are
Each subproblem is solved approximately with conjugate gradient.
DEAL solves inverse problems by minimizing a data-fidelity term together with a learned adaptive regularizer. In the implementation, this regularizer is induced by a masked linear operator, where learned filters are modulated by spatially varying masks predicted by the network. The reconstruction is then refined through iterative least-squares updates, where the regularizer is recomputed from the current iterate.
This implementation is adapted from the official DEAL repository.
Here, the model is illustrated first for Gaussian denoising, and then for a simple inpainting reconstruction problem.
Import packages and load a grayscale example image.
import torch
from deepinv.loss.metric import PSNR
from deepinv.models import DEAL
from deepinv.physics import Denoising, GaussianNoise, Inpainting
from deepinv.utils import load_example, plot
device = "cuda" if torch.cuda.is_available() else "cpu"
torch.manual_seed(0)
# Load image (grayscale)
x = load_example("butterfly.png", img_size=128, device=device, grayscale=True)
Noise level in the normalized [0,1] convention used by DeepInverse.
sigma = 0.1
Load pretrained DEAL model
model = DEAL(
pretrained="download",
sigma_denoiser=sigma,
lambda_reg=10.0,
max_iter=10,
auto_scale=False,
color=False,
device=device,
clamp_output=True,
)
model.eval()
n_params = sum(p.numel() for p in model.parameters())
print(f"DEAL number of parameters: {n_params:,}")
psnr = PSNR()
DEAL number of parameters: 468,570
Denoising with DEAL#
We first illustrate Gaussian denoising with DEAL. The model is applied with a denoising operator, and the plotted mask corresponds to the spatially varying regularization weights from the last iteration.
physics_denoise = Denoising(GaussianNoise(sigma=sigma)).to(device)
y_denoise = physics_denoise(x)
with torch.no_grad():
x_hat_denoise = model(y_denoise, physics_denoise)
mask_denoise = model.mask.mean(dim=1, keepdim=True)
psnr_noisy = psnr(y_denoise, x).item()
psnr_denoise = psnr(x_hat_denoise, x).item()
print(f"[Denoising] PSNR noisy: {psnr_noisy:.2f} dB")
print(f"[Denoising] PSNR DEAL: {psnr_denoise:.2f} dB")
# The mask corresponds to the spatially varying weights learned by DEAL.
# It controls how strongly the regularizer acts at each pixel.
plot(
[x, y_denoise, x_hat_denoise, mask_denoise],
titles=[
"Ground truth",
"Noisy input",
"DEAL denoising",
"DEAL mask",
],
subtitles=[
"",
f"PSNR: {psnr_noisy:.2f} dB",
f"PSNR: {psnr_denoise:.2f} dB",
"Mean over channels",
],
figsize=(11, 3),
)

[Denoising] PSNR noisy: 19.99 dB
[Denoising] PSNR DEAL: 29.34 dB
Reconstruction example with inpainting#
We next illustrate a simple inpainting problem where a random subset of pixels is removed. DEAL combines data fidelity and its learned adaptive regularization to recover the missing content.
# Create random mask (50% missing pixels)
mask = (torch.rand(1, 1, 128, 128, device=device) > 0.5).float()
physics_inpaint = Inpainting(
img_size=(1, 128, 128),
mask=mask,
noise_model=GaussianNoise(sigma=0.0),
).to(device)
y_inpaint = physics_inpaint(x)
with torch.no_grad():
x_lin = physics_inpaint.A_adjoint(y_inpaint)
x_hat_inpaint = model(y_inpaint, physics_inpaint)
mask_inpaint = model.mask.mean(dim=1, keepdim=True)
psnr_meas_inpaint = psnr(y_inpaint, x).item()
psnr_lin_inpaint = psnr(x_lin, x).item()
psnr_deal_inpaint = psnr(x_hat_inpaint, x).item()
print(f"[Inpainting] PSNR measurement: {psnr_meas_inpaint:.2f} dB")
print(f"[Inpainting] PSNR linear: {psnr_lin_inpaint:.2f} dB")
print(f"[Inpainting] PSNR DEAL: {psnr_deal_inpaint:.2f} dB")
# Again, the mask shows the learned spatial weighting of the regularizer.
plot(
[x, y_inpaint, x_lin, x_hat_inpaint, mask_inpaint],
titles=[
"Ground truth",
"Masked measurement",
"Adjoint baseline",
"DEAL reconstruction",
"DEAL mask",
],
subtitles=[
"",
f"PSNR: {psnr_meas_inpaint:.2f} dB",
f"PSNR: {psnr_lin_inpaint:.2f} dB",
f"PSNR: {psnr_deal_inpaint:.2f} dB",
"Mean over channels",
],
figsize=(13, 3),
)

[Inpainting] PSNR measurement: 6.94 dB
[Inpainting] PSNR linear: 6.94 dB
[Inpainting] PSNR DEAL: 29.52 dB
Total running time of the script: (0 minutes 55.544 seconds)