least_squares#
- deepinv.optim.utils.least_squares(A, AT, y, z=0.0, init=None, gamma=None, parallel_dim=0, AAT=None, ATA=None, solver='CG', max_iter=100, tol=1e-6, **kwargs)[source]#
Solves \(\min_x \|Ax-y\|^2 + \frac{1}{\gamma}\|x-z\|^2\) using the specified solver.
The solvers are stopped either when \(\|Ax-y\| \leq \text{tol} \times \|y\|\) or when the maximum number of iterations is reached.
The solution depends on the regularization parameter \(\gamma\):
- If
gamma=None
(\(\gamma = \infty\)), it solves the unregularized least squares problem \(\min_x \|Ax-y\|^2\). If \(A\) is overcomplete (rows>=columns), it computes the minimum norm solution \(x = A^{\top}(AA^{\top})^{-1}y\).
If \(A\) is undercomplete (columns>rows), it computes the least squares solution \(x = (A^{\top}A)^{-1}A^{\top}y\).
- If
If \(0 < \gamma < \infty\), it computes the least squares solution \(x = (A^{\top}A + \frac{1}{\gamma}I)^{-1}(A^{\top}y + \frac{1}{\gamma}z)\).
Warning
If \(\gamma \leq 0\), the problem can become non-convex and the solvers are not designed for that. A warning is raised, but solvers continue anyway (except for LSQR, which cannot be used for negative \(\gamma\)).
Available solvers are:
'CG'
: Conjugate Gradient.'BiCGStab'
: Biconjugate Gradient Stabilized method'lsqr'
: Least Squares QR'minres'
: Minimal Residual Method
Note
Both
'CG'
and'BiCGStab'
are used for squared linear systems, while'lsqr'
is used for rectangular systems.If the chosen solver requires a squared system, we map to the problem to the normal equations: If the size of \(y\) is larger than \(x\) (overcomplete problem), it computes \((A^{\top} A)^{-1} A^{\top} y\), otherwise (incomplete problem) it computes \(A^{\top} (A A^{\top})^{-1} y\).
- Parameters:
A (Callable) – Linear operator \(A\) as a callable function.
AT (Callable) – Adjoint operator \(A^{\top}\) as a callable function.
y (torch.Tensor) – input tensor of shape (B, …)
z (torch.Tensor) – input tensor of shape (B, …) or scalar.
init (torch.Tensor) – (Optional) initial guess for the solver. If None, it is set to a tensor of zeros.
gamma (None, float, torch.Tensor) – (Optional) inverse regularization parameter. Can be batched (shape (B, …)) or a scalar. If multi-dimensional tensor, then its shape must match that of \(A^{\top} y\). If None, it is set to \(\infty\) (no regularization).
solver (str) – solver to be used, options are
'CG'
,'BiCGStab'
,'lsqr'
and'minres'
.AAT (Callable) – (Optional) Efficient implementation of \(A(A^{\top}(x))\). If not provided, it is computed as \(A(A^{\top}(x))\).
ATA (Callable) – (Optional) Efficient implementation of \(A^{\top}(A(x))\). If not provided, it is computed as \(A^{\top}(A(x))\).
max_iter (int) – maximum number of iterations.
tol (float) – relative tolerance for stopping the algorithm.
parallel_dim (None, int, list[int]) – dimensions to be considered as batch dimensions. If None, all dimensions are considered as batch dimensions.
kwargs – Keyword arguments to be passed to the solver.
- Returns:
(class:
torch.Tensor
) \(x\) of shape (B, …).- Return type:
Examples using least_squares
:#

Reducing the memory and computational complexity of unfolded network training