Note
Go to the end to download the full example code.
Image transformations for Equivariant Imaging#
This example demonstrates various geometric image transformations
implemented in deepinv
that can be used in Equivariant Imaging (EI)
for self-supervised learning:
Shift: integer pixel 2D shift;
Rotate: 2D image rotation;
Scale: continuous 2D image downscaling;
Euclidean: includes continuous translation, rotation, and reflection, forming the group \(\mathbb{E}(2)\);
Similarity: as above but includes scale, forming the group \(\text{S}(2)\);
Affine: as above but includes shear effects, forming the group \(\text{Aff}(3)\);
Homography: as above but includes perspective (i.e pan and tilt) effects, forming the group \(\text{PGL}(3)\);
PanTiltRotate: pure 3D camera rotation i.e pan, tilt and 2D image rotation.
See docs for full list.
These were proposed in the papers:
Shift
,Rotate
: Chen et al., Equivariant Imaging: Learning Beyond the Range SpaceScale
: Scanvic et al., Self-Supervised Learning for Image Super-Resolution and DeblurringHomography
and the projective geometry framework: Wang et al., Perspective-Equivariant Imaging: an Unsupervised Framework for Multispectral Pansharpening
import torch
from torch.utils.data import DataLoader, random_split
from torchvision.datasets import ImageFolder
from torchvision.transforms import Compose, ToTensor, CenterCrop, Resize
from torchvision.datasets.utils import download_and_extract_archive
import deepinv as dinv
from deepinv.utils.demo import get_data_home
device = dinv.utils.get_freer_gpu() if torch.cuda.is_available() else "cpu"
ORIGINAL_DATA_DIR = get_data_home() / "Urban100"
Define transforms. For the transforms that involve 3D camera rotation
(i.e pan or tilt), we limit theta_max
for display.
transforms = [
dinv.transform.Shift(),
dinv.transform.Rotate(),
dinv.transform.Scale(),
dinv.transform.Homography(theta_max=10),
dinv.transform.projective.Euclidean(),
dinv.transform.projective.Similarity(),
dinv.transform.projective.Affine(),
dinv.transform.projective.PanTiltRotate(theta_max=10),
]
Plot transforms on a sample image. Note that, during training, we never
have access to these ground truth images x
, only partial and noisy
measurements y
.
x = dinv.utils.load_url_image(dinv.utils.demo.get_image_url("celeba_example.jpg"))
dinv.utils.plot(
[x] + [t(x) for t in transforms],
["Orig"] + [t.__class__.__name__ for t in transforms],
)
data:image/s3,"s3://crabby-images/f46cf/f46cf4146e15d901c59574eb5b929e3b6a14d1b0" alt="Orig, Shift, Rotate, Scale, Homography, Euclidean, Similarity, Affine, PanTiltRotate"
Now, we run an inpainting experiment to reconstruct images from images masked with a random mask, without ground truth, using EI. For this example we use the Urban100 images of natural urban scenes. As these scenes are imaged with a camera free to move and rotate in the world, all of the above transformations are valid invariances that we can impose on the unknown image set \(x\in X\).
dataset = dinv.datasets.Urban100HR(
root=ORIGINAL_DATA_DIR,
download=True,
transform=Compose([ToTensor(), Resize(256), CenterCrop(256)]),
)
train_dataset, test_dataset = random_split(dataset, (0.8, 0.2))
train_dataloader = DataLoader(train_dataset, shuffle=True)
test_dataloader = DataLoader(test_dataset)
# Use physics to generate data online
physics = dinv.physics.Inpainting((3, 256, 256), mask=0.6, device=device)
0%| | 0/135388067 [00:00<?, ?it/s]
14%|█▍ | 18.4M/129M [00:00<00:00, 193MB/s]
33%|███▎ | 42.2M/129M [00:00<00:00, 226MB/s]
50%|█████ | 64.7M/129M [00:00<00:00, 230MB/s]
67%|██████▋ | 86.7M/129M [00:00<00:00, 228MB/s]
84%|████████▍ | 109M/129M [00:00<00:00, 230MB/s]
100%|██████████| 129M/129M [00:00<00:00, 229MB/s]
Extracting: 0%| | 0/101 [00:00<?, ?it/s]
Extracting: 19%|█▉ | 19/101 [00:00<00:00, 187.28it/s]
Extracting: 40%|███▉ | 40/101 [00:00<00:00, 199.09it/s]
Extracting: 62%|██████▏ | 63/101 [00:00<00:00, 211.42it/s]
Extracting: 84%|████████▍ | 85/101 [00:00<00:00, 192.40it/s]
Extracting: 100%|██████████| 101/101 [00:00<00:00, 196.44it/s]
Dataset has been successfully downloaded.
For training, use a small UNet, Adam optimizer, EI loss with homography
transform, and the deepinv.Trainer
functionality:
Note
We only train for a single epoch in the demo, but it is recommended to train multiple epochs in practice.
model = dinv.models.UNet(
in_channels=3, out_channels=3, scales=2, circular_padding=True, batch_norm=False
).to(device)
losses = [
dinv.loss.MCLoss(),
dinv.loss.EILoss(dinv.transform.Homography(theta_max=10, device=device)),
]
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-8)
model = dinv.Trainer(
model=model,
physics=physics,
online_measurements=True,
train_dataloader=train_dataloader,
eval_dataloader=test_dataloader,
epochs=1,
losses=losses,
optimizer=optimizer,
verbose=True,
show_progress_bar=False,
save_path=None,
device=device,
).train()
The model has 444867 trainable parameters
Train epoch 0: MCLoss=0.008, EILoss=0.024, TotalLoss=0.032, PSNR=9.754
Eval epoch 0: PSNR=17.21
Show results of a pretrained model trained using a larger UNet for 40 epochs:
model = dinv.models.UNet(
in_channels=3, out_channels=3, scales=3, circular_padding=True, batch_norm=False
).to(device)
ckpt = torch.hub.load_state_dict_from_url(
dinv.models.utils.get_weights_url("ei", "Urban100_inpainting_homography_model.pth"),
map_location=device,
)
model.load_state_dict(ckpt["state_dict"])
x = next(iter(train_dataloader))
x = x.to(device)
y = physics(x)
x_hat = model(y)
dinv.utils.plot([x, y, x_hat], ["x", "y", "reconstruction"])
data:image/s3,"s3://crabby-images/f1657/f1657f6ee1d486bcf8d6515493e4a19da60a8bbd" alt="x, y, reconstruction"
Downloading: "https://huggingface.co/deepinv/ei/resolve/main/Urban100_inpainting_homography_model.pth?download=true" to /home/runner/.cache/torch/hub/checkpoints/Urban100_inpainting_homography_model.pth
0%| | 0.00/7.90M [00:00<?, ?B/s]
14%|█▍ | 1.12M/7.90M [00:00<00:00, 10.5MB/s]
28%|██▊ | 2.25M/7.90M [00:00<00:00, 11.1MB/s]
43%|████▎ | 3.38M/7.90M [00:00<00:00, 10.4MB/s]
55%|█████▌ | 4.38M/7.90M [00:00<00:00, 10.4MB/s]
68%|██████▊ | 5.38M/7.90M [00:00<00:00, 10.4MB/s]
81%|████████ | 6.38M/7.90M [00:00<00:00, 10.4MB/s]
93%|█████████▎| 7.38M/7.90M [00:00<00:00, 10.4MB/s]
100%|██████████| 7.90M/7.90M [00:00<00:00, 10.5MB/s]
Total running time of the script: (2 minutes 18.391 seconds)