Note
Go to the end to download the full example code.
Self-supervised learning with measurement splitting
We demonstrate self-supervised learning with measurement splitting, to train a denoiser network on the MNIST dataset. The physics here is noisy computed tomography, as is the case in Noise2Inverse. Note this example can also be easily applied to undersampled multicoil MRI as is the case in SSDU.
Measurement splitting constructs a ground-truth free loss \(\frac{m}{m_2}\| y_2 - A_2 \inversef{y_1}{A_1}\|^2\) by splitting the measurement and the forward operator using a randomly generated mask.
See deepinv.loss.SplittingLoss
for full details.
import deepinv as dinv
from torch.utils.data import DataLoader
import torch
from torchvision import transforms, datasets
from deepinv.models.utils import get_weights_url
torch.manual_seed(0)
device = dinv.utils.get_freer_gpu() if torch.cuda.is_available() else "cpu"
Define loss
Our implementation has multiple optional parameters that control how the splitting is to be achieved. For example, you can:
Use
split_ratio
to set the ratio of pixels used in the forward pass vs the loss;Define custom masking methods using a
mask_generator
such asdeepinv.physics.generator.BernoulliSplittingMaskGenerator
ordeepinv.physics.generator.GaussianSplittingMaskGenerator
;Use
eval_n_samples
to set how many realisations of the random mask is used at evaluation time;Optionally disable measurement splitting at evaluation time using
eval_split_input
(as is the case in SSDU).Average over both input and output masks at evaluation time using
eval_split_output
. Seedeepinv.loss.SplittingLoss
for details.
Note that after the model has been defined, the loss must also “adapt” the model.
loss = dinv.loss.SplittingLoss(split_ratio=0.6, eval_split_input=True, eval_n_samples=5)
Prepare data
We use the torchvision
MNIST dataset, and use noisy tomography
physics (with number of angles equal to the image size) for the forward
operator.
Note
We use a subset of the whole training set to reduce the computational load of the example.
We recommend to use the whole set by setting train_datapoints=test_datapoints=None
to get the best results.
transform = transforms.Compose([transforms.ToTensor()])
train_dataset = datasets.MNIST(root=".", train=True, transform=transform, download=True)
test_dataset = datasets.MNIST(root=".", train=False, transform=transform, download=True)
physics = dinv.physics.Tomography(
angles=28,
img_width=28,
noise_model=dinv.physics.noise.GaussianNoise(0.1),
device=device,
)
deepinv_datasets_path = dinv.datasets.generate_dataset(
train_dataset=train_dataset,
test_dataset=test_dataset,
physics=physics,
device=device,
save_dir="MNIST",
train_datapoints=100,
test_datapoints=10,
)
train_dataset = dinv.datasets.HDF5Dataset(path=deepinv_datasets_path, train=True)
test_dataset = dinv.datasets.HDF5Dataset(path=deepinv_datasets_path, train=False)
train_dataloader = DataLoader(train_dataset, shuffle=True)
test_dataloader = DataLoader(test_dataset, shuffle=False)
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
<urlopen error [SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed: certificate has expired (_ssl.c:1147)>
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to ./MNIST/raw/train-images-idx3-ubyte.gz
0%| | 0.00/9.91M [00:00<?, ?B/s]
100%|██████████| 9.91M/9.91M [00:00<00:00, 129MB/s]
Extracting ./MNIST/raw/train-images-idx3-ubyte.gz to ./MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
<urlopen error [SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed: certificate has expired (_ssl.c:1147)>
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to ./MNIST/raw/train-labels-idx1-ubyte.gz
0%| | 0.00/28.9k [00:00<?, ?B/s]
100%|██████████| 28.9k/28.9k [00:00<00:00, 5.37MB/s]
Extracting ./MNIST/raw/train-labels-idx1-ubyte.gz to ./MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Failed to download (trying next):
<urlopen error [SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed: certificate has expired (_ssl.c:1147)>
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to ./MNIST/raw/t10k-images-idx3-ubyte.gz
0%| | 0.00/1.65M [00:00<?, ?B/s]
100%|██████████| 1.65M/1.65M [00:00<00:00, 39.8MB/s]
Extracting ./MNIST/raw/t10k-images-idx3-ubyte.gz to ./MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Failed to download (trying next):
<urlopen error [SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed: certificate has expired (_ssl.c:1147)>
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to ./MNIST/raw/t10k-labels-idx1-ubyte.gz
0%| | 0.00/4.54k [00:00<?, ?B/s]
100%|██████████| 4.54k/4.54k [00:00<00:00, 20.1MB/s]
Extracting ./MNIST/raw/t10k-labels-idx1-ubyte.gz to ./MNIST/raw
Dataset has been saved in MNIST
Define model
We use a simple U-Net architecture with 2 scales as the denoiser network.
To reduce training time, we use a pretrained model. Here we demonstrate training with 100 images for 1 epoch, after having loaded a pretrained model trained that was with 1000 images for 20 epochs.
Note
When using the splitting loss, the model must be “adapted” by the loss, as its forward pass takes only a subset of the pixels, not the full image.
model = dinv.models.ArtifactRemoval(
dinv.models.UNet(in_channels=1, out_channels=1, scales=2).to(device), pinv=True
)
model = loss.adapt_model(model)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-8)
# Load pretrained model
file_name = "demo_measplit_mnist_tomography.pth"
url = get_weights_url(model_name="measplit", file_name=file_name)
ckpt = torch.hub.load_state_dict_from_url(
url, map_location=lambda storage, loc: storage, file_name=file_name
)
model.load_state_dict(ckpt["state_dict"])
optimizer.load_state_dict(ckpt["optimizer"])
Downloading: "https://huggingface.co/deepinv/measplit/resolve/main/demo_measplit_mnist_tomography.pth?download=true" to /home/runner/.cache/torch/hub/checkpoints/demo_measplit_mnist_tomography.pth
0%| | 0.00/5.13M [00:00<?, ?B/s]
22%|██▏ | 1.12M/5.13M [00:00<00:00, 11.3MB/s]
44%|████▍ | 2.25M/5.13M [00:00<00:00, 11.5MB/s]
66%|██████▌ | 3.38M/5.13M [00:00<00:00, 10.5MB/s]
88%|████████▊ | 4.50M/5.13M [00:00<00:00, 10.9MB/s]
100%|██████████| 5.13M/5.13M [00:00<00:00, 10.6MB/s]
Train and test network
trainer = dinv.Trainer(
model=model,
physics=physics,
epochs=1,
losses=loss,
optimizer=optimizer,
device=device,
train_dataloader=train_dataloader,
plot_images=False,
save_path=None,
verbose=True,
show_progress_bar=False,
no_learning_method="A_dagger", # use pseudo-inverse as no-learning baseline
)
model = trainer.train()
The model has 444737 trainable parameters
Train epoch 0: TotalLoss=0.032, PSNR=29.007
Test and visualise the model outputs using a small test set. We set the output to average over 5 iterations of random mask realisations. The trained model improves on the no-learning reconstruction by ~7dB.
trainer.plot_images = True
trainer.test(test_dataloader)
Eval epoch 0: PSNR=31.238, PSNR no learning=24.549
Test results:
PSNR no learning: 24.549 +- 1.052
PSNR: 31.238 +- 2.738
{'PSNR no learning': 24.54879093170166, 'PSNR no learning_std': 1.0523137744435074, 'PSNR': 31.238416862487792, 'PSNR_std': 2.7380147046892569}
Demonstrate the effect of not averaging over multiple realisations of
the splitting mask at evaluation time, by setting eval_n_samples=1
.
We have a worse performance:
model.eval_n_samples = 1
trainer.test(test_dataloader)
Eval epoch 0: PSNR=29.202, PSNR no learning=24.549
Test results:
PSNR no learning: 24.549 +- 1.052
PSNR: 29.202 +- 2.439
{'PSNR no learning': 24.54879093170166, 'PSNR no learning_std': 1.0523137744435074, 'PSNR': 29.201856231689455, 'PSNR_std': 2.4385263457617437}
Furthermore, we can disable measurement splitting at evaluation
altogether by setting eval_split_input
to False (this is done in
SSDU). This generally is
worse than MC averaging:
model.eval_split_input = False
trainer.test(test_dataloader)
Eval epoch 0: PSNR=31.056, PSNR no learning=24.549
Test results:
PSNR no learning: 24.549 +- 1.052
PSNR: 31.056 +- 2.507
{'PSNR no learning': 24.54879093170166, 'PSNR no learning_std': 1.0523137744435074, 'PSNR': 31.055921363830567, 'PSNR_std': 2.5073731023586285}
Total running time of the script: (0 minutes 9.443 seconds)