DPS#
- class deepinv.sampling.DPS(model, data_fidelity, max_iter=1000, eta=1.0, verbose=False, device='cpu', save_iterates=False)[source]#
Bases:
Reconstructor
Diffusion Posterior Sampling (DPS).
This class implements the Diffusion Posterior Sampling algorithm (DPS) described in https://arxiv.org/abs/2209.14687.
DPS is an approximation of a gradient-based posterior sampling algorithm, which has minimal assumptions on the forward model. The only restriction is that the measurement model has to be differentiable, which is generally the case.
The algorithm writes as follows, for \(t\) decreasing from \(T\) to \(1\):
\[\begin{split}\begin{equation*} \begin{aligned} \widehat{\mathbf{x}}_{t} &= D_{\theta}(\mathbf{x}_t, \sqrt{1-\overline{\alpha}_t}/\sqrt{\overline{\alpha}_t}) \\ \mathbf{g}_t &= \nabla_{\mathbf{x}_t} \log p( \widehat{\mathbf{x}}_{t}(\mathbf{x}_t) | \mathbf{y} ) \\ \mathbf{\varepsilon}_t &= \mathcal{N}(0, \mathbf{I}) \\ \mathbf{x}_{t-1} &= a_t \,\, \mathbf{x}_t + b_t \, \, \widehat{\mathbf{x}}_t + \tilde{\sigma}_t \, \, \mathbf{\varepsilon}_t + \mathbf{g}_t, \end{aligned} \end{equation*}\end{split}\]where \(\denoiser{\cdot}{\sigma}\) is a denoising network for noise level \(\sigma\), \(\eta\) is a hyperparameter, and the constants \(\tilde{\sigma}_t, a_t, b_t\) are defined as
\[\begin{split}\begin{equation*} \begin{aligned} \tilde{\sigma}_t &= \eta \sqrt{ (1 - \frac{\overline{\alpha}_t}{\overline{\alpha}_{t-1}}) \frac{1 - \overline{\alpha}_{t-1}}{1 - \overline{\alpha}_t}} \\ a_t &= \sqrt{1 - \overline{\alpha}_{t-1} - \tilde{\sigma}_t^2}/\sqrt{1-\overline{\alpha}_t} \\ b_t &= \sqrt{\overline{\alpha}_{t-1}} - \sqrt{1 - \overline{\alpha}_{t-1} - \tilde{\sigma}_t^2} \frac{\sqrt{\overline{\alpha}_{t}}}{\sqrt{1 - \overline{\alpha}_{t}}}. \end{aligned} \end{equation*}\end{split}\]- Parameters:
model (torch.nn.Module) – a denoiser network that can handle different noise levels
data_fidelity (deepinv.optim.DataFidelity) – the data fidelity operator
max_iter (int) – the number of diffusion iterations to run the algorithm (default: 1000)
eta (float) – DDIM hyperparameter which controls the stochasticity
verbose (bool) – if True, print progress
device (str) – the device to use for the computations
- compute_alpha_betas()[source]#
Get the beta and alpha sequences for the algorithm. This is necessary for mapping noise levels to timesteps.
- forward(y, physics: Physics, seed=None, x_init=None)[source]#
Runs the diffusion to obtain a random sample of the posterior distribution.
- Parameters:
y (torch.Tensor) – the measurements.
physics (deepinv.physics.LinearPhysics) – the physics operator.
seed (int) – the seed for the random number generator.
x_init (torch.Tensor) – the initial guess for the reconstruction.