Self-supervised MRI reconstruction with Artifact2Artifact

We demonstrate the self-supervised Artifact2Artifact loss for solving an undersampled sequential MRI reconstruction problem without ground truth.

The Artifact2Artifact loss was introduced in Liu et al. RARE: Image Reconstruction using Deep Priors Learned without Groundtruth.

In our example, we use it to reconstruct static images, where the k-space measurements is a time-sequence, where each time step (phase) consists of sampled lines such that the whole measurement is a set of non-overlapping lines.

For a description of how Artifact2Artifact constructs the loss, see deepinv.loss.Artifact2ArtifactLoss.

Note in our implementation, this is a special case of the generic splitting loss: see deepinv.loss.SplittingLoss for more details. See deepinv.loss.Phase2PhaseLoss for the related Phase2Phase.

from pathlib import Path
import torch
from torch.utils.data import DataLoader, Subset
from torchvision import transforms

import deepinv as dinv
from deepinv.utils.demo import load_dataset, demo_mri_model
from deepinv.models.utils import get_weights_url
from deepinv.physics.generator import (
    GaussianMaskGenerator,
    BernoulliSplittingMaskGenerator,
)

torch.manual_seed(0)
device = dinv.utils.get_freer_gpu() if torch.cuda.is_available() else "cpu"

Load data

We use a subset of single coil knee MRI data from the fastMRI challenge and resize to 128x128. We use a train set of size 5 in this demo for speed to fine-tune the original model. Set to 150 to train the original model from scratch.

batch_size = 1
H = 128

transform = transforms.Compose([transforms.Resize(H)])

train_dataset = load_dataset(
    "fastmri_knee_singlecoil", Path("."), transform, train=True
)
test_dataset = load_dataset(
    "fastmri_knee_singlecoil", Path("."), transform, train=False
)

train_dataset = Subset(train_dataset, torch.arange(5))
test_dataset = Subset(test_dataset, torch.arange(30))

