Note
Go to the end to download the full example code.
Building your diffusion posterior sampling method using SDEs#
This demo shows you how to use
deepinv.sampling.PosteriorDiffusion
to perform posterior sampling. It also can be used to perform unconditional image generation with arbitrary denoisers, if the data fidelity term is not specified.
This method requires:
A well-trained denoiser with varying noise levels (ideally with large noise levels) (e.g.,
deepinv.models.NCSNpp
).A (noisy) data fidelity term (e.g.,
deepinv.sampling.DPSDataFidelity
).Define a drift term \(f(x, t)\) and a diffusion term \(g(t)\) for the forward-time SDE. They can be defined through the
deepinv.sampling.DiffusionSDE
(e.g.,deepinv.sampling.VarianceExplodingDiffusion
).
The deepinv.sampling.PosteriorDiffusion
class can be used to perform posterior sampling for inverse problems.
Consider the acquisition model:
where \(\forw{x}\) is the forward operator (e.g., a convolutional operator) and \(\noise{\cdot}\) is the noise operator (e.g., Gaussian noise). This class defines the reverse-time SDE for the posterior distribution \(p(x|y)\) given the data \(y\):
where \(f\) is the drift term, \(g\) is the diffusion coefficient and \(w\) is the standard Brownian motion.
The drift term and the diffusion coefficient are defined by the underlying (unconditional) forward-time SDE sde
.
In this example, we will use 2 well-known SDE in the literature: the Variance-Exploding (VE) and Variance-Preserving (VP or DDPM).
The (conditional) score function \(\nabla_{x_t} \log p_t(x_t | y)\) can be decomposed using the Bayes’ rule:
The first term is the score function of the unconditional SDE, which is typically approximated by an MMSE denoiser (denoiser
) using the well-known Tweedie’s formula, while the
second term is approximated by the (noisy) data-fidelity term (data_fidelity
).
We implement various data-fidelity terms in the user guide.
Let us import the necessary modules, define the denoiser and the SDE.
In this first example, we use the Variance-Exploding SDE, whose forward process is defined as:
import torch
import deepinv as dinv
from deepinv.models import NCSNpp
device = "cuda" if torch.cuda.is_available() else "cpu"
dtype = torch.float64
figsize = 2.5
gif_frequency = 10 # Increase this value to save the GIF saving time
from deepinv.sampling import (
PosteriorDiffusion,
DPSDataFidelity,
EulerSolver,
VarianceExplodingDiffusion,
)
from deepinv.optim import ZeroFidelity
# In this example, we use the pre-trained FFHQ-64 model from the
# EDM framework: https://arxiv.org/pdf/2206.00364 .
# The network architecture is from Song et al: https://arxiv.org/abs/2011.13456 .
denoiser = NCSNpp(pretrained="download").to(device)
# The solution is obtained by calling the SDE object with a desired solver (here, Euler).
# The reproducibility of the SDE Solver class can be controlled by providing the pseudo-random number generator.
num_steps = 150
rng = torch.Generator(device).manual_seed(42)
timesteps = torch.linspace(1, 0.001, num_steps)
solver = EulerSolver(timesteps=timesteps, rng=rng)
sigma_min = 0.02
sigma_max = 20
sde = VarianceExplodingDiffusion(
sigma_max=sigma_max,
sigma_min=sigma_min,
alpha=0.5,
device=device,
dtype=dtype,
)
Downloading: "https://huggingface.co/deepinv/edm/resolve/main/ncsnpp-ffhq64-uncond-ve.pt?download=true" to /home/runner/.cache/torch/hub/checkpoints/ncsnpp-ffhq64-uncond-ve.pt
0%| | 0.00/240M [00:00<?, ?B/s]
4%|▎ | 8.75M/240M [00:00<00:02, 91.5MB/s]
7%|▋ | 17.5M/240M [00:00<00:02, 86.9MB/s]
17%|█▋ | 39.6M/240M [00:00<00:01, 150MB/s]
27%|██▋ | 64.2M/240M [00:00<00:00, 192MB/s]
35%|███▍ | 82.8M/240M [00:00<00:00, 178MB/s]
47%|████▋ | 112M/240M [00:00<00:00, 217MB/s]
56%|█████▋ | 135M/240M [00:00<00:00, 227MB/s]
66%|██████▌ | 157M/240M [00:00<00:00, 218MB/s]
76%|███████▌ | 182M/240M [00:00<00:00, 228MB/s]
89%|████████▉ | 213M/240M [00:01<00:00, 258MB/s]
100%|██████████| 240M/240M [00:01<00:00, 220MB/s]
Reverse-time SDE as sampling process#
When the data fidelity is not given, the posterior diffusion is equivalent to the unconditional diffusion. Sampling is performed by solving the reverse-time SDE. To do so, we generate a reverse-time trajectory.
model = PosteriorDiffusion(
data_fidelity=ZeroFidelity(),
sde=sde,
denoiser=denoiser,
solver=solver,
dtype=dtype,
device=device,
verbose=True,
)
sample_seed_1, trajectory_seed_1 = model(
y=None,
physics=None,
x_init=(1, 3, 64, 64),
seed=1,
get_trajectory=True,
)
dinv.utils.plot(
sample_seed_1,
titles="Unconditional generation",
save_fn="sde_sample.png",
figsize=(figsize, figsize),
)
dinv.utils.save_videos(
trajectory_seed_1.cpu()[::gif_frequency],
time_dim=0,
titles=["VE-SDE Trajectory"],
save_fn="sde_trajectory.gif",
figsize=(figsize, figsize),
)
0%| | 0/149 [00:00<?, ?it/s]
1%| | 1/149 [00:00<01:21, 1.81it/s]
1%|▏ | 2/149 [00:01<01:20, 1.83it/s]
2%|▏ | 3/149 [00:01<01:19, 1.84it/s]
3%|▎ | 4/149 [00:02<01:18, 1.84it/s]
3%|▎ | 5/149 [00:02<01:18, 1.84it/s]
4%|▍ | 6/149 [00:03<01:17, 1.84it/s]
5%|▍ | 7/149 [00:03<01:17, 1.84it/s]
5%|▌ | 8/149 [00:04<01:16, 1.84it/s]
6%|▌ | 9/149 [00:04<01:15, 1.84it/s]
7%|▋ | 10/149 [00:05<01:15, 1.85it/s]
7%|▋ | 11/149 [00:05<01:14, 1.84it/s]
8%|▊ | 12/149 [00:06<01:14, 1.84it/s]
9%|▊ | 13/149 [00:07<01:14, 1.82it/s]
9%|▉ | 14/149 [00:07<01:14, 1.81it/s]
10%|█ | 15/149 [00:08<01:14, 1.81it/s]
11%|█ | 16/149 [00:08<01:13, 1.81it/s]
11%|█▏ | 17/149 [00:09<01:12, 1.82it/s]
12%|█▏ | 18/149 [00:09<01:12, 1.82it/s]
13%|█▎ | 19/149 [00:10<01:11, 1.82it/s]
13%|█▎ | 20/149 [00:10<01:10, 1.83it/s]
14%|█▍ | 21/149 [00:11<01:09, 1.83it/s]
15%|█▍ | 22/149 [00:12<01:09, 1.83it/s]
15%|█▌ | 23/149 [00:12<01:08, 1.83it/s]
16%|█▌ | 24/149 [00:13<01:08, 1.83it/s]
17%|█▋ | 25/149 [00:13<01:07, 1.84it/s]
17%|█▋ | 26/149 [00:14<01:06, 1.84it/s]
18%|█▊ | 27/149 [00:14<01:06, 1.84it/s]
19%|█▉ | 28/149 [00:15<01:07, 1.81it/s]
19%|█▉ | 29/149 [00:15<01:05, 1.82it/s]
20%|██ | 30/149 [00:16<01:05, 1.82it/s]
21%|██ | 31/149 [00:16<01:04, 1.83it/s]
21%|██▏ | 32/149 [00:17<01:03, 1.83it/s]
22%|██▏ | 33/149 [00:18<01:03, 1.84it/s]
23%|██▎ | 34/149 [00:18<01:02, 1.84it/s]
23%|██▎ | 35/149 [00:19<01:01, 1.85it/s]
24%|██▍ | 36/149 [00:19<01:01, 1.83it/s]
25%|██▍ | 37/149 [00:20<01:01, 1.83it/s]
26%|██▌ | 38/149 [00:20<01:00, 1.84it/s]
26%|██▌ | 39/149 [00:21<00:59, 1.84it/s]
27%|██▋ | 40/149 [00:21<00:59, 1.84it/s]
28%|██▊ | 41/149 [00:22<00:58, 1.84it/s]
28%|██▊ | 42/149 [00:22<00:58, 1.84it/s]
29%|██▉ | 43/149 [00:23<00:57, 1.84it/s]
30%|██▉ | 44/149 [00:23<00:56, 1.85it/s]
30%|███ | 45/149 [00:24<00:56, 1.85it/s]
31%|███ | 46/149 [00:25<00:55, 1.85it/s]
32%|███▏ | 47/149 [00:25<00:55, 1.85it/s]
32%|███▏ | 48/149 [00:26<00:54, 1.85it/s]
33%|███▎ | 49/149 [00:26<00:54, 1.84it/s]
34%|███▎ | 50/149 [00:27<00:53, 1.85it/s]
34%|███▍ | 51/149 [00:27<00:53, 1.85it/s]
35%|███▍ | 52/149 [00:28<00:52, 1.85it/s]
36%|███▌ | 53/149 [00:28<00:51, 1.85it/s]
36%|███▌ | 54/149 [00:29<00:51, 1.85it/s]
37%|███▋ | 55/149 [00:29<00:50, 1.84it/s]
38%|███▊ | 56/149 [00:30<00:50, 1.83it/s]
38%|███▊ | 57/149 [00:31<00:50, 1.84it/s]
39%|███▉ | 58/149 [00:31<00:49, 1.84it/s]
40%|███▉ | 59/149 [00:32<00:50, 1.79it/s]
40%|████ | 60/149 [00:32<00:49, 1.80it/s]
41%|████ | 61/149 [00:33<00:48, 1.82it/s]
42%|████▏ | 62/149 [00:33<00:47, 1.82it/s]
42%|████▏ | 63/149 [00:34<00:46, 1.83it/s]
43%|████▎ | 64/149 [00:34<00:46, 1.82it/s]
44%|████▎ | 65/149 [00:35<00:45, 1.83it/s]
44%|████▍ | 66/149 [00:35<00:45, 1.83it/s]
45%|████▍ | 67/149 [00:36<00:44, 1.84it/s]
46%|████▌ | 68/149 [00:37<00:44, 1.83it/s]
46%|████▋ | 69/149 [00:37<00:43, 1.84it/s]
47%|████▋ | 70/149 [00:38<00:42, 1.84it/s]
48%|████▊ | 71/149 [00:38<00:42, 1.85it/s]
48%|████▊ | 72/149 [00:39<00:41, 1.84it/s]
49%|████▉ | 73/149 [00:39<00:41, 1.84it/s]
50%|████▉ | 74/149 [00:40<00:40, 1.84it/s]
50%|█████ | 75/149 [00:40<00:40, 1.84it/s]
51%|█████ | 76/149 [00:41<00:39, 1.85it/s]
52%|█████▏ | 77/149 [00:41<00:39, 1.85it/s]
52%|█████▏ | 78/149 [00:42<00:38, 1.85it/s]
53%|█████▎ | 79/149 [00:43<00:37, 1.85it/s]
54%|█████▎ | 80/149 [00:43<00:37, 1.85it/s]
54%|█████▍ | 81/149 [00:44<00:36, 1.85it/s]
55%|█████▌ | 82/149 [00:44<00:36, 1.85it/s]
56%|█████▌ | 83/149 [00:45<00:35, 1.85it/s]
56%|█████▋ | 84/149 [00:45<00:35, 1.85it/s]
57%|█████▋ | 85/149 [00:46<00:34, 1.85it/s]
58%|█████▊ | 86/149 [00:46<00:34, 1.85it/s]
58%|█████▊ | 87/149 [00:47<00:33, 1.85it/s]
59%|█████▉ | 88/149 [00:47<00:32, 1.85it/s]
60%|█████▉ | 89/149 [00:48<00:32, 1.85it/s]
60%|██████ | 90/149 [00:48<00:31, 1.85it/s]
61%|██████ | 91/149 [00:49<00:31, 1.85it/s]
62%|██████▏ | 92/149 [00:50<00:30, 1.85it/s]
62%|██████▏ | 93/149 [00:50<00:30, 1.85it/s]
63%|██████▎ | 94/149 [00:51<00:29, 1.85it/s]
64%|██████▍ | 95/149 [00:51<00:29, 1.84it/s]
64%|██████▍ | 96/149 [00:52<00:28, 1.84it/s]
65%|██████▌ | 97/149 [00:52<00:28, 1.85it/s]
66%|██████▌ | 98/149 [00:53<00:27, 1.85it/s]
66%|██████▋ | 99/149 [00:53<00:27, 1.85it/s]
67%|██████▋ | 100/149 [00:54<00:26, 1.85it/s]
68%|██████▊ | 101/149 [00:54<00:25, 1.85it/s]
68%|██████▊ | 102/149 [00:55<00:25, 1.85it/s]
69%|██████▉ | 103/149 [00:56<00:24, 1.85it/s]
70%|██████▉ | 104/149 [00:56<00:24, 1.85it/s]
70%|███████ | 105/149 [00:57<00:23, 1.85it/s]
71%|███████ | 106/149 [00:57<00:23, 1.85it/s]
72%|███████▏ | 107/149 [00:58<00:22, 1.85it/s]
72%|███████▏ | 108/149 [00:58<00:22, 1.85it/s]
73%|███████▎ | 109/149 [00:59<00:21, 1.85it/s]
74%|███████▍ | 110/149 [00:59<00:21, 1.85it/s]
74%|███████▍ | 111/149 [01:00<00:20, 1.85it/s]
75%|███████▌ | 112/149 [01:00<00:20, 1.85it/s]
76%|███████▌ | 113/149 [01:01<00:19, 1.85it/s]
77%|███████▋ | 114/149 [01:01<00:18, 1.85it/s]
77%|███████▋ | 115/149 [01:02<00:18, 1.85it/s]
78%|███████▊ | 116/149 [01:03<00:17, 1.85it/s]
79%|███████▊ | 117/149 [01:03<00:17, 1.85it/s]
79%|███████▉ | 118/149 [01:04<00:16, 1.85it/s]
80%|███████▉ | 119/149 [01:04<00:16, 1.85it/s]
81%|████████ | 120/149 [01:05<00:15, 1.85it/s]
81%|████████ | 121/149 [01:05<00:15, 1.85it/s]
82%|████████▏ | 122/149 [01:06<00:14, 1.85it/s]
83%|████████▎ | 123/149 [01:06<00:14, 1.85it/s]
83%|████████▎ | 124/149 [01:07<00:13, 1.85it/s]
84%|████████▍ | 125/149 [01:07<00:12, 1.85it/s]
85%|████████▍ | 126/149 [01:08<00:12, 1.85it/s]
85%|████████▌ | 127/149 [01:09<00:11, 1.85it/s]
86%|████████▌ | 128/149 [01:09<00:11, 1.85it/s]
87%|████████▋ | 129/149 [01:10<00:10, 1.85it/s]
87%|████████▋ | 130/149 [01:10<00:10, 1.85it/s]
88%|████████▊ | 131/149 [01:11<00:09, 1.85it/s]
89%|████████▊ | 132/149 [01:11<00:09, 1.85it/s]
89%|████████▉ | 133/149 [01:12<00:08, 1.85it/s]
90%|████████▉ | 134/149 [01:12<00:08, 1.85it/s]
91%|█████████ | 135/149 [01:13<00:07, 1.85it/s]
91%|█████████▏| 136/149 [01:13<00:07, 1.85it/s]
92%|█████████▏| 137/149 [01:14<00:06, 1.85it/s]
93%|█████████▎| 138/149 [01:14<00:05, 1.85it/s]
93%|█████████▎| 139/149 [01:15<00:05, 1.85it/s]
94%|█████████▍| 140/149 [01:16<00:04, 1.85it/s]
95%|█████████▍| 141/149 [01:16<00:04, 1.85it/s]
95%|█████████▌| 142/149 [01:17<00:03, 1.85it/s]
96%|█████████▌| 143/149 [01:17<00:03, 1.85it/s]
97%|█████████▋| 144/149 [01:18<00:02, 1.85it/s]
97%|█████████▋| 145/149 [01:18<00:02, 1.85it/s]
98%|█████████▊| 146/149 [01:19<00:01, 1.84it/s]
99%|█████████▊| 147/149 [01:19<00:01, 1.84it/s]
99%|█████████▉| 148/149 [01:20<00:00, 1.84it/s]
100%|██████████| 149/149 [01:20<00:00, 1.85it/s]
100%|██████████| 149/149 [01:20<00:00, 1.84it/s]
We obtain the following unconditional sample


