Physics

Introduction

This package contains a large collection of forward operators appearing in imaging applications. The acquisition models are of the form

\[y = \noise{\forw{x}}\]

where \(x\in\xset\) is an image, \(y\in\yset\) are the measurements, \(A:\xset\mapsto \yset\) is a deterministic (linear or non-linear) operator capturing the physics of the acquisition and \(N:\yset\mapsto \yset\) is a mapping which characterizes the noise affecting the measurements.

All forward operators inherit the structure of the Physics() class:

deepinv.physics.Physics

Parent class for forward operators

They are torch.nn.Module which can be called with the forward method.

>>> import torch
>>> import deepinv as dinv
>>> # load an inpainting operator that masks 50% of the pixels and adds Gaussian noise
>>> physics = dinv.physics.Inpainting(mask=.5, tensor_size=(1, 28, 28),
...                    noise_model=dinv.physics.GaussianNoise(sigma=.05))
>>> x = torch.rand(1, 1, 28, 28) # create a random image
>>> y = physics(x) # compute noisy measurements
>>> y2 = physics.A(x) # compute the A operator (no noise)

Linear operators

Linear operators \(A:\xset\mapsto \yset\) inherit the structure of the deepinv.physics.LinearPhysics() class. They have important specific properties such as the existence of an adjoint \(A^*:\yset\mapsto \xset\). Linear operators with a closed-form singular value decomposition are defined via deepinv.physics.DecomposablePhysics(), which enables the efficient computation of their pseudo-inverse and regularized inverse. Composition and linear combinations of linear operators is still a linear operator.

>>> import torch
>>> import deepinv as dinv
>>> # load a CS operator with 300 measurements, acting on 28 x 28 grayscale images.
>>> physics = dinv.physics.CompressedSensing(m=300, img_shape=(1, 28, 28))
>>> x = torch.rand(1, 1, 28, 28) # create a random image
>>> y = physics(x) # compute noisy measurements
>>> y2 = physics.A(x) # compute the linear operator (no noise)
>>> x_adj = physics.A_adjoint(y) # compute the adjoint operator
>>> x_dagger = physics.A_dagger(y) # compute the pseudo-inverse operator
>>> x_prox = physics.prox_l2(x, y, .1) # compute a regularized inverse

More details can be found in the doc of each class:

deepinv.physics.LinearPhysics

Parent class for linear operators.

deepinv.physics.DecomposablePhysics

Parent class for linear operators with SVD decomposition.

Parameter-dependent operators

Many (linear or non-linear) operators depend on (optional) parameters \(\theta\) that describe the imaging system, ie \(y = \noise{\forw{x, \theta}}\) where the forward method can be called with a dictionary of parameters as an extra input. The explicit dependency on \(\theta\) is often useful for blind inverse problems, model identification, imaging system optimization, etc. The following example shows how operators and their parameter can be instantiated and called as:

>>> import torch
>>> from deepinv.physics import Blur
>>> x = torch.rand((1, 1, 16, 16))
>>> theta = torch.ones((1, 1, 2, 2)) / 4 # a basic 2x2 averaging filter
>>> # default usage
>>> physics = Blur(filter=theta) # we instantiate a blur operator with its convolution filter
>>> y = physics(x)
>>> theta2 = torch.randn((1, 1, 2, 2)) # a random 2x2 filter
>>> physics.update_parameters(filter=theta2)
>>> y2 = physics(x)
>>>
>>> # A second possibility
>>> physics = Blur() # a blur operator without convolution filter
>>> y = physics(x, filter=theta) # we define the blur by specifying its filter
>>> y = physics(x) # now, the filter is well-defined and this line does the same as above
>>>
>>> # The same can be done by passing in a dictionary including 'filter' as a key
>>> physics = Blur() # a blur operator without convolution filter
>>> dict_params = {'filter': theta, 'dummy': None}
>>> y = physics(x, **dict_params) # # we define the blur by passing in the dictionary

Physics Generators

We provide some parameters generation methods to sample random parameters’ \(\theta\). Physics generators inherit from the PhysicsGenerator() class:

