Image transformations for Equivariant Imaging

This example demonstrates various geometric image transformations implemented in deepinv that can be used in Equivariant Imaging (EI) for self-supervised learning:

  • Shift: integer pixel 2D shift;

  • Rotate: 2D image rotation;

  • Scale: continuous 2D image downscaling;

  • Euclidean: includes continuous translation, rotation, and reflection, forming the group \(\mathbb{E}(2)\);

  • Similarity: as above but includes scale, forming the group \(\text{S}(2)\);

  • Affine: as above but includes shear effects, forming the group \(\text{Aff}(3)\);

  • Homography: as above but includes perspective (i.e pan and tilt) effects, forming the group \(\text{PGL}(3)\);

  • PanTiltRotate: pure 3D camera rotation i.e pan, tilt and 2D image rotation.

See docs for full list.

These were proposed in the papers:

import deepinv as dinv
import torch
from torch.utils.data import DataLoader, random_split
from torchvision.datasets import ImageFolder
from torchvision.transforms import Compose, ToTensor, CenterCrop, Resize
from torchvision.datasets.utils import download_and_extract_archive

device = dinv.utils.get_freer_gpu() if torch.cuda.is_available() else "cpu"

Define transforms. For the transforms that involve 3D camera rotation (i.e pan or tilt), we limit theta_max for display.

Plot transforms on a sample image. Note that, during training, we never have access to these ground truth images x, only partial and noisy measurements y.

x = dinv.utils.load_url_image(dinv.utils.demo.get_image_url("celeba_example.jpg"))
dinv.utils.plot(
    [x] + [t(x) for t in transforms],
    ["Orig"] + [t.__class__.__name__ for t in transforms],
)
Orig, Shift, Rotate, Scale, Homography, Euclidean, Similarity, Affine, PanTiltRotate

Now, we run an inpainting experiment to reconstruct images from images masked with a random mask, without ground truth, using EI. For this example we use the Urban100 images of natural urban scenes. As these scenes are imaged with a camera free to move and rotate in the world, all of the above transformations are valid invariances that we can impose on the unknown image set \(x\in X\).

dataset = dinv.datasets.Urban100HR(
    root="Urban100",
    download=True,
    transform=Compose([ToTensor(), Resize(256), CenterCrop(256)]),
)

train_dataset, test_dataset = random_split(dataset, (0.8, 0.2))

train_dataloader = DataLoader(train_dataset, shuffle=True)
test_dataloader = DataLoader(test_dataset)

# Use physics to generate data online
physics = dinv.physics.Inpainting((3, 256, 256), mask=0.6, device=device)
  0%|          | 0/135388067 [00:00<?, ?it/s]
 15%|█▌        | 19.6M/129M [00:00<00:00, 205MB/s]
 31%|███       | 40.0M/129M [00:00<00:00, 210MB/s]
 48%|████▊     | 62.2M/129M [00:00<00:00, 219MB/s]
 66%|██████▌   | 85.4M/129M [00:00<00:00, 228MB/s]
 84%|████████▍ | 109M/129M [00:00<00:00, 234MB/s]
100%|██████████| 129M/129M [00:00<00:00, 228MB/s]

Extracting:   0%|          | 0/101 [00:00<?, ?it/s]
Extracting:  16%|█▌        | 16/101 [00:00<00:00, 145.88it/s]
Extracting:  32%|███▏      | 32/101 [00:00<00:00, 153.04it/s]
Extracting:  51%|█████▏    | 52/101 [00:00<00:00, 172.03it/s]
Extracting:  69%|██████▉   | 70/101 [00:00<00:00, 160.91it/s]
Extracting:  86%|████████▌ | 87/101 [00:00<00:00, 155.75it/s]
Extracting: 100%|██████████| 101/101 [00:00<00:00, 156.52it/s]
Dataset has been successfully downloaded.

For training, use a small UNet, Adam optimizer, EI loss with homography transform, and the deepinv.Trainer functionality:

Note

We only train for a single epoch in the demo, but it is recommended to train multiple epochs in practice.

model = dinv.models.UNet(
    in_channels=3, out_channels=3, scales=2, circular_padding=True, batch_norm=False
).to(device)

losses = [
    dinv.loss.MCLoss(),
    dinv.loss.EILoss(dinv.transform.Homography(theta_max=10, device=device)),
]

optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-8)

model = dinv.Trainer(
    model=model,
    physics=physics,
    online_measurements=True,
    train_dataloader=train_dataloader,
    eval_dataloader=test_dataloader,
    epochs=1,
    losses=losses,
    optimizer=optimizer,
    verbose=True,
    show_progress_bar=False,
    save_path=None,
    device=device,
).train()
The model has 444867 trainable parameters
Train epoch 0: MCLoss=0.008, EILoss=0.022, TotalLoss=0.03, PSNR=10.791
Eval epoch 0: PSNR=18.813

Show results of a pretrained model trained using a larger UNet for 40 epochs:

model = dinv.models.UNet(
    in_channels=3, out_channels=3, scales=3, circular_padding=True, batch_norm=False
).to(device)

ckpt = torch.hub.load_state_dict_from_url(
    dinv.models.utils.get_weights_url("ei", "Urban100_inpainting_homography_model.pth"),
    map_location=device,
)

model.load_state_dict(ckpt["state_dict"])

x = next(iter(train_dataloader))
x = x.to(device)
y = physics(x)
x_hat = model(y)

dinv.utils.plot([x, y, x_hat], ["x", "y", "reconstruction"])
x, y, reconstruction
Downloading: "https://huggingface.co/deepinv/ei/resolve/main/Urban100_inpainting_homography_model.pth?download=true" to /home/runner/.cache/torch/hub/checkpoints/Urban100_inpainting_homography_model.pth

  0%|          | 0.00/7.90M [00:00<?, ?B/s]
 19%|█▉        | 1.50M/7.90M [00:00<00:00, 12.9MB/s]
 35%|███▍      | 2.75M/7.90M [00:00<00:00, 12.9MB/s]
 51%|█████     | 4.00M/7.90M [00:00<00:00, 10.7MB/s]
 65%|██████▍   | 5.12M/7.90M [00:00<00:00, 11.1MB/s]
 79%|███████▉  | 6.25M/7.90M [00:00<00:00, 11.2MB/s]
 93%|█████████▎| 7.38M/7.90M [00:00<00:00, 11.4MB/s]
100%|██████████| 7.90M/7.90M [00:00<00:00, 11.0MB/s]

Total running time of the script: (2 minutes 14.303 seconds)

Gallery generated by Sphinx-Gallery