BaseOptim

class deepinv.optim.BaseOptim(iterator, params_algo={'lambda': 1.0, 'stepsize': 1.0}, data_fidelity=None, prior=None, max_iter=100, crit_conv='residual', thres_conv=1e-05, early_stop=False, has_cost=False, backtracking=False, gamma_backtracking=0.1, eta_backtracking=0.9, custom_metrics=None, custom_init=None, get_output=<function BaseOptim.<lambda>>, anderson_acceleration=False, history_size=5, beta_anderson_acc=1.0, eps_anderson_acc=0.0001, verbose=False)[source]

Bases: Module

Class for optimization algorithms, consists in iterating a fixed-point operator.

Module solving the problem

\[\begin{equation} \label{eq:min_prob} \tag{1} \underset{x}{\arg\min} \quad \datafid{x}{y} + \lambda \reg{x}, \end{equation}\]

where the first term \(\datafidname:\xset\times\yset \mapsto \mathbb{R}_{+}\) enforces data-fidelity, the second term \(\regname:\xset\mapsto \mathbb{R}_{+}\) acts as a regularization and \(\lambda > 0\) is a regularization parameter. More precisely, the data-fidelity term penalizes the discrepancy between the data \(y\) and the forward operator \(A\) applied to the variable \(x\), as

\[\datafid{x}{y} = \distance{Ax}{y}\]

where \(\distance{\cdot}{\cdot}\) is a distance function, and where \(A:\xset\mapsto \yset\) is the forward operator (see deepinv.physics.Physics())

Optimization algorithms for minimising the problem above can be written as fixed point algorithms, i.e. for \(k=1,2,...\)

\[\qquad (x_{k+1}, z_{k+1}) = \operatorname{FixedPoint}(x_k, z_k, f, g, A, y, ...)\]

where \(x_k\) is a variable converging to the solution of the minimization problem, and \(z_k\) is an additional variable that may be required in the computation of the fixed point operator.

The optim_builder() function can be used to instantiate this class with a specific fixed point operator.

If the algorithm is minimizing an explicit and fixed cost function \(F(x) = \datafid{x}{y} + \lambda \reg{x}\), the value of the cost function is computed along the iterations and can be used for convergence criterion. Moreover, backtracking can be used to adapt the stepsize at each iteration. Backtracking consists in choosing the largest stepsize \(\tau\) such that, at each iteration, sufficient decrease of the cost function \(F\) is achieved. More precisely, Given \(\gamma \in (0,1/2)\) and \(\eta \in (0,1)\) and an initial stepsize \(\tau > 0\), the following update rule is applied at each iteration \(k\):

\[\text{ while } F(x_k) - F(x_{k+1}) < \frac{\gamma}{\tau} || x_{k-1} - x_k ||^2, \,\, \text{ do } \tau \leftarrow \eta \tau\]

The variable params_algo is a dictionary containing all the relevant parameters for running the algorithm. If the value associated with the key is a float, the algorithm will use the same parameter across all iterations. If the value is list of length max_iter, the algorithm will use the corresponding parameter at each iteration.

The variable data_fidelity is a list of instances of deepinv.optim.DataFidelity() (or a single instance). If a single instance, the same data-fidelity is used at each iteration. If a list, the data-fidelity can change at each iteration. The same holds for the variable prior which is a list of instances of deepinv.optim.Prior() (or a single instance).

