Optim
This package contains a collection of routines that optimize
where the first term \(\datafidname:\xset\times\yset \mapsto \mathbb{R}_{+}\) enforces data-fidelity, the second term \(\regname:\xset\mapsto \mathbb{R}_{+}\) acts as a regularization and \(\lambda > 0\) is a regularization parameter. More precisely, the data-fidelity term penalizes the discrepancy between the data \(y\) and the forward operator \(A\) applied to the variable \(x\), as
where \(\distance{\cdot}{\cdot}\) is a distance function, and where \(A:\xset\mapsto \yset\) is the forward
operator (see deepinv.physics.Physics()
)
Note
The regularization term often (but not always) depends on a hyperparameter \(\sigma\) that can be either fixed or estimated. For example, if the regularization is implicitly defined by a denoiser, the hyperparameter is the noise level.
A typical example of optimization problem is the \(\ell_1\)-regularized least squares problem, where the data-fidelity term is the squared \(\ell_2\)-norm and the regularization term is the \(\ell_1\)-norm. In this case, a possible algorithm to solve the problem is the Proximal Gradient Descent (PGD) algorithm writing as
where \(\operatorname{prox}_{\lambda \regname}\) is the proximity operator of the regularization term, \(\gamma\) is the step size of the algorithm, and \(\nabla \datafidname\) is the gradient of the data-fidelity term.
The following example illustrates the implementation of the PGD algorithm with DeepInverse to solve the \(\ell_1\)-regularized least squares problem.
>>> import torch
>>> import deepinv as dinv
>>> from deepinv.optim import L2, TVPrior
>>>
>>> # Forward operator, here inpainting
>>> mask = torch.ones((1, 2, 2))
>>> mask[0, 0, 0] = 0
>>> physics = dinv.physics.Inpainting(tensor_size=mask.shape, mask=mask)
>>> # Generate data
>>> x = torch.ones((1, 1, 2, 2))
>>> y = physics(x)
>>> data_fidelity = L2() # The data fidelity term
>>> prior = TVPrior() # The prior term
>>> lambd = 0.1 # Regularization parameter
>>> # Compute the squared norm of the operator A
>>> norm_A2 = physics.compute_norm(y, tol=1e-4, verbose=False).item()
>>> stepsize = 1/norm_A2 # stepsize for the PGD algorithm
>>>
>>> # PGD algorithm
>>> max_iter = 20 # number of iterations
>>> x_k = torch.zeros_like(x) # initial guess
>>>
>>> for it in range(max_iter):
... u = x_k - stepsize*data_fidelity.grad(x_k, y, physics) # Gradient step
... x_k = prior.prox(u, gamma=lambd*stepsize) # Proximal step
... cost = data_fidelity(x_k, y, physics) + lambd*prior(x_k) # Compute the cost
...
>>> print(cost < 1e-5)
tensor([True])
>>> print('Estimated solution: ', x_k.flatten())
Estimated solution: tensor([1.0000, 1.0000, 1.0000, 1.0000])
Optimization algorithms such as the one above can be written as fixed point algorithms, i.e. for \(k=1,2,...\)
where \(x\) is a variable converging to the solution of the minimization problem, and \(z\) is an additional (dual) variable that may be required in the computation of the fixed point operator.
The function deepinv.optim.optim_builder()
returns an instance of deepinv.optim.BaseOptim()
with the
optimization algorithm of choice, either a predefined one ("PGD"
, "ADMM"
, "HQS"
, etc.),
or with a user-defined one.
Helper function for building an instance of the |
Optimization algorithm inherit from the base class deepinv.optim.BaseOptim()
, which serves as a common interface
for all optimization algorithms.
Class for optimization algorithms, consists in iterating a fixed-point operator. |
Data Fidelity
This is the base class for the data fidelity term \(\distance{A(x)}{y}\) where \(A\) is the forward operator, \(x\in\xset\) is a variable and \(y\in\yset\) is the data, and where \(d\) is a convex function.
This class comes with methods, such as \(\operatorname{prox}_{\distancename\circ A}\) and \(\nabla (\distancename \circ A)\) (among others), on which optimization algorithms rely.
Data fidelity term \(\datafid{x}{y}=\distance{\forw{x}}{y}\). |
|
\(\ell_1\) data fidelity term. |
|
Implementation of \(\distancename\) as the normalized \(\ell_2\) norm |
|
Indicator of \(\ell_2\) ball with radius \(r\). |
|
Poisson negative log-likelihood. |
|
Log-Poisson negative log-likelihood. |
|
Amplitude loss as the data fidelity term for |
Priors
This is the base class for implementing prior functions \(\reg{x}\) where \(x\in\xset\) is a variable and where \(\regname\) is a function.
Similarly to the deepinv.optim.DataFidelity()
class, this class comes with methods for computing
\(\operatorname{prox}_{g}\) and \(\nabla \regname\). This base class is used to implement user-defined differentiable
priors, such as the Tikhonov regularisation, but also implicit priors. For instance, in PnP methods, the method
computing the proximity operator is overwritten by a method performing denoising.
Prior term \(\reg{x}\). |
|
Plug-and-play prior \(\operatorname{prox}_{\gamma \regname}(x) = \operatorname{D}_{\sigma}(x)\). |
|
Regularization-by-Denoising (RED) prior \(\nabla \reg{x} = x - \operatorname{D}_{\sigma}(x)\). |
|
Score via MMSE denoiser \(\nabla \reg{x}=\left(x-\operatorname{D}_{\sigma}(x)\right)/\sigma^2\). |
|
Tikhonov regularizer \(\reg{x} = \frac{1}{2}\| x \|_2^2\). |
|
\(\ell_1\) prior \(\reg{x} = \| x \|_1\). |
|
Wavelet prior \(\reg{x} = \|\Psi x\|_{p}\). |
|
Total variation (TV) prior \(\reg{x} = \| D x \|_{1,2}\). |
|
Patch prior \(g(x) = \sum_i h(P_i x)\) for some prior \(h(x)\) on the space of patches. |
|
Patch prior via normalizing flows. |
|
\(\ell_{1,2}\) prior \(\reg{x} = \sum_i\| x_i \|_2\). |
Parameters
The parameters of the optimization algorithm, such as
stepsize, regularisation parameter, denoising standard deviation, etc.
are stored in a dictionary "params_algo"
, whose typical entries are:
Key |
Meaning |
Recommended Values |
---|---|---|
|
Step size of the optimization algorithm. |
Should be positive. Depending on the algorithm,
needs to be small enough for convergence;
e.g. for PGD with
g_first=False ,should be smaller than \(1/(\|A\|_2^2)\).
|
|
Regularization parameter \(\lambda\)
multiplying the regularization term.
|
Should be positive. |
|
Optional parameter \(\sigma\) which \(\regname\) depends on.
For priors based on denoisers,
corresponds to the noise level.
|
Should be positive. |
|
Relaxation parameter used in
ADMM, DRS, CP.
|
Should be positive. |
|
Step size in the dual update in the
Primal Dual algorithm (only required by CP).
|
Should be positive. |
Each value of the dictionary can be either an iterable (i.e., a list with a distinct value for each iteration) or a single float (same value for each iteration).
Iterators
An optim iterator is an object that implements a fixed point iteration for minimizing the sum of two functions \(F = \datafidname + \lambda \regname\) where \(\datafidname\) is a data-fidelity term that will be modeled by an instance of physics and \(\regname\) is a regularizer. The fixed point iteration takes the form
where \(x\) is a variable converging to the solution of the minimization problem, and \(z\) is an additional variable that may be required in the computation of the fixed point operator.
Fixed-point iterations module. |
The implementation of the fixed point algorithm in deepinv.optim()
,
following standard optimization theory, is split in two steps:
where \(\operatorname{step}_{\datafidname}\) and \(\operatorname{step}_g\) are gradient and/or proximal steps on \(\datafidname\) and \(\regname\), while using additional inputs, such as \(A\) and \(y\), but also stepsizes, relaxation parameters, etc…
The fStep and gStep classes precisely implement these steps.
Generic Optimizers
The following files contain the base classes for implementing generic optimizers:
Base class for all |
|
Iterator for Gradient Descent. |
|
Iterator for proximal gradient descent. |
|
Iterator for fast iterative soft-thresholding (FISTA). |
|
Iterator for Chambolle-Pock. |
|
Iterator for alternating direction method of multipliers. |
|
Iterator for Douglas-Rachford Splitting. |
|
Single iteration of half-quadratic splitting. |
|
Iterator for Spectral Methods for |
Utils
We provide some useful utilities for optimization algorithms.
Standard conjugate gradient algorithm. |
|
Standard gradient descent algorithm`. |
|
Gaussian mixture model including parameter estimation. |