least_squares_implicit_backward#
- deepinv.optim.utils.least_squares_implicit_backward(physics, y, z=None, init=None, gamma=None, **kwargs)[source]#
Least squares solver with O(1) memory backward propagation using implicit differentiation. The function is similar to
deepinv.optim.utils.least_squares()
for the forward pass, but uses implicit differentiation for the backward pass, which reduces memory consumption to O(1) in the number of iterations.This function supports backpropagation with respect to the inputs \(y\), \(z\) and \(\gamma\) and also with respect to the parameters of the physics operator \(A_\theta\) if they require gradients. See Reducing the memory and computational complexity of unfolded network training and the notes below for more details.
Let \(h(z, y, \theta, \gamma)\) denote the output of the least squares solver, i.e. the solution of the following problem:
\[h(z, y, \theta, \gamma) = \underset{x}{\arg\min} \; \frac{\gamma}{2}\|A_\theta x-y\|^2 + \frac{1}{2}\|x-z\|^2\]When the forward least-squares solver converges to the exact minimizer, we have the following closed-form expressions for \(h(z, y, \theta, \gamma)\):
\[h(z, y, \theta, \gamma) = \left( A_\theta^{\top} A_\theta + \frac{1}{\gamma} I \right)^{-1} \left( A_\theta^{\top} y + \frac{1}{\gamma} z \right)\]Let \(M\) denote the inverse \(\left( A_\theta^T A_\theta + \frac{1}{\gamma} I \right)^{-1}\). In the forward, we need to compute the vector-Jacobian products (VJPs), which can be computed as follows:
\[\begin{split}\left( \frac{\partial h}{\partial z} \right)^{\top} v &= \frac{1}{\gamma} M v \\ \left( \frac{\partial h}{\partial y} \right)^{\top} v &= A_\theta M v \\ \left( \frac{\partial h}{\partial \gamma} \right)^{\top} v &= (h - z)^\top M v / \gamma^2 \\ \left( \frac{\partial h}{\partial \theta} \right)^{\top} v &= \frac{\partial p}{\partial \theta}\end{split}\]where \(p = (y - A_\theta h)^{\top} A_\theta M v\) and \(\frac{\partial p}{\partial \theta}\) can be computed using the standard backpropagation mechanism (autograd).
Note
This function only supports first-order gradients. Higher-order gradients are not supported. If you need higher-order gradients, please use
deepinv.optim.utils.least_squares()
instead but be aware that it requires storing all intermediate iterates, which can be memory-intensive.Note
This function also supports implicit gradients with respect to the parameters of the physics operator \(A_\theta\) if they require gradients. This is useful for learning the physics parameters in an end-to-end fashion. The gradients are accumulated in-place in the
.grad
attribute of the parameters of the physics operator. To make this work, the function takes as input the physics operator itself (not just its matmul functions) and checks if any of its parameters require gradients. If so, it triggers the backward pass accordingly.Warning
Implicit gradients can be incorrect if the least squares solver does not converge sufficiently. Make sure to set the
max_iter
andtol
parameters of the least squares solver appropriately to ensure convergence. You can monitor the convergence by settingverbose=True
in the least squares solver viakwargs
. If the solver does not converge, the implicit gradients can be very inaccurate and lead to divergence of the training.Warning
This function does not support
deepinv.utils.TensorList
inputs yet. If you usedeepinv.utils.TensorList
as inputs, the function will fall back to standard least squares with full backpropagation.Tip
If you do not need gradients with respect to the physics parameters, you can set
requires_grad=False
for all parameters of the physics operator to avoid the additional backward pass. This can save some computation time.Tip
Training unfolded network with implicit differentiation can reduce memory consumption significantly, especially when using many iterations. On GPU, we can expect a memory reduction factor of about 2x-3x compared to standard backpropagation and a speed-up of about 1.2x-1.5x. The exact numbers depend on the problem and the number of iterations.
- Parameters:
deepinv.physics.LinearPhysics – physics operator
deepinv.physics.LinearPhysics
.y (torch.Tensor) – input tensor of shape (B, …)
z (torch.Tensor) – input tensor of shape (B, …). Default is
None
, which corresponds to a zero tensor.init (Optional[torch.Tensor]) – Optional initial guess, only used for the forward pass. Default is
None
, which corresponds to a zero initialization.gamma (Optional[float, torch.Tensor]) – regularization parameter \(\gamma > 0\). Default is
None
.kwargs – additional arguments to be passed to the least squares solver.
- Returns:
(
torch.Tensor
) \(x\) of shape (B, …), the solution of the least squares problem.- Return type: