Self-supervised learning with Equivariant Splitting#

Equivariant splitting consists in minimizing a self-supervised loss to train a reconstruction model using measurement data only [1].

It is based on the same assumption of invariance as equivariant imaging Self-supervised learning with Equivariant Imaging for MRI. Namely, the distribution of ground truth images is assumed to be invariant to certain transformations such as translations, rotations and flips.

Moreover, it is also based on splitting methods which separate measurements into inputs and targets \(y = [y_1^\top, y_2^\top]^\top\). The target measurements are not fed to the network and guide the network to learn to predict information that is not present in the input measurements.

The equivariant splitting loss combines the two approaches as:

\[\mathcal{L}_{\mathrm{ES}} (y, A, f) = \mathbb{E}_g \Big\{ \mathbb{E}_{y_1, A_1 \mid y, A T_g} \Big\{ \underbrace{\| A_1 R(y_1, A_1) - A_1 x \|^2}_{\text{Consistency term}} + \underbrace{\| A_2 R(y_1, A_1) - A_2 x \|^2}_{\text{Prediction term}} \Big\} \Big\}\]

where \(T_g\) denote a transformation and \(A T_g\) the associated virtual physics represented in the library by the class deepinv.physics.VirtualLinearPhysics. The loss itself is implemented in deepinv.loss.EquivariantSplittingLoss and this example shows how to use it to train a reconstruction model in a fully self-supervised way.

from pathlib import Path
import torch
from torch.utils.data import DataLoader
from torchvision import transforms
import os

import deepinv as dinv

Setup random seeds, paths and device#

# For reproducilibity
torch.manual_seed(0)
torch.cuda.manual_seed_all(0)
torch.backends.cudnn.deterministic = True

# Setup paths
BASE_DIR = Path(".")
DATA_DIR = BASE_DIR / "measurements"
CKPT_DIR = BASE_DIR / "ckpts"

# Select the device
device = dinv.utils.get_freer_gpu() if torch.cuda.is_available() else "cpu"

Forward model#

First, we define the forward model, here an inpainting problem with a fixed mask keeping 70% of pixels:

Create the imaging dataset#

Using the forward model and a base dataset, here deepinv.datasets.Urban100HR, we generate an imaging dataset that we further split into a 80 training samples, 19 evaluation samples and 1 test sample.

Dataset has been saved at Urban100/dinv_dataset0.h5

Visualizing the problem#

x, y = test_dataset[0]
x, y = x.unsqueeze(0), y.unsqueeze(0)
x, y = x.to(device), y.to(device)

psnr_fn = dinv.metric.PSNR()
psnr_y = psnr_fn(y, x).item()

dinv.utils.plot(
    [y, x],
    ["Measurements", "Ground truth"],
    subtitles=[f"PSNR={psnr_y:.1f}dB", ""],
    fontsize=10,
)
Measurements, Ground truth

Create the base model#

Here, we fine tune a pretrained deepinv.models.RAM model but any trainable reconstructor would do.

In order to track the improvements brought by fine-tuning the model, we also create a copy of the pre-trained model that will not be fine-tuned.

model = dinv.models.RAM(pretrained=True).to(device)
model_no_learning = dinv.models.RAM(pretrained=True).to(device)

model_no_learning.eval()
with torch.no_grad():
    x_pretrained = model_no_learning(y, physics)

psnr_pretrained = psnr_fn(x_pretrained, x).item()

dinv.utils.plot(
    [y, x_pretrained, x],
    ["Measurements", "RAM (Pre-trained)", "Ground truth"],
    subtitles=[f"PSNR={psnr_y:.1f}dB", f"PSNR={psnr_pretrained:.1f}dB", ""],
    fontsize=10,
)
Measurements, RAM (Pre-trained), Ground truth

Setup the equivariant splitting loss#

We create an instance of deepinv.loss.EquivariantSplittingLoss that implements the equivariant splitting loss.

The equivariant splitting loss requires the definition of a splitting scheme similarly to deepinv.loss.SplittingLoss. Here, we choose a pixel-wise Bernoulli splitting scheme with a split ratio of 0.9 using deepinv.physics.generator.BernoulliSplittingMaskGenerator.

Equivariant splitting requires choosing a set of transformations for which the forward operator is not equivariant. For inpainting, valid choices include shifts, rotations and reflections [1]. Here, we choose rotations and reflections.

