Note
Go to the end to download the full example code.
Self-supervised learning from incomplete measurements of multiple operators.
This example shows you how to train a reconstruction network for an inpainting inverse problem on a fully self-supervised way, i.e., using measurement data only.
The dataset consists of pairs \((y_i,A_{g_i})\) where \(y_i\) are the measurements and \(A_{g_i}\) is a binary sampling operator out of \(G\) (i.e., \(g_i\in \{1,\dots,G\}\)).
This self-supervised learning approach is presented in “Unsupervised Learning From Incomplete Measurements for Inverse Problems”, and minimizes the loss function:
where \(R_{\theta}\) is a reconstruction network with parameters \(\theta\), \(y_i\) are the measurements, \(A_s\) is a binary sampling operator, and \(\hat{x}_{i,\theta} = R_{\theta}(y_i,A_{g_i})\).
import deepinv as dinv
from torch.utils.data import DataLoader
import torch
from pathlib import Path
from torchvision import transforms
from deepinv.models.utils import get_weights_url
from torchvision import datasets
Setup paths for data loading and results.
BASE_DIR = Path(".")
ORIGINAL_DATA_DIR = BASE_DIR / "datasets"
DATA_DIR = BASE_DIR / "measurements"
CKPT_DIR = BASE_DIR / "ckpts"
# Set the global random seed from pytorch to ensure reproducibility of the example.
torch.manual_seed(0)
device = dinv.utils.get_freer_gpu() if torch.cuda.is_available() else "cpu"
Load base image datasets and degradation operators.
In this example, we use the MNIST dataset for training and testing.
transform = transforms.Compose([transforms.ToTensor()])
train_base_dataset = datasets.MNIST(
root="../datasets/", train=True, transform=transform, download=True
)
test_base_dataset = datasets.MNIST(
root="../datasets/", train=False, transform=transform, download=True
)
Generate a dataset of subsampled images and load it.
We generate 10 different inpainting operators, each one with a different random mask.
If the deepinv.datasets.generate_dataset()
receives a list of physics operators, it
generates a dataset for each operator and returns a list of paths to the generated datasets.
Note
We only use 10 training images per operator to reduce the computational time of this example. You can use the whole
dataset by setting n_images_max = None
.
number_of_operators = 10
# defined physics
physics = [
dinv.physics.Inpainting(mask=0.5, tensor_size=(1, 28, 28), device=device)
for _ in range(number_of_operators)
]
# Use parallel dataloader if using a GPU to reduce training time,
# otherwise, as all computes are on CPU, use synchronous data loading.
num_workers = 4 if torch.cuda.is_available() else 0
n_images_max = (
None if torch.cuda.is_available() else 50
) # number of images used for training (uses the whole dataset if you have a gpu)
operation = "inpainting"
my_dataset_name = "demo_multioperator_imaging"
measurement_dir = DATA_DIR / "MNIST" / operation
deepinv_datasets_path = dinv.datasets.generate_dataset(
train_dataset=train_base_dataset,
test_dataset=test_base_dataset,
physics=physics,
device=device,
save_dir=measurement_dir,
train_datapoints=n_images_max,
test_datapoints=10,
num_workers=num_workers,
dataset_filename=str(my_dataset_name),
)
train_dataset = [
dinv.datasets.HDF5Dataset(path=path, train=True) for path in deepinv_datasets_path
]
test_dataset = [
dinv.datasets.HDF5Dataset(path=path, train=False) for path in deepinv_datasets_path
]
Dataset has been saved at measurements/MNIST/inpainting/demo_multioperator_imaging0.h5
Dataset has been saved at measurements/MNIST/inpainting/demo_multioperator_imaging1.h5
Dataset has been saved at measurements/MNIST/inpainting/demo_multioperator_imaging2.h5
Dataset has been saved at measurements/MNIST/inpainting/demo_multioperator_imaging3.h5
Dataset has been saved at measurements/MNIST/inpainting/demo_multioperator_imaging4.h5
Dataset has been saved at measurements/MNIST/inpainting/demo_multioperator_imaging5.h5
Dataset has been saved at measurements/MNIST/inpainting/demo_multioperator_imaging6.h5
Dataset has been saved at measurements/MNIST/inpainting/demo_multioperator_imaging7.h5
Dataset has been saved at measurements/MNIST/inpainting/demo_multioperator_imaging8.h5
Dataset has been saved at measurements/MNIST/inpainting/demo_multioperator_imaging9.h5
Set up the reconstruction network
As a reconstruction network, we use a simple artifact removal network based on a U-Net. The network is defined as a \(R_{\theta}(y,A)=\phi_{\theta}(A^{\top}y)\) where \(\phi\) is the U-Net.
# Define the unfolded trainable model.
model = dinv.models.ArtifactRemoval(
backbone_net=dinv.models.UNet(in_channels=1, out_channels=1, scales=3)
)
model = model.to(device)
Set up the training parameters
We choose a self-supervised training scheme with two losses: the measurement consistency loss (MC) and the multi-operator imaging loss (MOI). Necessary and sufficient conditions on the number of operators and measurements are described here.
Note
We use a pretrained model to reduce training time. You can get the same results by training from scratch for 100 epochs.
epochs = 1
learning_rate = 5e-4
batch_size = 64 if torch.cuda.is_available() else 1
# choose self-supervised training losses
# generates 4 random rotations per image in the batch
losses = [dinv.loss.MCLoss(), dinv.loss.MOILoss(physics)]
# choose optimizer and scheduler
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=1e-8)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=int(epochs * 0.8) + 1)
# start with a pretrained model to reduce training time
file_name = "demo_moi_ckp_10.pth"
url = get_weights_url(model_name="demo", file_name=file_name)
ckpt = torch.hub.load_state_dict_from_url(
url, map_location=lambda storage, loc: storage, file_name=file_name
)
# load a checkpoint to reduce training time
model.load_state_dict(ckpt["state_dict"])
optimizer.load_state_dict(ckpt["optimizer"])
Downloading: "https://huggingface.co/deepinv/demo/resolve/main/demo_moi_ckp_10.pth?download=true" to /home/runner/.cache/torch/hub/checkpoints/demo_moi_ckp_10.pth
0%| | 0.00/23.8M [00:00<?, ?B/s]
5%|▍ | 1.12M/23.8M [00:00<00:02, 11.1MB/s]
9%|▉ | 2.25M/23.8M [00:00<00:01, 11.4MB/s]
14%|█▍ | 3.38M/23.8M [00:00<00:02, 10.5MB/s]
19%|█▉ | 4.50M/23.8M [00:00<00:01, 10.8MB/s]
24%|██▎ | 5.62M/23.8M [00:00<00:01, 10.4MB/s]
28%|██▊ | 6.62M/23.8M [00:00<00:01, 10.4MB/s]
32%|███▏ | 7.62M/23.8M [00:00<00:01, 10.4MB/s]
36%|███▋ | 8.62M/23.8M [00:00<00:01, 10.4MB/s]
41%|████ | 9.62M/23.8M [00:00<00:01, 10.4MB/s]
45%|████▍ | 10.6M/23.8M [00:01<00:01, 10.4MB/s]
49%|████▉ | 11.6M/23.8M [00:01<00:01, 10.4MB/s]
53%|█████▎ | 12.6M/23.8M [00:01<00:01, 10.4MB/s]
57%|█████▋ | 13.6M/23.8M [00:01<00:01, 10.4MB/s]
62%|██████▏ | 14.6M/23.8M [00:01<00:00, 10.4MB/s]
66%|██████▋ | 15.8M/23.8M [00:01<00:00, 10.8MB/s]
71%|███████ | 16.9M/23.8M [00:01<00:00, 10.3MB/s]
76%|███████▌ | 18.0M/23.8M [00:01<00:00, 10.7MB/s]
80%|████████ | 19.1M/23.8M [00:01<00:00, 10.2MB/s]
85%|████████▍ | 20.1M/23.8M [00:02<00:00, 10.3MB/s]
89%|████████▉ | 21.1M/23.8M [00:02<00:00, 10.3MB/s]
93%|█████████▎| 22.1M/23.8M [00:02<00:00, 10.3MB/s]
97%|█████████▋| 23.1M/23.8M [00:02<00:00, 10.4MB/s]
100%|██████████| 23.8M/23.8M [00:02<00:00, 10.5MB/s]
Train the network
verbose = True # print training information
wandb_vis = False # plot curves and images in Weight&Bias
train_dataloader = [
DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, shuffle=True)
for dataset in train_dataset
]
test_dataloader = [
DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, shuffle=False)
for dataset in test_dataset
]
# Initialize the trainer
trainer = dinv.Trainer(
model=model,
epochs=epochs,
scheduler=scheduler,
losses=losses,
optimizer=optimizer,
physics=physics,
device=device,
train_dataloader=train_dataloader,
eval_dataloader=test_dataloader,
save_path=str(CKPT_DIR / operation),
verbose=verbose,
plot_images=True,
show_progress_bar=False, # disable progress bar for better vis in sphinx gallery.
wandb_vis=wandb_vis,
ckp_interval=10,
)
# Train the network
model = trainer.train()
The model has 2069441 trainable parameters
Train epoch 0: MCLoss=0.0, MOILoss=0.0, TotalLoss=0.0, PSNR=15.125
Eval epoch 0: PSNR=13.906
Test the network
trainer.test(test_dataloader)
Eval epoch 0: PSNR=13.906, PSNR no learning=13.689
Test results:
PSNR no learning: 13.689 +- 2.375
PSNR: 13.906 +- 2.371
{'PSNR no learning': np.float64(13.68945770263672), 'PSNR no learning_std': np.float64(2.375037058737495), 'PSNR': np.float64(13.906121826171875), 'PSNR_std': np.float64(2.3711266469715837)}
Total running time of the script: (0 minutes 5.931 seconds)