- class deepinv.sampling.SKRock(prior, data_fidelity, step_size=1.0, inner_iter=10, eta=0.05, alpha=1.0, max_iter=1e3, burnin_ratio=0.2, thinning=10, clip=(-1.0, 2.0), thresh_conv=1e-3, save_chain=False, g_statistic=lambda x: ..., verbose=False, sigma=0.05)[source]#
Plug-and-Play SKROCK algorithm.
Obtains samples of the posterior distribution using an orthogonal Runge-Kutta-Chebyshev stochastic approximation to accelerate the standard Unadjusted Langevin Algorithm.
The algorithm was introduced in “Accelerating proximal Markov chain Monte Carlo by using an explicit stabilised method” by L. Vargas, M. Pereyra and K. Zygalakis (
SKROCK assumes that the denoiser is \(L\)-Lipschitz differentiable
For convergence, SKROCK required step_size smaller than \(\frac{1}{L+\|A\|_2^2}\)
- Parameters:
prior (deepinv.optim.ScorePrior, torch.nn.Module) – negative log-prior based on a trained or model-based denoiser.
data_fidelity (deepinv.optim.DataFidelity, torch.nn.Module) – negative log-likelihood function linked with the noise distribution in the acquisition physics.
step_size (float) – Step size of the algorithm. Tip: use physics.lipschitz to compute the Lipschitz
eta (float) – \(\eta\) SKROCK damping parameter.
alpha (float) – regularization parameter \(\alpha\).
inner_iter (int) – Number of inner SKROCK iterations.
max_iter (int) – Number of outer iterations.
thinning (int) – Thins the Markov Chain by an integer \(\geq 1\) (i.e., keeping one out of
samples to compute posterior statistics).burnin_ratio (float) – percentage of iterations used for burn-in period. The burn-in samples are discarded constant with a numerical algorithm.
clip (tuple) – Tuple containing the box-constraints \([a,b]\). If
, the algorithm will not project the samples.verbose (bool) – prints progress of the algorithm.
sigma (float) – noise level used in the plug-and-play prior denoiser. A larger value of sigma will result in a more regularized reconstruction.
g_statistic (Callable) – The sampler will compute the posterior mean and variance of the function g_statistic. By default, it is the identity function (lambda x: x), and thus the sampler computes the posterior mean and variance.