DiffPIR
- class deepinv.sampling.DiffPIR(model, data_fidelity, sigma=0.05, max_iter=100, zeta=1.0, lambda_=7.0, verbose=False, device='cpu')[source]
Bases:
Module
Diffusion PnP Image Restoration (DiffPIR).
This class implements the Diffusion PnP image restoration algorithm (DiffPIR) described in https://arxiv.org/abs/2305.08995.
The DiffPIR algorithm is inspired on a half-quadratic splitting (HQS) plug-and-play algorithm, where the denoiser is a conditional diffusion denoiser, combined with a diffusion process. The algorithm writes as follows, for \(t\) decreasing from \(T\) to \(1\):
\[\begin{split}\begin{equation*} \begin{aligned} x_{0}^{t} &= D_{\theta}(x_t, \frac{\sqrt{1-\overline{\alpha}_t}}{\sqrt{\overline{\alpha}_t}}) \\ \widehat{x}_{0}^{t} &= \operatorname{prox}_{2 f(y, \cdot) /{\rho_t}}(x_{0}^{t}) \\ \widehat{\varepsilon} &= \left(x_t - \sqrt{\overline{\alpha}_t} \,\, \widehat{x}_{0}^t\right)/\sqrt{1-\overline{\alpha}_t} \\ \varepsilon_t &= \mathcal{N}(0, \mathbf{I}) \\ x_{t-1} &= \sqrt{\overline{\alpha}_t} \,\, \widehat{x}_{0}^t + \sqrt{1-\overline{\alpha}_t} \left(\sqrt{1-\zeta} \,\, \widehat{\varepsilon} + \sqrt{\zeta} \,\, \varepsilon_t\right), \end{aligned} \end{equation*}\end{split}\]where \(D_\theta(\cdot,\sigma)\) is a Gaussian denoiser network with noise level \(\sigma\) and \(f(y, \cdot)\) is the data fidelity term.
Note
The algorithm might require careful tunning of the hyperparameters \(\lambda\) and \(\zeta\) to obtain optimal results.
- Parameters:
model (torch.nn.Module) – a conditional noise estimation model
sigma (float) – the noise level of the data
data_fidelity (deepinv.optim.DataFidelity) – the data fidelity operator
max_iter (int) – the number of iterations to run the algorithm (default: 100)
zeta (float) – hyperparameter \(\zeta\) for the sampling step (must be between 0 and 1). Default: 1.0.
lambda (float) – hyperparameter \(\lambda\) for the data fidelity step (\(\rho_t = \lambda \frac{\sigma_n^2}{\bar{\sigma}_t^2}\) in the paper where the optimal value range between 3.0 and 25.0 depending on the problem). Default:
7.0
.verbose (bool) – if
True
, print progressdevice (str) – the device to use for the computations
- Examples:
Denoising diffusion restoration model using a pretrained DRUNet denoiser:
>>> import deepinv as dinv >>> device = dinv.utils.get_freer_gpu(verbose=False) if torch.cuda.is_available() else 'cpu' >>> x = 0.5 * torch.ones(1, 3, 32, 32, device=device) # Define a plain gray 32x32 image >>> physics = dinv.physics.Inpainting( ... mask=0.5, tensor_size=(3, 32, 32), ... noise_model=dinv.physics.GaussianNoise(0.1), ... device=device ... ) >>> y = physics(x) # Measurements >>> denoiser = dinv.models.DRUNet(pretrained="download").to(device) >>> model = DiffPIR( ... model=denoiser, ... data_fidelity=dinv.optim.L2() ... ) # Define the DiffPIR model >>> xhat = model(y, physics) # Run the DiffPIR algorithm >>> dinv.metric.PSNR()(xhat, x) > dinv.metric.PSNR()(y, x) # Should be closer to the original tensor([True])
- forward(y, physics: LinearPhysics, seed=None, x_init=None)[source]
Runs the diffusion to obtain a random sample of the posterior distribution.
- Parameters:
y (torch.Tensor) – the measurements.
physics (deepinv.physics.LinearPhysics) – the physics operator.
sigma (float) – the noise level of the data.
seed (int) – the seed for the random number generator.
x_init (torch.Tensor) – the initial guess for the reconstruction.
- get_alpha_beta()[source]
Get the alpha and beta sequences for the algorithm. This is necessary for mapping noise levels to timesteps.