Transform#
- class deepinv.transform.Transform(*args, n_trans: int = 1, rng: Generator | None = None, constant_shape: bool = True, flatten_video_input: bool = True, **kwargs)[source]#
-
Base class for image transforms.
The base transform implements transform arithmetic and other methods to invert transforms and symmetrize functions.
All transforms must implement
_get_params()
to randomly generate e.g. rotation degrees or shift pixels, and_transform()
to deterministically transform an image given the params.To implement a new transform, please reimplement
_get_params()
and_transform()
(with a**kwargs
argument). See respective methods for details.Also handle deterministic (non-random) transformations by passing in fixed parameter values.
All transforms automatically handle video input (5D of shape
(B,C,T,H,W)
) by flattening the time dimension.
- Examples:
Randomly transform an image:
>>> import torch >>> from deepinv.transform import Shift, Rotate >>> x = torch.rand((1, 1, 2, 2)) # Define random image (B,C,H,W) >>> transform = Shift() # Define random shift transform >>> transform(x).shape torch.Size([1, 1, 2, 2])
Deterministically transform an image:
>>> y = transform(transform(x, x_shift=[1]), x_shift=[-1]) >>> torch.all(x == y) tensor(True)
# Accepts video input of shape (B,C,T,H,W):
>>> transform(torch.rand((1, 1, 3, 2, 2))).shape torch.Size([1, 1, 3, 2, 2])
Multiply transforms to create compound transforms (direct product of groups) - similar to
torchvision.transforms.Compose
:>>> rotoshift = Rotate() * Shift() # Chain rotate and shift transforms >>> rotoshift(x).shape torch.Size([1, 1, 2, 2])
Sum transforms to create stacks of transformed images (along the batch dimension).
>>> transform = Rotate() + Shift() # Stack rotate and shift transforms >>> transform(x).shape torch.Size([2, 1, 2, 2])
Randomly select from transforms - similar to
torchvision.transforms.RandomApply
:>>> transform = Rotate() | Shift() # Randomly select rotate or shift transforms >>> transform(x).shape torch.Size([1, 1, 2, 2])
Symmetrize a function by averaging over the group (also known as Reynolds averaging):
>>> f = lambda x: x[..., [0]] * x # Function to be symmetrized >>> f_s = rotoshift.symmetrize(f) >>> f_s(x).shape torch.Size([1, 1, 2, 2])
- Parameters:
n_trans (int) – number of transformed versions generated per input image, defaults to 1
rng (torch.Generator) – random number generator, if
None
, usetorch.Generator()
, defaults toNone
constant_shape (bool) – if
True
, transformed images are assumed to be same shape as input. For most transforms, this will not be an issue as automatic cropping/padding should mean all outputs are same shape. If False, for certain transforms includingdeepinv.transform.Rotate
,transform
will try to switch off automatic cropping/padding resulting in errors. However,symmetrize
will still work but perform one-by-one (i.e. without collating over batch, which is less efficient).flatten_video_input (bool) – accept video (5D) input of shape
(B,C,T,H,W)
by flattening time dim before transforming and unflattening after all operations.
- __add__(other: Transform)[source]#
Stacks two transforms via the + operation.
- Parameters:
other (deepinv.transform.Transform) – other transform
- Returns:
(deepinv.transform.Transform) operator which produces stacked transformed images
- __mul__(other: Transform)[source]#
Chains two transforms via the * operation.
- Parameters:
other (deepinv.transform.Transform) – other transform
- Returns:
(deepinv.transform.Transform) chained operator
- forward(x: Tensor, **params) Tensor [source]#
Perform random transformation on image.
Calls
get_params
to generate random params for image, thentransform
to deterministically transform.For purely deterministic transformation, pass in custom params and
get_params
will be ignored.- Parameters:
x (torch.Tensor) – input image of shape (B,C,H,W)
- Return torch.Tensor:
randomly transformed images concatenated along the first dimension
- get_params(x: Tensor) dict [source]#
Randomly generate transform parameters, one set per n_trans.
Params are represented as tensors where the first dimension indexes batch and
n_trans
. Params store e.g rotation degrees or shift amounts.Params may be any Tensor-like object. For inverse transforms, params are negated by default. To change this behaviour (e.g. calculate reciprocal for inverse), wrap the param in a
TransformParam
class:p = TransformParam(p, neg=lambda x: 1/x)
- Parameters:
x (torch.Tensor) – input image
- Return dict:
keyword args of transform parameters e.g.
{'theta': 30}
- identity(x: Tensor, average: bool = False) Tensor [source]#
Sanity check function that should do nothing.
This performs forward and inverse transform, which results in the exact original, down to interpolation and padding effects.
Interpolation and padding effects will be visible in non-pixelwise transformations, such as arbitrary rotation, scale or projective transformation.
- Parameters:
x (torch.Tensor) – input image
average (bool) – average over
n_trans
transformed versions to get same number as output images as input images. No effect whenn_trans=1
.
- Return torch.Tensor:
\(T_g^{-1}T_g x=x\)
- inverse(x: Tensor, batchwise=True, **params) Tensor [source]#
Perform random inverse transformation on image (i.e. when not a group).
For purely deterministic transformation, pass in custom params and
get_params
will be ignored.- Parameters:
x (torch.Tensor) – input image
batchwise (bool) – if True, the output dim 0 expands to be of size
len(x) * len(param)
for the params of interest. If False, params will attempt to match each image in batch to keep constantlen(out)=len(x)
. No effect whenn_trans==1
- Return torch.Tensor:
randomly transformed images
- invert_params(params: dict) dict [source]#
Invert transformation parameters. Pass variable of type
TransformParam
to override negation (e.g. to take reciprocal).- Parameters:
params (dict) – transform parameters as dict
- Return dict:
inverted parameters.
- symmetrize(f: Callable[[Tensor, Any], Tensor], average: bool = False, collate_batch: bool = True) Callable[[Tensor, Any], Tensor] [source]#
Symmetrise a function with a transform and its inverse.
Given a function \(f(\cdot):X\rightarrow X\) and a transform \(T_g\), returns the group averaged function \(\sum_{i=1}^N T_{g_i}^{-1} f(T_{g_i} \cdot)\) where \(N\) is the number of random transformations.
For example, this is useful for Reynolds averaging a function over a group. Set
average=True
to average overn_trans
. For example, useRotate(n_trans=4, positive=True, multiples=90).symmetrize(f)
to symmetrize f over the entire group.- Parameters:
f (Callable[[torch.Tensor, Any], torch.Tensor]) – function acting on tensors.
average (bool) – monte carlo average over all random transformations (in range
n_trans
) when symmetrising to get same number of output images as input images. No effect whenn_trans=1
.collate_batch (bool) – if
True
, collectn_trans
transformed images in batch dim and evaluatef
only once. However, this requiresn_trans
extra memory. IfFalse
, evaluatef
for each transformation. Always will beFalse
when transformed images aren’t constant shape.
- Return Callable[[torch.Tensor, Any], torch.Tensor]:
decorated function.
- transform(x: Tensor, **params) Tensor [source]#
Transform image given transform parameters.
Given randomly generated params (e.g. rotation degrees), deterministically transform the image x.
- Parameters:
x (torch.Tensor) – input image of shape (B,C,H,W)
params – parameters e.g. degrees or shifts provided as keyword args.
- Returns:
torch.Tensor: transformed image.
Examples using Transform
:#
Image transforms for equivariance & augmentations
Image transformations for Equivariant Imaging
Self-supervised learning with Equivariant Imaging for MRI.