Note
Go to the end to download the full example code.
Self-supervised MRI reconstruction with Artifact2Artifact#
We demonstrate the self-supervised Artifact2Artifact loss for solving an undersampled sequential MRI reconstruction problem without ground truth.
The Artifact2Artifact loss was introduced in Liu et al. RARE: Image Reconstruction using Deep Priors Learned without Groundtruth.
In our example, we use it to reconstruct static images, where the k-space measurements is a time-sequence, where each time step (phase) consists of sampled lines such that the whole measurement is a set of non-overlapping lines.
For a description of how Artifact2Artifact constructs the loss, see
deepinv.loss.Artifact2ArtifactLoss
.
Note in our implementation, this is a special case of the generic
splitting loss: see deepinv.loss.SplittingLoss
for more
details. See deepinv.loss.Phase2PhaseLoss
for the related
Phase2Phase.
from pathlib import Path
import torch
from torch.utils.data import DataLoader, Subset
from torchvision import transforms
import deepinv as dinv
from deepinv.datasets import SimpleFastMRISliceDataset
from deepinv.utils.demo import demo_mri_model, get_data_home
from deepinv.models.utils import get_weights_url
from deepinv.physics.generator import (
GaussianMaskGenerator,
BernoulliSplittingMaskGenerator,
)
torch.manual_seed(0)
device = dinv.utils.get_freer_gpu() if torch.cuda.is_available() else "cpu"
Load data#
# In this example, we use a mini demo subset of the single-coil `FastMRI dataset <https://fastmri.org/>`_
# as the base image dataset, consisting of knees of size 320x320, and then resized to 128x128 for speed.
#
# .. important::
#
# By using this dataset, you confirm that you have agreed to and signed the `FastMRI data use agreement <https://fastmri.med.nyu.edu/>`_.
#
# .. seealso::
#
# Datasets :class:`deepinv.datasets.FastMRISliceDataset` :class:`deepinv.datasets.SimpleFastMRISliceDataset`
# We provide convenient datasets to easily load both raw and reconstructed FastMRI images.
# You can download more data on the `FastMRI site <https://fastmri.med.nyu.edu/>`_.
#
#
# We use a train set of size 1 and test set of size 1 in this demo for
# speed to fine-tune the original model. To train the original
# model from scratch, use a larger dataset of size ~150.
#
batch_size = 1
H = 128
transform = transforms.Compose([transforms.Resize(H)])
train_dataset = SimpleFastMRISliceDataset(
get_data_home(), transform=transform, train=True, download=True, train_percent=0.5
)
test_dataset = SimpleFastMRISliceDataset(
get_data_home(), transform=transform, train=False, train_percent=0.5
)
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
/home/runner/work/deepinv/deepinv/deepinv/datasets/fastmri.py:105: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
x = torch.load(root_dir / file_name)
Define physics#
We simulate a sequential k-space sampler, that, over the course of 4
phases (i.e. frames), samples 64 lines (i.e 2x total undersampling from
128) with Gaussian weighting (plus a few extra for the ACS signals in
the centre of the k-space). We use
deepinv.physics.SequentialMRI
to do this.
First, we define a static 2x acceleration mask that all measurements use (of shape [B,C,H,W]):
mask_full = GaussianMaskGenerator((2, H, H), acceleration=2, device=device).step(
batch_size=batch_size
)["mask"]
Next, we randomly share the sampled lines across 4 time-phases into a time-varying mask:
# Split only in horizontal direction
masks = [mask_full[..., 0, :]]
splitter = BernoulliSplittingMaskGenerator((2, H), split_ratio=0.5, device=device)
acs = 10
# Split 4 times
for _ in range(2):
new_masks = []
for m in masks:
m1 = splitter.step(batch_size=batch_size, input_mask=m)["mask"]
m2 = m - m1
m1[..., H // 2 - acs // 2 : H // 2 + acs // 2] = 1
m2[..., H // 2 - acs // 2 : H // 2 + acs // 2] = 1
new_masks.extend([m1, m2])
masks = new_masks
# Merge masks into time dimension
mask = torch.stack(masks, 2)
# Convert to vertical lines
mask = torch.stack([mask] * H, -2)
/home/runner/work/deepinv/deepinv/deepinv/physics/generator/inpainting.py:110: UserWarning: Generating pixelwise mask assumes channel in first dimension. For 2D images (i.e. of shape (H,W)) ensure tensor_size is at least 3D (i.e. C,H,W). However, for tensor_size of shape (C,M), this will work as expected.
warn(
Now define physics using this time-varying mask of shape [B,C,T,H,W]:
physics = dinv.physics.SequentialMRI(mask=mask)
Let’s visualise the sequential measurements using a sample image (run this notebook yourself to display the video). We also visualise the frame-by-frame no-learning zero-filled reconstruction.
/home/runner/work/deepinv/deepinv/deepinv/utils/plotting.py:792: UserWarning: IPython can't be found. Install it to use display=True. Skipping...
warn("IPython can't be found. Install it to use display=True. Skipping...")
/opt/hostedtoolcache/Python/3.9.21/x64/lib/python3.9/site-packages/matplotlib/animation.py:872: UserWarning: Animation was deleted without rendering anything. This is most likely not intended. To prevent deletion, assign the Animation to a variable, e.g. `anim`, that exists until you output the Animation using `plt.show()` or `anim.save()`.
warnings.warn(
Also visualise the flattened time-series, recovering the original 2x undersampling mask (note the actual undersampling factor is much lower due to ACS lines):
Total acceleration: tensor(1.3617)
Define model#
As a (static) reconstruction network, we use an unrolled network
(half-quadratic splitting) with a trainable denoising prior based on the
DnCNN architecture as an example of a model-based deep learning architecture
from MoDL.
See deepinv.utils.demo.demo_mri_model()
for details.
Prep loss#
Perform loss on all collected lines by setting dynamic_model
to
False. Then adapt model to perform Artifact2Artifact. We set
split_size=1
to mean that each Artifact chunk containes only 1
frame.
loss = dinv.loss.Artifact2ArtifactLoss(
(2, 4, H, H), split_size=1, dynamic_model=False, device=device
)
model = loss.adapt_model(model)
Train model#
Original model is trained for 100 epochs. We demonstrate loading the pretrained model then fine-tuning with 1 epoch. Report PSNR and SSIM. To train from scratch, simply comment out the model loading code and increase the number of epochs.
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-8)
# Load pretrained model
file_name = "demo_artifact2artifact_mri.pth"
url = get_weights_url(model_name="measplit", file_name=file_name)
ckpt = torch.hub.load_state_dict_from_url(
url, map_location=lambda storage, loc: storage, file_name=file_name
)
model.load_state_dict(ckpt["state_dict"])
optimizer.load_state_dict(ckpt["optimizer"])
# Initialize the trainer
trainer = dinv.Trainer(
model,
physics=physics,
epochs=1,
losses=loss,
optimizer=optimizer,
train_dataloader=train_dataloader,
metrics=[dinv.metric.PSNR(), dinv.metric.SSIM()],
online_measurements=True,
device=device,
save_path=None,
verbose=True,
wandb_vis=False,
show_progress_bar=False,
)
model = trainer.train()
Downloading: "https://huggingface.co/deepinv/measplit/resolve/main/demo_artifact2artifact_mri.pth?download=true" to /home/runner/.cache/torch/hub/checkpoints/demo_artifact2artifact_mri.pth
0%| | 0.00/2.16M [00:00<?, ?B/s]
52%|█████▏ | 1.12M/2.16M [00:00<00:00, 11.2MB/s]
100%|██████████| 2.16M/2.16M [00:00<00:00, 11.0MB/s]
The model has 187019 trainable parameters
Train epoch 0: TotalLoss=0.001, PSNR=34.176, SSIM=0.876
Test the model#
trainer.plot_images = True
trainer.test(test_dataloader)
Eval epoch 0: PSNR=34.844, PSNR no learning=36.681, SSIM=0.875, SSIM no learning=0.852
Test results:
PSNR no learning: 36.681 +- 0.003
PSNR: 34.844 +- 0.000
SSIM no learning: 0.852 +- 0.000
SSIM: 0.875 +- 0.000
{'PSNR no learning': np.float64(36.68080139160156), 'PSNR no learning_std': np.float64(0.003231370587838091), 'PSNR': np.float64(34.84390640258789), 'PSNR_std': 0, 'SSIM no learning': np.float64(0.851647675037384), 'SSIM no learning_std': np.float64(8.194541857761619e-05), 'SSIM': np.float64(0.8746174573898315), 'SSIM_std': np.float64(5.169877507234071e-05)}
Total running time of the script: (0 minutes 1.214 seconds)