Diffusion and MCMC Algorithms#
This package contains posterior sampling algorithms, based on diffusion models and Markov Chain Monte Carlo (MCMC) methods.
These methods build a Markov chain
such that the samples \(x_t\) for large \(t\) are approximately sampled according to the posterior distribution \(p(x|y)\).
Diffusion models#
We provide a unified framework for image generation using diffusion models.
Diffusion models for posterior sampling are defined using deepinv.sampling.PosteriorDiffusion
,
which is a subclass of deepinv.models.Reconstructor
.
Below, we explain the main components of the diffusion models, see Building your diffusion posterior sampling method using SDEs for an example usage and visualizations.
Stochastic Differential Equations#
We define diffusion models as Stochastic Differential Equations (SDEs).
The forward-time SDE is defined as follows, from time \(0\) to \(T\):
where \(w_t\) is a Brownian process. Let \(p_t\) denote the distribution of the random vector \(x_t\). Under this forward process, we have that:
where the scaling over time is \(s(t) = \exp\left( \int_0^t f(r) d\,r \right)\) and the normalized noise level is \(\sigma(t) = \sqrt{\int_0^t \frac{g(r)^2}{s(r)^2} d\,r}\).
The reverse-time SDE is defined as follows, running backwards in time (from \(T\) to \(0\)):
where \(\alpha \in [0,1]\) is a scalar weighting the diffusion term (\(\alpha = 0\) corresponds to the ordinary differential equation (ODE) sampling and \(\alpha > 0\) corresponds to the SDE sampling), and \(\nabla \log p_{t}(x_t)\) is the score function that can be approximated by (a properly scaled version of) Tweedie’s formula:
where \(\denoiser{\cdot}{\sigma}\) is a denoiser trained to denoise images with noise level \(\sigma\) that is \(\denoiser{x+\sigma\omega}{\sigma} \approx \mathbb{E} [ x|x+\sigma\omega ]\) with \(\omega\sim\mathcal{N}(0,\mathrm{I})\).
Note
Using a normalized noise levels \(\sigma(t)\) and scalings \(s(t)\) lets us use any denoiser in the library trained for multiple noise levels assuming pixel values are in the range \([0,1]\).
Starting from a random point following the end-point distribution \(p_T\) of the forward process, solving the reverse-time SDE gives us a sample of the data distribution \(p_0\).
The base classes for defining a SDEs are deepinv.sampling.BaseSDE
and deepinv.sampling.DiffusionSDE
.
SDE |
\(f(t)\) |
\(g(t)\) |
Scaling \(s(t)\) |
Noise level \(\sigma(t)\) |
---|---|---|---|---|
\(0\) |
\(\sigma_{\mathrm{min}}\left(\frac{\sigma_{\mathrm{max}}}{\sigma_{\mathrm{min}}}\right)^t\) |
\(1\) |
\(\sigma_{\mathrm{min}}\left(\frac{\sigma_{\mathrm{max}}}{\sigma_{\mathrm{min}}}\right)^t\) |
|
\(-\frac{1}{2}\left(\beta_{\mathrm{min}} + t \beta_d \right)\) |
\(\sqrt{\beta_{\mathrm{min}} + t \beta_{d}}\) |
\(1/\sqrt{e^{\frac{1}{2}\beta_{d}t^2+\beta_{\mathrm{min}}t}}\) |
\(\sqrt{e^{\frac{1}{2}\beta_{d}t^2+\beta_{\mathrm{min}}t}-1}\) |
Solvers#
Once the SDE is defined, we can obtain an approximate sample with any of the following solvers:
Method |
Description |
---|---|
The base class for solvers is deepinv.sampling.BaseSDESolver
, and deepinv.sampling.SDEOutput
provides a container for storing the output of the solver.
Posterior sampling#
In the case of posterior sampling, we need simply to replace the (unconditional) score function \(\nabla \log p_t(x_t)\) by the conditional score function \(\nabla \log p_t(x_t|y)\). The conditional score can be decomposed using the Bayes’ rule:
The first term is the unconditional score function and can be approximated by using a denoiser as explained previously.
The second term is the conditional score function, and can be approximated by the (noisy) data-fidelity term.
We implement the following data-fidelity terms, which inherit from the deepinv.sampling.NoisyDataFidelity
base class.
Class |
\(\nabla_x \log p_t(y|x + \epsilon\sigma(t))\) |
---|---|
\(\nabla_x \frac{\lambda}{2\sqrt{m}} \| \forw{\denoiser{x}{\sigma}} - y \|\) |
Popular posterior samplers#
We also provide custom implementations of some popular diffusion methods for posterior sampling, which can be used directly without the need to define the SDE and the solvers.
Method |
Description |
Limitations |
---|---|---|
Diffusion Denoising Restoration Models |
Only for |
|
Diffusion PnP Image Restoration |
Only for |
|
Diffusion Posterior Sampling |
Can be slow, requires backpropagation through the denoiser. |
Uncertainty quantification#
Diffusion methods obtain a single sample per call. If multiple samples are required, the
deepinv.sampling.DiffusionSampler
can be used to convert a diffusion method into a sampler that
obtains multiple samples to compute posterior statistics such as the mean or variance.
It uses the helper class deepinv.sampling.DiffusionIterator
to interface diffusion samplers with deepinv.sampling.BaseSampling
.
Markov Chain Monte Carlo#
Markov Chain Monte Carlo (MCMC) methods build a chain of samples which aim at sampling the negative-log-posterior distribution:
where \(x\) is the image to be reconstructed, \(y\) are the measurements, \(d(Ax,y) \propto - \log p(y|x,A)\) is the negative log-likelihood and \(\reg{x} \propto - \log p_{\sigma}(x)\) is the negative log-prior.
The negative log likelihood can be chosen from this list, and the negative log prior can be approximated using deepinv.optim.ScorePrior
with a
pretrained denoiser, which leverages Tweedie’s formula with \(\sigma\) is typically set to a small value.
Unlike diffusion sampling methods, MCMC methods generally use a fixed noise level \(\sigma\) during the sampling process, i.e.,
\(\nabla \log p_t(x_t) = \frac{\left(\denoiser{x_t}{\sigma} - x_t \right)}{\sigma^2}\).
Note
The approximation of the prior obtained via
deepinv.optim.ScorePrior
is also valid for maximum-a-posteriori (MAP) denoisers,
but \(p_{\sigma}(x)\) is not given by the convolution with a Gaussian kernel, but rather
given by the Moreau-Yosida envelope of \(p(x)\), i.e.,
All MCMC methods inherit from deepinv.sampling.BaseSampling
.
The function deepinv.sampling.sampling_builder()
returns an instance of deepinv.sampling.BaseSampling
with the
optimization algorithm of choice, either a predefined one ("SKRock"
, "ULA"
),
or with a user-defined one (an instance of deepinv.sampling.SamplingIterator
). For example, we can use ULA with a score prior:
model = dinv.sampling.sampling_builder(iteration="ULA", prior=prior, data_fidelity=data_fidelity,
params = {"step_size": step_size, "alpha": alpha, "sigma": sigma}, max_iter=max_iter)
x_hat = model(y, physics)
We provide a very flexible framework for MCMC algorithms, providing some predefined algorithms alongside making it easy to implement your own custom sampling algorithms.
This is achieved by creating your own sampling iterator, which involves subclassing deepinv.sampling.SamplingIterator
. See deepinv.sampling.SamplingIterator
for a short example.
A custom iterator needs to implement two methods:
initialize_latent_variables(self, x_init, y, physics, data_fidelity, prior)
: This method sets up the initial state of your Markov chain. It receives the initial image estimate \(x_{\text{init}}\), measurements \(y\), the physics operator, data fidelity term, and prior. It should return a dictionary representing the initial state \(X_0\), which must include the image as{"x": x_init, ...}
and can include any other latent variables your sampler requires. The default (non overridden) behavior is returning{"x":x_init}
forward(self, X, y, physics, data_fidelity, prior, iteration_number, **iterator_specific_params)
: This method defines a single step of your MCMC algorithm. It takes the previous state \(X\) (a dictionary containing at least the previous image{"x": x, ...}
), measurements \(y\), the data fidelity, the prior, and returns the new state \(X_{next}\) (again, a dictionary including{"x": x_next, ...}
).
Some predefined iterators are provided:
Algorithm |
Parameters |
---|---|
|
|
|
|
No parameters, see the uncertainty quantification section above. |
Some legacy predefined classes are also provided:
Method |
Description |
---|---|
Unadjusted Langevin algorithm. |
|
Runge-Kutta-Chebyshev stochastic approximation to accelerate the standard Unadjusted Langevin Algorithm. |