DPS#

class deepinv.sampling.DPS(model, data_fidelity, max_iter=1000, eta=1.0, verbose=False, device='cpu', save_iterates=False)[source]#

Bases: Reconstructor

Diffusion Posterior Sampling (DPS).

This class implements the Diffusion Posterior Sampling algorithm (DPS) described in https://arxiv.org/abs/2209.14687.

DPS is an approximation of a gradient-based posterior sampling algorithm, which has minimal assumptions on the forward model. The only restriction is that the measurement model has to be differentiable, which is generally the case.

The algorithm writes as follows, for \(t\) decreasing from \(T\) to \(1\):

\[\begin{split}\begin{equation*} \begin{aligned} \widehat{\mathbf{x}}_{t} &= D_{\theta}(\mathbf{x}_t, \sqrt{1-\overline{\alpha}_t}/\sqrt{\overline{\alpha}_t}) \\ \mathbf{g}_t &= \nabla_{\mathbf{x}_t} \log p( \widehat{\mathbf{x}}_{t}(\mathbf{x}_t) | \mathbf{y} ) \\ \mathbf{\varepsilon}_t &= \mathcal{N}(0, \mathbf{I}) \\ \mathbf{x}_{t-1} &= a_t \,\, \mathbf{x}_t + b_t \, \, \widehat{\mathbf{x}}_t + \tilde{\sigma}_t \, \, \mathbf{\varepsilon}_t + \mathbf{g}_t, \end{aligned} \end{equation*}\end{split}\]

where \(\denoiser{\cdot}{\sigma}\) is a denoising network for noise level \(\sigma\), \(\eta\) is a hyperparameter, and the constants \(\tilde{\sigma}_t, a_t, b_t\) are defined as

\[\begin{split}\begin{equation*} \begin{aligned} \tilde{\sigma}_t &= \eta \sqrt{ (1 - \frac{\overline{\alpha}_t}{\overline{\alpha}_{t-1}}) \frac{1 - \overline{\alpha}_{t-1}}{1 - \overline{\alpha}_t}} \\ a_t &= \sqrt{1 - \overline{\alpha}_{t-1} - \tilde{\sigma}_t^2}/\sqrt{1-\overline{\alpha}_t} \\ b_t &= \sqrt{\overline{\alpha}_{t-1}} - \sqrt{1 - \overline{\alpha}_{t-1} - \tilde{\sigma}_t^2} \frac{\sqrt{\overline{\alpha}_{t}}}{\sqrt{1 - \overline{\alpha}_{t}}}. \end{aligned} \end{equation*}\end{split}\]
Parameters:
  • model (torch.nn.Module) – a denoiser network that can handle different noise levels

  • data_fidelity (deepinv.optim.DataFidelity) – the data fidelity operator

  • max_iter (int) – the number of diffusion iterations to run the algorithm (default: 1000)

  • eta (float) – DDIM hyperparameter which controls the stochasticity

  • verbose (bool) – if True, print progress

  • device (str) – the device to use for the computations

compute_alpha_betas()[source]#

Get the beta and alpha sequences for the algorithm. This is necessary for mapping noise levels to timesteps.

forward(y, physics: Physics, seed=None, x_init=None)[source]#

Applies reconstruction model \(\inversef{y}{A}\).

Parameters:
Returns:

(torch.Tensor) reconstructed tensor.

Examples using DPS:#

Implementing DPS

Implementing DPS