SKRockIterator#

class deepinv.sampling.SKRockIterator(algo_params, clip=None)[source]#

Bases: SamplingIterator

Single iteration of the SK-ROCK (Stabilized Runge-Kutta-Chebyshev) Algorithm.

Obtains samples of the posterior distribution using an orthogonal Runge-Kutta-Chebyshev stochastic approximation to accelerate the standard Unadjusted Langevin Algorithm.

The algorithm was introduced in “Accelerating proximal Markov chain Monte Carlo by using an explicit stabilised method” by L. Vargas, M. Pereyra and K. Zygalakis (https://arxiv.org/abs/1908.08845)

  • SKROCK assumes that the denoiser is \(L\)-Lipschitz differentiable

  • For convergence, SKROCK requires that step_size smaller than \(\frac{1}{L+\|A\|_2^2}\)

Parameters:
  • clip (tuple(int,int)) – Tuple of (min, max) values to clip/project the samples into a bounded range during sampling. Useful for images where pixel values should stay within a specific range (e.g., (0,1) or (0,255)). Default: None

  • algo_params (dict) – Dictionary containing the algorithm parameters (see table below)

Parameter

Type

Description

step_size

float

Step size of the algorithm (default: 1.0). Tip: use physics.lipschitz to compute the Lipschitz constant

alpha

float

Regularization parameter \(\alpha\) (default: 1.0)

inner_iter

int

Number of internal iterations (default: 10)

eta

float

Damping parameter \(\eta\) (default: 0.05)

sigma

float

Noise level for the score prior denoiser (default: 0.05). A larger value of sigma will result in a more regularized reconstruction

forward(X, y, physics, cur_data_fidelity, cur_prior, iteration, *args, **kwargs)[source]#

Performs a single SK-ROCK sampling step.

Parameters:
  • X (Dict) – Dictionary containing the current state \(x_t\).

  • y (torch.Tensor) – Observed measurements/data tensor

  • physics (Physics) – Forward operator

  • cur_data_fidelity (DataFidelity) – Negative log-likelihood function

  • cur_prior (ScorePrior) – Prior

Returns:

Dictionary {"x": x} containing the next state \(x_{t+1}\) in the Markov chain.

Return type:

Dict