PosteriorDiffusion#

class deepinv.sampling.PosteriorDiffusion(data_fidelity=ZeroFidelity(), denoiser=None, sde=None, solver=None, rescale=False, dtype=torch.float64, device=torch.device('cpu'), *args, **kwargs)[source]#

Bases: Reconstructor

Posterior distribution sampling for inverse problems using diffusion models by Reverse-time Stochastic Differential Equation (SDE).

Consider the acquisition model:

\[y = \noise{\forw{x}}.\]

This class defines the reverse-time SDE for the posterior distribution \(p(x|y)\) given the data \(y\):

\[d\, x_t = \left( f(x_t, t) - \frac{1 + \alpha}{2} g(t)^2 \nabla_{x_t} \log p_t(x_t | y) \right) d\,t + g(t) \sqrt{\alpha} d\, w_{t}\]

where \(f\) is the drift term, \(g\) is the diffusion coefficient and \(w\) is the standard Brownian motion. The drift term and the diffusion coefficient are defined by the underlying (unconditional) forward-time SDE sde. The (conditional) score function \(\nabla_{x_t} \log p_t(x_t | y)\) can be decomposed using the Bayes’ rule:

\[\nabla_{x_t} \log p_t(x_t | y) = \nabla_{x_t} \log p_t(x_t) + \nabla_{x_t} \log p_t(y | x_t).\]

The first term is the score function of the unconditional SDE, which is typically approximated by a MMSE denoiser using the well-known Tweedie’s formula, while the second term is approximated by the (noisy) data-fidelity term. We implement various data-fidelity terms in deepinv.sampling.NoisyDataFidelity.

Parameters:
  • data_fidelity (deepinv.sampling.NoisyDataFidelity) – the noisy data-fidelity term, used to approximate the score \(\nabla_{x_t} \log p_t(y \vert x_t)\). Default to deepinv.optim.ZeroFidelity, which corresponds to the zero data-fidelity term and the sampling process boils down to the unconditional SDE sampling.

  • denoiser (deepinv.models.Denoiser) – a denoiser used to provide an approximation of the (unconditional) score at time \(t\) \(\nabla \log p_t\).

  • sde (deepinv.sampling.DiffusionSDE) – the forward-time SDE, which defines the drift and diffusion terms of the reverse-time SDE.

  • solver (deepinv.sampling.BaseSDESolver) – the solver for the SDE. If not specified, the solver from the sde will be used.

  • rescale (bool) – whether to rescale the input to the denoiser to [-1, 1].

  • dtype (torch.dtype) – the data type of the sampling solver, except for the denoiser which will use torch.float32. We recommend using torch.float64 for better stability and less numerical error when solving the SDE in discrete time, since most computation cost is from evaluating the denoiser, which will be always computed in torch.float32.

  • device (torch.device) – the device for the computations.

forward(y, physics, x_init=None, seed=None, timesteps=None, get_trajectory=False, *args, **kwargs)[source]#

Sample the posterior distribution \(p(x|y)\) given the data measurement \(y\).

Parameters:
  • y (torch.Tensor) – the data measurement.

  • physics (deepinv.physics.Physics) – the forward operator.

  • x_init (torch.Tensor, tuple) – the initial value for the sampling, can be a torch.Tensor or a tuple (B, C, H, W), indicating the shape of the initial point, matching the shape of physics and y. In this case, the initial value is taken randomly following the end-point distribution of the sde.

  • seed (int) – the random seed.

  • timesteps (torch.Tensor) – the time steps for the solver. If None, the default time steps in the solver will be used.

  • get_trajectory (bool) – whether to return the full trajectory of the SDE or only the last sample, optional. Default to False.

  • *args – the additional arguments for the solver.

  • **kwargs – the additional keyword arguments for the solver.

Returns:

the generated sample (torch.Tensor of shape (B, C, H, W)) if get_trajectory is False. Otherwise, returns a tuple (torch.Tensor, torch.Tensor) of shape (B, C, H, W) and (N, B, C, H, W) where N is the number of steps.

score(y, physics, x, t, *args, **kwargs)[source]#

Approximating the conditional score \(\nabla_{x_t} \log p_t(x_t \vert y)\).

Parameters:
  • y (torch.Tensor) – the data measurement.

  • physics (deepinv.physics.Physics) – the forward operator.

  • x (torch.Tensor) – the current state.

  • t (torch.Tensor, float) – the current time step.

  • *args – additional arguments for the score function of the unconditional SDE.

  • **kwargs – additional keyword arguments for the score function of the unconditional SDE.

Returns:

the score function \(\nabla_{x_t} \log p_t(x_t \vert y)\).

Return type:

torch.Tensor

Examples using PosteriorDiffusion:#

Posterior Sampling for Inverse Problems with Stochastic Differential Equations modeling.

Posterior Sampling for Inverse Problems with Stochastic Differential Equations modeling.