BaseDEQ
- class deepinv.unfolded.BaseDEQ(*args, max_iter_backward=50, anderson_acceleration_backward=False, history_size_backward=5, beta_anderson_acc_backward=1.0, eps_anderson_acc_backward=0.0001, **kwargs)[source]
Bases:
BaseUnfold
Base class for deep equilibrium (DEQ) algorithms. Child of
deepinv.unfolded.BaseUnfold
.Enables to turn any fixed-point algorithm into a DEQ algorithm, i.e. an algorithm that can be virtually unrolled infinitely leveraging the implicit function theorem. The backward pass is performed using fixed point iterations to find solutions of the fixed-point equation
\[\begin{equation} v = \left(\frac{\partial \operatorname{FixedPoint}(x^\star)}{\partial x^\star} \right )^T v + u. \end{equation}\]where \(u\) is the incoming gradient from the backward pass, and \(x^\star\) is the equilibrium point of the forward pass.
See this tutorial for more details.
For now DEQ is only possible with PGD, HQS and GD optimization algorithms.
- Parameters:
max_iter_backward (int) – Maximum number of backward iterations. Default:
50
.anderson_acceleration_backward (bool) – if True, the Anderson acceleration is used at iteration of fixed-point algorithm for computing the backward pass. Default:
False
.history_size_backward (int) – size of the history used for the Anderson acceleration for the backward pass. Default:
5
.beta_anderson_acc_backward (float) – momentum of the Anderson acceleration step for the backward pass. Default:
1.0
.eps_anderson_acc_backward (float) – regularization parameter of the Anderson acceleration step for the backward pass. Default:
1e-4
.
- forward(y, physics, x_gt=None, compute_metrics=False)[source]
The forward pass of the DEQ algorithm. Compared to
deepinv.unfolded.BaseUnfold
, the backward algorithm is performed using fixed point iterations.- Parameters:
y (torch.Tensor) – Input tensor.
physics (deepinv.Physics) – Physics object.
x_gt (torch.Tensor) – (optional) ground truth image, for plotting the PSNR across optim iterations.
compute_metrics (bool) – whether to compute the metrics or not. Default:
False
.
- Returns:
If
compute_metrics
isFalse
, returns (torch.Tensor
) the output of the algorithm. Else, returns (torch.Tensor
, dict) the output of the algorithm and the metrics.
Examples using BaseDEQ
:
Deep Equilibrium (DEQ) algorithms for image deblurring