.. _optim: Optimization ============ This module contains a collection of routines that optimize .. math:: \begin{equation} \label{eq:min_prob} \tag{1} \underset{x}{\arg\min} \quad \datafid{x}{y} + \lambda \reg{x}, \end{equation} where the first term :math:`\datafidname:\xset\times\yset \mapsto \mathbb{R}_{+}` enforces data-fidelity, the second term :math:`\regname:\xset\mapsto \mathbb{R}_{+}` acts as a regularization and :math:`\lambda > 0` is a regularization parameter. More precisely, the data-fidelity term penalizes the discrepancy between the data :math:`y` and the forward operator :math:`A` applied to the variable :math:`x`, as .. math:: \datafid{x}{y} = \distance{A(x)}{y}, where :math:`\distance{\cdot}{\cdot}` is a distance function, and where :math:`A:\xset\mapsto \yset` is the forward operator (see :class:`deepinv.physics.Physics`) .. note:: The regularization term often (but not always) depends on a hyperparameter :math:`\sigma` that can be either fixed or estimated. For example, if the regularization is implicitly defined by a denoiser, the hyperparameter is the noise level. A typical example of optimization problem is the :math:`\ell_1`-regularized least squares problem, where the data-fidelity term is the squared :math:`\ell_2`-norm and the regularization term is the :math:`\ell_1`-norm. In this case, a possible algorithm to solve the problem is the Proximal Gradient Descent (PGD) algorithm writing as .. math:: \qquad x_{k+1} = \operatorname{prox}_{\gamma \lambda \regname} \left( x_k - \gamma \nabla \datafidname(x_k, y) \right), where :math:`\operatorname{prox}_{\lambda \regname}` is the proximity operator of the regularization term, :math:`\gamma` is the step size of the algorithm, and :math:`\nabla \datafidname` is the gradient of the data-fidelity term. The following example illustrates the implementation of the PGD algorithm with DeepInverse to solve the :math:`\ell_1`-regularized least squares problem. .. doctest:: >>> import torch >>> import deepinv as dinv >>> from deepinv.optim import L2, TVPrior >>> >>> # Forward operator, here inpainting >>> mask = torch.ones((1, 2, 2)) >>> mask[0, 0, 0] = 0 >>> physics = dinv.physics.Inpainting(img_size=mask.shape, mask=mask) >>> # Generate data >>> x = torch.ones((1, 1, 2, 2)) >>> y = physics(x) >>> data_fidelity = L2() # The data fidelity term >>> prior = TVPrior() # The prior term >>> lambd = 0.1 # Regularization parameter >>> # Compute the squared norm of the operator A >>> norm_A2 = physics.compute_norm(y, tol=1e-4, verbose=False).item() >>> stepsize = 1/norm_A2 # stepsize for the PGD algorithm >>> >>> # PGD algorithm >>> max_iter = 20 # number of iterations >>> x_k = torch.zeros_like(x) # initial guess >>> >>> for it in range(max_iter): ... u = x_k - stepsize*data_fidelity.grad(x_k, y, physics) # Gradient step ... x_k = prior.prox(u, gamma=lambd*stepsize) # Proximal step ... cost = data_fidelity(x_k, y, physics) + lambd*prior(x_k) # Compute the cost ... >>> print(cost < 1e-5) tensor([True]) >>> print('Estimated solution: ', x_k.flatten()) Estimated solution: tensor([1.0000, 1.0000, 1.0000, 1.0000]) .. _potentials: Potentials ---------- The class :class:`deepinv.optim.Potential` implements potential scalar functions :math:`h : \xset \to \mathbb{R}` used to define an optimization problems. For example, both :math:`f` and :math:`\regname` are potentials. This class comes with methods for computing operators useful for optimization, such as its proximal operator :math:`\operatorname{prox}_{h}`, its gradient :math:`\nabla h`, its convex conjugate :math:`h^*`, etc. The following classes inherit from :class:`deepinv.optim.Potential` .. list-table:: * - Class - :math:`h(x)` - Requires * - :class:`deepinv.optim.Bregman` - :math:`\phi(x)` with :math:`\phi` convex - None * - :class:`deepinv.optim.Distance` - :math:`d(x,y)` - :math:`y` * - :class:`deepinv.optim.DataFidelity` - :math:`d(A(x),y)` where :math:`d` is a distance. - :math:`y` & operator :math:`A` * - :class:`deepinv.optim.Prior` - :math:`g_{\sigma}(x)` - optional denoising level :math:`\sigma` .. _data-fidelity: Data Fidelity ~~~~~~~~~~~~~ The base class :class:`deepinv.optim.DataFidelity` implements data fidelity terms :math:`\distance{A(x)}{y}` where :math:`A` is the forward operator, :math:`x\in\xset` is a variable and :math:`y\in\yset` is the data, and where :math:`d` is a distance function from the class :class:`deepinv.optim.Distance`. The class :class:`deepinv.optim.Distance` is implemented as a child class from :class:`deepinv.optim.Potential`. This data-fidelity class thus comes with useful methods, such as :math:`\operatorname{prox}_{\distancename\circ A}` and :math:`\nabla (\distancename \circ A)` (among others) which are used by most optimization algorithms. .. list-table:: Data Fidelity Overview :header-rows: 1 * - Data Fidelity - :math:`d(A(x), y)` * - :class:`deepinv.optim.L1` - :math:`\|A(x) - y\|_1` * - :class:`deepinv.optim.L2` - :math:`\|A(x) - y\|_2^2` * - :class:`deepinv.optim.IndicatorL2` - Indicator function of :math:`\|A(x) - y\|_2 \leq \epsilon` * - :class:`deepinv.optim.PoissonLikelihood` - :math:`\datafid{A(x)}{y} = -y^{\top} \log(A(x)+\beta)+1^{\top}A(x)` * - :class:`deepinv.optim.LogPoissonLikelihood` - :math:`N_0 (1^{\top} \exp(-\mu A(x))+ \mu \exp(-\mu y)^{\top}A(x))` * - :class:`deepinv.optim.AmplitudeLoss` - :math:`\sum_{i=1}^{m}{(\sqrt{|b_i^{\top} x|^2}-\sqrt{y_i})^2}` * - :class:`deepinv.optim.ZeroFidelity` - :math:`\datafid{x}{y} = 0`. * - :class:`deepinv.optim.ItohFidelity` - :math:`\datafid{x}{y} = \|Dx - w_t(Dy)\|_2^2` where :math:`D` is a finite difference operator and :math:`w_t` the modulo operator. .. _priors: Priors ~~~~~~ Prior functions are defined as :math:`\reg{x}` where :math:`x\in\xset` is a variable and where :math:`\regname` is a function. The base class is :class:`deepinv.optim.Prior` implemented as a child class from :class:`deepinv.optim.Potential` and therefore it comes with methods for computing operators such as :math:`\operatorname{prox}_{\regname}` and :math:`\nabla \regname`. This base class is used to implement user-defined differentiable priors (eg. Tikhonov regularization) but also implicit priors (eg. plug-and-play methods). .. list-table:: Priors Overview :header-rows: 1 * - Prior - :math:`\reg{x}` - Explicit :math:`\regname` * - :class:`deepinv.optim.PnP` - :math:`\operatorname{prox}_{\gamma \regname}(x) = \operatorname{D}_{\sigma}(x)` - No * - :class:`deepinv.optim.RED` - :math:`\nabla \reg{x} = x - \operatorname{D}_{\sigma}(x)` - No * - :class:`deepinv.optim.ScorePrior` - :math:`\nabla \reg{x}=\left(x-\operatorname{D}_{\sigma}(x)\right)/\sigma^2` - No * - :class:`deepinv.optim.ZeroPrior` - :math:`\regname(x) = 0` - Yes * - :class:`deepinv.optim.Tikhonov` - :math:`\reg{x}=\|x\|_2^2` - Yes * - :class:`deepinv.optim.L1Prior` - :math:`\reg{x}=\|x\|_1` - Yes * - :class:`deepinv.optim.WaveletPrior` - :math:`\reg{x} = \|\Psi x\|_{p}` where :math:`\Psi` is a wavelet transform - Yes * - :class:`deepinv.optim.TVPrior` - :math:`\reg{x}=\|Dx\|_{1,2}` where :math:`D` is a finite difference operator - Yes * - :class:`deepinv.optim.PatchPrior` - :math:`\reg{x} = \sum_i h(P_i x)` for some prior :math:`h(x)` on the space of patches - Yes * - :class:`deepinv.optim.PatchNR` - Patch prior via normalizing flows. - Yes * - :class:`deepinv.optim.L12Prior` - :math:`\reg{x} = \sum_i\| x_i \|_2` - Yes .. _optim_iterators: Predefined Algorithms --------------------- Optimization algorithm inherit from the base class :class:`deepinv.optim.BaseOptim`, which serves as a common interface for all predefined optimization algorithms. Classical optimizations algorithms are already implemented as subclasses of :class:`deepinv.optim.BaseOptim`, such as :class:`deepinv.optim.GD`, :class:`deepinv.optim.PGD`, :class:`deepinv.optim.ADMM`, etc. For example, we can create the same proximal gradient algorithm as the one at the beginning of this page, in one line of code: .. doctest:: >>> model = dinv.optim.PGD(prior=prior, data_fidelity=data_fidelity, stepsize=stepsize, lambda_reg=lambd, max_iter=max_iter) >>> x_hat = model(y, physics) >>> dinv.utils.plot([x, y, x_hat], ["signal", "measurement", "estimate"], rescale_mode='clip') Some predefined optimizers are provided: .. list-table:: :header-rows: 1 * - Algorithm - Iteration * - :class:`deepinv.optim.GD` - | :math:`v_{k} = \nabla f(x_k) + \lambda \nabla \reg{x_k}` | :math:`x_{k+1} = x_k-\gamma v_{k}` * - :class:`deepinv.optim.PGD` - | :math:`u_{k} = x_k - \gamma \nabla f(x_k)` | :math:`x_{k+1} = \operatorname{prox}_{\gamma \lambda \regname}(u_k)` * - :class:`deepinv.optim.FISTA` - | :math:`u_{k} = z_k - \gamma \nabla f(z_k)` | :math:`x_{k+1} = \operatorname{prox}_{\gamma \lambda \regname}(u_k)` | :math:`z_{k+1} = x_{k+1} + \alpha_k (x_{k+1} - x_k)` * - :class:`deepinv.optim.HQS` - | :math:`u_{k} = \operatorname{prox}_{\gamma f}(x_k)` | :math:`x_{k+1} = \operatorname{prox}_{\sigma \lambda \regname}(u_k)` * - :class:`deepinv.optim.ADMM` - | :math:`u_{k+1} = \operatorname{prox}_{\gamma f}(x_k - z_k)` | :math:`x_{k+1} = \operatorname{prox}_{\gamma \lambda \regname}(u_{k+1} + z_k)` | :math:`z_{k+1} = z_k + \beta (u_{k+1} - x_{k+1})` * - :class:`deepinv.optim.DRS` - | :math:`u_{k+1} = \operatorname{prox}_{\gamma f}(z_k)` | :math:`x_{k+1} = \operatorname{prox}_{\gamma \lambda \regname}(2*u_{k+1}-z_k)` | :math:`z_{k+1} = z_k + \beta (x_{k+1} - u_{k+1})` * - :class:`deepinv.optim.PDCP` - | :math:`u_{k+1} = \operatorname{prox}_{\sigma F^*}(u_k + \sigma K z_k)` | :math:`x_{k+1} = \operatorname{prox}_{\tau \lambda G}(x_k-\tau K^\top u_{k+1})` | :math:`z_{k+1} = x_{k+1} + \beta(x_{k+1}-x_k)` * - :class:`deepinv.optim.MD` - | :math:`v_{k} = \nabla f(x_k) + \lambda \nabla \reg{x_k}` | :math:`x_{k+1} = \nabla h^*(\nabla h(x_k) - \gamma v_{k})` * - :class:`deepinv.optim.PMD` - | :math:`v_{k} = \nabla f(x_k) + \lambda \nabla \reg{x_k}` | :math:`u_{k} = \nabla h^*(\nabla h(x_k) - \gamma v_{k})` | :math:`x_{k+1} = \operatorname{prox^h}_{\gamma \lambda \regname}(u_k)` .. _initialization: Initialization ~~~~~~~~~~~~~~~~~~~~~~~ By default, in these predefined algorithms, the iterates are initialized with the adjoint applied to the measurement :math:`A^{\top}y`, when the adjoint is defined, and with the observation :math:`y` if the adjoint is not defined. Custom initialization can be defined in two ways: 1. When calling the model via the ``init`` argument in the ``forward`` method of :class:`deepinv.optim.BaseOptim`. In this case, ``init`` can be either a fixed initialization or a Callable function of the form ``init(y, physics)`` that takes as input the measurement :math:`y` and the physics ``physics``. The output of the function or the fixed initialization can be either: - a tuple of tensors :math:`(x_0, z_0)` (where :math:`x_0` and :math:`z_0` are the initial primal and dual variables), - a single tensor :math:`x_0` (if no dual variables :math:`z_0` are used), or - a dictionary of the form ``X = {'est': (x_0, z_0)}``. 2. When creating the optim model via the :class:`custom_init ` argument. In this case, it must be set as a callable function ``custom_init(y, physics)`` that takes as input the measurement :math:`y` and the physics :math:`A` and returns the initialization in the same form as in case 1. For example, for initializing the above PGD algorithm with the pseudo-inverse of the measurement operator :math:`A^{\dagger}y`, one can either use the ``init`` argument when calling the standard ``PGD`` model: .. doctest:: >>> x_hat = model(y, physics, init=physics.A_dagger(y)) >>> dinv.utils.plot([x, y, x_hat], ["signal", "measurement", "estimate"], rescale_mode='clip') or one can define a custom initialization function and pass it to the ``custom_init`` argument when creating the optimization model: .. doctest:: >>> def pseudo_inverse_init(y, physics): ... return physics.A_dagger(y) >>> model = dinv.optim.PGD(custom_init=pseudo_inverse_init, prior=prior, data_fidelity=data_fidelity, stepsize=stepsize, lambda_reg=lambd, max_iter=max_iter) >>> x_hat = model(y, physics) >>> dinv.utils.plot([x, y, x_hat], ["signal", "measurement", "estimate"], rescale_mode='clip') .. _optim-params: Optimization Parameters ~~~~~~~~~~~~~~~~~~~~~~~~ The parameters of generic optimization algorithms, such as stepsize, regularization parameter, standard deviation of denoiser prior can be passed as arguments to the constructor of the optimization algorithm. Alternatively, the parameters can be defined via the dictionary ``params_algo``. This dictionary contains keys that are strings corresponding to the name of the parameters. .. list-table:: :header-rows: 1 * - Parameters name - Meaning - Recommended Values * - ``"stepsize"`` - Step size of the optimization algorithm. - | Should be positive. Depending on the algorithm, | needs to be small enough for convergence; | e.g. for PGD with ``g_first=False``, | should be smaller than :math:`1/(\|A\|_2^2)`. * - ``"lambda_reg"`` (or ``"lambda"`` if passed via the dictionary ``params_algo``) - | Regularization parameter :math:`\lambda` | multiplying the regularization term. - Should be positive. * - ``"g_param"`` or ``"sigma_denoiser"`` - | Optional prior hyper-parameter which :math:`\regname` depends on. | For priors based on denoisers, | corresponds to the noise level :math:`\sigma` . - Should be positive. * - ``"beta"`` - | Relaxation parameter used in | ADMM, DRS, CP. - Should be positive. * - ``"stepsize_dual"`` - | Step size in the dual update in the | Primal Dual algorithm (only required by CP). - Should be positive. Each parameter can be given as an iterable (i.e., a list) with a distinct value for each iteration or a single float (same parameter value for each iteration). Moreover, backtracking can be used to automaticaly adapt the stepsize at each iteration. Backtracking consists in choosing the largest stepsize :math:`\tau` such that, at each iteration, sufficient decrease of the cost function :math:`F` is achieved. More precisely, Given :math:`\gamma \in (0,1/2)` and :math:`\eta \in (0,1)` and an initial stepsize :math:`\tau > 0`, the following update rule is applied at each iteration :math:`k`: .. math:: \text{ while } F(x_k) - F(x_{k+1}) < \frac{\gamma}{\tau} || x_{k-1} - x_k ||^2, \,\, \text{ do } \tau \leftarrow \eta \tau In order to use backtracking, the argument ``backtracking`` of :class:`deepinv.optim.BaseOptim` must be an instance of :class:`deepinv.optim.BacktrackingConfig`, which defines the parameters for backtracking line-search. The :class:`deepinv.optim.BacktrackingConfig` dataclass has the following attributes and default values: .. code-block:: python @dataclass class BacktrackingConfig: gamma: float = 0.1 # Armijo-like parameter controlling sufficient decrease eta: float = 0.9 # Step reduction factor max_iter: int = 10 # Maximum number of backtracking iterations By default, backtracking is disabled (i.e., ``backtracking=None``), and as soon as ``backtraking`` is not ``None``, the above ``BacktrackingConfig`` is used by default. .. note:: To use backtracking, the optimized function (i.e., both the the data-fidelity and prior) must be explicit and provide a computable cost for the current iterate. If the prior is not explicit (e.g. a denoiser) i.e. the argument ``explicit_prior``, of the prior :class:`deepinv.optim.Prior` is ``False`` or if the argument ``has_cost`` of the class :class:`deepinv.optim.BaseOptim` is ``False``, backtracking is automatically disabled. .. _bregman: Bregman ~~~~~~~ Bregman potentials are defined as :math:`\phi(x)` where :math:`x\in\xset` is a variable and where :math:`\phi` is a convex scalar function, and are defined via the base class :class:`deepinv.optim.Bregman`. In addition to the methods inherited from :class:`deepinv.optim.Potential` (gradient :math:`\nabla \phi`, conjugate :math:`\phi^*` and its gradient :math:`\nabla \phi^*`), this class provides the Bregman divergence :math:`D(x,y) = \phi(x) - \phi^*(y) - x^{\top} y`, and is well suited for performing Mirror Descent. .. list-table:: Bregman potentials :header-rows: 1 * - Class - Bregman potential :math:`\phi(x)` * - :class:`deepinv.optim.bregman.BregmanL2` - :math:`\|x\|_2^2` * - :class:`deepinv.optim.bregman.BurgEntropy` - :math:`- \sum_i \log x_i` * - :class:`deepinv.optim.bregman.NegEntropy` - :math:`\sum_i x_i \log x_i` * - :class:`deepinv.optim.bregman.Bregman_ICNN` - :class:`Convolutional Input Convex NN ` .. _optim-utils: Utils ----- We provide some useful routines for optimization algorithms. - :class:`deepinv.optim.utils.conjugate_gradient` implements the conjugate gradient algorithm for solving linear systems. - :class:`deepinv.optim.utils.gradient_descent` implements the gradient descent algorithm.