least_squares#
- deepinv.optim.utils.least_squares(A, AT, y, z=0.0, init=None, gamma=None, parallel_dim=0, AAT=None, ATA=None, solver='CG', max_iter=100, tol=1e-6, **kwargs)[source]#
Solves \(\min_x \|Ax-y\|^2 + \frac{1}{\gamma}\|x-z\|^2\) using the specified solver.
The solvers are stopped either when \(\|Ax-y\| \leq \text{tol} \times \|y\|\) or when the maximum number of iterations is reached.
Available solvers are:
'CG'
: Conjugate Gradient.'BiCGStab'
: Biconjugate Gradient Stabilized method'lsqr'
: Least Squares QR
Note
Both
'CG'
and'BiCGStab'
are used for squared linear systems, while'lsqr'
is used for rectangular systems.If the chosen solver requires a squared system, we map to the problem to the normal equations: If the size of \(y\) is larger than \(x\) (overcomplete problem), it computes \((A^{\top} A)^{-1} A^{\top} y\), otherwise (incomplete problem) it computes \(A^{\top} (A A^{\top})^{-1} y\).
- Parameters:
A (Callable) – Linear operator \(A\) as a callable function.
AT (Callable) – Adjoint operator \(A^{\top}\) as a callable function.
y (torch.Tensor) – input tensor of shape (B, …)
z (torch.Tensor) – input tensor of shape (B, …) or scalar.
gamma (None, float) – (Optional) inverse regularization parameter.
solver (str) – solver to be used.
AAT (Callable) – (Optional) Efficient implementation of \(A(A^{\top}(x))\). If not provided, it is computed as \(A(A^{\top}(x))\).
ATA (Callable) – (Optional) Efficient implementation of \(A^{\top}(A(x))\). If not provided, it is computed as \(A^{\top}(A(x))\).
max_iter (int) – maximum number of iterations.
tol (float) – relative tolerance for stopping the algorithm.
parallel_dim (None, int, list[int]) – dimensions to be considered as batch dimensions. If None, all dimensions are considered as batch dimensions.
kwargs – Keyword arguments to be passed to the solver.
- Returns:
(class:
torch.Tensor
) \(x\) of shape (B, …).