DPS#
- class deepinv.sampling.DPS(model, data_fidelity, max_iter=1000, eta=1.0, verbose=False, device='cpu', save_iterates=False)[source]#
Bases:
Reconstructor
Diffusion Posterior Sampling (DPS).
This class implements the Diffusion Posterior Sampling algorithm (DPS) described in https://arxiv.org/abs/2209.14687.
DPS is an approximation of a gradient-based posterior sampling algorithm, which has minimal assumptions on the forward model. The only restriction is that the measurement model has to be differentiable, which is generally the case.
The algorithm writes as follows, for \(t\) decreasing from \(T\) to \(1\):
\[\begin{split}\begin{equation*} \begin{aligned} \widehat{\mathbf{x}}_{t} &= D_{\theta}(\mathbf{x}_t, \sqrt{1-\overline{\alpha}_t}/\sqrt{\overline{\alpha}_t}) \\ \mathbf{g}_t &= \nabla_{\mathbf{x}_t} \log p( \widehat{\mathbf{x}}_{t}(\mathbf{x}_t) | \mathbf{y} ) \\ \mathbf{\varepsilon}_t &= \mathcal{N}(0, \mathbf{I}) \\ \mathbf{x}_{t-1} &= a_t \,\, \mathbf{x}_t + b_t \, \, \widehat{\mathbf{x}}_t + \tilde{\sigma}_t \, \, \mathbf{\varepsilon}_t + \mathbf{g}_t, \end{aligned} \end{equation*}\end{split}\]where \(\denoiser{\cdot}{\sigma}\) is a denoising network for noise level \(\sigma\), \(\eta\) is a hyperparameter, and the constants \(\tilde{\sigma}_t, a_t, b_t\) are defined as
\[\begin{split}\begin{equation*} \begin{aligned} \tilde{\sigma}_t &= \eta \sqrt{ (1 - \frac{\overline{\alpha}_t}{\overline{\alpha}_{t-1}}) \frac{1 - \overline{\alpha}_{t-1}}{1 - \overline{\alpha}_t}} \\ a_t &= \sqrt{1 - \overline{\alpha}_{t-1} - \tilde{\sigma}_t^2}/\sqrt{1-\overline{\alpha}_t} \\ b_t &= \sqrt{\overline{\alpha}_{t-1}} - \sqrt{1 - \overline{\alpha}_{t-1} - \tilde{\sigma}_t^2} \frac{\sqrt{\overline{\alpha}_{t}}}{\sqrt{1 - \overline{\alpha}_{t}}}. \end{aligned} \end{equation*}\end{split}\]- Parameters:
model (torch.nn.Module) – a denoiser network that can handle different noise levels
data_fidelity (deepinv.optim.DataFidelity) – the data fidelity operator
max_iter (int) – the number of diffusion iterations to run the algorithm (default: 1000)
eta (float) – DDIM hyperparameter which controls the stochasticity
verbose (bool) – if True, print progress
device (str) – the device to use for the computations
- compute_alpha_betas()[source]#
Get the beta and alpha sequences for the algorithm. This is necessary for mapping noise levels to timesteps.
- forward(y, physics: Physics, seed=None, x_init=None)[source]#
Applies reconstruction model \(\inversef{y}{A}\).
- Parameters:
y (torch.Tensor) – measurements.
physics (deepinv.physics.Physics) – forward model \(A\).
- Returns:
(torch.Tensor) reconstructed tensor.