Scan-specific zero-shot SSDU for MRI#

We demonstrate scan-specific self-supervised learning, that is, learning to reconstruct MRI scans from a single accelerated sample without ground truth.

Here, we demonstrate fine-tuning a pretrained model (deepinv.models.RAM) [1] with the weighted SSDU loss [2][3]. However, note that any of the self-supervised losses can be used to do this with varying performance [4]. For example see the example using Equivariant Imaging [5].

Note that, if more data is available, better results can be obtained by fine-tuning on more samples!

import torch
import deepinv as dinv

device = dinv.utils.get_device()
rng = torch.Generator(device=device).manual_seed(0)
rng_cpu = torch.Generator(device="cpu").manual_seed(0)
Selected CPU device

Data#

First, download a demo single brain MRI volume (FLAIR sequence, SIEMENS Trio Tim 3T scanner) from the FastMRI brain dataset [6], via HuggingFace.

Important

By using this dataset, you confirm that you have agreed to and signed the FastMRI data use agreement.

DATA_DIR = dinv.utils.get_data_home() / "fastMRI" / "multicoil_train"
SLICE_DIR = DATA_DIR / "slices"
DATA_DIR.mkdir(parents=True, exist_ok=True)
SLICE_DIR.mkdir(exist_ok=True)

dinv.utils.download_example("demo_fastmri_brain_multicoil.h5", DATA_DIR)

We use the FastMRI slice dataset provided in DeepInverse to load the volume and return all 16 slices. The data is returned in the format x, y, params where params is a dictionary containing the acceleration mask (simulated Gaussian mask with acceleration 6) and the estimated coil sensitivity map.

Note

This loading takes a few seconds per slice, as it must estimate the coil sensitivity map on the fly.

dataset = dinv.datasets.FastMRISliceDataset(
    DATA_DIR,
    slice_index="all",
    transform=dinv.datasets.MRISliceTransform(
        mask_generator=dinv.physics.generator.GaussianMaskGenerator(
            img_size=(256, 256),  # this is overridden internally by true image size
            acceleration=6,
            center_fraction=0.08,
            device="cpu",
            rng=rng_cpu,
        ),
        estimate_coil_maps=True,
    ),
)
  0%|          | 0/1 [00:00<?, ?it/s]
100%|██████████| 1/1 [00:00<00:00, 1061.85it/s]

When training with data that is slow to be loaded, it is faster to save the pre-loaded slices:

from tqdm import tqdm

if not any(SLICE_DIR.iterdir()):
    for i, (x, y, params) in tqdm(enumerate(dataset)):
        torch.save([x, y, params], SLICE_DIR / f"{i}.pt")
0it [00:00, ?it/s]
1it [00:02,  2.08s/it]
2it [00:04,  2.09s/it]
3it [00:06,  2.15s/it]
4it [00:08,  2.12s/it]
5it [00:10,  2.08s/it]
6it [00:12,  2.04s/it]
7it [00:14,  2.00s/it]
8it [00:16,  1.97s/it]
9it [00:18,  1.94s/it]
10it [00:19,  1.91s/it]
11it [00:21,  1.87s/it]
12it [00:23,  1.84s/it]
13it [00:25,  1.79s/it]
14it [00:26,  1.74s/it]
15it [00:28,  1.69s/it]
16it [00:29,  1.64s/it]
16it [00:29,  1.87s/it]

Then the dataset can be loaded very quickly. We also pre-normalize to bring the data into a more friendly range. We also load a rough noise level as a param to be passed into the physics. The ground truth is loaded for evaluation later.

def loader(f):
    x, y, params = torch.load(f, weights_only=True)
    return x * 1e5, y * 1e5, params | {"sigma": 1e-5 * 1e5}


dataset = dinv.datasets.ImageFolder(SLICE_DIR, x_path="*.pt", loader=loader)

Physics#

The multicoil physics is defined every easily:

Model#

For the model we fine-tune a pretrained model (deepinv.models.RAM) [1]

Loss#

We define the weighted SSDU loss by first defining the generator that generates the splitting masks. These splitting masks are multiplied with the measurements as per the original paper. Finally, the weighted SSDU loss requires knowledge of the original physics generator to define the weight.

Tip

Feel free to use any self-supervised loss you like here!

Training#

We train the model using the self-supervised loss. We randomly split the volume into training slices and validation slices for early stopping (up to a maximum of 100 epochs). Because the FastMRI ground truth are cropped magnitude root-sum-of-squares reconstructions, we define a helper metric for evaluation later.

Hint

This example trains on GPU to accelerate training in the example.

def crop(x_net, x):
    """Crop to GT shape then take magnitude."""
    return dinv.utils.MRIMixin().rss(
        dinv.utils.MRIMixin().crop(x_net, shape=x.shape), multicoil=False
    )


class CropPSNR(dinv.metric.PSNR):
    def forward(self, x_net=None, x=None, *args, **kwargs):
        return super().forward(crop(x_net, x), x, *args, **kwargs)


metric = CropPSNR(max_pixel=None)

train_dataset, val_dataset = torch.utils.data.random_split(
    dataset, (0.8, 0.2), generator=rng_cpu
)

trainer = dinv.Trainer(
    model=model,
    physics=physics,
    losses=loss,
    metrics=metric,
    optimizer=torch.optim.Adam(model.parameters(), lr=1e-6),
    train_dataloader=torch.utils.data.DataLoader(
        train_dataset, shuffle=True, generator=rng_cpu
    ),
    eval_dataloader=torch.utils.data.DataLoader(val_dataset, generator=rng_cpu),
    epochs=0 if str(device) == "cpu" else 100,
    save_path=None,
    early_stop=3,
    early_stop_on_losses=True,
    compute_train_metrics=False,
    compute_eval_losses=True,
    show_progress_bar=False,  # disable progress bar for better vis in sphinx gallery.
    device=device,
)

model = trainer.train()
model = model.eval()
The model has 35618813 trainable parameters
/local/jtachell/deepinv/deepinv/deepinv/training/trainer.py:556: UserWarning: No training will be done because epochs (0) <= loaded epoch_start (0) from checkpoint.
  warnings.warn(

Evaluation#

Now that the model is trained, we test the model on 3 samples by evaluating the model, plotting and saving the reconstructions and evaluation metrics.

from torch.utils.data._utils.collate import default_collate

for i in [len(dataset) // 2 - 1, len(dataset) // 2, len(dataset) // 2 + 1]:
    # Load slice
    x, y, params = default_collate([dataset[i]])
    x, y, params = (
        x,
        y.to(device),
        {
            k: (v.to(device) if isinstance(v, torch.Tensor) else v)
            for (k, v) in params.items()
        },
    )

    physics.update(**params)

    # Compute baseline reconstructions
    x_adj = physics.A_adjoint(y).detach().cpu()
    x_dag = physics.A_dagger(y).detach().cpu()

    # Evaluate model
    with torch.no_grad():
        x_hat = model(y, physics).detach().cpu()

    dinv.utils.plot(
        {
            "GT": x,
            "Adjoint": crop(x_adj, x),
            "SENSE": crop(x_dag, x),
            "Trained": crop(x_hat, x),
        },
        subtitles=[
            "",
            f"{metric(x_adj, x).item():.2f} dB",
            f"{metric(x_dag, x).item():.2f} dB",
            f"{metric(x_hat, x).item():.2f} dB",
        ],
    )
  • GT, Adjoint, SENSE, Trained
  • GT, Adjoint, SENSE, Trained
  • GT, Adjoint, SENSE, Trained
References:

Total running time of the script: (0 minutes 37.932 seconds)

Gallery generated by Sphinx-Gallery