EDMDiffusionSDE#

class deepinv.sampling.EDMDiffusionSDE(sigma_t, scale_t=None, sigma_prime_t=None, scale_prime_t=None, variance_preserving=False, variance_exploding=False, alpha=1.0, T=1.0, denoiser=None, solver=None, dtype=torch.float64, device=torch.device('cpu'), *args, **kwargs)[source]#

Bases: DiffusionSDE

Generative diffusion Stochastic Differential Equation.

This class implements the diffusion generative SDE based on the formulation from Karras et al.[1] (with \(\beta(t) = \alpha(t) s(t)^2 \sigma(t) \sigma'(t)\)):

\[d x_t = \left(\frac{s'(t)}{s(t)} x_t - (1 + \alpha(t)) s(t)^2 \sigma(t) \sigma'(t) \nabla \log p_t(x_t) \right) dt + s(t) \sqrt{2 \alpha(t) \sigma(t) \sigma'(t)} d w_t\]

where \(s(t)\) is a time-dependent scale, \(\sigma(t)\) is a time-dependent noise level, and \(\alpha(t)\) is weighting the diffusion term. It corresponds to the reverse-time SDE of the following forward-time SDE:

\[d x_t = \frac{s'(t)}{s(t)} x_t dt + s(t) \sqrt{2 \sigma(t) \sigma'(t)} d w_t\]

The scale \(s(t)\) and noise \(\sigma(t)\) schedulers must satisfy \(s(0) = 1\), \(\sigma(0) = 0\) and \(\lim_{t \to \infty} \sigma(t) = +\infty\).

Common choices include the variance-preserving formulation \(s(t) = \left(1 + \sigma(t)^2\right)^{-1/2}\) and the variance-exploding formulation \(s(t) = 1\).

  • For choosing variance-preserving formulation, set variance_preserving=True and do not provide scale_t and scale_prime_t.

  • For choosing variance-exploding formulation, set variance_exploding=True and do not provide scale_t and scale_prime_t.

Note

This SDE must be solved by going reverse in time i.e. from \(t=T\) to \(t=0\).

Parameters:
  • sigma_t (Callable) – a time-dependent noise level schedule. It takes a time step t (either a Python float or a torch.Tensor) as input and returns the noise level at time t (either a Python float or a torch.Tensor). Note that this is a required argument.

  • scale_t (Callable) – a time-dependent scale schedule. It takes a time step t (either a Python float or a torch.Tensor) as input and returns the noise level at time t (either a Python float or a torch.Tensor). If not provided, it will be set to \(s(t) = \left(1 + \sigma(t)^2\right)^{-1/2}\) if variance_preserving=True, or \(s(t) = 1\) if variance_exploding=True. If both variance_preserving and variance_exploding are False, scale_t must be provided. Default to None.

  • sigma_prime_t (Callable) – the derivative of sigma_t. It takes a time step t (either a Python float or a torch.Tensor) as input and returns the noise level at time t (either a Python float or a torch.Tensor). If not provided, it will be computed using autograd. Default to None.

  • scale_prime_t (Callable) – the derivative of scale_t. It takes a time step t (either a Python float or a torch.Tensor) as input and returns the noise level at time t (either a Python float or a torch.Tensor). If not provided, it will be computed using autograd. Default to None.

  • variance_preserving (bool) – whether to use a variance-preserving diffusion schedule, which imposes \(s(t) = \left(1 + \sigma(t)^2\right)^{-1/2}\). Default to False.

  • variance_exploding (bool) – whether to use a variance-exploding diffusion schedule, which imposes \(s(t) = 1\). Default to False.

  • alpha (Callable, float) – a (possibly time-dependent) positive scalar weighting the diffusion term. A constant function \(\alpha(t) = 0\) corresponds to ODE sampling and \(\alpha(t) > 0\) corresponds to SDE sampling.

  • T (float) – the end time of the forward SDE. Default to 1.0.

  • denoiser (deepinv.models.Denoiser) – a denoiser used to provide an approximation of the score at time \(t\): \(\nabla \log p_t\). Default to None.

  • solver (deepinv.sampling.BaseSDESolver) – the solver for solving the SDE. Default to None.

  • dtype (torch.dtype) – data type of the computation, except for the denoiser which will use torch.float32. We recommend using torch.float64 for better stability and less numerical error when solving the SDE in discrete time, since most computation cost is from evaluating the denoiser, which will be always computed in torch.float32.

  • device (torch.device) – device on which the computation is performed. Default to CPU.

  • *args – additional arguments for the deepinv.sampling.DiffusionSDE.

  • **kwargs – additional keyword arguments for the deepinv.sampling.DiffusionSDE.


References:

sample_init(shape, rng)[source]#

Sample from the initial distribution of the reverse-time diffusion SDE, which is a Gaussian with zero mean and covariance matrix :math:` s(T)^2 sigma(T)^2 operatorname{Id}`.

Parameters:
  • shape (tuple) – The shape of the sample to generate

  • rng (torch.Generator) – Random number generator for reproducibility

Returns:

A sample from the prior distribution

Return type:

torch.Tensor

score(x, t, *args, **kwargs)[source]#

Approximating the score function \(\nabla \log p_t\) by the denoiser.

Parameters:
  • x (torch.Tensor) – current state

  • t (torch.Tensor, float) – current time step

  • *args – additional arguments for the denoiser.

  • **kwargs – additional keyword arguments for the denoiser, e.g., class_labels for class-conditional models.

Returns:

the score function \(\nabla \log p_t(x)\).

Return type:

Tensor

Examples using EDMDiffusionSDE:#

Using state-of-the-art diffusion models from HuggingFace Diffusers with DeepInverse

Using state-of-the-art diffusion models from HuggingFace Diffusers with DeepInverse

Building your diffusion posterior sampling method using SDEs

Building your diffusion posterior sampling method using SDEs

Flow-Matching for posterior sampling and unconditional generation

Flow-Matching for posterior sampling and unconditional generation