DiffPIR

class deepinv.sampling.DiffPIR(model, data_fidelity, sigma=0.05, max_iter=100, zeta=1.0, lambda_=7.0, verbose=False, device='cpu')[source]

Bases: Module

Diffusion PnP Image Restoration (DiffPIR).

This class implements the Diffusion PnP image restoration algorithm (DiffPIR) described in https://arxiv.org/abs/2305.08995.

The DiffPIR algorithm is inspired on a half-quadratic splitting (HQS) plug-and-play algorithm, where the denoiser is a conditional diffusion denoiser, combined with a diffusion process. The algorithm writes as follows, for \(t\) decreasing from \(T\) to \(1\):

\[\begin{split}\begin{equation*} \begin{aligned} x_{0}^{t} &= D_{\theta}(x_t, \frac{\sqrt{1-\overline{\alpha}_t}}{\sqrt{\overline{\alpha}_t}}) \\ \widehat{x}_{0}^{t} &= \operatorname{prox}_{2 f(y, \cdot) /{\rho_t}}(x_{0}^{t}) \\ \widehat{\varepsilon} &= \left(x_t - \sqrt{\overline{\alpha}_t} \,\, \widehat{x}_{0}^t\right)/\sqrt{1-\overline{\alpha}_t} \\ \varepsilon_t &= \mathcal{N}(0, \mathbf{I}) \\ x_{t-1} &= \sqrt{\overline{\alpha}_t} \,\, \widehat{x}_{0}^t + \sqrt{1-\overline{\alpha}_t} \left(\sqrt{1-\zeta} \,\, \widehat{\varepsilon} + \sqrt{\zeta} \,\, \varepsilon_t\right), \end{aligned} \end{equation*}\end{split}\]

where \(D_\theta(\cdot,\sigma)\) is a Gaussian denoiser network with noise level \(\sigma\) and \(f(y, \cdot)\) is the data fidelity term.

Note

The algorithm might require careful tunning of the hyperparameters \(\lambda\) and \(\zeta\) to obtain optimal results.

Parameters:
  • model (torch.nn.Module) – a conditional noise estimation model

  • sigma (float) – the noise level of the data

  • data_fidelity (deepinv.optim.DataFidelity) – the data fidelity operator

  • max_iter (int) – the number of iterations to run the algorithm (default: 100)

  • zeta (float) – hyperparameter \(\zeta\) for the sampling step (must be between 0 and 1). Default: 1.0.

  • lambda (float) – hyperparameter \(\lambda\) for the data fidelity step (\(\rho_t = \lambda \frac{\sigma_n^2}{\bar{\sigma}_t^2}\) in the paper where the optimal value range between 3.0 and 25.0 depending on the problem). Default: 7.0.

  • verbose (bool) – if True, print progress

  • device (str) – the device to use for the computations


Examples:

Denoising diffusion restoration model using a pretrained DRUNet denoiser:

>>> import deepinv as dinv
>>> device = dinv.utils.get_freer_gpu(verbose=False) if torch.cuda.is_available() else 'cpu'
>>> x = 0.5 * torch.ones(1, 3, 32, 32, device=device) # Define a plain gray 32x32 image
>>> physics = dinv.physics.Inpainting(
...   mask=0.5, tensor_size=(3, 32, 32),
...   noise_model=dinv.physics.GaussianNoise(0.1),
...   device=device
... )
>>> y = physics(x) # Measurements
>>> denoiser = dinv.models.DRUNet(pretrained="download").to(device)
>>> model = DiffPIR(
...   model=denoiser,
...   data_fidelity=dinv.optim.data_fidelity.L2()
... ) # Define the DiffPIR model
>>> xhat = model(y, physics) # Run the DiffPIR algorithm
>>> dinv.metric.PSNR()(xhat, x) > dinv.metric.PSNR()(y, x) # Should be closer to the original
tensor([True])
find_nearest(array, value)[source]

Find the argmin of the nearest value in an array.

forward(y, physics: LinearPhysics, seed=None, x_init=None)[source]

Runs the diffusion to obtain a random sample of the posterior distribution.

Parameters:
get_alpha_beta()[source]

Get the alpha and beta sequences for the algorithm. This is necessary for mapping noise levels to timesteps.

get_alpha_prod(beta_start=0.0001, beta_end=0.02, num_train_timesteps=1000)[source]

Get the alpha sequences; this is necessary for mapping noise levels to timesteps when performing pure denoising.

get_noise_schedule(sigma)[source]

Get the noise schedule for the algorithm.

Examples using DiffPIR:

Implementing DiffPIR

Implementing DiffPIR