from __future__ import annotations
import torch
from torch import Tensor
from typing import Callable
from deepinv.utils import TensorList, zeros_like
from deepinv.utils.compat import zip_strict
[docs]
def lsqr(
A: Callable,
AT: Callable,
b: Tensor,
eta: float | torch.Tensor = 0.0,
x0: Tensor = None,
tol: float = 1e-6,
conlim: float = 1e8,
max_iter: int = 100,
parallel_dim: None | int | list[int] = 0,
verbose: bool = False,
**kwargs,
) -> Tensor:
r"""
LSQR algorithm for solving linear systems.
Code adapted from SciPy's implementation of LSQR: https://github.com/scipy/scipy/blob/v1.15.1/scipy/sparse/linalg/_isolve/lsqr.py
The function solves the linear system :math:`\min_x \|Ax-b\|^2 + \eta \|x-x_0\|^2` in the least squares sense
using the LSQR algorithm of :cite:t:`paige1982lsqr`.
:param Callable A: Linear operator as a callable function.
:param Callable AT: Adjoint operator as a callable function.
:param torch.Tensor b: input tensor of shape (B, ...)
:param float, torch.Tensor eta: damping parameter :math:`eta \geq 0`. Can be batched (shape (B, ...)) or a scalar.
:param None, torch.Tensor x0: Optional :math:`x_0`, which is also used as the initial guess.
:param float tol: relative tolerance for stopping the LSQR algorithm.
:param float conlim: maximum value of the condition number of the system.
:param int max_iter: maximum number of LSQR iterations.
:param None, int, list[int] parallel_dim: dimensions to be considered as batch dimensions. If None, all dimensions are considered as batch dimensions.
:param bool verbose: Output progress information in the console.
:retrun: (:class:`torch.Tensor`) :math:`x` of shape (B, ...), (:class:`torch.Tensor`) condition number of the system.
"""
xt = AT(b)
if isinstance(parallel_dim, int):
parallel_dim = [parallel_dim]
if parallel_dim is None:
parallel_dim = []
if isinstance(b, TensorList):
device = b[0].device
else:
device = b.device
def normf(u):
if isinstance(u, TensorList):
total = 0.0
dims = [[i for i in range(bi.ndim) if i not in parallel_dim] for bi in b]
for k in range(len(u)):
total += torch.linalg.vector_norm(
u[k], dim=dims[k], keepdim=False
) # don't keep dim as dims might be different
return total
else:
dim = [i for i in range(u.ndim) if i not in parallel_dim]
return torch.linalg.vector_norm(u, dim=dim, keepdim=False)
b_shape = []
if isinstance(b, TensorList):
for j in range(len(b)):
b_shape.append([])
for i in range(len(b[j].shape)):
b_shape[j].append(b[j].shape[i] if i in parallel_dim else 1)
else:
for i in range(len(b.shape)):
b_shape.append(b.shape[i] if i in parallel_dim else 1)
Atb_shape = []
for i in range(len(xt.shape)):
Atb_shape.append(xt.shape[i] if i in parallel_dim else 1)
def scalar(v, alpha, b_domain):
if b_domain:
if isinstance(v, TensorList):
return TensorList(
[
vi * alpha.view(bi_shape)
for vi, bi_shape in zip_strict(v, b_shape)
]
)
else:
return v * alpha.view(b_shape)
else:
return v * alpha.view(Atb_shape)
if eta is None:
eta = 0.0
if not isinstance(eta, Tensor):
eta = torch.tensor(eta, device=device)
if eta.ndim > 0: # if batched eta
if eta.size(0) != b.size(0):
raise ValueError(
"If eta is batched, its batch size must match the one of b."
)
else: # ensure eta has ndim as b
eta = eta.squeeze()
if torch.any(eta < 0):
raise ValueError(
"Damping parameter eta must be non-negative. LSQR cannot be applied to problems with negative eta."
)
# this should be safe as eta should be non-negative
eta_sqrt = torch.sqrt(eta)
# ctol = 1 / conlim if conlim > 0 else 0
anorm = 0.0
acond = torch.zeros(1, device=device)
dampsq = eta
ddnorm = 0.0
# res2 = 0.0
# xnorm = 0.0
xxnorm = 0.0
z = 0.0
cs2 = -1.0
sn2 = 0.0
u = b.clone()
bnorm = normf(b)
if x0 is None:
x = zeros_like(xt)
beta = bnorm
else:
if isinstance(x0, float):
x = x0 * zeros_like(xt)
else:
x = x0.clone()
u -= A(x)
beta = normf(u)
if torch.all(beta > 0):
u = scalar(u, 1 / beta, b_domain=True)
v = AT(u)
alpha = normf(v)
else:
v = torch.zeros_like(x)
alpha = torch.zeros(1, device=device)
if torch.all(alpha > 0):
v = scalar(v, 1 / alpha, b_domain=False) # v / view(alpha, Atb_shape)
w = v.clone()
rhobar = alpha
phibar = beta
arnorm = alpha * beta
if torch.any(arnorm == 0):
return x, acond
flag = False
for itn in range(max_iter):
u = A(v) - scalar(u, alpha, b_domain=True)
beta = normf(u)
if torch.all(beta > 0):
u = scalar(u, 1 / beta, b_domain=True)
anorm = torch.sqrt(anorm**2 + alpha**2 + beta**2 + dampsq)
v = AT(u) - scalar(v, beta, b_domain=False)
alpha = normf(v)
if torch.all(alpha > 0):
v = scalar(v, 1 / alpha, b_domain=False)
if torch.any(eta > 0):
rhobar1 = torch.sqrt(rhobar**2 + dampsq)
cs1 = rhobar / rhobar1
sn1 = eta_sqrt / rhobar1
psi = sn1 * phibar
phibar = cs1 * phibar
else:
rhobar1 = rhobar
psi = 0.0
cs, sn, rho = _sym_ortho(rhobar1, beta)
theta = sn * alpha
rhobar = -cs * alpha
phi = cs * phibar
phibar = sn * phibar
# tau = sn * phi
t1 = phi / rho
t2 = -theta / rho
dk = scalar(w, 1 / rho, b_domain=False)
x = x + scalar(w, t1, b_domain=False)
w = v + scalar(w, t2, b_domain=False)
ddnorm = ddnorm + normf(dk) ** 2
# if calc_var:
# var = var + dk ** 2
delta = sn2 * rho
gambar = -cs2 * rho
rhs = phi - delta * z
# zbar = rhs / gambar
# xnorm = torch.sqrt(xxnorm + zbar ** 2)
gamma = torch.sqrt(gambar**2 + theta**2)
cs2 = gambar / gamma
sn2 = theta / gamma
z = rhs / gamma
xxnorm = xxnorm + z**2
acond = anorm * torch.sqrt(ddnorm).mean()
rnorm = torch.sqrt(phibar**2 + psi**2)
# arnorm = alpha * abs(tau)
if torch.all(rnorm <= tol * bnorm):
flag = True
if verbose:
print("LSQR converged at iteration", itn)
break
elif torch.any(acond > conlim):
flag = True
if verbose:
print(f"LSQR reached condition number limit {conlim} at iteration", itn)
break
if not flag and verbose:
print("LSQR did not converge")
return x, acond.sqrt()
def _sym_ortho(a: Tensor, b: Tensor) -> tuple[Tensor, ...]:
"""
Stable implementation of Givens rotation.
Adapted from https://github.com/scipy/scipy/blob/v1.15.1/scipy/sparse/linalg/_isolve/lsqr.py
The routine '_sym_ortho' was added for numerical stability. This is
recommended by S.-C. Choi in "Iterative Methods for Singular Linear Equations and Least-Squares
Problems". It removes the unpleasant potential of
``1/eps`` in some important places.
"""
a, b = torch.broadcast_tensors(a, b)
if torch.any(b == 0):
return torch.sign(a), 0, a.abs()
elif torch.any(a == 0):
return 0, torch.sign(b), b.abs()
elif torch.any(b.abs() > a.abs()):
tau = a / b
s = torch.sign(b) / torch.sqrt(1 + tau * tau)
c = s * tau
r = b / s
else:
tau = b / a
c = torch.sign(a) / torch.sqrt(1 + tau * tau)
s = c * tau
r = a / c
return c, s, r