train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
Downloading fastmri_knee_singlecoil.pt

  0%|          | 0.00/399M [00:00<?, ?iB/s]
  0%|          | 1.11M/399M [00:00<00:38, 10.3MiB/s]
  1%|          | 2.16M/399M [00:00<00:38, 10.3MiB/s]
  1%|          | 3.21M/399M [00:00<00:38, 10.3MiB/s]
  1%|          | 4.26M/399M [00:00<00:38, 10.3MiB/s]
  1%|▏         | 5.30M/399M [00:00<00:38, 10.3MiB/s]
  2%|▏         | 6.35M/399M [00:00<00:37, 10.3MiB/s]
  2%|▏         | 7.40M/399M [00:00<00:37, 10.3MiB/s]
  2%|▏         | 8.45M/399M [00:00<00:37, 10.4MiB/s]
  2%|▏         | 9.50M/399M [00:00<00:37, 10.4MiB/s]
  3%|▎         | 10.5M/399M [00:01<00:37, 10.4MiB/s]
  3%|▎         | 11.6M/399M [00:01<00:37, 10.4MiB/s]
  3%|▎         | 12.6M/399M [00:01<00:37, 10.4MiB/s]
  3%|▎         | 13.7M/399M [00:01<00:36, 10.4MiB/s]
  4%|▎         | 14.7M/399M [00:01<00:37, 10.4MiB/s]
  4%|▍         | 15.8M/399M [00:01<00:37, 10.3MiB/s]
  4%|▍         | 16.8M/399M [00:01<00:36, 10.4MiB/s]
  4%|▍         | 17.9M/399M [00:01<00:36, 10.4MiB/s]
  5%|▍         | 18.9M/399M [00:01<00:36, 10.4MiB/s]
  5%|▌         | 20.0M/399M [00:01<00:36, 10.3MiB/s]
  5%|▌         | 21.0M/399M [00:02<00:36, 10.4MiB/s]
  6%|▌         | 22.1M/399M [00:02<00:36, 10.4MiB/s]
  6%|▌         | 23.1M/399M [00:02<00:36, 10.4MiB/s]
  6%|▌         | 24.2M/399M [00:02<00:36, 10.3MiB/s]
  6%|▋         | 25.2M/399M [00:02<00:36, 10.4MiB/s]
  7%|▋         | 26.3M/399M [00:02<00:35, 10.3MiB/s]
  7%|▋         | 27.3M/399M [00:02<00:35, 10.4MiB/s]
  7%|▋         | 28.4M/399M [00:02<00:35, 10.4MiB/s]
  7%|▋         | 29.4M/399M [00:02<00:35, 10.3MiB/s]
  8%|▊         | 30.5M/399M [00:02<00:35, 10.4MiB/s]
  8%|▊         | 31.5M/399M [00:03<00:35, 10.4MiB/s]
  8%|▊         | 32.6M/399M [00:03<00:35, 10.4MiB/s]
  8%|▊         | 33.6M/399M [00:03<00:35, 10.3MiB/s]
  9%|▊         | 34.7M/399M [00:03<00:35, 10.4MiB/s]
  9%|▉         | 35.7M/399M [00:03<00:34, 10.4MiB/s]
  9%|▉         | 36.8M/399M [00:03<00:34, 10.4MiB/s]
  9%|▉         | 37.8M/399M [00:03<00:34, 10.3MiB/s]
 10%|▉         | 38.9M/399M [00:03<00:34, 10.3MiB/s]
 10%|█         | 39.9M/399M [00:03<00:34, 10.4MiB/s]
 10%|█         | 41.0M/399M [00:03<00:34, 10.4MiB/s]
 11%|█         | 42.0M/399M [00:04<00:34, 10.5MiB/s]
 11%|█         | 43.1M/399M [00:04<00:34, 10.4MiB/s]
 11%|█         | 44.1M/399M [00:04<00:34, 10.4MiB/s]
 11%|█▏        | 45.2M/399M [00:04<00:34, 10.3MiB/s]
 12%|█▏        | 46.2M/399M [00:04<00:34, 10.3MiB/s]
 12%|█▏        | 47.2M/399M [00:04<00:33, 10.3MiB/s]
 12%|█▏        | 48.3M/399M [00:04<00:33, 10.3MiB/s]
 12%|█▏        | 49.3M/399M [00:04<00:33, 10.3MiB/s]
 13%|█▎        | 50.4M/399M [00:04<00:33, 10.3MiB/s]
 13%|█▎        | 51.4M/399M [00:04<00:33, 10.3MiB/s]
 13%|█▎        | 52.5M/399M [00:05<00:33, 10.4MiB/s]
 13%|█▎        | 53.6M/399M [00:05<00:33, 10.4MiB/s]
 14%|█▎        | 54.6M/399M [00:05<00:33, 10.4MiB/s]
 14%|█▍        | 55.7M/399M [00:05<00:32, 10.4MiB/s]
 14%|█▍        | 56.7M/399M [00:05<00:32, 10.4MiB/s]
 14%|█▍        | 57.7M/399M [00:05<00:32, 10.4MiB/s]
 15%|█▍        | 58.8M/399M [00:05<00:32, 10.6MiB/s]
 15%|█▌        | 59.9M/399M [00:05<00:32, 10.6MiB/s]
 15%|█▌        | 60.9M/399M [00:05<00:32, 10.5MiB/s]
 16%|█▌        | 62.0M/399M [00:05<00:32, 10.4MiB/s]
 16%|█▌        | 63.0M/399M [00:06<00:32, 10.4MiB/s]
 16%|█▌        | 64.1M/399M [00:06<00:32, 10.4MiB/s]
 16%|█▋        | 65.1M/399M [00:06<00:32, 10.4MiB/s]
 17%|█▋        | 66.2M/399M [00:06<00:32, 10.3MiB/s]
 17%|█▋        | 67.2M/399M [00:06<00:32, 10.3MiB/s]
 17%|█▋        | 68.2M/399M [00:06<00:32, 10.3MiB/s]
 17%|█▋        | 69.3M/399M [00:06<00:32, 10.3MiB/s]
 18%|█▊        | 70.3M/399M [00:06<00:31, 10.3MiB/s]
 18%|█▊        | 71.4M/399M [00:06<00:31, 10.4MiB/s]
 18%|█▊        | 72.4M/399M [00:06<00:31, 10.4MiB/s]
 18%|█▊        | 73.5M/399M [00:07<00:31, 10.4MiB/s]
 19%|█▊        | 74.5M/399M [00:07<00:31, 10.4MiB/s]
 19%|█▉        | 75.6M/399M [00:07<00:31, 10.4MiB/s]
 19%|█▉        | 76.6M/399M [00:07<00:30, 10.4MiB/s]
 19%|█▉        | 77.7M/399M [00:07<00:30, 10.5MiB/s]
 20%|█▉        | 78.7M/399M [00:07<00:30, 10.5MiB/s]
 20%|██        | 79.8M/399M [00:07<00:30, 10.4MiB/s]
 20%|██        | 80.8M/399M [00:07<00:30, 10.4MiB/s]
 21%|██        | 81.9M/399M [00:07<00:30, 10.3MiB/s]
 21%|██        | 82.9M/399M [00:07<00:30, 10.3MiB/s]
 21%|██        | 83.9M/399M [00:08<00:30, 10.3MiB/s]
 21%|██▏       | 85.0M/399M [00:08<00:29, 10.5MiB/s]
 22%|██▏       | 86.1M/399M [00:08<00:29, 10.4MiB/s]
 22%|██▏       | 87.1M/399M [00:08<00:29, 10.5MiB/s]
 22%|██▏       | 88.2M/399M [00:08<00:29, 10.4MiB/s]
 22%|██▏       | 89.2M/399M [00:08<00:29, 10.4MiB/s]
 23%|██▎       | 90.3M/399M [00:08<00:29, 10.3MiB/s]
 23%|██▎       | 91.3M/399M [00:08<00:29, 10.3MiB/s]
 23%|██▎       | 92.3M/399M [00:08<00:29, 10.3MiB/s]
 23%|██▎       | 93.4M/399M [00:09<00:29, 10.2MiB/s]
 24%|██▎       | 94.6M/399M [00:09<00:28, 10.7MiB/s]
 24%|██▍       | 95.6M/399M [00:09<00:28, 10.7MiB/s]
 24%|██▍       | 96.7M/399M [00:09<00:28, 10.6MiB/s]
 25%|██▍       | 97.8M/399M [00:09<00:28, 10.6MiB/s]
 25%|██▍       | 98.8M/399M [00:09<00:28, 10.6MiB/s]
 25%|██▌       | 99.9M/399M [00:09<00:28, 10.5MiB/s]
 25%|██▌       | 101M/399M [00:09<00:28, 10.5MiB/s]
 26%|██▌       | 102M/399M [00:09<00:28, 10.5MiB/s]
 26%|██▌       | 103M/399M [00:09<00:28, 10.4MiB/s]
 26%|██▌       | 104M/399M [00:10<00:28, 10.3MiB/s]
 26%|██▋       | 105M/399M [00:10<00:28, 10.3MiB/s]
 27%|██▋       | 106M/399M [00:10<00:28, 10.3MiB/s]
 27%|██▋       | 107M/399M [00:10<00:28, 10.3MiB/s]
 27%|██▋       | 108M/399M [00:10<00:28, 10.2MiB/s]
 27%|██▋       | 109M/399M [00:10<00:28, 10.2MiB/s]
 28%|██▊       | 110M/399M [00:10<00:28, 10.1MiB/s]
 28%|██▊       | 111M/399M [00:10<00:28, 10.1MiB/s]
 28%|██▊       | 112M/399M [00:10<00:28, 10.2MiB/s]
 28%|██▊       | 113M/399M [00:10<00:28, 10.1MiB/s]
 29%|██▊       | 114M/399M [00:11<00:27, 10.2MiB/s]
 29%|██▉       | 115M/399M [00:11<00:27, 10.3MiB/s]
 29%|██▉       | 116M/399M [00:11<00:27, 10.3MiB/s]
 29%|██▉       | 118M/399M [00:11<00:27, 10.4MiB/s]
 30%|██▉       | 119M/399M [00:11<00:26, 10.4MiB/s]
 30%|███       | 120M/399M [00:11<00:26, 10.3MiB/s]
 30%|███       | 121M/399M [00:11<00:26, 10.3MiB/s]
 31%|███       | 122M/399M [00:11<00:26, 10.4MiB/s]
 31%|███       | 123M/399M [00:11<00:26, 10.4MiB/s]
 31%|███       | 124M/399M [00:11<00:26, 10.4MiB/s]
 31%|███▏      | 125M/399M [00:12<00:26, 10.3MiB/s]
 32%|███▏      | 126M/399M [00:12<00:26, 10.3MiB/s]
 32%|███▏      | 127M/399M [00:12<00:26, 10.3MiB/s]
 32%|███▏      | 128M/399M [00:12<00:26, 10.3MiB/s]
 32%|███▏      | 129M/399M [00:12<00:26, 10.3MiB/s]
 33%|███▎      | 130M/399M [00:12<00:25, 10.4MiB/s]
 33%|███▎      | 131M/399M [00:12<00:25, 10.5MiB/s]
 33%|███▎      | 132M/399M [00:12<00:25, 10.5MiB/s]
 33%|███▎      | 133M/399M [00:12<00:25, 10.5MiB/s]
 34%|███▎      | 134M/399M [00:12<00:25, 10.4MiB/s]
 34%|███▍      | 135M/399M [00:13<00:25, 10.4MiB/s]
 34%|███▍      | 136M/399M [00:13<00:25, 10.4MiB/s]
 34%|███▍      | 137M/399M [00:13<00:25, 10.3MiB/s]
 35%|███▍      | 138M/399M [00:13<00:25, 10.3MiB/s]
 35%|███▌      | 140M/399M [00:13<00:24, 10.4MiB/s]
 35%|███▌      | 141M/399M [00:13<00:24, 10.4MiB/s]
 36%|███▌      | 142M/399M [00:13<00:24, 10.4MiB/s]
 36%|███▌      | 143M/399M [00:13<00:24, 10.4MiB/s]
 36%|███▌      | 144M/399M [00:13<00:24, 10.4MiB/s]
 36%|███▋      | 145M/399M [00:13<00:24, 10.4MiB/s]
 37%|███▋      | 146M/399M [00:14<00:24, 10.4MiB/s]
 37%|███▋      | 147M/399M [00:14<00:24, 10.4MiB/s]
 37%|███▋      | 148M/399M [00:14<00:24, 10.4MiB/s]
 37%|███▋      | 149M/399M [00:14<00:23, 10.4MiB/s]
 38%|███▊      | 150M/399M [00:14<00:23, 10.4MiB/s]
 38%|███▊      | 151M/399M [00:14<00:23, 10.4MiB/s]
 38%|███▊      | 152M/399M [00:14<00:23, 10.4MiB/s]
 38%|███▊      | 153M/399M [00:14<00:23, 10.5MiB/s]
 39%|███▊      | 154M/399M [00:14<00:23, 10.5MiB/s]
 39%|███▉      | 155M/399M [00:14<00:23, 10.5MiB/s]
 39%|███▉      | 156M/399M [00:15<00:23, 10.5MiB/s]
 39%|███▉      | 157M/399M [00:15<00:23, 10.4MiB/s]
 40%|███▉      | 158M/399M [00:15<00:23, 10.4MiB/s]
 40%|████      | 159M/399M [00:15<00:23, 10.4MiB/s]
 40%|████      | 160M/399M [00:15<00:22, 10.4MiB/s]
 41%|████      | 162M/399M [00:15<00:22, 10.5MiB/s]
 41%|████      | 163M/399M [00:15<00:22, 10.5MiB/s]
 41%|████      | 164M/399M [00:15<00:22, 10.4MiB/s]
 41%|████▏     | 165M/399M [00:15<00:22, 10.4MiB/s]
 42%|████▏     | 166M/399M [00:15<00:22, 10.4MiB/s]
 42%|████▏     | 167M/399M [00:16<00:22, 10.4MiB/s]
 42%|████▏     | 168M/399M [00:16<00:22, 10.3MiB/s]
 42%|████▏     | 169M/399M [00:16<00:22, 10.3MiB/s]
 43%|████▎     | 170M/399M [00:16<00:22, 10.3MiB/s]
 43%|████▎     | 171M/399M [00:16<00:22, 10.3MiB/s]
 43%|████▎     | 172M/399M [00:16<00:21, 10.3MiB/s]
 43%|████▎     | 173M/399M [00:16<00:21, 10.3MiB/s]
 44%|████▎     | 174M/399M [00:16<00:21, 10.4MiB/s]
 44%|████▍     | 175M/399M [00:16<00:21, 10.3MiB/s]
 44%|████▍     | 176M/399M [00:16<00:21, 10.3MiB/s]
 44%|████▍     | 177M/399M [00:17<00:21, 10.4MiB/s]
 45%|████▍     | 178M/399M [00:17<00:21, 10.4MiB/s]
 45%|████▌     | 179M/399M [00:17<00:21, 10.3MiB/s]
 45%|████▌     | 180M/399M [00:17<00:20, 10.4MiB/s]
 46%|████▌     | 181M/399M [00:17<00:20, 10.3MiB/s]
 46%|████▌     | 183M/399M [00:17<00:20, 10.4MiB/s]
 46%|████▌     | 184M/399M [00:17<00:20, 10.3MiB/s]
 46%|████▋     | 185M/399M [00:17<00:20, 10.4MiB/s]
 47%|████▋     | 186M/399M [00:17<00:20, 10.4MiB/s]
 47%|████▋     | 187M/399M [00:17<00:20, 10.4MiB/s]
 47%|████▋     | 188M/399M [00:18<00:20, 10.4MiB/s]
 47%|████▋     | 189M/399M [00:18<00:20, 10.4MiB/s]
 48%|████▊     | 190M/399M [00:18<00:20, 10.3MiB/s]
 48%|████▊     | 191M/399M [00:18<00:19, 10.5MiB/s]
 48%|████▊     | 192M/399M [00:18<00:19, 10.4MiB/s]
 48%|████▊     | 193M/399M [00:18<00:19, 10.3MiB/s]
 49%|████▊     | 194M/399M [00:18<00:19, 10.3MiB/s]
 49%|████▉     | 195M/399M [00:18<00:19, 10.3MiB/s]
 49%|████▉     | 196M/399M [00:18<00:19, 10.4MiB/s]
 49%|████▉     | 197M/399M [00:19<00:19, 10.4MiB/s]
 50%|████▉     | 198M/399M [00:19<00:19, 10.4MiB/s]
 50%|█████     | 199M/399M [00:19<00:19, 10.3MiB/s]
 50%|█████     | 200M/399M [00:19<00:19, 10.3MiB/s]
 51%|█████     | 201M/399M [00:19<00:19, 10.4MiB/s]
 51%|█████     | 202M/399M [00:19<00:18, 10.4MiB/s]
 51%|█████     | 203M/399M [00:19<00:18, 10.4MiB/s]
 51%|█████▏    | 205M/399M [00:19<00:18, 10.4MiB/s]
 52%|█████▏    | 206M/399M [00:19<00:18, 10.4MiB/s]
 52%|█████▏    | 207M/399M [00:19<00:18, 10.4MiB/s]
 52%|█████▏    | 208M/399M [00:20<00:18, 10.4MiB/s]
 52%|█████▏    | 209M/399M [00:20<00:18, 10.4MiB/s]
 53%|█████▎    | 210M/399M [00:20<00:18, 10.4MiB/s]
 53%|█████▎    | 211M/399M [00:20<00:17, 10.4MiB/s]
 53%|█████▎    | 212M/399M [00:20<00:17, 10.4MiB/s]
 53%|█████▎    | 213M/399M [00:20<00:17, 10.4MiB/s]
 54%|█████▎    | 214M/399M [00:20<00:17, 10.4MiB/s]
 54%|█████▍    | 215M/399M [00:20<00:17, 10.4MiB/s]
 54%|█████▍    | 216M/399M [00:20<00:17, 10.4MiB/s]
 54%|█████▍    | 217M/399M [00:20<00:17, 10.3MiB/s]
 55%|█████▍    | 218M/399M [00:21<00:17, 10.3MiB/s]
 55%|█████▌    | 219M/399M [00:21<00:17, 10.3MiB/s]
 55%|█████▌    | 220M/399M [00:21<00:17, 10.3MiB/s]
 56%|█████▌    | 221M/399M [00:21<00:17, 10.4MiB/s]
 56%|█████▌    | 222M/399M [00:21<00:16, 10.5MiB/s]
 56%|█████▌    | 223M/399M [00:21<00:16, 10.5MiB/s]
 56%|█████▋    | 225M/399M [00:21<00:16, 10.5MiB/s]
 57%|█████▋    | 226M/399M [00:21<00:16, 10.5MiB/s]
 57%|█████▋    | 227M/399M [00:21<00:16, 10.4MiB/s]
 57%|█████▋    | 228M/399M [00:21<00:16, 10.4MiB/s]
 57%|█████▋    | 229M/399M [00:22<00:16, 10.4MiB/s]
 58%|█████▊    | 230M/399M [00:22<00:16, 10.4MiB/s]
 58%|█████▊    | 231M/399M [00:22<00:16, 10.4MiB/s]
 58%|█████▊    | 232M/399M [00:22<00:16, 10.4MiB/s]
 58%|█████▊    | 233M/399M [00:22<00:16, 10.3MiB/s]
 59%|█████▊    | 234M/399M [00:22<00:15, 10.3MiB/s]
 59%|█████▉    | 235M/399M [00:22<00:15, 10.3MiB/s]
 59%|█████▉    | 236M/399M [00:22<00:15, 10.3MiB/s]
 59%|█████▉    | 237M/399M [00:22<00:15, 10.3MiB/s]
 60%|█████▉    | 238M/399M [00:22<00:15, 10.3MiB/s]
 60%|██████    | 239M/399M [00:23<00:15, 10.3MiB/s]
 60%|██████    | 240M/399M [00:23<00:15, 10.3MiB/s]
 61%|██████    | 241M/399M [00:23<00:15, 10.4MiB/s]
 61%|██████    | 242M/399M [00:23<00:15, 10.4MiB/s]
 61%|██████    | 243M/399M [00:23<00:14, 10.4MiB/s]
 61%|██████▏   | 244M/399M [00:23<00:14, 10.3MiB/s]
 62%|██████▏   | 245M/399M [00:23<00:14, 10.4MiB/s]
 62%|██████▏   | 246M/399M [00:23<00:14, 10.4MiB/s]
 62%|██████▏   | 248M/399M [00:23<00:14, 10.4MiB/s]
 62%|██████▏   | 249M/399M [00:23<00:14, 10.4MiB/s]
 63%|██████▎   | 250M/399M [00:24<00:14, 10.4MiB/s]
 63%|██████▎   | 251M/399M [00:24<00:14, 10.4MiB/s]
 63%|██████▎   | 252M/399M [00:24<00:14, 10.4MiB/s]
 63%|██████▎   | 253M/399M [00:24<00:14, 10.4MiB/s]
 64%|██████▎   | 254M/399M [00:24<00:13, 10.3MiB/s]
 64%|██████▍   | 255M/399M [00:24<00:13, 10.4MiB/s]
 64%|██████▍   | 256M/399M [00:24<00:13, 10.4MiB/s]
 64%|██████▍   | 257M/399M [00:24<00:13, 10.4MiB/s]
 65%|██████▍   | 258M/399M [00:24<00:13, 10.4MiB/s]
 65%|██████▌   | 259M/399M [00:24<00:13, 10.3MiB/s]
 65%|██████▌   | 260M/399M [00:25<00:13, 10.4MiB/s]
 66%|██████▌   | 261M/399M [00:25<00:13, 10.4MiB/s]
 66%|██████▌   | 262M/399M [00:25<00:13, 10.3MiB/s]
 66%|██████▌   | 263M/399M [00:25<00:13, 10.3MiB/s]
 66%|██████▋   | 264M/399M [00:25<00:13, 10.3MiB/s]
 67%|██████▋   | 265M/399M [00:25<00:12, 10.3MiB/s]
 67%|██████▋   | 266M/399M [00:25<00:12, 10.4MiB/s]
 67%|██████▋   | 267M/399M [00:25<00:12, 10.4MiB/s]
 67%|██████▋   | 268M/399M [00:25<00:12, 10.4MiB/s]
 68%|██████▊   | 270M/399M [00:25<00:12, 10.4MiB/s]
 68%|██████▊   | 271M/399M [00:26<00:12, 10.4MiB/s]
 68%|██████▊   | 272M/399M [00:26<00:12, 10.4MiB/s]
 68%|██████▊   | 273M/399M [00:26<00:12, 10.4MiB/s]
 69%|██████▊   | 274M/399M [00:26<00:11, 10.4MiB/s]
 69%|██████▉   | 275M/399M [00:26<00:11, 10.4MiB/s]
 69%|██████▉   | 276M/399M [00:26<00:11, 10.4MiB/s]
 69%|██████▉   | 277M/399M [00:26<00:11, 10.4MiB/s]
 70%|██████▉   | 278M/399M [00:26<00:11, 10.4MiB/s]
 70%|███████   | 279M/399M [00:26<00:11, 10.4MiB/s]
 70%|███████   | 280M/399M [00:26<00:11, 10.5MiB/s]
 71%|███████   | 281M/399M [00:27<00:11, 10.4MiB/s]
 71%|███████   | 282M/399M [00:27<00:11, 10.4MiB/s]
 71%|███████   | 283M/399M [00:27<00:11, 10.4MiB/s]
 71%|███████▏  | 284M/399M [00:27<00:11, 10.4MiB/s]
 72%|███████▏  | 285M/399M [00:27<00:10, 10.4MiB/s]
 72%|███████▏  | 286M/399M [00:27<00:10, 10.4MiB/s]
 72%|███████▏  | 287M/399M [00:27<00:10, 10.4MiB/s]
 72%|███████▏  | 288M/399M [00:27<00:10, 10.4MiB/s]
 73%|███████▎  | 289M/399M [00:27<00:10, 10.4MiB/s]
 73%|███████▎  | 291M/399M [00:27<00:10, 10.4MiB/s]
 73%|███████▎  | 292M/399M [00:28<00:10, 10.4MiB/s]
 73%|███████▎  | 293M/399M [00:28<00:10, 10.4MiB/s]
 74%|███████▎  | 294M/399M [00:28<00:10, 10.4MiB/s]
 74%|███████▍  | 295M/399M [00:28<00:09, 10.5MiB/s]
 74%|███████▍  | 296M/399M [00:28<00:09, 10.5MiB/s]
 74%|███████▍  | 297M/399M [00:28<00:09, 10.5MiB/s]
 75%|███████▍  | 298M/399M [00:28<00:09, 10.5MiB/s]
 75%|███████▌  | 299M/399M [00:28<00:09, 10.5MiB/s]
 75%|███████▌  | 300M/399M [00:28<00:09, 10.4MiB/s]
 76%|███████▌  | 301M/399M [00:28<00:09, 10.4MiB/s]
 76%|███████▌  | 302M/399M [00:29<00:09, 10.3MiB/s]
 76%|███████▌  | 303M/399M [00:29<00:09, 10.3MiB/s]
 76%|███████▋  | 304M/399M [00:29<00:09, 10.4MiB/s]
 77%|███████▋  | 305M/399M [00:29<00:08, 10.4MiB/s]
 77%|███████▋  | 306M/399M [00:29<00:08, 10.4MiB/s]
 77%|███████▋  | 307M/399M [00:29<00:08, 10.4MiB/s]
 77%|███████▋  | 308M/399M [00:29<00:08, 10.3MiB/s]
 78%|███████▊  | 309M/399M [00:29<00:08, 10.4MiB/s]
 78%|███████▊  | 310M/399M [00:29<00:08, 10.4MiB/s]
 78%|███████▊  | 311M/399M [00:30<00:08, 10.4MiB/s]
 78%|███████▊  | 313M/399M [00:30<00:08, 10.4MiB/s]
 79%|███████▊  | 314M/399M [00:30<00:08, 10.4MiB/s]
 79%|███████▉  | 315M/399M [00:30<00:08, 10.5MiB/s]
 79%|███████▉  | 316M/399M [00:30<00:07, 10.4MiB/s]
 79%|███████▉  | 317M/399M [00:30<00:07, 10.6MiB/s]
 80%|███████▉  | 318M/399M [00:30<00:07, 10.5MiB/s]
 80%|████████  | 319M/399M [00:30<00:07, 10.5MiB/s]
 80%|████████  | 320M/399M [00:30<00:07, 10.5MiB/s]
 81%|████████  | 321M/399M [00:30<00:07, 10.5MiB/s]
 81%|████████  | 322M/399M [00:31<00:07, 10.5MiB/s]
 81%|████████  | 323M/399M [00:31<00:07, 10.7MiB/s]
 81%|████████▏ | 324M/399M [00:31<00:06, 10.7MiB/s]
 82%|████████▏ | 325M/399M [00:31<00:06, 10.6MiB/s]
 82%|████████▏ | 326M/399M [00:31<00:07, 10.1MiB/s]
 82%|████████▏ | 327M/399M [00:31<00:06, 10.2MiB/s]
 82%|████████▏ | 329M/399M [00:31<00:06, 10.3MiB/s]
 83%|████████▎ | 330M/399M [00:31<00:06, 10.4MiB/s]
 83%|████████▎ | 331M/399M [00:31<00:06, 10.3MiB/s]
 83%|████████▎ | 332M/399M [00:31<00:06, 10.3MiB/s]
 83%|████████▎ | 333M/399M [00:32<00:06, 10.4MiB/s]
 84%|████████▎ | 334M/399M [00:32<00:06, 10.4MiB/s]
 84%|████████▍ | 335M/399M [00:32<00:06, 10.4MiB/s]
 84%|████████▍ | 336M/399M [00:32<00:06, 10.4MiB/s]
 85%|████████▍ | 337M/399M [00:32<00:05, 10.4MiB/s]
 85%|████████▍ | 338M/399M [00:32<00:05, 10.5MiB/s]
 85%|████████▌ | 339M/399M [00:32<00:05, 10.4MiB/s]
 85%|████████▌ | 340M/399M [00:32<00:05, 10.4MiB/s]
 86%|████████▌ | 341M/399M [00:32<00:05, 10.5MiB/s]
 86%|████████▌ | 342M/399M [00:32<00:05, 10.5MiB/s]
 86%|████████▌ | 343M/399M [00:33<00:05, 10.4MiB/s]
 86%|████████▋ | 344M/399M [00:33<00:05, 10.4MiB/s]
 87%|████████▋ | 345M/399M [00:33<00:05, 10.4MiB/s]
 87%|████████▋ | 346M/399M [00:33<00:05, 10.4MiB/s]
 87%|████████▋ | 347M/399M [00:33<00:04, 10.4MiB/s]
 87%|████████▋ | 348M/399M [00:33<00:04, 10.4MiB/s]
 88%|████████▊ | 350M/399M [00:33<00:04, 10.5MiB/s]
 88%|████████▊ | 351M/399M [00:33<00:04, 10.5MiB/s]
 88%|████████▊ | 352M/399M [00:33<00:04, 10.4MiB/s]
 88%|████████▊ | 353M/399M [00:33<00:04, 10.4MiB/s]
 89%|████████▉ | 354M/399M [00:34<00:04, 10.4MiB/s]
 89%|████████▉ | 355M/399M [00:34<00:04, 10.3MiB/s]
 89%|████████▉ | 356M/399M [00:34<00:04, 10.3MiB/s]
 90%|████████▉ | 357M/399M [00:34<00:04, 10.4MiB/s]
 90%|████████▉ | 358M/399M [00:34<00:03, 10.4MiB/s]
 90%|█████████ | 359M/399M [00:34<00:03, 10.4MiB/s]
 90%|█████████ | 360M/399M [00:34<00:03, 10.4MiB/s]
 91%|█████████ | 361M/399M [00:34<00:03, 10.3MiB/s]
 91%|█████████ | 362M/399M [00:34<00:03, 10.4MiB/s]
 91%|█████████ | 363M/399M [00:34<00:03, 10.5MiB/s]
 91%|█████████▏| 364M/399M [00:35<00:03, 10.4MiB/s]
 92%|█████████▏| 365M/399M [00:35<00:03, 10.4MiB/s]
 92%|█████████▏| 366M/399M [00:35<00:03, 10.5MiB/s]
 92%|█████████▏| 367M/399M [00:35<00:02, 10.5MiB/s]
 92%|█████████▏| 368M/399M [00:35<00:02, 10.4MiB/s]
 93%|█████████▎| 370M/399M [00:35<00:02, 10.5MiB/s]
 93%|█████████▎| 371M/399M [00:35<00:02, 10.5MiB/s]
 93%|█████████▎| 372M/399M [00:35<00:02, 10.5MiB/s]
 94%|█████████▎| 373M/399M [00:35<00:02, 10.5MiB/s]
 94%|█████████▍| 374M/399M [00:35<00:02, 10.4MiB/s]
 94%|█████████▍| 375M/399M [00:36<00:02, 10.4MiB/s]
 94%|█████████▍| 376M/399M [00:36<00:02, 10.4MiB/s]
 95%|█████████▍| 377M/399M [00:36<00:02, 10.4MiB/s]
 95%|█████████▍| 378M/399M [00:36<00:02, 10.2MiB/s]
 95%|█████████▌| 379M/399M [00:36<00:01, 10.2MiB/s]
 95%|█████████▌| 380M/399M [00:36<00:01, 10.2MiB/s]
 96%|█████████▌| 381M/399M [00:36<00:01, 10.4MiB/s]
 96%|█████████▌| 382M/399M [00:36<00:01, 10.3MiB/s]
 96%|█████████▌| 383M/399M [00:36<00:01, 10.5MiB/s]
 96%|█████████▋| 384M/399M [00:36<00:01, 10.5MiB/s]
 97%|█████████▋| 385M/399M [00:37<00:01, 10.4MiB/s]
 97%|█████████▋| 386M/399M [00:37<00:01, 10.4MiB/s]
 97%|█████████▋| 387M/399M [00:37<00:01, 10.1MiB/s]
 97%|█████████▋| 388M/399M [00:37<00:01, 10.1MiB/s]
 98%|█████████▊| 389M/399M [00:37<00:00, 10.2MiB/s]
 98%|█████████▊| 390M/399M [00:37<00:00, 10.2MiB/s]
 98%|█████████▊| 391M/399M [00:37<00:00, 10.1MiB/s]
 98%|█████████▊| 392M/399M [00:37<00:00, 10.2MiB/s]
 99%|█████████▊| 394M/399M [00:37<00:00, 10.3MiB/s]
 99%|█████████▉| 395M/399M [00:38<00:00, 10.3MiB/s]
 99%|█████████▉| 396M/399M [00:38<00:00, 10.4MiB/s]
