least_squares#

deepinv.optim.utils.least_squares(A, AT, y, z=0.0, init=None, gamma=None, parallel_dim=0, AAT=None, ATA=None, solver='CG', max_iter=100, tol=1e-6, **kwargs)[source]#

Solves \(\min_x \|Ax-y\|^2 + \frac{1}{\gamma}\|x-z\|^2\) using the specified solver.

The solvers are stopped either when \(\|Ax-y\| \leq \text{tol} \times \|y\|\) or when the maximum number of iterations is reached.

Available solvers are:

Note

Both 'CG' and 'BiCGStab' are used for squared linear systems, while 'lsqr' is used for rectangular systems.

If the chosen solver requires a squared system, we map to the problem to the normal equations: If the size of \(y\) is larger than \(x\) (overcomplete problem), it computes \((A^{\top} A)^{-1} A^{\top} y\), otherwise (incomplete problem) it computes \(A^{\top} (A A^{\top})^{-1} y\).

Parameters:
  • A (Callable) – Linear operator \(A\) as a callable function.

  • AT (Callable) – Adjoint operator \(A^{\top}\) as a callable function.

  • y (torch.Tensor) – input tensor of shape (B, …)

  • z (torch.Tensor) – input tensor of shape (B, …) or scalar.

  • gamma (None, float) – (Optional) inverse regularization parameter.

  • solver (str) – solver to be used.

  • AAT (Callable) – (Optional) Efficient implementation of \(A(A^{\top}(x))\). If not provided, it is computed as \(A(A^{\top}(x))\).

  • ATA (Callable) – (Optional) Efficient implementation of \(A^{\top}(A(x))\). If not provided, it is computed as \(A^{\top}(A(x))\).

  • max_iter (int) – maximum number of iterations.

  • tol (float) – relative tolerance for stopping the algorithm.

  • parallel_dim (None, int, list[int]) – dimensions to be considered as batch dimensions. If None, all dimensions are considered as batch dimensions.

  • kwargs – Keyword arguments to be passed to the solver.

Returns:

(class:torch.Tensor) \(x\) of shape (B, …).