PosteriorDiffusion#
- class deepinv.sampling.PosteriorDiffusion(data_fidelity=ZeroFidelity(), denoiser=None, sde=None, solver=None, rescale=False, dtype=torch.float64, device=torch.device('cpu'), *args, **kwargs)[source]#
Bases:
Reconstructor
Posterior distribution sampling for inverse problems using diffusion models by Reverse-time Stochastic Differential Equation (SDE).
Consider the acquisition model:
\[y = \noise{\forw{x}}.\]This class defines the reverse-time SDE for the posterior distribution \(p(x|y)\) given the data \(y\):
\[d\, x_t = \left( f(x_t, t) - \frac{1 + \alpha}{2} g(t)^2 \nabla_{x_t} \log p_t(x_t | y) \right) d\,t + g(t) \sqrt{\alpha} d\, w_{t}\]where \(f\) is the drift term, \(g\) is the diffusion coefficient and \(w\) is the standard Brownian motion. The drift term and the diffusion coefficient are defined by the underlying (unconditional) forward-time SDE
sde
. The (conditional) score function \(\nabla_{x_t} \log p_t(x_t | y)\) can be decomposed using the Bayes’ rule:\[\nabla_{x_t} \log p_t(x_t | y) = \nabla_{x_t} \log p_t(x_t) + \nabla_{x_t} \log p_t(y | x_t).\]The first term is the score function of the unconditional SDE, which is typically approximated by a MMSE denoiser using the well-known Tweedie’s formula, while the second term is approximated by the (noisy) data-fidelity term. We implement various data-fidelity terms in
deepinv.sampling.NoisyDataFidelity
.- Parameters:
data_fidelity (deepinv.sampling.NoisyDataFidelity) – the noisy data-fidelity term, used to approximate the score \(\nabla_{x_t} \log p_t(y \vert x_t)\). Default to
deepinv.optim.ZeroFidelity
, which corresponds to the zero data-fidelity term and the sampling process boils down to the unconditional SDE sampling.denoiser (deepinv.models.Denoiser) – a denoiser used to provide an approximation of the (unconditional) score at time \(t\) \(\nabla \log p_t\).
sde (deepinv.sampling.DiffusionSDE) – the forward-time SDE, which defines the drift and diffusion terms of the reverse-time SDE.
solver (deepinv.sampling.BaseSDESolver) – the solver for the SDE. If not specified, the solver from the
sde
will be used.rescale (bool) – whether to rescale the input to the denoiser to [-1, 1].
dtype (torch.dtype) – the data type of the sampling solver, except for the
denoiser
which will usetorch.float32
. We recommend usingtorch.float64
for better stability and less numerical error when solving the SDE in discrete time, since most computation cost is from evaluating thedenoiser
, which will be always computed intorch.float32
.device (torch.device) – the device for the computations.
- forward(y, physics, x_init=None, seed=None, timesteps=None, get_trajectory=False, *args, **kwargs)[source]#
Sample the posterior distribution \(p(x|y)\) given the data measurement \(y\).
- Parameters:
y (torch.Tensor) – the data measurement.
physics (deepinv.physics.Physics) – the forward operator.
x_init (torch.Tensor, tuple) – the initial value for the sampling, can be a
torch.Tensor
or a tuple(B, C, H, W)
, indicating the shape of the initial point, matching the shape ofphysics
andy
. In this case, the initial value is taken randomly following the end-point distribution of thesde
.seed (int) – the random seed.
timesteps (torch.Tensor) – the time steps for the solver. If
None
, the default time steps in the solver will be used.get_trajectory (bool) – whether to return the full trajectory of the SDE or only the last sample, optional. Default to
False
.*args – the additional arguments for the solver.
**kwargs – the additional keyword arguments for the solver.
- Returns:
the generated sample (
torch.Tensor
of shape(B, C, H, W)
) ifget_trajectory
isFalse
. Otherwise, returns a tuple (torch.Tensor
,torch.Tensor
) of shape(B, C, H, W)
and(N, B, C, H, W)
whereN
is the number of steps.
- score(y, physics, x, t, *args, **kwargs)[source]#
Approximating the conditional score \(\nabla_{x_t} \log p_t(x_t \vert y)\).
- Parameters:
y (torch.Tensor) – the data measurement.
physics (deepinv.physics.Physics) – the forward operator.
x (torch.Tensor) – the current state.
t (torch.Tensor, float) – the current time step.
*args – additional arguments for the score function of the unconditional SDE.
**kwargs – additional keyword arguments for the score function of the unconditional SDE.
- Returns:
the score function \(\nabla_{x_t} \log p_t(x_t \vert y)\).
- Return type:
Examples using PosteriorDiffusion
:#

Posterior Sampling for Inverse Problems with Stochastic Differential Equations modeling.