When the data fidelity is given, together with the measurements and the physics, this class can be used to perform posterior sampling for inverse problems.
For example, consider the inpainting problem, where we have a noisy image and we want to recover the original image.
We can use the deepinv.sampling.DPSDataFidelity
as the data fidelity term.
x = sample_seed_1
physics = dinv.physics.Inpainting(tensor_size=x.shape[1:], mask=0.5, device=device)
y = physics(x)
model = PosteriorDiffusion(
data_fidelity=DPSDataFidelity(denoiser=denoiser),
denoiser=denoiser,
sde=sde,
solver=solver,
dtype=dtype,
device=device,
verbose=True,
)
# To perform posterior sampling, we need to provide the measurements, the physics and the solver.
# Moreover, when the physics is given, the initial point can be inferred from the physics if not given explicitly.
seed_1 = 11
x_hat, trajectory = model(
y,
physics,
seed=seed_1,
get_trajectory=True,
)
# Here, we plot the original image, the measurement and the posterior sample
dinv.utils.plot(
[x, y, x_hat],
show=True,
titles=["Original", "Measurement", "Posterior sample"],
save_fn="posterior_sample.png",
figsize=(figsize * 3, figsize),
)
# We can also save the trajectory of the posterior sample
dinv.utils.save_videos(
trajectory[::gif_frequency],
time_dim=0,
titles=["Posterior sample with VE"],
save_fn="posterior_trajectory.gif",
figsize=(figsize, figsize),
)
0%| | 0/149 [00:00<?, ?it/s]
1%| | 1/149 [00:01<03:59, 1.62s/it]
1%|▏ | 2/149 [00:03<03:56, 1.61s/it]
2%|▏ | 3/149 [00:04<03:55, 1.61s/it]
3%|▎ | 4/149 [00:06<03:53, 1.61s/it]
3%|▎ | 5/149 [00:08<03:51, 1.61s/it]
4%|▍ | 6/149 [00:09<03:50, 1.61s/it]
5%|▍ | 7/149 [00:11<03:48, 1.61s/it]
5%|▌ | 8/149 [00:12<03:47, 1.61s/it]
6%|▌ | 9/149 [00:14<03:45, 1.61s/it]
7%|▋ | 10/149 [00:16<03:43, 1.61s/it]
7%|▋ | 11/149 [00:17<03:41, 1.61s/it]
8%|▊ | 12/149 [00:19<03:40, 1.61s/it]
9%|▊ | 13/149 [00:20<03:38, 1.61s/it]
9%|▉ | 14/149 [00:22<03:36, 1.61s/it]
10%|█ | 15/149 [00:24<03:35, 1.61s/it]
11%|█ | 16/149 [00:25<03:33, 1.61s/it]
11%|█▏ | 17/149 [00:27<03:32, 1.61s/it]
12%|█▏ | 18/149 [00:28<03:30, 1.61s/it]
13%|█▎ | 19/149 [00:30<03:28, 1.61s/it]
13%|█▎ | 20/149 [00:32<03:27, 1.61s/it]
14%|█▍ | 21/149 [00:33<03:25, 1.61s/it]
15%|█▍ | 22/149 [00:35<03:24, 1.61s/it]
15%|█▌ | 23/149 [00:37<03:23, 1.62s/it]
16%|█▌ | 24/149 [00:38<03:21, 1.61s/it]
17%|█▋ | 25/149 [00:40<03:19, 1.61s/it]
17%|█▋ | 26/149 [00:41<03:18, 1.61s/it]
18%|█▊ | 27/149 [00:43<03:16, 1.61s/it]
19%|█▉ | 28/149 [00:45<03:14, 1.61s/it]
19%|█▉ | 29/149 [00:46<03:13, 1.61s/it]
20%|██ | 30/149 [00:48<03:12, 1.62s/it]
21%|██ | 31/149 [00:49<03:10, 1.61s/it]
21%|██▏ | 32/149 [00:51<03:08, 1.61s/it]
22%|██▏ | 33/149 [00:53<03:07, 1.61s/it]
23%|██▎ | 34/149 [00:54<03:05, 1.61s/it]
23%|██▎ | 35/149 [00:56<03:04, 1.62s/it]
24%|██▍ | 36/149 [00:57<03:02, 1.62s/it]
25%|██▍ | 37/149 [00:59<03:00, 1.62s/it]
26%|██▌ | 38/149 [01:01<02:59, 1.61s/it]
26%|██▌ | 39/149 [01:02<02:57, 1.62s/it]
27%|██▋ | 40/149 [01:04<02:56, 1.62s/it]
28%|██▊ | 41/149 [01:06<02:54, 1.61s/it]
28%|██▊ | 42/149 [01:07<02:52, 1.61s/it]
29%|██▉ | 43/149 [01:09<02:51, 1.62s/it]
30%|██▉ | 44/149 [01:10<02:49, 1.61s/it]
30%|███ | 45/149 [01:12<02:47, 1.61s/it]
31%|███ | 46/149 [01:14<02:45, 1.61s/it]
32%|███▏ | 47/149 [01:15<02:44, 1.61s/it]
32%|███▏ | 48/149 [01:17<02:43, 1.61s/it]
33%|███▎ | 49/149 [01:18<02:41, 1.62s/it]
34%|███▎ | 50/149 [01:20<02:41, 1.63s/it]
34%|███▍ | 51/149 [01:22<02:39, 1.62s/it]
35%|███▍ | 52/149 [01:23<02:37, 1.62s/it]
36%|███▌ | 53/149 [01:25<02:35, 1.62s/it]
36%|███▌ | 54/149 [01:27<02:33, 1.62s/it]
37%|███▋ | 55/149 [01:28<02:31, 1.61s/it]
38%|███▊ | 56/149 [01:30<02:30, 1.61s/it]
38%|███▊ | 57/149 [01:31<02:30, 1.64s/it]
39%|███▉ | 58/149 [01:33<02:28, 1.63s/it]
40%|███▉ | 59/149 [01:35<02:26, 1.63s/it]
40%|████ | 60/149 [01:36<02:24, 1.62s/it]
41%|████ | 61/149 [01:38<02:22, 1.62s/it]
42%|████▏ | 62/149 [01:40<02:20, 1.61s/it]
42%|████▏ | 63/149 [01:41<02:19, 1.62s/it]
43%|████▎ | 64/149 [01:43<02:17, 1.62s/it]
44%|████▎ | 65/149 [01:44<02:16, 1.62s/it]
44%|████▍ | 66/149 [01:46<02:14, 1.62s/it]
45%|████▍ | 67/149 [01:48<02:13, 1.63s/it]
46%|████▌ | 68/149 [01:49<02:11, 1.62s/it]
46%|████▋ | 69/149 [01:51<02:09, 1.62s/it]
47%|████▋ | 70/149 [01:53<02:07, 1.61s/it]
48%|████▊ | 71/149 [01:54<02:05, 1.61s/it]
48%|████▊ | 72/149 [01:56<02:04, 1.61s/it]
49%|████▉ | 73/149 [01:57<02:02, 1.61s/it]
50%|████▉ | 74/149 [01:59<02:00, 1.61s/it]
50%|█████ | 75/149 [02:01<01:59, 1.61s/it]
51%|█████ | 76/149 [02:02<01:57, 1.61s/it]
52%|█████▏ | 77/149 [02:04<01:55, 1.61s/it]
52%|█████▏ | 78/149 [02:05<01:54, 1.61s/it]
53%|█████▎ | 79/149 [02:07<01:52, 1.61s/it]
54%|█████▎ | 80/149 [02:09<01:51, 1.61s/it]
54%|█████▍ | 81/149 [02:10<01:49, 1.62s/it]
55%|█████▌ | 82/149 [02:12<01:48, 1.62s/it]
56%|█████▌ | 83/149 [02:13<01:46, 1.62s/it]
56%|█████▋ | 84/149 [02:15<01:44, 1.62s/it]
57%|█████▋ | 85/149 [02:17<01:43, 1.61s/it]
58%|█████▊ | 86/149 [02:18<01:41, 1.61s/it]
58%|█████▊ | 87/149 [02:20<01:40, 1.61s/it]
59%|█████▉ | 88/149 [02:22<01:38, 1.62s/it]
60%|█████▉ | 89/149 [02:23<01:36, 1.62s/it]
60%|██████ | 90/149 [02:25<01:35, 1.61s/it]
61%|██████ | 91/149 [02:26<01:33, 1.61s/it]
62%|██████▏ | 92/149 [02:28<01:31, 1.61s/it]
62%|██████▏ | 93/149 [02:30<01:30, 1.61s/it]
63%|██████▎ | 94/149 [02:31<01:28, 1.61s/it]
64%|██████▍ | 95/149 [02:33<01:26, 1.61s/it]
64%|██████▍ | 96/149 [02:34<01:25, 1.61s/it]
65%|██████▌ | 97/149 [02:36<01:23, 1.61s/it]
66%|██████▌ | 98/149 [02:38<01:21, 1.61s/it]
66%|██████▋ | 99/149 [02:39<01:20, 1.61s/it]
67%|██████▋ | 100/149 [02:41<01:18, 1.61s/it]
68%|██████▊ | 101/149 [02:42<01:17, 1.61s/it]
68%|██████▊ | 102/149 [02:44<01:15, 1.61s/it]
69%|██████▉ | 103/149 [02:46<01:13, 1.61s/it]
70%|██████▉ | 104/149 [02:47<01:12, 1.61s/it]
70%|███████ | 105/149 [02:49<01:10, 1.61s/it]
71%|███████ | 106/149 [02:50<01:09, 1.61s/it]
72%|███████▏ | 107/149 [02:52<01:07, 1.61s/it]
72%|███████▏ | 108/149 [02:54<01:05, 1.61s/it]
73%|███████▎ | 109/149 [02:55<01:04, 1.61s/it]
74%|███████▍ | 110/149 [02:57<01:02, 1.61s/it]
74%|███████▍ | 111/149 [02:58<01:01, 1.61s/it]
75%|███████▌ | 112/149 [03:00<00:59, 1.61s/it]
76%|███████▌ | 113/149 [03:02<00:57, 1.61s/it]
77%|███████▋ | 114/149 [03:03<00:56, 1.61s/it]
77%|███████▋ | 115/149 [03:05<00:54, 1.61s/it]
78%|███████▊ | 116/149 [03:07<00:53, 1.61s/it]
79%|███████▊ | 117/149 [03:08<00:51, 1.61s/it]
79%|███████▉ | 118/149 [03:10<00:49, 1.61s/it]
80%|███████▉ | 119/149 [03:11<00:48, 1.61s/it]
81%|████████ | 120/149 [03:13<00:46, 1.61s/it]
81%|████████ | 121/149 [03:15<00:44, 1.61s/it]
82%|████████▏ | 122/149 [03:16<00:43, 1.61s/it]
83%|████████▎ | 123/149 [03:18<00:41, 1.61s/it]
83%|████████▎ | 124/149 [03:19<00:40, 1.61s/it]
84%|████████▍ | 125/149 [03:21<00:38, 1.61s/it]
85%|████████▍ | 126/149 [03:23<00:36, 1.61s/it]
85%|████████▌ | 127/149 [03:24<00:35, 1.61s/it]
86%|████████▌ | 128/149 [03:26<00:33, 1.61s/it]
87%|████████▋ | 129/149 [03:27<00:32, 1.61s/it]
87%|████████▋ | 130/149 [03:29<00:30, 1.61s/it]
88%|████████▊ | 131/149 [03:31<00:28, 1.61s/it]
89%|████████▊ | 132/149 [03:32<00:27, 1.61s/it]
89%|████████▉ | 133/149 [03:34<00:25, 1.61s/it]
90%|████████▉ | 134/149 [03:35<00:24, 1.61s/it]
91%|█████████ | 135/149 [03:37<00:22, 1.61s/it]
91%|█████████▏| 136/149 [03:39<00:20, 1.61s/it]
92%|█████████▏| 137/149 [03:40<00:19, 1.61s/it]
93%|█████████▎| 138/149 [03:42<00:17, 1.61s/it]
93%|█████████▎| 139/149 [03:44<00:16, 1.60s/it]
94%|█████████▍| 140/149 [03:45<00:14, 1.61s/it]
95%|█████████▍| 141/149 [03:47<00:12, 1.60s/it]
95%|█████████▌| 142/149 [03:48<00:11, 1.61s/it]
96%|█████████▌| 143/149 [03:50<00:09, 1.61s/it]
97%|█████████▋| 144/149 [03:52<00:08, 1.61s/it]
97%|█████████▋| 145/149 [03:53<00:06, 1.61s/it]
98%|█████████▊| 146/149 [03:55<00:04, 1.61s/it]
99%|█████████▊| 147/149 [03:56<00:03, 1.61s/it]
99%|█████████▉| 148/149 [03:58<00:01, 1.61s/it]
100%|██████████| 149/149 [04:00<00:00, 1.61s/it]
100%|██████████| 149/149 [04:00<00:00, 1.61s/it]
We obtain the following posterior sample and trajectory


