Source code for deepinv.optim.distance

import torch
from deepinv.optim.potential import Potential
from deepinv.utils.tensorlist import TensorList


[docs] class Distance(Potential): r""" Distance :math:`\distance{x}{y}`. This is the base class for a distance :math:`\distance{x}{y}` between a variable :math:`x` and an observation :math:`y`. Comes with methods to compute the distance gradient, proximal operator or convex conjugate with respect to the variable :math:`x`. .. warning:: All variables have a batch dimension as first dimension. :param Callable d: distance function :math:`\distance{x}{y}`. Outputs a tensor of size `B`, the size of the batch. Default: None. """ def __init__(self, d=None): super().__init__(fn=d)
[docs] def fn(self, x, y, *args, **kwargs): r""" Computes the distance :math:`\distance{x}{y}`. :param torch.Tensor x: Variable :math:`x`. :param torch.Tensor y: Observation :math:`y`. :return: (:class:`torch.Tensor`) distance :math:`\distance{x}{y}` of size `B` with `B` the size of the batch. """ return self._fn(x, y, *args, **kwargs)
[docs] def forward(self, x, y, *args, **kwargs): r""" Computes the value of the distance :math:`\distance{x}{y}`. :param torch.Tensor x: Variable :math:`x`. :param torch.Tensor y: Observation :math:`y`. :return: (:class:`torch.Tensor`) distance :math:`\distance{x}{y}` of size `B` with `B` the size of the batch. """ return self.fn(x, y, *args, **kwargs)
[docs] class L2Distance(Distance): r""" Implementation of :math:`\distancename` as the normalized :math:`\ell_2` norm .. math:: f(x) = \frac{1}{2\sigma^2}\|x-y\|^2 :param float sigma: normalization parameter. Default: 1. """ def __init__(self, sigma=1.0): super().__init__() self.norm = 1 / (sigma**2)
[docs] def fn(self, x, y, *args, **kwargs): r""" Computes the distance :math:`\distance{x}{y}` i.e. .. math:: \distance{x}{y} = \frac{1}{2}\|x-y\|^2 :param torch.Tensor u: Variable :math:`x` at which the data fidelity is computed. :param torch.Tensor y: Data :math:`y`. :return: (:class:`torch.Tensor`) data fidelity :math:`\datafid{u}{y}` of size `B` with `B` the size of the batch. """ z = x - y d = 0.5 * torch.norm(z.reshape(z.shape[0], -1), p=2, dim=-1) ** 2 * self.norm return d
[docs] def grad(self, x, y, *args, **kwargs): r""" Computes the gradient of :math:`\distancename`, that is :math:`\nabla_{x}\distance{x}{y}`, i.e. .. math:: \nabla_{x}\distance{x}{y} = \frac{1}{\sigma^2} x-y :param torch.Tensor x: Variable :math:`x` at which the gradient is computed. :param torch.Tensor y: Observation :math:`y`. :return: (:class:`torch.Tensor`) gradient of the distance function :math:`\nabla_{x}\distance{x}{y}`. """ return (x - y) * self.norm
[docs] def prox(self, x, y, *args, gamma=1.0, **kwargs): r""" Proximal operator of :math:`\gamma \distance{x}{y} = \frac{\gamma}{2 \sigma^2} \|x-y\|^2`. Computes :math:`\operatorname{prox}_{\gamma \distancename}`, i.e. .. math:: \operatorname{prox}_{\gamma \distancename} = \underset{u}{\text{argmin}} \frac{\gamma}{2\sigma^2}\|u-y\|_2^2+\frac{1}{2}\|u-x\|_2^2 :param torch.Tensor x: Variable :math:`x` at which the proximity operator is computed. :param torch.Tensor y: Data :math:`y`. :param float gamma: thresholding parameter. :return: (:class:`torch.Tensor`) proximity operator :math:`\operatorname{prox}_{\gamma \distancename}(x)`. """ return (x + self.norm * gamma * y) / (1 + gamma * self.norm)
[docs] class IndicatorL2Distance(Distance): r""" Indicator of :math:`\ell_2` ball with radius :math:`r`. The indicator function of the $\ell_2$ ball with radius :math:`r`, denoted as \iota_{\mathcal{B}_2(y,r)(x)}, is defined as .. math:: \iota_{\mathcal{B}_2(y,r)}(x)= \left. \begin{cases} 0, & \text{if } \|x-y\|_2\leq r \\ +\infty & \text{else.} \end{cases} \right. :param float radius: radius of the ball. Default: None. """ def __init__(self, radius=None): super().__init__() self.radius = radius
[docs] def fn(self, x, y, *args, radius=None, **kwargs): r""" Computes the batched indicator of :math:`\ell_2` ball with radius `radius`, i.e. :math:`\iota_{\mathcal{B}(y,r)}(x)`. :param torch.Tensor x: Variable :math:`x` at which the indicator is computed. :math:`u` is assumed to be of shape (B, ...) where B is the batch size. :param torch.Tensor y: Observation :math:`y` of the same dimension as :math:`u`. :param float radius: radius of the :math:`\ell_2` ball. If `radius` is None, the radius of the ball is set to `self.radius`. Default: None. :return: (:class:`torch.Tensor`) indicator of :math:`\ell_2` ball with radius `radius`. If the point is inside the ball, the output is 0, else it is 1e16. """ diff = x - y dist = torch.norm(diff.reshape(diff.shape[0], -1), p=2, dim=-1) radius = self.radius if radius is None else radius loss = (dist > radius) * 1e16 return loss
[docs] def prox(self, x, y, *args, radius=None, gamma=None, **kwargs): r""" Proximal operator of the indicator of :math:`\ell_2` ball with radius `radius`, i.e. .. math:: \operatorname{prox}_{\iota_{\mathcal{B}_2(y,r)}}(x) = \operatorname{proj}_{\mathcal{B}_2(y, r)}(x) where :math:`\operatorname{proj}_{C}(x)` denotes the projection on the closed convex set :math:`C`. :param torch.Tensor x: Variable :math:`x` at which the proximity operator is computed. :param torch.Tensor y: Observation :math:`y` of the same dimension as :math:`x`. :param float gamma: step-size. Note that this parameter is not used in this function. :param float radius: radius of the :math:`\ell_2` ball. :return: (:class:`torch.Tensor`) projection on the :math:`\ell_2` ball of radius `radius` and centered in `y`. """ radius = self.radius if radius is None else radius diff = x - y dist = torch.norm(diff.reshape(diff.shape[0], -1), p=2, dim=-1) return y + diff * ( torch.min(torch.tensor([radius]).to(x.device), dist) / (dist + 1e-12) ).view(-1, 1, 1, 1)
[docs] class PoissonLikelihoodDistance(Distance): r""" (Negative) Log-likelihood of the Poisson distribution. .. math:: \distance{y}{x} = \sum_i y_i \log(y_i / x_i) + x_i - y_i .. note:: The function is not Lipschitz smooth w.r.t. :math:`x` in the absence of background (:math:`\beta=0`). :param float gain: gain of the measurement :math:`y`. Default: 1.0. :param float bkg: background level :math:`\beta`. Default: 0. :param bool denormalize: if True, the measurement is divided by the gain. By default, in the :class:`deepinv.physics.PoissonNoise`, the measurements are multiplied by the gain after being sampled by the Poisson distribution. Default: True. """ def __init__(self, gain=1.0, bkg=0, denormalize=False): super().__init__() self.bkg = bkg self.gain = gain self.denormalize = denormalize
[docs] def fn(self, x, y, *args, **kwargs): r""" Computes the Kullback-Leibler divergence :param torch.Tensor x: Variable :math:`x` at which the distance is computed. :param torch.Tensor y: Observation :math:`y`. """ if self.denormalize: y = y / self.gain return (-y * torch.log(x / self.gain + self.bkg)).flatten().sum() + ( (x / self.gain) + self.bkg - y ).reshape(x.shape[0], -1).sum(dim=1)
[docs] def grad(self, x, y, *args, **kwargs): r""" Gradient of the Kullback-Leibler divergence :param torch.Tensor x: signal :math:`x` at which the function is computed. :param torch.Tensor y: measurement :math:`y`. """ if self.denormalize: y = y / self.gain return self.gain * (torch.ones_like(x) - y / (x / self.gain + self.bkg))
[docs] def prox(self, x, y, *args, gamma=1.0, **kwargs): r""" Proximal operator of the Kullback-Leibler divergence :param torch.Tensor x: signal :math:`x` at which the function is computed. :param torch.Tensor y: measurement :math:`y`. :param float gamma: proximity operator step size. """ if self.denormalize: y = y / self.gain out = ( x - (1 / (self.gain * gamma)) * ((x - (1 / (self.gain * gamma))).pow(2) + 4 * y / gamma).sqrt() ) return out / 2
[docs] class L1Distance(Distance): r""" :math:`\ell_1` distance .. math:: f(x) = \|x-y\|_1. """ def __init__(self): super().__init__()
[docs] def fn(self, x, y, *args, **kwargs): diff = x - y return torch.norm(diff.reshape(diff.shape[0], -1), p=1, dim=-1)
[docs] def grad(self, x, y, *args, **kwargs): r""" Gradient of the gradient of the :math:`\ell_1` norm, i.e. .. math:: \partial \datafid(x) = \operatorname{sign}(x-y) .. note:: The gradient is not defined at :math:`x=y`. :param torch.Tensor x: Variable :math:`x` at which the gradient is computed. :param torch.Tensor y: Data :math:`y` of the same dimension as :math:`x`. :return: (:class:`torch.Tensor`) gradient of the :math:`\ell_1` norm at `x`. """ return torch.sign(x - y)
[docs] def prox(self, u, y, *args, gamma=1.0, **kwargs): r""" Proximal operator of the :math:`\ell_1` norm, i.e. .. math:: \operatorname{prox}_{\gamma \ell_1}(x) = \underset{z}{\text{argmin}} \,\, \gamma \|z-y\|_1+\frac{1}{2}\|z-x\|_2^2 also known as the soft-thresholding operator. :param torch.Tensor u: Variable :math:`u` at which the proximity operator is computed. :param torch.Tensor y: Data :math:`y` of the same dimension as :math:`x`. :param float gamma: stepsize (or soft-thresholding parameter). :return: (:class:`torch.Tensor`) soft-thresholding of `u` with parameter `gamma`. """ d = u - y aux = torch.sign(d) * torch.maximum( d.abs() - gamma, torch.tensor([0]).to(d.device) ) return aux + y
[docs] class AmplitudeLossDistance(Distance): r""" Amplitude loss for :class:`deepinv.physics.PhaseRetrieval` reconstruction, defined as .. math:: f(x) = \sum_{i=1}^{m}{(\sqrt{|y_i - x|^2}-\sqrt{y_i})^2}, where :math:`y_i` is the i-th entry of the measurements, and :math:`m` is the number of measurements. """ def __init__(self): super().__init__()
[docs] def fn(self, u, y, *args, **kwargs): r""" Computes the amplitude loss. :param torch.Tensor u: estimated measurements. :param torch.Tensor y: true measurements. :return: (:class:`torch.Tensor`) the amplitude loss of shape B where B is the batch size. """ x = torch.sqrt(u) - torch.sqrt(y) d = torch.norm(x.reshape(x.shape[0], -1), p=2, dim=-1) ** 2 return d
[docs] def grad(self, u, y, *args, epsilon=1e-12, **kwargs): r""" Computes the gradient of the amplitude loss :math:`\distance{u}{y}`, i.e., .. math:: \nabla_{u}\distance{u}{y} = \frac{\sqrt{u}-\sqrt{y}}{\sqrt{u}} :param torch.Tensor u: Variable :math:`u` at which the gradient is computed. :param torch.Tensor y: Data :math:`y`. :param float epsilon: small value to avoid division by zero. :return: (:class:`torch.Tensor`) gradient of the amplitude loss function. """ return (torch.sqrt(u + epsilon) - torch.sqrt(y)) / torch.sqrt(u + epsilon)
[docs] class LogPoissonLikelihoodDistance(Distance): r""" Log-Poisson negative log-likelihood. .. math:: \distancz{z}{y} = N_0 (1^{\top} \exp(-\mu z)+ \mu \exp(-\mu y)^{\top}x) Corresponds to LogPoissonNoise with the same arguments N0 and mu. There is no closed-form of the prox known. :param float N0: average number of photons :param float mu: normalization constant """ def __init__(self, N0=1024.0, mu=1 / 50.0): super().__init__() self.mu = mu self.N0 = N0
[docs] def fn(self, x, y, *args, **kwargs): out1 = torch.exp(-x * self.mu) * self.N0 out2 = torch.exp(-y * self.mu) * self.N0 * (x * self.mu) return (out1 + out2).reshape(x.shape[0], -1).sum(dim=1)