VarianceExplodingDiffusion#

class deepinv.sampling.VarianceExplodingDiffusion(denoiser=None, rescale=False, sigma_min=0.02, sigma_max=100, alpha=1.0, solver=None, dtype=torch.float64, device=torch.device('cpu'), *args, **kwargs)[source]#

Bases: DiffusionSDE

Variance-Exploding Stochastic Differential Equation (VE-SDE)

The forward-time SDE is defined as follows:

\[d\, x_t = \sigma(t) d\, w_t \quad \mbox{where } \sigma(t) = \sigma_{\mathrm{min}} \left( \frac{\sigma_{\mathrm{max}}}{\sigma_{\mathrm{min}}} \right)^t\]

This class is the reverse-time SDE of the VE-SDE, serving as the generation process.

Parameters:
  • denoiser (deepinv.models.Denoiser) – a denoiser used to provide an approximation of the score at time \(t\) \(\nabla \log p_t\).

  • rescale (bool) – whether to rescale the input to the denoiser to [-1, 1].

  • sigma_min (float) – the minimum noise level.

  • sigma_max (float) – the maximum noise level.

  • alpha (float) – the weighting factor of the diffusion term.

  • solver (deepinv.sampling.BaseSDESolver) – the solver for solving the SDE.

  • dtype (torch.dtype) – data type of the computation, except for the denoiser which will use torch.float32. We recommend using torch.float64 for better stability and less numerical error when solving the SDE in discrete time, since most computation cost is from evaluating the denoiser, which will be always computed in torch.float32.

  • device (torch.device) – device on which the computation is performed.

sample_init(shape, rng)[source]#

Sample from the initial distribution of the reverse-time diffusion SDE, which is a Gaussian with zero mean and covariance matrix \(\sigma_{max}^2 \operatorname{Id}\).

Parameters:
  • shape (tuple) – The shape of the sample to generate

  • rng (torch.Generator) – Random number generator for reproducibility

Returns:

A sample from the prior distribution

Return type:

torch.Tensor

score(x, t, *args, **kwargs)[source]#

Approximating the score function \(\nabla \log p_t\) by the denoiser.

Parameters:
  • x (torch.Tensor) – current state

  • t (torch.Tensor, float) – current time step

  • *args – additional arguments for the denoiser.

  • **kwargs – additional keyword arguments for the denoiser, e.g., class_labels for class-conditional models.

Returns:

the score function \(\nabla \log p_t(x)\).

Return type:

torch.Tensor

sigma_t(t)[source]#

The std of the condition distribution \(p(x_t \vert x_0) \sim \mathcal{N}(..., \sigma_t^2 \mathrm{Id})\).

Parameters:

t (torch.Tensor, float) – current time step

Returns:

the noise level at time step t.

Return type:

torch.Tensor

Examples using VarianceExplodingDiffusion:#

Posterior Sampling for Inverse Problems with Stochastic Differential Equations modeling.

Posterior Sampling for Inverse Problems with Stochastic Differential Equations modeling.