MonteCarlo
- class deepinv.sampling.MonteCarlo(iterator: ~torch.nn.modules.module.Module, prior: ~deepinv.optim.prior.ScorePrior, data_fidelity: ~deepinv.optim.data_fidelity.DataFidelity, max_iter=1000.0, burnin_ratio=0.2, thinning=10, clip=(-1.0, 2.0), thresh_conv=0.001, crit_conv='residual', save_chain=False, g_statistic=<function MonteCarlo.<lambda>>, verbose=False)[source]
Bases:
Module
Base class for Monte Carlo sampling.
This class can be used to create new Monte Carlo samplers, by only defining their kernel inside a torch.nn.Module:
# define custom sampling kernel (possibly a Markov kernel which depends on the previous sample). class MyKernel(torch.torch.nn.Module): def __init__(self, iterator_params): super().__init__() self.iterator_params = iterator_params def forward(self, x, y, physics, likelihood, prior): # run one sampling kernel iteration new_x = f(x, y, physics, likelihood, prior, self.iterator_params) return new_x class MySampler(MonteCarlo): def __init__(self, prior, data_fidelity, iterator_params, max_iter=1e3, burnin_ratio=.1, clip=(-1,2), verbose=True): # generate an iterator iterator = MyKernel(step_size=step_size, alpha=alpha) # set the params of the base class super().__init__(iterator, prior, data_fidelity, max_iter=max_iter, burnin_ratio=burnin_ratio, clip=clip, verbose=verbose) # create the sampler sampler = MySampler(prior, data_fidelity, iterator_params) # compute posterior mean and variance of reconstruction of measurement y mean, var = sampler(y, physics)
This class computes the mean and variance of the chain using Welford’s algorithm, which avoids storing the whole Monte Carlo samples.
- Parameters:
prior (deepinv.optim.ScorePrior) – negative log-prior based on a trained or model-based denoiser.
data_fidelity (deepinv.optim.DataFidelity) – negative log-likelihood function linked with the noise distribution in the acquisition physics.
max_iter (int) – number of Monte Carlo iterations.
thinning (int) – thins the Monte Carlo samples by an integer \(\geq 1\) (i.e., keeping one out of
thinning
samples to compute posterior statistics).burnin_ratio (float) – percentage of iterations used for burn-in period, should be set between 0 and 1. The burn-in samples are discarded constant with a numerical algorithm.
clip (tuple) – Tuple containing the box-constraints \([a,b]\). If
None
, the algorithm will not project the samples.crit_conv (float) – Threshold for verifying the convergence of the mean and variance estimates.
g_statistic (function_handle) – The sampler will compute the posterior mean and variance of the function g_statistic. By default, it is the identity function (lambda x: x), and thus the sampler computes the posterior mean and variance.
verbose (bool) – prints progress of the algorithm.
- forward(y, physics, seed=None, x_init=None)[source]
Runs an Monte Carlo chain to obtain the posterior mean and variance of the reconstruction of the measurements y.
- Parameters:
y (torch.Tensor) – Measurements
physics (deepinv.physics.Physics) – Forward operator associated with the measurements
seed (float) – Random seed for generating the Monte Carlo samples
- Returns:
(tuple of torch.tensor) containing the posterior mean and variance.
- get_chain()[source]
Returns the thinned Monte Carlo samples (after burn-in iterations). Requires
save_chain=True
.
Examples using MonteCarlo
:
Uncertainty quantification with PnP-ULA.
Image reconstruction with a diffusion model
Building your custom sampling algorithm.