Note
New to DeepInverse? Get started with the basics with the 5 minute quickstart tutorial..
Scan-specific zero-shot SSDU for MRI#
We demonstrate scan-specific self-supervised learning, that is, learning to reconstruct MRI scans from a single accelerated sample without ground truth.
Here, we demonstrate fine-tuning a pretrained model (deepinv.models.RAM) [1]
with the weighted SSDU loss [2][3].
However, note that any of the self-supervised losses can be used to do this with varying performance [4].
For example see the example using Equivariant Imaging [5].
Note that, if more data is available, better results can be obtained by fine-tuning on more samples!
import torch
import deepinv as dinv
device = dinv.utils.get_device()
rng = torch.Generator(device=device).manual_seed(0)
rng_cpu = torch.Generator(device="cpu").manual_seed(0)
Selected CPU device
Data#
First, download a demo single brain MRI volume (FLAIR sequence, SIEMENS Trio Tim 3T scanner) from the FastMRI brain dataset [6], via HuggingFace.
Important
By using this dataset, you confirm that you have agreed to and signed the FastMRI data use agreement.
DATA_DIR = dinv.utils.get_data_home() / "fastMRI" / "multicoil_train"
SLICE_DIR = DATA_DIR / "slices"
DATA_DIR.mkdir(parents=True, exist_ok=True)
SLICE_DIR.mkdir(exist_ok=True)
dinv.utils.download_example("demo_fastmri_brain_multicoil.h5", DATA_DIR)
We use the FastMRI slice dataset provided in DeepInverse to load the volume and return all 16 slices.
The data is returned in the format x, y, params where params is a dictionary containing
the acceleration mask (simulated Gaussian mask with acceleration 6) and the
estimated coil sensitivity map.
Note
This loading takes a few seconds per slice, as it must estimate the coil sensitivity map on the fly.
dataset = dinv.datasets.FastMRISliceDataset(
DATA_DIR,
slice_index="all",
transform=dinv.datasets.MRISliceTransform(
mask_generator=dinv.physics.generator.GaussianMaskGenerator(
img_size=(256, 256), # this is overridden internally by true image size
acceleration=6,
center_fraction=0.08,
device="cpu",
rng=rng_cpu,
),
estimate_coil_maps=True,
),
)
0%| | 0/1 [00:00<?, ?it/s]
100%|██████████| 1/1 [00:00<00:00, 1061.85it/s]
When training with data that is slow to be loaded, it is faster to save the pre-loaded slices:
0it [00:00, ?it/s]
1it [00:02, 2.08s/it]
2it [00:04, 2.09s/it]
3it [00:06, 2.15s/it]
4it [00:08, 2.12s/it]
5it [00:10, 2.08s/it]
6it [00:12, 2.04s/it]
7it [00:14, 2.00s/it]
8it [00:16, 1.97s/it]
9it [00:18, 1.94s/it]
10it [00:19, 1.91s/it]
11it [00:21, 1.87s/it]
12it [00:23, 1.84s/it]
13it [00:25, 1.79s/it]
14it [00:26, 1.74s/it]
15it [00:28, 1.69s/it]
16it [00:29, 1.64s/it]
16it [00:29, 1.87s/it]
Then the dataset can be loaded very quickly. We also pre-normalize to bring the data into a more friendly range. We also load a rough noise level as a param to be passed into the physics. The ground truth is loaded for evaluation later.
def loader(f):
x, y, params = torch.load(f, weights_only=True)
return x * 1e5, y * 1e5, params | {"sigma": 1e-5 * 1e5}
dataset = dinv.datasets.ImageFolder(SLICE_DIR, x_path="*.pt", loader=loader)
Physics#
The multicoil physics is defined every easily:
physics = dinv.physics.MultiCoilMRI(
device=device, noise_model=dinv.physics.GaussianNoise(0.0)
)
Model#
For the model we fine-tune a pretrained model (deepinv.models.RAM) [1]
Loss#
We define the weighted SSDU loss by first defining the generator that generates the splitting masks. These splitting masks are multiplied with the measurements as per the original paper. Finally, the weighted SSDU loss requires knowledge of the original physics generator to define the weight.
Tip
Feel free to use any self-supervised loss you like here!
split_generator = dinv.physics.generator.GaussianMaskGenerator(
(256, 256), acceleration=2, center_fraction=0.0, device=device
)
mask_generator = dinv.physics.generator.MultiplicativeSplittingMaskGenerator(
(256, 256), split_generator, device=device
)
physics_generator = dinv.physics.generator.GaussianMaskGenerator(
(256, 256), acceleration=6, center_fraction=0.04, rng=rng, device=device
)
loss = dinv.loss.mri.WeightedSplittingLoss(
mask_generator=mask_generator, physics_generator=physics_generator
)
Training#
We train the model using the self-supervised loss. We randomly split the volume into training slices and validation slices for early stopping (up to a maximum of 100 epochs). Because the FastMRI ground truth are cropped magnitude root-sum-of-squares reconstructions, we define a helper metric for evaluation later.
Hint
This example trains on GPU to accelerate training in the example.
def crop(x_net, x):
"""Crop to GT shape then take magnitude."""
return dinv.utils.MRIMixin().rss(
dinv.utils.MRIMixin().crop(x_net, shape=x.shape), multicoil=False
)
class CropPSNR(dinv.metric.PSNR):
def forward(self, x_net=None, x=None, *args, **kwargs):
return super().forward(crop(x_net, x), x, *args, **kwargs)
metric = CropPSNR(max_pixel=None)
train_dataset, val_dataset = torch.utils.data.random_split(
dataset, (0.8, 0.2), generator=rng_cpu
)
trainer = dinv.Trainer(
model=model,
physics=physics,
losses=loss,
metrics=metric,
optimizer=torch.optim.Adam(model.parameters(), lr=1e-6),
train_dataloader=torch.utils.data.DataLoader(
train_dataset, shuffle=True, generator=rng_cpu
),
eval_dataloader=torch.utils.data.DataLoader(val_dataset, generator=rng_cpu),
epochs=0 if str(device) == "cpu" else 100,
save_path=None,
early_stop=3,
early_stop_on_losses=True,
compute_train_metrics=False,
compute_eval_losses=True,
show_progress_bar=False, # disable progress bar for better vis in sphinx gallery.
device=device,
)
model = trainer.train()
model = model.eval()
The model has 35618813 trainable parameters
/local/jtachell/deepinv/deepinv/deepinv/training/trainer.py:556: UserWarning: No training will be done because epochs (0) <= loaded epoch_start (0) from checkpoint.
warnings.warn(
Evaluation#
Now that the model is trained, we test the model on 3 samples by evaluating the model, plotting and saving the reconstructions and evaluation metrics.
from torch.utils.data._utils.collate import default_collate
for i in [len(dataset) // 2 - 1, len(dataset) // 2, len(dataset) // 2 + 1]:
# Load slice
x, y, params = default_collate([dataset[i]])
x, y, params = (
x,
y.to(device),
{
k: (v.to(device) if isinstance(v, torch.Tensor) else v)
for (k, v) in params.items()
},
)
physics.update(**params)
# Compute baseline reconstructions
x_adj = physics.A_adjoint(y).detach().cpu()
x_dag = physics.A_dagger(y).detach().cpu()
# Evaluate model
with torch.no_grad():
x_hat = model(y, physics).detach().cpu()
dinv.utils.plot(
{
"GT": x,
"Adjoint": crop(x_adj, x),
"SENSE": crop(x_dag, x),
"Trained": crop(x_hat, x),
},
subtitles=[
"",
f"{metric(x_adj, x).item():.2f} dB",
f"{metric(x_dag, x).item():.2f} dB",
f"{metric(x_hat, x).item():.2f} dB",
],
)
- References:
Total running time of the script: (0 minutes 37.932 seconds)


