Image transformations for Equivariant Imaging#

This example demonstrates various geometric image transformations implemented in deepinv that can be used in Equivariant Imaging (EI) for self-supervised learning:

  • Shift: integer pixel 2D shift;

  • Rotate: 2D image rotation;

  • Scale: continuous 2D image downscaling;

  • Euclidean: includes continuous translation, rotation, and reflection, forming the group \(\mathbb{E}(2)\);

  • Similarity: as above but includes scale, forming the group \(\text{S}(2)\);

  • Affine: as above but includes shear effects, forming the group \(\text{Aff}(3)\);

  • Homography: as above but includes perspective (i.e pan and tilt) effects, forming the group \(\text{PGL}(3)\);

  • PanTiltRotate: pure 3D camera rotation i.e pan, tilt and 2D image rotation.

See docs for full list.

These were proposed in the papers:

  • Shift, Rotate: Chen et al.[1].

  • Scale: Scanvic et al.[2].

  • Homography and the projective geometry framework: Wang and Davies[3].

import torch
from torch.utils.data import DataLoader, random_split
from torchvision.transforms import Compose, ToTensor, CenterCrop, Resize

import deepinv as dinv
from deepinv.utils.demo import get_data_home

device = dinv.utils.get_freer_gpu() if torch.cuda.is_available() else "cpu"

ORIGINAL_DATA_DIR = get_data_home() / "Urban100"

Define transforms. For the transforms that involve 3D camera rotation (i.e pan or tilt), we limit theta_max for display.

/home/runner/work/deepinv/deepinv/deepinv/transform/rotate.py:46: UserWarning: The default interpolation mode will be changed to bilinear interpolation in the near future. Please specify the interpolation mode explicitly if you plan to keep using nearest interpolation.
  warn(

Plot transforms on a sample image. Note that, during training, we never have access to these ground truth images x, only partial and noisy measurements y.

x = dinv.utils.load_example("celeba_example.jpg")
dinv.utils.plot(
    [x] + [t(x) for t in transforms],
    ["Orig"] + [t.__class__.__name__ for t in transforms],
    fontsize=24,
)
Orig, Shift, Rotate, Scale, Homography, Euclidean, Similarity, Affine, PanTiltRotate

Now, we run an inpainting experiment to reconstruct images from images masked with a random mask, without ground truth, using EI. For this example we use the Urban100 images of natural urban scenes. As these scenes are imaged with a camera free to move and rotate in the world, all of the above transformations are valid invariances that we can impose on the unknown image set \(x\in X\).

dataset = dinv.datasets.Urban100HR(
    root=ORIGINAL_DATA_DIR,
    download=True,
    transform=Compose([ToTensor(), Resize(256), CenterCrop(256)]),
)

train_dataset, test_dataset = random_split(dataset, (0.8, 0.2))

train_dataloader = DataLoader(train_dataset, shuffle=True)
test_dataloader = DataLoader(test_dataset)

# Use physics to generate data online
physics = dinv.physics.Inpainting((3, 256, 256), mask=0.6, device=device)
  0%|          | 0/135388067 [00:00<?, ?it/s]
  9%|▉         | 12.1M/129M [00:00<00:00, 127MB/s]
 36%|███▌      | 45.9M/129M [00:00<00:00, 260MB/s]
 61%|██████    | 78.6M/129M [00:00<00:00, 298MB/s]
 87%|████████▋ | 113M/129M [00:00<00:00, 322MB/s]
100%|██████████| 129M/129M [00:00<00:00, 277MB/s]

Extracting:   0%|          | 0/101 [00:00<?, ?it/s]
Extracting:  21%|██        | 21/101 [00:00<00:00, 204.26it/s]
Extracting:  42%|████▏     | 42/101 [00:00<00:00, 192.04it/s]
Extracting:  66%|██████▋   | 67/101 [00:00<00:00, 216.05it/s]
Extracting:  88%|████████▊ | 89/101 [00:00<00:00, 211.45it/s]
Extracting: 100%|██████████| 101/101 [00:00<00:00, 208.91it/s]
Dataset has been successfully downloaded.

For training, use a small UNet, Adam optimizer, EI loss with homography transform, and the deepinv.Trainer functionality:

Note

We only train for a single epoch in the demo, but it is recommended to train multiple epochs in practice.

model = dinv.models.UNet(
    in_channels=3, out_channels=3, scales=2, circular_padding=True, batch_norm=False
).to(device)

losses = [
    dinv.loss.MCLoss(),
    dinv.loss.EILoss(dinv.transform.Homography(theta_max=10, device=device)),
]

optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-8)

model = dinv.Trainer(
    model=model,
    physics=physics,
    online_measurements=True,
    train_dataloader=train_dataloader,
    eval_dataloader=test_dataloader,
    epochs=1,
    losses=losses,
    optimizer=optimizer,
    verbose=True,
    show_progress_bar=False,
    save_path=None,
    device=device,
).train()
The model has 444867 trainable parameters
Train epoch 0: MCLoss=0.007, EILoss=0.022, TotalLoss=0.03, PSNR=11.139
Eval epoch 0: PSNR=19.338
Best model saved at epoch 1

Show results of a pretrained model trained using a larger UNet for 40 epochs:

model = dinv.models.UNet(
    in_channels=3, out_channels=3, scales=3, circular_padding=True, batch_norm=False
).to(device)

ckpt = torch.hub.load_state_dict_from_url(
    dinv.models.utils.get_weights_url("ei", "Urban100_inpainting_homography_model.pth"),
    map_location=device,
)

model.load_state_dict(ckpt["state_dict"])

x = next(iter(train_dataloader))
x = x.to(device)
y = physics(x)
x_hat = model(y)

dinv.utils.plot([x, y, x_hat], ["x", "y", "reconstruction"])
x, y, reconstruction
Downloading: "https://huggingface.co/deepinv/ei/resolve/main/Urban100_inpainting_homography_model.pth?download=true" to /home/runner/.cache/torch/hub/checkpoints/Urban100_inpainting_homography_model.pth

  0%|          | 0.00/7.90M [00:00<?, ?B/s]
100%|██████████| 7.90M/7.90M [00:00<00:00, 103MB/s]
References:

Total running time of the script: (2 minutes 19.451 seconds)

Gallery generated by Sphinx-Gallery