Since the base model RAM is not already equivariant to these transformations, we use group averaging by passing in transform and eval_transform to the loss. Namely, we swap the base reconstructor \(\tilde{R}\) for the equivariant reconstructor defined by

\[R(y, A) = \frac{1}{|\mathcal{G}|}\sum_{g\in \mathcal{G}} T_g \tilde{R}(y, A T_g)\]

which is estimated using a Monte Carlo sampling where a subset of transformations is used, typically a single one at training time and the full set at evaluation time. Internally, the input model is wrapped in an deepinv.models.EquivariantReconstructor when calling deepinv.loss.EquivariantSplittingLoss.adapt_model().

Note

The equivariant splitting loss consists in a prediction term and a consistency term. In the absence of noise, they are computed exactly using deepinv.loss.MCLoss. In the presence of noise, they can be estimated without bias using denoising losses such as deepinv.loss.R2RLoss and deepinv.loss.SureGaussianLoss.

# Splitting scheme
mask_generator = dinv.physics.generator.BernoulliSplittingMaskGenerator(
    img_size=(1, img_size, img_size),
    split_ratio=0.9,
    pixelwise=True,
    device=device,
)

# Underlying measurement comparison losses
consistency_loss = dinv.loss.MCLoss(metric=dinv.metric.MSE())
prediction_loss = dinv.loss.MCLoss(metric=dinv.metric.MSE())

# A random grid-preserving transformation
train_transform = dinv.transform.Rotate(
    n_trans=1, multiples=90, positive=True
) * dinv.transform.Reflect(n_trans=1, dim=[-1])
# All grid-preserving transformations
eval_transform = dinv.transform.Rotate(
    n_trans=4, multiples=90, positive=True
) * dinv.transform.Reflect(n_trans=2, dim=[-1])

es_loss = dinv.loss.EquivariantSplittingLoss(
    mask_generator=mask_generator,
    consistency_loss=consistency_loss,
    prediction_loss=prediction_loss,
    transform=train_transform,
    eval_transform=eval_transform,
    eval_n_samples=5,
)

# Wrap the model so it takes split measurements as input and apply Reynolds averaging
model = es_loss.adapt_model(model)

Train the model#

We fine-tune the pre-trained model using the equivariant splitting loss and early stopping with the validation loss as criterion. This makes the whole training fully self-supervised and usable when no ground truth image is available.

Note

We skip the training and directly load the cached checkpoint to avoid making the documentation longer to build but you can get the same results by running the training locally.

# Cached checkpoint after training to avoid doing the computation over and over
cached_checkpoint = (
    "https://huggingface.co/jscanvic/deepinv/resolve/main/ES/demo/ckp_best.pth.tar"
)

if cached_checkpoint is None:
    epochs = 20
    ckpt_pretrained = None
else:
    epochs = 0
    ckpt_pretrained = (
        dinv.utils.get_data_home() / "examples" / "ES" / "ckp_best.pth.tar"
    )
    os.makedirs(ckpt_pretrained.parent, exist_ok=True)

    # Download if not found
    if not ckpt_pretrained.exists():
        torch.hub.download_url_to_file(cached_checkpoint, ckpt_pretrained)
    else:
        print(f"Checkpoint found at {ckpt_pretrained}, skipping download.")

    # Ignore RNG states from the checkpoint
    ckpt = torch.load(ckpt_pretrained, map_location=device, weights_only=True)
    state_dict = ckpt["state_dict"]
    current_state_dict = model.state_dict()
    for key in list(state_dict.keys()):
        if "initial_random_state" in key:
            state_dict[key] = current_state_dict[key]
    torch.save(ckpt, ckpt_pretrained)

trainer = dinv.Trainer(
    model,
    physics=physics,
    epochs=epochs,
    ckpt_pretrained=ckpt_pretrained,
    ckp_interval=epochs,
    scheduler=None,
    losses=[es_loss],
    optimizer=torch.optim.AdamW(model.parameters(), lr=5e-4, weight_decay=1e-8),
    train_dataloader=train_dataloader,
    eval_dataloader=eval_dataloader,
    metrics=[
        dinv.metric.PSNR()
    ],  # Supervised oracle metric for monitoring, not used for training and early stopping
    plot_images=False,
    device=device,
    verbose=True,
    show_progress_bar=False,
    no_learning_method=model_no_learning,
    compute_eval_losses=True,  # Compute eval losses
    early_stop=3,  # Patience parameter, stop if there is no improvement for multiple epochs
    early_stop_on_losses=True,  # Use the evaluation loss as a self-supervised stopping criterion
)

