Self-supervised learning with Equivariant Imaging for MRI.#

This example shows you how to train a reconstruction network for an MRI inverse problem on a fully self-supervised way, i.e., using measurement data only.

The equivariant imaging loss is presented in “Equivariant Imaging: Learning Beyond the Range Space”.

import deepinv as dinv
from torch.utils.data import DataLoader
import torch
from pathlib import Path
from torchvision import transforms
from deepinv.optim.prior import PnP
from deepinv.utils.demo import load_dataset, load_degradation, demo_mri_model
from deepinv.models.utils import get_weights_url

Setup paths for data loading and results.#

BASE_DIR = Path(".")
DATA_DIR = BASE_DIR / "measurements"
CKPT_DIR = BASE_DIR / "ckpts"

# Set the global random seed from pytorch to ensure reproducibility of the example.
torch.manual_seed(0)

device = dinv.utils.get_freer_gpu() if torch.cuda.is_available() else "cpu"

Load base image datasets and degradation operators.#

In this example, we use a subset of the single-coil FastMRI dataset as the base image dataset. It consists of 973 knee images of size 320x320.

Note

We reduce to the size to 128x128 for faster training in the demo.

operation = "MRI"
train_dataset_name = "fastmri_knee_singlecoil"
img_size = 128

transform = transforms.Compose([transforms.Resize(img_size)])

train_dataset = load_dataset(train_dataset_name, transform, train=True)
test_dataset = load_dataset(train_dataset_name, transform, train=False)
/home/runner/work/deepinv/deepinv/deepinv/utils/demo.py:22: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
  x = torch.load(str(root_dir) + ".pt")

Generate a dataset of knee images and load it.#

mask = load_degradation("mri_mask_128x128.npy")

# defined physics
physics = dinv.physics.MRI(mask=mask, device=device)

# Use parallel dataloader if using a GPU to fasten training,
# otherwise, as all computes are on CPU, use synchronous data loading.
num_workers = 4 if torch.cuda.is_available() else 0
n_images_max = (
    900 if torch.cuda.is_available() else 5
)  # number of images used for training
# (the dataset has up to 973 images, however here we use only 900)

my_dataset_name = "demo_equivariant_imaging"
measurement_dir = DATA_DIR / train_dataset_name / operation
deepinv_datasets_path = dinv.datasets.generate_dataset(
    train_dataset=train_dataset,
    test_dataset=test_dataset,
    physics=physics,
    device=device,
    save_dir=measurement_dir,
    train_datapoints=n_images_max,
    num_workers=num_workers,
    dataset_filename=str(my_dataset_name),
)

train_dataset = dinv.datasets.HDF5Dataset(path=deepinv_datasets_path, train=True)
test_dataset = dinv.datasets.HDF5Dataset(path=deepinv_datasets_path, train=False)
mri_mask_128x128.npy degradation downloaded in datasets
Dataset has been saved at measurements/fastmri_knee_singlecoil/MRI/demo_equivariant_imaging0.h5

Set up the reconstruction network#

As a reconstruction network, we use an unrolled network (half-quadratic splitting) with a trainable denoising prior based on the DnCNN architecture as an example of a model-based deep learning architecture from MoDL. See deepinv.utils.demo.demo_mri_model() for details.

model = demo_mri_model(device=device)

Set up the training parameters#

We choose a self-supervised training scheme with two losses: the measurement consistency loss (MC) and the equivariant imaging loss (EI). The EI loss requires a group of transformations to be defined. The forward model should not be equivariant to these transformations. Here we use the group of 4 rotations of 90 degrees, as the accelerated MRI acquisition is not equivariant to rotations (while it is equivariant to translations).

See docs for full list of available transforms.

Note

We use a pretrained model to reduce training time. You can get the same results by training from scratch for 150 epochs.

epochs = 1  # choose training epochs
learning_rate = 5e-4
batch_size = 16 if torch.cuda.is_available() else 1

# choose self-supervised training losses
# generates 4 random rotations per image in the batch
losses = [dinv.loss.MCLoss(), dinv.loss.EILoss(dinv.transform.Rotate(n_trans=4))]

# choose optimizer and scheduler
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=1e-8)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=int(epochs * 0.8) + 1)

# start with a pretrained model to reduce training time
file_name = "new_demo_ei_ckp_150_v3.pth"
url = get_weights_url(model_name="demo", file_name=file_name)
ckpt = torch.hub.load_state_dict_from_url(
    url,
    map_location=lambda storage, loc: storage,
    file_name=file_name,
)
# load a checkpoint to reduce training time
model.load_state_dict(ckpt["state_dict"])
optimizer.load_state_dict(ckpt["optimizer"])
Downloading: "https://huggingface.co/deepinv/demo/resolve/main/new_demo_ei_ckp_150_v3.pth?download=true" to /home/runner/.cache/torch/hub/checkpoints/new_demo_ei_ckp_150_v3.pth

  0%|          | 0.00/2.17M [00:00<?, ?B/s]
 52%|█████▏    | 1.12M/2.17M [00:00<00:00, 10.7MB/s]
100%|██████████| 2.17M/2.17M [00:00<00:00, 10.8MB/s]
100%|██████████| 2.17M/2.17M [00:00<00:00, 10.8MB/s]

Train the network#

verbose = True  # print training information
wandb_vis = False  # plot curves and images in Weight&Bias

train_dataloader = DataLoader(
    train_dataset, batch_size=batch_size, num_workers=num_workers, shuffle=True
)
test_dataloader = DataLoader(
    test_dataset, batch_size=batch_size, num_workers=num_workers, shuffle=False
)

# Initialize the trainer
trainer = dinv.Trainer(
    model,
    physics=physics,
    epochs=epochs,
    scheduler=scheduler,
    losses=losses,
    optimizer=optimizer,
    train_dataloader=train_dataloader,
    plot_images=True,
    device=device,
    save_path=str(CKPT_DIR / operation),
    verbose=verbose,
    wandb_vis=wandb_vis,
    show_progress_bar=False,  # disable progress bar for better vis in sphinx gallery.
    ckp_interval=10,
)

model = trainer.train()
Ground truth, Measurement, Reconstruction
The model has 187019 trainable parameters
Train epoch 0: MCLoss=0.0, EILoss=0.0, TotalLoss=0.0, PSNR=37.645

Test the network#

trainer.test(test_dataloader)
Ground truth, Measurement, No learning, Reconstruction
Eval epoch 0: PSNR=38.288, PSNR no learning=29.389
Test results:
PSNR no learning: 29.389 +- 3.411
PSNR: 38.288 +- 2.265

{'PSNR no learning': np.float64(29.38880230629281), 'PSNR no learning_std': np.float64(3.411393798566601), 'PSNR': np.float64(38.28786186322774), 'PSNR_std': np.float64(2.2652720747247512)}

Total running time of the script: (0 minutes 19.627 seconds)

Gallery generated by Sphinx-Gallery