DiffusionSDE#
- class deepinv.sampling.DiffusionSDE(forward_drift, forward_diffusion, alpha=1.0, denoiser=None, solver=None, dtype=torch.float64, device=torch.device('cpu'), *args, **kwargs)[source]#
Bases:
BaseSDE
Define the Reverse-time Diffusion Stochastic Differential Equation.
Given a forward-time SDE of the form:
\[d x_t = f(x_t, t) dt + g(t)d w_t\]This class define the following reverse-time SDE:
\[d x_{t} = \left( f(x_t, t) - \frac{1 + \alpha}{2} g(t)^2 \nabla \log p_t(x_t) \right) dt + g(t) \sqrt{\alpha} d w_{t}.\]- Parameters:
drift (Callable) – a time-dependent drift function \(f(x, t)\) of the forward-time SDE.
diffusion (Callable) – a time-dependent diffusion function \(g(t)\) of the forward-time SDE.
alpha (Callable) – a scalar weighting the diffusion term. \(\alpha = 0\) corresponds to the ODE sampling and \(\alpha > 0\) corresponds to the SDE sampling.
deepinv.models.Denoiser – a denoiser used to provide an approximation of the score at time \(t\) \(\nabla \log p_t\).
solver (deepinv.sampling.BaseSDESolver) – the solver for solving the SDE.
dtype (torch.dtype) – data type of the computation, except for the
denoiser
which will usetorch.float32
. We recommend usingtorch.float64
for better stability and less numerical error when solving the SDE in discrete time, since most computation cost is from evaluating thedenoiser
, which will be always computed intorch.float32
.device (torch.device) – device on which the computation is performed.
- scale_t(t)[source]#
The scale \(s(t)\) of the condition distribution \(p(x_t \vert x_0) \sim \mathcal{N}(s(t)x_0, s(t)^2 \sigma_t^2 \mathrm{Id})\).
- Parameters:
t (torch.Tensor, float) – current time step
- Returns:
the mean of the condition distribution at time step
t
.- Return type:
- score(x, t, *args, **kwargs)[source]#
Approximating the score function \(\nabla \log p_t\) by the denoiser.
- Parameters:
x (torch.Tensor) – current state
t (torch.Tensor, float) – current time step
*args – additional arguments for the
denoiser
.**kwargs – additional keyword arguments for the
denoiser
, e.g.,class_labels
for class-conditional models.
- Returns:
the score function \(\nabla \log p_t(x)\).
- Return type:
- sigma_t(t)[source]#
The \(\sigma(t)\) of the condition distribution \(p(x_t \vert x_0) \sim \mathcal{N}(s(t)x_0, s(t)^2 \sigma_t^2 \mathrm{Id})\).
- Parameters:
t (torch.Tensor, float) – current time step
- Returns:
the noise level at time step
t
.- Return type:
Examples using DiffusionSDE
:#

Building your diffusion posterior sampling method using SDEs