FISTAIteration#

class deepinv.optim.optim_iterators.FISTAIteration(a=3, **kwargs)[source]#

Bases: OptimIterator

Iterator for fast iterative soft-thresholding (FISTA).

Class for a single iteration of the FISTA algorithm for minimizing \(f(x) + \lambda \regname(x)\) as proposed by Chambolle & Dossal.

The iteration is given by

\[\begin{split}\begin{equation*} \begin{aligned} u_{k} &= z_k - \gamma \nabla f(z_k) \\ x_{k+1} &= \operatorname{prox}_{\gamma \lambda \regname}(u_k) \\ z_{k+1} &= x_{k+1} + \alpha_k (x_{k+1} - x_k), \end{aligned} \end{equation*}\end{split}\]

where \(\gamma\) is a stepsize that should satisfy \(\gamma \leq 1/\operatorname{Lip}(\|\nabla f\|)\) and \(\alpha_k = (k+a-1)/(k+a)\).

Parameters:

a (float) – Parameter \(a\) in the FISTA algorithm (should be strictly greater than 2).

forward(X, cur_data_fidelity, cur_prior, cur_params, y, physics, *args, **kwargs)[source]#

Forward pass of an iterate of the FISTA algorithm.

Parameters:
  • X (dict) – Dictionary containing the current iterate and the estimated cost.

  • cur_data_fidelity (deepinv.optim.DataFidelity) – Instance of the DataFidelity class defining the current data_fidelity.

  • cur_prior (deepinv.optim.prior) – Instance of the Prior class defining the current prior.

  • cur_params (dict) – Dictionary containing the current parameters of the algorithm.

  • y (torch.Tensor) – Input data.

  • physics (deepinv.physics) – Instance of the physics modeling the observation.

Returns:

Dictionary {“est”: (x, z), “cost”: F} containing the updated current iterate and the estimated current cost.