Unfolded Algorithms
This package contains a collection of routines turning the optimization algorithms defined in Optim into unfolded architectures. Recall that optimization algorithms aim at solving problems of the form \(\datafid{x}{y} + \reg{x}\) where \(\datafid{\cdot}{\cdot}\) is a data-fidelity term, \(\reg{\cdot}\) is a regularization term. The resulting fixed-point algorithms for solving these problems are of the form (see Optim)
where \(\operatorname{step}_f\) and \(\operatorname{step}_g\) are gradient and/or proximal steps on \(f\) and \(g\) respectively.
Unfolded architectures (sometimes called ‘unrolled architectures’) are obtained by replacing parts of these algorithms by learnable modules. In turn, they can be trained in an end-to-end fashion to solve inverse problems.
Unfolded
The deepinv.unfolded.unfolded_builder
class is a generic class for building unfolded architectures. It provides
a trainable reconstruction network using a either pre-existing optimizer (e.g., “PGD”) or
an iterator defined by the user. The user can choose which parameters (e.g., prior denoiser, step size, regularization
parameter, etc.) are learnable and which are not.
Helper function for building an unfolded architecture. |
The builder depends on the backbone class for DEQs, deepinv.unfolded.BaseUnfold
.
Base class for unfolded algorithms. |
In the following example, we create an unfolded architecture of 5 proximal gradient steps using a DnCNN plug-and-play prior a standard L2 data-fidelity term. The network can be trained end-to-end, and evaluated with any forward model (e.g., denoising, deconvolution, inpainting, etc.).
>>> import torch
>>> import deepinv as dinv
>>>
>>> # Create a trainable unfolded architecture
>>> model = dinv.unfolded.unfolded_builder(
... iteration="PGD",
... data_fidelity=dinv.optim.L2(),
... prior=dinv.optim.PnP(dinv.models.DnCNN()),
... params_algo={"stepsize": 1.0, "g_param": 1.0},
... trainable_params=["stepsize", "g_param"]
... )
>>> # Forward pass
>>> x = torch.randn(1, 3, 16, 16)
>>> physics = dinv.physics.Denoising()
>>> y = physics(x)
>>> x_hat = model(y, physics)
Deep Equilibrium
Deep Equilibrium models (DEQ) are a particular class of unfolded architectures where the backward pass is performed via Fixed-Point iterations. DEQ algorithms can virtually unroll infinitely many layers leveraging the implicit function theorem. The backward pass consists in looking for solutions of the fixed-point equation
where \(u\) is the incoming gradient from the backward pass, and \(x^\star\) is the equilibrium point of the forward pass. See this tutorial for more details.
The deepinv.unfolded.DEQ_builder
class is a generic class for building Deep Equilibrium (DEQ) architectures.
Helper function for building an instance of the |
The builder depends on the backbone class for DEQs, deepinv.unfolded.BaseDEQ
.
Base class for deep equilibrium (DEQ) algorithms. |
Utils
Some more specific unfolded architectures are also available.
Primal block for the Primal-Dual unfolding model. |
|
Dual block for the Primal-Dual unfolding model. |