SamplingIterator#

class deepinv.sampling.SamplingIterator(algo_params, **kwargs)[source]#

Bases: Module

Base class for sampling iterators.

All samplers should implement the forward method which performs one step of the Markov chain Monte Carlo sampling process, generating the next state \(X_{t+1}\) given the current state \(X_t\). Where \(X_t\) is a dict containing the image \(x_t\) as well as any latent variables. See the docs for deepinv.sampling.BaseSampling for an example along with more information.

Parameters:

algo_params (dict) – Dictionary containing the parameters for the sampling algorithm

forward(X, y, physics, cur_data_fidelity, cur_prior, iteration, *args, **kwargs)[source]#

Performs a single sampling step: \(X_t \rightarrow X_{t+1}, where :math:`X_t\) is a dict containing the image \(x_t\) as well as any latents`

Parameters:
  • X (Dict) – Dictionary containing the current image \(X_t\) of the Markov chain along with any latent variables.

  • y (torch.Tensor) – Observed measurements/data tensor

  • physics (Physics) – Forward operator

  • cur_data_fidelity (DataFidelity) – Negative log-likelihood

  • cur_prior (Prior) – Negative log-prior term

  • iteration (int) – Current iteration number in the sampling process (zero-indexed)

  • args – Additional positional arguments

  • kwargs – Additional keyword arguments

Returns:

Dictionary {"x": x, ...} containing the next state along with any latent variables.

Return type:

dict[str, Any]

initialize_latent_variables(x_init, y, physics, cur_data_fidelity, cur_prior)[source]#

Initializes latent variables for the sampling iterator.

This method is intended to be overridden by subclasses to initialize any latent variables required by the specific sampling algorithm. The default implementation simply returns the initial state x in a dictionary.

Parameters:
  • x_init (torch.Tensor) – Initial state tensor.

  • y (torch.Tensor) – Observed measurements/data tensor.

  • physics (Physics) – Forward operator.

  • cur_data_fidelity (DataFidelity) – Negative log-likelihood.

  • cur_prior (Prior) – Negative log-prior term.

Returns:

Dictionary containing the initial state x and any latent variables.

Return type:

dict[str, Any]

Examples using SamplingIterator:#

Building your custom MCMC sampling algorithm.

Building your custom MCMC sampling algorithm.