Source code for deepinv.physics.forward

from typing import Union

import torch
from torch import Tensor
from deepinv.physics.noise import GaussianNoise
from deepinv.utils.tensorlist import randn_like, TensorList
from deepinv.optim.utils import least_squares, lsqr


[docs] class Physics(torch.nn.Module): # parent class for forward models r""" Parent class for forward operators It describes the general forward measurement process .. math:: y = N(A(x)) where :math:`x` is an image of :math:`n` pixels, :math:`y` is the measurements of size :math:`m`, :math:`A:\xset\mapsto \yset` is a deterministic mapping capturing the physics of the acquisition and :math:`N:\yset\mapsto \yset` is a stochastic mapping which characterizes the noise affecting the measurements. :param Callable A: forward operator function which maps an image to the observed measurements :math:`x\mapsto y`. :param deepinv.physics.NoiseModel, Callable noise_model: function that adds noise to the measurements :math:`N(z)`. See the noise module for some predefined functions. :param Callable sensor_model: function that incorporates any sensor non-linearities to the sensing process, such as quantization or saturation, defined as a function :math:`\eta(z)`, such that :math:`y=\eta\left(N(A(x))\right)`. By default, the `sensor_model` is set to the identity :math:`\eta(z)=z`. :param int max_iter: If the operator does not have a closed form pseudoinverse, the gradient descent algorithm is used for computing it, and this parameter fixes the maximum number of gradient descent iterations. :param float tol: If the operator does not have a closed form pseudoinverse, the gradient descent algorithm is used for computing it, and this parameter fixes the absolute tolerance of the gradient descent algorithm. :param str solver: least squares solver to use. Only gradient descent is available for non-linear operators. """ def __init__( self, A=lambda x, **kwargs: x, noise_model=lambda x, **kwargs: x, sensor_model=lambda x: x, solver="gradient_descent", max_iter=50, tol=1e-4, ): super().__init__() self.noise_model = noise_model self.sensor_model = sensor_model self.forw = A self.SVD = False # flag indicating SVD available self.max_iter = max_iter self.tol = tol self.solver = solver
[docs] def __mul__(self, other): # physics3 = physics1 \circ physics2 r""" Concatenates two forward operators :math:`A = A_1\circ A_2` via the mul operation The resulting operator keeps the noise and sensor models of :math:`A_1`. :param deepinv.physics.Physics other: Physics operator :math:`A_2` :return: (:class:`deepinv.physics.Physics`) concatenated operator """ A = lambda x: self.A(other.A(x)) # (A' = A_1 A_2) noise = self.noise_model sensor = self.sensor_model return Physics( A=A, noise_model=noise, sensor_model=sensor, max_iter=self.max_iter, tol=self.tol, )
[docs] def stack(self, other): r""" Stacks two forward operators :math:`A(x) = \begin{bmatrix} A_1(x) \\ A_2(x) \end{bmatrix}` The measurements produced by the resulting model are :class:`deepinv.utils.TensorList` objects, where each entry corresponds to the measurements of the corresponding operator. Returns a :class:`deepinv.physics.StackedPhysics` object. See :ref:`physics_combining` for more information. :param deepinv.physics.Physics other: Physics operator :math:`A_2` :return: (:class:`deepinv.physics.StackedPhysics`) stacked operator """ return stack(self, other)
[docs] def forward(self, x, **kwargs): r""" Computes forward operator .. math:: y = N(A(x), \sigma) :param torch.Tensor, list[torch.Tensor] x: signal/image :return: (:class:`torch.Tensor`) noisy measurements """ return self.sensor(self.noise(self.A(x, **kwargs), **kwargs))
[docs] def A(self, x, **kwargs): r""" Computes forward operator :math:`y = A(x)` (without noise and/or sensor non-linearities) :param torch.Tensor,list[torch.Tensor] x: signal/image :return: (:class:`torch.Tensor`) clean measurements """ return self.forw(x, **kwargs)
[docs] def sensor(self, x): r""" Computes sensor non-linearities :math:`y = \eta(y)` :param torch.Tensor,list[torch.Tensor] x: signal/image :return: (:class:`torch.Tensor`) clean measurements """ return self.sensor_model(x)
[docs] def set_noise_model(self, noise_model, **kwargs): r""" Sets the noise model :param Callable noise_model: noise model """ self.noise_model = noise_model
[docs] def noise(self, x, **kwargs) -> Tensor: r""" Incorporates noise into the measurements :math:`\tilde{y} = N(y)` :param torch.Tensor x: clean measurements :param None, float noise_level: optional noise level parameter :return: noisy measurements """ return self.noise_model(x, **kwargs)
[docs] def A_dagger(self, y, x_init=None): r""" Computes an inverse as: .. math:: x^* \in \underset{x}{\arg\min} \quad \|\forw{x}-y\|^2. This function uses gradient descent to find the inverse. It can be overwritten by a more efficient pseudoinverse in cases where closed form formulas exist. :param torch.Tensor y: a measurement :math:`y` to reconstruct via the pseudoinverse. :param torch.Tensor x_init: initial guess for the reconstruction. :return: (:class:`torch.Tensor`) The reconstructed image :math:`x`. """ if self.solver == "gradient_descent": if x_init is None: x_init = self.A_adjoint(y) x = x_init lr = 1e-1 loss = torch.nn.MSELoss() for _ in range(self.max_iter): x = x - lr * self.A_vjp(x, self.A(x) - y) err = loss(self.A(x), y) if err < self.tol: break else: raise NotImplementedError( f"Solver {self.solver} not implemented for A_dagger" ) return x.clone()
[docs] def set_ls_solver(self, solver, max_iter=None, tol=None): r""" Change default solver for computing the least squares solution: .. math:: x^* \in \underset{x}{\arg\min} \quad \|\forw{x}-y\|^2. :param str solver: solver to use. If the physics are non-linear, the only available solver is `'gradient_descent'`. For linear operators, the options are `'CG'`, `'lsqr'`, `'BiCGStab'` and `'minres'` (see :func:`deepinv.optim.utils.least_squares` for more details). :param int max_iter: maximum number of iterations for the solver. :param float tol: relative tolerance for the solver, stopping when :math:`\|A(x) - y\| < \text{tol} \|y\|`. """ if max_iter is not None: self.max_iter = max_iter if tol is not None: self.tol = tol self.solver = solver
[docs] def A_vjp(self, x, v): r""" Computes the product between a vector :math:`v` and the Jacobian of the forward operator :math:`A` evaluated at :math:`x`, defined as: .. math:: A_{vjp}(x, v) = \left. \frac{\partial A}{\partial x} \right|_x^\top v. By default, the Jacobian is computed using automatic differentiation. :param torch.Tensor x: signal/image. :param torch.Tensor v: vector. :return: (:class:`torch.Tensor`) the VJP product between :math:`v` and the Jacobian. """ _, vjpfunc = torch.func.vjp(self.A, x) return vjpfunc(v)[0]
[docs] def update(self, **kwargs): r""" Update the parameters of the forward operator. :param dict kwargs: dictionary of parameters to update. """ if hasattr(self, "update_parameters"): self.update_parameters(**kwargs) else: raise NotImplementedError( "update_parameters method not implemented for this physics operator" ) # if self.noise_model is not None: # check if noise model has a method named update_parameters if hasattr(self.noise_model, "update_parameters"): self.noise_model.update_parameters(**kwargs)
[docs] class LinearPhysics(Physics): r""" Parent class for linear operators. It describes the linear forward measurement process of the form .. math:: y = N(A(x)) where :math:`x` is an image of :math:`n` pixels, :math:`y` is the measurements of size :math:`m`, :math:`A:\xset\mapsto \yset` is a deterministic linear mapping capturing the physics of the acquisition and :math:`N:\yset\mapsto \yset` is a stochastic mapping which characterizes the noise affecting the measurements. :param Callable A: forward operator function which maps an image to the observed measurements :math:`x\mapsto y`. It is recommended to normalize it to have unit norm. :param Callable A_adjoint: transpose of the forward operator, which should verify the adjointness test. .. note:: A_adjoint can be generated automatically using the :func:`deepinv.physics.adjoint_function` method which relies on automatic differentiation, at the cost of a few extra computations per adjoint call. :param Callable noise_model: function that adds noise to the measurements :math:`N(z)`. See the noise module for some predefined functions. :param Callable sensor_model: function that incorporates any sensor non-linearities to the sensing process, such as quantization or saturation, defined as a function :math:`\eta(z)`, such that :math:`y=\eta\left(N(A(x))\right)`. By default, the sensor_model is set to the identity :math:`\eta(z)=z`. :param int max_iter: If the operator does not have a closed form pseudoinverse, the conjugate gradient algorithm is used for computing it, and this parameter fixes the maximum number of conjugate gradient iterations. :param float tol: If the operator does not have a closed form pseudoinverse, a least squares algorithm is used for computing it, and this parameter fixes the relative tolerance of the least squares algorithm. :param str solver: least squares solver to use. Choose between `'CG'`, `'lsqr'`, `'BiCGStab'` and `'minres'`. See :func:`deepinv.optim.utils.least_squares` for more details. |sep| :Examples: Blur operator with a basic averaging filter applied to a 32x32 black image with a single white pixel in the center: >>> from deepinv.physics.blur import Blur, Downsampling >>> x = torch.zeros((1, 1, 32, 32)) # Define black image of size 32x32 >>> x[:, :, 16, 16] = 1 # Define one white pixel in the middle >>> w = torch.ones((1, 1, 3, 3)) / 9 # Basic 3x3 averaging filter >>> physics = Blur(filter=w) >>> y = physics(x) Linear operators can also be stacked. The measurements produced by the resulting model are :class:`deepinv.utils.TensorList` objects, where each entry corresponds to the measurements of the corresponding operator (see :ref:`physics_combining` for more information): >>> physics1 = Blur(filter=w) >>> physics2 = Downsampling(img_size=((1, 32, 32)), filter="gaussian", factor=4) >>> physics = physics1.stack(physics2) >>> y = physics(x) Linear operators can also be composed by multiplying them: >>> physics = physics1 * physics2 >>> y = physics(x) Linear operators also come with an adjoint, a pseudoinverse, and proximal operators in a given norm: >>> from deepinv.loss.metric import PSNR >>> physics = Blur(filter=w, padding='circular') >>> y = physics(x) # Compute measurements >>> x_dagger = physics.A_dagger(y) # Compute linear pseudoinverse >>> x_prox = physics.prox_l2(torch.zeros_like(x), y, 1.) # Compute prox at x=0 >>> PSNR()(x, x_prox) > PSNR()(x, y) # Should be closer to the original tensor([True]) The adjoint can be generated automatically using the :func:`deepinv.physics.adjoint_function` method which relies on automatic differentiation, at the cost of a few extra computations per adjoint call: >>> from deepinv.physics import LinearPhysics, adjoint_function >>> A = lambda x: torch.roll(x, shifts=(1,1), dims=(2,3)) # Shift image by one pixel >>> physics = LinearPhysics(A=A, A_adjoint=adjoint_function(A, (4, 1, 5, 5))) >>> x = torch.randn((4, 1, 5, 5)) >>> y = physics(x) >>> torch.allclose(physics.A_adjoint(y), x) # We have A^T(A(x)) = x True """ def __init__( self, A=lambda x, **kwargs: x, A_adjoint=lambda x, **kwargs: x, noise_model=lambda x, **kwargs: x, sensor_model=lambda x: x, max_iter=50, tol=1e-4, solver="CG", **kwargs, ): super().__init__( A=A, noise_model=noise_model, sensor_model=sensor_model, max_iter=max_iter, solver=solver, tol=tol, ) self.A_adj = A_adjoint
[docs] def A_adjoint(self, y, **kwargs): r""" Computes transpose of the forward operator :math:`\tilde{x} = A^{\top}y`. If :math:`A` is linear, it should be the exact transpose of the forward matrix. .. note:: If the problem is non-linear, there is not a well-defined transpose operation, but defining one can be useful for some reconstruction networks, such as ``deepinv.models.ArtifactRemoval``. :param torch.Tensor y: measurements. :param None, torch.Tensor params: optional additional parameters for the adjoint operator. :return: (:class:`torch.Tensor`) linear reconstruction :math:`\tilde{x} = A^{\top}y`. """ return self.A_adj(y, **kwargs)
[docs] def A_vjp(self, x, v): r""" Computes the product between a vector :math:`v` and the Jacobian of the forward operator :math:`A` evaluated at :math:`x`, defined as: .. math:: A_{vjp}(x, v) = \left. \frac{\partial A}{\partial x} \right|_x^\top v = \conj{A} v. :param torch.Tensor x: signal/image. :param torch.Tensor v: vector. :return: (:class:`torch.Tensor`) the VJP product between :math:`v` and the Jacobian. """ return self.A_adjoint(v)
[docs] def A_A_adjoint(self, y, **kwargs): r""" A helper function that computes :math:`A A^{\top}y`. This function can speed up computation when :math:`A A^{\top}` is available in closed form. Otherwise it just calls :func:`deepinv.physics.Physics.A` and :func:`deepinv.physics.LinearPhysics.A_adjoint`. :param torch.Tensor y: measurement. :return: (:class:`torch.Tensor`) the product :math:`AA^{\top}y`. """ return self.A(self.A_adjoint(y, **kwargs), **kwargs)
[docs] def A_adjoint_A(self, x, **kwargs): r""" A helper function that computes :math:`A^{\top}Ax`. This function can speed up computation when :math:`A^{\top}A` is available in closed form. Otherwise it just cals :func:`deepinv.physics.Physics.A` and :func:`deepinv.physics.LinearPhysics.A_adjoint`. :param torch.Tensor x: signal/image. :return: (:class:`torch.Tensor`) the product :math:`A^{\top}Ax`. """ return self.A_adjoint(self.A(x, **kwargs), **kwargs)
[docs] def __mul__(self, other): r""" Concatenates two linear forward operators :math:`A = A_1\circ A_2` via the * operation The resulting linear operator keeps the noise and sensor models of :math:`A_1`. :param deepinv.physics.LinearPhysics other: Physics operator :math:`A_2` :return: (:class:`deepinv.physics.LinearPhysics`) concatenated operator """ A = lambda x, **kwargs: self.A(other.A(x, **kwargs), **kwargs) # (A' = A_1 A_2) A_adjoint = lambda x, **kwargs: other.A_adjoint( self.A_adjoint(x, **kwargs), **kwargs ) noise = self.noise_model sensor = self.sensor_model return LinearPhysics( A=A, A_adjoint=A_adjoint, noise_model=noise, sensor_model=sensor, max_iter=self.max_iter, tol=self.tol, )
[docs] def stack(self, other): r""" Stacks forward operators :math:`A = \begin{bmatrix} A_1 \\ A_2 \end{bmatrix}`. The measurements produced by the resulting model are :class:`deepinv.utils.TensorList` objects, where each entry corresponds to the measurements of the corresponding operator. .. note:: When using the ``stack`` operator between two noise objects, the operation will retain only the second noise. See :ref:`physics_combining` for more information. :param deepinv.physics.Physics other: Physics operator :math:`A_2` :return: (:class:`deepinv.physics.StackedPhysics`) stacked operator """ return stack(self, other)
[docs] def compute_norm(self, x0, max_iter=100, tol=1e-3, verbose=True, **kwargs): r""" Computes the spectral :math:`\ell_2` norm (Lipschitz constant) of the operator :math:`A^{\top}A`, i.e., :math:`\|A^{\top}A\|`. using the `power method <https://en.wikipedia.org/wiki/Power_iteration>`_. :param torch.Tensor x0: initialisation point of the algorithm :param int max_iter: maximum number of iterations :param float tol: relative variation criterion for convergence :param bool verbose: print information :returns z: (float) spectral norm of :math:`\conj{A} A`, i.e., :math:`\|\conj{A} A\|`. """ x = torch.randn_like(x0) x /= torch.norm(x) zold = torch.zeros_like(x) for it in range(max_iter): y = self.A(x, **kwargs) y = self.A_adjoint(y, **kwargs) z = torch.matmul(x.conj().reshape(-1), y.reshape(-1)) / torch.norm(x) ** 2 rel_var = torch.norm(z - zold) if rel_var < tol and verbose: print( f"Power iteration converged at iteration {it}, value={z.item():.2f}" ) break zold = z x = y / torch.norm(y) return z.real
[docs] def adjointness_test(self, u, **kwargs): r""" Numerically check that :math:`A^{\top}` is indeed the adjoint of :math:`A`. :param torch.Tensor u: initialisation point of the adjointness test method :return: (float) a quantity that should be theoretically 0. In practice, it should be of the order of the chosen dtype precision (i.e. single or double). """ u_in = u # .type(self.dtype) Au = self.A(u_in, **kwargs) if isinstance(Au, tuple) or isinstance(Au, list): V = [randn_like(au) for au in Au] Atv = self.A_adjoint(V, **kwargs) s1 = 0 for au, v in zip(Au, V): s1 += (v.conj() * au).flatten().sum() else: v = randn_like(Au) Atv = self.A_adjoint(v, **kwargs) s1 = (v.conj() * Au).flatten().sum() s2 = (Atv * u_in.conj()).flatten().sum() return s1.conj() - s2
[docs] def condition_number(self, x, max_iter=500, tol=1e-6, verbose=False, **kwargs): r""" Computes an approximation of the condition number of the linear operator :math:`A`. Uses the LSQR algorithm, see :func:`deepinv.optim.utils.lsqr` for more details. :param torch.Tensor x: Any input tensor (e.g. random) :param int max_iter: maximum number of iterations :param float tol: relative variation criterion for convergence :param bool verbose: print information :return: (:class:`torch.Tensor`) condition number of the operator """ y = self.A(x, **kwargs) _, cond = lsqr( self.A, self.A_adjoint, y, max_iter=max_iter, verbose=verbose, tol=tol, parallel_dim=None, **kwargs, ) return cond
[docs] def prox_l2( self, z, y, gamma, solver="CG", max_iter=None, tol=None, verbose=False, **kwargs ): r""" Computes proximal operator of :math:`f(x) = \frac{1}{2}\|Ax-y\|^2`, i.e., .. math:: \underset{x}{\arg\min} \; \frac{\gamma}{2}\|Ax-y\|^2 + \frac{1}{2}\|x-z\|^2 :param torch.Tensor y: measurements tensor :param torch.Tensor z: signal tensor :param float gamma: hyperparameter of the proximal operator :return: (:class:`torch.Tensor`) estimated signal tensor """ if max_iter is not None: self.max_iter = max_iter if tol is not None: self.tol = tol if solver is not None: self.solver = solver return least_squares( self.A, self.A_adjoint, y, solver=solver, gamma=gamma, verbose=verbose, init=z, z=z, parallel_dim=[0], ATA=self.A_adjoint_A, AAT=self.A_A_adjoint, max_iter=self.max_iter, tol=self.tol, **kwargs, )
[docs] def A_dagger( self, y, solver="CG", max_iter=None, tol=None, verbose=False, **kwargs ): r""" Computes the solution in :math:`x` to :math:`y = Ax` using a least squares solver. This function can be overwritten by a more efficient pseudoinverse in cases where closed form formulas exist. :param torch.Tensor y: a measurement :math:`y` to reconstruct via the pseudoinverse. :param str solver: least squares solver to use. Choose between 'CG', 'lsqr' and 'BiCGStab'. See :func:`deepinv.optim.utils.least_squares` for more details. :return: (:class:`torch.Tensor`) The reconstructed image :math:`x`. """ if max_iter is not None: self.max_iter = max_iter if tol is not None: self.tol = tol if solver is not None: self.solver = solver return least_squares( self.A, self.A_adjoint, y, parallel_dim=[0], AAT=self.A_A_adjoint, verbose=verbose, ATA=self.A_adjoint_A, max_iter=self.max_iter, tol=self.tol, solver=self.solver, **kwargs, )
[docs] class DecomposablePhysics(LinearPhysics): r""" Parent class for linear operators with SVD decomposition. The singular value decomposition is expressed as .. math:: A = U\text{diag}(s)V^{\top} \in \mathbb{R}^{m\times n} where :math:`U\in\mathbb{C}^{n\times n}` and :math:`V\in\mathbb{C}^{m\times m}` are orthonormal linear transformations and :math:`s\in\mathbb{R}_{+}^{n}` are the singular values. :param Callable U: orthonormal transformation :param Callable U_adjoint: transpose of U :param Callable V: orthonormal transformation :param Callable V_adjoint: transpose of V :param torch.nn.parameter.Parameter, float params: Singular values of the transform |sep| :Examples: Recreation of the Inpainting operator using the DecomposablePhysics class: >>> from deepinv.physics import DecomposablePhysics >>> seed = torch.manual_seed(0) # Random seed for reproducibility >>> tensor_size = (1, 1, 3, 3) # Input size >>> mask = torch.tensor([[1, 0, 1], [1, 0, 1], [1, 0, 1]]) # Binary mask >>> U = lambda x: x # U is the identity operation >>> U_adjoint = lambda x: x # U_adjoint is the identity operation >>> V = lambda x: x # V is the identity operation >>> V_adjoint = lambda x: x # V_adjoint is the identity operation >>> mask_svd = mask.float().unsqueeze(0).unsqueeze(0) # Convert the mask to torch.Tensor and adjust its dimensions >>> physics = DecomposablePhysics(U=U, U_adjoint=U_adjoint, V=V, V_adjoint=V_adjoint, mask=mask_svd) Apply the operator to a random tensor: >>> x = torch.randn(tensor_size) >>> with torch.no_grad(): ... physics.A(x) # Apply the masking tensor([[[[ 1.5410, -0.0000, -2.1788], [ 0.5684, -0.0000, -1.3986], [ 0.4033, 0.0000, -0.7193]]]]) """ def __init__( self, U=lambda x: x, U_adjoint=lambda x: x, V=lambda x: x, V_adjoint=lambda x: x, mask=1.0, **kwargs, ): super().__init__(**kwargs) self._V = V self._U = U self._U_adjoint = U_adjoint self._V_adjoint = V_adjoint mask = torch.tensor(mask) if not isinstance(mask, torch.Tensor) else mask self.mask = mask
[docs] def A(self, x, mask=None, **kwargs) -> Tensor: r""" Applies the forward operator :math:`y = A(x)`. If a mask/singular values is provided, it is used to apply the forward operator, and also stored as the current mask/singular values. :param torch.Tensor x: input tensor :param torch.nn.parameter.Parameter, float mask: singular values. :return: output tensor """ self.update_parameters(mask=mask, **kwargs) return self.U(self.mask * self.V_adjoint(x))
[docs] def A_adjoint(self, y, mask=None, **kwargs) -> Tensor: r""" Computes the adjoint of the forward operator :math:`\tilde{x} = A^{\top}y`. If a mask/singular values is provided, it is used to apply the adjoint operator, and also stored as the current mask/singular values. :param torch.Tensor y: input tensor :param torch.nn.parameter.Parameter, float mask: singular values. :return: output tensor """ self.update_parameters(mask=mask, **kwargs) if isinstance(self.mask, float): mask = self.mask else: mask = torch.conj(self.mask) return self.V(mask * self.U_adjoint(y))
[docs] def A_A_adjoint(self, y, mask=None, **kwargs): r""" A helper function that computes :math:`A A^{\top}y`. Using the SVD decomposition, we have :math:`A A^{\top} = U\text{diag}(s^2)U^{\top}`. :param torch.Tensor y: measurement. :return: (:class:`torch.Tensor`) the product :math:`AA^{\top}y`. """ self.update_parameters(mask=mask, **kwargs) return self.U(self.mask.conj() * self.mask * self.U_adjoint(y))
[docs] def A_adjoint_A(self, x, mask=None, **kwargs): r""" A helper function that computes :math:`A^{\top} A x`. Using the SVD decomposition, we have :math:`A^{\top}A = V\text{diag}(s^2)V^{\top}`. :param torch.Tensor x: signal/image. :return: (:class:`torch.Tensor`) the product :math:`A^{\top}Ax`. """ self.update_parameters(mask=mask, **kwargs) return self.V(self.mask.conj() * self.mask * self.V_adjoint(x))
[docs] def U(self, x): r""" Applies the :math:`U` operator of the SVD decomposition. .. note:: This method should be overwritten by the user to define its custom `DecomposablePhysics` operator. :param torch.Tensor x: input tensor """ return self._U(x)
[docs] def V(self, x): r""" Applies the :math:`V` operator of the SVD decomposition. .. note:: This method should be overwritten by the user to define its custom `DecomposablePhysics` operator. :param torch.Tensor x: input tensor """ return self._V(x)
[docs] def U_adjoint(self, x): r""" Applies the :math:`U^{\top}` operator of the SVD decomposition. .. note:: This method should be overwritten by the user to define its custom `DecomposablePhysics` operator. :param torch.Tensor x: input tensor """ return self._U_adjoint(x)
[docs] def V_adjoint(self, x): r""" Applies the :math:`V^{\top}` operator of the SVD decomposition. .. note:: This method should be overwritten by the user to define its custom `DecomposablePhysics` operator. :param torch.Tensor x: input tensor """ return self._V_adjoint(x)
[docs] def prox_l2(self, z, y, gamma, **kwargs): r""" Computes proximal operator of :math:`f(x)=\frac{\gamma}{2}\|Ax-y\|^2` in an efficient manner leveraging the singular vector decomposition. :param torch.Tensor, float z: signal tensor :param torch.Tensor y: measurements tensor :param float gamma: hyperparameter :math:`\gamma` of the proximal operator :return: (:class:`torch.Tensor`) estimated signal tensor """ b = self.A_adjoint(y) + 1 / gamma * z if isinstance(self.mask, float): scaling = self.mask**2 + 1 / gamma else: scaling = torch.conj(self.mask) * self.mask + 1 / gamma x = self.V(self.V_adjoint(b) / scaling) return x
[docs] def A_dagger(self, y, mask=None, **kwargs): r""" Computes :math:`A^{\dagger}y = x` in an efficient manner leveraging the singular vector decomposition. :param torch.Tensor y: a measurement :math:`y` to reconstruct via the pseudoinverse. :return: (:class:`torch.Tensor`) The reconstructed image :math:`x`. """ self.update_parameters(mask=mask, **kwargs) # avoid division by singular value = 0 if not isinstance(self.mask, float): mask = torch.zeros_like(self.mask) mask[self.mask > 1e-5] = 1 / self.mask[self.mask > 1e-5] else: mask = 1 / self.mask return self.V(self.U_adjoint(y) * mask)
[docs] def update_parameters(self, **kwargs): r""" Updates the singular values of the operator. """ for key, value in kwargs.items(): if ( value is not None and hasattr(self, key) and isinstance(value, torch.Tensor) ): setattr(self, key, torch.nn.Parameter(value, requires_grad=False))
[docs] class Denoising(DecomposablePhysics): r""" Forward operator for denoising problems. The linear operator is just the identity mapping :math:`A(x)=x` :param torch.nn.Module noise: noise distribution, e.g., ``deepinv.physics.GaussianNoise``, or a user-defined torch.nn.Module. |sep| :Examples: Denoising operator with Gaussian noise with standard deviation 0.1: >>> from deepinv.physics import Denoising, GaussianNoise >>> seed = torch.manual_seed(0) # Random seed for reproducibility >>> x = 0.5*torch.randn(1, 1, 3, 3) # Define random 3x3 image >>> physics = Denoising(GaussianNoise(sigma=0.1)) >>> with torch.no_grad(): ... physics(x) tensor([[[[ 0.7302, -0.2064, -1.0712], [ 0.1985, -0.4322, -0.8064], [ 0.2139, 0.3624, -0.3223]]]]) """ def __init__(self, noise_model=GaussianNoise(sigma=0.1), **kwargs): super().__init__(noise_model=noise_model, **kwargs)
[docs] def adjoint_function(A, input_size, device="cpu", dtype=torch.float): r""" Provides the adjoint function of a linear operator :math:`A`, i.e., :math:`A^{\top}`. The generated function can be simply called as ``A_adjoint(y)``, for example: >>> import torch >>> from deepinv.physics.forward import adjoint_function >>> A = lambda x: torch.roll(x, shifts=(1,1), dims=(2,3)) # shift image by one pixel >>> x = torch.randn((4, 1, 5, 5)) >>> y = A(x) >>> A_adjoint = adjoint_function(A, (4, 1, 5, 5)) >>> torch.allclose(A_adjoint(y), x) # we have A^T(A(x)) = x True :param Callable A: linear operator :math:`A`. :param tuple input_size: size of the input tensor e.g. (B, C, H, W). The first dimension, i.e. batch size, should be equal or lower than the batch size B of the input tensor to the adjoint operator. :param str device: device where the adjoint operator is computed. :return: (Callable) function that computes the adjoint of :math:`A`. """ x = torch.ones(input_size, device=device, dtype=dtype) (_, vjpfunc) = torch.func.vjp(A, x) batches = x.size()[0] def adjoint(y): if y.size()[0] < batches: y2 = torch.zeros((batches,) + y.shape[1:], device=y.device, dtype=y.dtype) y2[: y.size()[0], ...] = y return vjpfunc(y2)[0][: y.size()[0], ...] elif y.size()[0] > batches: raise ValueError("Batch size of A_adjoint input is larger than expected") else: return vjpfunc(y)[0] return adjoint
[docs] def stack(*physics: Union[Physics, LinearPhysics]): r""" Stacks multiple forward operators :math:`A = \begin{bmatrix} A_1(x) \\ A_2(x) \\ \vdots \\ A_n(x) \end{bmatrix}`. The measurements produced by the resulting model are :class:`deepinv.utils.TensorList` objects, where each entry corresponds to the measurements of the corresponding operator. :param deepinv.physics.Physics physics: Physics operators :math:`A_i` to be stacked. """ if all(isinstance(phys, LinearPhysics) for phys in physics): return StackedLinearPhysics(physics) else: return StackedPhysics(physics)
[docs] class StackedPhysics(Physics): r""" Stacks multiple physics operators into a single operator. The measurements produced by the resulting model are :class:`deepinv.utils.TensorList` objects, where each entry corresponds to the measurements of the corresponding operator. See :ref:`physics_combining` for more information. :param list[deepinv.physics.Physics] physics_list: list of physics operators to stack. """ def __init__(self, physics_list: list[Physics], **kwargs): super(StackedPhysics, self).__init__() self.physics_list = [] for physics in physics_list: self.physics_list.extend( [physics] if not isinstance(physics, StackedPhysics) else physics.physics_list )
[docs] def A(self, x: Tensor, **kwargs) -> TensorList: r""" Computes forward of stacked operator .. math:: y = \begin{bmatrix} A_1(x) \\ A_2(x) \\ \vdots \\ A_n(x) \end{bmatrix} :param torch.Tensor x: signal/image :return: measurements """ return TensorList([physics.A(x, **kwargs) for physics in self.physics_list])
def __str__(self): return "StackedPhysics(" + sum([f"{p}\n" for p in self.physics_list]) + ")" def __repr__(self): return self.__str__() def __getitem__(self, item): r""" Returns the physics operator at index `item`. :param int item: index of the physics operator """ return self.physics_list[item]
[docs] def sensor(self, y: TensorList, **kwargs) -> TensorList: r""" Applies sensor non-linearities to the measurements per physics operator in the stacked operator. :param deepinv.utils.TensorList y: measurements :return: measurements """ for i, physics in enumerate(self.physics_list): y[i] = physics.sensor(y[i], **kwargs) return y
def __len__(self): r""" Returns the number of physics operators in the stacked operator """ return len(self.physics_list)
[docs] def noise(self, y: TensorList, **kwargs) -> TensorList: r""" Applies noise to the measurements per physics operator in the stacked operator. :param deepinv.utils.TensorList y: measurements :return: noisy measurements """ for i, physics in enumerate(self.physics_list): y[i] = physics.noise(y[i], **kwargs) return y
[docs] def set_noise_model(self, noise_model, item=0): r""" Sets the noise model for the physics operator at index `item`. :param Callable, deepinv.physics.NoiseModel noise_model: noise model for the physics operator. :param int item: index of the physics operator """ self.physics_list[item].set_noise_model(noise_model)
[docs] def update_parameters(self, **kwargs): r""" Updates the parameters of the stacked operator. :param dict kwargs: dictionary of parameters to update. """ for physics in self.physics_list: physics.update_parameters(**kwargs)
[docs] class StackedLinearPhysics(StackedPhysics, LinearPhysics): r""" Stacks multiple linear physics operators into a single operator. The measurements produced by the resulting model are :class:`deepinv.utils.TensorList` objects, where each entry corresponds to the measurements of the corresponding operator. See :ref:`physics_combining` for more information. :param list[deepinv.physics.Physics] physics_list: list of physics operators to stack. :param str reduction: how to combine tensorlist outputs of adjoint operators into single adjoint output. Choose between ``sum``, ``mean`` or ``None``. """ def __init__(self, physics_list, reduction="sum", **kwargs): super(StackedLinearPhysics, self).__init__(physics_list, **kwargs) if reduction == "sum": self.reduction = sum elif reduction == "mean": self.reduction = lambda x: sum(x) / len(x) elif reduction in ("none", None): self.reduction = lambda x: x else: raise ValueError("reduction must be either sum, mean or none.")
[docs] def A_adjoint(self, y: TensorList, **kwargs) -> torch.Tensor: r""" Computes the adjoint of the stacked operator, defined as .. math:: A^{\top}y = \sum_{i=1}^{n} A_i^{\top}y_i. :param deepinv.utils.TensorList y: measurements """ return self.reduction( [ physics.A_adjoint(y[i], **kwargs) for i, physics in enumerate(self.physics_list) ] )
[docs] def update_parameters(self, **kwargs): r""" Updates the parameters of the stacked operator. :param dict kwargs: dictionary of parameters to update. """ for physics in self.physics_list: physics.update_parameters(**kwargs)