Note
New to DeepInverse? Get started with the basics with the 5 minute quickstart tutorial..
Scan-specific zero-shot SSDU for MRI#
We demonstrate scan-specific self-supervised learning, that is, learning to reconstruct MRI scans from a single accelerated sample without ground truth.
Here, we demonstrate fine-tuning a pretrained model (deepinv.models.RAM) [1]
with the weighted SSDU loss [2][3].
However, note that any of the self-supervised losses can be used to do this with varying performance [4].
For example see the example using Equivariant Imaging [5].
Note that, if more data is available, better results can be obtained by fine-tuning on more samples!
import torch
import deepinv as dinv
device = dinv.utils.get_device()
rng = torch.Generator(device=device).manual_seed(0)
rng_cpu = torch.Generator(device="cpu").manual_seed(0)
Selected GPU 0 with 4064.25 MiB free memory
Data#
First, download a demo single brain MRI volume (FLAIR sequence, SIEMENS Trio Tim 3T scanner) from the FastMRI brain dataset [6], via HuggingFace.
Important
By using this dataset, you confirm that you have agreed to and signed the FastMRI data use agreement.
DATA_DIR = dinv.utils.get_data_home() / "fastMRI" / "multicoil_train"
SLICE_DIR = DATA_DIR / "slices"
DATA_DIR.mkdir(parents=True, exist_ok=True)
SLICE_DIR.mkdir(exist_ok=True)
dinv.utils.download_example("demo_fastmri_brain_multicoil.h5", DATA_DIR)
We use the FastMRI slice dataset provided in DeepInverse to load the volume and return all 16 slices.
The data is returned in the format x, y, params where params is a dictionary containing
the acceleration mask (simulated Gaussian mask with acceleration 6) and the
estimated coil sensitivity map.
Note
This loading takes a few seconds per slice, as it must estimate the coil sensitivity map on the fly.
dataset = dinv.datasets.FastMRISliceDataset(
DATA_DIR,
slice_index="all",
transform=dinv.datasets.MRISliceTransform(
mask_generator=dinv.physics.generator.GaussianMaskGenerator(
img_size=(256, 256), # this is overridden internally by true image size
acceleration=6,
center_fraction=0.08,
device="cpu",
rng=rng_cpu,
),
estimate_coil_maps=True,
),
)
0%| | 0/1 [00:00<?, ?it/s]
100%|██████████| 1/1 [00:00<00:00, 1009.22it/s]
When training with data that is slow to be loaded, it is faster to save the pre-loaded slices:
0it [00:00, ?it/s]
1it [00:01, 1.99s/it]
2it [00:03, 1.99s/it]
3it [00:06, 2.02s/it]
4it [00:07, 1.97s/it]
5it [00:09, 1.97s/it]
6it [00:11, 2.01s/it]
7it [00:13, 1.95s/it]
8it [00:15, 1.94s/it]
9it [00:17, 1.91s/it]
10it [00:19, 1.85s/it]
11it [00:20, 1.79s/it]
12it [00:22, 1.75s/it]
13it [00:24, 1.71s/it]
14it [00:25, 1.67s/it]
15it [00:27, 1.61s/it]
16it [00:28, 1.54s/it]
16it [00:28, 1.79s/it]
Then the dataset can be loaded very quickly. We also pre-normalize to bring the data into a more friendly range. We also load a rough noise level as a param to be passed into the physics. The ground truth is loaded for evaluation later.
def loader(f):
x, y, params = torch.load(f, weights_only=True)
return x * 1e5, y * 1e5, params | {"sigma": 1e-5 * 1e5}
dataset = dinv.datasets.ImageFolder(SLICE_DIR, x_path="*.pt", loader=loader)
Physics#
The multicoil physics is defined every easily:
physics = dinv.physics.MultiCoilMRI(
device=device, noise_model=dinv.physics.GaussianNoise(0.0)
)
Model#
For the model we fine-tune a pretrained model (deepinv.models.RAM) [1]
Loss#
We define the weighted SSDU loss by first defining the generator that generates the splitting masks. These splitting masks are multiplied with the measurements as per the original paper. Finally, the weighted SSDU loss requires knowledge of the original physics generator to define the weight.
Tip
Feel free to use any self-supervised loss you like here!
split_generator = dinv.physics.generator.GaussianMaskGenerator(
(256, 256), acceleration=2, center_fraction=0.0, device=device
)
mask_generator = dinv.physics.generator.MultiplicativeSplittingMaskGenerator(
(256, 256), split_generator, device=device
)
physics_generator = dinv.physics.generator.GaussianMaskGenerator(
(256, 256), acceleration=6, center_fraction=0.04, rng=rng, device=device
)
loss = dinv.loss.mri.WeightedSplittingLoss(
mask_generator=mask_generator, physics_generator=physics_generator
)
Training#
We train the model using the self-supervised loss. We randomly split the volume into training slices and validation slices for early stopping (up to a maximum of 100 epochs). Because the FastMRI ground truth are cropped magnitude root-sum-of-squares reconstructions, we define a helper metric for evaluation later.
Hint
This example trains on GPU to accelerate training in the example.
def crop(x_net, x):
"""Crop to GT shape then take magnitude."""
return dinv.utils.MRIMixin().rss(
dinv.utils.MRIMixin().crop(x_net, shape=x.shape), multicoil=False
)
class CropPSNR(dinv.metric.PSNR):
def forward(self, x_net=None, x=None, *args, **kwargs):
return super().forward(crop(x_net, x), x, *args, **kwargs)
metric = CropPSNR(max_pixel=None)
train_dataset, val_dataset = torch.utils.data.random_split(
dataset, (0.8, 0.2), generator=rng_cpu
)
trainer = dinv.Trainer(
model=model,
physics=physics,
losses=loss,
metrics=metric,
optimizer=torch.optim.Adam(model.parameters(), lr=1e-6),
train_dataloader=torch.utils.data.DataLoader(
train_dataset, shuffle=True, generator=rng_cpu
),
eval_dataloader=torch.utils.data.DataLoader(val_dataset, generator=rng_cpu),
epochs=0 if str(device) == "cpu" else 100,
save_path=None,
early_stop=3,
early_stop_on_losses=True,
compute_train_metrics=False,
compute_eval_losses=True,
show_progress_bar=False, # disable progress bar for better vis in sphinx gallery.
device=device,
)
model = trainer.train()
model = model.eval()
/local/jtachell/deepinv/deepinv/deepinv/training/trainer.py:1354: UserWarning: non_blocking_transfers=True but DataLoader.pin_memory=False; set pin_memory=True to overlap host-device copies with compute.
self.setup_train()
The model has 35618813 trainable parameters
/local/jtachell/deepinv/deepinv/deepinv/physics/forward.py:978: UserWarning: At least one input physics is a DecomposablePhysics, but resulting physics will not be decomposable. `A_dagger` and `prox_l2` will fall back to approximate methods, which may impact performance.
warnings.warn(
/local/jtachell/deepinv/deepinv/deepinv/loss/mri/measplit.py:111: UserWarning: WeightedSplittingLoss detected new y width 213 in forward pass. Recalculating weight (this may take a second)...
warn(
Train epoch 0: TotalLoss=0.316
Eval epoch 0: TotalLoss=0.224, CropPSNR=30.974
Best model saved at epoch 1
Train epoch 1: TotalLoss=0.304
Eval epoch 1: TotalLoss=0.347, CropPSNR=31.01
Best model saved at epoch 2
Train epoch 2: TotalLoss=0.263
Eval epoch 2: TotalLoss=0.259, CropPSNR=31.041
Best model saved at epoch 3
Train epoch 3: TotalLoss=0.288
Eval epoch 3: TotalLoss=0.214, CropPSNR=31.062
Best model saved at epoch 4
Train epoch 4: TotalLoss=0.284
Eval epoch 4: TotalLoss=0.233, CropPSNR=31.088
Best model saved at epoch 5
Train epoch 5: TotalLoss=0.261
Eval epoch 5: TotalLoss=0.346, CropPSNR=31.107
Best model saved at epoch 6
Train epoch 6: TotalLoss=0.227
Eval epoch 6: TotalLoss=0.19, CropPSNR=31.131
Best model saved at epoch 7
Train epoch 7: TotalLoss=0.316
Eval epoch 7: TotalLoss=0.176, CropPSNR=31.151
Best model saved at epoch 8
Train epoch 8: TotalLoss=0.265
Eval epoch 8: TotalLoss=0.285, CropPSNR=31.162
Best model saved at epoch 9
Train epoch 9: TotalLoss=0.253
Eval epoch 9: TotalLoss=0.285, CropPSNR=31.176
Best model saved at epoch 10
Train epoch 10: TotalLoss=0.437
Eval epoch 10: TotalLoss=0.141, CropPSNR=31.178
Best model saved at epoch 11
Train epoch 11: TotalLoss=0.224
Eval epoch 11: TotalLoss=0.508, CropPSNR=31.188
Best model saved at epoch 12
Train epoch 12: TotalLoss=0.35
Eval epoch 12: TotalLoss=0.363, CropPSNR=31.199
Best model saved at epoch 13
Train epoch 13: TotalLoss=0.255
Eval epoch 13: TotalLoss=0.211, CropPSNR=31.211
Best model saved at epoch 14
Train epoch 14: TotalLoss=0.484
Eval epoch 14: TotalLoss=0.191, CropPSNR=31.235
Best model saved at epoch 15
Early stopping triggered at epoch 14 as validation metrics have not improved in the last 3 validation steps. Disable it with early_stop=None, or modify early_stop>0 to wait for more validation steps.
Evaluation#
Now that the model is trained, we test the model on 3 samples by evaluating the model, plotting and saving the reconstructions and evaluation metrics.
from torch.utils.data._utils.collate import default_collate
for i in [len(dataset) // 2 - 1, len(dataset) // 2, len(dataset) // 2 + 1]:
# Load slice
x, y, params = default_collate([dataset[i]])
x, y, params = (
x,
y.to(device),
{
k: (v.to(device) if isinstance(v, torch.Tensor) else v)
for (k, v) in params.items()
},
)
physics.update(**params)
# Compute baseline reconstructions
x_adj = physics.A_adjoint(y).detach().cpu()
x_dag = physics.A_dagger(y).detach().cpu()
# Evaluate model
with torch.no_grad():
x_hat = model(y, physics).detach().cpu()
dinv.utils.plot(
{
"GT": x,
"Adjoint": crop(x_adj, x),
"SENSE": crop(x_dag, x),
"Trained": crop(x_hat, x),
},
subtitles=[
"",
f"{metric(x_adj, x).item():.2f} dB",
f"{metric(x_dag, x).item():.2f} dB",
f"{metric(x_hat, x).item():.2f} dB",
],
)
- References:
Total running time of the script: (2 minutes 3.765 seconds)