trainer.train()
if epochs > 0:
    trainer.load_best_model()
  0%|          | 0.00/397M [00:00<?, ?B/s]
  1%|▏         | 5.62M/397M [00:00<00:07, 58.4MB/s]
  3%|β–Ž         | 11.5M/397M [00:00<00:06, 60.0MB/s]
  4%|▍         | 17.2M/397M [00:00<00:06, 57.1MB/s]
  6%|β–Œ         | 23.6M/397M [00:00<00:06, 60.7MB/s]
  7%|β–‹         | 29.5M/397M [00:00<00:06, 59.0MB/s]
  9%|β–‰         | 35.2M/397M [00:00<00:06, 54.6MB/s]
 10%|β–ˆ         | 40.6M/397M [00:00<00:10, 35.9MB/s]
 12%|β–ˆβ–        | 48.9M/397M [00:01<00:07, 46.9MB/s]
 15%|β–ˆβ–Œ        | 60.9M/397M [00:01<00:05, 64.4MB/s]
 17%|β–ˆβ–‹        | 68.1M/397M [00:01<00:05, 63.2MB/s]
 19%|β–ˆβ–‰        | 75.0M/397M [00:01<00:05, 61.4MB/s]
 21%|β–ˆβ–ˆ        | 81.4M/397M [00:01<00:05, 60.5MB/s]
 22%|β–ˆβ–ˆβ–       | 87.5M/397M [00:01<00:05, 59.7MB/s]
 24%|β–ˆβ–ˆβ–Ž       | 93.5M/397M [00:01<00:05, 58.2MB/s]
 25%|β–ˆβ–ˆβ–Œ       | 99.2M/397M [00:01<00:05, 58.2MB/s]
 26%|β–ˆβ–ˆβ–‹       | 105M/397M [00:01<00:05, 57.9MB/s]
 28%|β–ˆβ–ˆβ–Š       | 111M/397M [00:02<00:05, 57.5MB/s]
 29%|β–ˆβ–ˆβ–‰       | 116M/397M [00:02<00:05, 54.3MB/s]
 31%|β–ˆβ–ˆβ–ˆ       | 123M/397M [00:02<00:04, 57.8MB/s]
 32%|β–ˆβ–ˆβ–ˆβ–      | 128M/397M [00:02<00:05, 56.1MB/s]
 34%|β–ˆβ–ˆβ–ˆβ–      | 134M/397M [00:02<00:04, 57.6MB/s]
 35%|β–ˆβ–ˆβ–ˆβ–Œ      | 140M/397M [00:02<00:04, 57.8MB/s]
 37%|β–ˆβ–ˆβ–ˆβ–‹      | 146M/397M [00:02<00:04, 57.5MB/s]
 38%|β–ˆβ–ˆβ–ˆβ–Š      | 151M/397M [00:02<00:04, 57.3MB/s]
 39%|β–ˆβ–ˆβ–ˆβ–‰      | 157M/397M [00:02<00:04, 56.6MB/s]
 41%|β–ˆβ–ˆβ–ˆβ–ˆ      | 162M/397M [00:02<00:04, 56.9MB/s]
 42%|β–ˆβ–ˆβ–ˆβ–ˆβ–     | 168M/397M [00:03<00:04, 56.9MB/s]
 44%|β–ˆβ–ˆβ–ˆβ–ˆβ–Ž     | 173M/397M [00:03<00:04, 56.3MB/s]
 45%|β–ˆβ–ˆβ–ˆβ–ˆβ–Œ     | 179M/397M [00:03<00:04, 56.0MB/s]
 46%|β–ˆβ–ˆβ–ˆβ–ˆβ–‹     | 184M/397M [00:03<00:03, 56.8MB/s]
 48%|β–ˆβ–ˆβ–ˆβ–ˆβ–Š     | 190M/397M [00:03<00:03, 56.8MB/s]
 49%|β–ˆβ–ˆβ–ˆβ–ˆβ–‰     | 195M/397M [00:03<00:03, 55.4MB/s]
 51%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆ     | 201M/397M [00:03<00:03, 57.3MB/s]
 52%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–    | 207M/397M [00:03<00:03, 57.2MB/s]
 53%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Ž    | 212M/397M [00:04<00:04, 44.8MB/s]
 55%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–    | 217M/397M [00:04<00:04, 42.5MB/s]
 56%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Œ    | 223M/397M [00:04<00:03, 47.7MB/s]
 58%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Š    | 229M/397M [00:04<00:03, 49.6MB/s]
 59%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–‰    | 234M/397M [00:04<00:03, 51.9MB/s]
 60%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ    | 240M/397M [00:04<00:03, 52.3MB/s]
 62%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–   | 245M/397M [00:04<00:02, 54.1MB/s]
 63%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Ž   | 251M/397M [00:04<00:02, 55.0MB/s]
 65%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–   | 256M/397M [00:04<00:02, 55.2MB/s]
 66%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Œ   | 262M/397M [00:04<00:02, 56.2MB/s]
 67%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–‹   | 268M/397M [00:05<00:02, 56.3MB/s]
 69%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–‰   | 273M/397M [00:05<00:02, 56.6MB/s]
 70%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ   | 279M/397M [00:05<00:02, 56.4MB/s]
 72%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–  | 284M/397M [00:05<00:02, 56.6MB/s]
 73%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Ž  | 290M/397M [00:05<00:01, 56.4MB/s]
 74%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–  | 295M/397M [00:05<00:02, 49.3MB/s]
 76%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–‹  | 303M/397M [00:05<00:01, 58.7MB/s]
 78%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Š  | 309M/397M [00:05<00:01, 58.3MB/s]
 79%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–‰  | 315M/397M [00:05<00:01, 58.0MB/s]
 81%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ  | 321M/397M [00:06<00:01, 57.7MB/s]
 82%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ– | 326M/397M [00:06<00:01, 57.3MB/s]
 84%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Ž | 332M/397M [00:06<00:01, 56.8MB/s]
 85%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Œ | 338M/397M [00:06<00:01, 57.2MB/s]
 86%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–‹ | 343M/397M [00:06<00:00, 57.0MB/s]
 88%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Š | 348M/397M [00:06<00:00, 56.3MB/s]
 89%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–‰ | 354M/397M [00:06<00:00, 57.2MB/s]
 91%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ | 360M/397M [00:06<00:00, 56.6MB/s]
 92%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–| 365M/397M [00:06<00:00, 56.5MB/s]
 93%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Ž| 371M/397M [00:06<00:00, 56.6MB/s]
 95%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–| 376M/397M [00:07<00:00, 56.5MB/s]
 96%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Œ| 382M/397M [00:07<00:00, 56.6MB/s]
 98%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Š| 388M/397M [00:07<00:00, 56.9MB/s]
 99%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–‰| 393M/397M [00:07<00:00, 56.9MB/s]