Note
Reproducibility: To ensure the reproducibility, if the parameter rng
is given, the same sample will
be generated when the same seed is used.
One can obtain varying samples by using a different seed.
Parallel sampling: one can draw multiple samples in parallel by giving the initial shape, e.g., x_init = (B, C, H, W)
Varying the SDE#
One can also change the underlying SDE for sampling.
For example, we can also use the Variance-Preserving (VP or DDPM) in deepinv.sampling.VariancePreservingDiffusion
, whose forward drift and diffusion term are defined as:
from deepinv.sampling import VariancePreservingDiffusion
sde = VariancePreservingDiffusion(device=device, dtype=dtype)
model = PosteriorDiffusion(
data_fidelity=DPSDataFidelity(denoiser=denoiser),
denoiser=denoiser,
sde=sde,
solver=solver,
device=device,
dtype=dtype,
verbose=True,
)
x_hat_vp, trajectory_vp = model(
y,
physics,
seed=111,
timesteps=torch.linspace(1, 0.001, 300),
get_trajectory=True,
)
dinv.utils.plot(
[x_hat, x_hat_vp],
titles=[
"posterior sample with VE",
"posterior sample with VP",
],
save_fn="posterior_sample_ve_vp.png",
figsize=(figsize * 2, figsize),
)
# We can also save the trajectory of the posterior sample
dinv.utils.save_videos(
trajectory[::gif_frequency],
time_dim=0,
titles=["Posterior sample with VP"],
save_fn="posterior_trajectory_vp.gif",
figsize=(figsize, figsize),
)
0%| | 0/299 [00:00<?, ?it/s]
0%| | 1/299 [00:01<08:09, 1.64s/it]
1%| | 2/299 [00:03<08:09, 1.65s/it]
1%| | 3/299 [00:04<08:08, 1.65s/it]
1%|▏ | 4/299 [00:06<08:05, 1.65s/it]
2%|▏ | 5/299 [00:08<08:05, 1.65s/it]
2%|▏ | 6/299 [00:09<08:04, 1.66s/it]
2%|▏ | 7/299 [00:11<08:02, 1.65s/it]
3%|▎ | 8/299 [00:13<08:00, 1.65s/it]
3%|▎ | 9/299 [00:14<07:59, 1.65s/it]
3%|▎ | 10/299 [00:16<07:57, 1.65s/it]
4%|▎ | 11/299 [00:18<08:00, 1.67s/it]
4%|▍ | 12/299 [00:19<07:57, 1.66s/it]
4%|▍ | 13/299 [00:21<07:55, 1.66s/it]
5%|▍ | 14/299 [00:23<07:55, 1.67s/it]
5%|▌ | 15/299 [00:24<07:54, 1.67s/it]
5%|▌ | 16/299 [00:26<07:54, 1.68s/it]
6%|▌ | 17/299 [00:28<07:53, 1.68s/it]
6%|▌ | 18/299 [00:29<07:53, 1.68s/it]
6%|▋ | 19/299 [00:31<07:52, 1.69s/it]
7%|▋ | 20/299 [00:33<07:49, 1.68s/it]
7%|▋ | 21/299 [00:35<07:46, 1.68s/it]
7%|▋ | 22/299 [00:36<07:44, 1.68s/it]
8%|▊ | 23/299 [00:38<07:44, 1.68s/it]
8%|▊ | 24/299 [00:40<07:42, 1.68s/it]
8%|▊ | 25/299 [00:41<07:39, 1.68s/it]
9%|▊ | 26/299 [00:43<07:37, 1.68s/it]
9%|▉ | 27/299 [00:45<07:36, 1.68s/it]
9%|▉ | 28/299 [00:46<07:33, 1.67s/it]
10%|▉ | 29/299 [00:48<07:31, 1.67s/it]
10%|█ | 30/299 [00:50<07:30, 1.67s/it]
10%|█ | 31/299 [00:51<07:28, 1.67s/it]
11%|█ | 32/299 [00:53<07:27, 1.67s/it]
11%|█ | 33/299 [00:55<07:25, 1.67s/it]
11%|█▏ | 34/299 [00:56<07:22, 1.67s/it]
12%|█▏ | 35/299 [00:58<07:21, 1.67s/it]
12%|█▏ | 36/299 [01:00<07:19, 1.67s/it]
12%|█▏ | 37/299 [01:01<07:17, 1.67s/it]
13%|█▎ | 38/299 [01:03<07:14, 1.67s/it]
13%|█▎ | 39/299 [01:05<07:13, 1.67s/it]
13%|█▎ | 40/299 [01:06<07:13, 1.67s/it]
14%|█▎ | 41/299 [01:08<07:11, 1.67s/it]
14%|█▍ | 42/299 [01:10<07:09, 1.67s/it]
14%|█▍ | 43/299 [01:11<07:07, 1.67s/it]
15%|█▍ | 44/299 [01:13<07:05, 1.67s/it]
15%|█▌ | 45/299 [01:15<07:03, 1.67s/it]
15%|█▌ | 46/299 [01:16<07:10, 1.70s/it]
16%|█▌ | 47/299 [01:18<07:06, 1.69s/it]
16%|█▌ | 48/299 [01:20<07:03, 1.69s/it]
16%|█▋ | 49/299 [01:21<07:00, 1.68s/it]
17%|█▋ | 50/299 [01:23<06:57, 1.68s/it]
17%|█▋ | 51/299 [01:25<06:55, 1.67s/it]
17%|█▋ | 52/299 [01:26<06:52, 1.67s/it]
18%|█▊ | 53/299 [01:28<06:51, 1.67s/it]
18%|█▊ | 54/299 [01:30<06:49, 1.67s/it]
18%|█▊ | 55/299 [01:31<06:46, 1.67s/it]
19%|█▊ | 56/299 [01:33<06:45, 1.67s/it]
19%|█▉ | 57/299 [01:35<06:44, 1.67s/it]
19%|█▉ | 58/299 [01:36<06:42, 1.67s/it]
20%|█▉ | 59/299 [01:38<06:42, 1.68s/it]
20%|██ | 60/299 [01:40<06:40, 1.67s/it]
20%|██ | 61/299 [01:41<06:37, 1.67s/it]
21%|██ | 62/299 [01:43<06:36, 1.67s/it]
21%|██ | 63/299 [01:45<06:33, 1.67s/it]
21%|██▏ | 64/299 [01:46<06:32, 1.67s/it]
22%|██▏ | 65/299 [01:48<06:30, 1.67s/it]
22%|██▏ | 66/299 [01:50<06:27, 1.66s/it]
22%|██▏ | 67/299 [01:51<06:26, 1.67s/it]
23%|██▎ | 68/299 [01:53<06:24, 1.66s/it]
23%|██▎ | 69/299 [01:55<06:23, 1.67s/it]
23%|██▎ | 70/299 [01:56<06:21, 1.67s/it]
24%|██▎ | 71/299 [01:58<06:20, 1.67s/it]
24%|██▍ | 72/299 [02:00<06:18, 1.67s/it]
24%|██▍ | 73/299 [02:01<06:17, 1.67s/it]
25%|██▍ | 74/299 [02:03<06:15, 1.67s/it]
25%|██▌ | 75/299 [02:05<06:13, 1.67s/it]
25%|██▌ | 76/299 [02:06<06:11, 1.67s/it]
26%|██▌ | 77/299 [02:08<06:09, 1.67s/it]
26%|██▌ | 78/299 [02:10<06:08, 1.67s/it]
26%|██▋ | 79/299 [02:11<06:06, 1.67s/it]
27%|██▋ | 80/299 [02:13<06:04, 1.67s/it]
27%|██▋ | 81/299 [02:15<06:02, 1.66s/it]
27%|██▋ | 82/299 [02:16<06:01, 1.66s/it]
28%|██▊ | 83/299 [02:18<05:58, 1.66s/it]
28%|██▊ | 84/299 [02:20<05:56, 1.66s/it]
28%|██▊ | 85/299 [02:21<05:55, 1.66s/it]
29%|██▉ | 86/299 [02:23<05:54, 1.66s/it]
29%|██▉ | 87/299 [02:25<05:52, 1.66s/it]
29%|██▉ | 88/299 [02:26<05:50, 1.66s/it]
30%|██▉ | 89/299 [02:28<05:49, 1.66s/it]
30%|███ | 90/299 [02:30<05:48, 1.67s/it]
30%|███ | 91/299 [02:31<05:46, 1.67s/it]
31%|███ | 92/299 [02:33<05:44, 1.67s/it]
31%|███ | 93/299 [02:35<05:43, 1.67s/it]
31%|███▏ | 94/299 [02:36<05:41, 1.66s/it]
32%|███▏ | 95/299 [02:38<05:40, 1.67s/it]
32%|███▏ | 96/299 [02:40<05:38, 1.67s/it]
32%|███▏ | 97/299 [02:41<05:36, 1.66s/it]
33%|███▎ | 98/299 [02:43<05:34, 1.67s/it]
33%|███▎ | 99/299 [02:45<05:32, 1.66s/it]
33%|███▎ | 100/299 [02:46<05:31, 1.67s/it]
34%|███▍ | 101/299 [02:48<05:32, 1.68s/it]
34%|███▍ | 102/299 [02:50<05:29, 1.67s/it]
34%|███▍ | 103/299 [02:51<05:26, 1.67s/it]
35%|███▍ | 104/299 [02:53<05:25, 1.67s/it]
35%|███▌ | 105/299 [02:55<05:24, 1.67s/it]
35%|███▌ | 106/299 [02:56<05:22, 1.67s/it]
36%|███▌ | 107/299 [02:58<05:21, 1.68s/it]
36%|███▌ | 108/299 [03:00<05:19, 1.67s/it]
36%|███▋ | 109/299 [03:01<05:17, 1.67s/it]
37%|███▋ | 110/299 [03:03<05:16, 1.67s/it]
37%|███▋ | 111/299 [03:05<05:14, 1.68s/it]
37%|███▋ | 112/299 [03:06<05:12, 1.67s/it]
38%|███▊ | 113/299 [03:08<05:10, 1.67s/it]
38%|███▊ | 114/299 [03:10<05:08, 1.67s/it]
38%|███▊ | 115/299 [03:11<05:06, 1.67s/it]
39%|███▉ | 116/299 [03:13<05:04, 1.66s/it]
39%|███▉ | 117/299 [03:15<05:02, 1.66s/it]
39%|███▉ | 118/299 [03:16<05:01, 1.66s/it]
40%|███▉ | 119/299 [03:18<04:59, 1.66s/it]
40%|████ | 120/299 [03:20<04:57, 1.66s/it]
40%|████ | 121/299 [03:21<04:55, 1.66s/it]
41%|████ | 122/299 [03:23<04:54, 1.66s/it]
41%|████ | 123/299 [03:25<04:52, 1.66s/it]
41%|████▏ | 124/299 [03:26<04:52, 1.67s/it]
42%|████▏ | 125/299 [03:28<04:51, 1.67s/it]
42%|████▏ | 126/299 [03:30<04:49, 1.67s/it]
42%|████▏ | 127/299 [03:31<04:47, 1.67s/it]
43%|████▎ | 128/299 [03:33<04:44, 1.66s/it]
43%|████▎ | 129/299 [03:35<04:43, 1.67s/it]
43%|████▎ | 130/299 [03:37<04:44, 1.68s/it]
44%|████▍ | 131/299 [03:38<04:43, 1.68s/it]
44%|████▍ | 132/299 [03:40<04:41, 1.68s/it]
44%|████▍ | 133/299 [03:42<04:37, 1.67s/it]
45%|████▍ | 134/299 [03:43<04:35, 1.67s/it]
45%|████▌ | 135/299 [03:45<04:33, 1.67s/it]
45%|████▌ | 136/299 [03:47<04:32, 1.67s/it]
46%|████▌ | 137/299 [03:48<04:28, 1.66s/it]
46%|████▌ | 138/299 [03:50<04:24, 1.65s/it]
46%|████▋ | 139/299 [03:51<04:21, 1.64s/it]
47%|████▋ | 140/299 [03:53<04:19, 1.63s/it]
47%|████▋ | 141/299 [03:55<04:16, 1.62s/it]
47%|████▋ | 142/299 [03:56<04:14, 1.62s/it]
48%|████▊ | 143/299 [03:58<04:12, 1.62s/it]
48%|████▊ | 144/299 [03:59<04:11, 1.62s/it]
48%|████▊ | 145/299 [04:01<04:09, 1.62s/it]
49%|████▉ | 146/299 [04:03<04:09, 1.63s/it]
49%|████▉ | 147/299 [04:04<04:07, 1.63s/it]
49%|████▉ | 148/299 [04:06<04:06, 1.63s/it]
50%|████▉ | 149/299 [04:08<04:03, 1.63s/it]
50%|█████ | 150/299 [04:09<04:01, 1.62s/it]
51%|█████ | 151/299 [04:11<03:59, 1.62s/it]
51%|█████ | 152/299 [04:12<03:57, 1.62s/it]
51%|█████ | 153/299 [04:14<03:56, 1.62s/it]
52%|█████▏ | 154/299 [04:16<03:54, 1.61s/it]
52%|█████▏ | 155/299 [04:17<03:52, 1.61s/it]
52%|█████▏ | 156/299 [04:19<03:50, 1.61s/it]
53%|█████▎ | 157/299 [04:21<03:49, 1.61s/it]
53%|█████▎ | 158/299 [04:22<03:47, 1.62s/it]
53%|█████▎ | 159/299 [04:24<03:46, 1.62s/it]
54%|█████▎ | 160/299 [04:25<03:44, 1.62s/it]
54%|█████▍ | 161/299 [04:27<03:43, 1.62s/it]
54%|█████▍ | 162/299 [04:29<03:41, 1.62s/it]
55%|█████▍ | 163/299 [04:30<03:39, 1.62s/it]
55%|█████▍ | 164/299 [04:32<03:37, 1.61s/it]
55%|█████▌ | 165/299 [04:33<03:35, 1.61s/it]
56%|█████▌ | 166/299 [04:35<03:34, 1.61s/it]
56%|█████▌ | 167/299 [04:37<03:33, 1.61s/it]
56%|█████▌ | 168/299 [04:38<03:31, 1.61s/it]
57%|█████▋ | 169/299 [04:40<03:29, 1.62s/it]
57%|█████▋ | 170/299 [04:42<03:29, 1.62s/it]
57%|█████▋ | 171/299 [04:43<03:27, 1.62s/it]
58%|█████▊ | 172/299 [04:45<03:25, 1.62s/it]
58%|█████▊ | 173/299 [04:46<03:23, 1.62s/it]
58%|█████▊ | 174/299 [04:48<03:22, 1.62s/it]
59%|█████▊ | 175/299 [04:50<03:20, 1.62s/it]
59%|█████▉ | 176/299 [04:51<03:19, 1.62s/it]
59%|█████▉ | 177/299 [04:53<03:17, 1.62s/it]
60%|█████▉ | 178/299 [04:55<03:16, 1.62s/it]
60%|█████▉ | 179/299 [04:56<03:14, 1.62s/it]
60%|██████ | 180/299 [04:58<03:13, 1.63s/it]
61%|██████ | 181/299 [04:59<03:12, 1.63s/it]
61%|██████ | 182/299 [05:01<03:10, 1.63s/it]
61%|██████ | 183/299 [05:03<03:08, 1.63s/it]
62%|██████▏ | 184/299 [05:04<03:06, 1.63s/it]
62%|██████▏ | 185/299 [05:06<03:04, 1.62s/it]
62%|██████▏ | 186/299 [05:08<03:03, 1.62s/it]
63%|██████▎ | 187/299 [05:09<03:01, 1.62s/it]
63%|██████▎ | 188/299 [05:11<03:00, 1.62s/it]
63%|██████▎ | 189/299 [05:12<02:58, 1.62s/it]
64%|██████▎ | 190/299 [05:14<02:57, 1.62s/it]
64%|██████▍ | 191/299 [05:16<02:55, 1.63s/it]
64%|██████▍ | 192/299 [05:17<02:53, 1.63s/it]
65%|██████▍ | 193/299 [05:19<02:52, 1.63s/it]
65%|██████▍ | 194/299 [05:21<02:50, 1.63s/it]
65%|██████▌ | 195/299 [05:22<02:49, 1.63s/it]
66%|██████▌ | 196/299 [05:24<02:47, 1.63s/it]
66%|██████▌ | 197/299 [05:25<02:46, 1.63s/it]
66%|██████▌ | 198/299 [05:27<02:44, 1.63s/it]
67%|██████▋ | 199/299 [05:29<02:42, 1.63s/it]
67%|██████▋ | 200/299 [05:30<02:40, 1.63s/it]
67%|██████▋ | 201/299 [05:32<02:39, 1.62s/it]
68%|██████▊ | 202/299 [05:34<02:37, 1.62s/it]
68%|██████▊ | 203/299 [05:35<02:35, 1.62s/it]
68%|██████▊ | 204/299 [05:37<02:34, 1.62s/it]
69%|██████▊ | 205/299 [05:38<02:32, 1.62s/it]
69%|██████▉ | 206/299 [05:40<02:30, 1.62s/it]
69%|██████▉ | 207/299 [05:42<02:28, 1.62s/it]
70%|██████▉ | 208/299 [05:43<02:27, 1.62s/it]
70%|██████▉ | 209/299 [05:45<02:25, 1.62s/it]
70%|███████ | 210/299 [05:46<02:23, 1.61s/it]
71%|███████ | 211/299 [05:48<02:22, 1.61s/it]
71%|███████ | 212/299 [05:50<02:20, 1.61s/it]
71%|███████ | 213/299 [05:51<02:18, 1.61s/it]
72%|███████▏ | 214/299 [05:53<02:16, 1.61s/it]
72%|███████▏ | 215/299 [05:54<02:15, 1.61s/it]
72%|███████▏ | 216/299 [05:56<02:13, 1.61s/it]
73%|███████▎ | 217/299 [05:58<02:12, 1.61s/it]
73%|███████▎ | 218/299 [05:59<02:10, 1.61s/it]
73%|███████▎ | 219/299 [06:01<02:08, 1.61s/it]
74%|███████▎ | 220/299 [06:03<02:07, 1.61s/it]
74%|███████▍ | 221/299 [06:04<02:05, 1.61s/it]
74%|███████▍ | 222/299 [06:06<02:04, 1.61s/it]
75%|███████▍ | 223/299 [06:07<02:02, 1.61s/it]
75%|███████▍ | 224/299 [06:09<02:01, 1.61s/it]
75%|███████▌ | 225/299 [06:11<02:00, 1.62s/it]
76%|███████▌ | 226/299 [06:12<01:58, 1.62s/it]
76%|███████▌ | 227/299 [06:14<01:56, 1.62s/it]
76%|███████▋ | 228/299 [06:15<01:54, 1.62s/it]
77%|███████▋ | 229/299 [06:17<01:53, 1.62s/it]
77%|███████▋ | 230/299 [06:19<01:52, 1.64s/it]
77%|███████▋ | 231/299 [06:20<01:50, 1.63s/it]
78%|███████▊ | 232/299 [06:22<01:49, 1.63s/it]
78%|███████▊ | 233/299 [06:24<01:47, 1.62s/it]
78%|███████▊ | 234/299 [06:25<01:45, 1.62s/it]
79%|███████▊ | 235/299 [06:27<01:43, 1.62s/it]
79%|███████▉ | 236/299 [06:28<01:41, 1.62s/it]
79%|███████▉ | 237/299 [06:30<01:40, 1.62s/it]
80%|███████▉ | 238/299 [06:32<01:38, 1.62s/it]
80%|███████▉ | 239/299 [06:33<01:37, 1.62s/it]
80%|████████ | 240/299 [06:35<01:35, 1.63s/it]
81%|████████ | 241/299 [06:37<01:34, 1.63s/it]
81%|████████ | 242/299 [06:38<01:33, 1.63s/it]
81%|████████▏ | 243/299 [06:40<01:31, 1.63s/it]
82%|████████▏ | 244/299 [06:41<01:29, 1.62s/it]
82%|████████▏ | 245/299 [06:43<01:27, 1.62s/it]
82%|████████▏ | 246/299 [06:45<01:25, 1.62s/it]
83%|████████▎ | 247/299 [06:46<01:24, 1.62s/it]
83%|████████▎ | 248/299 [06:48<01:22, 1.62s/it]
83%|████████▎ | 249/299 [06:50<01:20, 1.62s/it]
84%|████████▎ | 250/299 [06:51<01:19, 1.62s/it]
84%|████████▍ | 251/299 [06:53<01:17, 1.62s/it]
84%|████████▍ | 252/299 [06:54<01:16, 1.62s/it]
85%|████████▍ | 253/299 [06:56<01:14, 1.62s/it]
85%|████████▍ | 254/299 [06:58<01:12, 1.62s/it]
85%|████████▌ | 255/299 [06:59<01:11, 1.62s/it]
86%|████████▌ | 256/299 [07:01<01:09, 1.62s/it]
86%|████████▌ | 257/299 [07:03<01:07, 1.61s/it]
86%|████████▋ | 258/299 [07:04<01:06, 1.61s/it]
87%|████████▋ | 259/299 [07:06<01:04, 1.61s/it]
87%|████████▋ | 260/299 [07:07<01:02, 1.61s/it]
87%|████████▋ | 261/299 [07:09<01:00, 1.60s/it]
88%|████████▊ | 262/299 [07:11<00:59, 1.60s/it]
88%|████████▊ | 263/299 [07:12<00:57, 1.60s/it]
88%|████████▊ | 264/299 [07:14<00:56, 1.61s/it]
89%|████████▊ | 265/299 [07:15<00:55, 1.62s/it]
89%|████████▉ | 266/299 [07:17<00:53, 1.61s/it]
89%|████████▉ | 267/299 [07:19<00:51, 1.61s/it]
90%|████████▉ | 268/299 [07:20<00:49, 1.61s/it]
90%|████████▉ | 269/299 [07:22<00:48, 1.61s/it]
90%|█████████ | 270/299 [07:23<00:46, 1.61s/it]
91%|█████████ | 271/299 [07:25<00:44, 1.61s/it]
91%|█████████ | 272/299 [07:27<00:43, 1.60s/it]
91%|█████████▏| 273/299 [07:28<00:41, 1.60s/it]
92%|█████████▏| 274/299 [07:30<00:40, 1.60s/it]
92%|█████████▏| 275/299 [07:31<00:38, 1.60s/it]
92%|█████████▏| 276/299 [07:33<00:36, 1.60s/it]
93%|█████████▎| 277/299 [07:35<00:35, 1.60s/it]
93%|█████████▎| 278/299 [07:36<00:33, 1.60s/it]
93%|█████████▎| 279/299 [07:38<00:32, 1.61s/it]
94%|█████████▎| 280/299 [07:39<00:30, 1.61s/it]
94%|█████████▍| 281/299 [07:41<00:28, 1.61s/it]
94%|█████████▍| 282/299 [07:43<00:27, 1.61s/it]
95%|█████████▍| 283/299 [07:44<00:25, 1.61s/it]
95%|█████████▍| 284/299 [07:46<00:24, 1.61s/it]
95%|█████████▌| 285/299 [07:47<00:22, 1.61s/it]
96%|█████████▌| 286/299 [07:49<00:20, 1.60s/it]
96%|█████████▌| 287/299 [07:51<00:19, 1.60s/it]
96%|█████████▋| 288/299 [07:52<00:17, 1.60s/it]
97%|█████████▋| 289/299 [07:54<00:16, 1.60s/it]
97%|█████████▋| 290/299 [07:55<00:14, 1.60s/it]
97%|█████████▋| 291/299 [07:57<00:12, 1.60s/it]
98%|█████████▊| 292/299 [07:59<00:11, 1.61s/it]
98%|█████████▊| 293/299 [08:00<00:09, 1.60s/it]
98%|█████████▊| 294/299 [08:02<00:08, 1.61s/it]
99%|█████████▊| 295/299 [08:04<00:06, 1.61s/it]
99%|█████████▉| 296/299 [08:05<00:04, 1.61s/it]
99%|█████████▉| 297/299 [08:07<00:03, 1.61s/it]
100%|█████████▉| 298/299 [08:08<00:01, 1.61s/it]
100%|██████████| 299/299 [08:10<00:00, 1.60s/it]
100%|██████████| 299/299 [08:10<00:00, 1.64s/it]
We can comparing the sampling trajectory depending on the underlying SDE



