Self-supervised learning with measurement splitting

We demonstrate self-supervised learning with measurement splitting, to train a denoiser network on the MNIST dataset. The physics here is noisy computed tomography, as is the case in Noise2Inverse. Note this example can also be easily applied to undersampled multicoil MRI as is the case in SSDU.

Measurement splitting constructs a ground-truth free loss \(\frac{m}{m_2}\| y_2 - A_2 \inversef{y_1}{A_1}\|^2\) by splitting the measurement and the forward operator using a randomly generated mask.

See deepinv.loss.SplittingLoss for full details.

import deepinv as dinv
from torch.utils.data import DataLoader
import torch
from torchvision import transforms, datasets
from deepinv.models.utils import get_weights_url

torch.manual_seed(0)
device = dinv.utils.get_freer_gpu() if torch.cuda.is_available() else "cpu"

Define loss

Our implementation has multiple optional parameters that control how the splitting is to be achieved. For example, you can:

Note that after the model has been defined, the loss must also “adapt” the model.

loss = dinv.loss.SplittingLoss(split_ratio=0.6, eval_split_input=True, eval_n_samples=5)

Prepare data

We use the torchvision MNIST dataset, and use noisy tomography physics (with number of angles equal to the image size) for the forward operator.

Note

We use a subset of the whole training set to reduce the computational load of the example. We recommend to use the whole set by setting train_datapoints=test_datapoints=None to get the best results.

transform = transforms.Compose([transforms.ToTensor()])

train_dataset = datasets.MNIST(root=".", train=True, transform=transform, download=True)
test_dataset = datasets.MNIST(root=".", train=False, transform=transform, download=True)

physics = dinv.physics.Tomography(
    angles=28,
    img_width=28,
    noise_model=dinv.physics.noise.GaussianNoise(0.1),
    device=device,
)

deepinv_datasets_path = dinv.datasets.generate_dataset(
    train_dataset=train_dataset,
    test_dataset=test_dataset,
    physics=physics,
    device=device,
    save_dir="MNIST",
    train_datapoints=100,
    test_datapoints=10,
)

train_dataset = dinv.datasets.HDF5Dataset(path=deepinv_datasets_path, train=True)
test_dataset = dinv.datasets.HDF5Dataset(path=deepinv_datasets_path, train=False)

