Transform#

class deepinv.transform.Transform(*args, n_trans: int = 1, rng: Generator | None = None, constant_shape: bool = True, flatten_video_input: bool = True, **kwargs)[source]#

Bases: Module, TimeMixin

Base class for image transforms.

The base transform implements transform arithmetic and other methods to invert transforms and symmetrize functions.

All transforms must implement _get_params() to randomly generate e.g. rotation degrees or shift pixels, and _transform() to deterministically transform an image given the params.

To implement a new transform, please reimplement _get_params() and _transform() (with a **kwargs argument). See respective methods for details.

Also handle deterministic (non-random) transformations by passing in fixed parameter values.

All transforms automatically handle video input (5D of shape (B,C,T,H,W)) by flattening the time dimension.


Examples:

Randomly transform an image:

>>> import torch
>>> from deepinv.transform import Shift, Rotate
>>> x = torch.rand((1, 1, 2, 2)) # Define random image (B,C,H,W)
>>> transform = Shift() # Define random shift transform
>>> transform(x).shape
torch.Size([1, 1, 2, 2])

Deterministically transform an image:

>>> y = transform(transform(x, x_shift=[1]), x_shift=[-1])
>>> torch.all(x == y)
tensor(True)

# Accepts video input of shape (B,C,T,H,W):

>>> transform(torch.rand((1, 1, 3, 2, 2))).shape
torch.Size([1, 1, 3, 2, 2])

Multiply transforms to create compound transforms (direct product of groups) - similar to torchvision.transforms.Compose:

>>> rotoshift = Rotate() * Shift() # Chain rotate and shift transforms
>>> rotoshift(x).shape
torch.Size([1, 1, 2, 2])

Sum transforms to create stacks of transformed images (along the batch dimension).

>>> transform = Rotate() + Shift() # Stack rotate and shift transforms
>>> transform(x).shape
torch.Size([2, 1, 2, 2])

Randomly select from transforms - similar to torchvision.transforms.RandomApply:

>>> transform = Rotate() | Shift() # Randomly select rotate or shift transforms
>>> transform(x).shape
torch.Size([1, 1, 2, 2])

Symmetrize a function by averaging over the group (also known as Reynolds averaging):

>>> f = lambda x: x[..., [0]] * x # Function to be symmetrized
>>> f_s = rotoshift.symmetrize(f)
>>> f_s(x).shape
torch.Size([1, 1, 2, 2])
Parameters:
  • n_trans (int) – number of transformed versions generated per input image, defaults to 1

  • rng (torch.Generator) – random number generator, if None, use torch.Generator(), defaults to None

  • constant_shape (bool) – if True, transformed images are assumed to be same shape as input. For most transforms, this will not be an issue as automatic cropping/padding should mean all outputs are same shape. If False, for certain transforms including deepinv.transform.Rotate, transform will try to switch off automatic cropping/padding resulting in errors. However, symmetrize will still work but perform one-by-one (i.e. without collating over batch, which is less efficient).

  • flatten_video_input (bool) – accept video (5D) input of shape (B,C,T,H,W) by flattening time dim before transforming and unflattening after all operations.

__add__(other: Transform)[source]#

Stacks two transforms via the + operation.

Parameters:

other (deepinv.transform.Transform) – other transform

Returns:

(deepinv.transform.Transform) operator which produces stacked transformed images

__mul__(other: Transform)[source]#

Chains two transforms via the * operation.

Parameters:

other (deepinv.transform.Transform) – other transform

Returns:

(deepinv.transform.Transform) chained operator

forward(x: Tensor, **params) Tensor[source]#

Perform random transformation on image.

Calls get_params to generate random params for image, then transform to deterministically transform.

For purely deterministic transformation, pass in custom params and get_params will be ignored.

Parameters:

x (torch.Tensor) – input image of shape (B,C,H,W)

Return torch.Tensor:

randomly transformed images concatenated along the first dimension

get_params(x: Tensor) dict[source]#

Randomly generate transform parameters, one set per n_trans.

Params are represented as tensors where the first dimension indexes batch and n_trans. Params store e.g rotation degrees or shift amounts.

Params may be any Tensor-like object. For inverse transforms, params are negated by default. To change this behaviour (e.g. calculate reciprocal for inverse), wrap the param in a TransformParam class: p = TransformParam(p, neg=lambda x: 1/x)

Parameters:

x (torch.Tensor) – input image

Return dict:

keyword args of transform parameters e.g. {'theta': 30}

identity(x: Tensor, average: bool = False) Tensor[source]#

Sanity check function that should do nothing.

This performs forward and inverse transform, which results in the exact original, down to interpolation and padding effects.

Interpolation and padding effects will be visible in non-pixelwise transformations, such as arbitrary rotation, scale or projective transformation.

Parameters:
  • x (torch.Tensor) – input image

  • average (bool) – average over n_trans transformed versions to get same number as output images as input images. No effect when n_trans=1.

Return torch.Tensor:

\(T_g^{-1}T_g x=x\)

inverse(x: Tensor, batchwise=True, **params) Tensor[source]#

Perform random inverse transformation on image (i.e. when not a group).

For purely deterministic transformation, pass in custom params and get_params will be ignored.

Parameters:
  • x (torch.Tensor) – input image

  • batchwise (bool) – if True, the output dim 0 expands to be of size len(x) * len(param) for the params of interest. If False, params will attempt to match each image in batch to keep constant len(out)=len(x). No effect when n_trans==1

Return torch.Tensor:

randomly transformed images

invert_params(params: dict) dict[source]#

Invert transformation parameters. Pass variable of type TransformParam to override negation (e.g. to take reciprocal).

Parameters:

params (dict) – transform parameters as dict

Return dict:

inverted parameters.

symmetrize(f: Callable[[Tensor, Any], Tensor], average: bool = False, collate_batch: bool = True) Callable[[Tensor, Any], Tensor][source]#

Symmetrise a function with a transform and its inverse.

Given a function \(f(\cdot):X\rightarrow X\) and a transform \(T_g\), returns the group averaged function \(\sum_{i=1}^N T_{g_i}^{-1} f(T_{g_i} \cdot)\) where \(N\) is the number of random transformations.

For example, this is useful for Reynolds averaging a function over a group. Set average=True to average over n_trans. For example, use Rotate(n_trans=4, positive=True, multiples=90).symmetrize(f) to symmetrize f over the entire group.

Parameters:
  • f (Callable[[torch.Tensor, Any], torch.Tensor]) – function acting on tensors.

  • average (bool) – monte carlo average over all random transformations (in range n_trans) when symmetrising to get same number of output images as input images. No effect when n_trans=1.

  • collate_batch (bool) – if True, collect n_trans transformed images in batch dim and evaluate f only once. However, this requires n_trans extra memory. If False, evaluate f for each transformation. Always will be False when transformed images aren’t constant shape.

Return Callable[[torch.Tensor, Any], torch.Tensor]:

decorated function.

transform(x: Tensor, **params) Tensor[source]#

Transform image given transform parameters.

Given randomly generated params (e.g. rotation degrees), deterministically transform the image x.

Parameters:
  • x (torch.Tensor) – input image of shape (B,C,H,W)

  • params – parameters e.g. degrees or shifts provided as keyword args.

Returns:

torch.Tensor: transformed image.

Examples using Transform:#

Image transforms for equivariance & augmentations

Image transforms for equivariance & augmentations

Image transformations for Equivariant Imaging

Image transformations for Equivariant Imaging

Self-supervised learning with Equivariant Imaging for MRI.

Self-supervised learning with Equivariant Imaging for MRI.