VarianceExplodingDiffusion#
- class deepinv.sampling.VarianceExplodingDiffusion(denoiser=None, rescale=False, sigma_min=0.02, sigma_max=100, alpha=1.0, solver=None, dtype=torch.float64, device=torch.device('cpu'), *args, **kwargs)[source]#
Bases:
DiffusionSDE
Variance-Exploding Stochastic Differential Equation (VE-SDE)
The forward-time SDE is defined as follows:
\[d\, x_t = \sigma(t) d\, w_t \quad \mbox{where } \sigma(t) = \sigma_{\mathrm{min}} \left( \frac{\sigma_{\mathrm{max}}}{\sigma_{\mathrm{min}}} \right)^t\]This class is the reverse-time SDE of the VE-SDE, serving as the generation process.
- Parameters:
denoiser (deepinv.models.Denoiser) – a denoiser used to provide an approximation of the score at time \(t\) \(\nabla \log p_t\).
rescale (bool) – whether to rescale the input to the denoiser to [-1, 1].
sigma_min (float) – the minimum noise level.
sigma_max (float) – the maximum noise level.
alpha (float) – the weighting factor of the diffusion term.
solver (deepinv.sampling.BaseSDESolver) – the solver for solving the SDE.
dtype (torch.dtype) – data type of the computation, except for the
denoiser
which will usetorch.float32
. We recommend usingtorch.float64
for better stability and less numerical error when solving the SDE in discrete time, since most computation cost is from evaluating thedenoiser
, which will be always computed intorch.float32
.device (torch.device) – device on which the computation is performed.
- sample_init(shape, rng)[source]#
Sample from the initial distribution of the reverse-time diffusion SDE, which is a Gaussian with zero mean and covariance matrix \(\sigma_{max}^2 \operatorname{Id}\).
- Parameters:
shape (tuple) – The shape of the sample to generate
rng (torch.Generator) – Random number generator for reproducibility
- Returns:
A sample from the prior distribution
- Return type:
- score(x, t, *args, **kwargs)[source]#
Approximating the score function \(\nabla \log p_t\) by the denoiser.
- Parameters:
x (torch.Tensor) – current state
t (torch.Tensor, float) – current time step
*args – additional arguments for the
denoiser
.**kwargs – additional keyword arguments for the
denoiser
, e.g.,class_labels
for class-conditional models.
- Returns:
the score function \(\nabla \log p_t(x)\).
- Return type:
- sigma_t(t)[source]#
The std of the condition distribution \(p(x_t \vert x_0) \sim \mathcal{N}(..., \sigma_t^2 \mathrm{Id})\).
- Parameters:
t (torch.Tensor, float) – current time step
- Returns:
the noise level at time step
t
.- Return type:
Examples using VarianceExplodingDiffusion
:#

Posterior Sampling for Inverse Problems with Stochastic Differential Equations modeling.