deepinv.physics.generator.PhysicsGenerator

Base class for parameter generation of physics parameters.

>>> import torch
>>> import deepinv as dinv
>>>
>>> x = torch.rand((1, 1, 8, 8))
>>> physics = dinv.physics.Blur(filter=dinv.physics.blur.gaussian_blur(.2))
>>> y = physics(x) # compute with Gaussian blur
>>> generator = dinv.physics.generator.MotionBlurGenerator(psf_size=(3, 3))
>>> params = generator.step(x.size(0)) # params = {'filter': torch.tensor(...)}
>>> y1 = physics(x, **params) # compute with motion blur
>>> assert not torch.allclose(y, y1) # different blurs, different outputs
>>> y2 = physics(x) # motion kernel is stored in the physics object as default kernel
>>> assert torch.allclose(y1, y2) # same blur, same output

If we want to generate both a new physics and noise parameters, it is possible to sum generators as follows:

>>> mask_generator = dinv.physics.generator.SigmaGenerator() \
...    + dinv.physics.generator.RandomMaskGenerator((32, 32))
>>> params = mask_generator.step(batch_size=4)
>>> print(sorted(params.keys()))
['mask', 'sigma']

It is also possible to mix generators of physics parameters through the GeneratorMixture() class:

deepinv.physics.generator.GeneratorMixture

Base class for mixing multiple PhysicsGenerator.

Forward operators

Various popular forward operators are provided with efficient implementations.

Pixelwise operators

Pixelwise operators operate in the pixel domain and are used for denoising, inpainting, decolorization, etc.

deepinv.physics.Denoising

Forward operator for denoising problems.

deepinv.physics.Inpainting

Inpainting forward operator, keeps a subset of entries.

deepinv.physics.Decolorize

Converts RGB images to grayscale.

deepinv.physics.Demosaicing

Demosaicing operator.

For random inpainting we also provide generators to create random masks on-the-fly. These can also be used as splitting masks for deepinv.loss.SplittingLoss and its variations.

deepinv.physics.generator.BernoulliSplittingMaskGenerator

Base generator for splitting/inpainting masks.

deepinv.physics.generator.GaussianSplittingMaskGenerator

Randomly generate Gaussian splitting/inpainting masks.

deepinv.physics.generator.Phase2PhaseSplittingMaskGenerator

Phase2Phase splitting mask generator for dynamic data.

deepinv.physics.generator.Artifact2ArtifactSplittingMaskGenerator

Artifact2Artifact splitting mask generator for dynamic data.

Blur & Super-Resolution

Different types of blur operators are available, from simple stationary kernels to space-varying ones.

deepinv.physics.Blur

Blur operator.

deepinv.physics.BlurFFT

FFT-based blur operator.

deepinv.physics.SpaceVaryingBlur

Implements a space varying blur via product-convolution.

deepinv.physics.Downsampling

Downsampling operator for super-resolution problems.

We provide the implementation of typical blur kernels such as Gaussian, bilinear, bicubic, etc.

deepinv.physics.blur.gaussian_blur

Gaussian blur filter.

deepinv.physics.blur.bilinear_filter

Bilinear filter.

deepinv.physics.blur.bicubic_filter

Bicubic filter.

deepinv.physics.blur.sinc_filter

Anti-aliasing sinc filter multiplied by a Kaiser window.

We also provide a set of generators to simulate various types of blur, which can be used to train blind or semi-blind deblurring networks.

deepinv.physics.generator.MotionBlurGenerator

Random motion blur generator.

deepinv.physics.generator.DiffractionBlurGenerator

Diffraction limited blur generator.

deepinv.physics.generator.DiffractionBlurGenerator3D

Generates 3D diffraction limited kernels in optics using Zernike decomposition of the phase mask (Fresnel/Fraunhoffer diffraction theory).

deepinv.physics.generator.ProductConvolutionBlurGenerator

Generates parameters of space-varying blurs.

Magnetic Resonance Imaging

