DiffusionSDE#
- class deepinv.sampling.DiffusionSDE(forward_drift, forward_diffusion, alpha=1.0, denoiser=None, rescale=False, solver=None, dtype=torch.float64, device=torch.device('cpu'), *args, **kwargs)[source]#
Bases:
BaseSDE
Reverse-time Diffusion Stochastic Differential Equation defined by
\[d\, x_{t} = \left( f(x_t, t) - \frac{1 + \alpha}{2} g(t)^2 \nabla \log p_t(x_t) \right) d\,t + g(t) \sqrt{\alpha} d\, w_{t}.\]- Parameters:
drift (Callable) – a time-dependent drift function \(f(x, t)\) of the forward-time SDE.
diffusion (Callable) – a time-dependent diffusion function \(g(t)\) of the forward-time SDE.
alpha (Callable) – a scalar weighting the diffusion term. \(\alpha = 0\) corresponds to the ODE sampling and \(\alpha > 0\) corresponds to the SDE sampling.
deepinv.models.Denoiser – a denoiser used to provide an approximation of the score at time \(t\) \(\nabla \log p_t\).
rescale (bool) – whether to rescale the input to the denoiser to \([-1, 1]\), default to
False
.solver (deepinv.sampling.BaseSDESolver) – the solver for solving the SDE.
dtype (torch.dtype) – data type of the computation, except for the
denoiser
which will usetorch.float32
. We recommend usingtorch.float64
for better stability and less numerical error when solving the SDE in discrete time, since most computation cost is from evaluating thedenoiser
, which will be always computed intorch.float32
.device (torch.device) – device on which the computation is performed.
- score(x, t, *args, **kwargs)[source]#
Approximating the score function \(\nabla \log p_t\) by the denoiser.
- Parameters:
x (torch.Tensor) – current state
t (torch.Tensor, float) – current time step
*args – additional arguments for the
denoiser
.**kwargs – additional keyword arguments for the
denoiser
, e.g.,class_labels
for class-conditional models.
- Returns:
the score function \(\nabla \log p_t(x)\).
- Return type:
- sigma_t(t)[source]#
The std of the condition distribution \(p(x_t \vert x_0) \sim \mathcal{N}(..., \sigma_t^2 \mathrm{Id})\).
- Parameters:
t (torch.Tensor, float) – current time step
- Returns:
the noise level at time step
t
.- Return type:
Examples using DiffusionSDE
:#

Posterior Sampling for Inverse Problems with Stochastic Differential Equations modeling.