Note
Go to the end to download the full example code.
Implementing DPS#
In this tutorial, we will go over the steps in the Diffusion Posterior Sampling (DPS) algorithm introduced in
Chung et al. The full algorithm is implemented in
deepinv.sampling.DPS()
.
Installing dependencies#
Let us import
the relevant packages, and load a sample
image of size 64 x 64. This will be used as our ground truth image.
Note
We work with an image of size 64 x 64 to reduce the computational time of this example. The algorithm works best with images of size 256 x 256.
import numpy as np
import torch
import deepinv as dinv
from deepinv.utils.plotting import plot
from deepinv.optim.data_fidelity import L2
from deepinv.utils.demo import load_url_image, get_image_url
from tqdm import tqdm # to visualize progress
device = dinv.utils.get_freer_gpu() if torch.cuda.is_available() else "cpu"
url = get_image_url("butterfly.png")
x_true = load_url_image(url=url, img_size=64).to(device)
x = x_true.clone()
In this tutorial we consider random inpainting as the inverse problem, where the forward operator is implemented
in deepinv.physics.Inpainting()
. In the example that we use, 90% of the pixels will be masked out randomly,
and we will additionally have Additive White Gaussian Noise (AWGN) of standard deviation 12.75/255.
Diffusion model loading#
We will take a pre-trained diffusion model that was also used for the DiffPIR algorithm, namely the one trained on
the FFHQ 256x256 dataset. Note that this means that the diffusion model was trained with human face images,
which is very different from the image that we consider in our example. Nevertheless, we will see later on that
DPS
generalizes sufficiently well even in such case.
model = dinv.models.DiffUNet(large_model=False).to(device)
Downloading: "https://huggingface.co/deepinv/diffunet/resolve/main/diffusion_ffhq_10m.pt?download=true" to /home/runner/.cache/torch/hub/checkpoints/diffusion_ffhq_10m.pt
0%| | 0.00/357M [00:00<?, ?B/s]
0%| | 1.12M/357M [00:00<00:32, 11.6MB/s]
1%| | 2.25M/357M [00:00<00:31, 11.6MB/s]
1%| | 3.38M/357M [00:00<00:35, 10.5MB/s]
1%|▏ | 4.50M/357M [00:00<00:33, 10.9MB/s]
2%|▏ | 5.62M/357M [00:00<00:35, 10.2MB/s]
2%|▏ | 6.75M/357M [00:00<00:34, 10.7MB/s]
2%|▏ | 7.88M/357M [00:00<00:35, 10.2MB/s]
2%|▏ | 8.88M/357M [00:00<00:35, 10.3MB/s]
3%|▎ | 10.0M/357M [00:00<00:33, 10.7MB/s]
3%|▎ | 11.1M/357M [00:01<00:35, 10.3MB/s]
3%|▎ | 12.1M/357M [00:01<00:35, 10.3MB/s]
4%|▎ | 13.2M/357M [00:01<00:33, 10.7MB/s]
4%|▍ | 14.4M/357M [00:01<00:34, 10.3MB/s]
4%|▍ | 15.4M/357M [00:01<00:34, 10.3MB/s]
5%|▍ | 16.4M/357M [00:01<00:34, 10.3MB/s]
5%|▍ | 17.4M/357M [00:01<00:34, 10.4MB/s]
5%|▌ | 18.4M/357M [00:01<00:34, 10.4MB/s]
5%|▌ | 19.4M/357M [00:01<00:34, 10.4MB/s]
6%|▌ | 20.4M/357M [00:02<00:33, 10.4MB/s]
6%|▌ | 21.4M/357M [00:02<00:33, 10.4MB/s]
6%|▋ | 22.4M/357M [00:02<00:33, 10.4MB/s]
7%|▋ | 23.4M/357M [00:02<00:33, 10.4MB/s]
7%|▋ | 24.4M/357M [00:02<00:33, 10.4MB/s]
7%|▋ | 25.4M/357M [00:02<00:33, 10.4MB/s]
7%|▋ | 26.4M/357M [00:02<00:33, 10.4MB/s]
8%|▊ | 27.5M/357M [00:02<00:31, 10.8MB/s]
8%|▊ | 28.6M/357M [00:02<00:33, 10.3MB/s]
8%|▊ | 29.6M/357M [00:02<00:33, 10.3MB/s]
9%|▊ | 30.8M/357M [00:03<00:31, 10.7MB/s]
9%|▉ | 31.9M/357M [00:03<00:33, 10.2MB/s]
9%|▉ | 32.9M/357M [00:03<00:33, 10.3MB/s]
9%|▉ | 33.9M/357M [00:03<00:32, 10.3MB/s]
10%|▉ | 34.9M/357M [00:03<00:32, 10.3MB/s]
10%|█ | 35.9M/357M [00:03<00:32, 10.3MB/s]
10%|█ | 37.0M/357M [00:03<00:31, 10.8MB/s]
11%|█ | 38.1M/357M [00:03<00:32, 10.3MB/s]
11%|█ | 39.4M/357M [00:03<00:32, 10.4MB/s]
11%|█▏ | 40.4M/357M [00:04<00:32, 10.3MB/s]
12%|█▏ | 41.4M/357M [00:04<00:32, 10.3MB/s]
12%|█▏ | 42.4M/357M [00:04<00:31, 10.3MB/s]
12%|█▏ | 43.4M/357M [00:04<00:32, 10.3MB/s]
12%|█▏ | 44.4M/357M [00:04<00:31, 10.3MB/s]
13%|█▎ | 45.4M/357M [00:04<00:31, 10.3MB/s]
13%|█▎ | 46.4M/357M [00:04<00:31, 10.3MB/s]
13%|█▎ | 47.4M/357M [00:04<00:31, 10.3MB/s]
14%|█▎ | 48.4M/357M [00:04<00:31, 10.4MB/s]
14%|█▍ | 49.4M/357M [00:04<00:31, 10.4MB/s]
14%|█▍ | 50.5M/357M [00:05<00:29, 10.8MB/s]
14%|█▍ | 51.6M/357M [00:05<00:31, 10.3MB/s]
15%|█▍ | 52.6M/357M [00:05<00:31, 10.3MB/s]
15%|█▌ | 53.6M/357M [00:05<00:30, 10.3MB/s]
15%|█▌ | 54.6M/357M [00:05<00:30, 10.3MB/s]
16%|█▌ | 55.6M/357M [00:05<00:30, 10.3MB/s]
16%|█▌ | 56.6M/357M [00:05<00:30, 10.4MB/s]
16%|█▌ | 57.6M/357M [00:05<00:30, 10.3MB/s]
16%|█▋ | 58.6M/357M [00:05<00:30, 10.4MB/s]
17%|█▋ | 59.6M/357M [00:06<00:30, 10.4MB/s]
17%|█▋ | 60.6M/357M [00:06<00:29, 10.4MB/s]
17%|█▋ | 61.9M/357M [00:06<00:29, 10.5MB/s]
18%|█▊ | 62.9M/357M [00:06<00:29, 10.4MB/s]
18%|█▊ | 63.9M/357M [00:06<00:29, 10.4MB/s]
18%|█▊ | 64.9M/357M [00:06<00:29, 10.4MB/s]
19%|█▊ | 66.1M/357M [00:06<00:29, 10.4MB/s]
19%|█▉ | 67.1M/357M [00:06<00:29, 10.4MB/s]
19%|█▉ | 68.1M/357M [00:06<00:29, 10.4MB/s]
19%|█▉ | 69.1M/357M [00:06<00:29, 10.4MB/s]
20%|█▉ | 70.2M/357M [00:07<00:27, 10.8MB/s]
20%|█▉ | 71.4M/357M [00:07<00:29, 10.3MB/s]
20%|██ | 72.4M/357M [00:07<00:28, 10.3MB/s]
21%|██ | 73.5M/357M [00:07<00:27, 10.7MB/s]
21%|██ | 74.6M/357M [00:07<00:28, 10.3MB/s]
21%|██ | 75.6M/357M [00:07<00:28, 10.3MB/s]
21%|██▏ | 76.6M/357M [00:07<00:28, 10.3MB/s]
22%|██▏ | 77.6M/357M [00:07<00:28, 10.2MB/s]
22%|██▏ | 78.6M/357M [00:07<00:28, 10.3MB/s]
22%|██▏ | 79.9M/357M [00:08<00:28, 10.3MB/s]
23%|██▎ | 81.1M/357M [00:08<00:27, 10.4MB/s]
23%|██▎ | 82.1M/357M [00:08<00:27, 10.4MB/s]
23%|██▎ | 83.1M/357M [00:08<00:27, 10.4MB/s]
24%|██▎ | 84.1M/357M [00:08<00:27, 10.4MB/s]
24%|██▍ | 85.1M/357M [00:08<00:27, 10.4MB/s]
24%|██▍ | 86.1M/357M [00:08<00:27, 10.4MB/s]
24%|██▍ | 87.1M/357M [00:08<00:27, 10.4MB/s]
25%|██▍ | 88.1M/357M [00:08<00:27, 10.4MB/s]
25%|██▍ | 89.1M/357M [00:08<00:27, 10.3MB/s]
25%|██▌ | 90.1M/357M [00:09<00:27, 10.3MB/s]
26%|██▌ | 91.1M/357M [00:09<00:26, 10.3MB/s]
26%|██▌ | 92.1M/357M [00:09<00:26, 10.4MB/s]
26%|██▌ | 93.1M/357M [00:09<00:26, 10.4MB/s]
26%|██▋ | 94.1M/357M [00:09<00:26, 10.3MB/s]
27%|██▋ | 95.1M/357M [00:09<00:26, 10.4MB/s]
27%|██▋ | 96.1M/357M [00:09<00:26, 10.4MB/s]
27%|██▋ | 97.1M/357M [00:09<00:26, 10.4MB/s]
27%|██▋ | 98.1M/357M [00:09<00:26, 10.4MB/s]
28%|██▊ | 99.1M/357M [00:09<00:25, 10.4MB/s]
28%|██▊ | 100M/357M [00:10<00:24, 10.8MB/s]
28%|██▊ | 101M/357M [00:10<00:25, 10.4MB/s]
29%|██▊ | 102M/357M [00:10<00:25, 10.4MB/s]
29%|██▉ | 103M/357M [00:10<00:25, 10.4MB/s]
29%|██▉ | 104M/357M [00:10<00:25, 10.4MB/s]
30%|██▉ | 105M/357M [00:10<00:25, 10.3MB/s]
30%|██▉ | 107M/357M [00:10<00:25, 10.4MB/s]
30%|███ | 108M/357M [00:10<00:25, 10.4MB/s]
30%|███ | 109M/357M [00:10<00:25, 10.4MB/s]
31%|███ | 110M/357M [00:11<00:25, 10.4MB/s]
31%|███ | 111M/357M [00:11<00:24, 10.3MB/s]
31%|███▏ | 112M/357M [00:11<00:24, 10.4MB/s]
32%|███▏ | 113M/357M [00:11<00:24, 10.4MB/s]
32%|███▏ | 114M/357M [00:11<00:24, 10.4MB/s]
32%|███▏ | 115M/357M [00:11<00:24, 10.4MB/s]
32%|███▏ | 116M/357M [00:11<00:24, 10.4MB/s]
33%|███▎ | 117M/357M [00:11<00:24, 10.4MB/s]
33%|███▎ | 118M/357M [00:11<00:24, 10.4MB/s]
33%|███▎ | 119M/357M [00:11<00:24, 10.4MB/s]
34%|███▎ | 120M/357M [00:12<00:24, 10.3MB/s]
34%|███▍ | 121M/357M [00:12<00:24, 10.3MB/s]
34%|███▍ | 122M/357M [00:12<00:23, 10.3MB/s]
34%|███▍ | 123M/357M [00:12<00:23, 10.3MB/s]
35%|███▍ | 124M/357M [00:12<00:23, 10.3MB/s]
35%|███▌ | 125M/357M [00:12<00:22, 10.7MB/s]
35%|███▌ | 126M/357M [00:12<00:23, 10.3MB/s]
36%|███▌ | 127M/357M [00:12<00:23, 10.3MB/s]
36%|███▌ | 128M/357M [00:12<00:23, 10.3MB/s]
36%|███▌ | 129M/357M [00:13<00:22, 10.4MB/s]
37%|███▋ | 130M/357M [00:13<00:22, 10.4MB/s]
37%|███▋ | 131M/357M [00:13<00:22, 10.4MB/s]
37%|███▋ | 132M/357M [00:13<00:22, 10.4MB/s]
37%|███▋ | 133M/357M [00:13<00:22, 10.4MB/s]
38%|███▊ | 134M/357M [00:13<00:22, 10.4MB/s]
38%|███▊ | 135M/357M [00:13<00:22, 10.4MB/s]
38%|███▊ | 136M/357M [00:13<00:22, 10.4MB/s]
38%|███▊ | 137M/357M [00:13<00:22, 10.4MB/s]
39%|███▉ | 138M/357M [00:13<00:22, 10.4MB/s]
39%|███▉ | 139M/357M [00:14<00:22, 10.4MB/s]
39%|███▉ | 140M/357M [00:14<00:22, 10.3MB/s]
40%|███▉ | 141M/357M [00:14<00:22, 10.3MB/s]
40%|███▉ | 142M/357M [00:14<00:21, 10.3MB/s]
40%|████ | 144M/357M [00:14<00:20, 10.7MB/s]
41%|████ | 145M/357M [00:14<00:21, 10.3MB/s]
41%|████ | 146M/357M [00:14<00:21, 10.3MB/s]
41%|████ | 147M/357M [00:14<00:21, 10.3MB/s]
41%|████▏ | 148M/357M [00:14<00:21, 10.3MB/s]
42%|████▏ | 149M/357M [00:14<00:20, 10.8MB/s]
42%|████▏ | 150M/357M [00:15<00:21, 10.3MB/s]
42%|████▏ | 151M/357M [00:15<00:20, 10.3MB/s]
43%|████▎ | 152M/357M [00:15<00:21, 10.2MB/s]
43%|████▎ | 153M/357M [00:15<00:20, 10.3MB/s]
43%|████▎ | 154M/357M [00:15<00:20, 10.2MB/s]
43%|████▎ | 155M/357M [00:15<00:20, 10.3MB/s]
44%|████▎ | 156M/357M [00:15<00:20, 10.3MB/s]
44%|████▍ | 157M/357M [00:15<00:20, 10.3MB/s]
44%|████▍ | 158M/357M [00:15<00:20, 10.3MB/s]
44%|████▍ | 159M/357M [00:16<00:20, 10.3MB/s]
45%|████▍ | 160M/357M [00:16<00:19, 10.3MB/s]
45%|████▌ | 161M/357M [00:16<00:19, 10.4MB/s]
45%|████▌ | 162M/357M [00:16<00:19, 10.4MB/s]
46%|████▌ | 163M/357M [00:16<00:19, 10.4MB/s]
46%|████▌ | 164M/357M [00:16<00:19, 10.4MB/s]
46%|████▌ | 165M/357M [00:16<00:19, 10.4MB/s]
46%|████▋ | 166M/357M [00:16<00:19, 10.4MB/s]
47%|████▋ | 167M/357M [00:16<00:19, 10.4MB/s]
47%|████▋ | 168M/357M [00:16<00:18, 10.8MB/s]
47%|████▋ | 169M/357M [00:17<00:19, 10.3MB/s]
48%|████▊ | 170M/357M [00:17<00:18, 10.4MB/s]
48%|████▊ | 171M/357M [00:17<00:18, 10.4MB/s]
48%|████▊ | 172M/357M [00:17<00:18, 10.4MB/s]
48%|████▊ | 173M/357M [00:17<00:18, 10.4MB/s]
49%|████▉ | 174M/357M [00:17<00:18, 10.4MB/s]
49%|████▉ | 175M/357M [00:17<00:18, 10.4MB/s]
49%|████▉ | 176M/357M [00:17<00:18, 10.4MB/s]
50%|████▉ | 177M/357M [00:17<00:18, 10.4MB/s]
50%|████▉ | 178M/357M [00:17<00:18, 10.4MB/s]
50%|█████ | 179M/357M [00:18<00:17, 10.4MB/s]
50%|█████ | 180M/357M [00:18<00:17, 10.4MB/s]
51%|█████ | 181M/357M [00:18<00:17, 10.4MB/s]
51%|█████ | 182M/357M [00:18<00:16, 10.8MB/s]
51%|█████▏ | 183M/357M [00:18<00:17, 10.3MB/s]
52%|█████▏ | 184M/357M [00:18<00:17, 10.3MB/s]
52%|█████▏ | 185M/357M [00:18<00:17, 10.4MB/s]
52%|█████▏ | 186M/357M [00:18<00:17, 10.4MB/s]
52%|█████▏ | 187M/357M [00:18<00:17, 10.4MB/s]
53%|█████▎ | 188M/357M [00:19<00:17, 10.3MB/s]
53%|█████▎ | 189M/357M [00:19<00:17, 10.3MB/s]
53%|█████▎ | 190M/357M [00:19<00:16, 10.4MB/s]
54%|█████▎ | 191M/357M [00:19<00:16, 10.4MB/s]
54%|█████▍ | 192M/357M [00:19<00:16, 10.4MB/s]
54%|█████▍ | 193M/357M [00:19<00:16, 10.4MB/s]
54%|█████▍ | 194M/357M [00:19<00:16, 10.4MB/s]
55%|█████▍ | 195M/357M [00:19<00:16, 10.4MB/s]
55%|█████▍ | 196M/357M [00:19<00:16, 10.4MB/s]
55%|█████▌ | 197M/357M [00:19<00:16, 10.4MB/s]
56%|█████▌ | 198M/357M [00:20<00:16, 10.4MB/s]
56%|█████▌ | 200M/357M [00:20<00:15, 10.8MB/s]
56%|█████▌ | 201M/357M [00:20<00:15, 10.3MB/s]
57%|█████▋ | 202M/357M [00:20<00:15, 10.7MB/s]
57%|█████▋ | 203M/357M [00:20<00:15, 10.3MB/s]
57%|█████▋ | 204M/357M [00:20<00:15, 10.3MB/s]
57%|█████▋ | 205M/357M [00:20<00:15, 10.3MB/s]
58%|█████▊ | 206M/357M [00:20<00:15, 10.2MB/s]
58%|█████▊ | 207M/357M [00:20<00:15, 10.3MB/s]
58%|█████▊ | 208M/357M [00:21<00:15, 10.4MB/s]
59%|█████▊ | 209M/357M [00:21<00:15, 10.3MB/s]
59%|█████▉ | 210M/357M [00:21<00:14, 10.8MB/s]
59%|█████▉ | 211M/357M [00:21<00:14, 10.3MB/s]
59%|█████▉ | 212M/357M [00:21<00:14, 10.3MB/s]
60%|█████▉ | 214M/357M [00:21<00:14, 10.7MB/s]
60%|██████ | 215M/357M [00:21<00:14, 10.3MB/s]
60%|██████ | 216M/357M [00:21<00:14, 10.4MB/s]
61%|██████ | 217M/357M [00:21<00:14, 10.4MB/s]
61%|██████ | 218M/357M [00:21<00:13, 10.8MB/s]
61%|██████▏ | 219M/357M [00:22<00:14, 10.3MB/s]
62%|██████▏ | 220M/357M [00:22<00:13, 10.7MB/s]
62%|██████▏ | 221M/357M [00:22<00:13, 10.2MB/s]
62%|██████▏ | 222M/357M [00:22<00:13, 10.6MB/s]
63%|██████▎ | 224M/357M [00:22<00:13, 10.3MB/s]
63%|██████▎ | 225M/357M [00:22<00:13, 10.3MB/s]
63%|██████▎ | 226M/357M [00:22<00:13, 10.3MB/s]
63%|██████▎ | 227M/357M [00:22<00:13, 10.3MB/s]
64%|██████▍ | 228M/357M [00:22<00:12, 10.7MB/s]
64%|██████▍ | 229M/357M [00:23<00:13, 10.3MB/s]
64%|██████▍ | 230M/357M [00:23<00:12, 10.3MB/s]
65%|██████▍ | 231M/357M [00:23<00:12, 10.4MB/s]
65%|██████▍ | 232M/357M [00:23<00:12, 10.4MB/s]
65%|██████▌ | 233M/357M [00:23<00:12, 10.4MB/s]
65%|██████▌ | 234M/357M [00:23<00:12, 10.4MB/s]
66%|██████▌ | 235M/357M [00:23<00:12, 10.4MB/s]
66%|██████▌ | 236M/357M [00:23<00:12, 10.4MB/s]
66%|██████▋ | 237M/357M [00:23<00:11, 10.8MB/s]
67%|██████▋ | 238M/357M [00:24<00:12, 10.3MB/s]
67%|██████▋ | 239M/357M [00:24<00:11, 10.3MB/s]
67%|██████▋ | 240M/357M [00:24<00:11, 10.2MB/s]
68%|██████▊ | 241M/357M [00:24<00:11, 10.4MB/s]
68%|██████▊ | 242M/357M [00:24<00:11, 10.4MB/s]
68%|██████▊ | 243M/357M [00:24<00:11, 10.4MB/s]
68%|██████▊ | 244M/357M [00:24<00:11, 10.4MB/s]
69%|██████▊ | 245M/357M [00:24<00:11, 10.4MB/s]
69%|██████▉ | 246M/357M [00:24<00:11, 10.4MB/s]
69%|██████▉ | 247M/357M [00:24<00:11, 10.4MB/s]
70%|██████▉ | 248M/357M [00:25<00:10, 10.4MB/s]
70%|██████▉ | 250M/357M [00:25<00:10, 10.8MB/s]
70%|███████ | 251M/357M [00:25<00:10, 10.3MB/s]
70%|███████ | 252M/357M [00:25<00:10, 10.3MB/s]
71%|███████ | 253M/357M [00:25<00:10, 10.3MB/s]
71%|███████ | 254M/357M [00:25<00:10, 10.3MB/s]
71%|███████▏ | 255M/357M [00:25<00:09, 10.8MB/s]
72%|███████▏ | 256M/357M [00:25<00:10, 10.3MB/s]
72%|███████▏ | 257M/357M [00:25<00:10, 10.3MB/s]
72%|███████▏ | 258M/357M [00:26<00:09, 10.4MB/s]
73%|███████▎ | 259M/357M [00:26<00:09, 10.4MB/s]
73%|███████▎ | 260M/357M [00:26<00:09, 10.4MB/s]
73%|███████▎ | 261M/357M [00:26<00:09, 10.4MB/s]
73%|███████▎ | 262M/357M [00:26<00:09, 10.4MB/s]
74%|███████▎ | 263M/357M [00:26<00:09, 10.4MB/s]
74%|███████▍ | 264M/357M [00:26<00:09, 10.8MB/s]
74%|███████▍ | 265M/357M [00:26<00:09, 10.3MB/s]
75%|███████▍ | 266M/357M [00:26<00:08, 10.7MB/s]
75%|███████▍ | 268M/357M [00:26<00:09, 10.3MB/s]
75%|███████▌ | 269M/357M [00:27<00:08, 10.3MB/s]
76%|███████▌ | 270M/357M [00:27<00:08, 10.3MB/s]
76%|███████▌ | 271M/357M [00:27<00:08, 10.4MB/s]
76%|███████▌ | 272M/357M [00:27<00:08, 10.8MB/s]
76%|███████▋ | 273M/357M [00:27<00:08, 10.3MB/s]
77%|███████▋ | 274M/357M [00:27<00:08, 10.3MB/s]
77%|███████▋ | 275M/357M [00:27<00:08, 10.4MB/s]
77%|███████▋ | 276M/357M [00:27<00:08, 10.4MB/s]
78%|███████▊ | 277M/357M [00:27<00:08, 10.4MB/s]
78%|███████▊ | 278M/357M [00:28<00:07, 10.4MB/s]
78%|███████▊ | 279M/357M [00:28<00:07, 10.4MB/s]
78%|███████▊ | 280M/357M [00:28<00:07, 10.4MB/s]
79%|███████▊ | 281M/357M [00:28<00:07, 10.4MB/s]
79%|███████▉ | 282M/357M [00:28<00:07, 10.4MB/s]
79%|███████▉ | 283M/357M [00:28<00:07, 10.4MB/s]
80%|███████▉ | 284M/357M [00:28<00:07, 10.4MB/s]
80%|███████▉ | 285M/357M [00:28<00:07, 10.5MB/s]
80%|████████ | 286M/357M [00:28<00:07, 10.5MB/s]
80%|████████ | 287M/357M [00:28<00:06, 10.5MB/s]
81%|████████ | 288M/357M [00:29<00:06, 10.8MB/s]
81%|████████ | 290M/357M [00:29<00:06, 10.4MB/s]
81%|████████▏ | 291M/357M [00:29<00:06, 10.4MB/s]
82%|████████▏ | 292M/357M [00:29<00:06, 10.4MB/s]
82%|████████▏ | 293M/357M [00:29<00:06, 10.4MB/s]
82%|████████▏ | 294M/357M [00:29<00:06, 10.4MB/s]
83%|████████▎ | 295M/357M [00:29<00:06, 10.4MB/s]
83%|████████▎ | 296M/357M [00:29<00:06, 10.4MB/s]
83%|████████▎ | 297M/357M [00:29<00:05, 10.8MB/s]
83%|████████▎ | 298M/357M [00:30<00:05, 10.3MB/s]
84%|████████▍ | 299M/357M [00:30<00:05, 10.4MB/s]
84%|████████▍ | 300M/357M [00:30<00:05, 10.4MB/s]
84%|████████▍ | 301M/357M [00:30<00:05, 10.4MB/s]
85%|████████▍ | 302M/357M [00:30<00:05, 10.4MB/s]
85%|████████▍ | 303M/357M [00:30<00:05, 10.4MB/s]
85%|████████▌ | 304M/357M [00:30<00:05, 10.4MB/s]
85%|████████▌ | 305M/357M [00:30<00:05, 10.4MB/s]
86%|████████▌ | 306M/357M [00:30<00:05, 10.4MB/s]
86%|████████▌ | 307M/357M [00:30<00:05, 10.4MB/s]
86%|████████▋ | 308M/357M [00:31<00:04, 10.4MB/s]
87%|████████▋ | 309M/357M [00:31<00:04, 10.4MB/s]
87%|████████▋ | 310M/357M [00:31<00:04, 10.4MB/s]
87%|████████▋ | 311M/357M [00:31<00:04, 10.4MB/s]
87%|████████▋ | 312M/357M [00:31<00:04, 10.4MB/s]
88%|████████▊ | 313M/357M [00:31<00:04, 10.4MB/s]
88%|████████▊ | 314M/357M [00:31<00:04, 10.4MB/s]
88%|████████▊ | 315M/357M [00:31<00:05, 8.41MB/s]
89%|████████▉ | 317M/357M [00:31<00:03, 11.0MB/s]
89%|████████▉ | 318M/357M [00:32<00:03, 11.2MB/s]
90%|████████▉ | 320M/357M [00:32<00:03, 10.6MB/s]
90%|████████▉ | 321M/357M [00:32<00:03, 10.8MB/s]
90%|█████████ | 322M/357M [00:32<00:03, 10.3MB/s]
90%|█████████ | 323M/357M [00:32<00:03, 10.7MB/s]
91%|█████████ | 324M/357M [00:32<00:03, 10.3MB/s]
91%|█████████ | 325M/357M [00:32<00:03, 10.7MB/s]
91%|█████████▏| 326M/357M [00:32<00:03, 10.3MB/s]
92%|█████████▏| 327M/357M [00:33<00:03, 10.3MB/s]
92%|█████████▏| 328M/357M [00:33<00:02, 10.7MB/s]
92%|█████████▏| 330M/357M [00:33<00:02, 10.3MB/s]
93%|█████████▎| 331M/357M [00:33<00:02, 10.3MB/s]
93%|█████████▎| 332M/357M [00:33<00:02, 10.3MB/s]
93%|█████████▎| 333M/357M [00:33<00:02, 10.3MB/s]
93%|█████████▎| 334M/357M [00:33<00:02, 10.3MB/s]
94%|█████████▎| 335M/357M [00:33<00:02, 10.3MB/s]
94%|█████████▍| 336M/357M [00:33<00:02, 10.4MB/s]
94%|█████████▍| 337M/357M [00:33<00:02, 10.3MB/s]
95%|█████████▍| 338M/357M [00:34<00:01, 10.4MB/s]
95%|█████████▍| 339M/357M [00:34<00:01, 10.4MB/s]
95%|█████████▌| 340M/357M [00:34<00:01, 10.4MB/s]
95%|█████████▌| 341M/357M [00:34<00:01, 10.4MB/s]
96%|█████████▌| 342M/357M [00:34<00:01, 10.3MB/s]
96%|█████████▌| 343M/357M [00:34<00:01, 10.4MB/s]
96%|█████████▋| 344M/357M [00:34<00:01, 10.4MB/s]
97%|█████████▋| 345M/357M [00:34<00:01, 10.4MB/s]
97%|█████████▋| 346M/357M [00:34<00:01, 10.4MB/s]
97%|█████████▋| 347M/357M [00:35<00:01, 10.4MB/s]
97%|█████████▋| 348M/357M [00:35<00:00, 10.4MB/s]
98%|█████████▊| 349M/357M [00:35<00:00, 10.1MB/s]
98%|█████████▊| 350M/357M [00:35<00:00, 10.2MB/s]
98%|█████████▊| 351M/357M [00:35<00:00, 10.3MB/s]
99%|█████████▊| 352M/357M [00:35<00:00, 10.3MB/s]
99%|█████████▉| 353M/357M [00:35<00:00, 10.3MB/s]
99%|█████████▉| 354M/357M [00:35<00:00, 10.3MB/s]
100%|█████████▉| 355M/357M [00:35<00:00, 10.3MB/s]
100%|█████████▉| 356M/357M [00:35<00:00, 10.4MB/s]
100%|██████████| 357M/357M [00:36<00:00, 10.4MB/s]
Define diffusion schedule#
We will use the standard linear diffusion noise schedule. Once \(\beta_t\) is defined to follow a linear schedule that interpolates between \(\beta_{\rm min}\) and \(\beta_{\rm max}\), we have the following additional definitions: \(\alpha_t := 1 - \beta_t\), \(\bar\alpha_t := \prod_{j=1}^t \alpha_j\). The following equations will also be useful later on (we always assume that \(\mathbf{\epsilon} \sim \mathcal{N}(\mathbf{0}, \mathbf{I})\) hereafter.)
where we use the reparametrization trick.
num_train_timesteps = 1000 # Number of timesteps used during training
def get_betas(
beta_start=0.1 / 1000, beta_end=20 / 1000, num_train_timesteps=num_train_timesteps
):
betas = np.linspace(beta_start, beta_end, num_train_timesteps, dtype=np.float32)
betas = torch.from_numpy(betas).to(device)
return betas
# Utility function to let us easily retrieve \bar\alpha_t
def compute_alpha(beta, t):
beta = torch.cat([torch.zeros(1).to(beta.device), beta], dim=0)
a = (1 - beta).cumprod(dim=0).index_select(0, t + 1).view(-1, 1, 1, 1)
return a
betas = get_betas()
The DPS algorithm#
Now that the inverse problem is defined, we can apply the DPS algorithm to solve it. The DPS algorithm is a diffusion algorithm that alternates between a denoising step, a gradient step and a reverse diffusion sampling step. The algorithm writes as follows, for \(t\) decreasing from \(T\) to \(1\):
where \(\denoiser{\cdot}{\sigma}\) is a denoising network for noise level \(\sigma\), \(\eta\) is a hyperparameter, and the constants \(\tilde{\sigma}_t, a_t, b_t\) are defined as
Denoising step#
The first step of DPS consists of applying a denoiser function to the current image \(\mathbf{x}_t\), with standard deviation \(\sigma_t = \sqrt{1 - \overline{\alpha}_t}/\sqrt{\overline{\alpha}_t}\).
This is equivalent to sampling \(\mathbf{x}_t \sim q(\mathbf{x}_t|\mathbf{x}_0)\), and then computing the posterior mean.
t = torch.ones(1, device=device) * 200 # choose some arbitrary timestep
at = compute_alpha(betas, t.long())
sigmat = (1 - at).sqrt() / at.sqrt()
x0 = x_true
xt = x0 + sigmat * torch.randn_like(x0)
# apply denoiser
x0_t = model(xt, sigmat)
# Visualize
imgs = [x0, xt, x0_t]
plot(
imgs,
titles=["ground-truth", "noisy", "posterior mean"],
)
/home/runner/work/deepinv/deepinv/deepinv/models/diffunet.py:411: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
sigma = torch.tensor(sigma).to(x.device)
DPS approximation#
In order to perform gradient-based posterior sampling with diffusion models, we have to be able to compute \(\nabla_{\mathbf{x}_t} \log p(\mathbf{x}_t|\mathbf{y})\). Applying Bayes rule, we have
For the former term, we can simply plug-in our estimated score function as in Tweedie’s formula. As the latter term is intractable, DPS proposes the following approximation (for details, see Theorem 1 of Chung et al.)
Remarkably, we can now compute the latter term when we have Gaussian noise, as
Moreover, taking the gradient w.r.t. \(\mathbf{x}_t\) can be performed through automatic differentiation.
Let’s see how this can be done in PyTorch. Note that when we are taking the gradient w.r.t. a tensor,
we first have to enable the gradient computation by tensor.requires_grad_()
Note
The diffPIR algorithm assumes that the images are in the range [-1, 1], whereas standard denoisers usually output images in the range [0, 1]. This is why we rescale the images before applying the steps.
x0 = x_true * 2.0 - 1.0 # [0, 1] -> [-1, 1]
data_fidelity = L2()
# xt ~ q(xt|x0)
i = 200 # choose some arbitrary timestep
t = (torch.ones(1) * i).to(device)
at = compute_alpha(betas, t.long())
sigma_cur = (1 - at).sqrt() / at.sqrt()
xt = x0 + sigma_cur * torch.randn_like(x0)
# DPS
with torch.enable_grad():
# Turn on gradient
xt.requires_grad_()
# normalize to [0,1], denoise, and rescale to [-1, 1]
x0_t = model(xt / 2 + 0.5, sigma_cur / 2) * 2 - 1
# Log-likelihood
ll = data_fidelity(x0_t, y, physics).sqrt().sum()
# Take gradient w.r.t. xt
grad_ll = torch.autograd.grad(outputs=ll, inputs=xt)[0]
# Visualize
imgs = [x0, xt, x0_t, grad_ll]
plot(
imgs,
titles=["groundtruth", "noisy", "posterior mean", "gradient"],
)
DPS Algorithm#
As we visited all the key components of DPS, we are now ready to define the algorithm. For every denoising timestep, the algorithm iterates the following
Get \(\hat{\mathbf{x}}\) using the denoiser network.
Compute \(\nabla_{\mathbf{x}_t} \log p(\mathbf{y}|\hat{\mathbf{x}}_t)\) through backpropagation.
Perform reverse diffusion sampling with DDPM(IM), corresponding to an update with \(\nabla_{\mathbf{x}_t} \log p(\mathbf{x}_t)\).
Take a gradient step with \(\nabla_{\mathbf{x}_t} \log p(\mathbf{y}|\hat{\mathbf{x}}_t)\).
There are two caveats here. First, in the original work, DPS used DDPM ancestral sampling. As the DDIM sampler is a generalization of DDPM in a sense that it retrieves DDPM when \(\eta = 1.0\), here we consider DDIM sampling. One can freely choose the \(\eta\) parameter here, but since we will consider 1000 neural function evaluations (NFEs), it is advisable to keep it \(\eta = 1.0\). Second, when taking the log-likelihood gradient step, the gradient is weighted so that the actual implementation is a static step size times the \(\ell_2\) norm of the residual:
With these in mind, let us solve the inverse problem with DPS!
Note
We only use 200 steps to reduce the computational time of this example. As suggested by the authors of DPS, the
algorithm works best with num_steps = 1000
.
num_steps = 200
skip = num_train_timesteps // num_steps
batch_size = 1
eta = 1.0
seq = range(0, num_train_timesteps, skip)
seq_next = [-1] + list(seq[:-1])
time_pairs = list(zip(reversed(seq), reversed(seq_next)))
# measurement
x0 = x_true * 2.0 - 1.0
# x0 = x_true.clone()
y = physics(x0.to(device))
# initial sample from x_T
x = torch.randn_like(x0)
xs = [x]
x0_preds = []
for i, j in tqdm(time_pairs):
t = (torch.ones(batch_size) * i).to(device)
next_t = (torch.ones(batch_size) * j).to(device)
at = compute_alpha(betas, t.long())
at_next = compute_alpha(betas, next_t.long())
xt = xs[-1].to(device)
with torch.enable_grad():
xt.requires_grad_()
# 1. denoising step
aux_x = xt / (2 * at.sqrt()) + 0.5 # renormalize in [0, 1]
sigma_cur = (1 - at).sqrt() / at.sqrt() # sigma_t
x0_t = 2 * model(aux_x, sigma_cur / 2) - 1
x0_t = torch.clip(x0_t, -1.0, 1.0) # optional
# 2. likelihood gradient approximation
l2_loss = data_fidelity(x0_t, y, physics).sqrt().sum()
norm_grad = torch.autograd.grad(outputs=l2_loss, inputs=xt)[0]
norm_grad = norm_grad.detach()
sigma_tilde = ((1 - at / at_next) * (1 - at_next) / (1 - at)).sqrt() * eta
c2 = ((1 - at_next) - sigma_tilde**2).sqrt()
# 3. noise step
epsilon = torch.randn_like(xt)
# 4. DDPM(IM) step
xt_next = (
(at_next.sqrt() - c2 * at.sqrt() / (1 - at).sqrt()) * x0_t
+ sigma_tilde * epsilon
+ c2 * xt / (1 - at).sqrt()
- norm_grad
)
x0_preds.append(x0_t.to("cpu"))
xs.append(xt_next.to("cpu"))
recon = xs[-1]
# plot the results
x = recon / 2 + 0.5
imgs = [y, x, x_true]
plot(imgs, titles=["measurement", "model output", "groundtruth"])
0%| | 0/200 [00:00<?, ?it/s]
0%| | 1/200 [00:00<01:28, 2.25it/s]
1%| | 2/200 [00:00<01:28, 2.25it/s]
2%|▏ | 3/200 [00:01<01:28, 2.23it/s]
2%|▏ | 4/200 [00:01<01:27, 2.23it/s]
2%|▎ | 5/200 [00:02<01:27, 2.22it/s]
3%|▎ | 6/200 [00:02<01:26, 2.24it/s]
4%|▎ | 7/200 [00:03<01:26, 2.23it/s]
4%|▍ | 8/200 [00:03<01:25, 2.24it/s]
4%|▍ | 9/200 [00:04<01:25, 2.24it/s]
5%|▌ | 10/200 [00:04<01:24, 2.24it/s]
6%|▌ | 11/200 [00:04<01:24, 2.24it/s]
6%|▌ | 12/200 [00:05<01:24, 2.23it/s]
6%|▋ | 13/200 [00:05<01:23, 2.24it/s]
7%|▋ | 14/200 [00:06<01:23, 2.24it/s]
8%|▊ | 15/200 [00:06<01:22, 2.25it/s]
8%|▊ | 16/200 [00:07<01:21, 2.26it/s]
8%|▊ | 17/200 [00:07<01:21, 2.26it/s]
9%|▉ | 18/200 [00:08<01:20, 2.27it/s]
10%|▉ | 19/200 [00:08<01:20, 2.26it/s]
10%|█ | 20/200 [00:08<01:19, 2.27it/s]
10%|█ | 21/200 [00:09<01:19, 2.26it/s]
11%|█ | 22/200 [00:09<01:18, 2.26it/s]
12%|█▏ | 23/200 [00:10<01:18, 2.27it/s]
12%|█▏ | 24/200 [00:10<01:17, 2.26it/s]
12%|█▎ | 25/200 [00:11<01:17, 2.26it/s]
13%|█▎ | 26/200 [00:11<01:17, 2.25it/s]
14%|█▎ | 27/200 [00:12<01:16, 2.25it/s]
14%|█▍ | 28/200 [00:12<01:16, 2.25it/s]
14%|█▍ | 29/200 [00:12<01:15, 2.26it/s]
15%|█▌ | 30/200 [00:13<01:15, 2.26it/s]
16%|█▌ | 31/200 [00:13<01:14, 2.26it/s]
16%|█▌ | 32/200 [00:14<01:14, 2.26it/s]
16%|█▋ | 33/200 [00:14<01:14, 2.25it/s]
17%|█▋ | 34/200 [00:15<01:13, 2.25it/s]
18%|█▊ | 35/200 [00:15<01:13, 2.25it/s]
18%|█▊ | 36/200 [00:16<01:13, 2.24it/s]
18%|█▊ | 37/200 [00:16<01:12, 2.24it/s]
19%|█▉ | 38/200 [00:16<01:12, 2.25it/s]
20%|█▉ | 39/200 [00:17<01:11, 2.25it/s]
20%|██ | 40/200 [00:17<01:11, 2.25it/s]
20%|██ | 41/200 [00:18<01:10, 2.25it/s]
21%|██ | 42/200 [00:18<01:09, 2.26it/s]
22%|██▏ | 43/200 [00:19<01:09, 2.25it/s]
22%|██▏ | 44/200 [00:19<01:09, 2.26it/s]
22%|██▎ | 45/200 [00:20<01:09, 2.24it/s]
23%|██▎ | 46/200 [00:20<01:08, 2.24it/s]
24%|██▎ | 47/200 [00:20<01:08, 2.22it/s]
24%|██▍ | 48/200 [00:21<01:08, 2.23it/s]
24%|██▍ | 49/200 [00:21<01:07, 2.23it/s]
25%|██▌ | 50/200 [00:22<01:07, 2.22it/s]
26%|██▌ | 51/200 [00:22<01:06, 2.23it/s]
26%|██▌ | 52/200 [00:23<01:06, 2.24it/s]
26%|██▋ | 53/200 [00:23<01:05, 2.25it/s]
27%|██▋ | 54/200 [00:24<01:06, 2.21it/s]
28%|██▊ | 55/200 [00:24<01:05, 2.23it/s]
28%|██▊ | 56/200 [00:24<01:04, 2.23it/s]
28%|██▊ | 57/200 [00:25<01:04, 2.21it/s]
29%|██▉ | 58/200 [00:25<01:04, 2.21it/s]
30%|██▉ | 59/200 [00:26<01:03, 2.21it/s]
30%|███ | 60/200 [00:26<01:03, 2.22it/s]
30%|███ | 61/200 [00:27<01:02, 2.21it/s]
31%|███ | 62/200 [00:27<01:02, 2.22it/s]
32%|███▏ | 63/200 [00:28<01:01, 2.23it/s]
32%|███▏ | 64/200 [00:28<01:01, 2.22it/s]
32%|███▎ | 65/200 [00:28<01:00, 2.23it/s]
33%|███▎ | 66/200 [00:29<01:00, 2.22it/s]
34%|███▎ | 67/200 [00:29<00:59, 2.24it/s]
34%|███▍ | 68/200 [00:30<00:59, 2.21it/s]
34%|███▍ | 69/200 [00:30<01:00, 2.17it/s]
35%|███▌ | 70/200 [00:31<00:59, 2.19it/s]
36%|███▌ | 71/200 [00:31<00:58, 2.21it/s]
36%|███▌ | 72/200 [00:32<00:57, 2.23it/s]
36%|███▋ | 73/200 [00:32<00:57, 2.23it/s]
37%|███▋ | 74/200 [00:33<00:56, 2.24it/s]
38%|███▊ | 75/200 [00:33<00:56, 2.22it/s]
38%|███▊ | 76/200 [00:33<00:55, 2.22it/s]
38%|███▊ | 77/200 [00:34<00:55, 2.23it/s]
39%|███▉ | 78/200 [00:34<00:54, 2.22it/s]
40%|███▉ | 79/200 [00:35<00:54, 2.23it/s]
40%|████ | 80/200 [00:35<00:54, 2.21it/s]
40%|████ | 81/200 [00:36<00:53, 2.22it/s]
41%|████ | 82/200 [00:36<00:53, 2.20it/s]
42%|████▏ | 83/200 [00:37<00:52, 2.21it/s]
42%|████▏ | 84/200 [00:37<00:52, 2.23it/s]
42%|████▎ | 85/200 [00:38<00:51, 2.22it/s]
43%|████▎ | 86/200 [00:38<00:51, 2.23it/s]
44%|████▎ | 87/200 [00:38<00:51, 2.20it/s]
44%|████▍ | 88/200 [00:39<00:50, 2.21it/s]
44%|████▍ | 89/200 [00:39<00:50, 2.19it/s]
45%|████▌ | 90/200 [00:40<00:50, 2.19it/s]
46%|████▌ | 91/200 [00:40<00:50, 2.17it/s]
46%|████▌ | 92/200 [00:41<00:49, 2.17it/s]
46%|████▋ | 93/200 [00:41<00:49, 2.18it/s]
47%|████▋ | 94/200 [00:42<00:48, 2.18it/s]
48%|████▊ | 95/200 [00:42<00:47, 2.20it/s]
48%|████▊ | 96/200 [00:43<00:47, 2.20it/s]
48%|████▊ | 97/200 [00:43<00:46, 2.21it/s]
49%|████▉ | 98/200 [00:43<00:46, 2.20it/s]
50%|████▉ | 99/200 [00:44<00:45, 2.21it/s]
50%|█████ | 100/200 [00:44<00:44, 2.22it/s]
50%|█████ | 101/200 [00:45<00:44, 2.23it/s]
51%|█████ | 102/200 [00:45<00:43, 2.24it/s]
52%|█████▏ | 103/200 [00:46<00:43, 2.22it/s]
52%|█████▏ | 104/200 [00:46<00:42, 2.24it/s]
52%|█████▎ | 105/200 [00:47<00:42, 2.25it/s]
53%|█████▎ | 106/200 [00:47<00:42, 2.24it/s]
54%|█████▎ | 107/200 [00:47<00:41, 2.24it/s]
54%|█████▍ | 108/200 [00:48<00:41, 2.20it/s]
55%|█████▍ | 109/200 [00:48<00:41, 2.22it/s]
55%|█████▌ | 110/200 [00:49<00:41, 2.19it/s]
56%|█████▌ | 111/200 [00:49<00:40, 2.21it/s]
56%|█████▌ | 112/200 [00:50<00:39, 2.22it/s]
56%|█████▋ | 113/200 [00:50<00:39, 2.20it/s]
57%|█████▋ | 114/200 [00:51<00:38, 2.22it/s]
57%|█████▊ | 115/200 [00:51<00:38, 2.20it/s]
58%|█████▊ | 116/200 [00:52<00:38, 2.21it/s]
58%|█████▊ | 117/200 [00:52<00:37, 2.19it/s]
59%|█████▉ | 118/200 [00:52<00:37, 2.21it/s]
60%|█████▉ | 119/200 [00:53<00:36, 2.22it/s]
60%|██████ | 120/200 [00:53<00:36, 2.20it/s]
60%|██████ | 121/200 [00:54<00:35, 2.21it/s]
61%|██████ | 122/200 [00:54<00:35, 2.20it/s]
62%|██████▏ | 123/200 [00:55<00:34, 2.21it/s]
62%|██████▏ | 124/200 [00:55<00:34, 2.19it/s]
62%|██████▎ | 125/200 [00:56<00:33, 2.21it/s]
63%|██████▎ | 126/200 [00:56<00:33, 2.22it/s]
64%|██████▎ | 127/200 [00:57<00:32, 2.21it/s]
64%|██████▍ | 128/200 [00:57<00:32, 2.22it/s]
64%|██████▍ | 129/200 [00:57<00:32, 2.20it/s]
65%|██████▌ | 130/200 [00:58<00:31, 2.21it/s]
66%|██████▌ | 131/200 [00:58<00:31, 2.19it/s]
66%|██████▌ | 132/200 [00:59<00:30, 2.21it/s]
66%|██████▋ | 133/200 [00:59<00:30, 2.22it/s]
67%|██████▋ | 134/200 [01:00<00:29, 2.21it/s]
68%|██████▊ | 135/200 [01:00<00:29, 2.21it/s]
68%|██████▊ | 136/200 [01:01<00:29, 2.20it/s]
68%|██████▊ | 137/200 [01:01<00:28, 2.21it/s]
69%|██████▉ | 138/200 [01:02<00:28, 2.18it/s]
70%|██████▉ | 139/200 [01:02<00:27, 2.20it/s]
70%|███████ | 140/200 [01:02<00:27, 2.22it/s]
70%|███████ | 141/200 [01:03<00:26, 2.22it/s]
71%|███████ | 142/200 [01:03<00:25, 2.23it/s]
72%|███████▏ | 143/200 [01:04<00:26, 2.18it/s]
72%|███████▏ | 144/200 [01:04<00:25, 2.20it/s]
72%|███████▎ | 145/200 [01:05<00:25, 2.18it/s]
73%|███████▎ | 146/200 [01:05<00:24, 2.20it/s]
74%|███████▎ | 147/200 [01:06<00:23, 2.22it/s]
74%|███████▍ | 148/200 [01:06<00:23, 2.22it/s]
74%|███████▍ | 149/200 [01:07<00:22, 2.23it/s]
75%|███████▌ | 150/200 [01:07<00:22, 2.21it/s]
76%|███████▌ | 151/200 [01:07<00:22, 2.23it/s]
76%|███████▌ | 152/200 [01:08<00:21, 2.23it/s]
76%|███████▋ | 153/200 [01:08<00:20, 2.25it/s]
77%|███████▋ | 154/200 [01:09<00:20, 2.25it/s]
78%|███████▊ | 155/200 [01:09<00:20, 2.22it/s]
78%|███████▊ | 156/200 [01:10<00:19, 2.23it/s]
78%|███████▊ | 157/200 [01:10<00:19, 2.22it/s]
79%|███████▉ | 158/200 [01:11<00:18, 2.23it/s]
80%|███████▉ | 159/200 [01:11<00:18, 2.23it/s]
80%|████████ | 160/200 [01:11<00:17, 2.23it/s]
80%|████████ | 161/200 [01:12<00:17, 2.23it/s]
81%|████████ | 162/200 [01:12<00:17, 2.23it/s]
82%|████████▏ | 163/200 [01:13<00:16, 2.23it/s]
82%|████████▏ | 164/200 [01:13<00:16, 2.22it/s]
82%|████████▎ | 165/200 [01:14<00:15, 2.22it/s]
83%|████████▎ | 166/200 [01:14<00:15, 2.22it/s]
84%|████████▎ | 167/200 [01:15<00:14, 2.21it/s]
84%|████████▍ | 168/200 [01:15<00:14, 2.22it/s]
84%|████████▍ | 169/200 [01:15<00:13, 2.22it/s]
85%|████████▌ | 170/200 [01:16<00:13, 2.23it/s]
86%|████████▌ | 171/200 [01:16<00:12, 2.24it/s]
86%|████████▌ | 172/200 [01:17<00:12, 2.25it/s]
86%|████████▋ | 173/200 [01:17<00:11, 2.25it/s]
87%|████████▋ | 174/200 [01:18<00:11, 2.24it/s]
88%|████████▊ | 175/200 [01:18<00:11, 2.25it/s]
88%|████████▊ | 176/200 [01:19<00:10, 2.24it/s]
88%|████████▊ | 177/200 [01:19<00:10, 2.24it/s]
89%|████████▉ | 178/200 [01:19<00:09, 2.24it/s]
90%|████████▉ | 179/200 [01:20<00:09, 2.24it/s]
90%|█████████ | 180/200 [01:20<00:08, 2.25it/s]
90%|█████████ | 181/200 [01:21<00:08, 2.25it/s]
91%|█████████ | 182/200 [01:21<00:08, 2.25it/s]
92%|█████████▏| 183/200 [01:22<00:07, 2.25it/s]
92%|█████████▏| 184/200 [01:22<00:07, 2.26it/s]
92%|█████████▎| 185/200 [01:23<00:06, 2.26it/s]
93%|█████████▎| 186/200 [01:23<00:06, 2.21it/s]
94%|█████████▎| 187/200 [01:24<00:05, 2.22it/s]
94%|█████████▍| 188/200 [01:24<00:05, 2.18it/s]
94%|█████████▍| 189/200 [01:24<00:04, 2.21it/s]
95%|█████████▌| 190/200 [01:25<00:04, 2.21it/s]
96%|█████████▌| 191/200 [01:25<00:04, 2.22it/s]
96%|█████████▌| 192/200 [01:26<00:03, 2.24it/s]
96%|█████████▋| 193/200 [01:26<00:03, 2.23it/s]
97%|█████████▋| 194/200 [01:27<00:02, 2.24it/s]
98%|█████████▊| 195/200 [01:27<00:02, 2.22it/s]
98%|█████████▊| 196/200 [01:28<00:01, 2.23it/s]
98%|█████████▊| 197/200 [01:28<00:01, 2.24it/s]
99%|█████████▉| 198/200 [01:28<00:00, 2.23it/s]
100%|█████████▉| 199/200 [01:29<00:00, 2.23it/s]
100%|██████████| 200/200 [01:29<00:00, 2.23it/s]
100%|██████████| 200/200 [01:29<00:00, 2.23it/s]
Using DPS in your inverse problem#
You can readily use this algorithm via the deepinv.sampling.DPS()
class.
y = physics(x)
model = dinv.sampling.DPS(dinv.models.DiffUNet(), data_fidelity=dinv.optim.data_fidelity.L2())
xhat = model(y, physics)
Total running time of the script: (2 minutes 7.494 seconds)