Source code for deepinv.optim.optim_iterators.spectral_methods

import torch

from .optim_iterator import OptimIterator, fStep, gStep


[docs] class SMIteration(OptimIterator): r""" Iterator for Spectral Methods for :class:`deepinv.physics.PhaseRetrieval`. Class for a single iteration of the Spectral Methods algorithm to find the principal eigenvector of the regularized weighted covariance matrix: .. math:: \begin{equation*} M = \conj{B} \text{diag}(T(y)) B + \lambda I, \end{equation*} where :math:`B` is the linear operator of the phase retrieval class, :math:`T(\cdot)` is a preprocessing function for the measurements, and :math:`I` is the identity matrix of corresponding dimensions. Parameter :math:`\lambda` tunes the strength of regularization. The iteration is given by .. math:: \begin{equation*} \begin{aligned} x_{k+1} &= M x_k \\ x_{k+1} &= \operatorname{prox}_{\gamma g}(x_{k+1}), \end{aligned} \end{equation*} where :math:`\gamma` is a stepsize that should satisfy :math:`\lambda \gamma \leq 2/\operatorname{Lip}(\|\nabla f\|)`. """ def __init__( self, lamb=10, n_iter=50, preprocessing=lambda x: torch.max(1 - 1 / x, torch.tensor(-5.0)), **kwargs, ): super(SMIteration, self).__init__() self.n_iter = n_iter self.f_step = fStepSM(lamb, preprocessing=preprocessing, **kwargs) self.g_step = gStepSM(**kwargs)
[docs] def forward(self, x, cur_prior, cur_params, y, physics, *args): r""" Single iteration of the spectral method. :param dict x: the current iterate :math:`x_k`. :param deepinv.optim.Prior cur_prior: Instance of the Prior class defining the current prior. :param dict cur_params: Dictionary containing the current parameters of the algorithm. :param torch.Tensor y: Input data. :param deepinv.physics.Physics physics: Instance of the physics containing the forward operator. :return: The new iterate :math:`x_{k+1}`. """ assert hasattr( physics, "B" ), "The physics should inherit from the PhaseRetrieval class." assert hasattr( physics, "B_adjoint" ), "The physics should inherit from the PhaseRetrieval class." x = self.f_step(x, y, physics) x = self.g_step(x, cur_prior, cur_params) return x
class fStepSM(fStep): r""" Spectral Methods fStep module. """ def __init__( self, lamb=10, preprocessing=lambda x: torch.max(1 - 1 / x, torch.tensor(-5.0)), **kwargs, ): super(fStepSM, self).__init__(**kwargs) self.preprocessing = preprocessing self.lamb = lamb def forward(self, x: torch.Tensor, y: torch.Tensor, physics): r""" Single power iteration step for spectral methods. :param torch.Tensor x: Current iterate :math:`x_k`. :param torch.Tensor y: Measurements. :param deepinv.physics.Physics physics: Instance of the physics modeling the forward matrix. """ x = x.to(torch.cfloat) # normalize every image in x x = torch.stack([subtensor / subtensor.norm() for subtensor in x]) # y should have mean 1 for each image y = y / torch.mean(y, dim=1, keepdim=True) diag_T = self.preprocessing(y) diag_T = diag_T.to(torch.cfloat) res = physics.B(x) res = diag_T * res res = physics.B_adjoint(res) x = res + self.lamb * x x = torch.stack([subtensor / subtensor.norm() for subtensor in x]) return x class gStepSM(gStep): r""" Spectral Methods gStep module. """ def __init__(self, **kwargs): super(gStepSM, self).__init__(**kwargs) def forward(self, x: torch.Tensor, cur_prior, cur_params): r""" Single iteration step on the prior term :math:`g`. :param torch.Tensor x: Current iterate :math:`x_k`. :param dict cur_prior: Dictionary containing the current prior. :param dict cur_params: Dictionary containing the current parameters of the algorithm. """ return cur_prior.prox( x, cur_params["g_param"], gamma=cur_params["lambda"] * cur_params["stepsize"], )