EDMDiffusionSDE#
- class deepinv.sampling.EDMDiffusionSDE(sigma_t, scale_t=None, sigma_prime_t=None, scale_prime_t=None, variance_preserving=False, variance_exploding=False, alpha=1.0, T=1.0, denoiser=None, solver=None, dtype=torch.float64, device=torch.device('cpu'), *args, **kwargs)[source]#
Bases:
DiffusionSDEGenerative diffusion Stochastic Differential Equation.
This class implements the diffusion generative SDE based on the formulation from Karras et al.[1] (with \(\beta(t) = \alpha(t) s(t)^2 \sigma(t) \sigma'(t)\)):
\[d x_t = \left(\frac{s'(t)}{s(t)} x_t - (1 + \alpha(t)) s(t)^2 \sigma(t) \sigma'(t) \nabla \log p_t(x_t) \right) dt + s(t) \sqrt{2 \alpha(t) \sigma(t) \sigma'(t)} d w_t\]where \(s(t)\) is a time-dependent scale, \(\sigma(t)\) is a time-dependent noise level, and \(\alpha(t)\) is weighting the diffusion term. It corresponds to the reverse-time SDE of the following forward-time SDE:
\[d x_t = \frac{s'(t)}{s(t)} x_t dt + s(t) \sqrt{2 \sigma(t) \sigma'(t)} d w_t\]The scale \(s(t)\) and noise \(\sigma(t)\) schedulers must satisfy \(s(0) = 1\), \(\sigma(0) = 0\) and \(\lim_{t \to \infty} \sigma(t) = +\infty\).
Common choices include the variance-preserving formulation \(s(t) = \left(1 + \sigma(t)^2\right)^{-1/2}\) and the variance-exploding formulation \(s(t) = 1\).
For choosing variance-preserving formulation, set
variance_preserving=Trueand do not providescale_tandscale_prime_t.For choosing variance-exploding formulation, set
variance_exploding=Trueand do not providescale_tandscale_prime_t.
Note
This SDE must be solved by going reverse in time i.e. from \(t=T\) to \(t=0\).
- Parameters:
sigma_t (Callable) – a time-dependent noise level schedule. It takes a time step
t(either a Pythonfloator atorch.Tensor) as input and returns the noise level at timet(either a Pythonfloator atorch.Tensor). Note that this is a required argument.scale_t (Callable) – a time-dependent scale schedule. It takes a time step
t(either a Pythonfloator atorch.Tensor) as input and returns the noise level at timet(either a Pythonfloator atorch.Tensor). If not provided, it will be set to \(s(t) = \left(1 + \sigma(t)^2\right)^{-1/2}\) ifvariance_preserving=True, or \(s(t) = 1\) ifvariance_exploding=True. If bothvariance_preservingandvariance_explodingareFalse,scale_tmust be provided. Default toNone.sigma_prime_t (Callable) – the derivative of
sigma_t. It takes a time stept(either a Pythonfloator atorch.Tensor) as input and returns the noise level at timet(either a Pythonfloator atorch.Tensor). If not provided, it will be computed using autograd. Default toNone.scale_prime_t (Callable) – the derivative of
scale_t. It takes a time stept(either a Pythonfloator atorch.Tensor) as input and returns the noise level at timet(either a Pythonfloator atorch.Tensor). If not provided, it will be computed using autograd. Default toNone.variance_preserving (bool) – whether to use a variance-preserving diffusion schedule, which imposes \(s(t) = \left(1 + \sigma(t)^2\right)^{-1/2}\). Default to
False.variance_exploding (bool) – whether to use a variance-exploding diffusion schedule, which imposes \(s(t) = 1\). Default to
False.alpha (Callable, float) – a (possibly time-dependent) positive scalar weighting the diffusion term. A constant function \(\alpha(t) = 0\) corresponds to ODE sampling and \(\alpha(t) > 0\) corresponds to SDE sampling.
T (float) – the end time of the forward SDE. Default to
1.0.denoiser (deepinv.models.Denoiser) – a denoiser used to provide an approximation of the score at time \(t\): \(\nabla \log p_t\). Default to
None.solver (deepinv.sampling.BaseSDESolver) – the solver for solving the SDE. Default to
None.dtype (torch.dtype) – data type of the computation, except for the
denoiserwhich will usetorch.float32. We recommend usingtorch.float64for better stability and less numerical error when solving the SDE in discrete time, since most computation cost is from evaluating thedenoiser, which will be always computed intorch.float32.device (torch.device) – device on which the computation is performed. Default to CPU.
*args – additional arguments for the
deepinv.sampling.DiffusionSDE.**kwargs – additional keyword arguments for the
deepinv.sampling.DiffusionSDE.
- References:
- sample_init(shape, rng)[source]#
Sample from the initial distribution of the reverse-time diffusion SDE, which is a Gaussian with zero mean and covariance matrix :math:` s(T)^2 sigma(T)^2 operatorname{Id}`.
- Parameters:
shape (tuple) – The shape of the sample to generate
rng (torch.Generator) – Random number generator for reproducibility
- Returns:
A sample from the prior distribution
- Return type:
- score(x, t, *args, **kwargs)[source]#
Approximating the score function \(\nabla \log p_t\) by the denoiser.
- Parameters:
x (torch.Tensor) – current state
t (torch.Tensor, float) – current time step
*args – additional arguments for the
denoiser.**kwargs – additional keyword arguments for the
denoiser, e.g.,class_labelsfor class-conditional models.
- Returns:
the score function \(\nabla \log p_t(x)\).
- Return type:
Examples using EDMDiffusionSDE:#
Using state-of-the-art diffusion models from HuggingFace Diffusers with DeepInverse
Building your diffusion posterior sampling method using SDEs
Flow-Matching for posterior sampling and unconditional generation