Plug-and-play Posterior Sampling with arbitrary denoisers#
The deepinv.sampling.PosteriorDiffusion
class can be used together with any (well-trained) denoisers for posterior sampling.
For example, we can use the deepinv.models.DRUNet
for posterior sampling.
We can also change the underlying SDE, for example change the sigma_max
value.
sigma_min = 0.02
sigma_max = 7.0
rng = torch.Generator(device)
dtype = torch.float32
timesteps = torch.linspace(1, 0.001, 300)
solver = EulerSolver(timesteps=timesteps, rng=rng)
denoiser = dinv.models.DRUNet(pretrained="download").to(device)
sde = VarianceExplodingDiffusion(
sigma_max=sigma_max, sigma_min=sigma_min, alpha=0.75, device=device, dtype=dtype
)
x = dinv.utils.load_url_image(
dinv.utils.demo.get_image_url("butterfly.png"), img_size=256, resize_mode="resize"
).to(device)
mask = torch.ones_like(x)
mask[..., 100:150, 125:175] = 0.0
physics = dinv.physics.Inpainting(
mask=mask,
tensor_size=x.shape[1:],
device=device,
)
y = physics(x)
model = PosteriorDiffusion(
data_fidelity=DPSDataFidelity(denoiser=denoiser),
denoiser=denoiser,
sde=sde,
solver=solver,
dtype=dtype,
device=device,
verbose=True,
)
# To perform posterior sampling, we need to provide the measurements, the physics and the solver.
x_hat, trajectory = model(
y=y,
physics=physics,
seed=12,
get_trajectory=True,
)
# Here, we plot the original image, the measurement and the posterior sample
dinv.utils.plot(
[x, y, x_hat.clip(0, 1)],
titles=["Original", "Measurement", "Posterior sample DRUNet"],
figsize=(figsize * 3, figsize),
save_fn="posterior_sample_DRUNet.png",
)
# We can also save the trajectory of the posterior sample
dinv.utils.save_videos(
trajectory[::gif_frequency].clip(0, 1),
time_dim=0,
titles=["Posterior trajectory DRUNet"],
save_fn="posterior_sample_DRUNet.gif",
figsize=(figsize, figsize),
)
0%| | 0/299 [00:00<?, ?it/s]
0%| | 1/299 [00:05<25:09, 5.06s/it]
1%| | 2/299 [00:10<24:47, 5.01s/it]
1%| | 3/299 [00:15<24:38, 4.99s/it]
1%|▏ | 4/299 [00:20<24:34, 5.00s/it]
2%|▏ | 5/299 [00:24<24:26, 4.99s/it]
2%|▏ | 6/299 [00:29<24:21, 4.99s/it]
2%|▏ | 7/299 [00:34<24:14, 4.98s/it]
3%|▎ | 8/299 [00:39<24:09, 4.98s/it]
3%|▎ | 9/299 [00:44<24:04, 4.98s/it]
3%|▎ | 10/299 [00:49<23:58, 4.98s/it]
4%|▎ | 11/299 [00:54<23:50, 4.97s/it]
4%|▍ | 12/299 [00:59<23:43, 4.96s/it]
4%|▍ | 13/299 [01:04<23:40, 4.97s/it]
5%|▍ | 14/299 [01:09<23:35, 4.97s/it]
5%|▌ | 15/299 [01:14<23:30, 4.97s/it]
5%|▌ | 16/299 [01:19<23:27, 4.97s/it]
6%|▌ | 17/299 [01:24<23:24, 4.98s/it]
6%|▌ | 18/299 [01:29<23:19, 4.98s/it]
6%|▋ | 19/299 [01:34<23:12, 4.97s/it]
7%|▋ | 20/299 [01:39<23:07, 4.97s/it]
7%|▋ | 21/299 [01:44<23:02, 4.97s/it]
7%|▋ | 22/299 [01:49<22:58, 4.98s/it]
8%|▊ | 23/299 [01:54<22:51, 4.97s/it]
8%|▊ | 24/299 [01:59<22:45, 4.97s/it]
8%|▊ | 25/299 [02:04<22:41, 4.97s/it]
9%|▊ | 26/299 [02:09<22:34, 4.96s/it]
9%|▉ | 27/299 [02:14<22:29, 4.96s/it]
9%|▉ | 28/299 [02:19<22:26, 4.97s/it]
10%|▉ | 29/299 [02:24<22:20, 4.97s/it]
10%|█ | 30/299 [02:29<22:15, 4.97s/it]
10%|█ | 31/299 [02:34<22:10, 4.97s/it]
11%|█ | 32/299 [02:39<22:06, 4.97s/it]
11%|█ | 33/299 [02:44<22:02, 4.97s/it]
11%|█▏ | 34/299 [02:49<22:02, 4.99s/it]
12%|█▏ | 35/299 [02:54<21:54, 4.98s/it]
12%|█▏ | 36/299 [02:59<21:49, 4.98s/it]
12%|█▏ | 37/299 [03:04<21:41, 4.97s/it]
13%|█▎ | 38/299 [03:09<21:35, 4.96s/it]
13%|█▎ | 39/299 [03:13<21:29, 4.96s/it]
13%|█▎ | 40/299 [03:18<21:25, 4.96s/it]
14%|█▎ | 41/299 [03:23<21:18, 4.96s/it]
14%|█▍ | 42/299 [03:28<21:13, 4.95s/it]
14%|█▍ | 43/299 [03:33<21:08, 4.95s/it]
15%|█▍ | 44/299 [03:38<21:05, 4.96s/it]
15%|█▌ | 45/299 [03:43<20:59, 4.96s/it]
15%|█▌ | 46/299 [03:48<20:53, 4.96s/it]
16%|█▌ | 47/299 [03:53<20:48, 4.95s/it]
16%|█▌ | 48/299 [03:58<20:42, 4.95s/it]
16%|█▋ | 49/299 [04:03<20:45, 4.98s/it]
17%|█▋ | 50/299 [04:08<20:39, 4.98s/it]
17%|█▋ | 51/299 [04:13<20:32, 4.97s/it]
17%|█▋ | 52/299 [04:18<20:27, 4.97s/it]
18%|█▊ | 53/299 [04:23<20:20, 4.96s/it]
18%|█▊ | 54/299 [04:28<20:15, 4.96s/it]
18%|█▊ | 55/299 [04:33<20:10, 4.96s/it]
19%|█▊ | 56/299 [04:38<20:06, 4.96s/it]
19%|█▉ | 57/299 [04:43<20:00, 4.96s/it]
19%|█▉ | 58/299 [04:48<19:54, 4.96s/it]
20%|█▉ | 59/299 [04:53<19:49, 4.96s/it]
20%|██ | 60/299 [04:58<19:44, 4.96s/it]
20%|██ | 61/299 [05:03<19:38, 4.95s/it]
21%|██ | 62/299 [05:08<19:32, 4.95s/it]
21%|██ | 63/299 [05:12<19:28, 4.95s/it]
21%|██▏ | 64/299 [05:17<19:24, 4.95s/it]
22%|██▏ | 65/299 [05:22<19:21, 4.96s/it]
22%|██▏ | 66/299 [05:27<19:15, 4.96s/it]
22%|██▏ | 67/299 [05:32<19:09, 4.96s/it]
23%|██▎ | 68/299 [05:37<19:04, 4.96s/it]
23%|██▎ | 69/299 [05:42<18:59, 4.95s/it]
23%|██▎ | 70/299 [05:47<18:53, 4.95s/it]
24%|██▎ | 71/299 [05:52<18:49, 4.95s/it]
24%|██▍ | 72/299 [05:57<18:44, 4.95s/it]
24%|██▍ | 73/299 [06:02<18:39, 4.95s/it]
25%|██▍ | 74/299 [06:07<18:33, 4.95s/it]
25%|██▌ | 75/299 [06:12<18:29, 4.95s/it]
25%|██▌ | 76/299 [06:17<18:25, 4.96s/it]
26%|██▌ | 77/299 [06:22<18:19, 4.95s/it]
26%|██▌ | 78/299 [06:27<18:18, 4.97s/it]
26%|██▋ | 79/299 [06:32<18:11, 4.96s/it]
27%|██▋ | 80/299 [06:37<18:06, 4.96s/it]
27%|██▋ | 81/299 [06:42<18:00, 4.96s/it]
27%|██▋ | 82/299 [06:47<17:56, 4.96s/it]
28%|██▊ | 83/299 [06:52<17:51, 4.96s/it]
28%|██▊ | 84/299 [06:57<17:46, 4.96s/it]
28%|██▊ | 85/299 [07:02<17:40, 4.96s/it]
29%|██▉ | 86/299 [07:07<17:35, 4.96s/it]
29%|██▉ | 87/299 [07:11<17:30, 4.95s/it]
29%|██▉ | 88/299 [07:16<17:25, 4.96s/it]
30%|██▉ | 89/299 [07:21<17:20, 4.96s/it]
30%|███ | 90/299 [07:26<17:15, 4.96s/it]
30%|███ | 91/299 [07:31<17:10, 4.95s/it]
31%|███ | 92/299 [07:36<17:05, 4.96s/it]
31%|███ | 93/299 [07:41<17:00, 4.95s/it]
31%|███▏ | 94/299 [07:46<16:55, 4.95s/it]
32%|███▏ | 95/299 [07:51<16:50, 4.95s/it]
32%|███▏ | 96/299 [07:56<16:45, 4.95s/it]
32%|███▏ | 97/299 [08:01<16:39, 4.95s/it]
33%|███▎ | 98/299 [08:06<16:35, 4.95s/it]
33%|███▎ | 99/299 [08:11<16:30, 4.95s/it]
33%|███▎ | 100/299 [08:16<16:26, 4.95s/it]
34%|███▍ | 101/299 [08:21<16:23, 4.96s/it]
34%|███▍ | 102/299 [08:26<16:17, 4.96s/it]
34%|███▍ | 103/299 [08:31<16:11, 4.96s/it]
35%|███▍ | 104/299 [08:36<16:06, 4.96s/it]
35%|███▌ | 105/299 [08:41<16:02, 4.96s/it]
35%|███▌ | 106/299 [08:46<15:56, 4.96s/it]
36%|███▌ | 107/299 [08:51<15:51, 4.95s/it]
36%|███▌ | 108/299 [08:56<15:45, 4.95s/it]
36%|███▋ | 109/299 [09:00<15:41, 4.95s/it]
37%|███▋ | 110/299 [09:05<15:36, 4.95s/it]
37%|███▋ | 111/299 [09:10<15:31, 4.95s/it]
37%|███▋ | 112/299 [09:15<15:26, 4.95s/it]
38%|███▊ | 113/299 [09:20<15:21, 4.95s/it]
38%|███▊ | 114/299 [09:25<15:16, 4.95s/it]
38%|███▊ | 115/299 [09:30<15:10, 4.95s/it]
39%|███▉ | 116/299 [09:35<15:06, 4.95s/it]
39%|███▉ | 117/299 [09:40<15:01, 4.95s/it]
39%|███▉ | 118/299 [09:45<14:56, 4.96s/it]
40%|███▉ | 119/299 [09:50<14:51, 4.95s/it]
40%|████ | 120/299 [09:55<14:47, 4.96s/it]
40%|████ | 121/299 [10:00<14:42, 4.96s/it]
41%|████ | 122/299 [10:05<14:40, 4.98s/it]
41%|████ | 123/299 [10:10<14:35, 4.97s/it]
41%|████▏ | 124/299 [10:15<14:28, 4.97s/it]
42%|████▏ | 125/299 [10:20<14:25, 4.97s/it]
42%|████▏ | 126/299 [10:25<14:19, 4.97s/it]
42%|████▏ | 127/299 [10:30<14:14, 4.97s/it]
43%|████▎ | 128/299 [10:35<14:10, 4.97s/it]
43%|████▎ | 129/299 [10:40<14:05, 4.97s/it]
43%|████▎ | 130/299 [10:45<13:59, 4.97s/it]
44%|████▍ | 131/299 [10:50<13:54, 4.97s/it]
44%|████▍ | 132/299 [10:55<13:48, 4.96s/it]
44%|████▍ | 133/299 [11:00<13:44, 4.97s/it]
45%|████▍ | 134/299 [11:05<13:41, 4.98s/it]
45%|████▌ | 135/299 [11:10<13:35, 4.97s/it]
45%|████▌ | 136/299 [11:15<13:29, 4.97s/it]
46%|████▌ | 137/299 [11:19<13:24, 4.96s/it]
46%|████▌ | 138/299 [11:24<13:18, 4.96s/it]
46%|████▋ | 139/299 [11:29<13:13, 4.96s/it]
47%|████▋ | 140/299 [11:34<13:07, 4.95s/it]
47%|████▋ | 141/299 [11:39<13:02, 4.95s/it]
47%|████▋ | 142/299 [11:44<12:57, 4.95s/it]
48%|████▊ | 143/299 [11:49<12:53, 4.96s/it]
48%|████▊ | 144/299 [11:54<12:47, 4.95s/it]
48%|████▊ | 145/299 [11:59<12:42, 4.95s/it]
49%|████▉ | 146/299 [12:04<12:37, 4.95s/it]
49%|████▉ | 147/299 [12:09<12:32, 4.95s/it]
49%|████▉ | 148/299 [12:14<12:27, 4.95s/it]
50%|████▉ | 149/299 [12:19<12:23, 4.96s/it]
50%|█████ | 150/299 [12:24<12:19, 4.96s/it]
51%|█████ | 151/299 [12:29<12:14, 4.96s/it]
51%|█████ | 152/299 [12:34<12:09, 4.96s/it]
51%|█████ | 153/299 [12:39<12:04, 4.97s/it]
52%|█████▏ | 154/299 [12:44<11:59, 4.96s/it]
52%|█████▏ | 155/299 [12:49<11:53, 4.96s/it]
52%|█████▏ | 156/299 [12:54<11:49, 4.96s/it]
53%|█████▎ | 157/299 [12:59<11:45, 4.97s/it]
53%|█████▎ | 158/299 [13:04<11:40, 4.97s/it]
53%|█████▎ | 159/299 [13:09<11:35, 4.97s/it]
54%|█████▎ | 160/299 [13:13<11:29, 4.96s/it]
54%|█████▍ | 161/299 [13:18<11:24, 4.96s/it]
54%|█████▍ | 162/299 [13:23<11:18, 4.95s/it]
55%|█████▍ | 163/299 [13:28<11:13, 4.96s/it]
55%|█████▍ | 164/299 [13:33<11:08, 4.95s/it]
55%|█████▌ | 165/299 [13:38<11:03, 4.95s/it]
56%|█████▌ | 166/299 [13:43<10:58, 4.95s/it]
56%|█████▌ | 167/299 [13:48<10:55, 4.97s/it]
56%|█████▌ | 168/299 [13:53<10:50, 4.96s/it]
57%|█████▋ | 169/299 [13:58<10:44, 4.96s/it]
57%|█████▋ | 170/299 [14:03<10:39, 4.95s/it]
57%|█████▋ | 171/299 [14:08<10:34, 4.95s/it]
58%|█████▊ | 172/299 [14:13<10:28, 4.95s/it]
58%|█████▊ | 173/299 [14:18<10:23, 4.95s/it]
58%|█████▊ | 174/299 [14:23<10:18, 4.95s/it]
59%|█████▊ | 175/299 [14:28<10:13, 4.95s/it]
59%|█████▉ | 176/299 [14:33<10:08, 4.95s/it]
59%|█████▉ | 177/299 [14:38<10:04, 4.95s/it]
60%|█████▉ | 178/299 [14:43<09:58, 4.95s/it]
60%|█████▉ | 179/299 [14:48<09:53, 4.94s/it]
60%|██████ | 180/299 [14:53<09:49, 4.95s/it]
61%|██████ | 181/299 [14:57<09:44, 4.95s/it]
61%|██████ | 182/299 [15:02<09:39, 4.95s/it]
61%|██████ | 183/299 [15:07<09:33, 4.95s/it]
62%|██████▏ | 184/299 [15:12<09:29, 4.95s/it]
62%|██████▏ | 185/299 [15:17<09:24, 4.95s/it]
62%|██████▏ | 186/299 [15:22<09:19, 4.95s/it]
63%|██████▎ | 187/299 [15:27<09:14, 4.95s/it]
63%|██████▎ | 188/299 [15:32<09:09, 4.95s/it]
63%|██████▎ | 189/299 [15:37<09:04, 4.95s/it]
64%|██████▎ | 190/299 [15:42<08:59, 4.95s/it]
64%|██████▍ | 191/299 [15:47<08:55, 4.96s/it]
64%|██████▍ | 192/299 [15:52<08:51, 4.96s/it]
65%|██████▍ | 193/299 [15:57<08:45, 4.95s/it]
65%|██████▍ | 194/299 [16:02<08:39, 4.95s/it]
65%|██████▌ | 195/299 [16:07<08:34, 4.95s/it]
66%|██████▌ | 196/299 [16:12<08:29, 4.95s/it]
66%|██████▌ | 197/299 [16:17<08:25, 4.96s/it]
66%|██████▌ | 198/299 [16:22<08:20, 4.96s/it]
67%|██████▋ | 199/299 [16:27<08:15, 4.95s/it]
67%|██████▋ | 200/299 [16:32<08:10, 4.95s/it]
67%|██████▋ | 201/299 [16:37<08:05, 4.96s/it]
68%|██████▊ | 202/299 [16:42<08:01, 4.96s/it]
68%|██████▊ | 203/299 [16:47<07:57, 4.97s/it]
68%|██████▊ | 204/299 [16:52<07:53, 4.98s/it]
69%|██████▊ | 205/299 [16:56<07:47, 4.97s/it]
69%|██████▉ | 206/299 [17:01<07:42, 4.98s/it]
69%|██████▉ | 207/299 [17:06<07:38, 4.98s/it]
70%|██████▉ | 208/299 [17:11<07:33, 4.98s/it]
70%|██████▉ | 209/299 [17:16<07:28, 4.98s/it]
70%|███████ | 210/299 [17:21<07:22, 4.98s/it]
71%|███████ | 211/299 [17:26<07:17, 4.97s/it]
71%|███████ | 212/299 [17:31<07:12, 4.97s/it]
71%|███████ | 213/299 [17:36<07:07, 4.97s/it]
72%|███████▏ | 214/299 [17:41<07:02, 4.97s/it]
72%|███████▏ | 215/299 [17:46<06:56, 4.96s/it]
72%|███████▏ | 216/299 [17:51<06:51, 4.96s/it]
73%|███████▎ | 217/299 [17:56<06:46, 4.96s/it]
73%|███████▎ | 218/299 [18:01<06:42, 4.96s/it]
73%|███████▎ | 219/299 [18:06<06:37, 4.96s/it]
74%|███████▎ | 220/299 [18:11<06:32, 4.96s/it]
74%|███████▍ | 221/299 [18:16<06:28, 4.97s/it]
74%|███████▍ | 222/299 [18:21<06:23, 4.98s/it]
75%|███████▍ | 223/299 [18:26<06:18, 4.98s/it]
75%|███████▍ | 224/299 [18:31<06:13, 4.98s/it]
75%|███████▌ | 225/299 [18:36<06:08, 4.98s/it]
76%|███████▌ | 226/299 [18:41<06:03, 4.98s/it]
76%|███████▌ | 227/299 [18:46<05:58, 4.98s/it]
76%|███████▋ | 228/299 [18:51<05:53, 4.98s/it]
77%|███████▋ | 229/299 [18:56<05:48, 4.98s/it]
77%|███████▋ | 230/299 [19:01<05:44, 5.00s/it]
77%|███████▋ | 231/299 [19:06<05:39, 4.99s/it]
78%|███████▊ | 232/299 [19:11<05:33, 4.98s/it]
78%|███████▊ | 233/299 [19:16<05:29, 4.99s/it]
78%|███████▊ | 234/299 [19:21<05:23, 4.98s/it]
79%|███████▊ | 235/299 [19:26<05:18, 4.98s/it]
79%|███████▉ | 236/299 [19:31<05:13, 4.98s/it]
79%|███████▉ | 237/299 [19:36<05:09, 4.98s/it]
80%|███████▉ | 238/299 [19:41<05:04, 4.99s/it]
80%|███████▉ | 239/299 [19:46<04:59, 4.99s/it]
80%|████████ | 240/299 [19:51<04:54, 4.99s/it]
81%|████████ | 241/299 [19:56<04:49, 4.98s/it]
81%|████████ | 242/299 [20:01<04:44, 4.99s/it]
81%|████████▏ | 243/299 [20:06<04:39, 4.98s/it]
82%|████████▏ | 244/299 [20:11<04:34, 4.98s/it]
82%|████████▏ | 245/299 [20:16<04:29, 4.99s/it]
82%|████████▏ | 246/299 [20:21<04:24, 4.98s/it]
83%|████████▎ | 247/299 [20:26<04:18, 4.98s/it]
83%|████████▎ | 248/299 [20:31<04:13, 4.98s/it]
83%|████████▎ | 249/299 [20:36<04:08, 4.98s/it]
84%|████████▎ | 250/299 [20:41<04:03, 4.97s/it]
84%|████████▍ | 251/299 [20:45<03:58, 4.97s/it]
84%|████████▍ | 252/299 [20:50<03:53, 4.97s/it]
85%|████████▍ | 253/299 [20:55<03:48, 4.97s/it]
85%|████████▍ | 254/299 [21:00<03:43, 4.96s/it]
85%|████████▌ | 255/299 [21:05<03:38, 4.97s/it]
86%|████████▌ | 256/299 [21:10<03:33, 4.97s/it]
86%|████████▌ | 257/299 [21:15<03:28, 4.97s/it]
86%|████████▋ | 258/299 [21:20<03:23, 4.97s/it]
87%|████████▋ | 259/299 [21:25<03:18, 4.96s/it]
87%|████████▋ | 260/299 [21:30<03:13, 4.95s/it]
87%|████████▋ | 261/299 [21:35<03:08, 4.95s/it]
88%|████████▊ | 262/299 [21:40<03:03, 4.96s/it]
88%|████████▊ | 263/299 [21:45<02:58, 4.95s/it]
88%|████████▊ | 264/299 [21:50<02:53, 4.95s/it]
89%|████████▊ | 265/299 [21:55<02:48, 4.95s/it]
89%|████████▉ | 266/299 [22:00<02:43, 4.95s/it]
89%|████████▉ | 267/299 [22:05<02:38, 4.94s/it]
90%|████████▉ | 268/299 [22:10<02:33, 4.94s/it]
90%|████████▉ | 269/299 [22:15<02:28, 4.94s/it]
90%|█████████ | 270/299 [22:20<02:23, 4.94s/it]
91%|█████████ | 271/299 [22:24<02:18, 4.94s/it]
91%|█████████ | 272/299 [22:29<02:13, 4.94s/it]
91%|█████████▏| 273/299 [22:34<02:08, 4.93s/it]
92%|█████████▏| 274/299 [22:39<02:03, 4.94s/it]
92%|█████████▏| 275/299 [22:44<01:58, 4.94s/it]
92%|█████████▏| 276/299 [22:49<01:53, 4.94s/it]
93%|█████████▎| 277/299 [22:54<01:48, 4.94s/it]
93%|█████████▎| 278/299 [22:59<01:43, 4.94s/it]
93%|█████████▎| 279/299 [23:04<01:38, 4.94s/it]
94%|█████████▎| 280/299 [23:09<01:33, 4.94s/it]
94%|█████████▍| 281/299 [23:14<01:28, 4.94s/it]
94%|█████████▍| 282/299 [23:19<01:24, 4.96s/it]
95%|█████████▍| 283/299 [23:24<01:19, 4.95s/it]
95%|█████████▍| 284/299 [23:29<01:14, 4.95s/it]
95%|█████████▌| 285/299 [23:34<01:09, 4.95s/it]
96%|█████████▌| 286/299 [23:39<01:04, 4.95s/it]
96%|█████████▌| 287/299 [23:44<00:59, 4.95s/it]
96%|█████████▋| 288/299 [23:49<00:54, 4.95s/it]
97%|█████████▋| 289/299 [23:53<00:49, 4.95s/it]
97%|█████████▋| 290/299 [23:58<00:44, 4.95s/it]
97%|█████████▋| 291/299 [24:03<00:39, 4.95s/it]
98%|█████████▊| 292/299 [24:08<00:34, 4.94s/it]
98%|█████████▊| 293/299 [24:13<00:29, 4.95s/it]
98%|█████████▊| 294/299 [24:18<00:24, 4.96s/it]
99%|█████████▊| 295/299 [24:23<00:19, 4.98s/it]
99%|█████████▉| 296/299 [24:28<00:14, 4.97s/it]
99%|█████████▉| 297/299 [24:33<00:09, 4.97s/it]
100%|█████████▉| 298/299 [24:38<00:04, 4.98s/it]
100%|██████████| 299/299 [24:43<00:00, 4.97s/it]
100%|██████████| 299/299 [24:43<00:00, 4.96s/it]
We obtain the following posterior trajectory


