OptimIterator

class deepinv.optim.OptimIterator(g_first=False, F_fn=None, has_cost=False, **kwargs)[source]

Bases: Module

Base class for all Optim() iterators.

An optim iterator is an object that implements a fixed point iteration for minimizing the sum of two functions \(F = f + \lambda g\) where \(f\) is a data-fidelity term that will be modeled by an instance of physics and g is a regularizer. The fixed point iteration takes the form

\[\qquad (x_{k+1}, z_{k+1}) = \operatorname{FixedPoint}(x_k, z_k, f, g, A, y, ...)\]

where \(x\) is a “primal” variable converging to the solution of the minimization problem, and \(z\) is a “dual” variable.

Note

By an abuse of terminology, we call “primal” and “dual” variables the variables that are updated at each step and which may correspond to the actual primal and dual variables from (for instance in the case of the PD algorithm), but not necessarily (for instance in the case of the PGD algorithm).

The implementation of the fixed point algorithm in deepinv.optim() is split in two steps, alternating between a step on f and a step on g, that is for \(k=1,2,...\)

\[\begin{split}z_{k+1} = \operatorname{step}_f(x_k, z_k, y, A, ...)\\ x_{k+1} = \operatorname{step}_g(x_k, z_k, y, A, ...)\end{split}\]

where \(\operatorname{step}_f\) and \(\operatorname{step}_g\) are the steps on f and g respectively.

Parameters:
  • g_first (bool) – If True, the algorithm starts with a step on g and finishes with a step on f.

  • F_fn – function that returns the function F to be minimized at each iteration. Default: None.

  • has_cost (bool) – If True, the function F is computed at each iteration. Default: False.

forward(X, cur_data_fidelity, cur_prior, cur_params, y, physics)[source]

General form of a single iteration of splitting algorithms for minimizing \(F = f + \lambda g\), alternating between a step on \(f\) and a step on \(g\). The primal and dual variables as well as the estimated cost at the current iterate are stored in a dictionary $X$ of the form {‘est’: (x,z), ‘cost’: F}.

Parameters:
  • X (dict) – Dictionary containing the current iterate and the estimated cost.

  • cur_data_fidelity (deepinv.optim.DataFidelity) – Instance of the DataFidelity class defining the current data_fidelity.

  • cur_prior (deepinv.optim.prior) – Instance of the Prior class defining the current prior.

  • cur_params (dict) – Dictionary containing the current parameters of the algorithm.

  • y (torch.Tensor) – Input data.

  • physics (deepinv.physics) – Instance of the physics modeling the observation.

Returns:

Dictionary {“est”: (x, z), “cost”: F} containing the updated current iterate and the estimated current cost.

relaxation_step(u, v, beta)[source]

Performs a relaxation step of the form \(\beta u + (1-\beta) v\).

Parameters:
Returns:

Relaxed tensor.

Examples using OptimIterator:

PnP with custom optimization algorithm (Condat-Vu Primal-Dual)

PnP with custom optimization algorithm (Condat-Vu Primal-Dual)

Learned Primal-Dual algorithm for CT scan.

Learned Primal-Dual algorithm for CT scan.