Unfolded Algorithms

This package contains a collection of routines turning the optimization algorithms defined in Optim into unfolded architectures. Recall that optimization algorithms aim at solving problems of the form \(\datafid{x}{y} + \reg{x}\) where \(\datafid{\cdot}{\cdot}\) is a data-fidelity term, \(\reg{\cdot}\) is a regularization term. The resulting fixed-point algorithms for solving these problems are of the form (see Optim)

\[\begin{split}\begin{aligned} z_{k+1} &= \operatorname{step}_f(x_k, z_k, y, A, ...)\\ x_{k+1} &= \operatorname{step}_g(x_k, z_k, y, A, ...) \end{aligned}\end{split}\]

where \(\operatorname{step}_f\) and \(\operatorname{step}_g\) are gradient and/or proximal steps on \(f\) and \(g\) respectively.

Unfolded architectures (sometimes called ‘unrolled architectures’) are obtained by replacing parts of these algorithms by learnable modules. In turn, they can be trained in an end-to-end fashion to solve inverse problems.

Unfolded

The deepinv.unfolded.unfolded_builder class is a generic class for building unfolded architectures. It provides a trainable reconstruction network using a either pre-existing optimizer (e.g., “PGD”) or an iterator defined by the user. The user can choose which parameters (e.g., prior denoiser, step size, regularization parameter, etc.) are learnable and which are not.

deepinv.unfolded.unfolded_builder

Helper function for building an unfolded architecture.

The builder depends on the backbone class for DEQs, deepinv.unfolded.BaseUnfold.

deepinv.unfolded.BaseUnfold

Base class for unfolded algorithms.

In the following example, we create an unfolded architecture of 5 proximal gradient steps using a DnCNN plug-and-play prior a standard L2 data-fidelity term. The network can be trained end-to-end, and evaluated with any forward model (e.g., denoising, deconvolution, inpainting, etc.).

>>> import torch
>>> import deepinv as dinv
>>>
>>> # Create a trainable unfolded architecture
>>> model = dinv.unfolded.unfolded_builder(
...     iteration="PGD",
...     data_fidelity=dinv.optim.data_fidelity.L2(),
...     prior=dinv.optim.PnP(dinv.models.DnCNN()),
...     params_algo={"stepsize": 1.0, "g_param": 1.0},
...     trainable_params=["stepsize", "g_param"]
... )
>>> # Forward pass
>>> x = torch.randn(1, 3, 16, 16)
>>> physics = dinv.physics.Denoising()
>>> y = physics(x)
>>> x_hat = model(y, physics)

Deep Equilibrium

Deep Equilibrium models (DEQ) are a particular class of unfolded architectures where the backward pass is performed via Fixed-Point iterations. DEQ algorithms can virtually unroll infinitely many layers leveraging the implicit function theorem. The backward pass consists in looking for solutions of the fixed-point equation

\[v = \left(\frac{\partial \operatorname{FixedPoint}(x^\star)}{\partial x^\star} \right)^{\top} v + u.\]

where \(u\) is the incoming gradient from the backward pass, and \(x^\star\) is the equilibrium point of the forward pass. See this tutorial for more details.

The deepinv.unfolded.DEQ_builder class is a generic class for building Deep Equilibrium (DEQ) architectures.

deepinv.unfolded.DEQ_builder

Helper function for building an instance of the BaseDEQ() class.

The builder depends on the backbone class for DEQs, deepinv.unfolded.BaseDEQ.

deepinv.unfolded.BaseDEQ

Base class for deep equilibrium (DEQ) algorithms.

Utils

Some more specific unfolded architectures are also available.

deepinv.models.PDNet_PrimalBlock

Primal block for the Primal-Dual unfolding model.

deepinv.models.PDNet_DualBlock

Dual block for the Primal-Dual unfolding model.