Note
Go to the end to download the full example code.
Self-supervised MRI reconstruction with Artifact2Artifact#
We demonstrate the self-supervised Artifact2Artifact loss for solving an undersampled sequential MRI reconstruction problem without ground truth.
The Artifact2Artifact loss was introduced in Liu et al. RARE: Image Reconstruction using Deep Priors Learned without Groundtruth.
In our example, we use it to reconstruct static images, where the k-space measurements is a time-sequence, where each time step (phase) consists of sampled lines such that the whole measurement is a set of non-overlapping lines.
For a description of how Artifact2Artifact constructs the loss, see
deepinv.loss.Artifact2ArtifactLoss
.
Note in our implementation, this is a special case of the generic
splitting loss: see deepinv.loss.SplittingLoss
for more
details. See deepinv.loss.Phase2PhaseLoss
for the related
Phase2Phase.
from pathlib import Path
import torch
from torch.utils.data import DataLoader, Subset
from torchvision import transforms
import deepinv as dinv
from deepinv.utils.demo import load_dataset, demo_mri_model
from deepinv.models.utils import get_weights_url
from deepinv.physics.generator import (
GaussianMaskGenerator,
BernoulliSplittingMaskGenerator,
)
torch.manual_seed(0)
device = dinv.utils.get_freer_gpu() if torch.cuda.is_available() else "cpu"
Load data#
We use a subset of single coil knee MRI data from the fastMRI challenge and resize to 128x128. We use a train set of size 5 in this demo for speed to fine-tune the original model. Set to 150 to train the original model from scratch.
batch_size = 1
H = 128
transform = transforms.Compose([transforms.Resize(H)])
train_dataset = load_dataset("fastmri_knee_singlecoil", transform, train=True)
test_dataset = load_dataset("fastmri_knee_singlecoil", transform, train=False)
train_dataset = Subset(train_dataset, torch.arange(5))
test_dataset = Subset(test_dataset, torch.arange(30))
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
Downloading datasets/fastmri_knee_singlecoil.pt
0%| | 0.00/399M [00:00<?, ?iB/s]
0%| | 1.18M/399M [00:00<00:37, 10.5MiB/s]
1%| | 2.23M/399M [00:00<00:37, 10.4MiB/s]
1%| | 3.28M/399M [00:00<00:37, 10.4MiB/s]
1%| | 4.32M/399M [00:00<00:37, 10.4MiB/s]
1%|▏ | 5.37M/399M [00:00<00:37, 10.4MiB/s]
2%|▏ | 6.42M/399M [00:00<00:37, 10.4MiB/s]
2%|▏ | 7.49M/399M [00:00<00:37, 10.5MiB/s]
2%|▏ | 8.54M/399M [00:00<00:37, 10.4MiB/s]
2%|▏ | 9.61M/399M [00:00<00:37, 10.5MiB/s]
3%|▎ | 10.7M/399M [00:01<00:37, 10.5MiB/s]
3%|▎ | 11.7M/399M [00:01<00:37, 10.4MiB/s]
3%|▎ | 12.7M/399M [00:01<00:37, 10.4MiB/s]
3%|▎ | 13.8M/399M [00:01<00:37, 10.3MiB/s]
4%|▎ | 14.8M/399M [00:01<00:37, 10.3MiB/s]
4%|▍ | 15.9M/399M [00:01<00:37, 10.3MiB/s]
4%|▍ | 16.9M/399M [00:01<00:36, 10.4MiB/s]
5%|▍ | 18.0M/399M [00:01<00:36, 10.3MiB/s]
5%|▍ | 19.0M/399M [00:01<00:36, 10.4MiB/s]
5%|▌ | 20.1M/399M [00:01<00:36, 10.4MiB/s]
5%|▌ | 21.1M/399M [00:02<00:35, 10.5MiB/s]
6%|▌ | 22.2M/399M [00:02<00:35, 10.5MiB/s]
6%|▌ | 23.2M/399M [00:02<00:35, 10.4MiB/s]
6%|▌ | 24.3M/399M [00:02<00:35, 10.4MiB/s]
6%|▋ | 25.3M/399M [00:02<00:35, 10.5MiB/s]
7%|▋ | 26.4M/399M [00:02<00:35, 10.5MiB/s]
7%|▋ | 27.4M/399M [00:02<00:35, 10.5MiB/s]
7%|▋ | 28.5M/399M [00:02<00:34, 10.6MiB/s]
7%|▋ | 29.6M/399M [00:02<00:34, 10.6MiB/s]
8%|▊ | 30.6M/399M [00:02<00:34, 10.5MiB/s]
8%|▊ | 31.7M/399M [00:03<00:34, 10.5MiB/s]
8%|▊ | 32.7M/399M [00:03<00:34, 10.5MiB/s]
8%|▊ | 33.8M/399M [00:03<00:34, 10.4MiB/s]
9%|▊ | 34.8M/399M [00:03<00:35, 10.4MiB/s]
9%|▉ | 35.9M/399M [00:03<00:35, 10.3MiB/s]
9%|▉ | 36.9M/399M [00:03<00:35, 10.3MiB/s]
10%|▉ | 37.9M/399M [00:03<00:35, 10.3MiB/s]
10%|▉ | 39.0M/399M [00:03<00:34, 10.5MiB/s]
10%|█ | 40.1M/399M [00:03<00:34, 10.4MiB/s]
10%|█ | 41.1M/399M [00:03<00:34, 10.4MiB/s]
11%|█ | 42.2M/399M [00:04<00:34, 10.3MiB/s]
11%|█ | 43.2M/399M [00:04<00:34, 10.3MiB/s]
11%|█ | 44.2M/399M [00:04<00:34, 10.3MiB/s]
11%|█▏ | 45.3M/399M [00:04<00:33, 10.5MiB/s]
12%|█▏ | 46.4M/399M [00:04<00:33, 10.5MiB/s]
12%|█▏ | 47.4M/399M [00:04<00:33, 10.4MiB/s]
12%|█▏ | 48.5M/399M [00:04<00:33, 10.4MiB/s]
12%|█▏ | 49.5M/399M [00:04<00:33, 10.4MiB/s]
13%|█▎ | 50.5M/399M [00:04<00:33, 10.4MiB/s]
13%|█▎ | 51.6M/399M [00:04<00:33, 10.3MiB/s]
13%|█▎ | 52.6M/399M [00:05<00:33, 10.3MiB/s]
13%|█▎ | 53.7M/399M [00:05<00:33, 10.3MiB/s]
14%|█▎ | 54.7M/399M [00:05<00:33, 10.3MiB/s]
14%|█▍ | 55.8M/399M [00:05<00:32, 10.5MiB/s]
14%|█▍ | 56.8M/399M [00:05<00:32, 10.4MiB/s]
15%|█▍ | 57.9M/399M [00:05<00:32, 10.4MiB/s]
15%|█▍ | 58.9M/399M [00:05<00:32, 10.4MiB/s]
15%|█▌ | 59.9M/399M [00:05<00:32, 10.4MiB/s]
15%|█▌ | 61.0M/399M [00:05<00:33, 10.2MiB/s]
16%|█▌ | 62.1M/399M [00:05<00:32, 10.5MiB/s]
16%|█▌ | 63.2M/399M [00:06<00:31, 10.5MiB/s]
16%|█▌ | 64.2M/399M [00:06<00:31, 10.6MiB/s]
16%|█▋ | 65.3M/399M [00:06<00:31, 10.6MiB/s]
17%|█▋ | 66.4M/399M [00:06<00:31, 10.5MiB/s]
17%|█▋ | 67.4M/399M [00:06<00:31, 10.4MiB/s]
17%|█▋ | 68.5M/399M [00:06<00:31, 10.4MiB/s]
17%|█▋ | 69.5M/399M [00:06<00:31, 10.3MiB/s]
18%|█▊ | 70.5M/399M [00:06<00:31, 10.3MiB/s]
18%|█▊ | 71.6M/399M [00:06<00:31, 10.2MiB/s]
18%|█▊ | 72.6M/399M [00:06<00:31, 10.2MiB/s]
18%|█▊ | 73.6M/399M [00:07<00:31, 10.2MiB/s]
19%|█▊ | 74.6M/399M [00:07<00:31, 10.2MiB/s]
19%|█▉ | 75.6M/399M [00:07<00:31, 10.2MiB/s]
19%|█▉ | 76.7M/399M [00:07<00:31, 10.3MiB/s]
20%|█▉ | 77.7M/399M [00:07<00:31, 10.3MiB/s]
20%|█▉ | 78.8M/399M [00:07<00:31, 10.3MiB/s]
20%|██ | 79.8M/399M [00:07<00:31, 10.3MiB/s]
20%|██ | 80.9M/399M [00:07<00:30, 10.3MiB/s]
21%|██ | 81.9M/399M [00:07<00:30, 10.3MiB/s]
21%|██ | 83.0M/399M [00:07<00:30, 10.3MiB/s]
21%|██ | 84.0M/399M [00:08<00:30, 10.3MiB/s]
21%|██▏ | 85.1M/399M [00:08<00:30, 10.3MiB/s]
22%|██▏ | 86.1M/399M [00:08<00:30, 10.3MiB/s]
22%|██▏ | 87.2M/399M [00:08<00:29, 10.5MiB/s]
22%|██▏ | 88.3M/399M [00:08<00:29, 10.5MiB/s]
22%|██▏ | 89.3M/399M [00:08<00:29, 10.5MiB/s]
23%|██▎ | 90.4M/399M [00:08<00:29, 10.4MiB/s]
23%|██▎ | 91.4M/399M [00:08<00:29, 10.4MiB/s]
23%|██▎ | 92.4M/399M [00:08<00:29, 10.4MiB/s]
23%|██▎ | 93.5M/399M [00:09<00:29, 10.3MiB/s]
24%|██▎ | 94.5M/399M [00:09<00:29, 10.2MiB/s]
24%|██▍ | 95.6M/399M [00:09<00:29, 10.3MiB/s]
24%|██▍ | 96.6M/399M [00:09<00:29, 10.3MiB/s]
25%|██▍ | 97.8M/399M [00:09<00:27, 10.8MiB/s]
25%|██▍ | 98.9M/399M [00:09<00:27, 10.8MiB/s]
25%|██▌ | 100M/399M [00:09<00:27, 10.7MiB/s]
25%|██▌ | 101M/399M [00:09<00:27, 10.7MiB/s]
26%|██▌ | 102M/399M [00:09<00:29, 9.99MiB/s]
26%|██▌ | 103M/399M [00:09<00:29, 10.1MiB/s]
26%|██▌ | 104M/399M [00:10<00:28, 10.3MiB/s]
26%|██▋ | 105M/399M [00:10<00:28, 10.3MiB/s]
27%|██▋ | 106M/399M [00:10<00:28, 10.3MiB/s]
27%|██▋ | 107M/399M [00:10<00:28, 10.3MiB/s]
27%|██▋ | 108M/399M [00:10<00:28, 10.3MiB/s]
27%|██▋ | 109M/399M [00:10<00:27, 10.3MiB/s]
28%|██▊ | 110M/399M [00:10<00:27, 10.4MiB/s]
28%|██▊ | 112M/399M [00:10<00:27, 10.3MiB/s]
28%|██▊ | 113M/399M [00:10<00:27, 10.4MiB/s]
29%|██▊ | 114M/399M [00:10<00:27, 10.3MiB/s]
29%|██▉ | 115M/399M [00:11<00:27, 10.4MiB/s]
29%|██▉ | 116M/399M [00:11<00:27, 10.4MiB/s]
29%|██▉ | 117M/399M [00:11<00:27, 10.4MiB/s]
30%|██▉ | 118M/399M [00:11<00:26, 10.4MiB/s]
30%|██▉ | 119M/399M [00:11<00:26, 10.4MiB/s]
30%|███ | 120M/399M [00:11<00:26, 10.4MiB/s]
30%|███ | 121M/399M [00:11<00:26, 10.4MiB/s]
31%|███ | 122M/399M [00:11<00:26, 10.4MiB/s]
31%|███ | 123M/399M [00:11<00:26, 10.5MiB/s]
31%|███ | 124M/399M [00:11<00:26, 10.4MiB/s]
31%|███▏ | 125M/399M [00:12<00:26, 10.4MiB/s]
32%|███▏ | 126M/399M [00:12<00:26, 10.3MiB/s]
32%|███▏ | 127M/399M [00:12<00:26, 10.3MiB/s]
32%|███▏ | 128M/399M [00:12<00:25, 10.5MiB/s]
32%|███▏ | 129M/399M [00:12<00:25, 10.5MiB/s]
33%|███▎ | 130M/399M [00:12<00:25, 10.5MiB/s]
33%|███▎ | 131M/399M [00:12<00:25, 10.4MiB/s]
33%|███▎ | 133M/399M [00:12<00:25, 10.4MiB/s]
34%|███▎ | 134M/399M [00:12<00:25, 10.4MiB/s]
34%|███▍ | 135M/399M [00:12<00:25, 10.2MiB/s]
34%|███▍ | 136M/399M [00:13<00:24, 10.6MiB/s]
34%|███▍ | 137M/399M [00:13<00:24, 10.5MiB/s]
35%|███▍ | 138M/399M [00:13<00:24, 10.5MiB/s]
35%|███▍ | 139M/399M [00:13<00:24, 10.5MiB/s]
35%|███▌ | 140M/399M [00:13<00:24, 10.5MiB/s]
35%|███▌ | 141M/399M [00:13<00:24, 10.5MiB/s]
36%|███▌ | 142M/399M [00:13<00:24, 10.4MiB/s]
36%|███▌ | 143M/399M [00:13<00:24, 10.4MiB/s]
36%|███▌ | 144M/399M [00:13<00:24, 10.4MiB/s]
36%|███▋ | 145M/399M [00:13<00:24, 10.4MiB/s]
37%|███▋ | 146M/399M [00:14<00:23, 10.6MiB/s]
37%|███▋ | 147M/399M [00:14<00:23, 10.5MiB/s]
37%|███▋ | 148M/399M [00:14<00:23, 10.5MiB/s]
38%|███▊ | 149M/399M [00:14<00:23, 10.5MiB/s]
38%|███▊ | 151M/399M [00:14<00:23, 10.5MiB/s]
38%|███▊ | 152M/399M [00:14<00:23, 10.5MiB/s]
38%|███▊ | 153M/399M [00:14<00:23, 10.4MiB/s]
39%|███▊ | 154M/399M [00:14<00:23, 10.4MiB/s]
39%|███▉ | 155M/399M [00:14<00:23, 10.3MiB/s]
39%|███▉ | 156M/399M [00:14<00:23, 10.4MiB/s]
39%|███▉ | 157M/399M [00:15<00:23, 10.4MiB/s]
40%|███▉ | 158M/399M [00:15<00:23, 10.4MiB/s]
40%|███▉ | 159M/399M [00:15<00:23, 10.4MiB/s]
40%|████ | 160M/399M [00:15<00:23, 10.3MiB/s]
40%|████ | 161M/399M [00:15<00:23, 10.3MiB/s]
41%|████ | 162M/399M [00:15<00:22, 10.3MiB/s]
41%|████ | 163M/399M [00:15<00:22, 10.3MiB/s]
41%|████ | 164M/399M [00:15<00:22, 10.3MiB/s]
41%|████▏ | 165M/399M [00:15<00:22, 10.3MiB/s]
42%|████▏ | 166M/399M [00:15<00:22, 10.2MiB/s]
42%|████▏ | 167M/399M [00:16<00:22, 10.3MiB/s]
42%|████▏ | 168M/399M [00:16<00:22, 10.4MiB/s]
42%|████▏ | 169M/399M [00:16<00:22, 10.3MiB/s]
43%|████▎ | 170M/399M [00:16<00:22, 10.3MiB/s]
43%|████▎ | 171M/399M [00:16<00:21, 10.3MiB/s]
43%|████▎ | 172M/399M [00:16<00:21, 10.4MiB/s]
44%|████▎ | 173M/399M [00:16<00:21, 10.3MiB/s]
44%|████▍ | 174M/399M [00:16<00:21, 10.3MiB/s]
44%|████▍ | 176M/399M [00:16<00:21, 10.4MiB/s]
44%|████▍ | 177M/399M [00:16<00:21, 10.4MiB/s]
45%|████▍ | 178M/399M [00:17<00:21, 10.3MiB/s]
45%|████▍ | 179M/399M [00:17<00:21, 10.4MiB/s]
45%|████▌ | 180M/399M [00:17<00:21, 10.4MiB/s]
45%|████▌ | 181M/399M [00:17<00:20, 10.5MiB/s]
46%|████▌ | 182M/399M [00:17<00:20, 10.5MiB/s]
46%|████▌ | 183M/399M [00:17<00:20, 10.4MiB/s]
46%|████▌ | 184M/399M [00:17<00:20, 10.4MiB/s]
46%|████▋ | 185M/399M [00:17<00:20, 10.4MiB/s]
47%|████▋ | 186M/399M [00:17<00:20, 10.4MiB/s]
47%|████▋ | 187M/399M [00:18<00:20, 10.4MiB/s]
47%|████▋ | 188M/399M [00:18<00:20, 10.3MiB/s]
47%|████▋ | 189M/399M [00:18<00:20, 10.3MiB/s]
48%|████▊ | 190M/399M [00:18<00:20, 10.3MiB/s]
48%|████▊ | 191M/399M [00:18<00:20, 10.4MiB/s]
48%|████▊ | 192M/399M [00:18<00:19, 10.4MiB/s]
49%|████▊ | 193M/399M [00:18<00:19, 10.4MiB/s]
49%|████▉ | 194M/399M [00:18<00:19, 10.4MiB/s]
49%|████▉ | 195M/399M [00:18<00:19, 10.4MiB/s]
49%|████▉ | 196M/399M [00:18<00:19, 10.4MiB/s]
50%|████▉ | 198M/399M [00:19<00:19, 10.4MiB/s]
50%|████▉ | 199M/399M [00:19<00:19, 10.4MiB/s]
50%|█████ | 200M/399M [00:19<00:19, 10.4MiB/s]
50%|█████ | 201M/399M [00:19<00:18, 10.4MiB/s]
51%|█████ | 202M/399M [00:19<00:18, 10.4MiB/s]
51%|█████ | 203M/399M [00:19<00:18, 10.4MiB/s]
51%|█████ | 204M/399M [00:19<00:18, 10.4MiB/s]
51%|█████▏ | 205M/399M [00:19<00:18, 10.5MiB/s]
52%|█████▏ | 206M/399M [00:19<00:18, 10.5MiB/s]
52%|█████▏ | 207M/399M [00:19<00:18, 10.5MiB/s]
52%|█████▏ | 208M/399M [00:20<00:18, 10.4MiB/s]
52%|█████▏ | 209M/399M [00:20<00:18, 10.4MiB/s]
53%|█████▎ | 210M/399M [00:20<00:18, 10.4MiB/s]
53%|█████▎ | 211M/399M [00:20<00:18, 10.4MiB/s]
53%|█████▎ | 212M/399M [00:20<00:17, 10.5MiB/s]
54%|█████▎ | 213M/399M [00:20<00:17, 10.4MiB/s]
54%|█████▍ | 214M/399M [00:20<00:17, 10.4MiB/s]
54%|█████▍ | 215M/399M [00:20<00:17, 10.4MiB/s]
54%|█████▍ | 216M/399M [00:20<00:17, 10.4MiB/s]
55%|█████▍ | 217M/399M [00:20<00:17, 10.5MiB/s]
55%|█████▍ | 219M/399M [00:21<00:17, 10.5MiB/s]
55%|█████▌ | 220M/399M [00:21<00:16, 10.6MiB/s]
55%|█████▌ | 221M/399M [00:21<00:16, 10.5MiB/s]
56%|█████▌ | 222M/399M [00:21<00:16, 10.5MiB/s]
56%|█████▌ | 223M/399M [00:21<00:16, 10.5MiB/s]
56%|█████▌ | 224M/399M [00:21<00:16, 10.6MiB/s]
56%|█████▋ | 225M/399M [00:21<00:16, 10.5MiB/s]
57%|█████▋ | 226M/399M [00:21<00:16, 10.6MiB/s]
57%|█████▋ | 227M/399M [00:21<00:16, 10.6MiB/s]
57%|█████▋ | 228M/399M [00:21<00:16, 10.5MiB/s]
58%|█████▊ | 229M/399M [00:22<00:16, 10.5MiB/s]
58%|█████▊ | 230M/399M [00:22<00:16, 10.4MiB/s]
58%|█████▊ | 231M/399M [00:22<00:16, 10.4MiB/s]
58%|█████▊ | 232M/399M [00:22<00:15, 10.4MiB/s]
59%|█████▊ | 233M/399M [00:22<00:16, 10.3MiB/s]
59%|█████▉ | 234M/399M [00:22<00:15, 10.3MiB/s]
59%|█████▉ | 235M/399M [00:22<00:15, 10.3MiB/s]
59%|█████▉ | 236M/399M [00:22<00:17, 9.48MiB/s]
60%|█████▉ | 237M/399M [00:22<00:17, 9.42MiB/s]
60%|█████▉ | 238M/399M [00:22<00:16, 9.51MiB/s]
60%|██████ | 239M/399M [00:23<00:16, 9.74MiB/s]
60%|██████ | 241M/399M [00:23<00:15, 9.93MiB/s]
61%|██████ | 242M/399M [00:23<00:15, 10.1MiB/s]
61%|██████ | 243M/399M [00:23<00:15, 10.1MiB/s]
61%|██████ | 244M/399M [00:23<00:15, 10.2MiB/s]
61%|██████▏ | 245M/399M [00:23<00:14, 10.3MiB/s]
62%|██████▏ | 246M/399M [00:23<00:14, 10.3MiB/s]
62%|██████▏ | 247M/399M [00:23<00:14, 10.3MiB/s]
62%|██████▏ | 248M/399M [00:23<00:14, 10.3MiB/s]
62%|██████▏ | 249M/399M [00:23<00:14, 10.4MiB/s]
63%|██████▎ | 250M/399M [00:24<00:14, 10.4MiB/s]
63%|██████▎ | 251M/399M [00:24<00:14, 10.4MiB/s]
63%|██████▎ | 252M/399M [00:24<00:13, 10.5MiB/s]
64%|██████▎ | 253M/399M [00:24<00:13, 10.5MiB/s]
64%|██████▍ | 254M/399M [00:24<00:13, 10.4MiB/s]
64%|██████▍ | 255M/399M [00:24<00:13, 10.5MiB/s]
64%|██████▍ | 256M/399M [00:24<00:13, 10.5MiB/s]
65%|██████▍ | 257M/399M [00:24<00:13, 10.5MiB/s]
65%|██████▍ | 258M/399M [00:24<00:13, 10.5MiB/s]
65%|██████▌ | 259M/399M [00:24<00:13, 10.4MiB/s]
65%|██████▌ | 261M/399M [00:25<00:13, 10.4MiB/s]
66%|██████▌ | 262M/399M [00:25<00:13, 10.4MiB/s]
66%|██████▌ | 263M/399M [00:25<00:13, 10.4MiB/s]
66%|██████▌ | 264M/399M [00:25<00:12, 10.4MiB/s]
66%|██████▋ | 265M/399M [00:25<00:12, 10.4MiB/s]
67%|██████▋ | 266M/399M [00:25<00:12, 10.4MiB/s]
67%|██████▋ | 267M/399M [00:25<00:12, 10.3MiB/s]
67%|██████▋ | 268M/399M [00:25<00:12, 10.3MiB/s]
67%|██████▋ | 269M/399M [00:25<00:12, 10.3MiB/s]
68%|██████▊ | 270M/399M [00:26<00:12, 10.3MiB/s]
68%|██████▊ | 271M/399M [00:26<00:12, 10.3MiB/s]
68%|██████▊ | 272M/399M [00:26<00:12, 10.3MiB/s]
69%|██████▊ | 273M/399M [00:26<00:12, 10.4MiB/s]
69%|██████▉ | 274M/399M [00:26<00:11, 10.4MiB/s]
69%|██████▉ | 275M/399M [00:26<00:11, 10.4MiB/s]
69%|██████▉ | 276M/399M [00:26<00:11, 10.5MiB/s]
70%|██████▉ | 277M/399M [00:26<00:11, 10.5MiB/s]
70%|██████▉ | 278M/399M [00:26<00:11, 10.5MiB/s]
70%|███████ | 279M/399M [00:26<00:11, 10.5MiB/s]
70%|███████ | 280M/399M [00:27<00:11, 10.6MiB/s]
71%|███████ | 281M/399M [00:27<00:11, 10.5MiB/s]
71%|███████ | 283M/399M [00:27<00:11, 10.5MiB/s]
71%|███████ | 284M/399M [00:27<00:10, 10.7MiB/s]
71%|███████▏ | 285M/399M [00:27<00:10, 10.7MiB/s]
72%|███████▏ | 286M/399M [00:27<00:10, 10.6MiB/s]
72%|███████▏ | 287M/399M [00:27<00:10, 10.6MiB/s]
72%|███████▏ | 288M/399M [00:27<00:10, 10.6MiB/s]
73%|███████▎ | 289M/399M [00:27<00:10, 10.6MiB/s]
73%|███████▎ | 290M/399M [00:27<00:10, 10.7MiB/s]
73%|███████▎ | 291M/399M [00:28<00:10, 10.0MiB/s]
73%|███████▎ | 292M/399M [00:28<00:10, 10.1MiB/s]
74%|███████▎ | 293M/399M [00:28<00:10, 10.2MiB/s]
74%|███████▍ | 294M/399M [00:28<00:10, 10.3MiB/s]
74%|███████▍ | 295M/399M [00:28<00:10, 10.3MiB/s]
74%|███████▍ | 296M/399M [00:28<00:09, 10.3MiB/s]
75%|███████▍ | 297M/399M [00:28<00:09, 10.3MiB/s]
75%|███████▍ | 298M/399M [00:28<00:09, 10.4MiB/s]
75%|███████▌ | 299M/399M [00:28<00:09, 10.4MiB/s]
75%|███████▌ | 301M/399M [00:28<00:09, 10.4MiB/s]
76%|███████▌ | 302M/399M [00:29<00:09, 10.4MiB/s]
76%|███████▌ | 303M/399M [00:29<00:09, 10.4MiB/s]
76%|███████▌ | 304M/399M [00:29<00:09, 10.4MiB/s]
76%|███████▋ | 305M/399M [00:29<00:09, 10.4MiB/s]
77%|███████▋ | 306M/399M [00:29<00:08, 10.4MiB/s]
77%|███████▋ | 307M/399M [00:29<00:08, 10.5MiB/s]
77%|███████▋ | 308M/399M [00:29<00:08, 10.4MiB/s]
78%|███████▊ | 309M/399M [00:29<00:08, 10.4MiB/s]
78%|███████▊ | 310M/399M [00:29<00:08, 10.4MiB/s]
78%|███████▊ | 311M/399M [00:29<00:08, 10.4MiB/s]
78%|███████▊ | 312M/399M [00:30<00:08, 10.4MiB/s]
79%|███████▊ | 313M/399M [00:30<00:08, 10.4MiB/s]
79%|███████▉ | 314M/399M [00:30<00:08, 10.4MiB/s]
79%|███████▉ | 315M/399M [00:30<00:08, 10.4MiB/s]
79%|███████▉ | 316M/399M [00:30<00:07, 10.4MiB/s]
80%|███████▉ | 317M/399M [00:30<00:07, 10.3MiB/s]
80%|███████▉ | 318M/399M [00:30<00:07, 10.3MiB/s]
80%|████████ | 319M/399M [00:30<00:07, 10.3MiB/s]
80%|████████ | 320M/399M [00:30<00:07, 10.4MiB/s]
81%|████████ | 322M/399M [00:30<00:07, 10.4MiB/s]
81%|████████ | 323M/399M [00:31<00:07, 10.3MiB/s]
81%|████████ | 324M/399M [00:31<00:07, 10.5MiB/s]
81%|████████▏ | 325M/399M [00:31<00:07, 10.5MiB/s]
82%|████████▏ | 326M/399M [00:31<00:06, 10.4MiB/s]
82%|████████▏ | 327M/399M [00:31<00:06, 10.4MiB/s]
82%|████████▏ | 328M/399M [00:31<00:06, 10.3MiB/s]
83%|████████▎ | 329M/399M [00:31<00:06, 10.4MiB/s]
83%|████████▎ | 330M/399M [00:31<00:06, 10.4MiB/s]
83%|████████▎ | 331M/399M [00:31<00:06, 10.4MiB/s]
83%|████████▎ | 332M/399M [00:31<00:06, 10.4MiB/s]
84%|████████▎ | 333M/399M [00:32<00:06, 10.4MiB/s]
84%|████████▍ | 334M/399M [00:32<00:06, 10.3MiB/s]
84%|████████▍ | 335M/399M [00:32<00:06, 10.3MiB/s]
84%|████████▍ | 336M/399M [00:32<00:06, 10.3MiB/s]
85%|████████▍ | 337M/399M [00:32<00:05, 10.4MiB/s]
85%|████████▍ | 338M/399M [00:32<00:05, 10.4MiB/s]
85%|████████▌ | 339M/399M [00:32<00:05, 10.4MiB/s]
85%|████████▌ | 340M/399M [00:32<00:05, 10.3MiB/s]
86%|████████▌ | 341M/399M [00:32<00:05, 10.3MiB/s]
86%|████████▌ | 343M/399M [00:32<00:05, 10.4MiB/s]
86%|████████▌ | 344M/399M [00:33<00:07, 7.54MiB/s]
87%|████████▋ | 346M/399M [00:33<00:04, 11.4MiB/s]
87%|████████▋ | 347M/399M [00:33<00:04, 11.1MiB/s]
87%|████████▋ | 349M/399M [00:33<00:04, 11.3MiB/s]
88%|████████▊ | 350M/399M [00:33<00:04, 10.9MiB/s]
88%|████████▊ | 351M/399M [00:33<00:04, 10.4MiB/s]
88%|████████▊ | 352M/399M [00:33<00:04, 10.5MiB/s]
89%|████████▊ | 353M/399M [00:33<00:04, 10.6MiB/s]
89%|████████▉ | 354M/399M [00:34<00:04, 10.6MiB/s]
89%|████████▉ | 355M/399M [00:34<00:04, 10.5MiB/s]
89%|████████▉ | 356M/399M [00:34<00:04, 10.5MiB/s]
90%|████████▉ | 357M/399M [00:34<00:03, 10.6MiB/s]
90%|████████▉ | 358M/399M [00:34<00:03, 10.5MiB/s]
90%|█████████ | 359M/399M [00:34<00:03, 10.5MiB/s]
90%|█████████ | 361M/399M [00:34<00:03, 10.5MiB/s]
91%|█████████ | 362M/399M [00:34<00:03, 10.4MiB/s]
91%|█████████ | 363M/399M [00:34<00:03, 10.4MiB/s]
91%|█████████ | 364M/399M [00:35<00:03, 10.3MiB/s]
92%|█████████▏| 365M/399M [00:35<00:03, 10.3MiB/s]
92%|█████████▏| 366M/399M [00:35<00:03, 10.3MiB/s]
92%|█████████▏| 367M/399M [00:35<00:03, 10.3MiB/s]
92%|█████████▏| 368M/399M [00:35<00:02, 10.4MiB/s]
93%|█████████▎| 369M/399M [00:35<00:02, 10.4MiB/s]
93%|█████████▎| 370M/399M [00:35<00:02, 10.4MiB/s]
93%|█████████▎| 371M/399M [00:35<00:02, 10.4MiB/s]
93%|█████████▎| 372M/399M [00:35<00:02, 10.3MiB/s]
94%|█████████▎| 373M/399M [00:35<00:02, 10.5MiB/s]
94%|█████████▍| 374M/399M [00:36<00:02, 10.5MiB/s]
94%|█████████▍| 375M/399M [00:36<00:02, 10.4MiB/s]
94%|█████████▍| 376M/399M [00:36<00:02, 10.4MiB/s]
95%|█████████▍| 377M/399M [00:36<00:02, 10.4MiB/s]
95%|█████████▍| 378M/399M [00:36<00:01, 10.4MiB/s]
95%|█████████▌| 379M/399M [00:36<00:01, 10.4MiB/s]
95%|█████████▌| 380M/399M [00:36<00:01, 10.4MiB/s]
96%|█████████▌| 381M/399M [00:36<00:01, 10.3MiB/s]
96%|█████████▌| 382M/399M [00:36<00:01, 10.4MiB/s]
96%|█████████▌| 383M/399M [00:36<00:01, 10.3MiB/s]
96%|█████████▋| 385M/399M [00:37<00:01, 10.3MiB/s]
97%|█████████▋| 386M/399M [00:37<00:01, 10.3MiB/s]
97%|█████████▋| 387M/399M [00:37<00:01, 10.3MiB/s]
97%|█████████▋| 388M/399M [00:37<00:01, 10.3MiB/s]
98%|█████████▊| 389M/399M [00:37<00:00, 10.3MiB/s]
98%|█████████▊| 390M/399M [00:37<00:00, 10.3MiB/s]
98%|█████████▊| 391M/399M [00:37<00:00, 10.3MiB/s]
98%|█████████▊| 392M/399M [00:37<00:00, 10.3MiB/s]
99%|█████████▊| 393M/399M [00:37<00:00, 10.3MiB/s]
99%|█████████▉| 394M/399M [00:37<00:00, 10.4MiB/s]
99%|█████████▉| 395M/399M [00:38<00:00, 10.4MiB/s]
99%|█████████▉| 396M/399M [00:38<00:00, 10.5MiB/s]
100%|█████████▉| 397M/399M [00:38<00:00, 10.5MiB/s]
100%|█████████▉| 398M/399M [00:38<00:00, 10.5MiB/s]
100%|██████████| 399M/399M [00:38<00:00, 10.4MiB/s]
/home/runner/work/deepinv/deepinv/deepinv/utils/demo.py:22: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
x = torch.load(str(root_dir) + ".pt")
Define physics#
We simulate a sequential k-space sampler, that, over the course of 4
phases (i.e. frames), samples 64 lines (i.e 2x total undersampling from
128) with Gaussian weighting (plus a few extra for the ACS signals in
the centre of the k-space). We use
deepinv.physics.SequentialMRI
to do this.
First, we define a static 2x acceleration mask that all measurements use (of shape [B,C,H,W]):
mask_full = GaussianMaskGenerator((2, H, H), acceleration=2, device=device).step(
batch_size=batch_size
)["mask"]
Next, we randomly share the sampled lines across 4 time-phases into a time-varying mask:
# Split only in horizontal direction
masks = [mask_full[..., 0, :]]
splitter = BernoulliSplittingMaskGenerator((2, H), split_ratio=0.5, device=device)
acs = 10
# Split 4 times
for _ in range(2):
new_masks = []
for m in masks:
m1 = splitter.step(batch_size=batch_size, input_mask=m)["mask"]
m2 = m - m1
m1[..., H // 2 - acs // 2 : H // 2 + acs // 2] = 1
m2[..., H // 2 - acs // 2 : H // 2 + acs // 2] = 1
new_masks.extend([m1, m2])
masks = new_masks
# Merge masks into time dimension
mask = torch.stack(masks, 2)
# Convert to vertical lines
mask = torch.stack([mask] * H, -2)
/home/runner/work/deepinv/deepinv/deepinv/physics/generator/inpainting.py:116: UserWarning: Generating pixelwise mask assumes channel in first dimension. For 2D images (i.e. of shape (H,W)) ensure tensor_size is at least 3D (i.e. C,H,W). However, for tensor_size of shape (C,M), this will work as expected.
warn(
Now define physics using this time-varying mask of shape [B,C,T,H,W]:
physics = dinv.physics.SequentialMRI(mask=mask)
Let’s visualise the sequential measurements using a sample image (run this notebook yourself to display the video). We also visualise the frame-by-frame no-learning zero-filled reconstruction.
/home/runner/work/deepinv/deepinv/deepinv/utils/plotting.py:785: UserWarning: IPython can't be found. Install it to use display=True. Skipping...
warn("IPython can't be found. Install it to use display=True. Skipping...")
/opt/hostedtoolcache/Python/3.9.20/x64/lib/python3.9/site-packages/matplotlib/animation.py:872: UserWarning: Animation was deleted without rendering anything. This is most likely not intended. To prevent deletion, assign the Animation to a variable, e.g. `anim`, that exists until you output the Animation using `plt.show()` or `anim.save()`.
warnings.warn(
Also visualise the flattened time-series, recovering the original 2x undersampling mask (note the actual undersampling factor is much lower due to ACS lines):
Total acceleration: tensor(1.3617)
Define model#
As a (static) reconstruction network, we use an unrolled network
(half-quadratic splitting) with a trainable denoising prior based on the
DnCNN architecture as an example of a model-based deep learning architecture
from MoDL.
See deepinv.utils.demo.demo_mri_model()
for details.
Prep loss#
Perform loss on all collected lines by setting dynamic_model
to
False. Then adapt model to perform Artifact2Artifact. We set
split_size=1
to mean that each Artifact chunk containes only 1
frame.
loss = dinv.loss.Artifact2ArtifactLoss(
(2, 4, H, H), split_size=1, dynamic_model=False, device=device
)
model = loss.adapt_model(model)
Train model#
Original model is trained for 100 epochs. We demonstrate loading the pretrained model then fine-tuning with 1 epoch. Report PSNR and SSIM. To train from scratch, simply comment out the model loading code and increase the number of epochs.
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-8)
# Load pretrained model
file_name = "demo_artifact2artifact_mri.pth"
url = get_weights_url(model_name="measplit", file_name=file_name)
ckpt = torch.hub.load_state_dict_from_url(
url, map_location=lambda storage, loc: storage, file_name=file_name
)
model.load_state_dict(ckpt["state_dict"])
optimizer.load_state_dict(ckpt["optimizer"])
# Initialize the trainer
trainer = dinv.Trainer(
model,
physics=physics,
epochs=1,
losses=loss,
optimizer=optimizer,
train_dataloader=train_dataloader,
metrics=[dinv.metric.PSNR(), dinv.metric.SSIM()],
online_measurements=True,
device=device,
save_path=None,
verbose=True,
wandb_vis=False,
show_progress_bar=False,
)
model = trainer.train()
Downloading: "https://huggingface.co/deepinv/measplit/resolve/main/demo_artifact2artifact_mri.pth?download=true" to /home/runner/.cache/torch/hub/checkpoints/demo_artifact2artifact_mri.pth
0%| | 0.00/2.16M [00:00<?, ?B/s]
52%|█████▏ | 1.12M/2.16M [00:00<00:00, 11.2MB/s]
100%|██████████| 2.16M/2.16M [00:00<00:00, 11.0MB/s]
The model has 187019 trainable parameters
/opt/hostedtoolcache/Python/3.9.20/x64/lib/python3.9/site-packages/torchmetrics/utilities/prints.py:70: FutureWarning: Importing `spectral_angle_mapper` from `torchmetrics.functional` was deprecated and will be removed in 2.0. Import `spectral_angle_mapper` from `torchmetrics.image` instead.
_future_warning(
Train epoch 0: TotalLoss=0.005, PSNR=30.569, SSIM=0.824
Test the model#
trainer.plot_images = True
trainer.test(test_dataloader)
Eval epoch 0: PSNR=34.707, PSNR no learning=35.317, SSIM=0.887, SSIM no learning=0.793
Test results:
PSNR no learning: 35.317 +- 2.273
PSNR: 34.707 +- 1.726
SSIM no learning: 0.793 +- 0.069
SSIM: 0.887 +- 0.017
{'PSNR no learning': np.float64(35.31698811848958), 'PSNR no learning_std': np.float64(2.273216157077679), 'PSNR': np.float64(34.70675455729167), 'PSNR_std': np.float64(1.7261516614270602), 'SSIM no learning': np.float64(0.7928400039672852), 'SSIM no learning_std': np.float64(0.0694199089554509), 'SSIM': np.float64(0.8872030258178711), 'SSIM_std': np.float64(0.01740006355477841)}
Total running time of the script: (0 minutes 44.946 seconds)