conjugate_gradient
- class deepinv.optim.utils.conjugate_gradient(A: Callable, b: Tensor, max_iter: float = 100.0, tol: float = 1e-05, eps: float = 1e-08)[source]
Bases:
Standard conjugate gradient algorithm.
It solves the linear system \(Ax=b\), where \(A\) is a (square) linear operator and \(b\) is a tensor.
For more details see: http://en.wikipedia.org/wiki/Conjugate_gradient_method
- Parameters:
A ((callable)) – Linear operator as a callable function, has to be square!
b (torch.Tensor) – input tensor of shape (B, …)
max_iter (int) – maximum number of CG iterations
tol (float) – absolute tolerance for stopping the CG algorithm.
eps (float) – a small value for numerical stability
- Returns:
torch.Tensor \(x\) of shape (B, …) verifying \(Ax=b\).