VariancePreservingDiffusion#

class deepinv.sampling.VariancePreservingDiffusion(denoiser=None, beta_min=0.1, beta_max=20.0, alpha=0.0, scaled_linear=False, solver=None, dtype=torch.float64, device=torch.device('cpu'), *args, **kwargs)[source]#

Bases: SongDiffusionSDE

Variance-Preserving Stochastic Differential Equation (VP-SDE).

This class implements the reverse-time SDE of the Variance-Preserving SDE (VP-SDE) Song et al.[1].

The forward-time SDE is defined as follows:

\[d x_t = -\frac{1}{2} \beta(t) x_t dt + \sqrt{\beta(t)} d w_t \quad \mbox{ where } \beta(t) = \beta_{\mathrm{min}} + t \left( \beta_{\mathrm{max}} - \beta_{\mathrm{min}} \right)\]

The reverse-time SDE is defined as follows:

\[d x_t = -\left(\frac{1}{2} \beta(t) x_t + \frac{1 + \alpha(t)}{2} \beta(t) \nabla \log p_t(x_t) \right) dt + \sqrt{\alpha(t) \beta(t)} d w_t\]

where \(\alpha(t)\) is weighting the diffusion term.

This class is the reverse-time SDE of the VP-SDE, serving as the generation process.

Note

This SDE must be solved going reverse in time i.e. from \(t=T\) to \(t=0\).

Parameters:
  • denoiser (deepinv.models.Denoiser) – a denoiser used to provide an approximation of the score at time \(t\): \(\nabla \log p_t\).

  • beta_min (float) – the minimum noise level.

  • beta_max (float) – the maximum noise level.

  • alpha (Callable, float) – a (possibly time-dependent) positive scalar weighting the diffusion term. A constant function \(\alpha(t) = 0\) corresponds to ODE sampling and \(\alpha(t) > 0\) corresponds to SDE sampling.

  • scaled_linear (bool) – whether to use the scaled linear beta schedule. If False, uses the more standard linear schedule. Default to False.

  • solver (deepinv.sampling.BaseSDESolver) – the solver for solving the SDE.

  • dtype (torch.dtype) – data type of the computation, except for the denoiser which will use torch.float32. We recommend using torch.float64 for better stability and less numerical error when solving the SDE in discrete time, since most computation cost is from evaluating the denoiser, which will be always computed in torch.float32.

  • device (torch.device) – device on which the computation is performed.

  • *args – additional arguments for the deepinv.sampling.DiffusionSDE.

  • **kwargs – additional keyword arguments for the deepinv.sampling.DiffusionSDE.


References:

Examples using VariancePreservingDiffusion:#

Building your diffusion posterior sampling method using SDEs

Building your diffusion posterior sampling method using SDEs