SKRock#

class deepinv.sampling.SKRock(prior, data_fidelity, step_size=1.0, inner_iter=10, eta=0.05, alpha=1.0, max_iter=1e3, burnin_ratio=0.2, thinning=10, clip=(-1.0, 2.0), thresh_conv=1e-3, save_chain=False, verbose=False, sigma=0.05)[source]#

Bases: BaseSampling

Plug-and-Play SKROCK algorithm.

Obtains samples of the posterior distribution using an orthogonal Runge-Kutta-Chebyshev stochastic approximation to accelerate the standard Unadjusted Langevin Algorithm.

The algorithm was introduced by Pereyra et al.[1].

  • SKROCK assumes that the denoiser is \(L\)-Lipschitz differentiable

  • For convergence, SKROCK required step_size smaller than \(\frac{1}{L+\|A\|_2^2}\)

Warning

This a legacy class provided for convenience. See the example in Markov Chain Monte Carlo for details on how to build a SKRock sampler.

Parameters:
  • prior (deepinv.optim.ScorePrior, torch.nn.Module) – negative log-prior based on a trained or model-based denoiser.

  • data_fidelity (deepinv.optim.DataFidelity, torch.nn.Module) – negative log-likelihood function linked with the noise distribution in the acquisition physics.

  • step_size (float) – Step size of the algorithm. Tip: use physics.lipschitz to compute the Lipschitz

  • eta (float) – \(\eta\) SKROCK damping parameter.

  • alpha (float) – regularization parameter \(\alpha\).

  • inner_iter (int) – Number of inner SKROCK iterations.

  • max_iter (int) – Number of outer iterations.

  • thinning (int) – Thins the Markov Chain by an integer \(\geq 1\) (i.e., keeping one out of thinning samples to compute posterior statistics).

  • burnin_ratio (float) – percentage of iterations used for burn-in period. The burn-in samples are discarded constant with a numerical algorithm.

  • clip (tuple) – Tuple containing the box-constraints \([a,b]\). If None, the algorithm will not project the samples.

  • verbose (bool) – prints progress of the algorithm.

  • sigma (float) – noise level used in the plug-and-play prior denoiser. A larger value of sigma will result in a more regularized reconstruction.


References:

forward(y, physics, seed=None, x_init=None, g_statistics=lambda d: ...)[source]#

Runs the chain to obtain the posterior mean and variance of the reconstruction of the measurements y.

Parameters:
  • y (torch.Tensor) – Measurements

  • physics (deepinv.physics.Physics) – Forward operator associated with the measurements

  • seed (float) – Random seed for generating the Monte Carlo samples

  • g_statistics (List[Callable] | Callable) – List of functions for which to compute posterior statistics, or a single function. The sampler will compute the posterior mean and variance of each function in the list. Note the sampler outputs a dictionary so they must act on d["x"]. Default: lambda d: d["x"] (identity function)

Returns:

(tuple of torch.tensor) containing the posterior mean and variance.

Return type:

tuple[Tensor, Tensor]