In MRI, the Fourier transform is sampled on a grid (FFT) or off-the grid, with a single coil or multiple coils. We provide 2D and 2D+t dynamic MRI physics.

deepinv.physics.MRI

Single-coil accelerated magnetic resonance imaging.

deepinv.physics.DynamicMRI

Single-coil accelerated dynamic magnetic resonance imaging.

deepinv.physics.SequentialMRI

Single-coil accelerated magnetic resonance imaging using sequential sampling.

We provide generators for creating random and non-random acceleration masks using Cartesian sampling, for both static (k) and dynamic (k-t) accelerated MRI:

deepinv.physics.generator.BaseMaskGenerator

Base generator for MRI acceleration masks.

deepinv.physics.generator.GaussianMaskGenerator

Generator for MRI Cartesian acceleration masks using Gaussian undersampling.

deepinv.physics.generator.RandomMaskGenerator

Generator for MRI Cartesian acceleration masks using random uniform undersampling.

deepinv.physics.generator.EquispacedMaskGenerator

Generator for MRI Cartesian acceleration masks using uniform (equispaced) non-random undersampling with random offset.

Tomography

Tomography is based on the Radon-transform which computes line-integrals.

deepinv.physics.Tomography

(Computed) Tomography operator.

Remote Sensing

Remote sensing operators are used to simulate the acquisition of satellite data.

deepinv.physics.Pansharpen

Pansharpening forward operator.

Compressive operators

Compressive operators are implemented in the following classes:

deepinv.physics.CompressedSensing

Compressed Sensing forward operator.

deepinv.physics.SinglePixelCamera

Single pixel imaging camera.

Radio interferometric imaging

The radio interferometric imaging operator is implemented in the following class:

deepinv.physics.RadioInterferometry

Radio Interferometry measurement operator.

Single-photon lidar

Single-photon lidar is a popular technique for depth ranging and imaging.

deepinv.physics.SinglePhotonLidar

Single photon lidar operator for depth ranging.

Dehazing

Haze operators are used to capture the physics of light scattering in the atmosphere.

deepinv.physics.Haze

Standard haze model

Phase retrieval

Operators where \(A:\xset\mapsto \yset\) is of the form \(A(x) = |Bx|^2\) with \(B\) a linear operator.

deepinv.physics.PhaseRetrieval

Phase Retrieval base class corresponding to the operator

deepinv.physics.RandomPhaseRetrieval

Random Phase Retrieval forward operator.

Noise distributions

Noise mappings \(N:\yset\mapsto \yset\) are simple torch.nn.Module. The noise of a forward operator can be set in its construction or simply as

>>> import torch
>>> import deepinv as dinv
>>> # load a CS operator with 300 measurements, acting on 28 x 28 grayscale images.
>>> physics = dinv.physics.CompressedSensing(m=300, img_shape=(1, 28, 28))
>>> physics.noise_model = dinv.physics.GaussianNoise(sigma=.05) # set up the noise

deepinv.physics.GaussianNoise

Gaussian noise \(y=z+\epsilon\) where \(\epsilon\sim \mathcal{N}(0,I\sigma^2)\).

deepinv.physics.LogPoissonNoise

Log-Poisson noise \(y = \frac{1}{\mu} \log(\frac{\mathcal{P}(\exp(-\mu x) N_0)}{N_0})\).

deepinv.physics.PoissonNoise

Poisson noise \(y = \mathcal{P}(\frac{x}{\gamma})\) with gain \(\gamma>0\).

deepinv.physics.PoissonGaussianNoise

Poisson-Gaussian noise \(y = \gamma z + \epsilon\) where \(z\sim\mathcal{P}(\frac{x}{\gamma})\) and \(\epsilon\sim\mathcal{N}(0, I \sigma^2)\).

deepinv.physics.UniformNoise

Uniform noise \(y = x + \epsilon\) where \(\epsilon\sim\mathcal{U}(-a,a)\).

deepinv.physics.UniformGaussianNoise