100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 397M/397M [00:07<00:00, 55.7MB/s]
The model has 35618813 trainable parameters
Model, optimizer, epoch_start successfully loaded from checkpoint: datasets/examples/ES/ckp_best.pth.tar
/local/jtachell/deepinv/deepinv/deepinv/training/trainer.py:556: UserWarning: No training will be done because epochs (0) <= loaded epoch_start (12) from checkpoint.
  warnings.warn(

Evaluation of the trained model#

We can now evaluate the trained model on the test set using the PSNR metric.

We also compare it to the pre-trained model without fine-tuning to see the benefit of fine-tuning with equivariant splitting:

# Compute the performance metrics on the whole test set
trainer.compute_eval_losses = False
trainer.early_stop_on_losses = False
trainer.test(test_dataloader, metrics=dinv.metric.PSNR())

# Display the reconstructions for a single test sample
model.eval()

with torch.no_grad():
    x_hat = model(y, physics)

psnr = psnr_fn(x_hat, x).item()

dinv.utils.plot(
    [y, x_pretrained, x_hat, x],
    ["Measurements", "RAM (Pre-trained)", "Equivariant Splitting", "Ground truth"],
    subtitles=[
        f"PSNR={psnr_y:.1f}dB",
        f"PSNR={psnr_pretrained:.1f}dB",
        f"PSNR={psnr:.1f}dB",
        "",
    ],
    fontsize=10,
)
Measurements, RAM (Pre-trained), Equivariant Splitting, Ground truth
Eval epoch 0: PSNR=30.589, PSNR no learning=26.745
Test results:
PSNR no learning: 26.745 +- 0.001
PSNR: 30.589 +- 0.000
References:

Total running time of the script: (0 minutes 31.454 seconds)

Gallery generated by Sphinx-Gallery