Note
New to DeepInverse? Get started with the basics with the 5 minute quickstart tutorial..
Image transforms for equivariance & augmentations#
We demonstrate the use of our deepinv.transform module for use in
solving imaging problems. These can be used for:
Data augmentation (similar to
torchvision.transforms)Building equivariant denoisers (
deepinv.models.EquivariantDenoiser) for robust denoising (e.g from Terris et al.[1])Self-supervised learning using Equivariant Imaging from Chen et al.[2]. See Image transformations for Equivariant Imaging, Self-supervised learning with Equivariant Imaging for MRI. for thorough examples.
See docs for full list of implemented transforms.
1. Data augmentation#
We can use deepinv transforms in the same way as torchvision
transforms, and chain them together for data augmentation. Our
transforms are customisable and offer some group-theoretic properties.
We demonstrate a random roto-scale combined with a random masking, and a
constrained pixel-shift with a random color jitter.
Note that all our transforms can easily be inverted using the method transform.inverse().
First, load a sample image.
import deepinv as dinv
import torch
from torchvision.transforms import Compose, ColorJitter, RandomErasing, Resize
x = dinv.utils.load_example("celeba_example.jpg")
# Random roto-scale with random masking
transform = Compose(
[
dinv.transform.Rotate() * dinv.transform.Scale(),
RandomErasing(),
]
)
# Constrained pixel-shift with a random color jitter
transform2 = Compose(
[
dinv.transform.Shift(shift_max=0.2),
ColorJitter(hue=0.5),
]
)
# Random diffeomorphism
transform3 = dinv.transform.CPABDiffeomorphism()
dinv.utils.plot(
[x, transform(x), transform2(x), transform3(x)],
titles=["Orig", "Transform 1", "Transform 2", "Transform 3"],
)

/local/jtachell/deepinv/deepinv/deepinv/transform/rotate.py:49: UserWarning: The default interpolation mode will be changed to bilinear interpolation in the near future. Please specify the interpolation mode explicitly if you plan to keep using nearest interpolation.
warn(
By letting n_trans be equal to the full group size, all transforms
are recovered:
reflect = dinv.transform.Reflect(dim=[-2, -1], n_trans=4)
rotate = dinv.transform.Rotate(multiples=90, positive=True, n_trans=4)
dinv.utils.plot(
[reflect(x), rotate(x)], titles=["Full 2D reflect group", "Full rotate group"]
)

2. Equivariant denoiser or plug-and-play#
Suppose we want to make a denoiser equivariant to the rotoreflect group, taken as the group product of the 90 degree rotations (order 4) and 1D reflects (order 2). We can do this with our transform arithmetic (note this results in the full dihedral group \(\text{Dih}_4\) of order 8):
transform = rotate * dinv.transform.Reflect(dim=[-1], n_trans=2)
Let’s simulate some Gaussian noise and turn a simple (median filter)
denoiser into an equivariant denoiser
(deepinv.models.EquivariantDenoiser):
sigma = 0.1
physics = dinv.physics.Denoising(noise_model=dinv.physics.GaussianNoise(sigma=sigma))
x = Resize(128)(x)
# Put the image in an unusual orientation to show the benefits of equivariance
x = torch.rot90(x, k=1, dims=[-2, -1])
y = physics(x)
model = dinv.models.RAM(pretrained=True)
model_eq = dinv.models.EquivariantDenoiser(model, transform=transform)
with torch.no_grad():
x_hat = model(y, sigma=sigma)
x_eq = model_eq(y, sigma=sigma)
psnr_fn = dinv.metric.PSNR()
psnr = psnr_fn(x_hat, x).item()
psnr_eq = psnr_fn(x_eq, x).item()
psnr_y = psnr_fn(y, x).item()
dinv.utils.plot(
[y, x_hat, x_eq, x],
["Measurements", "Regular Denoiser", "Equivariant Denoiser", "Ground truth"],
subtitles=[
f"PSNR={psnr_y:.1f}dB",
f"PSNR={psnr:.1f}dB",
f"PSNR={psnr_eq:.1f}dB",
"",
],
fontsize=10,
)

What’s going on under the hood? We use the transform.symmetrize
method to symmetrize the function \(D\) with respect to a projective
transform (with a Monte Carlo approach of n_trans=2 transforms per call):
# Example non-equivariant function
D = lambda x: x[..., [0]] * x
# Example non-linear transform with n=2
t = dinv.transform.projective.PanTiltRotate(n_trans=2, theta_max=10, theta_z_max=0)
# Symmetrize function with respect to transform
D_s = t.symmetrize(D, average=True)
dinv.utils.plot(
[x, D(x), D_s(x)], titles=["Orig", "$D(x)$", "$\\sum_i T_g^{-1}D(T_g x)$"]
)

Reconstructors can also be made equivariant in a similar way using deepinv.models.EquivariantReconstructor. This amounts to averaging the base reconstructor \(\tilde{R}\) over the transformations to get an equivariant reconstructor
which is computed using a Monte Carlo sampling where a random subset of transformations is used, typically a single one at training time and the full set at evaluation time.
sigma = 0.1
physics = dinv.physics.Denoising(noise_model=dinv.physics.GaussianNoise(sigma=sigma))
y = physics(x)
rotate = dinv.transform.Rotate(multiples=90, positive=True, n_trans=4)
transform = rotate * dinv.transform.Reflect(dim=[-1], n_trans=2)
model = dinv.models.RAM(pretrained=True)
model_eq = dinv.models.EquivariantReconstructor(model, transform=transform)
with torch.no_grad():
x_hat = model(y, physics=physics)
x_eq = model_eq(y, physics=physics)
psnr = psnr_fn(x_hat, x).item()
psnr_eq = psnr_fn(x_eq, x).item()
psnr_y = psnr_fn(y, x).item()
dinv.utils.plot(
[y, x_hat, x_eq, x],
[
"Measurements",
"Regular Reconstructor",
"Equivariant Reconstructor",
"Ground truth",
],
subtitles=[
f"PSNR={psnr_y:.1f}dB",
f"PSNR={psnr:.1f}dB",
f"PSNR={psnr_eq:.1f}dB",
"",
],
fontsize=9,
)

3. Equivariant imaging#
We can also use our transforms to create the self-supervised equivariant imaging loss. See Image transformations for Equivariant Imaging, Self-supervised learning with Equivariant Imaging for MRI. for examples of self-supervised learning for MRI and inpainting. For example, the EI loss can easily be defined using any combination of transforms:
loss = dinv.loss.EILoss(
transform=dinv.transform.projective.Affine() | dinv.transform.projective.Euclidean()
)
- References:
Total running time of the script: (0 minutes 13.121 seconds)