Gaussian noise \(y=z+\epsilon\) where \(\epsilon\sim \mathcal{N}(0,I\sigma^2)\) and \(\sigma \sim\mathcal{U}(\sigma_{\text{min}}, \sigma_{\text{max}})\)

deepinv.physics.GammaNoise

Gamma noise \(y = \mathcal{G}(\ell, x/\ell)\)

The parameters of noise distributions can also be created from a deepinv.physics.generator.PhysicsGenerator(), which is useful for training and evaluating methods under various noise conditions.

deepinv.physics.generator.SigmaGenerator

Generator for the noise level \(\sigma\) in the Gaussian noise model.

Defining new operators

Defining a new forward operator is relatively simple. You need to create a new class that inherits from the right physics class, that is deepinv.physics.Physics() for non-linear operators, deepinv.physics.LinearPhysics() for linear operators and deepinv.physics.DecomposablePhysics() for linear operators with a closed-form singular value decomposition. The only requirement is to define a deepinv.physics.Physics.A method that computes the forward operator. See the example Creating a forward operator. for more details.

You can also inherit from mixin classes to provide useful methods for your physics:

deepinv.physics.TimeMixin

Base class for temporal capabilities for physics and models.

Defining a new linear operator requires the definition of deepinv.physics.LinearPhysics.A_adjoint, you can define the adjoint automatically using autograd with

deepinv.physics.adjoint_function

Provides the adjoint function of a linear operator \(A\), i.e., \(A^{\top}\).

Note however that coding a closed form adjoint is generally more efficient.

Functional

The toolbox is based on efficient PyTorch implementations of basic operations such as diagonal multipliers, Fourier transforms, convolutions, product-convolutions, Radon transform, interpolation mappings. Similar to the PyTorch structure, they are available within deepinv.physics.functional.

deepinv.physics.functional.conv2d

A helper function performing the 2d convolution of images x and filter.

deepinv.physics.functional.conv_transpose2d

A helper function performing the 2d transposed convolution 2d of x and filter.

deepinv.physics.functional.conv2d_fft

A helper function performing the 2d convolution of images x and filter using FFT.

deepinv.physics.functional.conv_transpose2d_fft

A helper function performing the 2d transposed convolution 2d of x and filter using FFT.

deepinv.physics.functional.conv3d_fft

A helper function performing the 3d convolution of x and filter using FFT.

deepinv.physics.functional.conv_transpose3d_fft

A helper function performing the 3d transposed convolution of y and filter using FFT.

deepinv.physics.functional.product_convolution2d

Product-convolution operator in 2d.

deepinv.physics.functional.multiplier

Implements diagonal matrices or multipliers \(x\) and mult.

deepinv.physics.functional.multiplier_adjoint

Implements the adjoint of diagonal matrices or multipliers \(x\) and mult.

deepinv.physics.functional.Radon

Sparse Radon transform operator.

deepinv.physics.functional.IRadon

Inverse sparse Radon transform operator.

deepinv.physics.functional.histogramdd

Computes the multidimensional histogram of a tensor.

deepinv.physics.functional.histogram

Computes the histogram of a tensor.

>>> import torch
>>> import deepinv as dinv

>>> x = torch.zeros((1, 1, 16, 16)) # Define black image of size 16x16
>>> x[:, :, 8, 8] = 1 # Define one white pixel in the middle
>>> filter = torch.ones((1, 1, 3, 3)) / 4
>>>
>>> padding = "circular"
>>> Ax = dinv.physics.functional.conv2d(x, filter, padding)
>>> print(Ax[:, :, 7:10, 7:10])
tensor([[[[0.2500, 0.2500, 0.2500],
          [0.2500, 0.2500, 0.2500],
          [0.2500, 0.2500, 0.2500]]]])
>>>
>>> _ = torch.manual_seed(0)
>>> y = torch.randn_like(Ax)
>>> z = dinv.physics.functional.conv_transpose2d(y, filter, padding)
>>> print((Ax * y).sum(dim=(1, 2, 3)) - (x * z).sum(dim=(1, 2, 3)))
tensor([5.9605e-08])