Scan-specific zero-shot SSDU for MRI#

We demonstrate scan-specific self-supervised learning, that is, learning to reconstruct MRI scans from a single accelerated sample without ground truth.

Here, we demonstrate fine-tuning a pretrained model (deepinv.models.RAM) [1] with the weighted SSDU loss [2][3]. However, note that any of the self-supervised losses can be used to do this with varying performance [4]. For example see the example using Equivariant Imaging [5].

Note that, if more data is available, better results can be obtained by fine-tuning on more samples!

import torch
import deepinv as dinv

device = dinv.utils.get_device()
rng = torch.Generator(device=device).manual_seed(0)
rng_cpu = torch.Generator(device="cpu").manual_seed(0)
Selected GPU 0 with 4064.25 MiB free memory

Data#

First, download a demo single brain MRI volume (FLAIR sequence, SIEMENS Trio Tim 3T scanner) from the FastMRI brain dataset [6], via HuggingFace.

Important

By using this dataset, you confirm that you have agreed to and signed the FastMRI data use agreement.

DATA_DIR = dinv.utils.get_cache_home() / "datasets" / "fastMRI" / "multicoil_train"
SLICE_DIR = DATA_DIR / "slices"
DATA_DIR.mkdir(parents=True, exist_ok=True)
SLICE_DIR.mkdir(exist_ok=True)

dinv.utils.download_example("demo_fastmri_brain_multicoil.h5", DATA_DIR)

We use the FastMRI slice dataset provided in DeepInverse to load the volume and return all 16 slices. The data is returned in the format x, y, params where params is a dictionary containing the acceleration mask (simulated Gaussian mask with acceleration 6) and the estimated coil sensitivity map.

Note

This loading takes a few seconds per slice, as it must estimate the coil sensitivity map on the fly.

dataset = dinv.datasets.FastMRISliceDataset(
    DATA_DIR,
    slice_index="all",
    transform=dinv.datasets.MRISliceTransform(
        mask_generator=dinv.physics.generator.GaussianMaskGenerator(
            img_size=(256, 256),  # this is overridden internally by true image size
            acceleration=6,
            center_fraction=0.08,
            device="cpu",
            rng=rng_cpu,
        ),
        estimate_coil_maps=True,
    ),
)
  0%|          | 0/1 [00:00<?, ?it/s]
100%|██████████| 1/1 [00:00<00:00, 1107.55it/s]

When training with data that is slow to be loaded, it is faster to save the pre-loaded slices:

from tqdm import tqdm

if not any(SLICE_DIR.iterdir()):
    for i, (x, y, params) in tqdm(enumerate(dataset)):
        torch.save([x, y, params], SLICE_DIR / f"{i}.pt")

Then the dataset can be loaded very quickly. We also pre-normalize to bring the data into a more friendly range. We also load a rough noise level as a param to be passed into the physics. The ground truth is loaded for evaluation later.

def loader(f):
    x, y, params = torch.load(f, weights_only=True)
    return x * 1e5, y * 1e5, params | {"sigma": 1e-5 * 1e5}


dataset = dinv.datasets.ImageFolder(SLICE_DIR, x_path="*.pt", loader=loader)

Physics#

The multicoil physics is defined every easily:

Model#

For the model we fine-tune a pretrained model (deepinv.models.RAM) [1]

Loss#

We define the weighted SSDU loss by first defining the generator that generates the splitting masks. These splitting masks are multiplied with the measurements as per the original paper. Finally, the weighted SSDU loss requires knowledge of the original physics generator to define the weight.

Tip

Feel free to use any self-supervised loss you like here!

Training#

We train the model using the self-supervised loss. We randomly split the volume into training slices and validation slices for early stopping (up to a maximum of 100 epochs). Because the FastMRI ground truth are cropped magnitude root-sum-of-squares reconstructions, we define a helper metric for evaluation later.

Hint

This example trains on GPU to accelerate training in the example.

def crop(x_net, x):
    """Crop to GT shape then take magnitude."""
    return dinv.utils.MRIMixin().rss(
        dinv.utils.MRIMixin().crop(x_net, shape=x.shape), multicoil=False
    )


class CropPSNR(dinv.metric.PSNR):
    def forward(self, x_net=None, x=None, *args, **kwargs):
        return super().forward(crop(x_net, x), x, *args, **kwargs)


metric = CropPSNR(max_pixel=None)

train_dataset, val_dataset = torch.utils.data.random_split(
    dataset, (0.8, 0.2), generator=rng_cpu
)

