DiffusionSDE#
- class deepinv.sampling.DiffusionSDE(forward_drift, forward_diffusion, alpha=1.0, denoiser=None, solver=None, minus_one_one=True, dtype=torch.float64, device=torch.device('cpu'), *args, **kwargs)[source]#
Bases:
BaseSDEDefine the Reverse-time Diffusion Stochastic Differential Equation.
Given a forward-time SDE of the form:
\[d x_t = f(x_t, t) dt + g(t)d w_t\]This class define the following reverse-time SDE:
\[d x_{t} = \left( f(x_t, t) - \frac{1 + \alpha}{2} g(t)^2 \nabla \log p_t(x_t) \right) dt + g(t) \sqrt{\alpha} d w_{t}.\]- Parameters:
drift (Callable) – a time-dependent drift function \(f(x, t)\) of the forward-time SDE.
diffusion (Callable) – a time-dependent diffusion function \(g(t)\) of the forward-time SDE.
alpha (Callable) – a scalar weighting the diffusion term. \(\alpha = 0\) corresponds to the ODE sampling and \(\alpha > 0\) corresponds to the SDE sampling.
deepinv.models.Denoiser – a denoiser used to provide an approximation of the score at time \(t\) \(\nabla \log p_t\).
solver (deepinv.sampling.BaseSDESolver) – the solver for solving the SDE.
minus_one_one (bool) – If
True, wrap the denoiser so that SDE statesxin [-1, 1] are converted to [0, 1] before denoising and mapped back afterward. SetTruefor denoisers trained on [0, 1] (all denoisers indeepinv.models.Denoiser); setFalseonly if the denoiser natively expects [-1, 1]. This affects only the denoiser interface and usually improves quality when matched to the denoiser’s training range. Default:True.dtype (torch.dtype) – data type of the computation, except for the
denoiserwhich will usetorch.float32. We recommend usingtorch.float64for better stability and less numerical error when solving the SDE in discrete time, since most computation cost is from evaluating thedenoiser, which will be always computed intorch.float32.device (torch.device) – device on which the computation is performed.
*args – additional arguments for the
deepinv.sampling.BaseSDE.**kwargs – additional keyword arguments for the
deepinv.sampling.BaseSDE.
- scale_t(t)[source]#
The scale \(s(t)\) of the condition distribution \(p(x_t \vert x_0) \sim \mathcal{N}(s(t)x_0, s(t)^2 \sigma_t^2 \mathrm{Id})\).
- Parameters:
t (torch.Tensor, float) – current time step
- Returns:
the mean of the condition distribution at time step
t.- Return type:
- score(x, t, *args, **kwargs)[source]#
Approximating the score function \(\nabla \log p_t\) by the denoiser.
- Parameters:
x (torch.Tensor) – current state
t (torch.Tensor, float) – current time step
*args – additional arguments for the
denoiser.**kwargs – additional keyword arguments for the
denoiser, e.g.,class_labelsfor class-conditional models.
- Returns:
the score function \(\nabla \log p_t(x)\).
- Return type:
- sigma_t(t)[source]#
The \(\sigma(t)\) of the condition distribution \(p(x_t \vert x_0) \sim \mathcal{N}(s(t)x_0, s(t)^2 \sigma_t^2 \mathrm{Id})\).
- Parameters:
t (torch.Tensor, float) – current time step
- Returns:
the noise level at time step
t.- Return type:
Examples using DiffusionSDE:#
Building your diffusion posterior sampling method using SDEs