100%|█████████▉| 397M/399M [00:38<00:00, 10.4MiB/s]
100%|█████████▉| 398M/399M [00:38<00:00, 10.3MiB/s]
100%|██████████| 399M/399M [00:38<00:00, 10.4MiB/s]
/home/runner/work/deepinv/deepinv/deepinv/utils/demo.py:21: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
  x = torch.load(str(root_dir) + ".pt")

Define physics

We simulate a sequential k-space sampler, that, over the course of 4 phases (i.e. frames), samples 64 lines (i.e 2x total undersampling from 128) with Gaussian weighting (plus a few extra for the ACS signals in the centre of the k-space). We use deepinv.physics.SequentialMRI to do this.

First, we define a static 2x acceleration mask that all measurements use (of shape [B,C,H,W]):

mask_full = GaussianMaskGenerator((2, H, H), acceleration=2, device=device).step(
    batch_size=batch_size
)["mask"]

Next, we randomly share the sampled lines across 4 time-phases into a time-varying mask:

# Split only in horizontal direction
masks = [mask_full[..., 0, :]]
splitter = BernoulliSplittingMaskGenerator((2, H), split_ratio=0.5, device=device)

acs = 10

# Split 4 times
for _ in range(2):
    new_masks = []
    for m in masks:
        m1 = splitter.step(batch_size=batch_size, input_mask=m)["mask"]
        m2 = m - m1
        m1[..., H // 2 - acs // 2 : H // 2 + acs // 2] = 1
        m2[..., H // 2 - acs // 2 : H // 2 + acs // 2] = 1
        new_masks.extend([m1, m2])
    masks = new_masks

# Merge masks into time dimension
mask = torch.stack(masks, 2)

# Convert to vertical lines
mask = torch.stack([mask] * H, -2)
/home/runner/work/deepinv/deepinv/deepinv/physics/generator/inpainting.py:116: UserWarning: Generating pixelwise mask assumes channel in first dimension. For 2D images (i.e. of shape (H,W)) ensure tensor_size is at least 3D (i.e. C,H,W). However, for tensor_size of shape (C,M), this will work as expected.
  warn(

Now define physics using this time-varying mask of shape [B,C,T,H,W]:

Let’s visualise the sequential measurements using a sample image (run this notebook yourself to display the video). We also visualise the frame-by-frame no-learning zero-filled reconstruction.

x = next(iter(train_dataloader))
y = physics(x)
dinv.utils.plot_videos(
    [physics.repeat(x, mask), y, mask, physics.A_adjoint(y, keep_time_dim=True)],
    titles=["x", "y", "mask", "x_init"],
    display=True,
)
x, y, mask, x_init
/home/runner/work/deepinv/deepinv/deepinv/utils/plotting.py:784: UserWarning: IPython can't be found. Install it to use display=True. Skipping...
  warn("IPython can't be found. Install it to use display=True. Skipping...")
/opt/hostedtoolcache/Python/3.9.20/x64/lib/python3.9/site-packages/matplotlib/animation.py:872: UserWarning: Animation was deleted without rendering anything. This is most likely not intended. To prevent deletion, assign the Animation to a variable, e.g. `anim`, that exists until you output the Animation using `plt.show()` or `anim.save()`.
  warnings.warn(

Also visualise the flattened time-series, recovering the original 2x undersampling mask (note the actual undersampling factor is much lower due to ACS lines):

dinv.utils.plot(
    [x, physics.average(y), physics.average(mask), physics.A_adjoint(y)],
    titles=["x", "y", "orig mask", "x_init"],
)

print("Total acceleration:", (2 * 128 * 128) / mask.sum())
x, y, orig mask, x_init
Total acceleration: tensor(1.3617)

Define model

As a (static) reconstruction network, we use an unrolled network (half-quadratic splitting) with a trainable denoising prior based on the DnCNN architecture as an example of a model-based deep learning architecture from MoDL. See deepinv.utils.demo.demo_mri_model() for details.

model = demo_mri_model(device=device)

Prep loss

Perform loss on all collected lines by setting dynamic_model to False. Then adapt model to perform Artifact2Artifact. We set split_size=1 to mean that each Artifact chunk containes only 1 frame.

loss = dinv.loss.Artifact2ArtifactLoss(
    (2, 4, H, H), split_size=1, dynamic_model=False, device=device
)
model = loss.adapt_model(model)

Train model

Original model is trained for 100 epochs. We demonstrate loading the pretrained model then fine-tuning with 1 epoch. Report PSNR and SSIM. To train from scratch, simply comment out the model loading code and increase the number of epochs.

optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-8)

# Load pretrained model
file_name = "demo_artifact2artifact_mri.pth"
url = get_weights_url(model_name="measplit", file_name=file_name)
ckpt = torch.hub.load_state_dict_from_url(
    url, map_location=lambda storage, loc: storage, file_name=file_name
)

model.load_state_dict(ckpt["state_dict"])
optimizer.load_state_dict(ckpt["optimizer"])

# Initialize the trainer
trainer = dinv.Trainer(
    model,
    physics=physics,
    epochs=1,
    losses=loss,
    optimizer=optimizer,
    train_dataloader=train_dataloader,
    metrics=[dinv.metric.PSNR(), dinv.metric.SSIM()],
    online_measurements=True,
    device=device,
    save_path=None,
    verbose=True,
    wandb_vis=False,
    show_progress_bar=False,
)

model = trainer.train()
Downloading: "https://huggingface.co/deepinv/measplit/resolve/main/demo_artifact2artifact_mri.pth?download=true" to /home/runner/.cache/torch/hub/checkpoints/demo_artifact2artifact_mri.pth

  0%|          | 0.00/2.16M [00:00<?, ?B/s]
 52%|█████▏    | 1.12M/2.16M [00:00<00:00, 11.0MB/s]
100%|██████████| 2.16M/2.16M [00:00<00:00, 10.9MB/s]
The model has 187019 trainable parameters
/opt/hostedtoolcache/Python/3.9.20/x64/lib/python3.9/site-packages/torchmetrics/utilities/prints.py:70: FutureWarning: Importing `spectral_angle_mapper` from `torchmetrics.functional` was deprecated and will be removed in 2.0. Import `spectral_angle_mapper` from `torchmetrics.image` instead.
  _future_warning(
Train epoch 0: TotalLoss=0.005, PSNR=30.569, SSIM=0.824

Test the model

trainer.plot_images = True
trainer.test(test_dataloader)
Ground truth, No learning, Reconstruction
Eval epoch 0: PSNR=34.707, PSNR no learning=35.317, SSIM=0.887, SSIM no learning=0.793
Test results:
PSNR no learning: 35.317 +- 2.273
PSNR: 34.707 +- 1.726
SSIM no learning: 0.793 +- 0.069
SSIM: 0.887 +- 0.017

{'PSNR no learning': np.float64(35.31698811848958), 'PSNR no learning_std': np.float64(2.273216157077679), 'PSNR': np.float64(34.70675455729167), 'PSNR_std': np.float64(1.7261516614270602), 'SSIM no learning': np.float64(0.7928400039672852), 'SSIM no learning_std': np.float64(0.0694199089554509), 'SSIM': np.float64(0.8872030258178711), 'SSIM_std': np.float64(0.01740006355477841)}

Total running time of the script: (0 minutes 45.034 seconds)

Gallery generated by Sphinx-Gallery