SKRock#

class deepinv.sampling.SKRock(prior, data_fidelity, step_size=1.0, inner_iter=10, eta=0.05, alpha=1.0, max_iter=1e3, burnin_ratio=0.2, thinning=10, clip=(-1.0, 2.0), thresh_conv=1e-3, save_chain=False, verbose=False, sigma=0.05)[source]#

Bases: BaseSampling

Plug-and-Play SKROCK algorithm.

Obtains samples of the posterior distribution using an orthogonal Runge-Kutta-Chebyshev stochastic approximation to accelerate the standard Unadjusted Langevin Algorithm.

The algorithm was introduced in “Accelerating proximal Markov chain Monte Carlo by using an explicit stabilised method” by L. Vargas, M. Pereyra and K. Zygalakis (https://arxiv.org/abs/1908.08845)

  • SKROCK assumes that the denoiser is \(L\)-Lipschitz differentiable

  • For convergence, SKROCK required step_size smaller than \(\frac{1}{L+\|A\|_2^2}\)

Warning

This a legacy class provided for convenience. See the example in Markov Chain Monte Carlo for details on how to build a SKRock sampler.

Parameters:
  • prior (deepinv.optim.ScorePrior, torch.nn.Module) – negative log-prior based on a trained or model-based denoiser.

  • data_fidelity (deepinv.optim.DataFidelity, torch.nn.Module) – negative log-likelihood function linked with the noise distribution in the acquisition physics.

  • step_size (float) – Step size of the algorithm. Tip: use physics.lipschitz to compute the Lipschitz

  • eta (float) – \(\eta\) SKROCK damping parameter.

  • alpha (float) – regularization parameter \(\alpha\).

  • inner_iter (int) – Number of inner SKROCK iterations.

  • max_iter (int) – Number of outer iterations.

  • thinning (int) – Thins the Markov Chain by an integer \(\geq 1\) (i.e., keeping one out of thinning samples to compute posterior statistics).

  • burnin_ratio (float) – percentage of iterations used for burn-in period. The burn-in samples are discarded constant with a numerical algorithm.

  • clip (tuple) – Tuple containing the box-constraints \([a,b]\). If None, the algorithm will not project the samples.

  • verbose (bool) – prints progress of the algorithm.

  • sigma (float) – noise level used in the plug-and-play prior denoiser. A larger value of sigma will result in a more regularized reconstruction.

forward(y, physics, seed=None, x_init=None, g_statistics=lambda d: ...)[source]#

Runs the chain to obtain the posterior mean and variance of the reconstruction of the measurements y.

Parameters:
  • y (torch.Tensor) – Measurements

  • physics (deepinv.physics.Physics) – Forward operator associated with the measurements

  • seed (float) – Random seed for generating the Monte Carlo samples

  • g_statistics (List[Callable] | Callable) – List of functions for which to compute posterior statistics, or a single function. The sampler will compute the posterior mean and variance of each function in the list. Note the sampler outputs a dictionary so they must act on d["x"]. Default: lambda d: d["x"] (identity function)

Returns:

(tuple of torch.tensor) containing the posterior mean and variance.

Return type:

tuple[Tensor, Tensor]