DiffusionSampler#

class deepinv.sampling.DiffusionSampler(diffusion, max_iter=1e2, clip=(-1, 2), thres_conv=1e-1, g_statistic=lambda x: ..., verbose=True, save_chain=False)[source]#

Bases: BaseSampling

Turns a diffusion method into a Monte Carlo sampler

Unlike diffusion methods, the resulting sampler computes the mean and variance of the distribution by running the diffusion multiple times.

See the docs for deepinv.sampling.BaseSampling for more information. It uses the helper class deepinv.sampling.DiffusionIterator.

Parameters:
  • diffusion (torch.nn.Module) – a diffusion model

  • max_iter (int) – the number of samples to generate

  • clip (tuple) – the clip range

  • g_statistic (Callable) – the algorithm computes mean and variance of the g function, by default \(g(x) = x\).

  • thres_conv (float) – the convergence threshold for the mean and variance

  • verbose (bool) – whether to print the progress

  • save_chain (bool) – whether to save the chain

  • thinning (int) – the thinning factor

  • burnin_ratio (float) – the burnin ratio

forward(y, physics, seed=None)[source]#

Runs the diffusion model to obtain the posterior mean and variance of the reconstruction of the measurements y.

Parameters:
Returns:

(tuple of torch.tensor) containing the posterior mean and variance.

Examples using DiffusionSampler:#

Image reconstruction with a diffusion model

Image reconstruction with a diffusion model