We can switch to a different denoiser, for example, the DiffUNet denoiser from the EDM framework.
denoiser = dinv.models.DiffUNet(pretrained="download").to(device)
sigma_min = 0.02
sigma_max = 100
rng = torch.Generator(device)
timesteps = torch.linspace(1, 0.001, 200)
solver = EulerSolver(timesteps=timesteps, rng=rng)
sde = VarianceExplodingDiffusion(
sigma_max=sigma_max, sigma_min=sigma_min, alpha=1.0, device=device, dtype=dtype
)
# sde = VariancePreservingDiffusion(device=device, dtype=dtype)
x = dinv.utils.load_url_image(
dinv.utils.demo.get_image_url("celeba_example.jpg"),
img_size=256,
resize_mode="resize",
).to(device)
physics = dinv.physics.Inpainting(
mask=0.5,
tensor_size=x.shape[1:],
device=device,
)
y = physics(x)
model = PosteriorDiffusion(
data_fidelity=DPSDataFidelity(denoiser=denoiser),
denoiser=denoiser,
sde=sde,
solver=solver,
dtype=dtype,
device=device,
verbose=True,
)
# To perform posterior sampling, we need to provide the measurements, the physics and the solver.
x_hat, trajectory = model(
y=y,
physics=physics,
seed=None,
get_trajectory=True,
)
# Here, we plot the original image, the measurement and the posterior sample
dinv.utils.plot(
[x, y, x_hat],
show=True,
titles=["Original", "Measurement", "Posterior sample DiffUNet"],
save_fn="posterior_sample_DiffUNet.png",
figsize=(figsize * 3, figsize),
)
# We can also save the trajectory of the posterior sample
dinv.utils.save_videos(
trajectory[::gif_frequency],
time_dim=0,
titles=["Posterior trajectory DiffUNet"],
save_fn="posterior_sample_DiffUNet.gif",
figsize=(figsize, figsize),
)
0%| | 0/199 [00:00<?, ?it/s]
1%| | 1/199 [00:07<25:36, 7.76s/it]
1%| | 2/199 [00:15<24:56, 7.60s/it]
2%|▏ | 3/199 [00:22<24:37, 7.54s/it]
2%|▏ | 4/199 [00:30<24:25, 7.52s/it]
3%|▎ | 5/199 [00:37<24:15, 7.50s/it]
3%|▎ | 6/199 [00:45<24:07, 7.50s/it]
4%|▎ | 7/199 [00:52<24:01, 7.51s/it]
4%|▍ | 8/199 [01:00<23:51, 7.49s/it]
5%|▍ | 9/199 [01:07<23:42, 7.49s/it]
5%|▌ | 10/199 [01:15<23:34, 7.49s/it]
6%|▌ | 11/199 [01:22<23:26, 7.48s/it]
6%|▌ | 12/199 [01:30<23:21, 7.50s/it]
7%|▋ | 13/199 [01:37<23:12, 7.48s/it]
7%|▋ | 14/199 [01:45<23:04, 7.48s/it]
8%|▊ | 15/199 [01:52<22:55, 7.48s/it]
8%|▊ | 16/199 [01:59<22:47, 7.47s/it]
9%|▊ | 17/199 [02:07<22:38, 7.47s/it]
9%|▉ | 18/199 [02:14<22:35, 7.49s/it]
10%|▉ | 19/199 [02:22<22:27, 7.48s/it]
10%|█ | 20/199 [02:29<22:19, 7.49s/it]
11%|█ | 21/199 [02:37<22:12, 7.49s/it]
11%|█ | 22/199 [02:44<22:04, 7.48s/it]
12%|█▏ | 23/199 [02:52<21:56, 7.48s/it]
12%|█▏ | 24/199 [02:59<21:47, 7.47s/it]
13%|█▎ | 25/199 [03:07<21:40, 7.47s/it]
13%|█▎ | 26/199 [03:14<21:36, 7.49s/it]
14%|█▎ | 27/199 [03:22<21:30, 7.50s/it]
14%|█▍ | 28/199 [03:29<21:23, 7.50s/it]
15%|█▍ | 29/199 [03:37<21:14, 7.50s/it]
15%|█▌ | 30/199 [03:44<21:05, 7.49s/it]
16%|█▌ | 31/199 [03:52<20:56, 7.48s/it]
16%|█▌ | 32/199 [03:59<20:44, 7.45s/it]
17%|█▋ | 33/199 [04:07<20:38, 7.46s/it]
17%|█▋ | 34/199 [04:14<20:32, 7.47s/it]
18%|█▊ | 35/199 [04:22<20:24, 7.47s/it]
18%|█▊ | 36/199 [04:29<20:17, 7.47s/it]
19%|█▊ | 37/199 [04:37<20:11, 7.48s/it]
19%|█▉ | 38/199 [04:44<20:05, 7.48s/it]
20%|█▉ | 39/199 [04:52<19:58, 7.49s/it]
20%|██ | 40/199 [04:59<19:51, 7.49s/it]
21%|██ | 41/199 [05:07<19:42, 7.49s/it]
21%|██ | 42/199 [05:14<19:35, 7.49s/it]
22%|██▏ | 43/199 [05:22<19:28, 7.49s/it]
22%|██▏ | 44/199 [05:29<19:22, 7.50s/it]
23%|██▎ | 45/199 [05:37<19:16, 7.51s/it]
23%|██▎ | 46/199 [05:44<19:07, 7.50s/it]
24%|██▎ | 47/199 [05:52<18:59, 7.49s/it]
24%|██▍ | 48/199 [05:59<18:50, 7.48s/it]
25%|██▍ | 49/199 [06:07<18:44, 7.50s/it]
25%|██▌ | 50/199 [06:14<18:37, 7.50s/it]
26%|██▌ | 51/199 [06:22<18:29, 7.50s/it]
26%|██▌ | 52/199 [06:29<18:23, 7.50s/it]
27%|██▋ | 53/199 [06:37<18:15, 7.50s/it]
27%|██▋ | 54/199 [06:44<18:07, 7.50s/it]
28%|██▊ | 55/199 [06:52<17:59, 7.50s/it]
28%|██▊ | 56/199 [06:59<17:50, 7.48s/it]
29%|██▊ | 57/199 [07:06<17:42, 7.48s/it]
29%|██▉ | 58/199 [07:14<17:37, 7.50s/it]
30%|██▉ | 59/199 [07:21<17:29, 7.49s/it]
30%|███ | 60/199 [07:29<17:21, 7.49s/it]
31%|███ | 61/199 [07:36<17:14, 7.50s/it]
31%|███ | 62/199 [07:44<17:08, 7.51s/it]
32%|███▏ | 63/199 [07:51<16:59, 7.50s/it]
32%|███▏ | 64/199 [07:59<16:52, 7.50s/it]
33%|███▎ | 65/199 [08:06<16:43, 7.49s/it]
33%|███▎ | 66/199 [08:14<16:36, 7.49s/it]
34%|███▎ | 67/199 [08:21<16:29, 7.49s/it]
34%|███▍ | 68/199 [08:29<16:20, 7.49s/it]
35%|███▍ | 69/199 [08:36<16:15, 7.51s/it]
35%|███▌ | 70/199 [08:44<16:07, 7.50s/it]
36%|███▌ | 71/199 [08:51<15:59, 7.49s/it]
36%|███▌ | 72/199 [08:59<15:50, 7.49s/it]
37%|███▋ | 73/199 [09:06<15:41, 7.47s/it]
37%|███▋ | 74/199 [09:14<15:35, 7.48s/it]
38%|███▊ | 75/199 [09:21<15:27, 7.48s/it]
38%|███▊ | 76/199 [09:29<15:19, 7.48s/it]
39%|███▊ | 77/199 [09:36<15:12, 7.48s/it]
39%|███▉ | 78/199 [09:44<15:03, 7.47s/it]
40%|███▉ | 79/199 [09:51<14:56, 7.47s/it]
40%|████ | 80/199 [09:59<14:48, 7.47s/it]
41%|████ | 81/199 [10:06<14:41, 7.47s/it]
41%|████ | 82/199 [10:14<14:34, 7.47s/it]
42%|████▏ | 83/199 [10:21<14:25, 7.46s/it]
42%|████▏ | 84/199 [10:29<14:19, 7.47s/it]
43%|████▎ | 85/199 [10:36<14:13, 7.48s/it]
43%|████▎ | 86/199 [10:44<14:06, 7.49s/it]
44%|████▎ | 87/199 [10:51<13:58, 7.49s/it]
44%|████▍ | 88/199 [10:59<13:50, 7.48s/it]
45%|████▍ | 89/199 [11:06<13:43, 7.48s/it]
45%|████▌ | 90/199 [11:13<13:36, 7.49s/it]
46%|████▌ | 91/199 [11:21<13:27, 7.48s/it]
46%|████▌ | 92/199 [11:28<13:19, 7.47s/it]
47%|████▋ | 93/199 [11:36<13:13, 7.48s/it]
47%|████▋ | 94/199 [11:43<13:04, 7.47s/it]
48%|████▊ | 95/199 [11:51<12:56, 7.47s/it]
48%|████▊ | 96/199 [11:58<12:49, 7.47s/it]
49%|████▊ | 97/199 [12:06<12:42, 7.47s/it]
49%|████▉ | 98/199 [12:13<12:35, 7.48s/it]
50%|████▉ | 99/199 [12:21<12:27, 7.47s/it]
50%|█████ | 100/199 [12:28<12:19, 7.47s/it]
51%|█████ | 101/199 [12:36<12:13, 7.48s/it]
51%|█████▏ | 102/199 [12:43<12:06, 7.49s/it]
52%|█████▏ | 103/199 [12:51<11:59, 7.49s/it]
52%|█████▏ | 104/199 [12:58<11:51, 7.49s/it]
53%|█████▎ | 105/199 [13:06<11:43, 7.48s/it]
53%|█████▎ | 106/199 [13:13<11:36, 7.49s/it]
54%|█████▍ | 107/199 [13:21<11:28, 7.48s/it]
54%|█████▍ | 108/199 [13:28<11:20, 7.48s/it]
55%|█████▍ | 109/199 [13:36<11:13, 7.48s/it]
55%|█████▌ | 110/199 [13:43<11:06, 7.49s/it]
56%|█████▌ | 111/199 [13:51<11:00, 7.50s/it]
56%|█████▋ | 112/199 [13:58<10:54, 7.53s/it]
57%|█████▋ | 113/199 [14:06<10:45, 7.51s/it]
57%|█████▋ | 114/199 [14:13<10:37, 7.50s/it]
58%|█████▊ | 115/199 [14:21<10:29, 7.49s/it]
58%|█████▊ | 116/199 [14:28<10:21, 7.49s/it]
59%|█████▉ | 117/199 [14:36<10:14, 7.49s/it]
59%|█████▉ | 118/199 [14:43<10:06, 7.49s/it]
60%|█████▉ | 119/199 [14:51<09:58, 7.48s/it]
60%|██████ | 120/199 [14:58<09:50, 7.48s/it]
61%|██████ | 121/199 [15:06<09:43, 7.48s/it]
61%|██████▏ | 122/199 [15:13<09:37, 7.49s/it]
62%|██████▏ | 123/199 [15:21<09:28, 7.49s/it]
62%|██████▏ | 124/199 [15:28<09:21, 7.49s/it]
63%|██████▎ | 125/199 [15:35<09:13, 7.48s/it]
63%|██████▎ | 126/199 [15:43<09:05, 7.47s/it]
64%|██████▍ | 127/199 [15:50<08:58, 7.48s/it]
64%|██████▍ | 128/199 [15:58<08:50, 7.48s/it]
65%|██████▍ | 129/199 [16:05<08:42, 7.47s/it]
65%|██████▌ | 130/199 [16:13<08:35, 7.48s/it]
66%|██████▌ | 131/199 [16:20<08:28, 7.48s/it]
66%|██████▋ | 132/199 [16:28<08:21, 7.49s/it]
67%|██████▋ | 133/199 [16:35<08:13, 7.48s/it]
67%|██████▋ | 134/199 [16:43<08:06, 7.48s/it]
68%|██████▊ | 135/199 [16:50<07:58, 7.48s/it]
68%|██████▊ | 136/199 [16:58<07:51, 7.48s/it]
69%|██████▉ | 137/199 [17:05<07:43, 7.47s/it]
69%|██████▉ | 138/199 [17:13<07:35, 7.47s/it]
70%|██████▉ | 139/199 [17:20<07:28, 7.48s/it]
70%|███████ | 140/199 [17:28<07:22, 7.50s/it]
71%|███████ | 141/199 [17:35<07:14, 7.50s/it]
71%|███████▏ | 142/199 [17:43<07:06, 7.49s/it]
72%|███████▏ | 143/199 [17:50<06:59, 7.49s/it]
72%|███████▏ | 144/199 [17:58<06:52, 7.49s/it]
73%|███████▎ | 145/199 [18:05<06:44, 7.50s/it]
73%|███████▎ | 146/199 [18:13<06:37, 7.50s/it]
74%|███████▍ | 147/199 [18:20<06:29, 7.50s/it]
74%|███████▍ | 148/199 [18:28<06:21, 7.49s/it]
75%|███████▍ | 149/199 [18:35<06:14, 7.49s/it]
75%|███████▌ | 150/199 [18:43<06:07, 7.50s/it]
76%|███████▌ | 151/199 [18:50<05:59, 7.49s/it]
76%|███████▋ | 152/199 [18:58<05:52, 7.51s/it]
77%|███████▋ | 153/199 [19:05<05:45, 7.51s/it]
77%|███████▋ | 154/199 [19:13<05:37, 7.51s/it]
78%|███████▊ | 155/199 [19:20<05:29, 7.50s/it]
78%|███████▊ | 156/199 [19:28<05:22, 7.49s/it]
79%|███████▉ | 157/199 [19:35<05:14, 7.50s/it]
79%|███████▉ | 158/199 [19:43<05:06, 7.49s/it]
80%|███████▉ | 159/199 [19:50<04:59, 7.48s/it]
80%|████████ | 160/199 [19:58<04:51, 7.48s/it]
81%|████████ | 161/199 [20:05<04:44, 7.48s/it]
81%|████████▏ | 162/199 [20:13<04:36, 7.49s/it]
82%|████████▏ | 163/199 [20:20<04:29, 7.48s/it]
82%|████████▏ | 164/199 [20:27<04:21, 7.48s/it]
83%|████████▎ | 165/199 [20:35<04:14, 7.49s/it]
83%|████████▎ | 166/199 [20:42<04:06, 7.48s/it]
84%|████████▍ | 167/199 [20:50<03:59, 7.48s/it]
84%|████████▍ | 168/199 [20:57<03:51, 7.47s/it]
85%|████████▍ | 169/199 [21:05<03:44, 7.47s/it]
85%|████████▌ | 170/199 [21:12<03:36, 7.46s/it]
86%|████████▌ | 171/199 [21:20<03:29, 7.47s/it]
86%|████████▋ | 172/199 [21:27<03:21, 7.47s/it]
87%|████████▋ | 173/199 [21:35<03:14, 7.48s/it]
87%|████████▋ | 174/199 [21:42<03:07, 7.49s/it]
88%|████████▊ | 175/199 [21:50<02:59, 7.49s/it]
88%|████████▊ | 176/199 [21:57<02:52, 7.48s/it]
89%|████████▉ | 177/199 [22:05<02:44, 7.48s/it]
89%|████████▉ | 178/199 [22:12<02:37, 7.49s/it]
90%|████████▉ | 179/199 [22:20<02:29, 7.48s/it]
90%|█████████ | 180/199 [22:27<02:21, 7.46s/it]
91%|█████████ | 181/199 [22:35<02:14, 7.48s/it]
91%|█████████▏| 182/199 [22:42<02:07, 7.49s/it]
92%|█████████▏| 183/199 [22:50<02:00, 7.50s/it]
92%|█████████▏| 184/199 [22:57<01:52, 7.51s/it]
93%|█████████▎| 185/199 [23:05<01:45, 7.50s/it]
93%|█████████▎| 186/199 [23:12<01:37, 7.51s/it]
94%|█████████▍| 187/199 [23:20<01:30, 7.51s/it]
94%|█████████▍| 188/199 [23:27<01:22, 7.50s/it]
95%|█████████▍| 189/199 [23:35<01:15, 7.51s/it]
95%|█████████▌| 190/199 [23:42<01:07, 7.52s/it]
96%|█████████▌| 191/199 [23:50<01:00, 7.51s/it]
96%|█████████▋| 192/199 [23:57<00:52, 7.51s/it]
97%|█████████▋| 193/199 [24:05<00:45, 7.51s/it]
97%|█████████▋| 194/199 [24:12<00:37, 7.51s/it]
98%|█████████▊| 195/199 [24:20<00:29, 7.50s/it]
98%|█████████▊| 196/199 [24:27<00:22, 7.50s/it]
99%|█████████▉| 197/199 [24:35<00:15, 7.51s/it]
99%|█████████▉| 198/199 [24:42<00:07, 7.51s/it]
100%|██████████| 199/199 [24:50<00:00, 7.50s/it]
100%|██████████| 199/199 [24:50<00:00, 7.49s/it]
We obtain the following posterior trajectory


Total running time of the script: (64 minutes 15.953 seconds)