DiffusionSampler#
- class deepinv.sampling.DiffusionSampler(diffusion, max_iter=1e2, clip=(-1, 2), thres_conv=1e-1, g_statistic=lambda x: ..., verbose=True, save_chain=False)[source]#
Bases:
BaseSampling
Turns a diffusion method into a Monte Carlo sampler
Unlike diffusion methods, the resulting sampler computes the mean and variance of the distribution by running the diffusion multiple times.
See the docs for
deepinv.sampling.BaseSampling
for more information. It uses the helper classdeepinv.sampling.DiffusionIterator
.- Parameters:
diffusion (torch.nn.Module) – a diffusion model
max_iter (int) – the number of samples to generate
clip (tuple) – the clip range
g_statistic (Callable) – the algorithm computes mean and variance of the g function, by default \(g(x) = x\).
thres_conv (float) – the convergence threshold for the mean and variance
verbose (bool) – whether to print the progress
save_chain (bool) – whether to save the chain
thinning (int) – the thinning factor
burnin_ratio (float) – the burnin ratio
- forward(y, physics, seed=None)[source]#
Runs the diffusion model to obtain the posterior mean and variance of the reconstruction of the measurements y.
- Parameters:
y (torch.Tensor) – Measurements
physics (deepinv.physics.Physics) – Forward operator associated with the measurements
seed (float) – Random seed for generating the samples
- Returns:
(tuple of torch.tensor) containing the posterior mean and variance.