Note
Go to the end to download the full example code.
Tour of MRI functionality in DeepInverse#
This example presents the various datasets, forward physics and models available in DeepInverse for Magnetic Resonance Imaging (MRI) problems:
Physics:
deepinv.physics.MRI
,deepinv.physics.MultiCoilMRI
,deepinv.physics.DynamicMRI
Datasets: the full FastMRI dataset
deepinv.datasets.FastMRISliceDataset
and a lightweight, easy-to-use subsetdeepinv.datasets.SimpleFastMRISliceDataset
Models:
deepinv.models.VarNet
(VarNet/E2E-VarNet),deepinv.utils.demo.demo_mri_model
(a simple MoDL unrolled model)
Contents:
Get started with FastMRI (singlecoil + multicoil)
Train an accelerated MRI with neural networks
Load raw FastMRI data (singlecoil + multicoil)
Train using raw data
Explore 3D MRI
Explore dynamic MRI
import deepinv as dinv
import torch, torchvision
device = dinv.utils.get_freer_gpu if torch.cuda.is_available() else "cpu"
rng = torch.Generator(device=device).manual_seed(0)
1. Get started with FastMRI#
You can get started with our simple FastMRI mini slice subsets which provide quick, easy-to-use, in-memory datasets which can be used for simulation experiments.
Important
By using this dataset, you confirm that you have agreed to and signed the FastMRI data use agreement.
See also
- Datasets
deepinv.datasets.FastMRISliceDataset
deepinv.datasets.SimpleFastMRISliceDataset
We provide convenient datasets to easily load both raw and reconstructed FastMRI images. You can download more data on the FastMRI site.
Load mini demo knee and brain datasets (original data is 320x320 but we resize to 128 for speed):
transform = torchvision.transforms.Resize(128)
knee_dataset = dinv.datasets.SimpleFastMRISliceDataset(
dinv.utils.get_data_home(),
anatomy="knee",
transform=transform,
train=True,
download=True,
)
brain_dataset = dinv.datasets.SimpleFastMRISliceDataset(
dinv.utils.get_data_home(),
anatomy="brain",
transform=transform,
train=True,
download=True,
)
img_size = knee_dataset[0].shape[-2:] # (128, 128)
dinv.utils.plot({"knee": knee_dataset[0], "brain": brain_dataset[0]})
/home/runner/work/deepinv/deepinv/deepinv/datasets/fastmri.py:105: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
x = torch.load(root_dir / file_name)
0%| | 0/820529 [00:00<?, ?it/s]
100%|██████████| 801k/801k [00:00<00:00, 14.8MB/s]
/home/runner/work/deepinv/deepinv/deepinv/datasets/fastmri.py:110: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
x = torch.load(root_dir / file_name)
0%| | 0/820534 [00:00<?, ?it/s]
100%|██████████| 801k/801k [00:00<00:00, 13.8MB/s]
Let’s start with single-coil MRI. We can define a constant Cartesian 4x
undersampling mask by sampling once from a physics generator. The mask,
data and measurements will all be of shape (B, C, H, W)
where
C=2
is the real and imaginary parts.
physics_generator = dinv.physics.generator.GaussianMaskGenerator(
img_size=img_size, acceleration=4, rng=rng, device=device
)
mask = physics_generator.step()["mask"]
physics = dinv.physics.MRI(mask=mask, img_size=img_size, device=device)
dinv.utils.plot(
{
"x": (x := knee_dataset[0].unsqueeze(0)),
"mask": mask,
"y": physics(x).clamp(-1, 1),
}
)
print("Shapes:", x.shape, physics.mask.shape)
Shapes: torch.Size([1, 2, 128, 128]) torch.Size([1, 2, 128, 128])
We can next generate an accelerated single-coil MRI measurement dataset. Let’s use knees for training and brains for testing.
We can also use the physics generator to randomly sample a new mask per sample, and save the masks alongside the measurements:
dataset_path = dinv.datasets.generate_dataset(
train_dataset=knee_dataset,
test_dataset=brain_dataset,
val_dataset=None,
physics=physics,
physics_generator=physics_generator,
save_physics_generator_params=True,
overwrite_existing=False,
device=device,
save_dir=dinv.utils.get_data_home(),
batch_size=1,
)
train_dataset = dinv.datasets.HDF5Dataset(
dataset_path, split="train", load_physics_generator_params=True
)
test_dataset = dinv.datasets.HDF5Dataset(
dataset_path, split="test", load_physics_generator_params=True
)
dinv.utils.plot(
{
"x0": train_dataset[0][0],
"mask0": train_dataset[0][2]["mask"],
"x1": train_dataset[1][0],
"mask1": train_dataset[1][2]["mask"],
}
)
Dataset has been saved at datasets/dinv_dataset0.h5
We can also simulate multicoil MRI data. Either pass in ground-truth
coil maps, or pass an integer to simulate simple birdcage coil maps. The
measurements y
are now of shape (B, C, N, H, W)
, where N
is
the coil-dimension.
mc_physics = dinv.physics.MultiCoilMRI(img_size=img_size, coil_maps=3, device=device)
dinv.utils.plot(
{
"x": x,
"mask": mask,
"coil_map_0": mc_physics.coil_maps.abs()[:, 0, ...],
"coil_map_1": mc_physics.coil_maps.abs()[:, 1, ...],
"coil_map_2": mc_physics.coil_maps.abs()[:, 2, ...],
"RSS": mc_physics.A_adjoint_A(x, mask=mask, rss=True),
}
)
2. Train an accelerated MRI problem with neural networks#
Next, we train a neural network to solve the MRI inverse problem. We provide various models specifically used for MRI reconstruction. These are unrolled networks which require a backbone denoiser, such as UNet or DnCNN:
denoiser = dinv.models.UNet(
in_channels=2,
out_channels=2,
scales=2,
)
denoiser = dinv.models.DnCNN(
in_channels=2,
out_channels=2,
pretrained=None,
depth=2,
)
These backbones can be used within specific MRI models, such as VarNet/E2E-VarNet and MoDL, for which we provide implementations:
model = dinv.models.VarNet(denoiser, num_cascades=2, mode="varnet").to(device)
model = dinv.utils.demo.demo_mri_model(denoiser, num_iter=2, device=device).to(device)
Now that we have our architecture defined, we can train it with supervised or self-supervised (using Equivariant Imaging) loss. We use the PSNR metric on the complex magnitude.
For the sake of speed in this example, we only use a very small 2-layer DnCNN inside an unrolled network with 2 cascades, and train with 2 images for 1 epoch.
loss = dinv.loss.SupLoss()
loss = dinv.loss.EILoss(transform=dinv.transform.CPABDiffeomorphism())
trainer = dinv.Trainer(
model=model,
physics=physics,
optimizer=torch.optim.Adam(model.parameters()),
train_dataloader=(train_dataloader := torch.utils.data.DataLoader(train_dataset)),
metrics=dinv.metric.PSNR(complex_abs=True),
epochs=1,
show_progress_bar=False,
save_path=None,
)
To improve results in the case of this very short training, we start training from a pretrained model state (trained on 900 images):
url = dinv.models.utils.get_weights_url(
model_name="demo", file_name="demo_tour_mri.pth"
)
ckpt = torch.hub.load_state_dict_from_url(
url, map_location=lambda storage, loc: storage, file_name="demo_tour_mri.pth"
)
trainer.model.load_state_dict(ckpt["state_dict"]) # load the state dict
trainer.optimizer.load_state_dict(ckpt["optimizer"]) # load the optimizer state dict
model = trainer.train() # train the model
trainer.plot_images = True
Downloading: "https://huggingface.co/deepinv/demo/resolve/main/demo_tour_mri.pth?download=true" to /home/runner/.cache/torch/hub/checkpoints/demo_tour_mri.pth
0%| | 0.00/37.4k [00:00<?, ?B/s]
100%|██████████| 37.4k/37.4k [00:00<00:00, 17.9MB/s]
The model has 2376 trainable parameters
Train epoch 0: TotalLoss=0.0, PSNR=30.972
Now that our model is trained, we can test it. Notice that we improve the PSNR compared to the zero-filled reconstruction, both on the train (knee) set and the test (brain) set:
_ = trainer.test(train_dataloader)
_ = trainer.test(torch.utils.data.DataLoader(test_dataset))
Eval epoch 0: PSNR=30.996, PSNR no learning=29.946
Test results:
PSNR no learning: 29.946 +- 0.566
PSNR: 30.996 +- 0.604
Eval epoch 0: PSNR=29.194, PSNR no learning=28.316
Test results:
PSNR no learning: 28.316 +- 0.480
PSNR: 29.194 +- 0.098
3. Load raw FastMRI data#
It is also possible to use the raw data directly.
The raw multi-coil FastMRI data is provided as pairs of (x, y)
where
y
are the fully-sampled k-space measurements of arbitrary size, and
x
are the cropped root-sum-square (RSS) magnitude reconstructions.
dinv.datasets.download_archive(
dinv.utils.get_image_url("demo_fastmri_brain_multicoil.h5"),
dinv.utils.get_data_home() / "brain" / "fastmri.h5",
)
dataset = dinv.datasets.FastMRISliceDataset(
dinv.utils.get_data_home() / "brain", slice_index="middle"
)
x, y = dataset[0]
x, y = x.unsqueeze(0), y.unsqueeze(0)
print("Shapes:", x.shape, y.shape) # x (B, 1, W, W); y (B, C, N, H, W)
img_shape, kspace_shape = x.shape[-2:], y.shape[-2:]
n_coils = y.shape[2]
0%| | 0/58754328 [00:00<?, ?it/s]
2%|▏ | 1.06M/56.0M [00:00<00:05, 10.9MB/s]
4%|▍ | 2.12M/56.0M [00:00<00:05, 10.9MB/s]
6%|▌ | 3.19M/56.0M [00:00<00:05, 10.9MB/s]
8%|▊ | 4.25M/56.0M [00:00<00:04, 11.0MB/s]
9%|▉ | 5.31M/56.0M [00:00<00:05, 10.2MB/s]
11%|█▏ | 6.31M/56.0M [00:00<00:05, 10.2MB/s]
13%|█▎ | 7.31M/56.0M [00:00<00:04, 10.3MB/s]
15%|█▍ | 8.31M/56.0M [00:00<00:04, 10.3MB/s]
17%|█▋ | 9.31M/56.0M [00:00<00:04, 10.3MB/s]
19%|█▊ | 10.4M/56.0M [00:01<00:04, 10.6MB/s]
20%|██ | 11.4M/56.0M [00:01<00:04, 10.6MB/s]
22%|██▏ | 12.5M/56.0M [00:01<00:04, 10.7MB/s]
24%|██▍ | 13.6M/56.0M [00:01<00:04, 10.2MB/s]
26%|██▌ | 14.6M/56.0M [00:01<00:04, 10.2MB/s]
28%|██▊ | 15.6M/56.0M [00:01<00:04, 10.3MB/s]
30%|██▉ | 16.6M/56.0M [00:01<00:04, 10.3MB/s]
31%|███▏ | 17.6M/56.0M [00:01<00:03, 10.4MB/s]
33%|███▎ | 18.6M/56.0M [00:01<00:03, 10.4MB/s]
35%|███▌ | 19.7M/56.0M [00:01<00:03, 10.8MB/s]
37%|███▋ | 20.8M/56.0M [00:02<00:03, 10.9MB/s]
39%|███▉ | 21.8M/56.0M [00:02<00:03, 10.2MB/s]
41%|████ | 22.9M/56.0M [00:02<00:03, 10.5MB/s]
43%|████▎ | 23.9M/56.0M [00:02<00:03, 10.6MB/s]
45%|████▍ | 25.0M/56.0M [00:02<00:03, 10.7MB/s]
47%|████▋ | 26.1M/56.0M [00:02<00:03, 10.1MB/s]
48%|████▊ | 27.1M/56.0M [00:02<00:02, 10.4MB/s]
50%|█████ | 28.2M/56.0M [00:02<00:02, 10.6MB/s]
52%|█████▏ | 29.2M/56.0M [00:02<00:02, 10.7MB/s]
54%|█████▍ | 30.3M/56.0M [00:03<00:02, 10.2MB/s]
56%|█████▌ | 31.3M/56.0M [00:03<00:02, 10.2MB/s]
58%|█████▊ | 32.3M/56.0M [00:03<00:02, 10.2MB/s]
59%|█████▉ | 33.3M/56.0M [00:03<00:02, 10.3MB/s]
61%|██████▏ | 34.4M/56.0M [00:03<00:02, 10.5MB/s]
63%|██████▎ | 35.4M/56.0M [00:03<00:02, 10.7MB/s]
65%|██████▌ | 36.5M/56.0M [00:03<00:01, 10.8MB/s]
67%|██████▋ | 37.6M/56.0M [00:03<00:01, 10.2MB/s]
69%|██████▉ | 38.6M/56.0M [00:03<00:01, 10.2MB/s]
71%|███████ | 39.6M/56.0M [00:03<00:01, 10.3MB/s]
72%|███████▏ | 40.6M/56.0M [00:04<00:01, 10.3MB/s]
74%|███████▍ | 41.6M/56.0M [00:04<00:01, 10.4MB/s]
76%|███████▌ | 42.6M/56.0M [00:04<00:01, 10.3MB/s]
78%|███████▊ | 43.6M/56.0M [00:04<00:01, 10.4MB/s]
80%|███████▉ | 44.6M/56.0M [00:04<00:01, 10.4MB/s]
81%|████████▏ | 45.6M/56.0M [00:04<00:01, 10.4MB/s]
83%|████████▎ | 46.6M/56.0M [00:04<00:00, 10.4MB/s]
85%|████████▍ | 47.6M/56.0M [00:04<00:00, 10.6MB/s]
87%|████████▋ | 48.7M/56.0M [00:04<00:00, 10.7MB/s]
89%|████████▉ | 49.8M/56.0M [00:04<00:00, 10.8MB/s]
91%|█████████ | 50.8M/56.0M [00:05<00:00, 10.2MB/s]
92%|█████████▏| 51.8M/56.0M [00:05<00:00, 10.2MB/s]
94%|█████████▍| 52.8M/56.0M [00:05<00:00, 10.3MB/s]
96%|█████████▌| 53.8M/56.0M [00:05<00:00, 10.3MB/s]
98%|█████████▊| 54.8M/56.0M [00:05<00:00, 10.3MB/s]
100%|█████████▉| 55.8M/56.0M [00:05<00:00, 10.4MB/s]
100%|██████████| 56.0M/56.0M [00:05<00:00, 10.5MB/s]
0%| | 0/1 [00:00<?, ?it/s]
100%|██████████| 1/1 [00:00<00:00, 1618.80it/s]
Shapes: torch.Size([1, 1, 213, 213]) torch.Size([1, 2, 4, 512, 213])
We can relate x
and y
using our
deepinv.physics.MultiCoilMRI
(note that since we are not
provided with the ground-truth coil-maps, we can only perform the
adjoint operator).
physics = dinv.physics.MultiCoilMRI(
img_size=img_shape,
mask=torch.ones(kspace_shape),
coil_maps=torch.ones((n_coils,) + kspace_shape, dtype=torch.complex64),
device=device,
)
x_rss = physics.A_adjoint(y, rss=True, crop=True)
assert torch.allclose(x, x_rss)
4. Train using raw data#
We now use a mask generator to generate acceleration masks on-the-fly
(online) during training. We use the E2E-VarNet model designed for
multicoil MRI. We do not perform coil sensitivity map estimation and
simply assume they are flat as above. To do this yourself, pass a model
as the sensitivity_model
parameter.
physics_generator = dinv.physics.generator.GaussianMaskGenerator(
img_size=kspace_shape, acceleration=4, rng=rng, device=device
)
model = dinv.models.VarNet(denoiser, num_cascades=2, mode="e2e-varnet").to(device)
Note that we require overriding the base
deepinv.training.Trainer
to deal with raw measurements, as we
do not want to generate k-space measurements, only mask it.
Note
We require loop_physics_generator=True and shuffle=False in the dataloader to ensure that each image is always matched with the same random mask at each iteration.
class RawFastMRITrainer(dinv.Trainer):
def get_samples_online(self, iterators, g):
# Get data
x, y = next(iterators[g])
x, y = x.to(self.device), y.to(self.device)
# Get physics
physics = self.physics[g]
# Generate random mask
params = self.physics_generator[g].step(
batch_size=y.size(0), img_size=y.shape[-2:]
)
# Generate measurements directly from raw measurements
y *= params["mask"]
physics.update_parameters(**params)
return x, y, physics
We also need to modify the metrics used to crop the model output when comparing to the cropped magnitude RSS targets:
transform = torchvision.transforms.Compose(
[
torchvision.transforms.CenterCrop(x.shape[-2:]),
dinv.metric.functional.complex_abs,
]
)
class CropMSE(dinv.metric.MSE):
def forward(self, x_net=None, x=None, *args, **kwargs):
return super().forward(transform(x_net), x, *args, **kwargs)
class CropPSNR(dinv.metric.PSNR):
def forward(self, x_net=None, x=None, *args, **kwargs):
return super().forward(transform(x_net), x, *args, **kwargs)
trainer = RawFastMRITrainer(
model=model,
physics=physics,
physics_generator=physics_generator,
online_measurements=True,
loop_physics_generator=True,
losses=dinv.loss.SupLoss(metric=CropMSE()),
metrics=CropPSNR(),
optimizer=torch.optim.Adam(model.parameters()),
train_dataloader=torch.utils.data.DataLoader(dataset, shuffle=False),
epochs=1,
save_path=None,
show_progress_bar=False,
)
_ = trainer.train()
/home/runner/work/deepinv/deepinv/deepinv/training/trainer.py:192: UserWarning: Generated measurements repeat each epoch. Ensure that dataloader is not shuffling.
warnings.warn(
The model has 2372 trainable parameters
Train epoch 0: TotalLoss=0.752, CropPSNR=1.24
5. Explore 3D MRI#
We can also simulate 3D MRI data.
Here, we use a demo 3D brain volume of shape (181, 217, 181)
from the
BrainWeb dataset
and simulate 3D single-coil or multi-coil Fourier measurements using
deepinv.physics.MRI
or
deepinv.physics.MultiCoilMRI
.
x = (
torch.from_numpy(
dinv.utils.demo.load_np_url(
"https://huggingface.co/datasets/deepinv/images/resolve/main/brainweb_t1_ICBM_1mm_subject_0.npy?download=true"
)
)
.unsqueeze(0)
.unsqueeze(0)
.to(device)
)
x = torch.cat([x, torch.zeros_like(x)], dim=1) # add imaginary dimension
print(x.shape) # (B, C, D, H, W) where D is depth
physics = dinv.physics.MultiCoilMRI(img_size=x.shape[1:], three_d=True, device=device)
physics = dinv.physics.MRI(img_size=x.shape[1:], three_d=True, device=device)
dinv.utils.plot_ortho3D([x, physics(x)], titles=["x", "y"])
torch.Size([1, 2, 181, 217, 181])
6. Explore dynamic MRI#
Finally, we show how to use the dynamic MRI for image sequence data of
shape (B, C, T, H, W)
where T
is the time dimension. Note that
this is also compatible with 3D MRI. We simulate an MRI image sequence
using the first 2 knees (T=2):
x = torch.stack([knee_dataset[i] for i in range(2)], dim=1).unsqueeze(0)
Generate a Cartesian k-t sampling mask and simulate k-t-space measurements:
physics_generator = dinv.physics.generator.EquispacedMaskGenerator(
img_size=img_size, acceleration=4, rng=rng, device=device
)
mask = physics_generator.step()["mask"]
physics = dinv.physics.DynamicMRI(mask=mask, img_size=img_size, device=device)
y = physics(x)
print(x.shape, physics(x).shape) # B,C,T,H,W
torch.Size([1, 2, 2, 128, 128]) torch.Size([1, 2, 2, 128, 128])
Total running time of the script: (0 minutes 13.024 seconds)