conjugate_gradient

class deepinv.optim.utils.conjugate_gradient(A: Callable, b: Tensor, max_iter: float = 100.0, tol: float = 1e-05, eps: float = 1e-08)[source]

Bases:

Standard conjugate gradient algorithm.

It solves the linear system \(Ax=b\), where \(A\) is a (square) linear operator and \(b\) is a tensor.

For more details see: http://en.wikipedia.org/wiki/Conjugate_gradient_method

Parameters:
  • A ((callable)) – Linear operator as a callable function, has to be square!

  • b (torch.Tensor) – input tensor of shape (B, …)

  • max_iter (int) – maximum number of CG iterations

  • tol (float) – absolute tolerance for stopping the CG algorithm.

  • eps (float) – a small value for numerical stability

Returns:

torch.Tensor \(x\) of shape (B, …) verifying \(Ax=b\).