VariancePreservingDiffusion#
- class deepinv.sampling.VariancePreservingDiffusion(denoiser=None, beta_min=0.1, beta_max=20.0, alpha=1.0, solver=None, dtype=torch.float64, device=torch.device('cpu'), *args, **kwargs)[source]#
Bases:
DiffusionSDE
Variance-Preserving Stochastic Differential Equation (VP-SDE)
The forward-time SDE is defined as follows:
\[d x_t = -\frac{1}{2} \beta(t) x_t dt + \sqrt{\beta(t)} d w_t \quad \mbox{ where } \beta(t) = \beta_{\mathrm{min}} + t \left( \beta_{\mathrm{max}} - \beta_{\mathrm{min}} \right)\]The reverse-time SDE is defined as follows:
\[d x_t = -\left(\frac{1}{2} \beta(t) x_t + \frac{1 + \alpha}{2} \beta(t) \nabla \log p_t(x_t) \right) dt + \sqrt{\alpha \beta(t)} d w_t\]where \(\alpha \in [0,1]\) is a constant weighting the diffusion term.
This class is the reverse-time SDE of the VE-SDE, serving as the generation process.
- Parameters:
denoiser (deepinv.models.Denoiser) – a denoiser used to provide an approximation of the score at time \(t\) \(\nabla \log p_t\).
beta_min (float) – the minimum noise level.
beta_max (float) – the maximum noise level.
alpha (float) – the weighting factor of the diffusion term.
solver (deepinv.sampling.BaseSDESolver) – the solver for solving the SDE.
dtype (torch.dtype) – data type of the computation, except for the
denoiser
which will usetorch.float32
. We recommend usingtorch.float64
for better stability and less numerical error when solving the SDE in discrete time, since most computation cost is from evaluating thedenoiser
, which will be always computed intorch.float32
.device (torch.device) – device on which the computation is performed.
- sample_init(shape, rng)[source]#
Sample from the initial distribution of the reverse-time diffusion SDE, which is the standard Gaussian distribution.
- Parameters:
shape (tuple) – The shape of the sample to generate
rng (torch.Generator) – Random number generator for reproducibility
- Returns:
A sample from the prior distribution
- Return type:
Examples using VariancePreservingDiffusion
:#

Building your diffusion posterior sampling method using SDEs