>>> import deepinv as dinv
>>> # This minimal example shows how to use the BaseOptim class to solve the problem
>>> #                min_x 0.5  ||Ax-y||_2^2 + \lambda ||x||_1
>>> # with the PGD algorithm, where A is the identity operator, lambda = 1 and y = [2, 2].
>>>
>>> # Create the measurement operator A
>>> A = torch.tensor([[1, 0], [0, 1]], dtype=torch.float64)
>>> A_forward = lambda v: A @ v
>>> A_adjoint = lambda v: A.transpose(0, 1) @ v
>>>
>>> # Define the physics model associated to this operator
>>> physics = dinv.physics.LinearPhysics(A=A_forward, A_adjoint=A_adjoint)
>>>
>>> # Define the measurement y
>>> y = torch.tensor([2, 2], dtype=torch.float64)
>>>
>>> # Define the data fidelity term
>>> data_fidelity = dinv.optim.data_fidelity.L2()
>>>
>>> # Define the prior
>>> prior = dinv.optim.Prior(g = lambda x, *args: torch.norm(x, p=1))
>>>
>>> # Define the parameters of the algorithm
>>> params_algo = {"stepsize": 0.5, "lambda": 1.0}
>>>
>>> # Define the fixed-point iterator
>>> iterator = dinv.optim.optim_iterators.PGDIteration()
>>>
>>> # Define the optimization algorithm
>>> optimalgo = dinv.optim.BaseOptim(iterator,
...                     data_fidelity=data_fidelity,
...                     params_algo=params_algo,
...                     prior=prior)
>>>
>>> # Run the optimization algorithm
>>> with torch.no_grad(): xhat = optimalgo(y, physics)
>>> print(xhat)
tensor([1., 1.], dtype=torch.float64)
Parameters:
  • iterator (deepinv.optim.optim_iterators.OptimIterator) – Fixed-point iterator of the optimization algorithm of interest.

  • params_algo (dict) – dictionary containing all the relevant parameters for running the algorithm, e.g. the stepsize, regularisation parameter, denoising standard deviation. Each value of the dictionary can be either Iterable (distinct value for each iteration) or a single float (same value for each iteration). Default: {“stepsize”: 1.0, “lambda”: 1.0}. See Parameters for more details.

  • deepinv.optim.DataFidelity (list,) – data-fidelity term. Either a single instance (same data-fidelity for each iteration) or a list of instances of deepinv.optim.DataFidelity() (distinct data-fidelity for each iteration). Default: None.

  • deepinv.optim.Prior (list,) – regularization prior. Either a single instance (same prior for each iteration) or a list of instances of deepinv.optim.Prior() (distinct prior for each iteration). Default: None.

  • max_iter (int) – maximum number of iterations of the optimization algorithm. Default: 100.

  • crit_conv (str) – convergence criterion to be used for claiming convergence, either "residual" (residual of the iterate norm) or “cost” (on the cost function). Default: "residual"

  • thres_conv (float) – value of the threshold for claiming convergence. Default: 1e-05.

  • early_stop (bool) – whether to stop the algorithm once the convergence criterion is reached. Default: True.

  • has_cost (bool) – whether the algorithm has an explicit cost function or not. Default: False.

  • custom_metrics (dict) – dictionary containing custom metrics to be computed at each iteration.

  • backtracking (bool) – whether to apply a backtracking strategy for stepsize selection. Default: False.

  • gamma_backtracking (float) – \(\gamma\) parameter in the backtracking selection. Default: 0.1.

  • eta_backtracking (float) – \(\eta\) parameter in the backtracking selection. Default: 0.9.

  • custom_init (function) – initializes the algorithm with custom_init(y, physics).

  • get_output (function) – get the image output given the current dictionary update containing primal and auxiliary variables X = {('est' : (primal, aux)}. Default : X['est'][0]. If None (default value) algorithm is initialized with \(A^Ty\). Default: None.

  • anderson_acceleration (bool) – whether to use Anderson acceleration for accelerating the forward fixed-point iterations. Default: False.

  • history_size (int) – size of the history of iterates used for Anderson acceleration. Default: 5.

  • beta_anderson_acc (float) – momentum of the Anderson acceleration step. Default: 1.0.

  • eps_anderson_acc (float) – regularization parameter of the Anderson acceleration step. Default: 1e-4.

  • verbose (bool) – whether to print relevant information of the algorithm during its run, such as convergence criterion at each iterate. Default: False.

Returns:

a torch model that solves the optimization problem.

check_conv_fn(it, X_prev, X)[source]

Checks the convergence of the algorithm.

Parameters:
  • it (int) – iteration number.

  • X_prev (dict) – dictionary containing the primal and dual previous iterates.

  • X (dict) – dictionary containing the current primal and dual iterates.

Return bool:

True if the algorithm has converged, False otherwise.

check_iteration_fn(X_prev, X)[source]

Performs stepsize backtracking.

Parameters:
  • X_prev (dict) – dictionary containing the primal and dual previous iterates.

  • X (dict) – dictionary containing the current primal and dual iterates.

forward(y, physics, x_gt=None, compute_metrics=False)[source]

Runs the fixed-point iteration algorithm for solving (1).

Parameters:
  • y (torch.Tensor) – measurement vector.

  • physics (deepinv.physics) – physics of the problem for the acquisition of y.

  • x_gt (torch.Tensor) – (optional) ground truth image, for plotting the PSNR across optim iterations.

  • compute_metrics (bool) – whether to compute the metrics or not. Default: False.

Returns:

If compute_metrics is False, returns (torch.Tensor) the output of the algorithm. Else, returns (torch.Tensor, dict) the output of the algorithm and the metrics.

init_iterate_fn(y, physics, F_fn=None)[source]

Initializes the iterate of the algorithm. The first iterate is stored in a dictionary of the form X = {'est': (x_0, u_0), 'cost': F_0} where:

  • est is a tuple containing the first primal and auxiliary iterates.

  • cost is the value of the cost function at the first iterate.

By default, the first (primal, auxiliary) iterate of the algorithm is chosen as \((A^{\top}y, A^{\top}y)\). A custom initialization is possible with the custom_init argument.

Parameters:
  • y (torch.Tensor) – measurement vector.

  • deepinv.physics – physics of the problem.

  • F_fn – function that computes the cost function.

Returns:

a dictionary containing the first iterate of the algorithm.

init_metrics_fn(X_init, x_gt=None)[source]

Initializes the metrics.

Metrics are computed for each batch and for each iteration. They are represented by a list of list, and metrics[metric_name][i,j] contains the metric metric_name computed for batch i, at iteration j.

Parameters:
  • X_init (dict) – dictionary containing the primal and auxiliary initial iterates.

  • x_gt (torch.Tensor) – ground truth image, required for PSNR computation. Default: None.

Return dict:

A dictionary containing the metrics.

update_data_fidelity_fn(it)[source]

For each data_fidelity function in data_fidelity, selects the data_fidelity value for iteration it (if this data_fidelity depends on the iteration number).

Parameters:

it (int) – iteration number.

Returns:

a dictionary containing the data_fidelity of iteration it.

update_metrics_fn(metrics, X_prev, X, x_gt=None)[source]

Function that compute all the metrics, across all batches, for the current iteration.

Parameters:
  • metrics (dict) – dictionary containing the metrics. Each metric is computed for each batch.

  • X_prev (dict) – dictionary containing the primal and dual previous iterates.

  • X (dict) – dictionary containing the current primal and dual iterates.

  • x_gt (torch.Tensor) – ground truth image, required for PSNR computation. Default: None.

Return dict:

a dictionary containing the updated metrics.

update_params_fn(it)[source]

For each parameter params_algo, selects the parameter value for iteration it (if this parameter depends on the iteration number).

Parameters:

it (int) – iteration number.

Returns:

a dictionary containing the parameters of iteration it.

update_prior_fn(it)[source]

For each prior function in prior, selects the prior value for iteration it (if this prior depends on the iteration number).

Parameters:

it (int) – iteration number.

Returns:

a dictionary containing the prior of iteration it.

Examples using BaseOptim:

Radio interferometric imaging with deepinverse

Radio interferometric imaging with deepinverse

Image deblurring with custom deep explicit prior.

Image deblurring with custom deep explicit prior.

Random phase retrieval and reconstruction methods.

Random phase retrieval and reconstruction methods.

Image deblurring with Total-Variation (TV) prior

Image deblurring with Total-Variation (TV) prior

Image inpainting with wavelet prior

Image inpainting with wavelet prior

Vanilla PnP for computed tomography (CT).

Vanilla PnP for computed tomography (CT).

DPIR method for PnP image deblurring.

DPIR method for PnP image deblurring.

Regularization by Denoising (RED) for Super-Resolution.

Regularization by Denoising (RED) for Super-Resolution.

PnP with custom optimization algorithm (Condat-Vu Primal-Dual)

PnP with custom optimization algorithm (Condat-Vu Primal-Dual)