DiffusionSDE#

class deepinv.sampling.DiffusionSDE(forward_drift, forward_diffusion, alpha=1.0, denoiser=None, rescale=False, solver=None, dtype=torch.float64, device=torch.device('cpu'), *args, **kwargs)[source]#

Bases: BaseSDE

Reverse-time Diffusion Stochastic Differential Equation defined by

\[d\, x_{t} = \left( f(x_t, t) - \frac{1 + \alpha}{2} g(t)^2 \nabla \log p_t(x_t) \right) d\,t + g(t) \sqrt{\alpha} d\, w_{t}.\]
Parameters:
  • drift (Callable) – a time-dependent drift function \(f(x, t)\) of the forward-time SDE.

  • diffusion (Callable) – a time-dependent diffusion function \(g(t)\) of the forward-time SDE.

  • alpha (Callable) – a scalar weighting the diffusion term. \(\alpha = 0\) corresponds to the ODE sampling and \(\alpha > 0\) corresponds to the SDE sampling.

  • deepinv.models.Denoiser – a denoiser used to provide an approximation of the score at time \(t\) \(\nabla \log p_t\).

  • rescale (bool) – whether to rescale the input to the denoiser to \([-1, 1]\), default to False.

  • solver (deepinv.sampling.BaseSDESolver) – the solver for solving the SDE.

  • dtype (torch.dtype) – data type of the computation, except for the denoiser which will use torch.float32. We recommend using torch.float64 for better stability and less numerical error when solving the SDE in discrete time, since most computation cost is from evaluating the denoiser, which will be always computed in torch.float32.

  • device (torch.device) – device on which the computation is performed.

score(x, t, *args, **kwargs)[source]#

Approximating the score function \(\nabla \log p_t\) by the denoiser.

Parameters:
  • x (torch.Tensor) – current state

  • t (torch.Tensor, float) – current time step

  • *args – additional arguments for the denoiser.

  • **kwargs – additional keyword arguments for the denoiser, e.g., class_labels for class-conditional models.

Returns:

the score function \(\nabla \log p_t(x)\).

Return type:

torch.Tensor

sigma_t(t)[source]#

The std of the condition distribution \(p(x_t \vert x_0) \sim \mathcal{N}(..., \sigma_t^2 \mathrm{Id})\).

Parameters:

t (torch.Tensor, float) – current time step

Returns:

the noise level at time step t.

Return type:

torch.Tensor

Examples using DiffusionSDE:#

Posterior Sampling for Inverse Problems with Stochastic Differential Equations modeling.

Posterior Sampling for Inverse Problems with Stochastic Differential Equations modeling.