Transforms
This package contains different transforms which can be used for data augmentation or together with the equivariant losses.
We implement various geometric transforms, ranging from Euclidean to homography and diffeomorphisms, some of which offer group-theoretic properties.
See Image transforms for equivariance & augmentations for example usage and visualisations.
Transforms inherit from deepinv.transform.Transform
. Transforms can also be stacked by summing them, chained by multiplying them (i.e. product group), or joined via |
to randomly select.
There are numerous other parameters e.g to randomly transform multiple times at once, to constrain the parameters to a range etc.
Transforms can also be used to make a denoiser equivariant using deepinv.models.EquivariantDenoiser
by performing Reynolds averaging using symmetrize()
.
They can also be used for equivariant imaging (EI) using the deepinv.loss.EILoss
loss.
See Image transformations for Equivariant Imaging and Self-supervised learning with Equivariant Imaging for MRI. for examples.
If needed, transforms can also be made deterministic by passing in specified parameters to the forward method.
This allows every transform to have its own deterministic inverse using transform.inverse()
.
Transforms can also be seamlessly integrated with existing torchvision
transforms.
Transforms can also accept video (5D) input.
Base class for image transforms. |
For example, random transforms can be used as follows:
>>> import torch
>>> from deepinv.transform import Shift, Rotate
>>> x = torch.rand((1, 1, 2, 2)) # Define random image (B,C,H,W)
>>> transform = Shift() # Define random shift transform
>>> transform(x).shape
torch.Size([1, 1, 2, 2])
>>> y = transform(transform(x, x_shift=[1]), x_shift=[-1]) # Deterministic transform
>>> torch.all(x == y)
tensor(True)
>>> transform(torch.rand((1, 1, 3, 2, 2))).shape # Accepts video input of shape (B,C,T,H,W)
torch.Size([1, 1, 3, 2, 2])
>>> transform = Rotate() + Shift() # Stack rotate and shift transforms
>>> transform(x).shape
torch.Size([2, 1, 2, 2])
>>> rotoshift = Rotate() * Shift() # Chain rotate and shift transforms
>>> rotoshift(x).shape
torch.Size([1, 1, 2, 2])
>>> transform = Rotate() | Shift() # Randomly select rotate or shift transforms
>>> transform(x).shape
torch.Size([1, 1, 2, 2])
>>> f = lambda x: x[..., [0]] * x # Function to be symmetrized
>>> f_s = rotoshift.symmetrize(f)
>>> f_s(x).shape
torch.Size([1, 1, 2, 2])
Simple transforms
2D Rotations. |
|
Fast integer 2D translations. |
|
2D Scaling. |
|
Reflect (flip) in random multiple axes. |
Advanced transforms
We implement the following further geometric transforms.
The projective transformations formulate the image transformations using the pinhole camera model, from which various transformation subgroups can be derived.
See Image transformations for Equivariant Imaging for a demonstration. Note these require kornia
installed.
Random projective transformations (homographies). |
|
Random Euclidean image transformations using projective transformation framework. |
|
Random 2D similarity image transformations using projective transformation framework. |
|
Random affine image transformations using projective transformation framework. |
|
Random 3D camera rotation image transformations using projective transformation framework. |
|
Continuous Piecewise-Affine-based Diffeomorphism. |
Video transforms
While all geometric transforms accept video input, the following transforms work specifically in the time dimension.
These can be easily compounded with geometric transformations using the *
operation.
Shift a video in time with reflective padding. |