conjugate_gradient#

deepinv.optim.utils.conjugate_gradient(A, b, max_iter=1e2, tol=1e-5, eps=1e-8, parallel_dim=0, init=None, verbose=False)[source]#

Standard conjugate gradient algorithm.

It solves the linear system \(Ax=b\), where \(A\) is a (square) linear operator and \(b\) is a tensor.

For more details see: http://en.wikipedia.org/wiki/Conjugate_gradient_method

Parameters:
  • A (Callable) – Linear operator as a callable function, has to be square!

  • b (torch.Tensor) – input tensor of shape (B, …)

  • max_iter (int) – maximum number of CG iterations

  • tol (float) – absolute tolerance for stopping the CG algorithm.

  • eps (float) – a small value for numerical stability

  • parallel_dim (None, int, list[int]) – dimensions to be considered as batch dimensions. If None, all dimensions are considered as batch dimensions.

  • init (torch.Tensor) – Optional initial guess.

  • verbose (bool) – Output progress information in the console.

Returns:

torch.Tensor \(x\) of shape (B, …) verifying \(Ax=b\).

Return type:

Tensor