trainer = dinv.Trainer(
    model=model,
    physics=physics,
    losses=loss,
    metrics=metric,
    optimizer=torch.optim.Adam(model.parameters(), lr=1e-6),
    train_dataloader=torch.utils.data.DataLoader(
        train_dataset, shuffle=True, generator=rng_cpu
    ),
    eval_dataloader=torch.utils.data.DataLoader(val_dataset, generator=rng_cpu),
    epochs=0 if str(device) == "cpu" else 100,
    save_path=None,
    early_stop=3,
    early_stop_on_losses=True,
    compute_train_metrics=False,
    compute_eval_losses=True,
    show_progress_bar=False,  # disable progress bar for better vis in sphinx gallery.
    device=device,
)

model = trainer.train()
model = model.eval()
/local/jtachell/deepinv/deepinv/deepinv/training/trainer.py:1356: UserWarning: non_blocking_transfers=True but DataLoader.pin_memory=False; set pin_memory=True to overlap host-device copies with compute.
  self.setup_train()
The model has 35618813 trainable parameters
/local/jtachell/deepinv/deepinv/deepinv/physics/forward.py:978: UserWarning: At least one input physics is a DecomposablePhysics, but resulting physics will not be decomposable. `A_dagger` and `prox_l2` will fall back to approximate methods, which may impact performance.
  warnings.warn(
/local/jtachell/deepinv/deepinv/deepinv/loss/mri/measplit.py:111: UserWarning: WeightedSplittingLoss detected new y width 213 in forward pass. Recalculating weight (this may take a second)...
  warn(
Train epoch 0: TotalLoss=0.276
Eval epoch 0: TotalLoss=0.256, CropPSNR=30.676
Best model saved at epoch 1
Train epoch 1: TotalLoss=0.321
Eval epoch 1: TotalLoss=0.403, CropPSNR=30.72
Best model saved at epoch 2
Train epoch 2: TotalLoss=0.211
Eval epoch 2: TotalLoss=0.296, CropPSNR=30.766
Best model saved at epoch 3
Train epoch 3: TotalLoss=0.265
Eval epoch 3: TotalLoss=0.194, CropPSNR=30.814
Best model saved at epoch 4
Train epoch 4: TotalLoss=0.372
Eval epoch 4: TotalLoss=0.241, CropPSNR=30.844
Best model saved at epoch 5
Train epoch 5: TotalLoss=0.232
Eval epoch 5: TotalLoss=0.421, CropPSNR=30.866
Best model saved at epoch 6
Train epoch 6: TotalLoss=0.232
Eval epoch 6: TotalLoss=0.284, CropPSNR=30.889
Best model saved at epoch 7
Train epoch 7: TotalLoss=0.181
Eval epoch 7: TotalLoss=0.296, CropPSNR=30.915
Best model saved at epoch 8
Early stopping triggered at epoch 7 as validation metrics have not improved in the last 3 validation steps. Disable it with early_stop=None, or modify early_stop>0 to wait for more validation steps.

Evaluation#

Now that the model is trained, we test the model on 3 samples by evaluating the model, plotting and saving the reconstructions and evaluation metrics.

from torch.utils.data._utils.collate import default_collate

for i in [len(dataset) // 2 - 1, len(dataset) // 2, len(dataset) // 2 + 1]:
    # Load slice
    x, y, params = default_collate([dataset[i]])
    x, y, params = (
        x,
        y.to(device),
        {
            k: (v.to(device) if isinstance(v, torch.Tensor) else v)
            for (k, v) in params.items()
        },
    )

    physics.update(**params)

    # Compute baseline reconstructions
    x_adj = physics.A_adjoint(y).detach().cpu()
    x_dag = physics.A_dagger(y).detach().cpu()

    # Evaluate model
    with torch.no_grad():
        x_hat = model(y, physics).detach().cpu()

    dinv.utils.plot(
        {
            "GT": x,
            "Adjoint": crop(x_adj, x),
            "SENSE": crop(x_dag, x),
            "Trained": crop(x_hat, x),
        },
        subtitles=[
            "",
            f"{metric(x_adj, x).item():.2f} dB",
            f"{metric(x_dag, x).item():.2f} dB",
            f"{metric(x_hat, x).item():.2f} dB",
        ],
    )
  • GT, Adjoint, SENSE, Trained
  • GT, Adjoint, SENSE, Trained
  • GT, Adjoint, SENSE, Trained
References:

Total running time of the script: (0 minutes 57.194 seconds)

Gallery generated by Sphinx-Gallery