train_dataloader = DataLoader(train_dataset, shuffle=True)
test_dataloader = DataLoader(test_dataset, shuffle=False)
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
<urlopen error [SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed: certificate has expired (_ssl.c:1147)>

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to ./MNIST/raw/train-images-idx3-ubyte.gz

  0%|          | 0.00/9.91M [00:00<?, ?B/s]
100%|██████████| 9.91M/9.91M [00:00<00:00, 129MB/s]
Extracting ./MNIST/raw/train-images-idx3-ubyte.gz to ./MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
<urlopen error [SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed: certificate has expired (_ssl.c:1147)>

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to ./MNIST/raw/train-labels-idx1-ubyte.gz

  0%|          | 0.00/28.9k [00:00<?, ?B/s]
100%|██████████| 28.9k/28.9k [00:00<00:00, 5.37MB/s]
Extracting ./MNIST/raw/train-labels-idx1-ubyte.gz to ./MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Failed to download (trying next):
<urlopen error [SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed: certificate has expired (_ssl.c:1147)>

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to ./MNIST/raw/t10k-images-idx3-ubyte.gz

  0%|          | 0.00/1.65M [00:00<?, ?B/s]
100%|██████████| 1.65M/1.65M [00:00<00:00, 39.8MB/s]
Extracting ./MNIST/raw/t10k-images-idx3-ubyte.gz to ./MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Failed to download (trying next):
<urlopen error [SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed: certificate has expired (_ssl.c:1147)>

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to ./MNIST/raw/t10k-labels-idx1-ubyte.gz

  0%|          | 0.00/4.54k [00:00<?, ?B/s]
100%|██████████| 4.54k/4.54k [00:00<00:00, 20.1MB/s]
Extracting ./MNIST/raw/t10k-labels-idx1-ubyte.gz to ./MNIST/raw

Dataset has been saved in MNIST

Define model

We use a simple U-Net architecture with 2 scales as the denoiser network.

To reduce training time, we use a pretrained model. Here we demonstrate training with 100 images for 1 epoch, after having loaded a pretrained model trained that was with 1000 images for 20 epochs.

Note

When using the splitting loss, the model must be “adapted” by the loss, as its forward pass takes only a subset of the pixels, not the full image.

model = dinv.models.ArtifactRemoval(
    dinv.models.UNet(in_channels=1, out_channels=1, scales=2).to(device), pinv=True
)
model = loss.adapt_model(model)

optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-8)

# Load pretrained model
file_name = "demo_measplit_mnist_tomography.pth"
url = get_weights_url(model_name="measplit", file_name=file_name)
ckpt = torch.hub.load_state_dict_from_url(
    url, map_location=lambda storage, loc: storage, file_name=file_name
)

model.load_state_dict(ckpt["state_dict"])
optimizer.load_state_dict(ckpt["optimizer"])
Downloading: "https://huggingface.co/deepinv/measplit/resolve/main/demo_measplit_mnist_tomography.pth?download=true" to /home/runner/.cache/torch/hub/checkpoints/demo_measplit_mnist_tomography.pth

  0%|          | 0.00/5.13M [00:00<?, ?B/s]
 22%|██▏       | 1.12M/5.13M [00:00<00:00, 11.3MB/s]
 44%|████▍     | 2.25M/5.13M [00:00<00:00, 11.5MB/s]
 66%|██████▌   | 3.38M/5.13M [00:00<00:00, 10.5MB/s]
 88%|████████▊ | 4.50M/5.13M [00:00<00:00, 10.9MB/s]
100%|██████████| 5.13M/5.13M [00:00<00:00, 10.6MB/s]

Train and test network

trainer = dinv.Trainer(
    model=model,
    physics=physics,
    epochs=1,
    losses=loss,
    optimizer=optimizer,
    device=device,
    train_dataloader=train_dataloader,
    plot_images=False,
    save_path=None,
    verbose=True,
    show_progress_bar=False,
    no_learning_method="A_dagger",  # use pseudo-inverse as no-learning baseline
)

model = trainer.train()
The model has 444737 trainable parameters
Train epoch 0: TotalLoss=0.032, PSNR=29.007

Test and visualise the model outputs using a small test set. We set the output to average over 5 iterations of random mask realisations. The trained model improves on the no-learning reconstruction by ~7dB.

trainer.plot_images = True
trainer.test(test_dataloader)
Ground truth, No learning, Reconstruction
Eval epoch 0: PSNR=31.238, PSNR no learning=24.549
Test results:
PSNR no learning: 24.549 +- 1.052
PSNR: 31.238 +- 2.738

{'PSNR no learning': 24.54879093170166, 'PSNR no learning_std': 1.0523137744435074, 'PSNR': 31.238416862487792, 'PSNR_std': 2.7380147046892569}

Demonstrate the effect of not averaging over multiple realisations of the splitting mask at evaluation time, by setting eval_n_samples=1. We have a worse performance:

Ground truth, No learning, Reconstruction
Eval epoch 0: PSNR=29.202, PSNR no learning=24.549
Test results:
PSNR no learning: 24.549 +- 1.052
PSNR: 29.202 +- 2.439

{'PSNR no learning': 24.54879093170166, 'PSNR no learning_std': 1.0523137744435074, 'PSNR': 29.201856231689455, 'PSNR_std': 2.4385263457617437}

Furthermore, we can disable measurement splitting at evaluation altogether by setting eval_split_input to False (this is done in SSDU). This generally is worse than MC averaging:

Ground truth, No learning, Reconstruction
Eval epoch 0: PSNR=31.056, PSNR no learning=24.549
Test results:
PSNR no learning: 24.549 +- 1.052
PSNR: 31.056 +- 2.507

{'PSNR no learning': 24.54879093170166, 'PSNR no learning_std': 1.0523137744435074, 'PSNR': 31.055921363830567, 'PSNR_std': 2.5073731023586285}

Total running time of the script: (0 minutes 9.443 seconds)

Gallery generated by Sphinx-Gallery